码迷,mamicode.com
首页 > 其他好文 > 详细

02-03 感知机对偶形式(鸢尾花分类)

时间:2019-10-16 17:58:15      阅读:94      评论:0      收藏:0      [点我收藏+]

标签:nts   https   points   end   技术   def   library   learning   base   

更新、更全的《机器学习》的更新网站,更有python、go、数据结构与算法、爬虫、人工智能教学等着你:https://www.cnblogs.com/nickchen121/

感知机对偶形式(鸢尾花分类)

一、导入模块

from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
%matplotlib inline
font = FontProperties(fname='/Library/Fonts/Heiti.ttc')

二、获取数据

def get_data():
    df = pd.read_csv(
        'http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None)
    X = df.iloc[0:100, [0, 2]].values
    train_data_p = df.iloc[0:50, [0, 2, 4]].values
    train_data_n = df.iloc[50:100, [0, 2, 4]].values
    train_data_p[:, [2]], train_data_n[:, [2]] = -1, 1
    train_data = train_data_p.tolist() + train_data_n.tolist()

    return train_data, X

三、训练模型

def train(num_iter, train_data, learning_rate):
    w = 0.0
    b = 0
    data_length = len(train_data)
    alpha = [0 for _ in range(data_length)]
    train_data = np.array(train_data)
    gram = np.matmul(train_data[:, 0:-1], train_data[:, 0:-1].T)
    for i in range(num_iter):
        count = 0
        i = random.randint(0, data_length - 1)
        yi = train_data[i, -1]
        for j in range(data_length):
            count += alpha[j] * train_data[j, -1] * gram[i, j]
        count += b
        if (yi * count <= 0):
            alpha[i] = alpha[i] + learning_rate
            b = b + learning_rate * yi
    for i in range(data_length):
        w += alpha[i] * train_data[i, 0:-1] * train_data[i, -1]
    return w, b, alpha, gram

四、可视化

def plot_points(w, b, X):
    plt.figure()
    x1 = np.linspace(4, 7, 100)
    x2 = (-b - w[0] * x1) / (w[1] + 1e-10)
    plt.plot(x1, x2, color='k')
    plt.scatter(X[:50, 0], X[:50, 1], color='r', s=50, marker='o', label='山鸢尾')
    plt.scatter(X[50:100, 0], X[50:100, 1], color='b',
                s=50, marker='x', label='变色鸢尾')
    plt.xlabel('萼片长度(cm)', fontproperties=font)
    plt.ylabel('花瓣长度(cm)', fontproperties=font)
    plt.legend(prop=font)
    plt.show()

五、运行

train_data, X = get_data()
w, b, alpha, gram = train(
    num_iter=1000, train_data=train_data, learning_rate=0.1)
plot_points(w, b, X)

技术图片

02-03 感知机对偶形式(鸢尾花分类)

标签:nts   https   points   end   技术   def   library   learning   base   

原文地址:https://www.cnblogs.com/nickchen121/p/11686753.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!