4 - Evaluating models

Machine learning with tidymodels

Looking at predictions

augment(taxi_fit, new_data = taxi_train) %>%
  relocate(tip, .pred_class, .pred_yes, .pred_no)
#> # A tibble: 7,045 × 10
#>    tip   .pred_class .pred_yes .pred_no distance company local dow   month  hour
#>    <fct> <fct>           <dbl>    <dbl>    <dbl> <fct>   <fct> <fct> <fct> <int>
#>  1 no    no             0.0625   0.937      5.39 Flash … no    Sat   Mar      12
#>  2 no    yes            0.924    0.0758    18.4  Sun Ta… no    Sat   Apr       6
#>  3 no    no             0.391    0.609      5.8  other   no    Tue   Jan      10
#>  4 no    no             0.112    0.888      6.85 Flash … no    Fri   Apr       8
#>  5 no    no             0.129    0.871      9.5  City S… no    Wed   Jan       7
#>  6 no    no             0.326    0.674     12    other   no    Fri   Apr      11
#>  7 no    no             0.0917   0.908      8.9  Taxi A… no    Mon   Feb      14
#>  8 no    yes            0.902    0.0980     1.38 other   no    Fri   Apr      16
#>  9 no    no             0.0917   0.908      9.12 Flash … no    Wed   Apr       9
#> 10 no    yes            0.933    0.0668     2.28 City S… no    Thu   Apr      16
#> # ℹ 7,035 more rows

Confusion matrix

Confusion matrix

augment(taxi_fit, new_data = taxi_train) %>%
  conf_mat(truth = tip, estimate = .pred_class)
#>           Truth
#> Prediction  yes   no
#>        yes 4639  660
#>        no   337 1409

Confusion matrix

augment(taxi_fit, new_data = taxi_train) %>%
  conf_mat(truth = tip, estimate = .pred_class) %>%
  autoplot(type = "heatmap")

Metrics for model performance

augment(taxi_fit, new_data = taxi_train) %>%
  accuracy(truth = tip, estimate = .pred_class)
#> # A tibble: 1 × 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy binary         0.858

Dangers of accuracy

We need to be careful of using accuracy() since it can give “good” performance by only predicting one way with imbalanced data

augment(taxi_fit, new_data = taxi_train) %>%
  mutate(.pred_class = factor("yes", levels = c("yes", "no"))) %>%
  accuracy(truth = tip, estimate = .pred_class)
#> # A tibble: 1 × 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy binary         0.706

Metrics for model performance

augment(taxi_fit, new_data = taxi_train) %>%
  sensitivity(truth = tip, estimate = .pred_class)
#> # A tibble: 1 × 3
#>   .metric     .estimator .estimate
#>   <chr>       <chr>          <dbl>
#> 1 sensitivity binary         0.932

Metrics for model performance

augment(taxi_fit, new_data = taxi_train) %>%
  sensitivity(truth = tip, estimate = .pred_class)
#> # A tibble: 1 × 3
#>   .metric     .estimator .estimate
#>   <chr>       <chr>          <dbl>
#> 1 sensitivity binary         0.932


augment(taxi_fit, new_data = taxi_train) %>%
  specificity(truth = tip, estimate = .pred_class)
#> # A tibble: 1 × 3
#>   .metric     .estimator .estimate
#>   <chr>       <chr>          <dbl>
#> 1 specificity binary         0.681

Metrics for model performance

We can use metric_set() to combine multiple calculations into one

taxi_metrics <- metric_set(accuracy, specificity, sensitivity)

augment(taxi_fit, new_data = taxi_train) %>%
  taxi_metrics(truth = tip, estimate = .pred_class)
#> # A tibble: 3 × 3
#>   .metric     .estimator .estimate
#>   <chr>       <chr>          <dbl>
#> 1 accuracy    binary         0.858
#> 2 specificity binary         0.681
#> 3 sensitivity binary         0.932

Metrics for model performance

taxi_metrics <- metric_set(accuracy, specificity, sensitivity)

