-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] decision surface, regression and KNN
- Loading branch information
Showing
4 changed files
with
144 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# | ||
from sklearn.datasets import load_iris | ||
iris = load_iris() | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from sklearn.datasets import load_iris | ||
from sklearn.inspection import DecisionBoundaryDisplay | ||
from sklearn.tree import DecisionTreeClassifier | ||
|
||
# Parameters | ||
n_classes = 3 | ||
plot_colors = "ryb" | ||
plot_step = 0.02 | ||
|
||
|
||
for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]): | ||
# We only take the two corresponding features | ||
X = iris.data[:, pair] | ||
y = iris.target | ||
|
||
# Train | ||
clf = DecisionTreeClassifier().fit(X, y) | ||
|
||
# Plot the decision boundary | ||
ax = plt.subplot(2, 3, pairidx + 1) | ||
plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5) | ||
DecisionBoundaryDisplay.from_estimator( | ||
clf, | ||
X, | ||
cmap=plt.cm.RdYlBu, | ||
response_method="predict", | ||
ax=ax, | ||
xlabel=iris.feature_names[pair[0]], | ||
ylabel=iris.feature_names[pair[1]], | ||
) | ||
|
||
# Plot the training points | ||
for i, color in zip(range(n_classes), plot_colors): | ||
idx = np.where(y == i) | ||
plt.scatter( | ||
X[idx, 0], | ||
X[idx, 1], | ||
c=color, | ||
label=iris.target_names[i], | ||
cmap=plt.cm.RdYlBu, | ||
edgecolor="black", | ||
s=15, | ||
) | ||
|
||
plt.suptitle("Decision surface of decision trees trained on pairs of features") | ||
plt.legend(loc="lower right", borderpad=0, handletextpad=0) | ||
_ = plt.axis("tight") | ||
|
||
|
||
# | ||
from sklearn.tree import plot_tree | ||
|
||
plt.figure() | ||
clf = DecisionTreeClassifier().fit(iris.data, iris.target) | ||
plot_tree(clf, filled=True) | ||
plt.title("Decision tree trained on all the iris features") | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# 决策树 | ||
# Import the necessary modules and libraries | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from sklearn.tree import DecisionTreeRegressor | ||
|
||
# Create a random dataset | ||
rng = np.random.RandomState(1) | ||
X = np.sort(5 * rng.rand(80, 1), axis=0) | ||
y = np.sin(X).ravel() | ||
y[::5] += 3 * (0.5 - rng.rand(16)) | ||
|
||
# Fit regression model | ||
regr_1 = DecisionTreeRegressor(max_depth=2) | ||
regr_2 = DecisionTreeRegressor(max_depth=5) | ||
regr_1.fit(X, y) | ||
regr_2.fit(X, y) | ||
|
||
# Predict | ||
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis] | ||
y_1 = regr_1.predict(X_test) | ||
y_2 = regr_2.predict(X_test) | ||
|
||
# Plot the results | ||
plt.figure() | ||
plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data") | ||
plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=2", linewidth=2) | ||
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2) | ||
plt.xlabel("data") | ||
plt.ylabel("target") | ||
plt.title("Decision Tree Regression") | ||
plt.legend() | ||
plt.show() | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
zhu78244
via email
Author
Collaborator
|
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#鸢尾KNN算法分类 | ||
#https://blog.csdn.net/hqllqh/article/details/108914072 | ||
|
||
import numpy as np | ||
from sklearn.datasets import load_iris | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.neighbors import KNeighborsClassifier | ||
|
||
iris_data = load_iris() | ||
This comment has been minimized.
Sorry, something went wrong. |
||
# 该函数返回一个Bunch对象,它直接继承自Dict类,与字典类似,由键值对组成。 | ||
# 可以使用bunch.keys(),bunch.values(),bunch.items()等方法。 | ||
print(type(iris_data)) | ||
# data里面是花萼长度、花萼宽度、花瓣长度、花瓣宽度的测量数据,格式为 NumPy数组 | ||
print(iris_data['data']) # 花的样本数据 | ||
print("花的样本数量:{}".format(iris_data['data'].shape)) | ||
print("花的前5个样本数据:{}".format(iris_data['data'][:5])) | ||
|
||
# 0 代表 setosa, 1 代表 versicolor,2 代表 virginica | ||
print(iris_data['target']) # 类别 | ||
print(iris_data['target_names']) # 花的品种 | ||
|
||
# 构造训练数据和测试数据 | ||
X_train,X_test,y_train,y_test = train_test_split(\ | ||
iris_data['data'],iris_data['target'],random_state=0) | ||
print("训练样本数据的大小:{}".format(X_train.shape)) | ||
print("训练样本标签的大小:{}".format(y_train.shape)) | ||
print("测试样本数据的大小:{}".format(X_test.shape)) | ||
print("测试样本标签的大小:{}".format(y_test.shape)) | ||
|
||
# 构造KNN模型 | ||
knn = KNeighborsClassifier(n_neighbors=1) | ||
# knn = KNeighborsClassifier(n_neighbors=3) | ||
|
||
# 训练模型 | ||
knn.fit(X_train,y_train) | ||
y_pred = knn.predict(X_test) | ||
|
||
# 评估模型 | ||
print("模型精度:{:.2f}".format(np.mean(y_pred==y_test))) | ||
print("模型精度:{:.2f}".format(knn.score(X_test,y_test))) | ||
|
||
# 做出预测 | ||
X_new = np.array([[1.1,5.9,1.4,2.2]]) | ||
prediction = knn.predict(X_new) | ||
print("预测的目标类别是:{}".format(prediction)) | ||
print("预测的目标类别花名是:{}".format(iris_data['target_names'][prediction])) |
图片可以保存成文件,插入到论文或者留给老师看,或者插入到README.md中