码迷,mamicode.com
首页 > 编程语言 > 详细

python实现简单的梯度下降法

时间:2018-07-15 17:33:09      阅读:137      评论:0      收藏:0      [点我收藏+]

标签:png   try   简单   alt   init   plot   turn   迭代   show   

代码如下:

# 梯度下降法模拟
import  numpy as np 
import matplotlib.pyplot as plt 
plot_x = np.linspace(-1,6,141)  


# 计算损失函数对应的导数,即对y=(x-2.5)**2-1求导
def dJ(theda):
    return 2*(theda-2.5)
# 计算theda对应的损失函数值
def J(theda):
    try:
        return (theda-2.5)**2-1
    except:
        return float(inf)

# 梯度下降法开始
theda_history = [] # 用来记录梯度下降的过程
theda = 0.0 # 以0作为开始点
eta = 0.1  # 设置学习率
# epsilon = 1e-8  由于导数可能达不到0,
# 所以设置epsilon,表示损失函数值每次减小不足1e-8就认为已经达到最小值了

# n_itera 用来限制迭代的次数,默认为10000次
# 梯度下降函数
def gradient_descent(initial_theda,eta,n_itera=1e4,epsilon=1e-8):
    theda = initial_theda
    theda_history.append(initial_theda)
    i_itera = 0
    while i_itera<n_itera:
        gradient = dJ(theda)
        last_theda = theda
        theda = theda - eta * gradient
        theda_history.append(theda)
        if(abs(J(theda)-J(last_theda))<epsilon):
            break
        i_itera += 1
def plot_theda_history():
    plt.plot(plot_x,J(plot_x))
    plt.plot(np.array(theda_history),J(np.array(theda_history)),color=r,marker=+)
    plt.show()

gradient_descent(theda,eta)
plot_theda_history()

效果图:

技术分享图片

 

python实现简单的梯度下降法

标签:png   try   简单   alt   init   plot   turn   迭代   show   

原文地址:https://www.cnblogs.com/ncuhwxiong/p/9313857.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!