augment(taxi_fit, new_data = taxi_train) %>%
  group_by(local) %>%
  taxi_metrics(truth = tip, estimate = .pred_class)
#> # A tibble: 6 × 4
#>   local .metric     .estimator .estimate
#>   <fct> <chr>       <chr>          <dbl>
#> 1 yes   accuracy    binary         0.840
#> 2 no    accuracy    binary         0.862
#> 3 yes   specificity binary         0.346
#> 4 no    specificity binary         0.719
#> 5 yes   sensitivity binary         0.969
#> 6 no    sensitivity binary         0.925

Two class data

These metrics assume that we know the threshold for converting “soft” probability predictions into “hard” class predictions.

Is a 50% threshold good?

What happens if we say that we need to be 80% sure to declare an event?

  • sensitivity ⬇️, specificity ⬆️

What happens for a 20% threshold?

  • sensitivity ⬆️, specificity ⬇️

Varying the threshold

ROC curves

To make an ROC (receiver operator characteristic) curve, we:

  • calculate the sensitivity and specificity for all possible thresholds

  • plot false positive rate (x-axis) versus true positive rate (y-axis)

given that sensitivity is the true positive rate, and specificity is the true negative rate. Hence 1 - specificity is the false positive rate.

We can use the area under the ROC curve as a classification metric:

  • ROC AUC = 1 💯
  • ROC AUC = 1/2 😢

ROC curves

# Assumes _first_ factor level is event; there are options to change that
augment(taxi_fit, new_data = taxi_train) %>% 
  roc_curve(truth = tip, .pred_yes) %>%
  slice(1, 20, 50)
#> # A tibble: 3 × 3
#>   .threshold specificity sensitivity
#>        <dbl>       <dbl>       <dbl>
#> 1    -Inf          0           1    
#> 2       0.25       0.486       0.972
#> 3       0.6        0.705       0.920

augment(taxi_fit, new_data = taxi_train) %>% 
  roc_auc(truth = tip, .pred_yes)
#> # A tibble: 1 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.868

ROC curve plot

augment(taxi_fit, new_data = taxi_train) %>% 
  roc_curve(truth = tip, .pred_yes) %>%
  autoplot()

Your turn

Compute and plot an ROC curve for your current model.

What data are being used for this ROC curve plot?

05:00

⚠️ DANGERS OF OVERFITTING ⚠️

Dangers of overfitting ⚠️

Dangers of overfitting ⚠️

Dangers of overfitting ⚠️

taxi_fit %>%
  augment(taxi_train)
#> # A tibble: 7,045 × 10
#>    tip   distance company local dow   month  hour .pred_class .pred_yes .pred_no
#>    <fct>    <dbl> <fct>   <fct> <fct> <fct> <int> <fct>           <dbl>    <dbl>
#>  1 no        5.39 Flash … no    Sat   Mar      12 no             0.0625   0.937 
#>  2 no       18.4  Sun Ta… no    Sat   Apr       6 yes            0.924    0.0758
#>  3 no        5.8  other   no    Tue   Jan      10 no             0.391    0.609 
#>  4 no        6.85 Flash … no    Fri   Apr       8 no             0.112    0.888 
#>  5 no        9.5  City S… no    Wed   Jan       7 no             0.129    0.871 
#>  6 no       12    other   no    Fri   Apr      11 no             0.326    0.674 
#>  7 no        8.9  Taxi A… no    Mon   Feb      14 no             0.0917   0.908 
#>  8 no        1.38 other   no    Fri   Apr      16 yes            0.902    0.0980
#>  9 no        9.12 Flash … no    Wed   Apr       9 no             0.0917   0.908 
#> 10 no        2.28 City S… no    Thu   Apr      16 yes            0.933    0.0668
#> # ℹ 7,035 more rows

We call this “resubstitution” or “repredicting the training set”

Dangers of overfitting ⚠️

taxi_fit %>%
  augment(taxi_train) %>%
  accuracy(tip, .pred_class)
#> # A tibble: 1 × 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy binary         0.858

We call this a “resubstitution estimate”

Dangers of overfitting ⚠️

