Obtain and format results produced by tuning functions

collect_predictions(x, summarize = FALSE, parameters = NULL)

collect_metrics(x, summarize = TRUE)

Arguments

x

The results of tune_grid(), tune_bayes(), fit_resamples(), or last_fit(). For collect_predictions(), the control option save_pred = TRUE should have been used.

summarize

A logical; should metrics be summarized over resamples (TRUE) or return the values for each individual resample. Note that, if x is created by last_fit(), summarize has no effect. For the other object types, the method of summarizing predictions is detailed below.

parameters

An optional tibble of tuning parameter values that can be used to filter the predicted values before processing. This tibble should only have columns for each tuning parameter identifier (e.g. "my_param" if tune("my_param") was used).

Value

A tibble. The column names depend on the results and the mode of the model.

For collect_metrics() and collect_predictions(), when unsummarized, there are columns for each tuning parameter (using the id from tune(), if any). collect_metrics() also has columns .metric, and .estimator. When the results are summarized, there are columns for mean, n, and std_err. When not summarized, the additional columns for the resampling identifier(s) and .estimate.

For collect_predictions(), there are additional columns for the resampling identifier(s), columns for the predicted values (e.g., .pred, .pred_class, etc.), and a column for the outcome(s) using the original column name(s) in the data.

collect_predictions() can summarize the various results over replicate out-of-sample predictions. For example, when using the bootstrap, each row in the original training set has multiple holdout predictions (across assessment sets). To convert these results to a format where every training set same has a single predicted value, the results are averaged over replicate predictions.

For regression cases, the numeric predictions are simply averaged. For classification models, the problem is more complex. When class probabilities are used, these are averaged and then re-normalized to make sure that they add to one. If hard class predictions also exist in the data, then these are determined from the summarized probability estimates (so that they match). If only hard class predictions are in the results, then the mode is used to summarize.

Examples

