码迷,mamicode.com
首页 > 其他好文 > 详细

[Exercise]softmax Regression

时间:2015-06-20 00:19:19      阅读:278      评论:0      收藏:0      [点我收藏+]

标签:

softmax回归用来解决K类分类问题,其实就是logistic回归的扩展。

注意:

1.对于sigmod函数g(x),当x∈[-1,1]时效果比较好。所以先把样本数据进行归一化(本例中就是对每一个数都除以10)

2.这次的参数θ不再是一维的向量了而是二维的矩阵: tt[1..CLS][1..LEN]  (CLS表示一共有几类数据,LEN表示每一个样本的维度)。tt[i]*xx[j]表示样本j属于类别i的概率。【从这里很容易看出logistic回归其实就是softmax回归的一种特例啦

3.因为θ变复杂了,所以梯度下降函数也要相应调整。

 

Code:

技术分享
  1 import numpy as np
  2 import csv
  3 import math as mt
  4 
  5 def epow(x):
  6     ex=mt.e
  7     tmp=mt.pow(ex,x)
  8     return tmp
  9 
 10 def vmul(aa,bb,x):        #[0..x-1]
 11     tmp=0.0
 12     for i in range(0,x):
 13         tmp+=aa[i]*bb[i]
 14     return tmp
 15 
 16 def iternum(k,j):
 17     tmp=0.1/(k+j+1)+0.1
 18     return tmp
 19 
 20 def hypo(x,q):        #return h(xx).[q]=(y=q|xx;tt)
 21     tx=0.0
 22     ty=0.0
 23     for j in range(1,CLS+1):        #[1..CLS] 24         tmp=vmul(tt[j],x,LEN)
 25         ty+=epow(tmp)
 26     tx=epow(vmul(tt[q],x,LEN))
 27     tx=tx/ty
 28     return tx
 29 
 30 def likeli(t):
 31     tmp=0.0
 32     for i in range(1,LEN+1):        #[1..LEN]
 33         tx=1.0
 34         for l in range(1,CLS+1):
 35             if(yy[i]==l):
 36                 ty=0.0
 37                 for j in range(1,CLS+1):
 38                     ty+=epow(vmul(t[j],xx[i],LEN))        #LEN: length of the vector xx[i]
 39                 ty=epow(vmul(t[l],xx[i],LEN))/ty
 40                 tx=tx*ty
 41         tx=mt.log(tx)
 42         tmp+=tx
 43     return tmp
 44 
 45 def GDA(iter):
 46     for j in range(1,num+1):        #[1..100]
 47         for k in range(1,CLS+1):
 48             tm=0.0
 49             if(yy[j]==k):
 50                 tm=1.0
 51             tm-=hypo(xx[j],k)
 52             for i in range(0,LEN):
 53                 tt[k][i]+=iternum(iter,j)*tm*xx[j][i]
 54 
 55 trainfile=file(train.csv,rb)
 56 trainread=csv.reader(trainfile)
 57 testfile=file(test.csv,rb)
 58 testread=csv.reader(testfile)
 59 LEN=4                                 #[0..2]
 60 CLS=2                                #divide the DATA into 2 classes
 61 xx=np.zeros((105,LEN+1),float)        #(xx,yy):traindata
 62 yy=np.zeros(105,float)                #[1..num][0..LEN-1,LEN]
 63 tt=np.zeros((CLS+1,LEN+1),float)    #tt[1..CLS][0..LEN-1,LEN]
 64 dx=np.zeros((105,LEN+1),float)        #(dx,dy):testdata
 65 dy=np.zeros(105,float)                #[1..num][0..LEN-1,LEN]
 66 
 67 num=0
 68 for line in trainread:
 69     num+=1
 70     xx[num]=line
 71     for i in range(0,LEN):
 72         xx[num][i]=xx[num][i]/10
 73     yy[num]=xx[num][LEN]
 74 trainfile.close()
 75 
 76 dnum=0
 77 for line in testread:
 78     dnum+=1
 79     dx[dnum]=line
 80     for i in range(0,LEN):
 81         dx[num][i]=dx[num][i]/10
 82     dy[dnum]=dx[dnum][LEN]
 83 testfile.close()
 84 
 85 for i in range(1,num+1):
 86     print(xx[i],yy[i])
 87 print(" ----- ")
 88 for i in range(1,dnum+1):
 89     print(dx[i],dy[i])
 90 
 91 iter=0
 92 lx=99999.0
 93 ly=likeli(tt)
 94 while(mt.fabs(ly-lx)>0.001):
 95     print(iter,likeli(tt),tt)
 96     lx=ly
 97     GDA(iter)
 98     iter+=1
 99     ly=likeli(tt)
100 
101 for i in range(1,dnum+1):
102     print("DATA ",i,dy[i])
103     for j in range(1,CLS+1):
104         tmp=hypo(dx[i],j)
105         print(j,"  ",tmp)
View Code

 

运行结果:

 

