码迷,mamicode.com
首页 > 其他好文 > 详细

5.非线性回归

时间:2019-09-28 23:54:48      阅读:216      评论:0      收藏:0      [点我收藏+]

标签:lib   pyplot   随机   feed   erro   port   esc   desc   die   

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# numpy生成200个随机点
x_data = np.linspace(-0.5,0.5,200)[:,np.newaxis]
noise = np.random.normal(0,0.02,x_data.shape)
y_data = np.square(x_data) + noise

plt.scatter(x_data, y_data)
plt.show()

技术图片

# 定义两个placeholder
x = tf.placeholder(tf.float32,[None,1])
y = tf.placeholder(tf.float32,[None,1])

# 神经网络结构:1-30-1
w1 = tf.Variable(tf.random_normal([1,30]))
b1 = tf.Variable(tf.zeros([30]))
wx_plus_b_1 = tf.matmul(x,w1) + b1
l1 = tf.nn.tanh(wx_plus_b_1)

w2 = tf.Variable(tf.random_normal([30,1]))
b2 = tf.Variable(tf.zeros([1]))
wx_plus_b_2 = tf.matmul(l1,w2) + b2
prediction = tf.nn.tanh(wx_plus_b_2)

# 二次代价函数
loss = tf.losses.mean_squared_error(y,prediction)
# 使用梯度下降法最小化loss
train = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

with tf.Session() as sess:
    # 变量初始化
    sess.run(tf.global_variables_initializer())
    for _ in range(3000):
        sess.run(train,feed_dict={x:x_data,y:y_data})
        
    # 获得预测值
    prediction_value = sess.run(prediction,feed_dict={x:x_data})
    # 画图
    plt.scatter(x_data, y_data)
    plt.plot(x_data, prediction_value, r-, lw=5)
    plt.show()

技术图片

5.非线性回归

标签:lib   pyplot   随机   feed   erro   port   esc   desc   die   

原文地址:https://www.cnblogs.com/liuwenhua/p/11605364.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!