Skip to content

Commit

Permalink
add svm model
Browse files Browse the repository at this point in the history
  • Loading branch information
Phil9S committed Mar 21, 2024
1 parent 1723066 commit 77b934d
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 2 deletions.
2 changes: 1 addition & 1 deletion R/fitInteractive.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fitInteractive <- function(data=NULL,metadata=NULL,autoSave=FALSE,autoSaveInt=15

if (!requireNamespace("colourpicker", quietly = TRUE)) {
stop(
"Package \"shinyjs\" must be installed to use interactive fitting",
"Package \"colourpicker\" must be installed to use interactive fitting",
call. = FALSE
)
}
Expand Down
68 changes: 67 additions & 1 deletion inst/misc_ml_fitting.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ cellQcLong <- cellQC %>%

```{r datapre}
## Data preview
head(cellQC)
dim(cellQC)
head(cellQC)
```

```{r plots}
Expand Down Expand Up @@ -99,6 +99,65 @@ summary(modelRecipe)
folds <- vfold_cv(trainingData, v = 10,strata = use)
```

## linear SVM

```{r linsvmModel}
#set linear svm
linsvm_mod <- svm_rbf(cost = tune(), rbf_sigma = tune()) %>%
set_mode("classification") %>%
set_engine("kernlab")
```

```{r linsvmworkflow}
linsvm_WF <- workflow() %>%
add_model(linsvm_mod) %>%
add_recipe(modelRecipe)
```

```{r linsvmtune}
#lr_reg_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 30))
linsvm_res <- linsvm_WF %>%
tune_grid(resamples = folds,
#grid = lr_reg_grid,
control = control_grid(save_pred = TRUE),
metrics = metric_set(accuracy))
```

```{r linsvmselectTuning}
linsvm_best <- linsvm_res %>%
select_best("accuracy",n = 15)
linsvm_WF <- finalize_workflow(linsvm_WF,linsvm_best)
```

```{r linsvmlastfit}
linsvm_res_final <- last_fit(linsvm_WF,dataSplit)
collect_metrics(linsvm_res_final)
```

```{r linsvmpred}
linsv_auc <- linsvm_res_final %>%
collect_predictions() %>%
roc_curve(use, .pred_FALSE) %>%
mutate(model = "SVM (linear)")
```

```{r modelPerflinsvm}
linsvm_res_final %>%
collect_metrics()
# linsvm_res_final %>%
# extract_fit_parsnip() %>%
# vip(num_features = 20) +
# theme_bw()
linsvm_res_final %>%
collect_predictions() %>%
roc_curve(use, .pred_FALSE) %>%
autoplot()
```

## logistic regression

```{r lgModel}
Expand Down Expand Up @@ -294,11 +353,18 @@ final_xgb_res %>%

```{r compare}
metrics <- metric_set(precision,accuracy,recall,f_meas,roc_auc)
lr_metrics <- lr_res_final %>%
collect_predictions() %>%
metrics(truth = use,estimate = .pred_class,.pred_FALSE) %>%
mutate(model = "logistic regression")
linsvm_metrics <- linsvm_res_final %>%
collect_predictions() %>%
metrics(truth = use,estimate = .pred_class,.pred_FALSE) %>%
mutate(model = "SVM (linear)")
rf_metrics <- final_rf_res %>%
collect_predictions() %>%
metrics(truth = use,estimate = .pred_class,.pred_FALSE) %>%
Expand Down

0 comments on commit 77b934d

Please sign in to comment.