Acknowledgement to Stanford CS229.
Generative modeling is itself a kind of unsupervised learning task[1]. Given unlabelled data,
To estimate the parameters, we can write the likelihood as
which is also
The EM algorithm can solve this pdf estimation iteratively.
An example is provided here. The data points are drawn from 2 gaussian distributions.
1 import numpy as np 2 import operator 3 np.random.seed(0) 4 x0=np.random.normal(0,1,50) 5 x0=np.concatenate((x0,np.random.normal(2,1,50)), 6 axis=0) 7 8 mus0=np.array([ 9 0,1 10 ]) 11 sigmas0=np.array([ 12 2,2 13 ]) 14 def gauss(x,mu,sigma): 15 """ 16 17 :param x: 18 :param mu: 19 :param sigma: 20 :return: pdf(x) 21 """ 22 # if np.abs((x-mu)/sigma)<1e-5: 23 # return 24 # numerator=np.exp( 25 # -(x-mu)**2/(2*sigma**2) 26 # ) 27 numerator=np.exp( 28 -0.5*((x-mu)/sigma)**2 29 ) 30 denominator=np.sqrt(2*np.pi*sigma**2) 31 return numerator/denominator 32 def e_step(mus=mus0,sigmas=sigmas0,x=x0,priors=np.ones(len(mus0))/len(mus0)): 33 """ 34 35 :param mus: gaussian centers, an array of shape (m,) 36 :param sigmas: gaussian standard deviations, an array of shape (m,) 37 :param x: n samples with no labels 38 :return: m by n array, where m is # classes 39 """ 40 assert len(mus)==len(sigmas),"mus and sigmas doesn‘t have the same length" 41 m=len(mus) 42 n=len(x) 43 w=np.zeros(shape=(m,n)) 44 for j in range(m): 45 for i in range(n): 46 w[j][i]=gauss(x=x[i],mu=mus[j],sigma=sigmas[j])*priors[j] 47 w_sum_wrt_j=np.sum(w,axis=0)#note j is the row index 48 for j in range(m): 49 w[j,:]=w[j,:]/w_sum_wrt_j 50 return w 51 def m_step(w,current_mus,x=x0): 52 """ 53 54 :param w: m by n array, where m is # classes 55 :return: mus: gaussian centers, an array of shape (m,) 56 sigmas: gaussian standard deviations, an array of shape (m,) 57 """ 58 m,n=w.shape 59 mus=np.zeros(shape=(m)) 60 sigmas=np.zeros(shape=(m)) 61 for j in range(m): 62 mus[j]=np.dot( 63 w[j,:],x 64 ) 65 mus/=np.sum(w,axis=1) 66 for j in range(m): 67 sigmas[j]=np.sqrt(np.dot( 68 w[j, :], (x-current_mus[j])**2 69 )) 70 sigmas/=np.sqrt(np.sum(w,axis=1)) 71 72 priors=np.zeros(shape=(len(mus))) 73 for i in range(n): 74 tmp=list(map( 75 gauss,[x[i]]*m,mus,sigmas 76 )) 77 tmpmaxindex,tmpmax=max( 78 enumerate(tmp),key=operator.itemgetter(1) 79 ) 80 # print(tmp) 81 # print(tmpmaxindex) 82 priors[tmpmaxindex]+=1/n 83 return mus,sigmas,priors 84 def solve(x=x0,priors=np.ones(len(mus0))/len(mus0)): 85 # print("priors={}".format(priors)) 86 mus=mus0 87 sigmas=sigmas0 88 for k in range(500): 89 w=e_step(mus=mus,sigmas=sigmas,x=x,priors=priors) 90 mus,sigmas,priors=m_step(w,current_mus=mus,x=x0) 91 print("k={},mus={},sigmas={},priors={}".format(k,mus,sigmas,priors)) 92 93 if __name__ == ‘__main__‘: 94 solve()
After 100 iterations, we get an approximation of the real model.
1 /usr/local/bin/python3.5 /home/csdl/review/fulcrum/gmm/gmm.py 2 k=0,mus=[ 0.81734343 1.27122747],sigmas=[ 1.60216343 1.33905931],priors=[ 0.32 0.68] 3 k=1,mus=[ 0.73393663 1.21263431],sigmas=[ 1.48989073 1.27140643],priors=[ 0.35 0.65] 4 k=2,mus=[ 0.72025041 1.24207148],sigmas=[ 1.47760392 1.25840835],priors=[ 0.36 0.64] 5 k=3,mus=[ 0.69405554 1.2656654 ],sigmas=[ 1.47453155 1.2480128 ],priors=[ 0.36 0.64] 6 k=4,mus=[ 0.65993336 1.28545741],sigmas=[ 1.47454151 1.238417 ],priors=[ 0.36 0.64] 7 k=5,mus=[ 0.62512005 1.30527053],sigmas=[ 1.4739642 1.22830782],priors=[ 0.36 0.64] 8 k=6,mus=[ 0.59009448 1.32522573],sigmas=[ 1.47230468 1.21788024],priors=[ 0.36 0.64] 9 k=7,mus=[ 0.55504913 1.34523959],sigmas=[ 1.46932309 1.20732602],priors=[ 0.36 0.64] 10 k=8,mus=[ 0.52016003 1.36521637],sigmas=[ 1.46489424 1.19678812],priors=[ 0.36 0.64] 11 k=9,mus=[ 0.4855794 1.38507002],sigmas=[ 1.45897578 1.18636451],priors=[ 0.36 0.64] 12 k=10,mus=[ 0.45142496 1.4047313 ],sigmas=[ 1.45158393 1.17611449],priors=[ 0.36 0.64] 13 k=11,mus=[ 0.41777707 1.42414967],sigmas=[ 1.44277296 1.16606539],priors=[ 0.36 0.64] 14 k=12,mus=[ 0.38468177 1.44329208],sigmas=[ 1.43261873 1.15621962],priors=[ 0.37 0.63] 15 k=13,mus=[ 0.36595587 1.46867892],sigmas=[ 1.41990091 1.14409883],priors=[ 0.37 0.63] 16 k=14,mus=[ 0.33654056 1.48870368],sigmas=[ 1.4082571 1.13343039],priors=[ 0.37 0.63] 17 k=15,mus=[ 0.30597566 1.50763142],sigmas=[ 1.39543174 1.12335036],priors=[ 0.37 0.63] 18 k=16,mus=[ 0.27568252 1.52609593],sigmas=[ 1.38137316 1.11360905],priors=[ 0.37 0.63] 19 k=17,mus=[ 0.24588996 1.54419117],sigmas=[ 1.36625212 1.10407072],priors=[ 0.37 0.63] 20 k=18,mus=[ 0.21664299 1.56192385],sigmas=[ 1.35022455 1.09464684],priors=[ 0.37 0.63] 21 k=19,mus=[ 0.18796432 1.57928065],sigmas=[ 1.33342219 1.0852798 ],priors=[ 0.37 0.63] 22 k=20,mus=[ 0.1598861 1.59623644],sigmas=[ 1.31596643 1.07593506],priors=[ 0.37 0.63] 23 k=21,mus=[ 0.13245872 1.61275414],sigmas=[ 1.2979812 1.06659735],priors=[ 0.39 0.61] 24 k=22,mus=[ 0.13549936 1.64420575],sigmas=[ 1.28163006 1.05238461],priors=[ 0.4 0.6] 25 k=23,mus=[ 0.13362832 1.67212388],sigmas=[ 1.26763647 1.03761078],priors=[ 0.4 0.6] 26 k=24,mus=[ 0.11750175 1.69125511],sigmas=[ 1.25330002 1.02525138],priors=[ 0.41 0.59] 27 k=25,mus=[ 0.11224826 1.71494289],sigmas=[ 1.23950504 1.01230038],priors=[ 0.42 0.58] 28 k=26,mus=[ 0.11153847 1.7395662 ],sigmas=[ 1.22728752 0.99889453],priors=[ 0.42 0.58] 29 k=27,mus=[ 0.0999276 1.75644918],sigmas=[ 1.21474604 0.98770556],priors=[ 0.43 0.57] 30 k=28,mus=[ 0.09911993 1.77770261],sigmas=[ 1.20375615 0.97601043],priors=[ 0.43 0.57] 31 k=29,mus=[ 0.08991339 1.79234269],sigmas=[ 1.19274093 0.96620904],priors=[ 0.43 0.57] 32 k=30,mus=[ 0.07854133 1.80401995],sigmas=[ 1.18163507 0.95803992],priors=[ 0.43 0.57] 33 k=31,mus=[ 0.06708472 1.81391145],sigmas=[ 1.1708709 0.9510143],priors=[ 0.43 0.57] 34 k=32,mus=[ 0.05629168 1.82248392],sigmas=[ 1.16077864 0.94483468],priors=[ 0.43 0.57] 35 k=33,mus=[ 0.04644144 1.8299628 ],sigmas=[ 1.15153082 0.93934709],priors=[ 0.43 0.57] 36 k=34,mus=[ 0.03761987 1.83648519],sigmas=[ 1.14319449 0.93447086],priors=[ 0.43 0.57] 37 k=35,mus=[ 0.02982246 1.84215374],sigmas=[ 1.13577403 0.93015559],priors=[ 0.43 0.57] 38 k=36,mus=[ 0.02299928 1.84705693],sigmas=[ 1.12923679 0.92636035],priors=[ 0.43 0.57] 39 k=37,mus=[ 0.01707735 1.85127645],sigmas=[ 1.12352817 0.92304533],priors=[ 0.43 0.57] 40 k=38,mus=[ 0.01197298 1.85488949],sigmas=[ 1.11858109 0.92016939],priors=[ 0.43 0.57] 41 k=39,mus=[ 0.00759925 1.85796875],sigmas=[ 1.11432244 0.91769023],priors=[ 0.43 0.57] 42 k=40,mus=[ 0.00387068 1.86058202],sigmas=[ 1.11067765 0.91556544],priors=[ 0.43 0.57] 43 k=41,mus=[ 7.06082311e-04 1.86279150e+00],sigmas=[ 1.10757393 0.91375369],priors=[ 0.43 0.57] 44 k=42,mus=[-0.00196965 1.86465346],sigmas=[ 1.10494246 0.91221583],priors=[ 0.43 0.57] 45 k=43,mus=[-0.00422464 1.86621814],sigmas=[ 1.10271971 0.91091554],priors=[ 0.43 0.57] 46 k=44,mus=[-0.00611974 1.86752982],sigmas=[ 1.10084819 0.9098198 ],priors=[ 0.43 0.57] 47 k=45,mus=[-0.00770859 1.86862714],sigmas=[ 1.09927671 0.90889909],priors=[ 0.43 0.57] 48 k=46,mus=[-0.00903796 1.86954354],sigmas=[ 1.09796019 0.90812732],priors=[ 0.43 0.57] 49 k=47,mus=[-0.01014832 1.87030773],sigmas=[ 1.09685943 0.90748172],priors=[ 0.43 0.57] 50 k=48,mus=[-0.01107441 1.87094421],sigmas=[ 1.09594057 0.90694261],priors=[ 0.43 0.57] 51 k=49,mus=[-0.01184586 1.87147378],sigmas=[ 1.09517461 0.90649307],priors=[ 0.43 0.57] 52 k=50,mus=[-0.01248783 1.87191401],sigmas=[ 1.09453685 0.90611867],priors=[ 0.43 0.57] 53 k=51,mus=[-0.01302159 1.87227973],sigmas=[ 1.09400634 0.90580718],priors=[ 0.43 0.57] 54 k=52,mus=[-0.01346505 1.87258336],sigmas=[ 1.09356541 0.90554823],priors=[ 0.43 0.57] 55 k=53,mus=[-0.01383328 1.87283531],sigmas=[ 1.09319917 0.90533313],priors=[ 0.43 0.57] 56 k=54,mus=[-0.01413888 1.87304431],sigmas=[ 1.09289515 0.90515454],priors=[ 0.43 0.57] 57 k=55,mus=[-0.0143924 1.87321761],sigmas=[ 1.09264288 0.90500635],priors=[ 0.43 0.57] 58 k=56,mus=[-0.01460264 1.87336127],sigmas=[ 1.09243365 0.90488343],priors=[ 0.43 0.57] 59 k=57,mus=[-0.01477693 1.87348033],sigmas=[ 1.09226016 0.9047815 ],priors=[ 0.43 0.57] 60 k=58,mus=[-0.01492139 1.87357899],sigmas=[ 1.09211635 0.90469701],priors=[ 0.43 0.57] 61 k=59,mus=[-0.0150411 1.87366073],sigmas=[ 1.09199717 0.90462698],priors=[ 0.43 0.57] 62 k=60,mus=[-0.01514028 1.87372844],sigmas=[ 1.09189842 0.90456896],priors=[ 0.43 0.57] 63 k=61,mus=[-0.01522245 1.87378452],sigmas=[ 1.09181661 0.90452088],priors=[ 0.43 0.57] 64 k=62,mus=[-0.01529051 1.87383097],sigmas=[ 1.09174884 0.90448106],priors=[ 0.43 0.57] 65 k=63,mus=[-0.01534687 1.87386944],sigmas=[ 1.0916927 0.90444807],priors=[ 0.43 0.57] 66 k=64,mus=[-0.01539356 1.87390129],sigmas=[ 1.09164621 0.90442075],priors=[ 0.43 0.57] 67 k=65,mus=[-0.01543222 1.87392767],sigmas=[ 1.09160771 0.90439813],priors=[ 0.43 0.57] 68 k=66,mus=[-0.01546423 1.87394951],sigmas=[ 1.09157583 0.90437939],priors=[ 0.43 0.57] 69 k=67,mus=[-0.01549074 1.87396759],sigmas=[ 1.09154943 0.90436388],priors=[ 0.43 0.57] 70 k=68,mus=[-0.01551269 1.87398257],sigmas=[ 1.09152757 0.90435103],priors=[ 0.43 0.57] 71 k=69,mus=[-0.01553086 1.87399496],sigmas=[ 1.09150947 0.9043404 ],priors=[ 0.43 0.57] 72 k=70,mus=[-0.0155459 1.87400523],sigmas=[ 1.09149449 0.90433159],priors=[ 0.43 0.57] 73 k=71,mus=[-0.01555836 1.87401373],sigmas=[ 1.09148208 0.9043243 ],priors=[ 0.43 0.57] 74 k=72,mus=[-0.01556868 1.87402076],sigmas=[ 1.09147181 0.90431826],priors=[ 0.43 0.57] 75 k=73,mus=[-0.01557722 1.87402659],sigmas=[ 1.0914633 0.90431327],priors=[ 0.43 0.57] 76 k=74,mus=[-0.01558428 1.87403141],sigmas=[ 1.09145626 0.90430913],priors=[ 0.43 0.57] 77 k=75,mus=[-0.01559014 1.8740354 ],sigmas=[ 1.09145043 0.9043057 ],priors=[ 0.43 0.57] 78 k=76,mus=[-0.01559498 1.87403871],sigmas=[ 1.09144561 0.90430287],priors=[ 0.43 0.57] 79 k=77,mus=[-0.01559899 1.87404144],sigmas=[ 1.09144161 0.90430052],priors=[ 0.43 0.57] 80 k=78,mus=[-0.01560232 1.87404371],sigmas=[ 1.0914383 0.90429857],priors=[ 0.43 0.57] 81 k=79,mus=[-0.01560506 1.87404558],sigmas=[ 1.09143556 0.90429696],priors=[ 0.43 0.57] 82 k=80,mus=[-0.01560734 1.87404714],sigmas=[ 1.0914333 0.90429563],priors=[ 0.43 0.57] 83 k=81,mus=[-0.01560923 1.87404842],sigmas=[ 1.09143142 0.90429453],priors=[ 0.43 0.57] 84 k=82,mus=[-0.01561079 1.87404948],sigmas=[ 1.09142987 0.90429362],priors=[ 0.43 0.57] 85 k=83,mus=[-0.01561208 1.87405037],sigmas=[ 1.09142858 0.90429286],priors=[ 0.43 0.57] 86 k=84,mus=[-0.01561315 1.8740511 ],sigmas=[ 1.09142751 0.90429223],priors=[ 0.43 0.57] 87 k=85,mus=[-0.01561403 1.8740517 ],sigmas=[ 1.09142663 0.90429172],priors=[ 0.43 0.57] 88 k=86,mus=[-0.01561476 1.8740522 ],sigmas=[ 1.0914259 0.90429129],priors=[ 0.43 0.57] 89 k=87,mus=[-0.01561537 1.87405261],sigmas=[ 1.0914253 0.90429093],priors=[ 0.43 0.57] 90 k=88,mus=[-0.01561587 1.87405295],sigmas=[ 1.0914248 0.90429064],priors=[ 0.43 0.57] 91 k=89,mus=[-0.01561629 1.87405324],sigmas=[ 1.09142438 0.90429039],priors=[ 0.43 0.57] 92 k=90,mus=[-0.01561663 1.87405347],sigmas=[ 1.09142404 0.90429019],priors=[ 0.43 0.57] 93 k=91,mus=[-0.01561692 1.87405367],sigmas=[ 1.09142376 0.90429003],priors=[ 0.43 0.57] 94 k=92,mus=[-0.01561715 1.87405383],sigmas=[ 1.09142352 0.90428989],priors=[ 0.43 0.57] 95 k=93,mus=[-0.01561735 1.87405396],sigmas=[ 1.09142333 0.90428977],priors=[ 0.43 0.57] 96 k=94,mus=[-0.01561751 1.87405407],sigmas=[ 1.09142317 0.90428968],priors=[ 0.43 0.57] 97 k=95,mus=[-0.01561764 1.87405416],sigmas=[ 1.09142303 0.9042896 ],priors=[ 0.43 0.57] 98 k=96,mus=[-0.01561775 1.87405424],sigmas=[ 1.09142292 0.90428954],priors=[ 0.43 0.57] 99 k=97,mus=[-0.01561785 1.8740543 ],sigmas=[ 1.09142283 0.90428948],priors=[ 0.43 0.57] 100 k=98,mus=[-0.01561792 1.87405435],sigmas=[ 1.09142276 0.90428944],priors=[ 0.43 0.57] 101 k=99,mus=[-0.01561799 1.8740544 ],sigmas=[ 1.09142269 0.9042894 ],priors=[ 0.43 0.57] 102 103 Process finished with exit code 0
In addition, a scikit-learn example can be found at http://scikit-learn.org/stable/modules/mixture.html
[1] Ian Goodfellow. https://www.quora.com/Why-could-generative-models-help-with-unsupervised-learning/answer/Ian-Goodfellow?srid=hTUVm