Obtain and format results produced by tuning functions
Usage
collect_predictions(x, ...)
# S3 method for default
collect_predictions(x, ...)
# S3 method for tune_results
collect_predictions(x, summarize = FALSE, parameters = NULL, ...)
collect_metrics(x, ...)
# S3 method for tune_results
collect_metrics(x, summarize = TRUE, ...)
collect_notes(x, ...)
# S3 method for tune_results
collect_notes(x, ...)
collect_extracts(x, ...)
# S3 method for tune_results
collect_extracts(x, ...)
Arguments
- x
The results of
tune_grid()
,tune_bayes()
,fit_resamples()
, orlast_fit()
. Forcollect_predictions()
, the control optionsave_pred = TRUE
should have been used.- ...
Not currently used.
- summarize
A logical; should metrics be summarized over resamples (
TRUE
) or return the values for each individual resample. Note that, ifx
is created bylast_fit()
,summarize
has no effect. For the other object types, the method of summarizing predictions is detailed below.- parameters
An optional tibble of tuning parameter values that can be used to filter the predicted values before processing. This tibble should only have columns for each tuning parameter identifier (e.g.
"my_param"
iftune("my_param")
was used).
Value
A tibble. The column names depend on the results and the mode of the model.
For collect_metrics()
and collect_predictions()
, when unsummarized,
there are columns for each tuning parameter (using the id
from tune()
,
if any).
collect_metrics()
also has columns .metric
, and .estimator
. When the
results are summarized, there are columns for mean
, n
, and std_err
.
When not summarized, the additional columns for the resampling identifier(s)
and .estimate
.
For collect_predictions()
, there are additional columns for the resampling
identifier(s), columns for the predicted values (e.g., .pred
,
.pred_class
, etc.), and a column for the outcome(s) using the original
column name(s) in the data.
collect_predictions()
can summarize the various results over
replicate out-of-sample predictions. For example, when using the bootstrap,
each row in the original training set has multiple holdout predictions
(across assessment sets). To convert these results to a format where every
training set same has a single predicted value, the results are averaged
over replicate predictions.
For regression cases, the numeric predictions are simply averaged. For classification models, the problem is more complex. When class probabilities are used, these are averaged and then re-normalized to make sure that they add to one. If hard class predictions also exist in the data, then these are determined from the summarized probability estimates (so that they match). If only hard class predictions are in the results, then the mode is used to summarize.
collect_notes()
returns a tibble with columns for the resampling
indicators, the location (preprocessor, model, etc.), type (error or warning),
and the notes.
collect_extracts()
returns a tibble with columns for the resampling
indicators, the location (preprocessor, model, etc.), and objects extracted
from workflows via the extract
argument to control functions.
Examples
data("example_ames_knn")
# The parameters for the model:
extract_parameter_set_dials(ames_wflow)
#> Collection of 5 parameters for tuning
#>
#> identifier type object
#> K neighbors nparam[+]
#> weight_func weight_func dparam[+]
#> dist_power dist_power nparam[+]
#> lon deg_free nparam[+]
#> lat deg_free nparam[+]
#>
# Summarized over resamples
collect_metrics(ames_grid_search)
#> # A tibble: 20 × 11
#> K weight_func dist_power lon lat .metric .estimator mean
#> <int> <chr> <dbl> <int> <int> <chr> <chr> <dbl>
#> 1 35 optimal 1.32 8 1 rmse standard 0.0785
#> 2 35 optimal 1.32 8 1 rsq standard 0.823
#> 3 35 rank 1.29 3 13 rmse standard 0.0809
#> 4 35 rank 1.29 3 13 rsq standard 0.814
#> 5 21 cos 0.626 1 4 rmse standard 0.0746
#> 6 21 cos 0.626 1 4 rsq standard 0.836
#> 7 4 biweight 0.311 8 4 rmse standard 0.0777
#> 8 4 biweight 0.311 8 4 rsq standard 0.814
#> 9 32 triangular 0.165 9 15 rmse standard 0.0770
#> 10 32 triangular 0.165 9 15 rsq standard 0.826
#> 11 3 rank 1.86 10 15 rmse standard 0.0875
#> 12 3 rank 1.86 10 15 rsq standard 0.762
#> 13 40 triangular 0.167 11 7 rmse standard 0.0778
#> 14 40 triangular 0.167 11 7 rsq standard 0.822
#> 15 12 epanechnikov 1.53 4 7 rmse standard 0.0774
#> 16 12 epanechnikov 1.53 4 7 rsq standard 0.820
#> 17 5 rank 0.411 2 7 rmse standard 0.0740
#> 18 5 rank 0.411 2 7 rsq standard 0.833
#> 19 33 triweight 0.511 10 3 rmse standard 0.0728
#> 20 33 triweight 0.511 10 3 rsq standard 0.842
#> # ℹ 3 more variables: n <int>, std_err <dbl>, .config <chr>
# Per-resample values
collect_metrics(ames_grid_search, summarize = FALSE)
#> # A tibble: 200 × 10
#> id K weight_func dist_power lon lat .metric .estimator
#> <chr> <int> <chr> <dbl> <int> <int> <chr> <chr>
#> 1 Fold01 35 optimal 1.32 8 1 rmse standard
#> 2 Fold01 35 optimal 1.32 8 1 rsq standard
#> 3 Fold02 35 optimal 1.32 8 1 rmse standard
#> 4 Fold02 35 optimal 1.32 8 1 rsq standard
#> 5 Fold03 35 optimal 1.32 8 1 rmse standard
#> 6 Fold03 35 optimal 1.32 8 1 rsq standard
#> 7 Fold04 35 optimal 1.32 8 1 rmse standard
#> 8 Fold04 35 optimal 1.32 8 1 rsq standard
#> 9 Fold05 35 optimal 1.32 8 1 rmse standard
#> 10 Fold05 35 optimal 1.32 8 1 rsq standard
#> # ℹ 190 more rows
#> # ℹ 2 more variables: .estimate <dbl>, .config <chr>
# ---------------------------------------------------------------------------
library(parsnip)
library(rsample)
library(dplyr)
#>
#> Attaching package: ‘dplyr’
#> The following objects are masked from ‘package:stats’:
#>
#> filter, lag
#> The following objects are masked from ‘package:base’:
#>
#> intersect, setdiff, setequal, union
library(recipes)
#>
#> Attaching package: ‘recipes’
#> The following object is masked from ‘package:stats’:
#>
#> step
library(tibble)
lm_mod <- linear_reg() %>% set_engine("lm")
set.seed(93599150)
car_folds <- vfold_cv(mtcars, v = 2, repeats = 3)
ctrl <- control_resamples(save_pred = TRUE, extract = extract_fit_engine)
spline_rec <-
recipe(mpg ~ ., data = mtcars) %>%
step_ns(disp, deg_free = tune("df"))
grid <- tibble(df = 3:6)
resampled <-
lm_mod %>%
tune_grid(spline_rec, resamples = car_folds, control = ctrl, grid = grid)
collect_predictions(resampled) %>% arrange(.row)
#> # A tibble: 384 × 7
#> id id2 .pred .row df mpg .config
#> <chr> <chr> <dbl> <int> <int> <dbl> <chr>
#> 1 Repeat1 Fold2 16.5 1 3 21 Preprocessor1_Model1
#> 2 Repeat2 Fold1 19.0 1 3 21 Preprocessor1_Model1
#> 3 Repeat3 Fold1 20.0 1 3 21 Preprocessor1_Model1
#> 4 Repeat1 Fold2 15.1 1 4 21 Preprocessor2_Model1
#> 5 Repeat2 Fold1 17.7 1 4 21 Preprocessor2_Model1
#> 6 Repeat3 Fold1 20.1 1 4 21 Preprocessor2_Model1
#> 7 Repeat1 Fold2 17.9 1 5 21 Preprocessor3_Model1
#> 8 Repeat2 Fold1 18.3 1 5 21 Preprocessor3_Model1
#> 9 Repeat3 Fold1 20.4 1 5 21 Preprocessor3_Model1
#> 10 Repeat1 Fold2 15.1 1 6 21 Preprocessor4_Model1
#> # ℹ 374 more rows
collect_predictions(resampled, summarize = TRUE) %>% arrange(.row)
#> # A tibble: 128 × 5
#> .row df mpg .config .pred
#> <int> <int> <dbl> <chr> <dbl>
#> 1 1 3 21 Preprocessor1_Model1 18.5
#> 2 1 4 21 Preprocessor2_Model1 17.6
#> 3 1 5 21 Preprocessor3_Model1 18.9
#> 4 1 6 21 Preprocessor4_Model1 16.7
#> 5 2 3 21 Preprocessor1_Model1 19.4
#> 6 2 4 21 Preprocessor2_Model1 19.0
#> 7 2 5 21 Preprocessor3_Model1 18.7
#> 8 2 6 21 Preprocessor4_Model1 16.4
#> 9 3 3 22.8 Preprocessor1_Model1 31.8
#> 10 3 4 22.8 Preprocessor2_Model1 23.8
#> # ℹ 118 more rows
collect_predictions(resampled, summarize = TRUE, grid[1, ]) %>% arrange(.row)
#> # A tibble: 32 × 5
#> .row df mpg .config .pred
#> <int> <int> <dbl> <chr> <dbl>
#> 1 1 3 21 Preprocessor1_Model1 18.5
#> 2 2 3 21 Preprocessor1_Model1 19.4
#> 3 3 3 22.8 Preprocessor1_Model1 31.8
#> 4 4 3 21.4 Preprocessor1_Model1 20.2
#> 5 5 3 18.7 Preprocessor1_Model1 18.4
#> 6 6 3 18.1 Preprocessor1_Model1 20.6
#> 7 7 3 14.3 Preprocessor1_Model1 13.5
#> 8 8 3 24.4 Preprocessor1_Model1 19.2
#> 9 9 3 22.8 Preprocessor1_Model1 34.8
#> 10 10 3 19.2 Preprocessor1_Model1 16.6
#> # ℹ 22 more rows
collect_extracts(resampled)
#> # A tibble: 24 × 5
#> id id2 df .extracts .config
#> <chr> <chr> <int> <list> <chr>
#> 1 Repeat1 Fold1 3 <lm> Preprocessor1_Model1
#> 2 Repeat1 Fold1 4 <lm> Preprocessor2_Model1
#> 3 Repeat1 Fold1 5 <lm> Preprocessor3_Model1
#> 4 Repeat1 Fold1 6 <lm> Preprocessor4_Model1
#> 5 Repeat1 Fold2 3 <lm> Preprocessor1_Model1
#> 6 Repeat1 Fold2 4 <lm> Preprocessor2_Model1
#> 7 Repeat1 Fold2 5 <lm> Preprocessor3_Model1
#> 8 Repeat1 Fold2 6 <lm> Preprocessor4_Model1
#> 9 Repeat2 Fold1 3 <lm> Preprocessor1_Model1
#> 10 Repeat2 Fold1 4 <lm> Preprocessor2_Model1
#> # ℹ 14 more rows