This is one of the 100+ free recipes of the IPython Cookbook, Second Edition, by Cyrille Rossant, a guide to numerical computing and data science in the Jupyter Notebook. The ebook and printed book are available for purchase at Packt Publishing.
▶ Text on GitHub with a CC-BY-NC-ND license
▶ Code on GitHub with a MIT license
In this recipe, we will see how to recognize handwritten digits with a K-nearest neighbors (K-NN) classifier. This classifier is a simple but powerful model, well-adapted to complex, highly nonlinear datasets such as images. We will explain how it works later in this recipe.
- We import the modules:
import numpy as np
import sklearn
import sklearn.datasets as ds
import sklearn.model_selection as ms
import sklearn.neighbors as nb
import matplotlib.pyplot as plt
%matplotlib inline
- Let's load the digits dataset, part of the
datasets
module of scikit-learn. This dataset contains handwritten digits that have been manually labeled:
digits = ds.load_digits()
X = digits.data
y = digits.target
print((X.min(), X.max()))
print(X.shape)
(0.0, 16.0)
(1797, 64)
In the matrix X
, each row contains 8*8=64
pixels (in grayscale, values between 0 and 16). The row-major ordering is used.
- Let's display some of the images along with their labels:
nrows, ncols = 2, 5
fig, axes = plt.subplots(nrows, ncols,
figsize=(6, 3))
for i in range(nrows):
for j in range(ncols):
# Image index
k = j + i * ncols
ax = axes[i, j]
ax.matshow(digits.images[k, ...],
cmap=plt.cm.gray)
ax.set_axis_off()
ax.set_title(digits.target[k])
- Now, let's fit a K-nearest neighbors classifier on the data:
(X_train, X_test, y_train, y_test) = \
ms.train_test_split(X, y, test_size=.25)
knc = nb.KNeighborsClassifier()
knc.fit(X_train, y_train)
- Let's evaluate the score of the trained classifier on the test dataset:
knc.score(X_test, y_test)
0.987
- Now, let's see if our classifier can recognize a handwritten digit:
# Let's draw a 1.
one = np.zeros((8, 8))
one[1:-1, 4] = 16 # The image values are in [0, 16].
one[2, 3] = 16
fig, ax = plt.subplots(1, 1, figsize=(2, 2))
ax.imshow(one, interpolation='none',
cmap=plt.cm.gray)
ax.grid(False)
ax.set_axis_off()
ax.set_title("One")
# We need to pass a (1, D) array.
knc.predict(one.reshape((1, -1)))
array([1])
Good job!
This example illustrates how to deal with images in scikit-learn. An image is a 2D
The idea of K-nearest neighbors is as follows: given a new point in the feature space, find the K closest points from the training set and assign the label of the majority of those points.
The distance is generally the Euclidean distance, but other distances can be used too.
The following image, obtained from the scikit-learn documentation at http://scikit-learn.org/stable/modules/neighbors.html, shows the space partition obtained with a 15-nearest-neighbors classifier on a toy dataset (with three labels):
The number
It should be noted that no model is learned by a K-nearest neighbor algorithm; the classifier just stores all data points and compares any new target points with them. This is an example of instance-based learning. It is in contrast to other classifiers such as the logistic regression model, which explicitly learns a simple mathematical model on the training data.
The K-nearest neighbors method works well on complex classification problems that have irregular decision boundaries. However, it might be computationally intensive with large training datasets because a large number of distances have to be computed for testing. Dedicated tree-based data structures such as K-D trees or ball trees can be used to accelerate the search of nearest neighbors.
The K-nearest neighbors method can be used for classification, like here, and also for regression problems. The model assigns the average of the target value of the nearest neighbors. In both cases, different weighting strategies can be used.
Here are a few references:
- The K-NN algorithm in scikit-learn's documentation, available at http://scikit-learn.org/stable/modules/neighbors.html
- The K-NN algorithm on Wikipedia, available at https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm
- Blog post about how to choose the K hyperparameter, available at http://datasciencelab.wordpress.com/2013/12/27/finding-the-k-in-k-means-clustering/
- Instance-based learning on Wikipedia, available at https://en.wikipedia.org/wiki/Instance-based_learning
- Predicting who will survive on the Titanic with logistic regression
- Using support vector machines for classification tasks