Overview:

The COVID-19 vaccine offers the best protection against severe illness and prevention of future outbreaks, however the efficacy depends on a widespread uptake of the vaccine. Currently, the national vaccination rate approaches the ‘guide’ for herd immunity, but 45.8% of the population has yet to be fully vaccinated. Beyond the interests in public health, individuals resistant to the idea of getting the vaccine could benefit from policies that encourage its uptake. Utilizing a predictive machine learning model, we uncover which factors appear to influence vaccine uptake the most by predicting vaccination rates by county. Understanding the sociological background of those unwilling to obtain the vaccine can aid policymakers in designing incentive programs for those groups, as well as offer information for future pandemics.

Methods:

We decided to use regression based methods of varying complexity to predict vaccination rate. We selected Ordinary Least Squares (OLS), Elastic Net, K-Nearest Neighbor, and Random Forest models. For tuning hyperparameters, we utilized a 5-fold cross validation and tuned the following hyperparameters: penalty, mixture, neighbors, mtry and min_n. The cross-validated RMSE selected the following values (ordered as above): 0.01, 1, 30, 8, and 6. In most cases, the MAE or 𝑅2 selected the same hyperparameter values.

Setting up the data

Setting up: Vaccination Data

# Load CDC Covid-19 Vaccination Data for US counties over the course of the pandemic
# This data is updated weekly and the file includes all past entries. 
vaccine = read_csv("county-vaccination.csv", show_col_types = FALSE)

# Rename variables for ease
vaccine %<>% 
  rename(
    "County"="Recip_County",
    "State"="Recip_State" ,
    "Population"="Census2019" ,
    "Vax12Plus"="Series_Complete_12PlusPop_Pct",
    "Population18"="Census2019_18PlusPop",
    "Population65"="Census2019_65PlusPop"
    )

# Subset data to isolate one set of observations - 1/1/22
vaccine = subset(vaccine, Date == '1/1/22')

# Removing counties with missing FIPS (essential for identification)
vaccine = subset(vaccine, FIPS != "UNK")

# Remove rows with missing vaccine data
vaccine = subset(vaccine, Vax12Plus != 0)

# Remove Guam, Puerto Rico, and Virgin Islands
vaccine = subset(vaccine, State != c("GU", "PR", "VI"))

# Pad FIPS codes with 4 digits to have 0 on left (like it should)
vaccine$FIPS = str_pad(vaccine$FIPS, width = 5, side = "left", pad = "0")

# Create a dummy variable for Metro area
# There is only 1 missing value for Metro. 
# Since the NA is a county that is in a metro area, the missing value will be getting 1
vaccine %<>% mutate(Metro = 
                    ifelse(Metro_status == 'Metro', 1, 
                    ifelse(Metro_status == 'Non-metro', 0, 1)))

# Create a variable for the percent of population under 18 
# by finding the percent of the population over 18 and taking the difference
vaccine %<>% mutate(PopUnder18 = (100-((Population18/Population)*100)))
# Rounding to 2 digits
vaccine$PopUnder18 = round(vaccine$PopUnder18, digit = 2)

# Create a variable for the percent of population over 65 (a high risk group)
vaccine %<>% mutate(Pop65Plus = ((Population65/Population)*100))
# Rounding to 2 digits
vaccine$Pop65Plus = round(vaccine$Pop65Plus, digit = 2)

# Create a variable for the percent of population between 18 and 65
vaccine %<>% mutate(Pop18to65 = (100-PopUnder18-Pop65Plus))
# Rounding to 2 digits
vaccine$Pop18to65 = round(vaccine$Pop18to65, digit = 2)

# Subset data with only relevant variables
vaccine = subset(vaccine, 
                 select=c("FIPS", "County", "State", "Vax12Plus", "Metro", 
                          "Population", "Pop18to65", "Pop65Plus", "PopUnder18"))

Setting up: Education Data

# Load education dataset from the USDA Economic Research Service
education = read_csv("Education.csv", show_col_types = FALSE)

# Pad FIPS codes
education$FIPS = str_pad(education$FIPS, width = 5, side = "left", pad = "0")

# Remove county as it is formatted differently
education = select(education, -c(County))

#M Merge the vaccine data and the education data by FIPS code and State
df = merge(vaccine, education, by = c("FIPS", "State"))

Setting up: Work From Home Data

# Load work from home dataset from the USDA Economic Research Service
WFH = read_csv("WFH.csv", show_col_types = FALSE)

