码迷,mamicode.com
首页 > 其他好文 > 详细

莫烦keras学习自修第三天【回归问题】

时间:2018-09-02 21:43:21      阅读:110      评论:0      收藏:0      [点我收藏+]

标签:info   get   技术   val   reg   问题   array   span   测试数据   

1. 代码实战

#!/usr/bin/env python
#!_*_ coding:UTF-8 _*_

import numpy as np
# 这句话不知道是什么意思
np.random.seed(1337)
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt

# 创建一些训练数据
# 生成-1 到 1 之间的float64的200个数的列表
X = np.linspace(-1, 1, 200)

# 打乱列表为无序状态
np.random.shuffle(X)

# 根据X的数据生成Y,并且系数为0.5, 偏置为0到0.05
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, ))

# 使用plt工具画图
plt.scatter(X, Y)
plt.show()

# 将生成的前面160条数据作为训练数据
X_train, Y_train = X[:160], Y[:160]
# 将生成的后面40条数据作为测试数据
X_test, Y_test = X[160:], Y[160:]

# 创建一个训练模型
model = Sequential()

# 为训练模型添加隐藏层
model.add(Dense(units=1, input_dim=1))

# 为训练模型进行编译,使用均方误差损失函数
model.compile(loss=mse, optimizer=sgd)

# 开始进行训练,
for step in range(301):
    # 该训练函数每次训练均返回cost损失
    cost = model.train_on_batch(X_train, Y_train)
    if step % 100 == 0:
        print(train cost: , cost)

# 测试训练过的模型
# 该批量测试函数返回损失值
cost = model.evaluate(X_test, Y_test, batch_size=40)
print(test cost:, cost)
# 打印W和b这些参数
W, b = model.layers[0].get_weights()
print(Weights=, W, \nbiases=, b)

# 打印测试值和预测值
Y_pred = model.predict(X_test)
# 使用散点图绘制测试值
plt.scatter(X_test, Y_test)
# 使用直线图绘制预测值
plt.plot(X_test, Y_pred)
plt.show()

结果:

/Users/liudaoqiang/PycharmProjects/numpy/venv/bin/python /Users/liudaoqiang/Project/python_project/keras_day02/regressor.py
Using Theano backend.
(train cost: , array(4.190890312194824, dtype=float32))
(train cost: , array(0.10415506362915039, dtype=float32))
(train cost: , array(0.011512807570397854, dtype=float32))
(train cost: , array(0.004584408365190029, dtype=float32))

40/40 [==============================] - 0s 5us/step
(test cost:, 0.0053740302100777626)
(Weights=, array([[ 0.56634265]], dtype=float32), \nbiases=, array([ 2.00106311], dtype=float32))

Process finished with exit code 0

技术分享图片

莫烦keras学习自修第三天【回归问题】

标签:info   get   技术   val   reg   问题   array   span   测试数据   

原文地址:https://www.cnblogs.com/liuzhiqaingxyz/p/9575174.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!