码迷,mamicode.com
首页 > 其他好文 > 详细

tensorflow之回归训练

时间:2020-07-11 21:02:11      阅读:68      评论:0      收藏:0      [点我收藏+]

标签:div   random   line   error   col   histogram   ali   model   infer   

1、函数

             matmul(a,b,name=None):

     功能: 矩阵乘运算"""Multiplies matrix `a` by matrix `b`, producing `a` * `b`.

             tf.square(error)

  功能:平方

     tf.reduce_mean(error)

  功能:均值

     tf.train.GradientDescentOptimizer(0.1).minimize(loss)

  功能:梯度下降优化器优化+最小损失
 
  1 import tensorflow as tf
  2 
  3 #实现线性回归训练
  4 class MyLinearRegression(object):
  5 
  6     """实现线性回归训练"""
  7 
  8     def __init__(self):
  9         self.learning_rate = 0.2
 10 
 11     def inputs(self):
 12         """
 13         1、准备数据的特征值和目标值 inputs
 14         :return x_data y_true
 15         """
 16         with tf.compat.v1.variable_scope("ytruexdata", reuse=tf.compat.v1.AUTO_REUSE):
 17             x_data = tf.random.normal([100,1],mean=0.0,stddev=1.0,name=x_data)#100个样本,每个样本一个特征值
 18             y_true = tf.matmul(x_data, [[0.7]]) + 0.8
 19             return    x_data ,y_true
 20 
 21     def inference(self,feature):
 22         """
 23         2、根据特征值建立线性回归模型(确定参数个数形状) inference
 24         :param  feature[100,1]
 25         :return y_predict
 26         """
 27 
 28         #定义一个命名空间(变量命名空间作用域,注:不是变量作用域),式代码结构更加清晰,Tensorboard图结构清楚
 29         with tf.compat.v1.variable_scope("y_predict", reuse=tf.compat.v1.AUTO_REUSE):
 30             # 建立回归模型,分析别人数据的特征数量--->权重数量, 偏置b
 31             # 随机初始化权重和偏置
 32             # 权重和偏置必须使用tf.Variable去定义,因为只有Variable才能被梯度下降所训练
 33             # 由于有梯度下降算法优化,所以一开始给随机的参数,权重和偏置
 34             # y_predict = x_data[100,1] * self.weight[1,1] +self.bias
 35             self.weight = tf.Variable(initial_value=tf.random.normal([1,1], mean=0.0, stddev=1.0),
 36                                       trainable=True, name=weight)
 37             self.bias = tf.Variable(0.0,trainable=True, name=bias)
 38             y_predict = tf.matmul(feature, self.weight) + self.bias
 39 
 40         return  y_predict
 41 
 42     def loss(self, y_true, y_predict):
 43         """
 44         3、根据模型得出预测结果,建立损失 loss
 45         :param   y_predict  y_true
 46         :return loss
 47         """
 48         with tf.compat.v1.variable_scope("losses", reuse=tf.compat.v1.AUTO_REUSE):
 49             loss =tf.reduce_mean(tf.square(y_true - y_predict),name=loss)
 50             return loss
 51 
 52     def sgd_op(self, loss):
 53         """
 54         4、梯度下降优化器优化损失 sgd_op
 55         :return   train_op
 56         """
 57         with tf.compat.v1.variable_scope("optimizate", reuse=tf.compat.v1.AUTO_REUSE):
 58         # 填充学习率:0 ~ 1    学习率是非常小,
 59         # 学习率越大,训练到较好结果的步长越小;学习率越小,训练到较好结果的步长越大。
 60         # 最小化损失
 61             train_op = tf.compat.v1.train.GradientDescentOptimizer(self.learning_rate,name=train_op).minimize(loss)
 62 
 63             return train_op
 64 
 65     def merge_summary(self, loss):
 66 
 67         # 1、收集张量的值
 68         tf.summary.scalar("show_loss", loss)
 69         tf.summary.histogram("show_weight", self.weight)
 70         tf.summary.histogram(show_bias, self.bias)
 71 
 72         # 2、合并变量
 73         merged = tf.summary.merge_all()
 74 
 75         return merged
 76 
 77     def train(self):
 78         """
 79         训练模型
 80         :param loss:
 81         :return:
 82         """
 83         with tf.compat.v1.variable_scope("train_model", reuse=tf.compat.v1.AUTO_REUSE):
 84             x_data,y_true = self.inputs()
 85             y_predict =self.inference(x_data)
 86             loss = self.loss(y_true, y_predict)
 87             train_op = self.sgd_op(loss)
 88             merged = self.merge_summary(loss)
 89 
 90         #在会话中训练
 91         with  tf.compat.v1.Session() as sess:
 92 
 93             sess.run(tf.compat.v1.global_variables_initializer())
 94 
 95             for i in range(10):
 96                 _,summary = sess.run([train_op,merged])
 97                 print("第%d次的损失为%f,权重为%f,偏置为%f" % (i + 1,loss.eval(),self.weight.eval(),self.bias.eval()))
 98 
 99             # tf.summary.FileWriter("./tmp/", graph=ss.graph)
100 
101 
102                 file_writer = tf.compat.v1.summary.FileWriter(./tmp/summary/,graph=sess.graph)
103 
104                 file_writer.add_summary(summary, i+1)
105 
106 
107 MyLinearRegression().train()

 

tensorflow之回归训练

标签:div   random   line   error   col   histogram   ali   model   infer   

原文地址:https://www.cnblogs.com/yangjingshixinlingdechuanghu/p/13283250.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!