码迷,mamicode.com
首页 > 其他好文 > 详细

机器学习实战笔记--朴素贝叶斯

时间:2017-01-07 19:37:40      阅读:421      评论:0      收藏:0      [点我收藏+]

标签:text   top   range   xtend   cut   email   并集   odi   for   

  1 #encoding:utf-8
  2 from numpy import *
  3 import feedparser
  4 
  5 #加载数据集
  6 def loadDataSet():
  7     postingList = [[my, dog, has, flea, problems, help, please],
  8                    [maybe, not, take, him, to, dog, park, stupid],
  9                    [my, dalmation, is, so, cute, I, love, him],
 10                    [stop, posting, stupid, worthless, garbage],
 11                    [mr, licks, ate, my, steak, how, to, stop, him],
 12                    [quit, buying, worthless, dog, food, stupid]]
 13     classVec = [0, 1, 0, 1, 0, 1]  # 1表示侮辱性言论,0表示正常言论
 14     return postingList, classVec
 15 
 16 
 17 def createVocabList(dataSet):  #得到词汇集合
 18     vocabSet = set([])
 19     for document in dataSet:
 20         vocabSet = vocabSet | set(document)  #两个集合的并集
 21     return list(vocabSet)
 22 
 23 
 24 def setOfWords2Vec(vocabList, inputSet):  #内容转化为向量
 25     returnVec = [0] * len(vocabList)   #所有元素都为0的向量
 26     for word in inputSet:
 27         if word in vocabList:
 28             returnVec[vocabList.index(word)] = 1   #如果有这个词条,则置1
 29         else:
 30             print "这个词条: %s 不在我的词典中!" % word
 31     return returnVec
 32 
 33 
 34 def trainNB0(trainMatrix, trainCategory):   #朴素贝叶斯分类器训练函数
 35     numTrainDocs = len(trainMatrix)          #文档数量
 36     numWords = len(trainMatrix[0])           #每篇文档的词条数
 37     pAbusive = sum(trainCategory) / float(numTrainDocs)   #侮辱性文档的比例
 38     p0Num = ones(numWords)
 39     p1Num = ones(numWords)  # change to ones()
 40     p0Denom = 2.0
 41     p1Denom = 2.0  # change to 2.0
 42     for i in range(numTrainDocs):
 43         if trainCategory[i] == 1:
 44             p1Num += trainMatrix[i]
 45             p1Denom += sum(trainMatrix[i])   #侮辱性文档的总词数
 46         else:
 47             p0Num += trainMatrix[i]
 48             p0Denom += sum(trainMatrix[i])  #正常文档的总词数
 49     p1Vect = log(p1Num / p1Denom)  # change to log()  表示p(wi/c1)
 50     p0Vect = log(p0Num / p0Denom)  # change to log()  表示p(wi/c0)
 51     return p0Vect, p1Vect, pAbusive
 52 
 53 
 54 def classifyNB(vec2Classify, p0Vec, p1Vec, pClass1):    #对输入的vec2Classify进行分类
 55     p1 = sum(vec2Classify * p1Vec) + log(pClass1)  # element-wise mult
 56     p0 = sum(vec2Classify * p0Vec) + log(1.0 - pClass1)
 57     if p1 > p0:
 58         return 1
 59     else:
 60         return 0
 61 
 62 
 63 def testingNB():   #是对以上方法的整合,方便运行和调试
 64     listOPosts, listClasses = loadDataSet()   #加载数据
 65     myVocabList = createVocabList(listOPosts)   #得到字典
 66     trainMat = []
 67     for postinDoc in listOPosts:
 68         trainMat.append(setOfWords2Vec(myVocabList, postinDoc))     #转化为0,1向量
 69     p0V, p1V, pAb = trainNB0(array(trainMat), array(listClasses))
 70     testEntry = [love, my, dalmation]           #代分类的输入
 71     thisDoc = array(setOfWords2Vec(myVocabList, testEntry))    #转化
 72     print testEntry, classified as: , classifyNB(thisDoc, p0V, p1V, pAb)   #进行分类
 73     testEntry = [stupid, garbage]
 74     thisDoc = array(setOfWords2Vec(myVocabList, testEntry))
 75     print testEntry, classified as: , classifyNB(thisDoc, p0V, p1V, pAb)
 76 
 77 
 78 def bagOfWords2VecMN(vocabList, inputSet):
 79     returnVec = [0] * len(vocabList)
 80     for word in inputSet:
 81         if word in vocabList:
 82             returnVec[vocabList.index(word)] += 1
 83     return returnVec
 84 
 85 
 86 def textParse(bigString):  # 处理字符串得到字符串列表,并过滤
 87     import re
 88     listOfTokens = re.split(r\W*, bigString)
 89     return [tok.lower() for tok in listOfTokens if len(tok) > 2]
 90 
 91 #垃圾邮件测试函数
 92 def spamTest():
 93     docList = [];
 94     classList = [];
 95     fullText = []
 96     for i in range(1, 26):
 97         wordList = textParse(open(email/spam/%d.txt % i).read())   #读取文件
 98         docList.append(wordList)         #注意append和extend
 99         fullText.extend(wordList)