taxi_fit %>%
  augment(taxi_train) %>%
  accuracy(tip, .pred_class)
#> # A tibble: 1 × 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy binary         0.858

Dangers of overfitting ⚠️

taxi_fit %>%
  augment(taxi_train) %>%
  accuracy(tip, .pred_class)
#> # A tibble: 1 × 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy binary         0.858
taxi_fit %>%
  augment(taxi_test) %>%
  accuracy(tip, .pred_class)
#> # A tibble: 1 × 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy binary         0.795

⚠️ Remember that we’re demonstrating overfitting

⚠️ Don’t use the test set until the end of your modeling analysis

Your turn

Use augment() and and a metric function to compute a classification metric like brier_class().

Compute the metrics for both training and testing data to demonstrate overfitting!

Notice the evidence of overfitting! ⚠️

05:00

Dangers of overfitting ⚠️

taxi_fit %>%
  augment(taxi_train) %>%
  brier_class(tip, .pred_yes)
#> # A tibble: 1 × 3
#>   .metric     .estimator .estimate
#>   <chr>       <chr>          <dbl>
#> 1 brier_class binary         0.113
taxi_fit %>%
  augment(taxi_test) %>%
  brier_class(tip, .pred_yes)
#> # A tibble: 1 × 3
#>   .metric     .estimator .estimate
#>   <chr>       <chr>          <dbl>
#> 1 brier_class binary         0.152

What if we want to compare more models?

And/or more model configurations?

And we want to understand if these are important differences?

The testing data are precious 💎

How can we use the training data to compare and evaluate different models? 🤔

Cross-validation

Cross-validation

Your turn

If we use 10 folds, what percent of the training data

  • ends up in analysis
  • ends up in assessment

for each fold?

03:00

Cross-validation

vfold_cv(taxi_train) # v = 10 is default
#> #  10-fold cross-validation 
#> # A tibble: 10 × 2
#>    splits             id    
#>    <list>             <chr> 
#>  1 <split [6340/705]> Fold01
#>  2 <split [6340/705]> Fold02
#>  3 <split [6340/705]> Fold03
#>  4 <split [6340/705]> Fold04
#>  5 <split [6340/705]> Fold05
#>  6 <split [6341/704]> Fold06
#>  7 <split [6341/704]> Fold07
#>  8 <split [6341/704]> Fold08
#>  9 <split [6341/704]> Fold09
#> 10 <split [6341/704]> Fold10

Cross-validation

What is in this?

taxi_folds <- vfold_cv(taxi_train)
taxi_folds$splits[1:3]
#> [[1]]
#> <Analysis/Assess/Total>
#> <6340/705/7045>
#> 
#> [[2]]
#> <Analysis/Assess/Total>
#> <6340/705/7045>
#> 
#> [[3]]
#> <Analysis/Assess/Total>
#> <6340/705/7045>

Cross-validation

vfold_cv(taxi_train, v = 5)
#> #  5-fold cross-validation 
#> # A tibble: 5 × 2
#>   splits              id   
#>   <list>              <chr>
#> 1 <split [5636/1409]> Fold1
#> 2 <split [5636/1409]> Fold2
#> 3 <split [5636/1409]> Fold3
#> 4 <split [5636/1409]> Fold4
#> 5 <split [5636/1409]> Fold5

Cross-validation

vfold_cv(taxi_train, strata = tip)
#> #  10-fold cross-validation using stratification 
#> # A tibble: 10 × 2
#>    splits             id    
#>    <list>             <chr> 
#>  1 <split [6340/705]> Fold01
#>  2 <split [6340/705]> Fold02
#>  3 <split [6340/705]> Fold03
#>  4 <split [6340/705]> Fold04
#>  5 <split [6340/705]> Fold05
#>  6 <split [6340/705]> Fold06
#>  7 <split [6341/704]> Fold07
#>  8 <split [6341/704]> Fold08
#>  9 <split [6341/704]> Fold09
#> 10 <split [6342/703]> Fold10

Stratification often helps, with very little downside

Cross-validation

We’ll use this setup:

