码迷,mamicode.com
首页 > 其他好文 > 详细

机器学习~用于机器学习中的分类边界、决策树等可视化的模块

时间:2018-07-26 18:38:19      阅读:419      评论:0      收藏:0      [点我收藏+]

标签:points   x86   region   import   lse   数据   维数   bin   led   

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import neighbors
import graphviz
from sklearn.tree import export_graphviz
import matplotlib.patches as mpatches


def plot_decision_tree(clf, feature_names, class_names):
"""
决策树结果可视化
需要安装
1. graphviz程序(已提供在代码目录下),并将安装目录下的bin目录添加到环境变量中,重启jupyter或系统生效
如:C:\Program Files (x86)\Graphviz2.38\bin 添加到系统PATH环境变量中
2. graphviz模块, pip install graphviz
"""

tmp_dot_file = ‘decision_tree_tmp.dot‘
export_graphviz(clf, out_file=tmp_dot_file, feature_names=feature_names, class_names=class_names,
filled=True, impurity=False)
with open(tmp_dot_file) as f:
dot_graph = f.read()
# Alternate method using pydotplus, if installed.
# graph = pydotplus.graphviz.graph_from_dot_data(dot_graph)
# return graph.create_png()
return graphviz.Source(dot_graph)


def plot_feature_importances(clf, feature_names):
"""
可视化分类器中特征的重要性
"""
c_features = len(feature_names)
plt.barh(range(c_features), clf.feature_importances_)
plt.xlabel(‘Feature importance‘)
plt.ylabel(‘Feature name‘)
plt.yticks(np.arange(c_features), feature_names)


def plot_class_regions_for_classifier(clf, X, y, X_test=None, y_test=None, title=None,
target_names=None, plot_decision_regions=True):
"""
根据分类器可视化数据分类的结果
只能用于二维特征的数据
"""

num_classes = np.amax(y) + 1
color_list_light = [‘#FFFFAA‘, ‘#EFEFEF‘, ‘#AAFFAA‘, ‘#AAAAFF‘]
color_list_bold = [‘#EEEE00‘, ‘#000000‘, ‘#00CC00‘, ‘#0000CC‘]
cmap_light = ListedColormap(color_list_light[0:num_classes])
cmap_bold = ListedColormap(color_list_bold[0:num_classes])

h = 0.03
k = 0.5
x_plot_adjust = 0.1
y_plot_adjust = 0.1
plot_symbol_size = 50

x_min = X[:, 0].min()
x_max = X[:, 0].max()
y_min = X[:, 1].min()
y_max = X[:, 1].max()
x2, y2 = np.meshgrid(np.arange(x_min-k, x_max+k, h), np.arange(y_min-k, y_max+k, h))

P = clf.predict(np.c_[x2.ravel(), y2.ravel()])
P = P.reshape(x2.shape)
plt.figure()
if plot_decision_regions:
plt.contourf(x2, y2, P, cmap=cmap_light, alpha=0.8)

plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, s=plot_symbol_size, edgecolor=‘black‘)
plt.xlim(x_min - x_plot_adjust, x_max + x_plot_adjust)
plt.ylim(y_min - y_plot_adjust, y_max + y_plot_adjust)

if X_test is not None:
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cmap_bold, s=plot_symbol_size,
marker=‘^‘, edgecolor=‘black‘)
train_score = clf.score(X, y)
test_score = clf.score(X_test, y_test)
title = title + "\nTrain score = {:.2f}, Test score = {:.2f}".format(train_score, test_score)

if target_names is not None:
legend_handles = []
for i in range(0, len(target_names)):
patch = mpatches.Patch(color=color_list_bold[i], label=target_names[i])
legend_handles.append(patch)
plt.legend(loc=0, handles=legend_handles)

if title is not None:
plt.title(title)
plt.show()


def plot_fruit_knn(X, y, n_neighbors):
"""
在“水果数据集”上对 height 和 width 二维数据进行kNN训练
并绘制出结果
"""
X_mat = X[[‘height‘, ‘width‘]].as_matrix()
y_mat = y.as_matrix()

# Create color maps
cmap_light = ListedColormap([‘#FFAAAA‘, ‘#AAFFAA‘, ‘#AAAAFF‘, ‘#AFAFAF‘])
cmap_bold = ListedColormap([‘#FF0000‘, ‘#00FF00‘, ‘#0000FF‘, ‘#AFAFAF‘])

clf = neighbors.KNeighborsClassifier(n_neighbors)
clf.fit(X_mat, y_mat)

# Plot the decision boundary by assigning a color in the color map
# to each mesh point.

mesh_step_size = .01 # step size in the mesh
plot_symbol_size = 50

x_min, x_max = X_mat[:, 0].min() - 1, X_mat[:, 0].max() + 1
y_min, y_max = X_mat[:, 1].min() - 1, X_mat[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, mesh_step_size),
np.arange(y_min, y_max, mesh_step_size))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

# Put the result into a color plot
Z = Z.reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx, yy, Z, cmap=cmap_light)

# Plot training points
plt.scatter(X_mat[:, 0], X_mat[:, 1], s=plot_symbol_size, c=y, cmap=cmap_bold,
edgecolor=‘black‘)
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())

patch0 = mpatches.Patch(color=‘#FF0000‘, label=‘apple‘)
patch1 = mpatches.Patch(color=‘#00FF00‘, label=‘mandarin‘)
patch2 = mpatches.Patch(color=‘#0000FF‘, label=‘orange‘)
patch3 = mpatches.Patch(color=‘#AFAFAF‘, label=‘lemon‘)
plt.legend(handles=[patch0, patch1, patch2, patch3])

plt.xlabel(‘height (cm)‘)
plt.ylabel(‘width (cm)‘)

plt.show()

机器学习~用于机器学习中的分类边界、决策树等可视化的模块

标签:points   x86   region   import   lse   数据   维数   bin   led   

原文地址:https://www.cnblogs.com/arthur-54271/p/9372761.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!