# Pad FIPS. Did not have a FIPS section - I created this in Excel. 
WFH$FIPS = str_pad(WFH$FIPS, width = 5, side = "left", pad = "0")

# Round population density
WFH$PopDensity = round(WFH$PopDensity, digit = 2)

# Subset with relevant variables
WFH = subset(WFH, select = c("FIPS", "PopDensity", "WFH"))

# Merge to df by FIPS
df = merge(df, WFH, by = c("FIPS"))

# Create percent of population that works from home 
df %<>% mutate(WFH=(WFH/Pop18to65)*100)

Setting up: Poverty Data

# Load poverty dataset from the USDA Economic Research Service
poverty = read_csv("Poverty.csv", show_col_types = FALSE)

# Remove state and county, as those variables are formatted differently 
poverty = select(poverty, -c(STATEFP, COUNTYFP))

# Pad FIPS
poverty$FIPS = str_pad(poverty$FIPS, width = 5, side = "left", pad = "0")

# Merge to df by FIPS
df = merge(df, poverty, by = c("FIPS"))

Setting up: Presidential Election Data

Presidential Election Data - 2012

# Load the dataset - includes elections from 2000-2020 and information on each party.
# We are interested in elections within the past 10 years. 
electionsall = read_csv("countypres_2000-2020.csv", show_col_types = FALSE)

# Remove unnecessary variables
electionsall = select(electionsall, -c(state, county_name, office, version, candidate, mode))

# Rename variables for ease
electionsall %<>% rename("State" = "state_po",
                         "FIPS" = "county_fips")

# Pad FIPS
electionsall$FIPS = str_pad(electionsall$FIPS, width = 5, side = "left", pad = "0")

# Subset 2012 election
election2012 = subset(electionsall, year == '2012')

# Subset Democrat party for 2012 election
dem2012 = subset(election2012, party == 'DEMOCRAT')
# Create a variable for the percent of total votes cast for Democrats in the 2012 election
dem2012 %<>% mutate(dem2012pct = candidatevotes/totalvotes*100)
# Round
dem2012$dem2012pct = round(dem2012$dem2012pct, digit = 2)
# Subset to remove irrelevant variables
dem2012 = select(dem2012, -c(year, totalvotes, candidatevotes, party))
# Merge to df by FIPS and State
df = merge(df, dem2012, by = c("FIPS", "State"))

# Subset Republican party for 2012 election
rep2012 = subset(election2012, party == 'REPUBLICAN')
# Create a variable for the percent of total votes cast for Republicans in the 2012 election
rep2012 %<>% mutate(rep2012pct = candidatevotes/totalvotes*100)
# Round
rep2012$rep2012pct = round(rep2012$rep2012pct, digit = 2)
# Subset to remove irrelevant variables
rep2012 = select(rep2012, -c(year, totalvotes, candidatevotes, party))
# Merge to df by FIPS and State
df = merge(df, rep2012, by = c("FIPS", "State"))

# Create a dummy variable for Democrat majority in 2012 election
df %<>% mutate(dem2012 = ifelse(dem2012pct > rep2012pct, 1, 0))

Presidential Election Data - 2016

# Subset 2016 elections
election2016 = subset(electionsall, year == '2016')

# Subset Democrat party for 2016 election
dem2016 = subset(election2016, party == 'DEMOCRAT')

# Create a variable for the percent of total votes cast for Democrats in the 2016 election
dem2016 = dem2016 %>% mutate(dem2016pct=candidatevotes/totalvotes*100)
# Round
dem2016$dem2016pct = round(dem2016$dem2016pct, digit = 2)

# Subset to remove irrelevant variables
dem2016 = select(dem2016, -c(year, totalvotes, candidatevotes, party))

# Merge to df by FIPS and State
df = merge(df, dem2016, by=c("FIPS", "State"))


# Subset Republican party for 2016 election
rep2016 = subset(election2016, party == 'REPUBLICAN')

# Create a variable for the percent of total votes cast for Republicans in the 2012 election
rep2016 %<>% mutate(rep2016pct=candidatevotes/totalvotes*100)
# Round
rep2016$rep2016pct = round(rep2016$rep2016pct,digit=2)

# Subset to remove irrelevant variables
rep2016 = select(rep2016, -c(year, totalvotes, candidatevotes, party))

# Merge to df by FIPS and State
df = merge(df, rep2016, by=c("FIPS", "State"))

# Create a dummy variable for Democrat majority in 2016 election
df %<>% mutate(dem2016 = ifelse(dem2016pct > rep2016pct, 1,0))