100         classList.append(1)
101         wordList = textParse(open(email/ham/%d.txt % i).read())
102         docList.append(wordList)
103         fullText.extend(wordList)
104         classList.append(0)
105     vocabList = createVocabList(docList)  # 建立字典
106     trainingSet = range(50);
107     testSet = []
108     for i in range(10):
109         randIndex = int(random.uniform(0, len(trainingSet)))   #随机取10个作为测试集
110         testSet.append(trainingSet[randIndex])
111         del (trainingSet[randIndex])    #从训练集中删除
112     trainMat = [];
113     trainClasses = []
114     for docIndex in trainingSet:  # 训练
115         trainMat.append(bagOfWords2VecMN(vocabList, docList[docIndex]))
116         trainClasses.append(classList[docIndex])
117     p0V, p1V, pSpam = trainNB0(array(trainMat), array(trainClasses))
118     errorCount = 0
119     for docIndex in testSet:  # 分类
120         #wordVector = bagOfWords2VecMN(vocabList, docList[docIndex])
121         wordVector = setOfWords2Vec(vocabList, docList[docIndex])
122         if classifyNB(array(wordVector), p0V, p1V, pSpam) != classList[docIndex]:
123             errorCount += 1
124             print "classification error", docList[docIndex]
125     print the error rate is: , float(errorCount) / len(testSet)
126     # return vocabList,fullText
127 
128 
129 def calcMostFreq(vocabList, fullText):   #计算出现频率
130     import operator
131     freqDict = {}
132     for token in vocabList:
133         freqDict[token] = fullText.count(token)
134     sortedFreq = sorted(freqDict.iteritems(), key=operator.itemgetter(1), reverse=True)  #排序
135     return sortedFreq[:30]  #返回前30
136 
137 
138 def localWords(feed1, feed0):
139     import feedparser
140     docList = [];
141     classList = [];
142     fullText = []
143     minLen = min(len(feed1[entries]), len(feed0[entries]))
144     for i in range(minLen):
145         wordList = textParse(feed1[entries][i][summary])    #每次访问一条RSS源
146         docList.append(wordList)
147         fullText.extend(wordList)
148         classList.append(1)  # NY is class 1
149         wordList = textParse(feed0[entries][i][summary])
150         docList.append(wordList)
151         fullText.extend(wordList)
152         classList.append(0)
153     vocabList = createVocabList(docList)  # 建立字典
154     top30Words = calcMostFreq(vocabList, fullText)  # 去掉次数最高的前30个词
155     for pairW in top30Words:
156         if pairW[0] in vocabList: vocabList.remove(pairW[0])
157     trainingSet = range(2 * minLen);
158     testSet = []  # create test set
159     for i in range(20):
160         randIndex = int(random.uniform(0, len(trainingSet)))
161         testSet.append(trainingSet[randIndex])
162         del (trainingSet[randIndex])
163     trainMat = [];
164     trainClasses = []
165     for docIndex in trainingSet:  # train the classifier (get probs) trainNB0
166         trainMat.append(bagOfWords2VecMN(vocabList, docList[docIndex]))
167         trainClasses.append(classList[docIndex])
168     p0V, p1V, pSpam = trainNB0(array(trainMat), array(trainClasses))
169     errorCount = 0
170     for docIndex in testSet:  # classify the remaining items
171         wordVector = bagOfWords2VecMN(vocabList, docList[docIndex])
172         if classifyNB(array(wordVector), p0V, p1V, pSpam) != classList[docIndex]:
173             errorCount += 1
174     print the error rate is: , float(errorCount) / len(testSet)
175     return vocabList, p0V, p1V
176 
177 
178 def getTopWords(ny, sf):
179     import operator
180     vocabList, p0V, p1V = localWords(ny, sf)
181     topNY = [];
182     topSF = []
183     for i in range(len(p0V)):
184         if p0V[i] > -6.0: topSF.append((vocabList[i], p0V[i]))
185         if p1V[i] > -6.0: topNY.append((vocabList[i], p1V[i]))
186     sortedSF = sorted(topSF, key=lambda pair: pair[1], reverse=True)
187     print "SF**SF**SF**SF**SF**SF**SF**SF**SF**SF**SF**SF**SF**SF**SF**SF**"
188     for item in sortedSF:
189         print item[0]
190     sortedNY = sorted(topNY, key=lambda pair: pair[1], reverse=True)
191     print "NY**NY**NY**NY**NY**NY**NY**NY**NY**NY**NY**NY**NY**NY**NY**NY**"
192     for item in sortedNY:
193         print item[0]
194 
195 if __name__ == __main__:
196     # postingList, classVec = loadDataSet()
197     # vocabList = createVocabList(postingList)
198     # trainMat = []
199     # for line in postingList:
200     #     trainMat.append(setOfWords2Vec(vocabList,line))
201     # p0V,p1V,p = trainNB0(trainMat,classVec)
202     # print p0V
203     # print p1V
204     # print p
205     # spamTest()
206     ny = feedparser.parse(http://newyork.craigslist.org/stp/index.rss)
207     sf = feedparser.parse(http://sfbay.craigslist.org/stp/index.rss)
208     getTopWords(ny,sf)

 

机器学习实战笔记--朴素贝叶斯

标签:text   top   range   xtend   cut   email   并集   odi   for   

原文地址:http://www.cnblogs.com/yzwhykd/p/6259945.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!