码迷,mamicode.com
首页 > 其他好文 > 详细

Mixtures of Gaussians and the EM algorithms

时间:2018-01-18 23:06:51      阅读:182      评论:0      收藏:0      [点我收藏+]

标签:def   add   operator   gmm   classes   belle   https   concat   tmp   

Acknowledgement to Stanford CS229.

Generative modeling is itself a kind of unsupervised learning task[1]. Given unlabelled data, 

技术分享图片

To estimate the parameters, we can write the likelihood as 

技术分享图片

which is also

技术分享图片

The EM algorithm can solve this pdf estimation iteratively.

技术分享图片

技术分享图片

技术分享图片

技术分享图片

技术分享图片

An example is provided here. The data points are drawn from 2 gaussian distributions. 

 1 import numpy as np
 2 import operator
 3 np.random.seed(0)
 4 x0=np.random.normal(0,1,50)
 5 x0=np.concatenate((x0,np.random.normal(2,1,50)),
 6                  axis=0)
 7 
 8 mus0=np.array([
 9     0,1
10 ])
11 sigmas0=np.array([
12     2,2
13 ])
14 def gauss(x,mu,sigma):
15     """
16 
17     :param x:
18     :param mu:
19     :param sigma:
20     :return: pdf(x)
21     """
22     # if np.abs((x-mu)/sigma)<1e-5:
23     #     return
24     # numerator=np.exp(
25     #     -(x-mu)**2/(2*sigma**2)
26     # )
27     numerator=np.exp(
28         -0.5*((x-mu)/sigma)**2
29     )
30     denominator=np.sqrt(2*np.pi*sigma**2)
31     return numerator/denominator
32 def e_step(mus=mus0,sigmas=sigmas0,x=x0,priors=np.ones(len(mus0))/len(mus0)):
33     """
34 
35     :param mus: gaussian centers, an array of shape (m,)
36     :param sigmas: gaussian standard deviations, an array of shape (m,)
37     :param x: n samples with no labels
38     :return: m by n array, where m is # classes
39     """
40     assert len(mus)==len(sigmas),"mus and sigmas doesn‘t have the same length"
41     m=len(mus)
42     n=len(x)
43     w=np.zeros(shape=(m,n))
44     for j in range(m):
45         for i in range(n):
46             w[j][i]=gauss(x=x[i],mu=mus[j],sigma=sigmas[j])*priors[j]
47     w_sum_wrt_j=np.sum(w,axis=0)#note j is the row index
48     for j in range(m):
49         w[j,:]=w[j,:]/w_sum_wrt_j
50     return w
51 def m_step(w,current_mus,x=x0):
52     """
53 
54     :param w: m by n array, where m is # classes
55     :return: mus: gaussian centers, an array of shape (m,)
56              sigmas: gaussian standard deviations, an array of shape (m,)
57     """
58     m,n=w.shape
59     mus=np.zeros(shape=(m))
60     sigmas=np.zeros(shape=(m))
61     for j in range(m):
62         mus[j]=np.dot(
63             w[j,:],x
64         )
65     mus/=np.sum(w,axis=1)
66     for j in range(m):
67         sigmas[j]=np.sqrt(np.dot(
68             w[j, :], (x-current_mus[j])**2
69         ))
70     sigmas/=np.sqrt(np.sum(w,axis=1))
71 
72     priors=np.zeros(shape=(len(mus)))
73     for i in range(n):
74         tmp=list(map(
75             gauss,[x[i]]*m,mus,sigmas
76         ))
77         tmpmaxindex,tmpmax=max(
78             enumerate(tmp),key=operator.itemgetter(1)
79         )
80         # print(tmp)
81         # print(tmpmaxindex)
82         priors[tmpmaxindex]+=1/n
83     return mus,sigmas,priors
84 def solve(x=x0,priors=np.ones(len(mus0))/len(mus0)):
85     # print("priors={}".format(priors))
86     mus=mus0
87     sigmas=sigmas0
88     for k in range(500):
89         w=e_step(mus=mus,sigmas=sigmas,x=x,priors=priors)
90         mus,sigmas,priors=m_step(w,current_mus=mus,x=x0)
91         print("k={},mus={},sigmas={},priors={}".format(k,mus,sigmas,priors))
92 
93 if __name__ == __main__:
94     solve()