Presidential Election Data - 2020

# Subset 2020 elections
election2020 = subset(electionsall, year == '2020')

# Subset Democrat party for 2020 election
dem2020 = subset(election2020, party == 'DEMOCRAT')

# Within the 2020 elections, there are several cases of counties that have multiple 
# lines for their results, which is an issue. 
# Fix formatting 
dem2020 %<>% 
  group_by(FIPS) %<>% 
  # Combine the candidate votes for counties with multiple entries, but keep 
  # total votes the same
  dplyr::summarise(candidatevotes = sum(candidatevotes), 
                   totalvotes = totalvotes, 
                   State = State) %<>% 
  # Now, we have the accurate candidatevotes value
  as.data.frame()

# Remove duplicated FIPS
dem2020 = dem2020[!duplicated(dem2020$FIPS), ]

# Create a variable for the percent of total votes cast for Democrats in the 2020 election
dem2020 %<>% mutate(dem2020pct = candidatevotes/totalvotes*100)
# Round
dem2020$dem2020pct = round(dem2020$dem2020pct,digit=2)

# Remove irrelevant variables
dem2020 = select(dem2020, -c(totalvotes, candidatevotes))

# Merge to df by FIPS and State
df = merge(df, dem2020, by = c("FIPS", "State"))


# Subset Republican party for 2020 election
rep2020 = subset(election2020, party == 'REPUBLICAN')

# Fix formatting 
rep2020 %<>%                        
  group_by(FIPS) %<>%
  # Combine the candidate votes for counties with multiple entries, but keep 
  # total votes the same 
  dplyr::summarise(candidatevotes = sum(candidatevotes), 
                   totalvotes = totalvotes, 
                   State = State) %<>% 
  # Now, we have the accurate candidatevotes value
  as.data.frame()

# Remove duplicated FIPS
rep2020 = rep2020[!duplicated(rep2020$FIPS), ]

# Create a variable for the percent of total votes cast for Republicans in the 2020 election
rep2020 %<>% mutate(rep2020pct=candidatevotes/totalvotes*100)
# Round
rep2020$rep2020pct = round(rep2020$rep2020pct,digit=2)

# Remove irrelevant variables
rep2020 = select(rep2020, -c(totalvotes, candidatevotes))

# Merge to DF by FIPS and State
df = merge(df, rep2020, by=c("FIPS", "State"))
# Round
df%<>% mutate(dem2020 = ifelse(dem2020pct > rep2020pct, 1,0))

Presidential Election Data - Voting Patterns Over Time

# Create variables for voting patterns
#Previous Democrat voting
df %<>% mutate(pastvoting_dem = (dem2012+dem2016+dem2020))

# Average percent of votes Republican in past 3 presidential elections
df %<>% mutate(avgpct_rep = ((rep2012pct+rep2016pct+rep2020pct)/3))

# Average percent of votes Democrat in past 3 presidential elections
df %<>% mutate(avgpct_dem = ((dem2012pct+dem2016pct+dem2020pct)/3))

# Round
df$pastvoting_dem = round(df$pastvoting_dem,digit = 2)
df$avgpct_rep = round(df$avgpct_rep, digit = 2)
df$avgpct_dem = round(df$avgpct_dem, digit = 2)

# Remove unnecessary variables
df = subset(df, select=-c(dem2012pct, dem2016pct,dem2020pct, rep2012pct, 
                          rep2016pct, rep2020pct, Pop18to65))
# Filter
df %<>% filter(is.na(avgpct_dem) == FALSE)
#Rename df
covid = df
# Write new CSV so do not have to load code every time
# write.csv(df,"524Project.csv", row.names = FALSE)

Exploring Data

Exploring Data: Map

# County Covid vaccination rates across the US
covid_map = covid %>% rename(fips = FIPS)

covid_map %<>% 
  mutate(above_med = ifelse(Vax12Plus >= 54.3, "Above Median", "Below Median"))

# graph vaccination rates
plot_usmap(
  data = covid_map, values = "above_med") + 
  scale_fill_manual(values = c("Above Median" = "dark cyan", "Below Median" = "pink"), 
                    name = "Median Vaccination Rate (Jan. 2022)") +  
  ggtitle("COVID-19 Median Vaccination Rates by County") +
  theme(legend.position = "right",
        legend.text = element_text(size=14),
        legend.title = element_text(size=16),
        panel.background = element_rect(fill="lavender"),
        plot.title = element_text(hjust = 0.5, 
                                  size = 20, 
                                  margin = margin(t = 10, b = 10)))