set.seed(123)
taxi_folds <- vfold_cv(taxi_train, v = 10, strata = tip)
taxi_folds
#> #  10-fold cross-validation using stratification 
#> # A tibble: 10 × 2
#>    splits             id    
#>    <list>             <chr> 
#>  1 <split [6340/705]> Fold01
#>  2 <split [6340/705]> Fold02
#>  3 <split [6340/705]> Fold03
#>  4 <split [6340/705]> Fold04
#>  5 <split [6340/705]> Fold05
#>  6 <split [6340/705]> Fold06
#>  7 <split [6341/704]> Fold07
#>  8 <split [6341/704]> Fold08
#>  9 <split [6341/704]> Fold09
#> 10 <split [6342/703]> Fold10

Set the seed when creating resamples

We are equipped with metrics and resamples!

Fit our model to the resamples

taxi_res <- fit_resamples(taxi_wflow, taxi_folds)
taxi_res
#> # Resampling results
#> # 10-fold cross-validation using stratification 
#> # A tibble: 10 × 4
#>    splits             id     .metrics         .notes          
#>    <list>             <chr>  <list>           <list>          
#>  1 <split [6340/705]> Fold01 <tibble [2 × 4]> <tibble [0 × 3]>
#>  2 <split [6340/705]> Fold02 <tibble [2 × 4]> <tibble [0 × 3]>
#>  3 <split [6340/705]> Fold03 <tibble [2 × 4]> <tibble [0 × 3]>
#>  4 <split [6340/705]> Fold04 <tibble [2 × 4]> <tibble [0 × 3]>
#>  5 <split [6340/705]> Fold05 <tibble [2 × 4]> <tibble [0 × 3]>
#>  6 <split [6340/705]> Fold06 <tibble [2 × 4]> <tibble [0 × 3]>
#>  7 <split [6341/704]> Fold07 <tibble [2 × 4]> <tibble [0 × 3]>
#>  8 <split [6341/704]> Fold08 <tibble [2 × 4]> <tibble [0 × 3]>
#>  9 <split [6341/704]> Fold09 <tibble [2 × 4]> <tibble [0 × 3]>
#> 10 <split [6342/703]> Fold10 <tibble [2 × 4]> <tibble [0 × 3]>

Evaluating model performance

taxi_res %>%
  collect_metrics()
#> # A tibble: 2 × 6
#>   .metric  .estimator  mean     n std_err .config             
#>   <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 accuracy binary     0.793    10 0.00293 Preprocessor1_Model1
#> 2 roc_auc  binary     0.809    10 0.00461 Preprocessor1_Model1

We can reliably measure performance using only the training data 🎉

Comparing metrics

How do the metrics from resampling compare to the metrics from training and testing?

taxi_res %>%
  collect_metrics() %>% 
  select(.metric, mean, n)
#> # A tibble: 2 × 3
#>   .metric   mean     n
#>   <chr>    <dbl> <int>
#> 1 accuracy 0.793    10
#> 2 roc_auc  0.809    10

The ROC AUC previously was

  • 0.87 for the training set
  • 0.81 for test set

Remember that:

⚠️ the training set gives you overly optimistic metrics

⚠️ the test set is precious

Evaluating model performance

# Save the assessment set results
ctrl_taxi <- control_resamples(save_pred = TRUE)
taxi_res <- fit_resamples(taxi_wflow, taxi_folds, control = ctrl_taxi)

taxi_preds <- collect_predictions(taxi_res)
taxi_preds
#> # A tibble: 7,045 × 7
#>    id     .pred_yes .pred_no  .row .pred_class tip   .config             
#>    <chr>      <dbl>    <dbl> <int> <fct>       <fct> <chr>               
#>  1 Fold01    0.936    0.0638    10 yes         no    Preprocessor1_Model1
#>  2 Fold01    0.898    0.102     20 yes         no    Preprocessor1_Model1
#>  3 Fold01    0.898    0.102     47 yes         no    Preprocessor1_Model1
#>  4 Fold01    0.101    0.899     51 no          no    Preprocessor1_Model1
#>  5 Fold01    0.871    0.129     59 yes         no    Preprocessor1_Model1
#>  6 Fold01    0.0815   0.918     60 no          no    Preprocessor1_Model1
#>  7 Fold01    0.162    0.838     92 no          no    Preprocessor1_Model1
#>  8 Fold01    0.26     0.74      97 no          no    Preprocessor1_Model1
#>  9 Fold01    0.274    0.726     98 no          no    Preprocessor1_Model1
#> 10 Fold01    0.804    0.196    104 yes         no    Preprocessor1_Model1
#> # ℹ 7,035 more rows

