-
Notifications
You must be signed in to change notification settings - Fork 0
/
clase_06_03_KNN_simple.py
73 lines (51 loc) · 1.69 KB
/
clase_06_03_KNN_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
Universidad Adolfo Ibañez
Facultad de Ingeniería y Ciencias
TICS 585 - Reconocimiento de Patrones en imágenes
Algoritmo KNN en biblioteca SKLEARN
Autor:. Miguel Carrasco (16-08-2021)
rev.1.0
"""
import matplotlib.pyplot as plt
import random
import numpy as np
from sklearn.datasets import make_blobs
from pandas import DataFrame
from sklearn.neighbors import NearestNeighbors
r_seed = 42
# Generamos tres clusters 2D con 1000 puntos
X, y = make_blobs(n_samples=1000, centers=5, n_features=2, random_state=r_seed)
random.seed(r_seed)
random_pts = []
# Generamos puntos aleatorios que se agregan a los datos
for i in range(50):
random_pts.append([random.randint(-10, 10), random.randint(-10, 10)])
X = np.append(X, random_pts, axis=0)
#generamos un dataframe
df = DataFrame(dict(x=X[:,0], y=X[:,1]))
df.plot(kind='scatter', x='x', y='y')
plt.show()
#aplicamos el algoritmo KNN
k = 5 #número de vecinos
knn = NearestNeighbors(n_neighbors=k)
knn.fit(X)
# Calcula la distancia a los k vecinos
neighbors_and_distances = knn.kneighbors(X)
knn_distances = neighbors_and_distances[0]
# distancia promedio a los k vecinos más cercanos
tnn_distance = np.mean(knn_distances, axis=1)
plt.figure()
plt.plot(np.sort(tnn_distance))
plt.title('Punto codo para seleccionar el umbral (Threshold)')
plt.ylabel('Distancias entre vecinos por punto')
plt.show()
#definir un umbral (el umbral se define en función del punto codo)
umbral = 1
indices = tnn_distance > umbral
print(knn_distances)
#PCM = df.plot(kind='scatter', x='x', y='y', c=tnn_distance, colormap='viridis')
plt.scatter(X[:, 0], X[:,1], color='red')
plt.scatter(X[indices, 0], X[indices,1], color='blue')
plt.xlabel('X')
plt.xlabel('Y')
plt.show()