Skip to content

For objects produced by the tune_*() functions, there may only be a subset of tuning parameter combinations of interest. For large data sets, it might be helpful to be able to remove some results. This function trims the .metrics column of unwanted results as well as columns .predictions and .extracts (if they were requested).

Usage

filter_parameters(x, ..., parameters = NULL)

Arguments

x

An object of class tune_results that has multiple tuning parameters.

...

Expressions that return a logical value, and are defined in terms of the tuning parameter values. If multiple expressions are included, they are combined with the & operator. Only rows for which all conditions evaluate to TRUE are kept.

parameters

A tibble of tuning parameter values that can be used to filter the predicted values before processing. This tibble should only have columns for tuning parameter identifiers (e.g. "my_param" if tune("my_param") was used). There can be multiple rows and one or more columns. If used, this parameter must be named.

Value

A version of x where the lists columns only retain the parameter combinations in parameters or satisfied by the filtering logic.

Details

Removing some parameter combinations might affect the results of autoplot() for the object.

Examples

library(dplyr)
library(tibble)

# For grid search:
data("example_ames_knn")

## -----------------------------------------------------------------------------
# select all combinations using the 'rank' weighting scheme

ames_grid_search %>%
  collect_metrics()
#> # A tibble: 20 × 11
#>        K weight_func  dist_power   lon   lat .metric .estimator   mean
#>    <int> <chr>             <dbl> <int> <int> <chr>   <chr>       <dbl>
#>  1    35 optimal           1.32      8     1 rmse    standard   0.0785
#>  2    35 optimal           1.32      8     1 rsq     standard   0.823 
#>  3    35 rank              1.29      3    13 rmse    standard   0.0809
#>  4    35 rank              1.29      3    13 rsq     standard   0.814 
#>  5    21 cos               0.626     1     4 rmse    standard   0.0746
#>  6    21 cos               0.626     1     4 rsq     standard   0.836 
#>  7     4 biweight          0.311     8     4 rmse    standard   0.0777
#>  8     4 biweight          0.311     8     4 rsq     standard   0.814 
#>  9    32 triangular        0.165     9    15 rmse    standard   0.0770
#> 10    32 triangular        0.165     9    15 rsq     standard   0.826 
#> 11     3 rank              1.86     10    15 rmse    standard   0.0875
#> 12     3 rank              1.86     10    15 rsq     standard   0.762 
#> 13    40 triangular        0.167    11     7 rmse    standard   0.0778
#> 14    40 triangular        0.167    11     7 rsq     standard   0.822 
#> 15    12 epanechnikov      1.53      4     7 rmse    standard   0.0774
#> 16    12 epanechnikov      1.53      4     7 rsq     standard   0.820 
#> 17     5 rank              0.411     2     7 rmse    standard   0.0740
#> 18     5 rank              0.411     2     7 rsq     standard   0.833 
#> 19    33 triweight         0.511    10     3 rmse    standard   0.0728
#> 20    33 triweight         0.511    10     3 rsq     standard   0.842 
#> # ℹ 3 more variables: n <int>, std_err <dbl>, .config <chr>

filter_parameters(ames_grid_search, weight_func == "rank") %>%
  collect_metrics()
#> # A tibble: 6 × 11
#>       K weight_func dist_power   lon   lat .metric .estimator   mean     n
#>   <int> <chr>            <dbl> <int> <int> <chr>   <chr>       <dbl> <int>
#> 1    35 rank             1.29      3    13 rmse    standard   0.0809    10
#> 2    35 rank             1.29      3    13 rsq     standard   0.814     10
#> 3     3 rank             1.86     10    15 rmse    standard   0.0875    10
#> 4     3 rank             1.86     10    15 rsq     standard   0.762     10
#> 5     5 rank             0.411     2     7 rmse    standard   0.0740    10
#> 6     5 rank             0.411     2     7 rsq     standard   0.833     10
#> # ℹ 2 more variables: std_err <dbl>, .config <chr>

rank_only <- tibble::tibble(weight_func = "rank")
filter_parameters(ames_grid_search, parameters = rank_only) %>%
  collect_metrics()
#> # A tibble: 6 × 11
#>       K weight_func dist_power   lon   lat .metric .estimator   mean     n
#>   <int> <chr>            <dbl> <int> <int> <chr>   <chr>       <dbl> <int>
#> 1    35 rank             1.29      3    13 rmse    standard   0.0809    10
#> 2    35 rank             1.29      3    13 rsq     standard   0.814     10
#> 3     3 rank             1.86     10    15 rmse    standard   0.0875    10
#> 4     3 rank             1.86     10    15 rsq     standard   0.762     10
#> 5     5 rank             0.411     2     7 rmse    standard   0.0740    10
#> 6     5 rank             0.411     2     7 rsq     standard   0.833     10
#> # ℹ 2 more variables: std_err <dbl>, .config <chr>

## -----------------------------------------------------------------------------
# Keep only the results from the numerically best combination

ames_iter_search %>%
  collect_metrics()
#> # A tibble: 40 × 12
#>        K weight_func dist_power   lon   lat .metric .estimator   mean
#>    <int> <chr>            <dbl> <int> <int> <chr>   <chr>       <dbl>
#>  1    35 optimal          1.32      8     1 rmse    standard   0.0785
#>  2    35 optimal          1.32      8     1 rsq     standard   0.823 
#>  3    35 rank             1.29      3    13 rmse    standard   0.0809
#>  4    35 rank             1.29      3    13 rsq     standard   0.814 
#>  5    21 cos              0.626     1     4 rmse    standard   0.0746
#>  6    21 cos              0.626     1     4 rsq     standard   0.836 
#>  7     4 biweight         0.311     8     4 rmse    standard   0.0777
#>  8     4 biweight         0.311     8     4 rsq     standard   0.814 
#>  9    32 triangular       0.165     9    15 rmse    standard   0.0770
#> 10    32 triangular       0.165     9    15 rsq     standard   0.826 
#> # ℹ 30 more rows
#> # ℹ 4 more variables: n <int>, std_err <dbl>, .config <chr>, .iter <int>

best_param <- select_best(ames_iter_search, metric = "rmse")
ames_iter_search %>%
  filter_parameters(parameters = best_param) %>%
  collect_metrics()
#> # A tibble: 2 × 12
#>       K weight_func dist_power   lon   lat .metric .estimator   mean     n
#>   <int> <chr>            <dbl> <int> <int> <chr>   <chr>       <dbl> <int>
#> 1    33 triweight        0.511    10     3 rmse    standard   0.0728    10
#> 2    33 triweight        0.511    10     3 rsq     standard   0.842     10
#> # ℹ 3 more variables: std_err <dbl>, .config <chr>, .iter <int>