Overview:
The COVID-19 vaccine offers the best protection against severe illness and prevention of future outbreaks, however the efficacy depends on a widespread uptake of the vaccine. Currently, the national vaccination rate approaches the ‘guide’ for herd immunity, but 45.8% of the population has yet to be fully vaccinated. Beyond the interests in public health, individuals resistant to the idea of getting the vaccine could benefit from policies that encourage its uptake. Utilizing a predictive machine learning model, we uncover which factors appear to influence vaccine uptake the most by predicting vaccination rates by county. Understanding the sociological background of those unwilling to obtain the vaccine can aid policymakers in designing incentive programs for those groups, as well as offer information for future pandemics.
Methods:
We decided to use regression based methods of varying complexity to predict vaccination rate. We selected Ordinary Least Squares (OLS), Elastic Net, K-Nearest Neighbor, and Random Forest models. For tuning hyperparameters, we utilized a 5-fold cross validation and tuned the following hyperparameters: penalty, mixture, neighbors, mtry and min_n. The cross-validated RMSE selected the following values (ordered as above): 0.01, 1, 30, 8, and 6. In most cases, the MAE or 𝑅2 selected the same hyperparameter values.
Setting up the data
Setting up: Vaccination Data
# Load CDC Covid-19 Vaccination Data for US counties over the course of the pandemic
# This data is updated weekly and the file includes all past entries.
= read_csv("county-vaccination.csv", show_col_types = FALSE)
vaccine
# Rename variables for ease
%<>%
vaccine rename(
"County"="Recip_County",
"State"="Recip_State" ,
"Population"="Census2019" ,
"Vax12Plus"="Series_Complete_12PlusPop_Pct",
"Population18"="Census2019_18PlusPop",
"Population65"="Census2019_65PlusPop"
)
# Subset data to isolate one set of observations - 1/1/22
= subset(vaccine, Date == '1/1/22')
vaccine
# Removing counties with missing FIPS (essential for identification)
= subset(vaccine, FIPS != "UNK")
vaccine
# Remove rows with missing vaccine data
= subset(vaccine, Vax12Plus != 0)
vaccine
# Remove Guam, Puerto Rico, and Virgin Islands
= subset(vaccine, State != c("GU", "PR", "VI"))
vaccine
# Pad FIPS codes with 4 digits to have 0 on left (like it should)
$FIPS = str_pad(vaccine$FIPS, width = 5, side = "left", pad = "0")
vaccine
# Create a dummy variable for Metro area
# There is only 1 missing value for Metro.
# Since the NA is a county that is in a metro area, the missing value will be getting 1
%<>% mutate(Metro =
vaccine ifelse(Metro_status == 'Metro', 1,
ifelse(Metro_status == 'Non-metro', 0, 1)))
# Create a variable for the percent of population under 18
# by finding the percent of the population over 18 and taking the difference
%<>% mutate(PopUnder18 = (100-((Population18/Population)*100)))
vaccine # Rounding to 2 digits
$PopUnder18 = round(vaccine$PopUnder18, digit = 2)
vaccine
# Create a variable for the percent of population over 65 (a high risk group)
%<>% mutate(Pop65Plus = ((Population65/Population)*100))
vaccine # Rounding to 2 digits
$Pop65Plus = round(vaccine$Pop65Plus, digit = 2)
vaccine
# Create a variable for the percent of population between 18 and 65
%<>% mutate(Pop18to65 = (100-PopUnder18-Pop65Plus))
vaccine # Rounding to 2 digits
$Pop18to65 = round(vaccine$Pop18to65, digit = 2)
vaccine
# Subset data with only relevant variables
= subset(vaccine,
vaccine select=c("FIPS", "County", "State", "Vax12Plus", "Metro",
"Population", "Pop18to65", "Pop65Plus", "PopUnder18"))
Setting up: Education Data
# Load education dataset from the USDA Economic Research Service
= read_csv("Education.csv", show_col_types = FALSE)
education
# Pad FIPS codes
$FIPS = str_pad(education$FIPS, width = 5, side = "left", pad = "0")
education
# Remove county as it is formatted differently
= select(education, -c(County))
education
#M Merge the vaccine data and the education data by FIPS code and State
= merge(vaccine, education, by = c("FIPS", "State")) df
Setting up: Work From Home Data
# Load work from home dataset from the USDA Economic Research Service
= read_csv("WFH.csv", show_col_types = FALSE)
WFH
# Pad FIPS. Did not have a FIPS section - I created this in Excel.
$FIPS = str_pad(WFH$FIPS, width = 5, side = "left", pad = "0")
WFH
# Round population density
$PopDensity = round(WFH$PopDensity, digit = 2)
WFH
# Subset with relevant variables
= subset(WFH, select = c("FIPS", "PopDensity", "WFH"))
WFH
# Merge to df by FIPS
= merge(df, WFH, by = c("FIPS"))
df
# Create percent of population that works from home
%<>% mutate(WFH=(WFH/Pop18to65)*100) df
Setting up: Poverty Data
# Load poverty dataset from the USDA Economic Research Service
= read_csv("Poverty.csv", show_col_types = FALSE)
poverty
# Remove state and county, as those variables are formatted differently
= select(poverty, -c(STATEFP, COUNTYFP))
poverty
# Pad FIPS
$FIPS = str_pad(poverty$FIPS, width = 5, side = "left", pad = "0")
poverty
# Merge to df by FIPS
= merge(df, poverty, by = c("FIPS")) df
Setting up: Presidential Election Data
Presidential Election Data - 2012
# Load the dataset - includes elections from 2000-2020 and information on each party.
# We are interested in elections within the past 10 years.
= read_csv("countypres_2000-2020.csv", show_col_types = FALSE)
electionsall
# Remove unnecessary variables
= select(electionsall, -c(state, county_name, office, version, candidate, mode))
electionsall
# Rename variables for ease
%<>% rename("State" = "state_po",
electionsall "FIPS" = "county_fips")
# Pad FIPS
$FIPS = str_pad(electionsall$FIPS, width = 5, side = "left", pad = "0")
electionsall
# Subset 2012 election
= subset(electionsall, year == '2012')
election2012
# Subset Democrat party for 2012 election
= subset(election2012, party == 'DEMOCRAT')
dem2012 # Create a variable for the percent of total votes cast for Democrats in the 2012 election
%<>% mutate(dem2012pct = candidatevotes/totalvotes*100)
dem2012 # Round
$dem2012pct = round(dem2012$dem2012pct, digit = 2)
dem2012# Subset to remove irrelevant variables
= select(dem2012, -c(year, totalvotes, candidatevotes, party))
dem2012 # Merge to df by FIPS and State
= merge(df, dem2012, by = c("FIPS", "State"))
df
# Subset Republican party for 2012 election
= subset(election2012, party == 'REPUBLICAN')
rep2012 # Create a variable for the percent of total votes cast for Republicans in the 2012 election
%<>% mutate(rep2012pct = candidatevotes/totalvotes*100)
rep2012 # Round
$rep2012pct = round(rep2012$rep2012pct, digit = 2)
rep2012# Subset to remove irrelevant variables
= select(rep2012, -c(year, totalvotes, candidatevotes, party))
rep2012 # Merge to df by FIPS and State
= merge(df, rep2012, by = c("FIPS", "State"))
df
# Create a dummy variable for Democrat majority in 2012 election
%<>% mutate(dem2012 = ifelse(dem2012pct > rep2012pct, 1, 0)) df
Presidential Election Data - 2016
# Subset 2016 elections
= subset(electionsall, year == '2016')
election2016
# Subset Democrat party for 2016 election
= subset(election2016, party == 'DEMOCRAT')
dem2016
# Create a variable for the percent of total votes cast for Democrats in the 2016 election
= dem2016 %>% mutate(dem2016pct=candidatevotes/totalvotes*100)
dem2016 # Round
$dem2016pct = round(dem2016$dem2016pct, digit = 2)
dem2016
# Subset to remove irrelevant variables
= select(dem2016, -c(year, totalvotes, candidatevotes, party))
dem2016
# Merge to df by FIPS and State
= merge(df, dem2016, by=c("FIPS", "State"))
df
# Subset Republican party for 2016 election
= subset(election2016, party == 'REPUBLICAN')
rep2016
# Create a variable for the percent of total votes cast for Republicans in the 2012 election
%<>% mutate(rep2016pct=candidatevotes/totalvotes*100)
rep2016 # Round
$rep2016pct = round(rep2016$rep2016pct,digit=2)
rep2016
# Subset to remove irrelevant variables
= select(rep2016, -c(year, totalvotes, candidatevotes, party))
rep2016
# Merge to df by FIPS and State
= merge(df, rep2016, by=c("FIPS", "State"))
df
# Create a dummy variable for Democrat majority in 2016 election
%<>% mutate(dem2016 = ifelse(dem2016pct > rep2016pct, 1,0)) df
Presidential Election Data - 2020
# Subset 2020 elections
= subset(electionsall, year == '2020')
election2020
# Subset Democrat party for 2020 election
= subset(election2020, party == 'DEMOCRAT')
dem2020
# Within the 2020 elections, there are several cases of counties that have multiple
# lines for their results, which is an issue.
# Fix formatting
%<>%
dem2020 group_by(FIPS) %<>%
# Combine the candidate votes for counties with multiple entries, but keep
# total votes the same
::summarise(candidatevotes = sum(candidatevotes),
dplyrtotalvotes = totalvotes,
State = State) %<>%
# Now, we have the accurate candidatevotes value
as.data.frame()
# Remove duplicated FIPS
= dem2020[!duplicated(dem2020$FIPS), ]
dem2020
# Create a variable for the percent of total votes cast for Democrats in the 2020 election
%<>% mutate(dem2020pct = candidatevotes/totalvotes*100)
dem2020 # Round
$dem2020pct = round(dem2020$dem2020pct,digit=2)
dem2020
# Remove irrelevant variables
= select(dem2020, -c(totalvotes, candidatevotes))
dem2020
# Merge to df by FIPS and State
= merge(df, dem2020, by = c("FIPS", "State"))
df
# Subset Republican party for 2020 election
= subset(election2020, party == 'REPUBLICAN')
rep2020
# Fix formatting
%<>%
rep2020 group_by(FIPS) %<>%
# Combine the candidate votes for counties with multiple entries, but keep
# total votes the same
::summarise(candidatevotes = sum(candidatevotes),
dplyrtotalvotes = totalvotes,
State = State) %<>%
# Now, we have the accurate candidatevotes value
as.data.frame()
# Remove duplicated FIPS
= rep2020[!duplicated(rep2020$FIPS), ]
rep2020
# Create a variable for the percent of total votes cast for Republicans in the 2020 election
%<>% mutate(rep2020pct=candidatevotes/totalvotes*100)
rep2020 # Round
$rep2020pct = round(rep2020$rep2020pct,digit=2)
rep2020
# Remove irrelevant variables
= select(rep2020, -c(totalvotes, candidatevotes))
rep2020
# Merge to DF by FIPS and State
= merge(df, rep2020, by=c("FIPS", "State"))
df # Round
%<>% mutate(dem2020 = ifelse(dem2020pct > rep2020pct, 1,0)) df
Presidential Election Data - Voting Patterns Over Time
# Create variables for voting patterns
#Previous Democrat voting
%<>% mutate(pastvoting_dem = (dem2012+dem2016+dem2020))
df
# Average percent of votes Republican in past 3 presidential elections
%<>% mutate(avgpct_rep = ((rep2012pct+rep2016pct+rep2020pct)/3))
df
# Average percent of votes Democrat in past 3 presidential elections
%<>% mutate(avgpct_dem = ((dem2012pct+dem2016pct+dem2020pct)/3))
df
# Round
$pastvoting_dem = round(df$pastvoting_dem,digit = 2)
df$avgpct_rep = round(df$avgpct_rep, digit = 2)
df$avgpct_dem = round(df$avgpct_dem, digit = 2)
df
# Remove unnecessary variables
= subset(df, select=-c(dem2012pct, dem2016pct,dem2020pct, rep2012pct,
df
rep2016pct, rep2020pct, Pop18to65))# Filter
%<>% filter(is.na(avgpct_dem) == FALSE)
df #Rename df
= df
covid # Write new CSV so do not have to load code every time
# write.csv(df,"524Project.csv", row.names = FALSE)
Exploring Data
Exploring Data: Map
# County Covid vaccination rates across the US
= covid %>% rename(fips = FIPS)
covid_map
%<>%
covid_map mutate(above_med = ifelse(Vax12Plus >= 54.3, "Above Median", "Below Median"))
# graph vaccination rates
plot_usmap(
data = covid_map, values = "above_med") +
scale_fill_manual(values = c("Above Median" = "dark cyan", "Below Median" = "pink"),
name = "Median Vaccination Rate (Jan. 2022)") +
ggtitle("COVID-19 Median Vaccination Rates by County") +
theme(legend.position = "right",
legend.text = element_text(size=14),
legend.title = element_text(size=16),
panel.background = element_rect(fill="lavender"),
plot.title = element_text(hjust = 0.5,
size = 20,
margin = margin(t = 10, b = 10)))
Exploring Data: Correlation Heat Map
# Correlation heat map
# Create normal correlation matrix
= covid %>% subset(select = -c(1,2,3,17,18, 19))
covid1 # Round
= round(cor(covid1),2)
cormat
# Melt package
p_load(reshape2)
<- function(cormat){
get_upper_tri lower.tri(cormat)]<- NA
cormat[return(cormat)
}<- get_upper_tri(cormat)
upper_tri <- melt(upper_tri, na.rm = TRUE)
melted_cormat
# Plot heatmap
ggplot(data = melted_cormat,
aes(Var2, Var1, fill = value)) +
coord_fixed() +
geom_tile(color = "white") +
scale_fill_gradient2(low = "dark cyan", high = "indianred1", mid = "white",
midpoint = 0, limit = c(-1,1), space = "Lab",
name="Correlation") +
theme_minimal() +
labs(x ="Variables", y ="Variables",
title = "Correlation Matrix Heatmap") +
theme(axis.text.x = element_text(angle = 60, vjust = 1, hjust = 1),
plot.title = element_text(hjust = 0.5,
size = 14,
margin = margin(t = 10, b = 10)))
Working with the data
Split & Recipe
# Set seed
set.seed(1)
# Remove county and state columns
%<>% subset(select = -c(2,3))
covid
# Make Split
= covid %>% initial_split(prop = 0.8, strata = Vax12Plus)
covid_split
# Separate training and testing data
= covid_split %>% training()
covid_train = covid_split %>% testing()
covid_test
# Define the recipe:
=
covid_recipe recipe(Vax12Plus ~ .,data = covid_train) %>%
# Normalize for numeric predictors
step_normalize(all_predictors() & all_numeric()) %>%
# KNN imputation for categorical predictors
# unnecessary
step_impute_knn(all_predictors() & all_nominal(), neighbors = 5) %>%
# Create dummies for categorical variables
step_dummy(all_predictors() & all_nominal()) %>%
update_role(FIPS, new_role = "id variable")
# Define the 5-fold split
= covid_train %>% vfold_cv(v = 5) covid_cv
Linear Regression
# Set seed
set.seed(1)
# Set model
= linear_reg() %>%
model_linear set_engine("lm")
# Create workflow
= workflow() %>%
workflow_linear add_model(model_linear) %>%
add_recipe(covid_recipe)
# Fit cross validated samples and extract rmse rsq and mae
=
cv_linear %>%
workflow_linear fit_resamples(
object = model_linear,
preprocessor = covid_recipe,
resamples = covid_cv,
metrics = yardstick::metric_set(yardstick::rmse, yardstick::rsq, yardstick::mae))
# Examine metrics
%>% collect_metrics()
cv_linear
# Choose a model with best rmse and finalize workflow
=
final_lin %>%
workflow_linear finalize_workflow(select_best(cv_linear, metric = "rmse"))
# Fit model on training data
= final_lin %>% fit(data = covid_train)
final_fit_lin
# Fit model on test data
= final_lin %>% last_fit(covid_split)
final_fit_lin
# Get test-sample predictions
%>% collect_predictions() %>% head() final_fit_lin
Linear Regression Results
# Get metrics of linear regression final fit
%>% collect_metrics() final_fit_lin
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 9.63 Preprocessor1_Model1
## 2 rsq standard 0.520 Preprocessor1_Model1
Elasticnet
# Set seed
set.seed(1)
# Set tuning values for penalty and mixture
= seq(from = 0, to = 1, by = 0.1)
alphas = 10^seq(from = 2, to = -3, length = 1e2)
lambdas
# Define the elasticnet model
=
model_net linear_reg(penalty = tune(),
mixture = tune()) %>%
set_engine("glmnet") %>%
set_mode("regression")
# Define workflow
= workflow() %>%
workflow_net add_model(model_net) %>%
add_recipe(covid_recipe)
# Cross-validate elasticnet with lambdas and alphas
=
cv_net %>%
workflow_net tune_grid(
covid_cv,grid = expand_grid(mixture = alphas, penalty = lambdas),
metrics = yardstick::metric_set(yardstick::rmse, yardstick::rsq, yardstick::mae))
# Collect best hyperparameter values for each metric
%>% show_best(metric = "rmse", n = 3)
cv_net %>% show_best(metric = "rsq", n = 3)
cv_net %>% collect_metrics(metric = "mae")
cv_net
# Elasticnet chooses lasso for this and penalty = 0.001 chosen by RMSE and RSQ
# Finalize workflow with best RMSE model
=
final_net %>%
workflow_netfinalize_workflow(select_best(cv_net, metric = "rmse"))
# Fit final model to training data
= final_net %>% fit(data = covid_train)
final_fit_net # Fit final model to the data split
= final_net %>% last_fit(covid_split)
final_fit_net # Get test-sample predictions
%>% collect_predictions() %>% head() final_fit_net
Elasticnet Results
# Retrieve elasticnet metrics of final fit
%>% collect_metrics() final_fit_net
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 9.63 Preprocessor1_Model1
## 2 rsq standard 0.521 Preprocessor1_Model1
###Elasticnet Graph of Metrics
# Graph of elasticnet metrics
autoplot(cv_net, metric = "rmse") +
theme_bw() +
scale_x_continuous(limits = c(0, 1)) +
scale_y_continuous(limits = c(10.25,10.8)) +
labs(x ="RMSE", y ="Amount of Regularization",
title = "Elasticnet Results - Amount of Regularization vs. RMSE") +
theme(plot.background = element_rect(fill ="lavender"),
plot.title = element_text(hjust = 0.5,
size = 14,
margin = margin(t = 10, b = 10)),
axis.text.x = element_text(color = "black",
size = 10,
margin = margin(t = 5, b = 5)),
axis.text.y = element_text(color = "black",
size = 10,
margin = margin(r = 5, l = 5)),
axis.title = element_text(size = 13,
margin = margin(t = 12, r = 12, b = 12, l = 12)))
autoplot(cv_net, metric = "mae") +
theme_bw() +
scale_x_continuous(limits = c(0, 1))+
scale_y_continuous(limits = c(7.3,7.9)) +
labs(x ="MAE", y ="Amount of Regularization",
title = "Elasticnet Results - Amount of Regularization vs. MAE") +
theme(plot.background = element_rect(fill ="lavender"),
plot.title = element_text(hjust = 0.5,
size = 14,
margin = margin(t = 10, b = 10)),
axis.text.x = element_text(color = "black",
size = 10,
margin = margin(t = 5, b = 5)),
axis.text.y = element_text(color = "black",
size = 10,
margin = margin(r = 5, l = 5)),
axis.title = element_text(size = 13,
margin = margin(t = 12, r = 12, b = 12, l = 12)))
Random Forest
# Set seed
set.seed(1)
# Create model tuning mtry and min_n
<- rand_forest(
tune_spec mtry = tune(),
trees = 200,
min_n = tune()
%>%
) set_mode("regression") %>%
set_engine("ranger", importance = "impurity")
# Define workflow
<- workflow() %>%
tune_wf add_recipe(covid_recipe) %>%
add_model(tune_spec)
# Create grid of possible mtry and min_n values
<- grid_regular(
rf_grid mtry(range = c(4, 10)),
min_n(range = c(2, 8)),
levels = 5)
# Cross-validate to tune hyperparameters
<- tune_grid(
regular_res
tune_wf,resamples = covid_cv,
grid = rf_grid,
metrics = yardstick::metric_set(yardstick::rmse, yardstick::rsq, yardstick::mae))
# Collect best hyperparameters from metrics
%>% show_best(metric="rmse")
regular_res %>% show_best(metric="rsq")
regular_res %>% show_best(metric="mae")
regular_res
# Final fit with best model determined by rmse
=
final_rf %>%
tune_wffinalize_workflow(select_best(regular_res, metric = "rmse"))
# Fit model on training data
= final_rf %>% fit(data = covid_train)
final_fit_rf
# Fit model on split
= final_rf %>% last_fit(covid_split)
final_fit_rf
# Get test-sample predictions
%>% collect_predictions() %>% head() final_fit_rf
Random Forest Results
# Retrieve metrics of random forest final fit
%>% collect_metrics() final_fit_rf
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 9.17 Preprocessor1_Model1
## 2 rsq standard 0.569 Preprocessor1_Model1
Random Forest Graph of Metrics
# Graph of random forest metrics
autoplot(regular_res, metric = "rmse") +
theme_bw()+
labs(x ="RMSE", y ="# of Randomly Selected Predictors",
title = "Random Forest Results - # of Predictors vs. RMSE") +
theme(plot.background = element_rect(fill ="lavender"),
plot.title = element_text(hjust = 0.5,
size = 14,
margin = margin(t = 10, b = 10)),
axis.text.x = element_text(color = "black",
size = 10,
margin = margin(t = 5, b = 5)),
axis.text.y = element_text(color = "black",
size = 10,
margin = margin(r = 5, l = 5)),
axis.title = element_text(size = 13,
margin = margin(t = 12, r = 12, b = 12, l = 12)))
autoplot(regular_res, metric = "mae") +
theme_bw()+
labs(x ="MAE", y ="# of Randomly Selected Predictors",
title = "Random Forest Results - # of Predictors vs. MAE") +
theme(plot.background = element_rect(fill ="lavender"),
plot.title = element_text(hjust = 0.5,
size = 14,
margin = margin(t = 10, b = 10)),
axis.text.x = element_text(color = "black",
size = 10,
margin = margin(t = 5, b = 5)),
axis.text.y = element_text(color = "black",
size = 10,
margin = margin(r = 5, l = 5)),
axis.title = element_text(size = 13,
margin = margin(t = 12, r = 12, b = 12, l = 12)))
Random Forest Variable Importance
# Variable importance
p_load('vip')
# Extract importance from final model fit on test data
%>%
final_fit_rf pluck(".workflow", 1) %>%
pull_workflow_fit() %>%
vip(num_features = 20)
KNN
# set seed
set.seed(1)
# Create knn model tuning neighbors
=
model_knn nearest_neighbor(neighbors = tune()) %>%
set_mode("regression") %>%
set_engine("kknn")
# Define knn workflow
=
workflow_knn workflow() %>%
add_model(model_knn) %>%
add_recipe(covid_recipe)
# cross-validate to find best neighbors
=
fit_knn_cv %>%
workflow_knn tune_grid(
covid_cv,grid = data.frame(neighbors = c(1, 5, seq(10, 100, 10))),
metrics = yardstick::metric_set(yardstick::rmse, yardstick::rsq, yardstick::mae)
)
# collect best number of neighbors
%>% show_best(metric = "rmse", n = 3)
fit_knn_cv %>% show_best(metric="rsq", n=3)
fit_knn_cv%>% show_best(metric="mae", n=3)
fit_knn_cv
# final workflow on knn model, selection from best rmse
=
final_knn %>%
workflow_knn finalize_workflow(select_best(fit_knn_cv, metric = "rmse"))
# fit on training data
= final_knn %>% fit(data = covid_train)
final_fit_knn # fit on split
= final_knn %>% last_fit(covid_split)
final_fit_knn # get test-sample predictions
%>% collect_predictions() %>% head() final_fit_knn
KNN Results
# Collect KNN test rmse
%>% collect_metrics() final_fit_knn
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 9.42 Preprocessor1_Model1
## 2 rsq standard 0.549 Preprocessor1_Model1
KNN Graph of Metrics
# Graph KNN CV metrics
%>%
fit_knn_cv collect_metrics(summarize = T) %>%
ggplot(aes(x = neighbors, y = mean)) +
geom_line(size = 0.7, alpha = 0.6) +
geom_point(size = 2.5) +
facet_wrap(~ toupper(.metric), scales = "free", nrow = 1) +
scale_x_continuous("Neighbors (k)", labels = scales::label_number()) +
scale_y_continuous("Estimate") +
scale_color_viridis_d("CV Folds:") +
theme_minimal() +
theme(legend.position = "bottom")
# Show each KNN fold's metrics
%>%
fit_knn_cv collect_metrics(summarize = F) %>%
ggplot(aes(x = neighbors, y = .estimate, color = id)) +
geom_line(size = 0.7, alpha = 0.6) +
geom_point(size = 2.5) +
facet_wrap(~ toupper(.metric), scales = "free", nrow = 1) +
scale_x_continuous("Neighbors (k)", labels = scales::label_number()) +
scale_y_continuous("Estimate") +
scale_color_viridis_d("CV Folds:") +
theme_minimal() +
theme(legend.position = "bottom")