Exploring Data: Correlation Heat Map

# Correlation heat map 
# Create normal correlation matrix
covid1 = covid %>% subset(select = -c(1,2,3,17,18, 19))
# Round
cormat = round(cor(covid1),2)

# Melt package
p_load(reshape2)
  get_upper_tri <- function(cormat){
    cormat[lower.tri(cormat)]<- NA
    return(cormat)
  }
  upper_tri <- get_upper_tri(cormat)
  melted_cormat <- melt(upper_tri, na.rm = TRUE)

# Plot heatmap
ggplot(data = melted_cormat, 
       aes(Var2, Var1, fill = value)) +
       coord_fixed() +
       geom_tile(color = "white") +
       scale_fill_gradient2(low = "dark cyan", high = "indianred1", mid = "white", 
                            midpoint = 0, limit = c(-1,1), space = "Lab", 
                            name="Correlation") +
       theme_minimal() + 
       labs(x ="Variables", y ="Variables",
            title = "Correlation Matrix Heatmap") +
       theme(axis.text.x = element_text(angle = 60, vjust = 1,  hjust = 1),
             plot.title = element_text(hjust = 0.5,
                                       size = 14, 
                                       margin = margin(t = 10, b = 10)))

Working with the data

Split & Recipe

# Set seed 
set.seed(1)

# Remove county and state columns
covid %<>% subset(select = -c(2,3))

# Make Split
covid_split = covid %>% initial_split(prop = 0.8, strata = Vax12Plus) 

# Separate training and testing data
covid_train = covid_split %>% training()
covid_test = covid_split %>% testing()

# Define the recipe: 
covid_recipe = 
  recipe(Vax12Plus ~ .,data = covid_train) %>%
  # Normalize for numeric predictors
  step_normalize(all_predictors() & all_numeric()) %>% 
  # KNN imputation for categorical predictors
  # unnecessary
  step_impute_knn(all_predictors() & all_nominal(), neighbors = 5) %>%
  # Create dummies for categorical variables
  step_dummy(all_predictors() & all_nominal()) %>%
  update_role(FIPS, new_role = "id variable")

# Define the 5-fold split
covid_cv = covid_train %>% vfold_cv(v = 5)

Linear Regression

# Set seed
set.seed(1)

# Set model
model_linear = linear_reg() %>%
  set_engine("lm")

# Create workflow
workflow_linear = workflow() %>%
  add_model(model_linear) %>% 
  add_recipe(covid_recipe)

# Fit cross validated samples and extract rmse rsq and mae
cv_linear = 
  workflow_linear %>%
   fit_resamples(
    object = model_linear,
    preprocessor = covid_recipe,
    resamples = covid_cv,
    metrics = yardstick::metric_set(yardstick::rmse, yardstick::rsq, yardstick::mae))

# Examine metrics
cv_linear %>% collect_metrics()

# Choose a model with best rmse and finalize workflow
final_lin =
    workflow_linear %>% 
    finalize_workflow(select_best(cv_linear, metric = "rmse"))

# Fit model on training data
final_fit_lin = final_lin %>% fit(data = covid_train)

# Fit model on test data
final_fit_lin = final_lin %>% last_fit(covid_split)

# Get test-sample predictions
final_fit_lin %>% collect_predictions() %>% head()

Linear Regression Results

# Get metrics of linear regression final fit
final_fit_lin %>% collect_metrics()
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard       9.63  Preprocessor1_Model1
## 2 rsq     standard       0.520 Preprocessor1_Model1

Elasticnet

# Set seed
set.seed(1)

# Set tuning values for penalty and mixture
alphas = seq(from = 0, to = 1, by = 0.1)
lambdas = 10^seq(from = 2, to = -3, length = 1e2)

# Define the elasticnet model
model_net = 
  linear_reg(penalty = tune(), 
             mixture = tune()) %>% 
  set_engine("glmnet") %>%
  set_mode("regression")

# Define workflow
workflow_net = workflow() %>%
  add_model(model_net) %>% 
  add_recipe(covid_recipe)

# Cross-validate elasticnet with lambdas and alphas
cv_net = 
  workflow_net %>%
  tune_grid(
    covid_cv,
    grid = expand_grid(mixture = alphas, penalty = lambdas),
    metrics = yardstick::metric_set(yardstick::rmse, yardstick::rsq, yardstick::mae))

# Collect best hyperparameter values for each metric
cv_net %>% show_best(metric = "rmse", n = 3)
cv_net %>% show_best(metric = "rsq", n = 3)
cv_net %>% collect_metrics(metric = "mae")