# \donttest{ data("example_ames_knn") # The parameters for the model: parameters(ames_wflow)
#> Collection of 5 parameters for tuning #> #> id parameter type object class #> K neighbors nparam[+] #> weight_func weight_func dparam[+] #> dist_power dist_power nparam[+] #> lon deg_free nparam[+] #> lat deg_free nparam[+] #>
# Summarized over resamples collect_metrics(ames_grid_search)
#> # A tibble: 20 x 11 #> K weight_func dist_power lon lat .metric .estimator mean n #> <int> <chr> <dbl> <int> <int> <chr> <chr> <dbl> <int> #> 1 35 optimal 1.32 8 1 rmse standard 0.0785 10 #> 2 35 optimal 1.32 8 1 rsq standard 0.823 10 #> 3 35 rank 1.29 3 13 rmse standard 0.0809 10 #> 4 35 rank 1.29 3 13 rsq standard 0.814 10 #> 5 21 cos 0.626 1 4 rmse standard 0.0746 10 #> 6 21 cos 0.626 1 4 rsq standard 0.836 10 #> 7 4 biweight 0.311 8 4 rmse standard 0.0777 10 #> 8 4 biweight 0.311 8 4 rsq standard 0.814 10 #> 9 32 triangular 0.165 9 15 rmse standard 0.0770 10 #> 10 32 triangular 0.165 9 15 rsq standard 0.826 10 #> 11 3 rank 1.86 10 15 rmse standard 0.0875 10 #> 12 3 rank 1.86 10 15 rsq standard 0.762 10 #> 13 40 triangular 0.167 11 7 rmse standard 0.0778 10 #> 14 40 triangular 0.167 11 7 rsq standard 0.822 10 #> 15 12 epanechnik… 1.53 4 7 rmse standard 0.0774 10 #> 16 12 epanechnik… 1.53 4 7 rsq standard 0.820 10 #> 17 5 rank 0.411 2 7 rmse standard 0.0740 10 #> 18 5 rank 0.411 2 7 rsq standard 0.833 10 #> 19 33 triweight 0.511 10 3 rmse standard 0.0728 10 #> 20 33 triweight 0.511 10 3 rsq standard 0.842 10 #> # … with 2 more variables: std_err <dbl>, .config <chr>
# Per-resample values collect_metrics(ames_grid_search, summarize = FALSE)
#> # A tibble: 200 x 10 #> id K weight_func dist_power lon lat .metric .estimator .estimate #> <chr> <int> <chr> <dbl> <int> <int> <chr> <chr> <dbl> #> 1 Fold… 35 optimal 1.32 8 1 rmse standard 0.0859 #> 2 Fold… 35 optimal 1.32 8 1 rsq standard 0.810 #> 3 Fold… 35 rank 1.29 3 13 rmse standard 0.0878 #> 4 Fold… 35 rank 1.29 3 13 rsq standard 0.802 #> 5 Fold… 21 cos 0.626 1 4 rmse standard 0.0812 #> 6 Fold… 21 cos 0.626 1 4 rsq standard 0.826 #> 7 Fold… 4 biweight 0.311 8 4 rmse standard 0.0886 #> 8 Fold… 4 biweight 0.311 8 4 rsq standard 0.780 #> 9 Fold… 32 triangular 0.165 9 15 rmse standard 0.0825 #> 10 Fold… 32 triangular 0.165 9 15 rsq standard 0.823 #> # … with 190 more rows, and 1 more variable: .config <chr>
# --------------------------------------------------------------------------- library(parsnip) library(rsample) library(dplyr)
#> #> Attaching package: ‘dplyr’
#> The following objects are masked from ‘package:stats’: #> #> filter, lag
#> The following objects are masked from ‘package:base’: #> #> intersect, setdiff, setequal, union
library(recipes)
#> #> Attaching package: ‘recipes’
#> The following object is masked from ‘package:stats’: #> #> step
library(tibble) lm_mod <-linear_reg() %>% set_engine("lm") set.seed(93599150) car_folds <- vfold_cv(mtcars, v = 2, repeats = 3) ctrl <- control_resamples(save_pred = TRUE) spline_rec <- recipe(mpg ~ ., data = mtcars) %>% step_ns(disp, deg_free = tune("df")) grid <- tibble(df = 3:6) resampled <- tune_grid(spline_rec, lm_mod, resamples = car_folds, control = ctrl, grid = grid)
#> Warning: `tune_grid.recipe()` is deprecated as of lifecycle 0.1.0. #> The first argument to `tune_grid()` should be either a model or a workflow. In the future, you can use: #> tune_grid(lm_mod, spline_rec, resamples = car_folds, grid = grid, #> control = ctrl) #> This warning is displayed once every 8 hours. #> Call `lifecycle::last_warnings()` to see where this warning was generated.
collect_predictions(resampled) %>% arrange(.row)
#> # A tibble: 384 x 7 #> id id2 .pred .row df mpg .config #> <chr> <chr> <dbl> <int> <int> <dbl> <chr> #> 1 Repeat1 Fold2 16.5 1 3 21 Recipe1 #> 2 Repeat1 Fold2 15.1 1 4 21 Recipe2 #> 3 Repeat1 Fold2 17.9 1 5 21 Recipe3 #> 4 Repeat1 Fold2 15.1 1 6 21 Recipe4 #> 5 Repeat2 Fold1 19.0 1 3 21 Recipe1 #> 6 Repeat2 Fold1 17.7 1 4 21 Recipe2 #> 7 Repeat2 Fold1 18.3 1 5 21 Recipe3 #> 8 Repeat2 Fold1 15.5 1 6 21 Recipe4 #> 9 Repeat3 Fold1 20.0 1 3 21 Recipe1 #> 10 Repeat3 Fold1 20.1 1 4 21 Recipe2 #> # … with 374 more rows
collect_predictions(resampled, summarize = TRUE) %>% arrange(.row)
#> # A tibble: 128 x 5 #> .row df mpg .config .pred #> <int> <int> <dbl> <chr> <dbl> #> 1 1 3 21 Recipe1 18.5 #> 2 1 4 21 Recipe2 17.6 #> 3 1 5 21 Recipe3 18.9 #> 4 1 6 21 Recipe4 16.7 #> 5 2 3 21 Recipe1 19.4 #> 6 2 4 21 Recipe2 19.0 #> 7 2 5 21 Recipe3 18.7 #> 8 2 6 21 Recipe4 16.4 #> 9 3 3 22.8 Recipe1 31.8 #> 10 3 4 22.8 Recipe2 23.8 #> # … with 118 more rows
collect_predictions(resampled, summarize = TRUE, grid[1,]) %>% arrange(.row)
#> # A tibble: 32 x 5 #> .row df mpg .config .pred #> <int> <int> <dbl> <chr> <dbl> #> 1 1 3 21 Recipe1 18.5 #> 2 2 3 21 Recipe1 19.4 #> 3 3 3 22.8 Recipe1 31.8 #> 4 4 3 21.4 Recipe1 20.2 #> 5 5 3 18.7 Recipe1 18.4 #> 6 6 3 18.1 Recipe1 20.6 #> 7 7 3 14.3 Recipe1 13.5 #> 8 8 3 24.4 Recipe1 19.2 #> 9 9 3 22.8 Recipe1 34.8 #> 10 10 3 19.2 Recipe1 16.6 #> # … with 22 more rows
# }