码迷,mamicode.com
首页 > 其他好文 > 详细

支持向量机

时间:2016-08-01 08:05:59      阅读:140      评论:0      收藏:0      [点我收藏+]

标签:

  1 # encoding: utf-8
  2 import numpy as np
  3 import matplotlib.pyplot as plt
  4 
  5 
  6 class SVC(object):
  7     def __init__(self, c=1.0, delta=0.001):  # 初始化
  8         self.N = 0
  9         self.delta = delta
 10         self.X = None
 11         self.y = None
 12         self.w = None
 13         self.wn = 0
 14         self.K = np.zeros((self.N, self.N))
 15         self.a = np.zeros((self.N, 1))
 16         self.b = 0
 17         self.C = c
 18         self.stop=1
 19 
 20     def kernel_function(self,x1, x2):  # 核函数
 21         return np.dot(x1, x2)
 22 
 23     def kernel_matrix(self, x):  # 核矩阵
 24         for i in range(0, len(x)):
 25             for j in range(i, len(x)):
 26                 self.K[j][i] = self.K[i][j] = self.kernel_function(self.X[i], self.X[j])
 27 
 28     def get_w(self):  # 计算更新w
 29         ay = self.a * self.y
 30         w = np.zeros((1, self.wn))
 31         for i in range(0, self.N):
 32             w += self.X[i] * ay[i]
 33         return w
 34 
 35     def get_b(self, a1, a2, a1_old, a2_old):  # 计算更新B
 36         y1 = self.y[a1]
 37         y2 = self.y[a2]
 38         a1_new = self.a[a1]
 39         a2_new = self.a[a2]
 40         b1_new = -self.E[a1] - y1 * self.K[a1][a1] * (a1_new - a1_old) - y2 * self.K[a2][a1] * (
 41             a2_new - a2_old) + self.b
 42         b2_new = -self.E[a2] - y1 * self.K[a1][a2] * (a1_new - a1_old) - y2 * self.K[a2][a2] * (
 43             a2_new - a2_old) + self.b
 44         if (0 < a1_new) and (a1_new < self.C) and (0 < a2_new) and (a2_new < self.C):
 45             return b1_new[0]
 46         else:
 47             return (b1_new[0] + b2_new[0]) / 2.0
 48 
 49     def gx(self, x):  # 判别函数g(x)
 50         return np.dot(self.w, x) + self.b
 51 
 52     def satisfy_kkt(self, a):  # 判断样本点是否满足kkt条件
 53         index = a[1]
 54         if a[0] == 0 and self.y[index] * self.gx(self.X[index]) > 1:
 55             return 1
 56         elif a[0] < self.C and self.y[index] * self.gx(self.X[index]) == 1:
 57             return 1
 58         elif a[0] == self.C and self.y[index] * self.gx(self.X[index]) < 1:
 59             return 1
 60         return 0
 61 
 62     def clip_func(self, a_new, a1_old, a2_old, y1, y2):  # 拉格朗日乘子的裁剪函数
 63         if (y1 == y2):
 64             L = max(0, a1_old + a2_old - self.C)
 65             H = min(self.C, a1_old + a2_old)
 66         else:
 67             L = max(0, a2_old - a1_old)
 68             H = min(self.C, self.C + a2_old - a1_old)
 69         if a_new < L:
 70             a_new = L
 71         if a_new > H:
 72             a_new = H
 73         return a_new
 74 
 75     def update_a(self, a1, a2):  # 更新a1,a2
 76         partial_a2 = self.K[a1][a1] + self.K[a2][a2] - 2 * self.K[a1][a2]
 77         if partial_a2 <= 1e-9:
 78             print "error:", partial_a2
 79         a2_new_unc = self.a[a2] + (self.y[a2] * ((self.E[a1] - self.E[a2]) / partial_a2))
 80         a2_new = self.clip_func(a2_new_unc, self.a[a1], self.a[a2], self.y[a1], self.y[a2])
 81         a1_new = self.a[a1] + self.y[a1] * self.y[a2] * (self.a[a2] - a2_new)
 82         if abs(a1_new - self.a[a1]) < self.delta:
 83             return 0
 84         self.a[a1] = a1_new
 85         self.a[a2] = a2_new
 86         self.is_update = 1
 87         return 1
 88 
 89     def update(self, first_a):  # 更新拉格朗日乘子
 90         for second_a in range(0, self.N):
 91             if second_a == first_a:
 92                 continue
 93             a1_old = self.a[first_a]
 94             a2_old = self.a[second_a]
 95             if self.update_a(first_a, second_a) == 0:
 96                 return
 97             self.b= self.get_b(first_a, second_a, a1_old, a2_old)
 98             self.w = self.get_w()
 99             self.E = [self.gx(self.X[i]) - self.y[i] for i in range(0, self.N)]
100             self.stop=0
101 
102     def train(self, x, y, max_iternum=100):  # SVC
103         x_len = len(x)
104         self.X = x
105         self.N = x_len
106         self.wn = len(x[0])
107         self.y = np.array(y).reshape((self.N, 1))
108         self.K = np.zeros((self.N, self.N))
109         self.kernel_matrix(self.X)
110         self.b = 0
111         self.a = np.zeros((self.N, 1))
112         self.w = self.get_w()
113         self.E = [self.gx(self.X[i]) - self.y[i] for i in range(0, self.N)]
114         self.is_update = 0
115         for i in range(0, max_iternum):
116             self.stop=1
117             data_on_bound = [[x,y] for x,y in zip(self.a, range(0, len(self.a))) if x > 0 and x< self.C]
118             if len(data_on_bound) == 0:
119                 data_on_bound = [[x,y] for x,y in zip(self.a, range(0, len(self.a)))]
120             for data in data_on_bound:
121                 if self.satisfy_kkt(data) != 1:
122                     self.update(data[1])
123             if self.is_update == 0:
124                 for data in [[x,y] for x,y in zip(self.a, range(0, len(self.a)))]:
125                     if self.satisfy_kkt(data) != 1:
126                         self.update(data[1])
127             if self.stop:
128                 break
129         print self.w, self.b
130 
131     def draw_p(self):  # 作图
132         min_x = min(min(self.X[:,0]),min(self.X[:,1])) - 0.1
133         max_x = max(max(self.X[:,0]), max(self.X[:,1])) +0.1
134         w = -self.w[0][0]/self.w[0][1]
135         b = -self.b/self.w[0][1]
136         r = 1/self.w[0][1]
137         x_line = (min_x, max_x)
138         plt.plot(x_line, [w*x+b for x in x_line], "b")
139         plt.plot(x_line, [w*x+b+r for x in x_line], "b--")
140         plt.plot(x_line, [w*x+b-r for x in x_line], "r--")
141         [plt.plot(self.X[i, 0], self.X[i, 1], "ob" if self.y[i] == 1 else "or") for i in range(0, self.N)]
142         plt.show()
143 
144 if __name__ == "__main__":
145     svc = SVC()
146     np.random.seed(0)技术分享
147     X = np.r_[np.random.randn(20, 2) - [2, 2], np.random.randn(20, 2) + [2, 2]]
148     Y = [-1] * 20 + [1] * 20
149     svc.train(X, Y)
150     svc.draw_p()

 

支持向量机

标签:

原文地址:http://www.cnblogs.com/qw12/p/5724519.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!