技术分享
  1 D:\Anaconda\python.exe D:/pycharmprojects/softmax/1.py
  2 //下面是训练样本
  3 (array([ 0.51,  0.35,  0.14,  0.02,  2.  ]), 2.0)
  4 (array([ 0.49,  0.3 ,  0.14,  0.02,  2.  ]), 2.0)
  5 (array([ 0.47,  0.32,  0.13,  0.02,  2.  ]), 2.0)
  6 (array([ 0.46,  0.31,  0.15,  0.02,  2.  ]), 2.0)
  7 (array([ 0.5 ,  0.36,  0.14,  0.02,  2.  ]), 2.0)
  8 (array([ 0.54,  0.39,  0.17,  0.04,  2.  ]), 2.0)
  9 (array([ 0.46,  0.34,  0.14,  0.03,  2.  ]), 2.0)
 10 (array([ 0.5 ,  0.34,  0.15,  0.02,  2.  ]), 2.0)
 11 (array([ 0.44,  0.29,  0.14,  0.02,  2.  ]), 2.0)
 12 (array([ 0.49,  0.31,  0.15,  0.01,  2.  ]), 2.0)
 13 (array([ 0.54,  0.37,  0.15,  0.02,  2.  ]), 2.0)
 14 (array([ 0.48,  0.34,  0.16,  0.02,  2.  ]), 2.0)
 15 (array([ 0.48,  0.3 ,  0.14,  0.01,  2.  ]), 2.0)
 16 (array([ 0.43,  0.3 ,  0.11,  0.01,  2.  ]), 2.0)
 17 (array([ 0.58,  0.4 ,  0.12,  0.02,  2.  ]), 2.0)
 18 (array([ 0.57,  0.44,  0.15,  0.04,  2.  ]), 2.0)
 19 (array([ 0.54,  0.39,  0.13,  0.04,  2.  ]), 2.0)
 20 (array([ 0.51,  0.35,  0.14,  0.03,  2.  ]), 2.0)
 21 (array([ 0.57,  0.38,  0.17,  0.03,  2.  ]), 2.0)
 22 (array([ 0.51,  0.38,  0.15,  0.03,  2.  ]), 2.0)
 23 (array([ 0.54,  0.34,  0.17,  0.02,  2.  ]), 2.0)
 24 (array([ 0.51,  0.37,  0.15,  0.04,  2.  ]), 2.0)
 25 (array([ 0.46,  0.36,  0.1 ,  0.02,  2.  ]), 2.0)
 26 (array([ 0.51,  0.33,  0.17,  0.05,  2.  ]), 2.0)
 27 (array([ 0.48,  0.34,  0.19,  0.02,  2.  ]), 2.0)
 28 (array([ 0.5 ,  0.3 ,  0.16,  0.02,  2.  ]), 2.0)
 29 (array([ 0.5 ,  0.34,  0.16,  0.04,  2.  ]), 2.0)
 30 (array([ 0.52,  0.35,  0.15,  0.02,  2.  ]), 2.0)
 31 (array([ 0.52,  0.34,  0.14,  0.02,  2.  ]), 2.0)
 32 (array([ 0.47,  0.32,  0.16,  0.02,  2.  ]), 2.0)
 33 (array([ 0.48,  0.31,  0.16,  0.02,  2.  ]), 2.0)
 34 (array([ 0.54,  0.34,  0.15,  0.04,  2.  ]), 2.0)
 35 (array([ 0.52,  0.41,  0.15,  0.01,  2.  ]), 2.0)
 36 (array([ 0.55,  0.42,  0.14,  0.02,  2.  ]), 2.0)
 37 (array([ 0.49,  0.31,  0.15,  0.01,  2.  ]), 2.0)
 38 (array([ 0.5 ,  0.32,  0.12,  0.02,  2.  ]), 2.0)
 39 (array([ 0.55,  0.35,  0.13,  0.02,  2.  ]), 2.0)
 40 (array([ 0.49,  0.31,  0.15,  0.01,  2.  ]), 2.0)
 41 (array([ 0.44,  0.3 ,  0.13,  0.02,  2.  ]), 2.0)
 42 (array([ 0.51,  0.34,  0.15,  0.02,  2.  ]), 2.0)
 43 (array([ 0.5 ,  0.35,  0.13,  0.03,  2.  ]), 2.0)
 44 (array([ 0.45,  0.23,  0.13,  0.03,  2.  ]), 2.0)
 45 (array([ 0.44,  0.32,  0.13,  0.02,  2.  ]), 2.0)
 46 (array([ 0.5 ,  0.35,  0.16,  0.06,  2.  ]), 2.0)
 47 (array([ 0.51,  0.38,  0.19,  0.04,  2.  ]), 2.0)
 48 (array([ 0.48,  0.3 ,  0.14,  0.03,  2.  ]), 2.0)
 49 (array([ 0.51,  0.38,  0.16,  0.02,  2.  ]), 2.0)
 50 (array([ 0.46,  0.32,  0.14,  0.02,  2.  ]), 2.0)
 51 (array([ 0.53,  0.37,  0.15,  0.02,  2.  ]), 2.0)
 52 (array([ 0.5 ,  0.33,  0.14,  0.02,  2.  ]), 2.0)
 53 (array([ 0.7 ,  0.32,  0.47,  0.14,  1.  ]), 1.0)
 54 (array([ 0.64,  0.32,  0.45,  0.15,  1.  ]), 1.0)
 55 (array([ 0.69,  0.31,  0.49,  0.15,  1.  ]), 1.0)
 56 (array([ 0.55,  0.23,  0.4 ,  0.13,  1.  ]), 1.0)
 57 (array([ 0.65,  0.28,  0.46,  0.15,  1.  ]), 1.0)
 58 (array([ 0.57,  0.28,  0.45,  0.13,  1.  ]), 1.0)
 59 (array([ 0.63,  0.33,  0.47,  0.16,  1.  ]), 1.0)
 60 (array([ 0.49,  0.24,  0.33,  0.1 ,  1.  ]), 1.0)
 61 (array([ 0.66,  0.29,  0.46,  0.13,  1.  ]), 1.0)
 62 (array([ 0.52,  0.27,  0.39,  0.14,  1.  ]), 1.0)
 63 (array([ 0.5 ,  0.2 ,  0.35,  0.1 ,  1.  ]), 1.0)
 64 (array([ 0.59,  0.3 ,  0.42,  0.15,  1.  ]), 1.0)
 65 (array([ 0.6 ,  0.22,  0.4 ,  0.1 ,  1.  ]), 1.0)
 66 (array([ 0.61,  0.29,  0.47,  0.14,  1.  ]), 1.0)
 67 (array([ 0.56,  0.29,  0.36,  0.13,  1.  ]), 1.0)
 68 (array([ 0.67,  0.31,  0.44,  0.14,  1.  ]), 1.0)
 69 (array([ 0.56,  0.3 ,  0.45,  0.15,  1.  ]), 1.0)
 70 (array([ 0.58,  0.27,  0.41,  0.1 ,  1.  ]), 1.0)
 71 (array([ 0.62,  0.22,  0.45,  0.15,  1.  ]), 1.0)
 72 (array([ 0.56,  0.25,  0.39,  0.11,  1.  ]), 1.0)
 73 (array([ 0.59,  0.32,  0.48,  0.18,  1.  ]), 1.0)
 74 (array([ 0.61,  0.28,  0.4 ,  0.13,  1.  ]), 1.0)
 75 (array([ 0.63,  0.25,  0.49,  0.15,  1.  ]), 1.0)
 76 (array([ 0.61,  0.28,  0.47,  0.12,  1.  ]), 1.0)
 77 (array([ 0.64,  0.29,  0.43,  0.13,  1.  ]), 1.0)
 78 (array([ 0.66,  0.3 ,  0.44,  0.14,  1.  ]), 1.0)
 79 (array([ 0.68,  0.28,  0.48,  0.14,  1.  ]), 1.0)
 80 (array([ 0.67,  0.3 ,  0.5 ,  0.17,  1.  ]), 1.0)
 81 (array([ 0.6 ,  0.29,  0.45,  0.15,  1.  ]), 1.0)
 82 (array([ 0.57,  0.26,  0.35,  0.1 ,  1.  ]), 1.0)
 83 (array([ 0.55,  0.24,  0.38,  0.11,  1.  ]), 1.0)
 84 (array([ 0.55,  0.24,  0.37,  0.1 ,  1.  ]), 1.0)
 85 (array([ 0.58,  0.27,  0.39,  0.12,  1.  ]), 1.0)
 86 (array([ 0.6 ,  0.27,  0.51,  0.16,  1.  ]), 1.0)
 87 (array([ 0.54,  0.3 ,  0.45,  0.15,  1.  ]), 1.0)
 88 (array([ 0.6 ,  0.34,  0.45,  0.16,  1.  ]), 1.0)
 89 (array([ 0.67,  0.31,  0.47,  0.15,  1.  ]), 1.0)
 90 (array([ 0.63,  0.23,  0.44,  0.13,  1.  ]), 1.0)
 91 (array([ 0.56,  0.3 ,  0.41,  0.13,  1.  ]), 1.0)
 92 (array([ 0.55,  0.25,  0.4 ,  0.13,  1.  ]), 1.0)
 93 (array([ 0.55,  0.26,  0.44,  0.12,  1.  ]), 1.0)
 94 (array([ 0.61,  0.3 ,  0.46,  0.14,  1.  ]), 1.0)
 95 (array([ 0.58,  0.26,  0.4 ,  0.12,  1.  ]), 1.0)
 96 (array([ 0.5 ,  0.23,  0.33,  0.1 ,  1.  ]), 1.0)
 97 (array([ 0.56,  0.27,  0.42,  0.13,  1.  ]), 1.0)
 98 (array([ 0.57,  0.3 ,  0.42,  0.12,  1.  ]), 1.0)
 99 (array([ 0.57,  0.29,  0.42,  0.13,  1.  ]), 1.0)