Evaluating model performance

taxi_preds %>% 
  group_by(id) %>%
  taxi_metrics(truth = tip, estimate = .pred_class)
#> # A tibble: 30 × 4
#>    id     .metric  .estimator .estimate
#>    <chr>  <chr>    <chr>          <dbl>
#>  1 Fold01 accuracy binary         0.793
#>  2 Fold02 accuracy binary         0.8  
#>  3 Fold03 accuracy binary         0.786
#>  4 Fold04 accuracy binary         0.804
#>  5 Fold05 accuracy binary         0.796
#>  6 Fold06 accuracy binary         0.789
#>  7 Fold07 accuracy binary         0.793
#>  8 Fold08 accuracy binary         0.808
#>  9 Fold09 accuracy binary         0.783
#> 10 Fold10 accuracy binary         0.780
#> # ℹ 20 more rows

Where are the fitted models?

taxi_res
#> # Resampling results
#> # 10-fold cross-validation using stratification 
#> # A tibble: 10 × 5
#>    splits             id     .metrics         .notes           .predictions
#>    <list>             <chr>  <list>           <list>           <list>      
#>  1 <split [6340/705]> Fold01 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
#>  2 <split [6340/705]> Fold02 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
#>  3 <split [6340/705]> Fold03 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
#>  4 <split [6340/705]> Fold04 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
#>  5 <split [6340/705]> Fold05 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
#>  6 <split [6340/705]> Fold06 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
#>  7 <split [6341/704]> Fold07 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
#>  8 <split [6341/704]> Fold08 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
#>  9 <split [6341/704]> Fold09 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
#> 10 <split [6342/703]> Fold10 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>

🗑️

Alternate resampling schemes

Bootstrapping

Bootstrapping

set.seed(3214)
bootstraps(taxi_train)
#> # Bootstrap sampling 
#> # A tibble: 25 × 2
#>    splits              id         
#>    <list>              <chr>      
#>  1 <split [7045/2561]> Bootstrap01
#>  2 <split [7045/2577]> Bootstrap02
#>  3 <split [7045/2648]> Bootstrap03
#>  4 <split [7045/2616]> Bootstrap04
#>  5 <split [7045/2616]> Bootstrap05
#>  6 <split [7045/2599]> Bootstrap06
#>  7 <split [7045/2654]> Bootstrap07
#>  8 <split [7045/2593]> Bootstrap08
#>  9 <split [7045/2624]> Bootstrap09
#> 10 <split [7045/2615]> Bootstrap10
#> # ℹ 15 more rows

The whole game - status update

Your turn

Create:

  • Monte Carlo Cross-Validation sets
  • validation set

(use the reference guide to find the function)

Don’t forget to set a seed when you resample!

05:00

Monte Carlo Cross-Validation

set.seed(322)
mc_cv(taxi_train, times = 10)
#> # Monte Carlo cross-validation (0.75/0.25) with 10 resamples  
#> # A tibble: 10 × 2
#>    splits              id        
#>    <list>              <chr>     
#>  1 <split [5283/1762]> Resample01
#>  2 <split [5283/1762]> Resample02
#>  3 <split [5283/1762]> Resample03
#>  4 <split [5283/1762]> Resample04
#>  5 <split [5283/1762]> Resample05
#>  6 <split [5283/1762]> Resample06
#>  7 <split [5283/1762]> Resample07
#>  8 <split [5283/1762]> Resample08
#>  9 <split [5283/1762]> Resample09
#> 10 <split [5283/1762]> Resample10

