码迷,mamicode.com
首页 > 编程语言 > 详细

《统计学习方法》-逻辑回归笔记和python源码

时间:2016-05-12 15:23:00      阅读:346      评论:0      收藏:0      [点我收藏+]

标签:

逻辑回归(Logistic regression)

逻辑回归是统计学习中的经典分类方法。其多用在二分类{0,1}问题上。

定义1:

设X是连续随机变量,X服从逻辑回归分布是指X具有下列分布函数与密度函数:

技术分享

技术分享

分布函数属于逻辑斯谛函数,其图形是一条S形曲线。


定义2:

二项逻辑斯谛回归模型是如下条件概率分布:

技术分享

从上式可以看出,逻辑回归对线性回归经行了归一化操作,将输出范围规定在{0,1}。

现在来看,逻辑回归的的特点,几率,指一件事件发生的概率与不发生的概率的比值。对上式分别求对数,我们可得如下式子。

技术分享

这就是说,在逻辑回归模型中,输出Y=1的对数几率是输入x的线性函数。

对输入x经行分类的线性函数w*x,其值域为实数域。通过逻辑回归模型可以将线性函数转化为概率,

技术分享

这就意味着,线性函数值越接近正无穷,概率越接近1;线性函数值越接近负无穷,概率值越接近0。这样的模型称为逻辑回归模型。


损失函数:

如同,在感知机一节中一样,我们需要构造损失函数,更新权值参数。我们利用极大似然估计法估计模型参数,即w。极大似然估计法是已经知道结果,然后寻求使该结果成立的最大可能条件(条件即模型参数)

似然函数:

技术分享

对数似然函数:


技术分享

这样子,我们有了损失函数,这里我们只要将该函数极大化即可,求其最大值时的w即可。

优化求解:

梯度下降法

总是朝着负方向改变,直到找到极小值。初中数学中,对一个函数求导可以得到函

数在某一点的斜率k(表示函数的增长速率,朝着正方向改变),如果我们将斜率取负号

-k,那么就得到了朝着负方向增长的速率。在这里,由于我们要极大化对数似然函数,所

以在这里不用加负号。

更新公式:

技术分享

其中,alpha是学习率。

python源码:


#coding=utf-8
#author=altman
import numpy as np
import matplotlib.pyplot as plt
def loadData():  
    train_x = []  
    train_y = []  
    fileIn = open('data.txt')
    for line in fileIn.readlines():  
        lineArr = line.strip().split()  
        train_x.append([1.0, float(lineArr[0]), float(lineArr[1])])  
        train_y.append(float(lineArr[2]))  
    train_x = np.array(train_x)
    train_y = np.array(train_y).T
    return train_x,train_y
def sigmod(x):
    return 1.0/(1.0+np.exp(-x))
def train(matrix,labels):
    size = matrix.shape[1]
    w = np.ones(size)
    while True:
        x = np.dot(matrix,w)
        y = sigmod(x)
        diff = labels - y
        tmpW = w + 0.01*np.dot(matrix.T,diff)
        diff2 = (tmpW-w)**2
        sum_diff2 = sum(diff2)
        sq = sum_diff2**0.5
        if sq < 0.001:
            break
        else:
            w = tmpW
    return w
def test(matrix,labels,w):
    x = np.dot(matrix,w)
    y = sigmod(x)
    error = 0.0
    for i,result in enumerate(y):
        if result > 0.5:
            predict = 1.0
            if predict != labels[i]:
                error +=1
        else:
            predict = 0.0
            if predict != labels[i]:
                error +=1
    print("错误率:%3.2f" %(error/100.0))
