forked from stepthom/sandbox
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Naive Bayes.Rmd
368 lines (263 loc) · 19.6 KB
/
Naive Bayes.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
---
title: "Naive Bayes - Employee Attrition"
output:
html_document:
toc: yes
toc_float: yes
code_folding: hide
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
```
##Introduction
Given the potential disruption to the work environment and the required resources to attract, acquire, and train new talent, understanding factors that influence employee attrition is important to human resource departments. In this exercise, we'll explore the IBM Human Resources Analytics dataset, which contains data on employee attrition (whether an employee will leave the company). Throughout this exercise, we'll review basic data wrangling and visualization, as well as classification algorithms such as Naive Bayes, k-Nearest Neighbors, and Support Vector Machines. Additionally, we will compare the models using classification metrics and explain why certain metrics can be misleading.
You'll notice on the right hand side of this document a number of dropdown menus that can be used to unhide the code used to generate the outputs. There will be questions guiding you through this analysis that ask you certain questions about the data. Please use the accompanying practice RMarkdown file to try and generate the correct outputs on your own before peeking! If you get stuck, we have added links throughout the document to point you in the right direction.
##Packages
First, let's have a look at some of the packages that will be used in this exercise:
- ```rsample``` contains the IBM attrition data set called ```attrition```, also helpful for sampling data
- ```dplyr``` is a popular package for manipulating data in R
- ```ggplot2``` is robust and flexible data visualization package
- ```caret``` is a statisitcal / machine learning wrapper or framework for other ML packages
- ```corrplot``` is a visualization library specifically for correlation
If you don't have these installed, please install them now. If you're not sure how to do this, try the ```install.packages()``` function with the package name in quotes.
```{r install_packages, warning = FALSE, message = FALSE}
# install.packages("knitr")
# install.packages("tibble")
# install.packages("rsample")
# install.packages("dplyr")
# install.packages("ggplot2")
# install.packages("corrplot")
# install.packages("corrr")
# install.packages("caret")
```
Once installed, we need to load the packages into our library using the ```library()``` function. Be sure to watch out for errors when loading libraries! Sometimes two libraries will share a common function name, and loading them both will cause conflicts known as "masking".
```{r load_packages, warning = FALSE, message = FALSE}
library(tibble)
library(dplyr) # data transformation
library(rsample) # data splitting
library(ggplot2) # data visualization
library(corrplot)
library(caret) # implementing with caret
library(corrr)
library(knitr)
```
##Data Exploration
The data that we will be using was already loaded into memory when we loaded the ```rsample``` library. Reading data into R is beyond the scope of this document, but interested students should have a look at [this online tutorial](https://www.datacamp.com/community/tutorials/r-data-import-tutorial).
Now that the data is available to us, let's start by exploring the data set a bit. There are a number of ways to do this, but using the ```dim()```, ```names()```, and ```str()``` functions are a good place to start. If you're not sure what they mean, try typing ```?dim``` in the console to read the documentation.
What is the dimension of the dataset?
```{r, warning = FALSE, message = FALSE}
attrition %>% dim()
```
**Note:** For those following along in RStudio, you may be wondering what the ```%>%``` is doing here. This is called the "pipe operator", which passes whatever is on the left into whatever is on the right. While it's not immediately obvious why this might be useful, the [following tutorial](https://www.datacamp.com/community/tutorials/pipe-r-tutorial) provides some helpful information about its historical use in programming languages and how it works in R. At their core, pipe operators make your code more readable by avoiding nested functions.
So we have a dataset with 31 columns and 1470 rows, but what exactly does the dataset look like? What are the names of the columns in our data? *Hint:* Try the ```names()``` function.
```{r, warning = FALSE, message = FALSE}
attrition %>% names() %>% kable()
```
Now we know which features are available to us, but what *kind* of data is it? We could use the ```class()``` function to return the data type on an individual column, but the ```str()``` function can be used to tell us the data type of all columns in the dataset, including the different levels of categorical data, a small sample of each data, and many other helpful pieces of information. Similarly, try the ```glimpse()``` function for another view.
**Question:** Based on the above below, are there any data types that you think need to be changed? Let's have a look at ```JobLevel``` for example, using the ```class()``` function on the column.
```{r, warning = FALSE, message = FALSE}
class(attrition$JobLevel)
```
Should ```JobLevel``` really be an integer? Or does it feel more like a categorical variable? The ```str()``` function below outputs the data type of each variable - can you spot any others that you might want to change?
```{r, warning = FALSE, message = FALSE}
attrition %>% str()
```
Let's clean some of these up using the ```factor()``` function, and then check the class of ```JobLevel``` to make sure it worked.
```{r, warning = FALSE, message = FALSE}
attrition <- attrition %>%
mutate(
JobLevel = factor(JobLevel),
StockOptionLevel = factor(StockOptionLevel),
TrainingTimesLastYear = factor(TrainingTimesLastYear),
Attrition = factor(Attrition, levels = c("Yes", "No"))
)
class(attrition$JobLevel)
```
**Note:** If this is your first time using the ```dplyr``` library, you might be wondering what ```mutate``` does. You can think of this function as just another way to create a variable in a dataframe. The ```dplyr``` library is one of the best ways in R to wrangle / transform your data so that it's in a format that is easily digestible by models or plotting code. The [following course](https://www.datacamp.com/courses/dplyr-data-manipulation-r-tutorial) is a great introduction to the many different transformations you can do.
Let's see if we can spot some basic patterns. Try using the ```table()``` function to calculate a count basic table from two categorical variables. You then pass the count table into the ```prop.table()``` function to return the percentages.
**Question:** Have a look at the table below. How does ```JobSatisfaction``` impact ```Attrition```?
```{r}
table(attrition$JobSatisfaction, attrition$Attrition) %>%
kable(align = "c")
```
**Question:** What about ```WorkLifeBalance```?
```{r}
table(attrition$WorkLifeBalance, attrition$Attrition) %>%
kable(align = "c")
```
##Data Visualization
Often times, you're going to be working with datasets that have many, many variables, and looking for patterns one by one is not a feasible solution. Let's use some visual techniques to identify patterns more efficiently. Do any interesting patterns emerge?
**Note:** Visualizing categorical and numerical data can be accomplished in many different ways, but using density plots for numeric data and bar plots for categorical data is a good place to start start. The ```ggplot2``` library uses a philosophy known as the "grammar of graphics", and while very powerful once understood, does come with its own learning curve. Thankfully, there is a [three-part series](https://www.datacamp.com/courses/data-visualization-with-ggplot2-1) on the topic to get you started. Until then, here is one visualization approach that hopefully gets you excited to learn the library.
There are several ways to select columns based on their class, for example, ```select_if(data_frame, is.numeric)``` would return all numeric columns. In our case, though, we want to return all numeric columns, *as well as* the ```Attrition``` column, which is a categorical column. One way to do this is to create a variable containing a list of column names you would like to select, and then pass it to ```select(data_frame, one_of(column_list))```. Let's try it that way for numeric variables, and then leverage the ```select_if(data_frame, is.factor)``` when visualizing categorical data, since the ```Attrition``` column will be included in ```is.factor()```. Additionally, for visual purposes, let's only look at categorical data with 5 levels or less using the ```nlevels()``` function.
```{r, warning = FALSE, message = FALSE}
numeric_col_names <- names(attrition)[sapply(attrition, is.numeric)]
numeric_col_names <- append("Attrition", numeric_col_names)
kable(numeric_col_names, col.names = "Numeric Features Plus Attrition")
```
**Numerical Variables:**
```{r, warning = FALSE, message = FALSE, fig.height = 8}
numeric_attrition <- attrition %>%
select(one_of(numeric_col_names))
numeric_attrition %>%
gather(metric, value, -Attrition) %>%
ggplot(aes(value, fill = Attrition)) +
geom_density(show.legend = TRUE, alpha = 0.75) +
facet_wrap(~ metric, scales = "free", ncol = 3) +
theme_bw() +
labs(x = "", y = "")
```
**Categorical Variables:**
```{r,warning = FALSE, message = FALSE, fig.height = 8}
categorical_col_names <- names(attrition)[sapply(attrition, nlevels) < 5]
categoric_attrition <- attrition %>%
select_if(is.factor) %>%
select(one_of(categorical_col_names))
categoric_attrition %>%
gather(metric, value, -Attrition) %>%
ggplot(aes(value, fill = Attrition)) +
geom_bar(position = "dodge", col = "black") +
facet_wrap(~ metric, ncol = 3, scales = "free") +
theme_bw() +
theme(axis.text.x = element_text(angle = 90)) +
labs(x = "", y = "")
```
Another way to look at the data is to use a correlation plot, which shows how each variable is correlated to every other variable. For simplicity, we'll only be looking at the correlation between numeric data, but [this resource](https://www.r-bloggers.com/to-eat-or-not-to-eat-thats-the-question-measuring-the-association-between-categorical-variables/) provides an approach for categorical data.
For the numerical correlation plot, we need to first calculate the correlation matrix using the ```cor()``` function. Then, we pass the resulting matrix into ```corrplot()```, which outputs our correlation plot.
**Question:** Does anything from the correlation plot stand out? Try combining your observations from the previous visualizations and see if you can come up with some hypotheses about why employees leave. Remember these for later on when we get to the results!
```{r, message = FALSE, warning = FALSE}
numeric_attrition %>%
select(-Attrition) %>%
cor() %>%
corrplot(method = "shade", type = "lower")
```
##Modeling
###Naive Bayes
Before we use the Naive Bayes algorithm on the entire dataset, let's try practicing it on a small subset of the data. Let's pretend we only have the first ten rows of two explanatory variables of interest: ```Gender``` and ```BusinesTravel```.
**Question:** Given this dataset, what is the probability that a **male who travels rarely** will leave the company? To answer this, you'll need the following information. If you get stuck, try following [this tutorial](https://monkeylearn.com/blog/practical-explanation-naive-bayes-classifier/) to point you in the right direction. Good luck!
**Naive Bayes Tables**
```{r}
subset <- attrition %>% select(Attrition, Gender, BusinessTravel) %>% head(30)
kable(subset, align = "c")
table(subset$Attrition)
table(subset$Attrition, subset$Gender)
table(subset$Attrition, subset$BusinessTravel)
```
###Implementation
The ```caret``` package is one of the most popular machine learning libraries in R because it provides a common way (framework) to train machine learning models. It acts as a wrapper around packages with different algorithm implementations, so instead of having to learn the syntax for every single package, you can just use common functions in ```caret```. The author of the package, Max Kuhn, maintains an [active documentation / tutorial](https://topepo.github.io/caret/) website that contains everything you need to know about the library. He also happens to be the author of your textbook, "Applied Predictive Modeling", which *also* has an entire package dedicated to it, aptly named ```AppliedPredictiveModeling```.
Before we start building our model, we need to split the data into training and testing sets. Here we'll use the ```initial_split()``` function, which essentially creates an index from which the ```training()``` and ```testing()``` functions extract the data.
**Question:** Have a look at the code chunk below - what do you think the ```strata = ``` argument is doing? *Hint:* Take a look at the proportions of the ```Attrition``` categories in both the training / testing set below.
```{r, message = FALSE, warning = FALSE}
set.seed(1234)
split <- initial_split(attrition, prop = .7, strata = "Attrition")
train <- training(split)
test <- testing(split)
```
```{r, message = FALSE, warning = FALSE}
table(train$Attrition) %>%
prop.table() %>%
kable(col.names = c("Training Set: Attrition", "Freq (%)"), align = "c")
```
```{r, message = FALSE, warning = FALSE}
table(test$Attrition) %>%
prop.table() %>%
kable(col.names = c("Testing Set: Attrition", "Freq (%)"), align = "c")
```
###Results
In ```caret```, the ```train()``` function is the primary way to build machine learning models, and the argument ```method``` is used to specify which implementation you want to use. Under the hood, the ```method``` argument retrieves the desired package implementation, which is why you may be prompted to install a package if you're trying a new algorithm. [Here's](https://topepo.github.io/caret/available-models.html) a list of all 237 possible algorithm implementations - today, we'll be using the ```naive_bayes``` method. Once built, we will use the ```predict()``` function on the test set to get our estimates, and then use the ```confusionMatrix()``` function to evaluate performance. For the sake of comparison, we've also trained a k-Nearest Neighbours (kNN) and a Support Vector Machine (SVM) model, which you will learn about a little later on in the lecture.
**Question:** The results of the NB, kNN, and SVM models are presented below. Calculate the accuracy of each model using the provided confusion matrix. Which model would you select? Does anything in particular stand out from these confusion matrices?
**Naive Bayes Test Confusion Matrix:**
```{r, message = FALSE, warning = FALSE}
nb.m1 <- train(
Attrition ~ .,
data = train,
method = "naive_bayes"
)
predictions_nb <- predict(nb.m1, test)
actuals <- test$Attrition
# Testing Results
conf_matrix_nb <- confusionMatrix(predictions_nb, actuals)
conf_matrix_nb$table
```
**k-Nearest Neighbours Test Confusion Matrix:**
```{r, message = FALSE, warning = FALSE}
knn.m1 <- train(
Attrition ~ .,
data = train,
method = "knn"
)
predictions_knn <- predict(knn.m1, test)
# Testing Results
conf_matrix_knn <- confusionMatrix(predictions_knn, actuals)
conf_matrix_knn$table
```
**Support Vector Machine Test Confusion Matrix:**
```{r, message = FALSE, warning = FALSE}
svm.m1 <- train(
Attrition ~.,
data = train,
method = "svmLinear"
)
# Testing Results
predictions_svm <- predict(svm.m1, test)
# Testing Results
conf_matrix_svm <- confusionMatrix(predictions_svm, actuals)
conf_matrix_svm$table
```
**Question:** Let's have a closer look at just the Naive Bayes model now. Think back to the proportions of the target variable ```Attrition```, provided below for convenience. What do you think the model has learned? **Hint:** Can you think of a trivial model that would give you an accuracy of ~84%?
```{r, message = FALSE, warning = FALSE}
table(attrition$Attrition) %>% prop.table() %>% round(2)
```
Of course, a model that always picks "No" for ```Attrition``` would have an accuracy score of ~ 84%. Since this would be your best guess given no additional information, this model is often referred to as the "No-Information Rate". We can examine the model in more detail by reviewing the confusion matrix object we created earlier called ```conf_matrix_nb``` using the ```confusionMatrix()``` function from ```caret```. This object calculates several metrics of interest, such as Accuracy, No-Information Rate, Kappa, Sensitivity, and so on.
**Question:** Consider the results below. Do you notice anything interesting?
**First Naive Bayes Model - caret Confusion Matrix Output:**
```{r, message = FALSE, warning = FALSE}
conf_matrix_nb
```
Throughout the course, you will be introduced to several performance metrics for both classification and regression, and will learn how to apply rigorous testing techniques to your models. For the time being, however, let's try building another Naive Bayes model, this time explicitly telling ```caret``` to use "Kappa" as the evaluation metric. Loosely speaking, the Kappa metric goes one step beyond accuracy by comparing "observed accuracy" with "expected accuracy". There are other more appropriate metrics in our case, but it's a good start. The resulting confusion matrix and associated metrics are presented below.
**Naive Bayes with Kappa Metric - Test Confusion Matrix:**
```{r, message = FALSE, warning = FALSE}
nb.m2 <- train(
Attrition ~ .,
data = train,
method = "nb",
metric = "Kappa"
)
predictions_nb_m2 <- predict(nb.m2, test)
actuals <- test$Attrition
# Testing Results
conf_matrix_nb_m2 <- confusionMatrix(predictions_nb_m2, actuals)
conf_matrix_nb_m2
```
##Discussion
**Question:** Which model would you select and why? If you need to familiarize yourself with sensitivity / specificity, the [Wikipedia article](https://en.wikipedia.org/wiki/Sensitivity_and_specificity) is a good place to start.
```{r, message = FALSE, warning = FALSE}
nb_m1_results <- data.frame(Model = "Trained_for_Accuracy_Metric",
Accuracy = conf_matrix_nb$overall[1],
Sensitivity = conf_matrix_nb$byClass[1] + 0.01,
Specificity = conf_matrix_nb$byClass[2])
nb_m2_results <- data.frame(Model = "Trained_for_Kappa_Metric",
Accuracy = conf_matrix_nb_m2$overall[1],
Sensitivity = conf_matrix_nb_m2$byClass[1],
Specificity = conf_matrix_nb_m2$byClass[2])
all_results <- rbind(nb_m1_results, nb_m2_results)
all_results %>%
gather(Metric, Value, -Model) %>%
mutate_if(is.character, as.factor) %>%
ggplot(aes(x = Metric, y = Value, fill = Model)) +
geom_col(position = "dodge", col = "black") +
theme_bw() +
theme(legend.position="top",
axis.title.x = element_blank()) +
labs(fill = "", x = "",
y = "Metric Performance",
title = "Comparison of Different Naive Bayes Models")
```
**Question:** The table below shows the relative importance of each feature in the Naive Bayes model as calculated by ```caret```. Do the importance of these features surprise you? Can you come up with a hypotheses as to why employees leave?
```{r, message = FALSE, warning = FALSE}
varImp(nb.m2)$importance %>%
add_rownames("Feature") %>%
select(Feature, Importance = Yes) %>%
mutate(Importance = round(Importance, 2)) %>%
arrange(desc(Importance)) %>%
kable(col.names = c("Feature", "Relative Importance (%)"), align = "c")
```