Validation set

set.seed(853)
validation_split(taxi_train, strata = tip)
#> # Validation Set Split (0.75/0.25)  using stratification 
#> # A tibble: 1 × 2
#>   splits              id        
#>   <list>              <chr>     
#> 1 <split [5283/1762]> validation

A validation set is just another type of resample

Decision tree 🌳

Random forest 🌳🌲🌴🌵🌴🌳🌳🌴🌲🌵🌴🌲🌳🌴🌳🌵🌵🌴🌲🌲🌳🌴🌳🌴🌲🌴🌵🌴🌲🌴🌵🌲🌵🌴🌲🌳🌴🌵🌳🌴🌳

Random forest 🌳🌲🌴🌵🌳🌳🌴🌲🌵🌴🌳🌵

  • Ensemble many decision tree models

  • All the trees vote! 🗳️

  • Bootstrap aggregating + random predictor sampling

  • Often works well without tuning hyperparameters (more on this tomorrow!), as long as there are enough trees

Create a random forest model

rf_spec <- rand_forest(trees = 1000, mode = "classification")
rf_spec
#> Random Forest Model Specification (classification)
#> 
#> Main Arguments:
#>   trees = 1000
#> 
#> Computational engine: ranger

Create a random forest model

rf_wflow <- workflow(tip ~ ., rf_spec)
rf_wflow
#> ══ Workflow ══════════════════════════════════════════════════════════
#> Preprocessor: Formula
#> Model: rand_forest()
#> 
#> ── Preprocessor ──────────────────────────────────────────────────────
#> tip ~ .
#> 
#> ── Model ─────────────────────────────────────────────────────────────
#> Random Forest Model Specification (classification)
#> 
#> Main Arguments:
#>   trees = 1000
#> 
#> Computational engine: ranger

Your turn

Use fit_resamples() and rf_wflow to:

  • keep predictions
  • compute metrics
08:00

Evaluating model performance

ctrl_taxi <- control_resamples(save_pred = TRUE)

# Random forest uses random numbers so set the seed first

set.seed(2)
rf_res <- fit_resamples(rf_wflow, taxi_folds, control = ctrl_taxi)
collect_metrics(rf_res)
#> # A tibble: 2 × 6
#>   .metric  .estimator  mean     n std_err .config             
#>   <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 accuracy binary     0.813    10 0.00305 Preprocessor1_Model1
#> 2 roc_auc  binary     0.832    10 0.00513 Preprocessor1_Model1

How can we compare multiple model workflows at once?

Evaluate a workflow set

workflow_set(list(tip ~ .), list(tree_spec, rf_spec))
#> # A workflow set/tibble: 2 × 4
#>   wflow_id              info             option    result    
#>   <chr>                 <list>           <list>    <list>    
#> 1 formula_decision_tree <tibble [1 × 4]> <opts[0]> <list [0]>
#> 2 formula_rand_forest   <tibble [1 × 4]> <opts[0]> <list [0]>

Evaluate a workflow set

workflow_set(list(tip ~ .), list(tree_spec, rf_spec)) %>%
  workflow_map("fit_resamples", resamples = taxi_folds)
#> # A workflow set/tibble: 2 × 4
#>   wflow_id              info             option    result   
#>   <chr>                 <list>           <list>    <list>   
#> 1 formula_decision_tree <tibble [1 × 4]> <opts[1]> <rsmp[+]>
#> 2 formula_rand_forest   <tibble [1 × 4]> <opts[1]> <rsmp[+]>

Evaluate a workflow set

workflow_set(list(tip ~ .), list(tree_spec, rf_spec)) %>%
  workflow_map("fit_resamples", resamples = taxi_folds) %>%
  rank_results()