100 (array([ 0.62,  0.29,  0.43,  0.13,  1.  ]), 1.0)
101 (array([ 0.51,  0.25,  0.3 ,  0.11,  1.  ]), 1.0)
102 (array([ 0.57,  0.28,  0.41,  0.13,  1.  ]), 1.0)
103  ----- //下面是测试样本
104 (array([ 4.8,  3.4,  1.6,  0.2,  2. ]), 2.0)
105 (array([ 4.8,  3. ,  1.4,  0.1,  2. ]), 2.0)
106 (array([ 4.3,  3. ,  1.1,  0.1,  2. ]), 2.0)
107 (array([ 5.8,  4. ,  1.2,  0.2,  2. ]), 2.0)
108 (array([ 5.6,  2.5,  3.9,  1.1,  1. ]), 1.0)
109 (array([ 5.9,  3.2,  4.8,  1.8,  1. ]), 1.0)
110 (array([ 6.1,  2.8,  4. ,  1.3,  1. ]), 1.0)
111 (array([ 6.3,  2.5,  4.9,  1.5,  1. ]), 1.0)
112 (array([ 6.1,  2.8,  4.7,  1.2,  1. ]), 1.0)
113 //格式:(迭代次数,似然函数值,参数tt)
114 (0, -2.772588722239781, array([[ 0.,  0.,  0.,  0.,  0.],
115        [ 0.,  0.,  0.,  0.,  0.],
116        [ 0.,  0.,  0.,  0.,  0.]]))
117 (1, -4.514315589179653, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
118        [ 0.5145982 ,  0.05042281,  0.73611002,  0.26706586,  0.        ],
119        [-0.50370467, -0.04718208, -0.72477409, -0.26317434,  0.        ]]))
120 (2, -4.227215487909576, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
121        [ 0.45433865, -0.23090121,  1.18513854,  0.45716201,  0.        ],
122        [-0.43905776,  0.23361395, -1.16573056, -0.45025988,  0.        ]]))
123 (3, -3.7663141540304017, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
124        [ 0.34235568, -0.52470169,  1.5748469 ,  0.62663847,  0.        ],
125        [-0.32372301,  0.52655101, -1.54839518, -0.6170723 ,  0.        ]]))
126 (4, -3.351323062397454, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
127        [ 0.23350681, -0.79912339,  1.9323152 ,  0.782718  ,  0.        ],
128        [-0.21195882,  0.80012599, -1.89955448, -0.77075573,  0.        ]]))
129 (5, -2.9952661316895712, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
130        [ 0.13281698, -1.05286174,  2.26314675,  0.92733284,  0.        ],
131        [-0.1086993 ,  1.05307502, -2.22474157, -0.91322043,  0.        ]]))
132 (6, -2.6911645875438817, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
133        [ 0.04008866, -1.28766414,  2.57038587,  1.06173426,  0.        ],
134        [-0.0137102 ,  1.28714809, -2.52694564, -1.04569818,  0.        ]]))
135 (7, -2.4310033411415093, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
136        [-0.04540335, -1.50544051,  2.85652835,  1.18698672,  0.        ],
137        [ 0.07376713,  1.50425253, -2.80860351, -1.16923217,  0.        ]]))
138 (8, -2.2076937297751518, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
139        [-0.12438972, -1.70794838,  3.12376352,  1.30403084,  0.        ],
140        [ 0.15449676,  1.7061424 , -3.07184422, -1.28474108,  0.        ]]))
141 (9, -2.0152221051169965, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
142        [-0.19754228, -1.89676316,  3.37402437,  1.41369945,  0.        ],
143        [ 0.22918248,  1.89438939, -3.31854265, -1.39303654,  0.        ]]))
144 (10, -1.8485615066553978, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
145        [-0.26546623, -2.07328532,  3.60901584,  1.5167286 ,  0.        ],
146        [ 0.29845873,  2.07039023, -3.55034999, -1.49483488,  0.        ]]))
147 (11, -1.7035420403682222, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
148        [-0.32870183, -2.23875443,  3.83024003,  1.61376783,  0.        ],
149        [ 0.36289166,  2.23538082, -3.76872006, -1.59076787,  0.        ]]))
150 (12, -1.5767187419565711, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
151        [-0.38772876, -2.39426506,  4.03901997,  1.70538999,  0.        ],
152        [ 0.42298345,  2.39045216, -3.9749333 , -1.68139268,  0.        ]]))
153 (13, -1.465249925219289, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
154        [-0.44297145, -2.54078266,  4.23652149,  1.79210051,  0.        ],
155        [ 0.47917776,  2.53656633, -4.17011849, -1.76720095,  0.        ]]))
156 (14, -1.366790841468854, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
157        [-0.49480477, -2.67915874,  4.4237731 ,  1.87434569,  0.        ],
158        [ 0.53186576,  2.67457169, -4.35527213, -1.84862711,  0.        ]]))
159 (15, -1.279403600912318, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
160        [-0.54355954, -2.81014484,  4.60168369,  1.95252024,  0.        ],
161        [ 0.581392  ,  2.80521684, -4.53127561, -1.92605557,  0.        ]]))
162 (16, -1.201482384930097, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
163        [-0.58952775, -2.93440516,  4.77105806,  2.0269739 ,  0.        ],
164        [ 0.62806003,  2.92916325, -4.69891011, -1.99982726,  0.        ]]))
165 (17, -1.1316921449530692, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
166        [-0.63296735, -3.05252766,  4.93261044,  2.09801724,  0.        ],
167        [ 0.67213751,  3.04699645, -4.85886965, -2.07024512,  0.        ]]))
168 (18, -1.0689187527177835, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
169        [-0.67410646, -3.16503394,  5.08697632,  2.1659267 ,  0.        ],
170        [ 0.71386074,  3.15923578, -5.01177236, -2.13757905,  0.        ]]))
171 (19, -1.0122286447337796, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
172        [-0.71314722, -3.27238775,  5.23472253,  2.23094896,  0.        ],
173        [ 0.75343874,  3.26634296, -5.15817015, -2.2020701 ,  0.        ]]))
174 (20, -0.9608362196301394, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
175        [-0.75026908, -3.37500243,  5.37635605,  2.29330469,  0.        ],
176        [ 0.79105679,  3.36872949, -5.29855715, -2.26393407,  0.        ]]))
177 (21, -0.9140775062980213, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
178        [-0.78563169, -3.4732473 ,  5.51233148,  2.35319186,  0.        ],
179        [ 0.82687947,  3.46676303, -5.43337686, -2.3233647 ,  0.        ]]))
180 (22, -0.8713888752894932, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
181        [-0.81937742, -3.56745317,  5.64305758,  2.41078849,  0.        ],
182        [ 0.86105337,  3.56077293, -5.56302845, -2.38053633,  0.        ]]))
183 (23, -0.8322897941423284, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
184        [-0.85163352, -3.65791719,  5.76890282,  2.46625507,  0.        ],
185        [ 0.89370933,  3.65105496, -5.68787203, -2.43560628,  0.        ]]))
186 (24, -0.7963688219230861, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
187        [-0.88251403, -3.74490689,  5.89020025,  2.51973669,  0.        ],
188        [ 0.92496448,  3.73787547, -5.80823335, -2.48871686,  0.        ]]))
189 (25, -0.7632721993305688, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
190        [-0.91212138, -3.82866377,  6.00725161,  2.57136484,  0.        ],
191        [ 0.9549239 ,  3.82147486, -5.9244078 , -2.53999708,  0.        ]]))
192 (26, -0.7326945214643822, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
193        [-0.9405478 , -3.90940638,  6.12033097,  2.62125893,  0.        ],
194        [ 0.98368212,  3.90207072, -6.03666385, -2.58956423,  0.        ]]))
195 (27, -0.7043710852590455, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
196        [-0.96787656, -3.98733297,  6.22968786,  2.66952771,  0.        ],
197        [ 1.01132441,  3.97986039, -6.14524613, -2.63752515,  0.        ]]))
198 (28, -0.6780715870959158, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
199        [-0.99418297, -4.06262381,  6.33554998,  2.71627042,  0.        ],
200        [ 1.03792783,  4.05502337, -6.25037798, -2.6839774 ,  0.        ]]))
201 (29, -0.653594912296152, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
202        [-1.01953536, -4.13544319,  6.43812558,  2.76157782,  0.        ],
203        [ 1.06356223,  4.12772321, -6.35226379, -2.72901025,  0.        ]]))
204 (30, -0.6307648105411697, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
205        [-1.04399582, -4.20594117,  6.53760551,  2.80553311,  0.        ],
206        [ 1.08829103,  4.19810933, -6.45109101, -2.77270558,  0.        ]]))
207 (31, -0.6094262926302728, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
208        [-1.06762093, -4.27425512,  6.63416507,  2.84821269,  0.        ],
209        [ 1.112172  ,  4.2663185 , -6.54703188, -2.81513859,  0.        ]]))
210 (32, -0.5894426166862736, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
211        [-1.09046233, -4.34051104,  6.72796557,  2.88968689,  0.        ],
212        [ 1.13525782,  4.3324762 , -6.64024498, -2.85637857,  0.        ]]))
213 (33, -0.5706927578062163, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
214        [-1.11256726, -4.40482474,  6.81915574,  2.93002053,  0.        ],
215        [ 1.15759668,  4.39669774, -6.73087659, -2.89648939,  0.        ]]))
216 (34, -0.5530692756852702, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
217        [-1.13397905, -4.46730288,  6.90787295,  2.96927353,  0.        ],
218        [ 1.17923271,  4.45908935, -6.81906192, -2.93553009,  0.        ]]))
219 (35, -0.5364765110641555, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
220        [-1.15473749, -4.52804386,  6.99424436,  3.0075013 ,  0.        ],
221        [ 1.20020643,  4.51974901, -6.90492611, -2.97355532,  0.        ]]))
222 (36, -0.5208290548620741, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
223        [-1.17487917, -4.58713862,  7.07838779,  3.04475523,  0.        ],
224        [ 1.22055513,  4.57876731, -6.98858522, -3.01061578,  0.        ]]))
225 (37, -0.5060504442603247, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
226        [-1.19443786, -4.6446714 ,  7.16041268,  3.08108302,  0.        ],
227        [ 1.24031316,  4.63622814, -7.07014706, -3.04675854,  0.        ]]))
228 (38, -0.49207204834551277, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
229        [-1.21344474, -4.70072028,  7.24042077,  3.11652906,  0.        ],
230        [ 1.25951222,  4.69220928, -7.14971191, -3.08202739,  0.        ]]))
231 (39, -0.4788321126350646, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
232        [-1.23192864, -4.75535784,  7.31850685,  3.15113466,  0.        ],
233        [ 1.27818165,  4.74678302, -7.22737322, -3.11646314,  0.        ]]))
234 (40, -0.4662749372279032, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
235        [-1.2499163 , -4.80865158,  7.39475932,  3.18493839,  0.        ],
236        [ 1.2963486 ,  4.80001661, -7.30321818, -3.15010386,  0.        ]]))
237 (41, -0.4543501677137507, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
238        [-1.26743253, -4.86066441,  7.46926076,  3.21797624,  0.        ],
239        [ 1.3140383 ,  4.85197271, -7.37732826, -3.18298513,  0.        ]]))
240 (42, -0.4430121815430045, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
241        [-1.28450038, -4.91145502,  7.54208842,  3.25028192,  0.        ],
242        [ 1.33127417,  4.90270981, -7.44977968, -3.21514024,  0.        ]]))
243 (43, -0.43221955546938995, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
244        [-1.30114133, -4.96107829,  7.61331464,  3.28188698,  0.        ],
245        [ 1.34807801,  4.95228256, -7.52064386, -3.24660038,  0.        ]]))
246 (44, -0.42193460205884403, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
247        [-1.3173754 , -5.00958552,  7.68300728,  3.31282101,  0.        ],
248        [ 1.36447014,  5.00074211, -7.5899878 , -3.27739481,  0.        ]]))
249 (45, -0.4121229652129941, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
250        [-1.33322127, -5.05702483,  7.75123004,  3.3431118 ,  0.        ],
251        [ 1.38046952,  5.04813637, -7.65787443, -3.307551  ,  0.        ]]))
252 (46, -0.4027532662657923, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
253        [-1.34869642, -5.10344132,  7.81804283,  3.37278549,  0.        ],
254        [ 1.3960939 ,  5.0945103 , -7.72436291, -3.33709479,  0.        ]]))
255 (47, -0.3937967935422819, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
256        [-1.3638172 , -5.14887738,  7.88350201,  3.40186665,  0.        ],
257        [ 1.41135985,  5.13990615, -7.78950894, -3.36605052,  0.        ]]))
258 (48, -0.3852272293713247, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
259        [-1.37859895, -5.19337285,  7.94766071,  3.43037848,  0.        ],
260        [ 1.42628293,  5.18436361, -7.85336503, -3.39444111,  0.        ]]))
261 (49, -0.3770204094610595, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
262        [-1.39305604, -5.23696526,  8.01056902,  3.45834282,  0.        ],
263        [ 1.44087771,  5.22792008, -7.9159807 , -3.4222882 ,  0.        ]]))
264 (50, -0.36915411031061895, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
265        [-1.407202  , -5.27968996,  8.07227425,  3.48578031,  0.        ],
266        [ 1.45515791,  5.2706108 , -7.97740273, -3.44961221,  0.        ]]))
267 (51, -0.3616078609712865, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
268        [-1.42104953, -5.32158031,  8.13282113,  3.51271047,  0.        ],
269        [ 1.4691364 ,  5.31246903, -8.03767533, -3.47643246,  0.        ]]))
270 (52, -0.3543627760069359, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
271        [-1.43461061, -5.36266782,  8.19225194,  3.53915176,  0.        ],
272        [ 1.48282531,  5.35352616, -8.09684037, -3.50276723,  0.        ]]))
273 (53, -0.34740140695499805, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
274        [-1.44789651, -5.40298228,  8.25060675,  3.56512167,  0.        ],
275        [ 1.49623608,  5.39381189, -8.15493746, -3.52863384,  0.        ]]))
276 (54, -0.3407076099700812, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
277        [-1.46091788, -5.44255187,  8.30792354,  3.59063675,  0.        ],
278        [ 1.50937948,  5.4333543 , -8.21200419, -3.55404869,  0.        ]]))
279 (55, -0.3342664276544936, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
280        [-1.47368478, -5.48140328,  8.36423834,  3.61571275,  0.        ],
281        [ 1.5222657 ,  5.47218003, -8.2680762 , -3.57902736,  0.        ]]))
282 (56, -0.3280639833531262, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
283        [-1.48620672, -5.51956183,  8.41958534,  3.64036458,  0.        ],
284        [ 1.53490438,  5.51031429, -8.32318736, -3.60358465,  0.        ]]))
285 (57, -0.32208738642249457, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
286        [-1.49849269, -5.55705153,  8.47399705,  3.66460644,  0.        ],
287        [ 1.54730461,  5.54778102, -8.37736985, -3.62773461,  0.        ]]))
288 (58, -0.31632464718171643, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
289        [-1.5105512 , -5.59389518,  8.52750439,  3.68845182,  0.        ],
290        [ 1.55947503,  5.58460298, -8.43065427, -3.65149062,  0.        ]]))
291 (59, -0.31076460042247034, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
292        [-1.52239035, -5.63011446,  8.58013679,  3.71191357,  0.        ],
293        [ 1.5714238 ,  5.62080176, -8.48306977, -3.67486542,  0.        ]]))
294 (60, -0.3053968364998493, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
295        [-1.53401779, -5.66572999,  8.63192227,  3.73500393,  0.        ],
296        [ 1.5831587 ,  5.65639792, -8.5346441 , -3.69787113,  0.        ]]))
297 (61, -0.3002116391505212, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
298        [-1.5454408 , -5.70076139,  8.68288754,  3.75773457,  0.        ],
299        [ 1.59468707,  5.69141103, -8.58540374, -3.72051932,  0.        ]]))
300 (62, -0.295199929291632, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
301        [-1.5566663 , -5.73522737,  8.73305811,  3.78011661,  0.        ],
302        [ 1.60601593,  5.72585974, -8.63537393, -3.74282102,  0.        ]]))
303 (63, -0.2903532141462281, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
304        [-1.56770086, -5.76914573,  8.7824583 ,  3.80216067,  0.        ],
305        [ 1.61715193,  5.75976182, -8.68457878, -3.76478678,  0.        ]]))
306 (64, -0.2856635411207813, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
307        [-1.57855076, -5.8025335 ,  8.83111134,  3.82387692,  0.        ],
308        [ 1.62810141,  5.79313423, -8.73304132, -3.78642665,  0.        ]]))
309 (65, -0.2811234559294959, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
310        [-1.58922195, -5.8354069 ,  8.87903946,  3.84527505,  0.        ],
311        [ 1.63887039,  5.82599316, -8.78078356, -3.80775026,  0.        ]]))
312 (66, -0.27672596452004006, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
313        [-1.59972013, -5.86778144,  8.9262639 ,  3.86636434,  0.        ],
314        [ 1.64946463,  5.85835408, -8.82782658, -3.82876682,  0.        ]]))
315 (67, -0.2724644984075098, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
316        [-1.6100507 , -5.89967195,  8.972805  ,  3.88715368,  0.        ],
317        [ 1.65988962,  5.89023177, -8.87419051, -3.84948514,  0.        ]]))
318 (68, -0.2683328830688281, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
319        [-1.62021887, -5.93109262,  9.01868222,  3.90765158,  0.        ],
320        [ 1.67015059,  5.92164038, -8.91989468, -3.86991366,  0.        ]]))
321 (69, -0.26432530908949137, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
322        [-1.63022956, -5.96205701,  9.06391425,  3.92786618,  0.        ],
323        [ 1.68025254,  5.95259345, -8.96495759, -3.89006047,  0.        ]]))
324 (70, -0.2604363057891946, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
325        [-1.6400875 , -5.99257812,  9.10851896,  3.94780531,  0.        ],
326        [ 1.69020025,  5.98310394, -9.00939698, -3.90993334,  0.        ]]))
327 (71, -0.25666071708328275, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
328        [-1.64979722, -6.02266842,  9.15251352,  3.96747647,  0.        ],
329        [ 1.69999829,  6.01318428, -9.05322986, -3.92953969,  0.        ]]))
330 (72, -0.2529936793636331, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
331        [-1.65936305, -6.05233984,  9.1959144 ,  3.98688685,  0.        ],
332        [ 1.70965104,  6.04284639, -9.09647258, -3.94888668,  0.        ]]))
333 (73, -0.249430601206015, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
334        [-1.66878912, -6.08160386,  9.23873743,  4.00604338,  0.        ],
335        [ 1.71916269,  6.0721017 , -9.13914083, -3.96798118,  0.        ]]))
336 (74, -0.2459671447315847, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
337        [-1.6780794 , -6.11047146,  9.2809978 ,  4.02495269,  0.        ],
338        [ 1.72853724,  6.10096119, -9.1812497 , -3.98682977,  0.        ]]))
339 (75, -0.24259920846844113, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
340        [-1.6872377 , -6.13895322,  9.32271014,  4.04362117,  0.        ],
341        [ 1.73777854,  6.12943538, -9.22281367, -4.0054388 ,  0.        ]]))
342 (76, -0.23932291157518165, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
343        [-1.69626767, -6.16705929,  9.3638885 ,  4.06205498,  0.        ],
344        [ 1.74689028,  6.15753442, -9.2638467 , -4.02381437,  0.        ]]))
345 (77, -0.2361345793026871, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
346        [-1.70517282, -6.19479942,  9.4045464 ,  4.08026001,  0.        ],
347        [ 1.755876  ,  6.18526803, -9.30436221, -4.04196235,  0.        ]]))
348 (78, -0.2330307295829455, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
349        [-1.7139565 , -6.222183  ,  9.44469686,  4.09824197,  0.        ],
350        [ 1.7647391 ,  6.21264559, -9.34437312, -4.0598884 ,  0.        ]]))
351 (79, -0.2300080606449215, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
352        [-1.72262195, -6.24921907,  9.48435244,  4.11600634,  0.        ],
353        [ 1.77348285,  6.2396761 , -9.38389187, -4.07759796,  0.        ]]))
354 (80, -0.22706343956745145, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
355        [-1.73117227, -6.27591632,  9.5235252 ,  4.13355841,  0.        ],
356        [ 1.78211038,  6.26636824, -9.42293046, -4.09509628,  0.        ]]))
357 (81, -0.2241938916879589, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
358        [-1.73961045, -6.30228314,  9.56222681,  4.15090327,  0.        ],
359        [ 1.7906247 ,  6.29273038, -9.46150046, -4.11238842,  0.        ]]))
360 (82, -0.2213965907937252, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
361        [-1.74793936, -6.32832759,  9.6004685 ,  4.16804583,  0.        ],
362        [ 1.79902874,  6.31877056, -9.49961302, -4.12947926,  0.        ]]))
363 (83, -0.21866885002945297, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
364        [-1.75616177, -6.35405745,  9.63826112,  4.18499085,  0.        ],
365        [ 1.80732526,  6.34449656, -9.53727889, -4.14637351,  0.        ]]))
366 (84, -0.2160081134611847, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
367        [-1.76428034, -6.37948024,  9.67561512,  4.20174288,  0.        ],
368        [ 1.81551698,  6.36991586, -9.57450848, -4.16307572,  0.        ]]))
369 (85, -0.21341194824226611, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
370        [-1.77229764, -6.40460319,  9.71254062,  4.21830637,  0.        ],
371        [ 1.82360648,  6.3950357 , -9.61131181, -4.17959029,  0.        ]]))
372 (86, -0.21087803733207908, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
373        [-1.78021614, -6.4294333 ,  9.74904737,  4.23468557,  0.        ],
374        [ 1.83159625,  6.41986303, -9.64769858, -4.19592145,  0.        ]]))
375 (87, -0.20840417272282616, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
376        [-1.78803823, -6.45397732,  9.78514481,  4.25088462,  0.        ],
377        [ 1.83948871,  6.44440462, -9.68367815, -4.21207331,  0.        ]]))
378 (88, -0.20598824913369043, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
379        [-1.79576621, -6.47824176,  9.82084207,  4.26690751,  0.        ],
380        [ 1.84728619,  6.46866695, -9.71925958, -4.22804983,  0.        ]]))
381 (89, -0.20362825813536678, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
382        [-1.80340229, -6.50223294,  9.85614795,  4.28275809,  0.        ],
383        [ 1.85499092,  6.49265633, -9.75445163, -4.24385485,  0.        ]]))
384 (90, -0.20132228267125782, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
385        [-1.81094862, -6.52595694,  9.89107099,  4.29844011,  0.        ],
386        [ 1.86260507,  6.51637883, -9.78926278, -4.25949208,  0.        ]]))
387 (91, -0.19906849194458476, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
388        [-1.81840728, -6.54941965,  9.92561945,  4.31395717,  0.        ],
389        [ 1.87013072,  6.53984033, -9.82370122, -4.27496511,  0.        ]]))
390 (92, -0.19686513664336477, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
391        [-1.82578026, -6.57262679,  9.95980133,  4.32931278,  0.        ],
392        [ 1.87756991,  6.56304653, -9.8577749 , -4.29027741,  0.        ]]))
393 (93, -0.1947105444776009, array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
394        [-1.83306949, -6.59558386,  9.99362437,  4.34451032,  0.        ],
395        [ 1.88492458,  6.58600292, -9.89149151, -4.30543237,  0.        ]]))
396 (94, -0.19260311600524177, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
397        [ -1.84027684,  -6.6182962 ,  10.02709607,   4.35955308,   0.        ],
398        [  1.89219662,   6.60871485,  -9.9248585 ,  -4.32043324,   0.        ]]))
399 (95, -0.19054132072544686, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
400        [ -1.84740413,  -6.640769  ,  10.06022369,   4.37444425,   0.        ],
401        [  1.89938785,   6.63118748,  -9.95788309,  -4.33528319,   0.        ]]))
402 (96, -0.18852369341947317, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
403        [ -1.8544531 ,  -6.66300725,  10.09301429,   4.38918691,   0.        ],
404        [  1.90650005,   6.6534258 ,  -9.99057228,  -4.34998528,   0.        ]]))
405 (97, -0.1865488307211645, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
406        [ -1.86142545,  -6.68501583,  10.1254747 ,   4.40378405,   0.        ],
407        [  1.91353491,   6.67543467, -10.02293286,  -4.36454248,   0.        ]]))
408 (98, -0.184615387900454, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
409        [ -1.8683228 ,  -6.70679943,  10.15761154,   4.41823857,   0.        ],
410        [  1.92049411,   6.69721878, -10.05497142,  -4.3789577 ,   0.        ]]))
411 (99, -0.18272207584469313, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
412        [ -1.87514677,  -6.72836263,  10.18943126,   4.43255329,   0.        ],
413        [  1.92737923,   6.7187827 , -10.08669434,  -4.39323371,   0.        ]]))
414 (100, -0.18086765822379874, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
415        [ -1.88189887,  -6.74970984,  10.22094008,   4.44673095,   0.        ],
416        [  1.93419184,   6.74013084, -10.11810784,  -4.40737326,   0.        ]]))
417 (101, -0.17905094882634626, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
418        [ -1.8885806 ,  -6.77084538,  10.25214407,   4.4607742 ,   0.        ],
419        [  1.94093344,   6.76126749, -10.14921792,  -4.42137896,   0.        ]]))
420 (102, -0.17727080905475595, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
421        [ -1.89519341,  -6.7917734 ,  10.28304912,   4.47468562,   0.        ],
422        [  1.94760548,   6.78219683, -10.18003045,  -4.43525339,   0.        ]]))
423 (103, -0.1755261455686309, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
424        [ -1.90173871,  -6.81249796,  10.31366093,   4.48846772,   0.        ],
425        [  1.95420939,   6.80292288, -10.2105511 ,  -4.44899904,   0.        ]]))
426 (104, -0.1738159080661718, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
427        [ -1.90821784,  -6.83302299,  10.34398508,   4.50212292,   0.        ],
428        [  1.96074654,   6.82344957, -10.24078539,  -4.46261833,   0.        ]]))
429 (105, -0.1721390871943469, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
430        [ -1.91463213,  -6.85335231,  10.37402695,   4.51565359,   0.        ],
431        [  1.96721825,   6.84378073, -10.2707387 ,  -4.47611361,   0.        ]]))
432 (106, -0.17049471257922666, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
433        [ -1.92098286,  -6.87348964,  10.40379179,   4.52906204,   0.        ],
434        [  1.97362583,   6.86392005, -10.30041622,  -4.48948716,   0.        ]]))
435 (107, -0.16888185096851546, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
436        [ -1.92727128,  -6.89343857,  10.43328472,   4.54235051,   0.        ],
437        [  1.97997053,   6.88387114, -10.32982305,  -4.50274122,   0.        ]]))
438 (108, -0.16729960447893147, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
439        [ -1.93349858,  -6.91320262,  10.46251068,   4.55552116,   0.        ],
440        [  1.98625356,   6.90363749, -10.35896411,  -4.51587795,   0.        ]]))
441 (109, -0.1657471089416123, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
442        [ -1.93966596,  -6.9327852 ,  10.49147452,   4.56857612,   0.        ],
443        [  1.99247611,   6.92322251, -10.38784421,  -4.52889945,   0.        ]]))
444 (110, -0.16422353233923076, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
445        [ -1.94577453,  -6.95218962,  10.52018094,   4.58151744,   0.        ],
446        [  1.99863934,   6.94262952, -10.41646802,  -4.54180778,   0.        ]]))
447 (111, -0.16272807332897177, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
448        [ -1.95182542,  -6.97141911,  10.5486345 ,   4.59434714,   0.        ],
449        [  2.00474436,   6.96186174, -10.44484009,  -4.55460493,   0.        ]]))
450 (112, -0.16125995984592273, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
451        [ -1.95781971,  -6.99047681,  10.57683967,   4.60706716,   0.        ],
452        [  2.01079225,   6.9809223 , -10.47296484,  -4.56729284,   0.        ]]))
453 (113, -0.15981844778184623, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
454        [ -1.96375843,  -7.00936579,  10.60480078,   4.61967941,   0.        ],
455        [  2.01678408,   6.99981425, -10.50084659,  -4.5798734 ,   0.        ]]))
456 (114, -0.15840281973463238, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
457        [ -1.96964261,  -7.02808901,  10.63252206,   4.63218575,   0.        ],
458        [  2.02272088,   7.01854057, -10.52848955,  -4.59234846,   0.        ]]))
459 (115, -0.1570123838240855, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
460        [ -1.97547324,  -7.04664937,  10.66000762,   4.64458797,   0.        ],
461        [  2.02860364,   7.03710415, -10.5558978 ,  -4.60471982,   0.        ]]))
462 (116, -0.15564647256998015, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
463        [ -1.98125128,  -7.06504971,  10.68726149,   4.65688784,   0.        ],
464        [  2.03443334,   7.05550782, -10.58307534,  -4.61698922,   0.        ]]))
465 (117, -0.15430444182860753, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
466        [ -1.98697767,  -7.08329276,  10.71428756,   4.66908708,   0.        ],
467        [  2.04021092,   7.07375431, -10.61002606,  -4.62915838,   0.        ]]))
468 (118, -0.1529856697843115, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
469        [ -1.99265334,  -7.10138121,  10.74108966,   4.68118735,   0.        ],
470        [  2.04593732,   7.09184631, -10.63675374,  -4.64122895,   0.        ]]))
471 (119, -0.151689555992716, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
472        [ -1.99827916,  -7.11931768,  10.76767149,   4.6931903 ,   0.        ],
473        [  2.05161343,   7.10978642, -10.66326209,  -4.65320257,   0.        ]]))
474 (120, -0.15041552047260684, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
475        [ -2.00385601,  -7.13710471,  10.7940367 ,   4.7050975 ,   0.        ],
476        [  2.05724012,   7.12757719, -10.6895547 ,  -4.6650808 ,   0.        ]]))
477 (121, -0.14916300284359077, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
478        [ -2.00938473,  -7.15474478,  10.8201888 ,   4.71691051,   0.        ],
479        [  2.06281824,   7.1452211 , -10.71563511,  -4.67686521,   0.        ]]))
480 (122, -0.1479314615068965, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
481        [ -2.01486615,  -7.17224032,  10.84613127,   4.72863085,   0.        ],
482        [  2.06834864,   7.16272057, -10.74150674,  -4.68855729,   0.        ]]))
483 (123, -0.14672037286680367, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
484        [ -2.02030106,  -7.1895937 ,  10.87186746,   4.74026   ,   0.        ],
485        [  2.07383211,   7.18007796, -10.76717294,  -4.70015852,   0.        ]]))
486 (124, -0.1455292305903876, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
487        [ -2.02569026,  -7.20680721,  10.89740066,   4.7517994 ,   0.        ],
488        [  2.07926946,   7.19729557, -10.79263698,  -4.71167033,   0.        ]]))
489 (125, -0.14435754490340363, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
490        [ -2.0310345 ,  -7.22388311,  10.92273409,   4.76325045,   0.        ],
491        [  2.08466144,   7.21437565, -10.81790206,  -4.72309413,   0.        ]]))
492 (126, -0.14320484192026453, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
493        [ -2.03633452,  -7.24082359,  10.94787087,   4.77461454,   0.        ],
494        [  2.09000882,   7.23132039, -10.8429713 ,  -4.73443128,   0.        ]]))
495 (127, -0.14207066300621723, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
496        [ -2.04159107,  -7.25763081,  10.97281408,   4.78589301,   0.        ],
497        [  2.09531231,   7.24813194, -10.86784773,  -4.74568313,   0.        ]]))
498 (128, -0.1409545641699272, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
499        [ -2.04680483,  -7.27430685,  10.99756669,   4.79708717,   0.        ],
500        [  2.10057265,   7.26481239, -10.89253434,  -4.75685097,   0.        ]]))
501 (129, -0.1398561154847964, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
502        [ -2.05197651,  -7.29085376,  11.02213164,   4.80819831,   0.        ],
503        [  2.10579052,   7.28136378, -10.91703404,  -4.7679361 ,   0.        ]]))
504 (130, -0.1387749005374439, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
505        [ -2.05710679,  -7.30727354,  11.04651178,   4.81922768,   0.        ],
506        [  2.11096661,   7.2977881 , -10.94134965,  -4.77893974,   0.        ]]))
507 (131, -0.13771051590188077, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
508        [ -2.0621963 ,  -7.32356814,  11.0707099 ,   4.8301765 ,   0.        ],
509        [  2.11610158,   7.31408732, -10.96548397,  -4.78986313,   0.        ]]))
510 (132, -0.13666257063799675, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
511        [ -2.06724572,  -7.33973946,  11.09472873,   4.84104597,   0.        ],
512        [  2.12119608,   7.33026332, -10.98943969,  -4.80070746,   0.        ]]))
513 (133, -0.13563068581305257, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
514        [ -2.07225565,  -7.35578938,  11.11857093,   4.85183728,   0.        ],
515        [  2.12625075,   7.34631798, -11.01321949,  -4.8114739 ,   0.        ]]))
516 (134, -0.13461449404497236, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
517        [ -2.07722672,  -7.37171972,  11.14223911,   4.86255155,   0.        ],
518        [  2.1312662 ,   7.36225311, -11.03682594,  -4.82216358,   0.        ]]))
519 (135, -0.13361363906627216, array([[  0.        ,   0.        ,   0.        ,   0.        ,   0.        ],
520        [ -2.08215952,  -7.38753225,  11.16573583,   4.87318991,   0.        ],
521        [  2.13624305,   7.37807049, -11.0602616 ,  -4.83277761,   0.        ]]))
522 //下面是样本分类结果:
523 //(DATA,编号,原类别)
524 //(i,属于类别i的概率)
525 //......
526 (DATA , 1, 2.0)
527 (1,   , 4.549426925361446e-15)
528 (2,   , 0.9999999999999954)
529 (DATA , 2, 2.0)
530 (1,   , 7.437806286527796e-15)
531 (2,   , 0.9999999999999926)
532 (DATA , 3, 2.0)
533 (1,   , 7.722102243838512e-17)
534 (2,   , 0.9999999999999999)
535 (DATA , 4, 2.0)
536 (1,   , 1.249562190027437e-24)
537 (2,   , 1.0)
538 (DATA , 5, 1.0)
539 (1,   , 1.0)
540 (2,   , 9.503727473427473e-17)
541 (DATA , 6, 1.0)
542 (1,   , 1.0)
543 (2,   , 2.3147075539380976e-23)
544 (DATA , 7, 1.0)
545 (1,   , 0.999999999999999)
546 (2,   , 1.027428387144679e-15)
547 (DATA , 8, 1.0)
548 (1,   , 1.0)
549 (2,   , 7.955629841599608e-27)
550 (DATA , 9, 1.0)
551 (1,   , 1.0)
552 (2,   , 4.604559079838839e-22)
553 
554 Process finished with exit code 0
View Code

 

 

测试时使用的iris数据集的前两类

 

 

Reference:

http://www.cnblogs.com/neopenx/p/4316611.html

http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92

[Exercise]softmax Regression

标签:

原文地址:http://www.cnblogs.com/pdev/p/4589952.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!