Skip to content

Model tuning via grid search

Usage

melodie_grid(object, ...)

# S3 method for class 'model_spec'
melodie_grid(
  object,
  preprocessor,
  resamples,
  ...,
  param_info = NULL,
  grid = 10,
  metrics = NULL,
  eval_time = NULL,
  control = control_grid()
)

# S3 method for class 'workflow'
melodie_grid(
  object,
  resamples,
  ...,
  param_info = NULL,
  grid = 10,
  metrics = NULL,
  eval_time = NULL,
  control = control_grid()
)

Arguments

object

A parsnip model specification or an unfitted workflow(). No tuning parameters are allowed; if arguments have been marked with tune(), their values must be finalized.

...

Not currently used.

preprocessor

A traditional model formula or a recipe created using recipes::recipe().

resamples

An rset resampling object created from an rsample function, such as rsample::vfold_cv().

param_info

A dials::parameters() object or NULL. If none is given, a parameters set is derived from other arguments. Passing this argument can be useful when parameter ranges need to be customized.

grid

A data frame of tuning combinations or a positive integer. The data frame should have columns for each parameter being tuned and rows for tuning parameter candidates. An integer denotes the number of candidate parameter sets to be created automatically.

metrics

A yardstick::metric_set(), or NULL to compute a standard set of metrics.

eval_time

A numeric vector of time points where dynamic event time metrics should be computed (e.g. the time-dependent ROC curve, etc). The values must be non-negative and should probably be no greater than the largest event time in the training set (See Details below).

control

An object used to modify the tuning process, likely created by control_grid().

Examples

library(recipes)
library(rsample)
library(parsnip)
library(workflows)
library(ggplot2)

# ---------------------------------------------------------------------------

set.seed(6735)
folds <- vfold_cv(mtcars, v = 5)

# ---------------------------------------------------------------------------

# tuning recipe parameters:

spline_rec <-
  recipe(mpg ~ ., data = mtcars) |>
  step_spline_natural(disp, deg_free = tune("disp")) |>
  step_spline_natural(wt, deg_free = tune("wt"))

lin_mod <-
  linear_reg() |>
  set_engine("lm")

# manually create a grid
spline_grid <- expand.grid(disp = 2:5, wt = 2:5)

# Warnings will occur from making spline terms on the holdout data that are
# extrapolations.
spline_res <-
  melodie_grid(lin_mod, spline_rec, resamples = folds, grid = spline_grid)
spline_res
#> # Tuning results
#> # 5-fold cross-validation 
#> # A tibble: 5 × 4
#>   splits         id    .metrics          .notes          
#>   <list>         <chr> <list>            <list>          
#> 1 <split [25/7]> Fold1 <tibble [32 × 6]> <tibble [0 × 3]>
#> 2 <split [25/7]> Fold2 <tibble [32 × 6]> <tibble [0 × 3]>
#> 3 <split [26/6]> Fold3 <tibble [32 × 6]> <tibble [0 × 3]>
#> 4 <split [26/6]> Fold4 <tibble [32 × 6]> <tibble [0 × 3]>
#> 5 <split [26/6]> Fold5 <tibble [32 × 6]> <tibble [0 × 3]>


show_best(spline_res, metric = "rmse")
#> # A tibble: 5 × 8
#>    disp    wt .metric .estimator  mean     n std_err .config         
#>   <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>           
#> 1     3     2 rmse    standard    2.54     5   0.207 pre05_mod0_post0
#> 2     3     3 rmse    standard    2.64     5   0.234 pre06_mod0_post0
#> 3     4     3 rmse    standard    2.82     5   0.456 pre10_mod0_post0
#> 4     4     2 rmse    standard    2.93     5   0.489 pre09_mod0_post0
#> 5     4     4 rmse    standard    3.01     5   0.475 pre11_mod0_post0

# ---------------------------------------------------------------------------

# tune model parameters only (example requires the `kernlab` package)

car_rec <-
  recipe(mpg ~ ., data = mtcars) |>
  step_normalize(all_predictors())

svm_mod <-
  svm_rbf(cost = tune(), rbf_sigma = tune()) |>
  set_engine("kernlab") |>
  set_mode("regression")

# Use a space-filling design with 7 points
set.seed(3254)
svm_res <- melodie_grid(svm_mod, car_rec, resamples = folds, grid = 7)
svm_res
#> # Tuning results
#> # 5-fold cross-validation 
#> # A tibble: 5 × 4
#>   splits         id    .metrics          .notes          
#>   <list>         <chr> <list>            <list>          
#> 1 <split [25/7]> Fold1 <tibble [14 × 6]> <tibble [0 × 3]>
#> 2 <split [25/7]> Fold2 <tibble [14 × 6]> <tibble [0 × 3]>
#> 3 <split [26/6]> Fold3 <tibble [14 × 6]> <tibble [0 × 3]>
#> 4 <split [26/6]> Fold4 <tibble [14 × 6]> <tibble [0 × 3]>
#> 5 <split [26/6]> Fold5 <tibble [14 × 6]> <tibble [0 × 3]>