码迷,mamicode.com
首页 > 其他好文 > 详细

1.5神经网络可视化显示(matplotlib)

时间:2017-12-21 23:03:51      阅读:463      评论:0      收藏:0      [点我收藏+]

标签:show   input   oat   ges   step   运算   .com   技术   nump   

神经网络训练+可视化显示

#添加隐层的神经网络结构+可视化显示
import tensorflow as tf

def add_layer(inputs,in_size,out_size,activation_function=None):
    #定义权重--随机生成inside和outsize的矩阵
    Weights=tf.Variable(tf.random_normal([in_size,out_size]))
    #不是矩阵,而是类似列表
    biaes=tf.Variable(tf.zeros([1,out_size])+0.1)
    Wx_plus_b=tf.matmul(inputs,Weights)+biaes
    if activation_function is  None:
        outputs=Wx_plus_b
    else:
        outputs=activation_function(Wx_plus_b)
    return outputs

import numpy as np
x_data=np.linspace(-1,1,300)[:,np.newaxis] #300行数据
noise=np.random.normal(0,0.05,x_data.shape)
y_data=np.square(x_data)-0.5+noise
#None指定sample个数,这里不限定--输出属性为1
xs=tf.placeholder(tf.float32,[None,1])  #这里需要指定tf.float32,
ys=tf.placeholder(tf.float32,[None,1])

#建造第一层layer
#输入层(1)
l1=add_layer(xs,1,10,activation_function=tf.nn.relu)
#隐层(10)
prediction=add_layer(l1,10,1,activation_function=None)
#输出层(1)
#预测prediction
loss=tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),
                   reduction_indices=[1])) #平方误差
train_step=tf.train.GradientDescentOptimizer(0.1).minimize(loss)

init=tf.initialize_all_variables()
sess=tf.Session()
#直到执行run才执行上述操作
sess.run(init)


import matplotlib.pyplot as plt
fig=plt.figure()
ax=fig.add_subplot(111)
ax.scatter(x_data,y_data)
plt.ion() #图像会连续显示
#plt.show()不会终止整个函数。在2.x时候使用plt.show(block=False)
plt.show()


for i in range(1000):
    #这里假定指定所有的x_data来指定运算结果
    sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
    if i%50:
        # print (sess.run(loss,feed_dict={xs:x_data,ys:y_data}))
        try:
            #忽略第一次的错误
            ax.lines.remove(lines[0]) #在图片中去掉lines的第1条线段,不然线会混乱
        except Exception:
            prediction_value=sess.run(prediction,feed_dict={xs:x_data})
            lines=ax.plot(x_data,prediction_value,r-,lw=5)
            # ax.lines.remove(lines[0]) 移动上上面,先移除第一条线
            plt.pause(0.2) #每次暂停0.2s

显示:

技术分享图片

 

1.5神经网络可视化显示(matplotlib)

标签:show   input   oat   ges   step   运算   .com   技术   nump   

原文地址:http://www.cnblogs.com/jackchen-Net/p/8082562.html

(1)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!