标签:
1 # encoding: utf-8 2 import numpy as np 3 import matplotlib.pyplot as plt 4 5 6 class SVC(object): 7 def __init__(self, c=1.0, delta=0.001): # 初始化 8 self.N = 0 9 self.delta = delta 10 self.X = None 11 self.y = None 12 self.w = None 13 self.wn = 0 14 self.K = np.zeros((self.N, self.N)) 15 self.a = np.zeros((self.N, 1)) 16 self.b = 0 17 self.C = c 18 self.stop=1 19 20 def kernel_function(self,x1, x2): # 核函数 21 return np.dot(x1, x2) 22 23 def kernel_matrix(self, x): # 核矩阵 24 for i in range(0, len(x)): 25 for j in range(i, len(x)): 26 self.K[j][i] = self.K[i][j] = self.kernel_function(self.X[i], self.X[j]) 27 28 def get_w(self): # 计算更新w 29 ay = self.a * self.y 30 w = np.zeros((1, self.wn)) 31 for i in range(0, self.N): 32 w += self.X[i] * ay[i] 33 return w 34 35 def get_b(self, a1, a2, a1_old, a2_old): # 计算更新B 36 y1 = self.y[a1] 37 y2 = self.y[a2] 38 a1_new = self.a[a1] 39 a2_new = self.a[a2] 40 b1_new = -self.E[a1] - y1 * self.K[a1][a1] * (a1_new - a1_old) - y2 * self.K[a2][a1] * ( 41 a2_new - a2_old) + self.b 42 b2_new = -self.E[a2] - y1 * self.K[a1][a2] * (a1_new - a1_old) - y2 * self.K[a2][a2] * ( 43 a2_new - a2_old) + self.b 44 if (0 < a1_new) and (a1_new < self.C) and (0 < a2_new) and (a2_new < self.C): 45 return b1_new[0] 46 else: 47 return (b1_new[0] + b2_new[0]) / 2.0 48 49 def gx(self, x): # 判别函数g(x) 50 return np.dot(self.w, x) + self.b 51 52 def satisfy_kkt(self, a): # 判断样本点是否满足kkt条件 53 index = a[1] 54 if a[0] == 0 and self.y[index] * self.gx(self.X[index]) > 1: 55 return 1 56 elif a[0] < self.C and self.y[index] * self.gx(self.X[index]) == 1: 57 return 1 58 elif a[0] == self.C and self.y[index] * self.gx(self.X[index]) < 1: 59 return 1 60 return 0 61 62 def clip_func(self, a_new, a1_old, a2_old, y1, y2): # 拉格朗日乘子的裁剪函数 63 if (y1 == y2): 64 L = max(0, a1_old + a2_old - self.C) 65 H = min(self.C, a1_old + a2_old) 66 else: 67 L = max(0, a2_old - a1_old) 68 H = min(self.C, self.C + a2_old - a1_old) 69 if a_new < L: 70 a_new = L 71 if a_new > H: 72 a_new = H 73 return a_new 74 75 def update_a(self, a1, a2): # 更新a1,a2 76 partial_a2 = self.K[a1][a1] + self.K[a2][a2] - 2 * self.K[a1][a2] 77 if partial_a2 <= 1e-9: 78 print "error:", partial_a2 79 a2_new_unc = self.a[a2] + (self.y[a2] * ((self.E[a1] - self.E[a2]) / partial_a2)) 80 a2_new = self.clip_func(a2_new_unc, self.a[a1], self.a[a2], self.y[a1], self.y[a2]) 81 a1_new = self.a[a1] + self.y[a1] * self.y[a2] * (self.a[a2] - a2_new) 82 if abs(a1_new - self.a[a1]) < self.delta: 83 return 0 84 self.a[a1] = a1_new 85 self.a[a2] = a2_new 86 self.is_update = 1 87 return 1 88 89 def update(self, first_a): # 更新拉格朗日乘子 90 for second_a in range(0, self.N): 91 if second_a == first_a: 92 continue 93 a1_old = self.a[first_a] 94 a2_old = self.a[second_a] 95 if self.update_a(first_a, second_a) == 0: 96 return 97 self.b= self.get_b(first_a, second_a, a1_old, a2_old) 98 self.w = self.get_w() 99 self.E = [self.gx(self.X[i]) - self.y[i] for i in range(0, self.N)] 100 self.stop=0 101 102 def train(self, x, y, max_iternum=100): # SVC 103 x_len = len(x) 104 self.X = x 105 self.N = x_len 106 self.wn = len(x[0]) 107 self.y = np.array(y).reshape((self.N, 1)) 108 self.K = np.zeros((self.N, self.N)) 109 self.kernel_matrix(self.X) 110 self.b = 0 111 self.a = np.zeros((self.N, 1)) 112 self.w = self.get_w() 113 self.E = [self.gx(self.X[i]) - self.y[i] for i in range(0, self.N)] 114 self.is_update = 0 115 for i in range(0, max_iternum): 116 self.stop=1 117 data_on_bound = [[x,y] for x,y in zip(self.a, range(0, len(self.a))) if x > 0 and x< self.C] 118 if len(data_on_bound) == 0: 119 data_on_bound = [[x,y] for x,y in zip(self.a, range(0, len(self.a)))] 120 for data in data_on_bound: 121 if self.satisfy_kkt(data) != 1: 122 self.update(data[1]) 123 if self.is_update == 0: 124 for data in [[x,y] for x,y in zip(self.a, range(0, len(self.a)))]: 125 if self.satisfy_kkt(data) != 1: 126 self.update(data[1]) 127 if self.stop: 128 break 129 print self.w, self.b 130 131 def draw_p(self): # 作图 132 min_x = min(min(self.X[:,0]),min(self.X[:,1])) - 0.1 133 max_x = max(max(self.X[:,0]), max(self.X[:,1])) +0.1 134 w = -self.w[0][0]/self.w[0][1] 135 b = -self.b/self.w[0][1] 136 r = 1/self.w[0][1] 137 x_line = (min_x, max_x) 138 plt.plot(x_line, [w*x+b for x in x_line], "b") 139 plt.plot(x_line, [w*x+b+r for x in x_line], "b--") 140 plt.plot(x_line, [w*x+b-r for x in x_line], "r--") 141 [plt.plot(self.X[i, 0], self.X[i, 1], "ob" if self.y[i] == 1 else "or") for i in range(0, self.N)] 142 plt.show() 143 144 if __name__ == "__main__": 145 svc = SVC() 146 np.random.seed(0) 147 X = np.r_[np.random.randn(20, 2) - [2, 2], np.random.randn(20, 2) + [2, 2]] 148 Y = [-1] * 20 + [1] * 20 149 svc.train(X, Y) 150 svc.draw_p()
标签:
原文地址:http://www.cnblogs.com/qw12/p/5724519.html