# Elasticnet chooses lasso for this and penalty = 0.001 chosen by RMSE and RSQ


# Finalize workflow with best RMSE model
final_net =
    workflow_net%>% 
    finalize_workflow(select_best(cv_net, metric = "rmse"))
# Fit final model to training data
final_fit_net = final_net %>% fit(data = covid_train)
# Fit final model to the data split
final_fit_net = final_net %>% last_fit(covid_split)
# Get test-sample predictions
final_fit_net %>% collect_predictions() %>% head()

Elasticnet Results

# Retrieve elasticnet metrics of final fit
final_fit_net %>% collect_metrics()
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard       9.63  Preprocessor1_Model1
## 2 rsq     standard       0.521 Preprocessor1_Model1

###Elasticnet Graph of Metrics

# Graph of elasticnet metrics
autoplot(cv_net, metric = "rmse") + 
  theme_bw() +
  scale_x_continuous(limits = c(0, 1)) +
  scale_y_continuous(limits = c(10.25,10.8)) +
  labs(x ="RMSE", y ="Amount of Regularization",
         title = "Elasticnet Results - Amount of Regularization vs. RMSE") +
  theme(plot.background = element_rect(fill ="lavender"),
        plot.title = element_text(hjust = 0.5, 
                                  size = 14, 
                                  margin = margin(t = 10, b = 10)),
        axis.text.x = element_text(color = "black", 
                                   size = 10, 
                                   margin = margin(t = 5, b = 5)),
        axis.text.y = element_text(color = "black", 
                                   size = 10,
                                   margin = margin(r = 5, l = 5)),
        axis.title = element_text(size = 13,
                                  margin = margin(t = 12, r = 12, b = 12, l = 12)))

autoplot(cv_net, metric = "mae") + 
  theme_bw() +
  scale_x_continuous(limits = c(0, 1))+
  scale_y_continuous(limits = c(7.3,7.9)) +
  labs(x ="MAE", y ="Amount of Regularization",
       title = "Elasticnet Results - Amount of Regularization vs. MAE") +
  theme(plot.background = element_rect(fill ="lavender"),
        plot.title = element_text(hjust = 0.5, 
                                  size = 14, 
                                  margin = margin(t = 10, b = 10)),
           axis.text.x = element_text(color = "black", 
                                      size = 10, 
                                      margin = margin(t = 5, b = 5)),
           axis.text.y = element_text(color = "black", 
                                      size = 10,
                                      margin = margin(r = 5, l = 5)),
           axis.title = element_text(size = 13,
                                     margin = margin(t = 12, r = 12, b = 12, l = 12)))

Random Forest

# Set seed
set.seed(1)

# Create model tuning mtry and min_n
tune_spec <- rand_forest(
  mtry = tune(),
  trees = 200,
  min_n = tune()
) %>%
  set_mode("regression") %>%
  set_engine("ranger", importance = "impurity")

# Define workflow
tune_wf <- workflow() %>%
  add_recipe(covid_recipe) %>%
  add_model(tune_spec)

# Create grid of possible mtry and min_n values
rf_grid <- grid_regular(
  mtry(range = c(4, 10)),
  min_n(range = c(2, 8)),
  levels = 5)

# Cross-validate to tune hyperparameters
regular_res <- tune_grid(
  tune_wf,
  resamples = covid_cv,
  grid = rf_grid,
  metrics = yardstick::metric_set(yardstick::rmse, yardstick::rsq, yardstick::mae))

# Collect best hyperparameters from metrics
 regular_res %>% show_best(metric="rmse")
 regular_res %>% show_best(metric="rsq")
 regular_res %>% show_best(metric="mae")

# Final fit with best model determined by rmse
final_rf =
    tune_wf%>% 
    finalize_workflow(select_best(regular_res, metric = "rmse"))

# Fit model on training data
final_fit_rf = final_rf %>% fit(data = covid_train)

# Fit model on split
final_fit_rf = final_rf %>% last_fit(covid_split)

# Get test-sample predictions
final_fit_rf %>% collect_predictions() %>% head()

Random Forest Results

# Retrieve metrics of random forest final fit
final_fit_rf %>% collect_metrics()
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard       9.17  Preprocessor1_Model1
## 2 rsq     standard       0.569 Preprocessor1_Model1

Random Forest Graph of Metrics