After 100 iterations, we get an approximation of the real model.

  1 /usr/local/bin/python3.5 /home/csdl/review/fulcrum/gmm/gmm.py
  2 k=0,mus=[ 0.81734343  1.27122747],sigmas=[ 1.60216343  1.33905931],priors=[ 0.32  0.68]
  3 k=1,mus=[ 0.73393663  1.21263431],sigmas=[ 1.48989073  1.27140643],priors=[ 0.35  0.65]
  4 k=2,mus=[ 0.72025041  1.24207148],sigmas=[ 1.47760392  1.25840835],priors=[ 0.36  0.64]
  5 k=3,mus=[ 0.69405554  1.2656654 ],sigmas=[ 1.47453155  1.2480128 ],priors=[ 0.36  0.64]
  6 k=4,mus=[ 0.65993336  1.28545741],sigmas=[ 1.47454151  1.238417  ],priors=[ 0.36  0.64]
  7 k=5,mus=[ 0.62512005  1.30527053],sigmas=[ 1.4739642   1.22830782],priors=[ 0.36  0.64]
  8 k=6,mus=[ 0.59009448  1.32522573],sigmas=[ 1.47230468  1.21788024],priors=[ 0.36  0.64]
  9 k=7,mus=[ 0.55504913  1.34523959],sigmas=[ 1.46932309  1.20732602],priors=[ 0.36  0.64]
 10 k=8,mus=[ 0.52016003  1.36521637],sigmas=[ 1.46489424  1.19678812],priors=[ 0.36  0.64]
 11 k=9,mus=[ 0.4855794   1.38507002],sigmas=[ 1.45897578  1.18636451],priors=[ 0.36  0.64]
 12 k=10,mus=[ 0.45142496  1.4047313 ],sigmas=[ 1.45158393  1.17611449],priors=[ 0.36  0.64]
 13 k=11,mus=[ 0.41777707  1.42414967],sigmas=[ 1.44277296  1.16606539],priors=[ 0.36  0.64]
 14 k=12,mus=[ 0.38468177  1.44329208],sigmas=[ 1.43261873  1.15621962],priors=[ 0.37  0.63]
 15 k=13,mus=[ 0.36595587  1.46867892],sigmas=[ 1.41990091  1.14409883],priors=[ 0.37  0.63]
 16 k=14,mus=[ 0.33654056  1.48870368],sigmas=[ 1.4082571   1.13343039],priors=[ 0.37  0.63]
 17 k=15,mus=[ 0.30597566  1.50763142],sigmas=[ 1.39543174  1.12335036],priors=[ 0.37  0.63]
 18 k=16,mus=[ 0.27568252  1.52609593],sigmas=[ 1.38137316  1.11360905],priors=[ 0.37  0.63]
 19 k=17,mus=[ 0.24588996  1.54419117],sigmas=[ 1.36625212  1.10407072],priors=[ 0.37  0.63]
 20 k=18,mus=[ 0.21664299  1.56192385],sigmas=[ 1.35022455  1.09464684],priors=[ 0.37  0.63]
 21 k=19,mus=[ 0.18796432  1.57928065],sigmas=[ 1.33342219  1.0852798 ],priors=[ 0.37  0.63]
 22 k=20,mus=[ 0.1598861   1.59623644],sigmas=[ 1.31596643  1.07593506],priors=[ 0.37  0.63]
 23 k=21,mus=[ 0.13245872  1.61275414],sigmas=[ 1.2979812   1.06659735],priors=[ 0.39  0.61]
 24 k=22,mus=[ 0.13549936  1.64420575],sigmas=[ 1.28163006  1.05238461],priors=[ 0.4  0.6]
 25 k=23,mus=[ 0.13362832  1.67212388],sigmas=[ 1.26763647  1.03761078],priors=[ 0.4  0.6]
 26 k=24,mus=[ 0.11750175  1.69125511],sigmas=[ 1.25330002  1.02525138],priors=[ 0.41  0.59]
 27 k=25,mus=[ 0.11224826  1.71494289],sigmas=[ 1.23950504  1.01230038],priors=[ 0.42  0.58]
 28 k=26,mus=[ 0.11153847  1.7395662 ],sigmas=[ 1.22728752  0.99889453],priors=[ 0.42  0.58]
 29 k=27,mus=[ 0.0999276   1.75644918],sigmas=[ 1.21474604  0.98770556],priors=[ 0.43  0.57]
 30 k=28,mus=[ 0.09911993  1.77770261],sigmas=[ 1.20375615  0.97601043],priors=[ 0.43  0.57]
 31 k=29,mus=[ 0.08991339  1.79234269],sigmas=[ 1.19274093  0.96620904],priors=[ 0.43  0.57]
 32 k=30,mus=[ 0.07854133  1.80401995],sigmas=[ 1.18163507  0.95803992],priors=[ 0.43  0.57]
 33 k=31,mus=[ 0.06708472  1.81391145],sigmas=[ 1.1708709  0.9510143],priors=[ 0.43  0.57]
 34 k=32,mus=[ 0.05629168  1.82248392],sigmas=[ 1.16077864  0.94483468],priors=[ 0.43  0.57]
 35 k=33,mus=[ 0.04644144  1.8299628 ],sigmas=[ 1.15153082  0.93934709],priors=[ 0.43  0.57]
 36 k=34,mus=[ 0.03761987  1.83648519],sigmas=[ 1.14319449  0.93447086],priors=[ 0.43  0.57]
 37 k=35,mus=[ 0.02982246  1.84215374],sigmas=[ 1.13577403  0.93015559],priors=[ 0.43  0.57]
 38 k=36,mus=[ 0.02299928  1.84705693],sigmas=[ 1.12923679  0.92636035],priors=[ 0.43  0.57]
 39 k=37,mus=[ 0.01707735  1.85127645],sigmas=[ 1.12352817  0.92304533],priors=[ 0.43  0.57]
 40 k=38,mus=[ 0.01197298  1.85488949],sigmas=[ 1.11858109  0.92016939],priors=[ 0.43  0.57]
 41 k=39,mus=[ 0.00759925  1.85796875],sigmas=[ 1.11432244  0.91769023],priors=[ 0.43  0.57]
 42 k=40,mus=[ 0.00387068  1.86058202],sigmas=[ 1.11067765  0.91556544],priors=[ 0.43  0.57]
 43 k=41,mus=[  7.06082311e-04   1.86279150e+00],sigmas=[ 1.10757393  0.91375369],priors=[ 0.43  0.57]
 44 k=42,mus=[-0.00196965  1.86465346],sigmas=[ 1.10494246  0.91221583],priors=[ 0.43  0.57]
 45 k=43,mus=[-0.00422464  1.86621814],sigmas=[ 1.10271971  0.91091554],priors=[ 0.43  0.57]
 46 k=44,mus=[-0.00611974  1.86752982],sigmas=[ 1.10084819  0.9098198 ],priors=[ 0.43  0.57]
 47 k=45,mus=[-0.00770859  1.86862714],sigmas=[ 1.09927671  0.90889909],priors=[ 0.43  0.57]
 48 k=46,mus=[-0.00903796  1.86954354],sigmas=[ 1.09796019  0.90812732],priors=[ 0.43  0.57]
 49 k=47,mus=[-0.01014832  1.87030773],sigmas=[ 1.09685943  0.90748172],priors=[ 0.43  0.57]
 50 k=48,mus=[-0.01107441  1.87094421],sigmas=[ 1.09594057  0.90694261],priors=[ 0.43  0.57]
 51 k=49,mus=[-0.01184586  1.87147378],sigmas=[ 1.09517461  0.90649307],priors=[ 0.43  0.57]
 52 k=50,mus=[-0.01248783  1.87191401],sigmas=[ 1.09453685  0.90611867],priors=[ 0.43  0.57]
 53 k=51,mus=[-0.01302159  1.87227973],sigmas=[ 1.09400634  0.90580718],priors=[ 0.43  0.57]
 54 k=52,mus=[-0.01346505  1.87258336],sigmas=[ 1.09356541  0.90554823],priors=[ 0.43  0.57]
 55 k=53,mus=[-0.01383328  1.87283531],sigmas=[ 1.09319917  0.90533313],priors=[ 0.43  0.57]
 56 k=54,mus=[-0.01413888  1.87304431],sigmas=[ 1.09289515  0.90515454],priors=[ 0.43  0.57]
 57 k=55,mus=[-0.0143924   1.87321761],sigmas=[ 1.09264288  0.90500635],priors=[ 0.43  0.57]
 58 k=56,mus=[-0.01460264  1.87336127],sigmas=[ 1.09243365  0.90488343],priors=[ 0.43  0.57]
 59 k=57,mus=[-0.01477693  1.87348033],sigmas=[ 1.09226016  0.9047815 ],priors=[ 0.43  0.57]
 60 k=58,mus=[-0.01492139  1.87357899],sigmas=[ 1.09211635  0.90469701],priors=[ 0.43  0.57]
 61 k=59,mus=[-0.0150411   1.87366073],sigmas=[ 1.09199717  0.90462698],priors=[ 0.43  0.57]
 62 k=60,mus=[-0.01514028  1.87372844],sigmas=[ 1.09189842  0.90456896],priors=[ 0.43  0.57]
 63 k=61,mus=[-0.01522245  1.87378452],sigmas=[ 1.09181661  0.90452088],priors=[ 0.43  0.57]
 64 k=62,mus=[-0.01529051  1.87383097],sigmas=[ 1.09174884  0.90448106],priors=[ 0.43  0.57]
 65 k=63,mus=[-0.01534687  1.87386944],sigmas=[ 1.0916927   0.90444807],priors=[ 0.43  0.57]
 66 k=64,mus=[-0.01539356  1.87390129],sigmas=[ 1.09164621  0.90442075],priors=[ 0.43  0.57]
 67 k=65,mus=[-0.01543222  1.87392767],sigmas=[ 1.09160771  0.90439813],priors=[ 0.43  0.57]
 68 k=66,mus=[-0.01546423  1.87394951],sigmas=[ 1.09157583  0.90437939],priors=[ 0.43  0.57]
 69 k=67,mus=[-0.01549074  1.87396759],sigmas=[ 1.09154943  0.90436388],priors=[ 0.43  0.57]
 70 k=68,mus=[-0.01551269  1.87398257],sigmas=[ 1.09152757  0.90435103],priors=[ 0.43  0.57]
 71 k=69,mus=[-0.01553086  1.87399496],sigmas=[ 1.09150947  0.9043404 ],priors=[ 0.43  0.57]
 72 k=70,mus=[-0.0155459   1.87400523],sigmas=[ 1.09149449  0.90433159],priors=[ 0.43  0.57]
 73 k=71,mus=[-0.01555836  1.87401373],sigmas=[ 1.09148208  0.9043243 ],priors=[ 0.43  0.57]
 74 k=72,mus=[-0.01556868  1.87402076],sigmas=[ 1.09147181  0.90431826],priors=[ 0.43  0.57]
 75 k=73,mus=[-0.01557722  1.87402659],sigmas=[ 1.0914633   0.90431327],priors=[ 0.43  0.57]
 76 k=74,mus=[-0.01558428  1.87403141],sigmas=[ 1.09145626  0.90430913],priors=[ 0.43  0.57]
 77 k=75,mus=[-0.01559014  1.8740354 ],sigmas=[ 1.09145043  0.9043057 ],priors=[ 0.43  0.57]
 78 k=76,mus=[-0.01559498  1.87403871],sigmas=[ 1.09144561  0.90430287],priors=[ 0.43  0.57]
 79 k=77,mus=[-0.01559899  1.87404144],sigmas=[ 1.09144161  0.90430052],priors=[ 0.43  0.57]
 80 k=78,mus=[-0.01560232  1.87404371],sigmas=[ 1.0914383   0.90429857],priors=[ 0.43  0.57]
 81 k=79,mus=[-0.01560506  1.87404558],sigmas=[ 1.09143556  0.90429696],priors=[ 0.43  0.57]
 82 k=80,mus=[-0.01560734  1.87404714],sigmas=[ 1.0914333   0.90429563],priors=[ 0.43  0.57]
 83 k=81,mus=[-0.01560923  1.87404842],sigmas=[ 1.09143142  0.90429453],priors=[ 0.43  0.57]
 84 k=82,mus=[-0.01561079  1.87404948],sigmas=[ 1.09142987  0.90429362],priors=[ 0.43  0.57]
 85 k=83,mus=[-0.01561208  1.87405037],sigmas=[ 1.09142858  0.90429286],priors=[ 0.43  0.57]
 86 k=84,mus=[-0.01561315  1.8740511 ],sigmas=[ 1.09142751  0.90429223],priors=[ 0.43  0.57]
 87 k=85,mus=[-0.01561403  1.8740517 ],sigmas=[ 1.09142663  0.90429172],priors=[ 0.43  0.57]
 88 k=86,mus=[-0.01561476  1.8740522 ],sigmas=[ 1.0914259   0.90429129],priors=[ 0.43  0.57]
 89 k=87,mus=[-0.01561537  1.87405261],sigmas=[ 1.0914253   0.90429093],priors=[ 0.43  0.57]
 90 k=88,mus=[-0.01561587  1.87405295],sigmas=[ 1.0914248   0.90429064],priors=[ 0.43  0.57]
 91 k=89,mus=[-0.01561629  1.87405324],sigmas=[ 1.09142438  0.90429039],priors=[ 0.43  0.57]
 92 k=90,mus=[-0.01561663  1.87405347],sigmas=[ 1.09142404  0.90429019],priors=[ 0.43  0.57]
 93 k=91,mus=[-0.01561692  1.87405367],sigmas=[ 1.09142376  0.90429003],priors=[ 0.43  0.57]
 94 k=92,mus=[-0.01561715  1.87405383],sigmas=[ 1.09142352  0.90428989],priors=[ 0.43  0.57]
 95 k=93,mus=[-0.01561735  1.87405396],sigmas=[ 1.09142333  0.90428977],priors=[ 0.43  0.57]
 96 k=94,mus=[-0.01561751  1.87405407],sigmas=[ 1.09142317  0.90428968],priors=[ 0.43  0.57]
 97 k=95,mus=[-0.01561764  1.87405416],sigmas=[ 1.09142303  0.9042896 ],priors=[ 0.43  0.57]
 98 k=96,mus=[-0.01561775  1.87405424],sigmas=[ 1.09142292  0.90428954],priors=[ 0.43  0.57]
 99 k=97,mus=[-0.01561785  1.8740543 ],sigmas=[ 1.09142283  0.90428948],priors=[ 0.43  0.57]
100 k=98,mus=[-0.01561792  1.87405435],sigmas=[ 1.09142276  0.90428944],priors=[ 0.43  0.57]
101 k=99,mus=[-0.01561799  1.8740544 ],sigmas=[ 1.09142269  0.9042894 ],priors=[ 0.43  0.57]
102 
103 Process finished with exit code 0

  In addition, a scikit-learn example can be found at http://scikit-learn.org/stable/modules/mixture.html

[1] Ian Goodfellow. https://www.quora.com/Why-could-generative-models-help-with-unsupervised-learning/answer/Ian-Goodfellow?srid=hTUVm

Mixtures of Gaussians and the EM algorithms

标签:def   add   operator   gmm   classes   belle   https   concat   tmp   

原文地址:https://www.cnblogs.com/cxxszz/p/8313163.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!