From 77b934db71686f95254d1bf5c6d5d79b79cc1bc2 Mon Sep 17 00:00:00 2001 From: Philip Smith Date: Thu, 21 Mar 2024 13:35:05 +0000 Subject: [PATCH] add svm model --- R/fitInteractive.R | 2 +- inst/misc_ml_fitting.Rmd | 68 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/R/fitInteractive.R b/R/fitInteractive.R index f6028b8..3d479fb 100644 --- a/R/fitInteractive.R +++ b/R/fitInteractive.R @@ -37,7 +37,7 @@ fitInteractive <- function(data=NULL,metadata=NULL,autoSave=FALSE,autoSaveInt=15 if (!requireNamespace("colourpicker", quietly = TRUE)) { stop( - "Package \"shinyjs\" must be installed to use interactive fitting", + "Package \"colourpicker\" must be installed to use interactive fitting", call. = FALSE ) } diff --git a/inst/misc_ml_fitting.Rmd b/inst/misc_ml_fitting.Rmd index 9a9dcd7..ccddeb8 100644 --- a/inst/misc_ml_fitting.Rmd +++ b/inst/misc_ml_fitting.Rmd @@ -37,8 +37,8 @@ cellQcLong <- cellQC %>% ```{r datapre} ## Data preview -head(cellQC) dim(cellQC) +head(cellQC) ``` ```{r plots} @@ -99,6 +99,65 @@ summary(modelRecipe) folds <- vfold_cv(trainingData, v = 10,strata = use) ``` +## linear SVM + +```{r linsvmModel} +#set linear svm +linsvm_mod <- svm_rbf(cost = tune(), rbf_sigma = tune()) %>% + set_mode("classification") %>% + set_engine("kernlab") +``` + +```{r linsvmworkflow} +linsvm_WF <- workflow() %>% + add_model(linsvm_mod) %>% + add_recipe(modelRecipe) +``` + +```{r linsvmtune} +#lr_reg_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 30)) + +linsvm_res <- linsvm_WF %>% + tune_grid(resamples = folds, + #grid = lr_reg_grid, + control = control_grid(save_pred = TRUE), + metrics = metric_set(accuracy)) +``` + +```{r linsvmselectTuning} +linsvm_best <- linsvm_res %>% + select_best("accuracy",n = 15) + +linsvm_WF <- finalize_workflow(linsvm_WF,linsvm_best) +``` + +```{r linsvmlastfit} +linsvm_res_final <- last_fit(linsvm_WF,dataSplit) +collect_metrics(linsvm_res_final) +``` + +```{r linsvmpred} +linsv_auc <- linsvm_res_final %>% + collect_predictions() %>% + roc_curve(use, .pred_FALSE) %>% + mutate(model = "SVM (linear)") +``` + +```{r modelPerflinsvm} +linsvm_res_final %>% + collect_metrics() + +# linsvm_res_final %>% +# extract_fit_parsnip() %>% +# vip(num_features = 20) + +# theme_bw() + +linsvm_res_final %>% + collect_predictions() %>% + roc_curve(use, .pred_FALSE) %>% + autoplot() +``` + ## logistic regression ```{r lgModel} @@ -294,11 +353,18 @@ final_xgb_res %>% ```{r compare} metrics <- metric_set(precision,accuracy,recall,f_meas,roc_auc) + lr_metrics <- lr_res_final %>% collect_predictions() %>% metrics(truth = use,estimate = .pred_class,.pred_FALSE) %>% mutate(model = "logistic regression") +linsvm_metrics <- linsvm_res_final %>% + collect_predictions() %>% + metrics(truth = use,estimate = .pred_class,.pred_FALSE) %>% + mutate(model = "SVM (linear)") + + rf_metrics <- final_rf_res %>% collect_predictions() %>% metrics(truth = use,estimate = .pred_class,.pred_FALSE) %>%