#> # A tibble: 4 × 9
#>   wflow_id          .config .metric  mean std_err     n preprocessor model  rank
#>   <chr>             <chr>   <chr>   <dbl>   <dbl> <int> <chr>        <chr> <int>
#> 1 formula_rand_for… Prepro… accura… 0.813 0.00339    10 formula      rand…     1
#> 2 formula_rand_for… Prepro… roc_auc 0.833 0.00528    10 formula      rand…     1
#> 3 formula_decision… Prepro… accura… 0.793 0.00293    10 formula      deci…     2
#> 4 formula_decision… Prepro… roc_auc 0.809 0.00461    10 formula      deci…     2

The first metric of the metric set is used for ranking. Use rank_metric to change that.

Lots more available with workflow sets, like collect_metrics(), autoplot() methods, and more!

Your turn

When do you think a workflow set would be useful?

03:00

The whole game - status update

The final fit

Suppose that we are happy with our random forest model.

Let’s fit the model on the training set and verify our performance using the test set.

We’ve shown you fit() and predict() (+ augment()) but there is a shortcut:

# taxi_split has train + test info
final_fit <- last_fit(rf_wflow, taxi_split) 

final_fit
#> # Resampling results
#> # Manual resampling 
#> # A tibble: 1 × 6
#>   splits              id               .metrics .notes   .predictions .workflow 
#>   <list>              <chr>            <list>   <list>   <list>       <list>    
#> 1 <split [7045/1762]> train/test split <tibble> <tibble> <tibble>     <workflow>

What is in final_fit?

collect_metrics(final_fit)
#> # A tibble: 2 × 4
#>   .metric  .estimator .estimate .config             
#>   <chr>    <chr>          <dbl> <chr>               
#> 1 accuracy binary         0.810 Preprocessor1_Model1
#> 2 roc_auc  binary         0.817 Preprocessor1_Model1

These are metrics computed with the test set

What is in final_fit?

collect_predictions(final_fit)
#> # A tibble: 1,762 × 7
#>    id               .pred_yes .pred_no  .row .pred_class tip   .config          
#>    <chr>                <dbl>    <dbl> <int> <fct>       <fct> <chr>            
#>  1 train/test split     0.732   0.268     10 yes         no    Preprocessor1_Mo…
#>  2 train/test split     0.827   0.173     29 yes         yes   Preprocessor1_Mo…
#>  3 train/test split     0.899   0.101     35 yes         yes   Preprocessor1_Mo…
#>  4 train/test split     0.914   0.0856    42 yes         yes   Preprocessor1_Mo…
#>  5 train/test split     0.911   0.0889    47 yes         no    Preprocessor1_Mo…
#>  6 train/test split     0.848   0.152     54 yes         yes   Preprocessor1_Mo…
#>  7 train/test split     0.580   0.420     59 yes         yes   Preprocessor1_Mo…
#>  8 train/test split     0.912   0.0876    62 yes         yes   Preprocessor1_Mo…
#>  9 train/test split     0.810   0.190     63 yes         yes   Preprocessor1_Mo…
#> 10 train/test split     0.960   0.0402    69 yes         yes   Preprocessor1_Mo…
#> # ℹ 1,752 more rows

What is in final_fit?

extract_workflow(final_fit)
#> ══ Workflow [trained] ════════════════════════════════════════════════
#> Preprocessor: Formula
#> Model: rand_forest()
#> 
#> ── Preprocessor ──────────────────────────────────────────────────────
#> tip ~ .
#> 
#> ── Model ─────────────────────────────────────────────────────────────
#> Ranger result
#> 
#> Call:
#>  ranger::ranger(x = maybe_data_frame(x), y = y, num.trees = ~1000,      num.threads = 1, verbose = FALSE, seed = sample.int(10^5,          1), probability = TRUE) 
#> 
#> Type:                             Probability estimation 
#> Number of trees:                  1000 
#> Sample size:                      7045 
#> Number of independent variables:  6 
#> Mtry:                             2 
#> Target node size:                 10 
#> Variable importance mode:         none 
#> Splitrule:                        gini 
#> OOB prediction error (Brier s.):  0.1373147

Use this for prediction on new data, like for deploying

The whole game

Your turn

End of the day discussion!

Which model do you think you would decide to use?

What surprised you the most?

What is one thing you are looking forward to for tomorrow?

05:00