# Graph of random forest metrics
autoplot(regular_res, metric = "rmse") +
   theme_bw()+
   labs(x ="RMSE", y ="# of Randomly Selected Predictors",
       title = "Random Forest Results - # of Predictors vs. RMSE") +
   theme(plot.background = element_rect(fill ="lavender"),
         plot.title = element_text(hjust = 0.5, 
                                   size = 14, 
                                   margin = margin(t = 10, b = 10)),
          axis.text.x = element_text(color = "black", 
                                     size = 10, 
                                     margin = margin(t = 5, b = 5)),
          axis.text.y = element_text(color = "black", 
                                     size = 10,
                                     margin = margin(r = 5, l = 5)),
          axis.title = element_text(size = 13,
                                     margin = margin(t = 12, r = 12, b = 12, l = 12))) 

autoplot(regular_res, metric = "mae") +
     theme_bw()+
   labs(x ="MAE", y ="# of Randomly Selected Predictors",
       title = "Random Forest Results - # of Predictors vs. MAE") +
  theme(plot.background = element_rect(fill ="lavender"),
        plot.title = element_text(hjust = 0.5, 
                                     size = 14, 
                                     margin = margin(t = 10, b = 10)),
           axis.text.x = element_text(color = "black", 
                                      size = 10, 
                                      margin = margin(t = 5, b = 5)),
           axis.text.y = element_text(color = "black", 
                                      size = 10,
                                      margin = margin(r = 5, l = 5)),
           axis.title = element_text(size = 13,
                                     margin = margin(t = 12, r = 12, b = 12, l = 12))) 

Random Forest Variable Importance

# Variable importance 
p_load('vip')

# Extract importance from final model fit on test data
final_fit_rf %>% 
  pluck(".workflow", 1) %>%   
  pull_workflow_fit() %>% 
  vip(num_features = 20)

KNN

# set seed
set.seed(1) 

# Create knn model tuning neighbors
model_knn = 
    nearest_neighbor(neighbors = tune()) %>%
    set_mode("regression") %>%
    set_engine("kknn")

# Define knn workflow
workflow_knn =
    workflow() %>%
    add_model(model_knn) %>%
    add_recipe(covid_recipe)
# cross-validate to find best neighbors 
fit_knn_cv =
    workflow_knn %>%
    tune_grid(
        covid_cv,
        grid = data.frame(neighbors = c(1, 5, seq(10, 100, 10))),
        metrics = yardstick::metric_set(yardstick::rmse, yardstick::rsq, yardstick::mae)
    )

# collect best number of neighbors
fit_knn_cv %>% show_best(metric = "rmse", n = 3)
fit_knn_cv%>% show_best(metric="rsq", n=3)
fit_knn_cv%>% show_best(metric="mae", n=3)

# final workflow on knn model, selection from best rmse
final_knn =
    workflow_knn %>% 
    finalize_workflow(select_best(fit_knn_cv, metric = "rmse"))
# fit on training data
final_fit_knn = final_knn %>% fit(data = covid_train)
# fit on split
final_fit_knn = final_knn %>% last_fit(covid_split)
# get test-sample predictions
final_fit_knn %>% collect_predictions() %>% head()

KNN Results

# Collect KNN test rmse
final_fit_knn %>% collect_metrics()
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard       9.42  Preprocessor1_Model1
## 2 rsq     standard       0.549 Preprocessor1_Model1

KNN Graph of Metrics

# Graph KNN CV metrics 
fit_knn_cv %>% 
    collect_metrics(summarize = T) %>%
    ggplot(aes(x = neighbors, y = mean)) +
    geom_line(size = 0.7, alpha = 0.6) +
    geom_point(size = 2.5) +
    facet_wrap(~ toupper(.metric), scales = "free", nrow = 1) +
    scale_x_continuous("Neighbors (k)", labels = scales::label_number()) +
    scale_y_continuous("Estimate") +
    scale_color_viridis_d("CV Folds:") +
    theme_minimal() + 
    theme(legend.position = "bottom")

# Show each KNN fold's metrics
fit_knn_cv %>% 
    collect_metrics(summarize = F) %>%
    ggplot(aes(x = neighbors, y = .estimate, color = id)) +
    geom_line(size = 0.7, alpha = 0.6) +
    geom_point(size = 2.5) +
    facet_wrap(~ toupper(.metric), scales = "free", nrow = 1) +
    scale_x_continuous("Neighbors (k)", labels = scales::label_number()) +
    scale_y_continuous("Estimate") +
    scale_color_viridis_d("CV Folds:") +
    theme_minimal() + 
    theme(legend.position = "bottom")