码迷,mamicode.com
首页 > 其他好文 > 详细

pyspark 写 logistic regression

时间:2015-07-03 20:40:40      阅读:523      评论:0      收藏:0      [点我收藏+]

标签:

import random as rd
import math

class LogisticRegressionPySpark:
    def __init__(self,MaxItr=100,eps=0.01,c=0.1):
        self.max_itr = MaxItr
        self.eps = eps
        self.c = c
    
    def train(self,data):
        #data为RDD,每条数据的最后一项为类别的标签 0 或者1
        k = len(data.take(1)[0])
        #初始化w
        self.w = [rd.uniform(0,1) for i in range(k)]#第一个是截距b
        n = data.count()
        
        for i in range(self.max_itr):
            wadd = data.map(self.gradientDescent).reduce(lambda a,b:[a[i]+b[i] for i in range(k)]).collect()
            for i in range(k):
                #b没有加入正规化项,所以这里加了一个(i>0)
                self.w[i] += (wadd[i]/n-self.c*self.w[i]*(i>0))*self.eps
        
        return self.w
            
    def gradientDescent(self,x):
        h = 1/(1+math.exp(-sum(x[i]*self.w[i+1] for i in range(len(x)-1)))-self.w[0])
        if x[len(x)-1]==0:
            h = 1-h
        return [h if i==0 else h*x[i-1] for i in range(len(x))]
    
    def predict(self,data):
        return data.map(lambda x:1/(1+math.exp(-sum(self.w[0] if i==0 else self.w[i]*x[i-1] for i in range(len(x)+1)))))    

 

pyspark 写 logistic regression

标签:

原文地址:http://www.cnblogs.com/porco/p/4619580.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!