def show(data,labels,w):
    x1=[]
    y1=[]
    x2=[]
    y2=[]
    for i in range(len(labels)):
        if labels[i] == 0:
            x1.append(data[i,1])
            y1.append(data[i,2])
        else:
            x2.append(data[i,1])
            y2.append(data[i,2])
    plt.scatter(x1,y1,edgecolors='r')
    plt.scatter(x2,y2,edgecolors='k')
    max_x = (np.max(data[:,1]))
    min_x = (np.min(data[:,1]))
    y_min_x = float(-w[0] - w[1] * min_x) / w[2]
    y_max_x = float(-w[0] - w[1] * max_x) / w[2]
    plt.plot([min_x, max_x], [y_min_x, y_max_x], '-g')
    plt.show()
def main():
    matrix,labels = loadData()
    weights = train(matrix,labels)
    test(matrix,labels,weights)
    show(matrix,labels,weights)
if __name__ == '__main__':
    main()
实验结果图:

技术分享

实验数据集:

-0.017612	14.053064	0
-1.395634	4.662541	1
-0.752157	6.538620	0
-1.322371	7.152853	0
0.423363	11.054677	0
0.406704	7.067335	1
0.667394	12.741452	0
-2.460150	6.866805	1
0.569411	9.548755	0
-0.026632	10.427743	0
0.850433	6.920334	1
1.347183	13.175500	0
1.176813	3.167020	1
-1.781871	9.097953	0
-0.566606	5.749003	1
0.931635	1.589505	1
-0.024205	6.151823	1
-0.036453	2.690988	1
-0.196949	0.444165	1
1.014459	5.754399	1
1.985298	3.230619	1
-1.693453	-0.557540	1
-0.576525	11.778922	0
-0.346811	-1.678730	1
-2.124484	2.672471	1
1.217916	9.597015	0
-0.733928	9.098687	0
-3.642001	-1.618087	1
0.315985	3.523953	1
1.416614	9.619232	0
-0.386323	3.989286	1
0.556921	8.294984	1
1.224863	11.587360	0
-1.347803	-2.406051	1
1.196604	4.951851	1
0.275221	9.543647	0
0.470575	9.332488	0
-1.889567	9.542662	0
-1.527893	12.150579	0
-1.185247	11.309318	0
-0.445678	3.297303	1
1.042222	6.105155	1
-0.618787	10.320986	0
1.152083	0.548467	1
0.828534	2.676045	1
-1.237728	10.549033	0
-0.683565	-2.166125	1
0.229456	5.921938	1
-0.959885	11.555336	0
0.492911	10.993324	0
0.184992	8.721488	0
-0.355715	10.325976	0
-0.397822	8.058397	0
0.824839	13.730343	0
1.507278	5.027866	1
0.099671	6.835839	1
-0.344008	10.717485	0
1.785928	7.718645	1
-0.918801	11.560217	0
-0.364009	4.747300	1
-0.841722	4.119083	1
0.490426	1.960539	1
-0.007194	9.075792	0
0.356107	12.447863	0
0.342578	12.281162	0
-0.810823	-1.466018	1
2.530777	6.476801	1
1.296683	11.607559	0
0.475487	12.040035	0
-0.783277	11.009725	0
0.074798	11.023650	0
-1.337472	0.468339	1
-0.102781	13.763651	0
-0.147324	2.874846	1
0.518389	9.887035	0
1.015399	7.571882	0
-1.658086	-0.027255	1
1.319944	2.171228	1
2.056216	5.019981	1
-0.851633	4.375691	1
-1.510047	6.061992	0
-1.076637	-3.181888	1
1.821096	10.283990	0
3.010150	8.401766	1
-1.099458	1.688274	1
-0.834872	-1.733869	1
-0.846637	3.849075	1
1.400102	12.628781	0
1.752842	5.468166	1
0.078557	0.059736	1
0.089392	-0.715300	1
1.825662	12.693808	0
0.197445	9.744638	0
0.126117	0.922311	1
-0.679797	1.220530	1
0.677983	2.556666	1
0.761349	10.693862	0
-2.168791	0.143632	1
1.388610	9.341997	0
0.317029	14.739025	0



《统计学习方法》-逻辑回归笔记和python源码

标签:

原文地址:http://blog.csdn.net/v_victor/article/details/51362984

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!