标签:
代码:
1 import numpy as np 2 import csv 3 import math as mt 4 5 def hypo(tt,xx): #hypothesis函数 6 exp=mt.e 7 tmp=0.0 8 for i in range(0,4): 9 tmp+=tt[i]*xx[i] 10 ans=mt.pow(exp,tmp)/(1+mt.pow(exp,tmp)) 11 return ans 12 13 def GDA(tt,iter): #随机梯度下降 14 for i in range(1,num+1): #[1..100] 15 tmp=hypo(tt,x[i]) 16 for j in range(0,4): #[0..3] 17 tt[j]=tt[j]+iternum(iter,i)*(y[i]-tmp)*x[i][j] 18 19 def likeli(tt): #计算似然函数 20 tmp=0.0 21 for i in range(1,num+1): 22 tm=hypo(tt,x[i]) 23 tmp+=(y[i]*mt.log(tm)+(1-y[i])*mt.log(1-tm)) 24 return tmp 25 26 def iternum(k,j): #计算迭代因子的函数【迭代次数不同,迭代因子也不同】 27 tmp=0.1/(k+j+1)+0.1 28 return tmp 29 30 trainfile=file(‘train.csv‘,‘rb‘) 31 trainread=csv.reader(trainfile) 32 testfile=file(‘test.csv‘,‘rb‘) 33 testread=csv.reader(testfile) 34 35 x=np.zeros((105,5),float) #(x,y)是训练样本 36 y=np.zeros(105,float) 37 tx=np.zeros(5,float) #parameter 38 ty=np.zeros(5,float) 39 dx=np.zeros((105,5),float) #(dx,dy)是测试样本 40 dy=np.zeros(105,float) 41 42 num=0 43 for line in trainread: 44 num=num+1 45 x[num]=line 46 y[num]=x[num][4] 47 #x[1..100] y[1..100] 48 trainfile.close() 49 50 dnum=0; 51 for line in testread: 52 dnum+=1 53 dx[dnum]=line 54 dy[dnum]=dx[dnum][4] 55 testfile.close() 56 57 for i in range(1,num+1): 58 print(x[i],y[i]) 59 print(" ----- ") 60 for i in range(1,dnum+1): 61 print(dx[i],dy[i]) 62 63 iter=0 64 lx=99999.0 65 ly=likeli(ty) 66 while(mt.fabs(ly-lx)>0.01): 67 print(iter,likeli(ty),ty) 68 lx=ly 69 GDA(ty,iter) 70 iter+=1 71 ly=likeli(ty) 72 73 print (" ") 74 75 part=0.5 76 for i in range(1,dnum+1): 77 tmp=hypo(ty,dx[i]) 78 if(tmp<part): 79 ans=0 80 else: 81 ans=1 82 print(dy[i],tmp,ans)
运行结果:
1 //训练样本过程: 2 //(id,似然函数值,[parameter]) 3 (0, -69.314718055994589, array([ 0., 0., 0., 0., 0.])) 4 (1, -179.19591664172614, array([ 0.47418735, 0.06857472, 0.6080708 , 0.21795486, 0. ])) 5 (2, -121.62179146512122, array([ 0.33221938, -0.24682766, 0.97199997, 0.37909118, 0. ])) 6 (3, -80.005102980506564, array([ 0.22184855, -0.51616823, 1.29138288, 0.51894067, 0. ])) 7 (4, -52.113434710918014, array([ 0.13470815, -0.74770815, 1.57528938, 0.64244598, 0. ])) 8 (5, -31.590025749624512, array([ 0.04771585, -0.95553279, 1.81895741, 0.74985357, 0. ])) 9 (6, -17.452182808491344, array([-0.04598164, -1.14217819, 2.02004087, 0.84120529, 0. ])) 10 (7, -10.267976018117292, array([-0.1248766 , -1.29754958, 2.1927993 , 0.92091932, 0. ])) 11 (8, -6.6580675453339095, array([-0.18789329, -1.42677758, 2.34511257, 0.9918262 , 0. ])) 12 (9, -4.6188250760341996, array([-0.24093233, -1.5375812 , 2.47958712, 1.05499614, 0. ])) 13 (10, -3.3744821853240419, array([-0.2866801 , -1.63418248, 2.59854375, 1.11126847, 0. ])) 14 (11, -2.5747996461385223, array([-0.32646178, -1.71915875, 2.7042384 , 1.16151128, 0. ])) 15 (12, -2.0385459675596458, array([-0.36122587, -1.79441754, 2.79869276, 1.20656554, 0. ])) 16 (13, -1.6650980025135245, array([-0.39176125, -1.86149956, 2.88364276, 1.24718908, 0. ])) 17 (14, -1.3960367006166241, array([-0.41873723, -1.92167844, 2.96054485, 1.28403573, 0. ])) 18 (15, -1.1961863968105035, array([-0.442717 , -1.97600919, 3.03060539, 1.31765663, 0. ])) 19 (16, -1.0436583937366917, array([-0.46416987, -2.02536393, 3.09481781, 1.34851106, 0. ])) 20 (17, -0.92441126768741677, array([-0.48348429, -2.07046263, 3.15399946, 1.37698019, 0. ])) 21 (18, -0.82917964441403547, array([-0.50098067, -2.11190002, 3.20882445, 1.40338032, 0. ])) 22 (19, -0.75168763375440906, array([-0.51692311, -2.15016851, 3.25985117, 1.42797471, 0. ])) 23 (20, -0.68758023105857857, array([-0.53152965, -2.18567731, 3.3075448 , 1.45098335, 0. ])) 24 (21, -0.633767252199817, array([-0.5449808 , -2.21876813, 3.35229525, 1.47259108, 0. ])) 25 (22, -0.5880105475158135, array([-0.55742672, -2.2497278 , 3.39443149, 1.4929541 , 0. ])) 26 (23, -0.54865823718178242, array([-0.56899284, -2.27879848, 3.43423286, 1.51220514, 0. ])) 27 (24, -0.51446987092368923, array([-0.57978456, -2.30618579, 3.47193808, 1.53045774, 0. ])) 28 (25, -0.48449905631347406, array([-0.58989091, -2.33206535, 3.5077524 , 1.54780952, 0. ])) 29 (26, -0.45801316245868812, array([-0.59938755, -2.35658801, 3.5418533 , 1.56434493, 0. ])) 30 (27, -0.43443740873146985, array([-0.60833908, -2.37988406, 3.57439518, 1.58013741, 0. ])) 31 (28, -0.41331528565404385, array([-0.61680104, -2.40206665, 3.60551298, 1.59525118, 0. ])) 32 (29, -0.39428010255982965, array([-0.62482134, -2.42323451, 3.63532529, 1.60974263, 0. ])) 33 (30, -0.37703423805211778, array([-0.6324416 , -2.44347427, 3.6639368 , 1.62366156, 0. ])) 34 (31, -0.36133380365882289, array([-0.63969811, -2.46286221, 3.69144031, 1.6370521 , 0. ])) 35 (32, -0.34697716567487236, array([-0.64662266, -2.48146587, 3.71791846, 1.64995354, 0. ])) 36 (33, -0.33379625351555026, array([-0.65324325, -2.49934526, 3.74344514, 1.662401 , 0. ])) 37 (34, -0.32164990574178071, array([-0.65958462, -2.5165539 , 3.76808662, 1.67442599, 0. ])) 38 (35, -0.31041872365028822, array([-0.66566872, -2.53313973, 3.79190257, 1.68605687, 0. ])) 39 (36, -0.30000105253087223, array([-0.67151512, -2.54914582, 3.8149469 , 1.69731925, 0. ])) 40 41 42 //预测新数据结果: 43 //(原类别,hypothesis值,分类结果) 44 (0.0, 0.004116287123555463, 0) 45 (0.0, 0.004491299234282269, 0) 46 (0.0, 0.001997774439620067, 0) 47 (0.0, 9.711727014021101e-05, 0) 48 (1.0, 0.9986958360885878, 1) 49 (1.0, 0.999907833813241, 1) 50 (1.0, 0.998089176390621, 1) 51 (1.0, 0.9999771709114254, 1) 52 (1.0, 0.9998452542803238, 1) 53 54 Process finished with exit code 0
训练样本及数据下载:
original:http://archive.ics.uci.edu/ml/datasets/Iris
csv文件:http://files.cnblogs.com/files/pdev/stochastic_GDA.zip
标签:
原文地址:http://www.cnblogs.com/pdev/p/4575354.html