Skip to content

Commit a312942

Browse files
committed
Add vignette
1 parent 68b0ede commit a312942

File tree

4 files changed

+52
-5
lines changed

4 files changed

+52
-5
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
*.RData
44
*.DS_Store
55
.Rproj.user
6+
inst/doc

DESCRIPTION

+3
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ Depends: R (>= 3.4.1)
77
License: What license is it under?
88
Encoding: UTF-8
99
LazyData: true
10+
Suggests: knitr,
11+
rmarkdown
12+
VignetteBuilder: knitr

R/cross_validation.R

+5-5
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ cross_validation <- function(X, y, k = 3, shuffle = TRUE, random_state = 0) {
3838
split_indices <- function(X2, k2, shuffle2 = TRUE) {
3939
set.seed(random_state)
4040
length <- dim(X2)[1]
41-
random_column <- sample(rep(1:k2, each=round(length/k2), len=length))
42-
df <- data.frame(cbind(data_index = 1:length, groups = random_column))
41+
splitting_column <- rep(1:k2, each=round(length/k2), len=length)
42+
df <- data.frame(cbind(data_index = 1:length, groups = splitting_column))
4343
if (shuffle2 == FALSE){
44-
df <- df[order(df$groups),]
45-
} else {
4644
df
45+
} else {
46+
df$groups <- sample(df$groups, size=length, replace=FALSE)
4747
}
4848
indices_list <- list()
4949
for (number in 1:k2){
@@ -55,7 +55,7 @@ cross_validation <- function(X, y, k = 3, shuffle = TRUE, random_state = 0) {
5555
# Apply cross_validation here
5656
if (shuffle == TRUE){
5757
indices_list <- split_indices(X2 = X, k2 = k, shuffle2 = TRUE)
58-
} else{
58+
} else {
5959
indices_list <- split_indices(X2 = X, k2 = k, shuffle2 = FALSE)
6060
}
6161

vignettes/CrossR.Rmd

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
---
2+
title: "CrossR"
3+
output: rmarkdown::html_vignette
4+
vignette: >
5+
%\VignetteIndexEntry{Vignette Title}
6+
%\VignetteEngine{knitr::rmarkdown}
7+
%\VignetteEncoding{UTF-8}
8+
---
9+
10+
## Overview
11+
12+
Cross-validation is an important technique used in model selection and hyper-parameter optimization. Scores from cross-validation are a good estimation of test score of a predictive model in test data or new data as long as the IID assumption approximately holds in data. This package aims to provide a standardized pipeline for performing cross-validation for different modeling functions in R. In addition, summary statistics of the cross-validation results are provided for users.
13+
14+
The `CrossR` package (short for _Cross_-validation in _R_) is a set of functions for implementing cross-validation inside the R environment.
15+
16+
### Similar packages
17+
18+
Cross-validation can be implemented with the [`caret`](https://cran.r-project.org/web/packages/caret/caret.pdf) package in R. `caret` contains the function `createDataPartition()` to split the data and `train_Control()` to apply cross-validation with different methods depending on the `method` argument. We have observed that `caret` functions have some features that make the cross-validation process cumbersome. `createDataPartition()` splits the *indices* of the data which could be used later on to actually split the data into training and test data. This will be applied with one step using `split_data()` in `CrossR`.
19+
20+
21+
## Functions
22+
23+
Three main functions in `CrossR`:
24+
25+
- `train_test_split`: This function partitions data into `k`-fold and returns the partitioned indices. A random shuffling option is provided. (`stratification` option for imbalanced representations will also be included if time allows).
26+
27+
- `cross_validation`: This function performs `k`-fold cross validation using the partitioned data and a selected model. It returns the scores of each validation. Additional methods for corss validation will be implemented (such as "Leave-One-Out" if time allows).
28+
29+
- `summary_cv`: This function outputs summary statistics(mean, median, standard deviation) of cross-validation scores.
30+
31+
32+
## Usage
33+
34+
```
35+
library(CrossR)
36+
37+
split_data <- train_test_split(X, y, test_size = 0.25, random_state = 0, shuffle = TRUE)
38+
39+
scores <- cross_validation(split_data['X_train'], split_data['y_train'], k = 3, shuffle = TRUE, random_state = 0)
40+
41+
summary_cv(scores)
42+
```
43+

0 commit comments

Comments
 (0)