码迷,mamicode.com
首页 > 其他好文 > 详细

[Exercise]随机梯度下降、logistic回归

时间:2015-06-14 18:11:23      阅读:141      评论:0      收藏:0      [点我收藏+]

标签:

代码:

技术分享
 1 import numpy as np
 2 import csv
 3 import math as mt
 4 
 5 def hypo(tt,xx):        #hypothesis函数
 6     exp=mt.e
 7     tmp=0.0
 8     for i in range(0,4):
 9         tmp+=tt[i]*xx[i]
10     ans=mt.pow(exp,tmp)/(1+mt.pow(exp,tmp))
11     return ans
12 
13 def GDA(tt,iter):       #随机梯度下降
14     for i in range(1,num+1):        #[1..100]
15         tmp=hypo(tt,x[i])
16         for j in range(0,4):        #[0..3]
17             tt[j]=tt[j]+iternum(iter,i)*(y[i]-tmp)*x[i][j]
18 
19 def likeli(tt):         #计算似然函数
20     tmp=0.0
21     for i in range(1,num+1):
22         tm=hypo(tt,x[i])
23         tmp+=(y[i]*mt.log(tm)+(1-y[i])*mt.log(1-tm))
24     return tmp
25 
26 def iternum(k,j):       #计算迭代因子的函数【迭代次数不同,迭代因子也不同】
27     tmp=0.1/(k+j+1)+0.1
28     return tmp
29 
30 trainfile=file(train.csv,rb)
31 trainread=csv.reader(trainfile)
32 testfile=file(test.csv,rb)
33 testread=csv.reader(testfile)
34 
35 x=np.zeros((105,5),float)           #(x,y)是训练样本
36 y=np.zeros(105,float)               
37 tx=np.zeros(5,float)                #parameter
38 ty=np.zeros(5,float)
39 dx=np.zeros((105,5),float)          #(dx,dy)是测试样本
40 dy=np.zeros(105,float)
41 
42 num=0
43 for line in trainread:
44     num=num+1
45     x[num]=line
46     y[num]=x[num][4]
47 #x[1..100]      y[1..100]
48 trainfile.close()
49 
50 dnum=0;
51 for line in testread:
52     dnum+=1
53     dx[dnum]=line
54     dy[dnum]=dx[dnum][4]
55 testfile.close()
56 
57 for i in range(1,num+1):
58     print(x[i],y[i])
59 print(" ----- ")
60 for i in range(1,dnum+1):
61     print(dx[i],dy[i])
62 
63 iter=0
64 lx=99999.0
65 ly=likeli(ty)
66 while(mt.fabs(ly-lx)>0.01):
67     print(iter,likeli(ty),ty)
68     lx=ly
69     GDA(ty,iter)
70     iter+=1
71     ly=likeli(ty)
72 
73 print ("  ")
74 
75 part=0.5
76 for i in range(1,dnum+1):
77     tmp=hypo(ty,dx[i])
78     if(tmp<part):
79         ans=0
80     else:
81         ans=1
82     print(dy[i],tmp,ans)
View Code

 

运行结果:

技术分享
 1 //训练样本过程:
 2 //(id,似然函数值,[parameter])
 3 (0, -69.314718055994589, array([ 0.,  0.,  0.,  0.,  0.]))
 4 (1, -179.19591664172614, array([ 0.47418735,  0.06857472,  0.6080708 ,  0.21795486,  0.        ]))
 5 (2, -121.62179146512122, array([ 0.33221938, -0.24682766,  0.97199997,  0.37909118,  0.        ]))
 6 (3, -80.005102980506564, array([ 0.22184855, -0.51616823,  1.29138288,  0.51894067,  0.        ]))
 7 (4, -52.113434710918014, array([ 0.13470815, -0.74770815,  1.57528938,  0.64244598,  0.        ]))
 8 (5, -31.590025749624512, array([ 0.04771585, -0.95553279,  1.81895741,  0.74985357,  0.        ]))
 9 (6, -17.452182808491344, array([-0.04598164, -1.14217819,  2.02004087,  0.84120529,  0.        ]))
10 (7, -10.267976018117292, array([-0.1248766 , -1.29754958,  2.1927993 ,  0.92091932,  0.        ]))
11 (8, -6.6580675453339095, array([-0.18789329, -1.42677758,  2.34511257,  0.9918262 ,  0.        ]))
12 (9, -4.6188250760341996, array([-0.24093233, -1.5375812 ,  2.47958712,  1.05499614,  0.        ]))
13 (10, -3.3744821853240419, array([-0.2866801 , -1.63418248,  2.59854375,  1.11126847,  0.        ]))
14 (11, -2.5747996461385223, array([-0.32646178, -1.71915875,  2.7042384 ,  1.16151128,  0.        ]))
15 (12, -2.0385459675596458, array([-0.36122587, -1.79441754,  2.79869276,  1.20656554,  0.        ]))
16 (13, -1.6650980025135245, array([-0.39176125, -1.86149956,  2.88364276,  1.24718908,  0.        ]))
17 (14, -1.3960367006166241, array([-0.41873723, -1.92167844,  2.96054485,  1.28403573,  0.        ]))
18 (15, -1.1961863968105035, array([-0.442717  , -1.97600919,  3.03060539,  1.31765663,  0.        ]))
19 (16, -1.0436583937366917, array([-0.46416987, -2.02536393,  3.09481781,  1.34851106,  0.        ]))
20 (17, -0.92441126768741677, array([-0.48348429, -2.07046263,  3.15399946,  1.37698019,  0.        ]))
21 (18, -0.82917964441403547, array([-0.50098067, -2.11190002,  3.20882445,  1.40338032,  0.        ]))
22 (19, -0.75168763375440906, array([-0.51692311, -2.15016851,  3.25985117,  1.42797471,  0.        ]))
23 (20, -0.68758023105857857, array([-0.53152965, -2.18567731,  3.3075448 ,  1.45098335,  0.        ]))
24 (21, -0.633767252199817, array([-0.5449808 , -2.21876813,  3.35229525,  1.47259108,  0.        ]))
25 (22, -0.5880105475158135, array([-0.55742672, -2.2497278 ,  3.39443149,  1.4929541 ,  0.        ]))
26 (23, -0.54865823718178242, array([-0.56899284, -2.27879848,  3.43423286,  1.51220514,  0.        ]))
27 (24, -0.51446987092368923, array([-0.57978456, -2.30618579,  3.47193808,  1.53045774,  0.        ]))
28 (25, -0.48449905631347406, array([-0.58989091, -2.33206535,  3.5077524 ,  1.54780952,  0.        ]))
29 (26, -0.45801316245868812, array([-0.59938755, -2.35658801,  3.5418533 ,  1.56434493,  0.        ]))
30 (27, -0.43443740873146985, array([-0.60833908, -2.37988406,  3.57439518,  1.58013741,  0.        ]))
31 (28, -0.41331528565404385, array([-0.61680104, -2.40206665,  3.60551298,  1.59525118,  0.        ]))
32 (29, -0.39428010255982965, array([-0.62482134, -2.42323451,  3.63532529,  1.60974263,  0.        ]))
33 (30, -0.37703423805211778, array([-0.6324416 , -2.44347427,  3.6639368 ,  1.62366156,  0.        ]))
34 (31, -0.36133380365882289, array([-0.63969811, -2.46286221,  3.69144031,  1.6370521 ,  0.        ]))
35 (32, -0.34697716567487236, array([-0.64662266, -2.48146587,  3.71791846,  1.64995354,  0.        ]))
36 (33, -0.33379625351555026, array([-0.65324325, -2.49934526,  3.74344514,  1.662401  ,  0.        ]))
37 (34, -0.32164990574178071, array([-0.65958462, -2.5165539 ,  3.76808662,  1.67442599,  0.        ]))
38 (35, -0.31041872365028822, array([-0.66566872, -2.53313973,  3.79190257,  1.68605687,  0.        ]))
39 (36, -0.30000105253087223, array([-0.67151512, -2.54914582,  3.8149469 ,  1.69731925,  0.        ]))
40   
41 
42  //预测新数据结果:
43  //(原类别,hypothesis值,分类结果)
44 (0.0, 0.004116287123555463, 0)
45 (0.0, 0.004491299234282269, 0)
46 (0.0, 0.001997774439620067, 0)
47 (0.0, 9.711727014021101e-05, 0)
48 (1.0, 0.9986958360885878, 1)
49 (1.0, 0.999907833813241, 1)
50 (1.0, 0.998089176390621, 1)
51 (1.0, 0.9999771709114254, 1)
52 (1.0, 0.9998452542803238, 1)
53 
54 Process finished with exit code 0
View Code

 

 

训练样本及数据下载:

original:http://archive.ics.uci.edu/ml/datasets/Iris

csv文件:http://files.cnblogs.com/files/pdev/stochastic_GDA.zip

[Exercise]随机梯度下降、logistic回归

标签:

原文地址:http://www.cnblogs.com/pdev/p/4575354.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!