标签:orm 会话 loss star pre ast on() red code
git: https://github.com/linyi0604/MachineLearning/tree/master/07_tensorflow/
1 import tensorflow as tf
2 from numpy.random import RandomState
3
4 ‘‘‘
5 模拟一个回归案例
6 自定义一个损失函数为:
7 当真实值y_更大的时候 loss = a(y_ - y)
8 当预测值y更大的时候 loss = b(y - y_)
9
10
11 loss_less = 10
12 loss_more = 1
13 loss = tf.reduce_sum(
14 tf.where(
15 tf.greater(y, y_),
16 (y - y_) * loss_more,
17 (y_ - y) * loss_less
18 ))
19
20 tf.reduce_sum() 求平均数
21 tf.where(condition, a, b) condition为真时返回a 否则返回b
22 tf.grater(a, b) a>b时候返回真 否则返回假
23
24 ‘‘‘
25
26 # 一批运算的数据数量
27 batch_size = 8
28
29 # 输入数据有两列特征
30 x = tf.placeholder(tf.float32, shape=[None, 2], name="x-input")
31 # 输入的真实值
32 y_ = tf.placeholder(tf.float32, shape=[None, 1], name="y-input")
33
34 # 定义一个单层神经网络 前向传播的过程
35 # 权重变量 2*1维度 方差为1 均值为0 种子变量使得每次运行生成同样的随机数
36 w1 = tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
37
38 # 计算过程
39 y = tf.matmul(x, w1)
40
41 # 自定义损失函数部分
42 loss_less = 10
43 loss_more = 1
44 loss = tf.reduce_sum(
45 tf.where(
46 tf.greater(y, y_),
47 (y - y_) * loss_more,
48 (y_ - y) * loss_less
49 ))
50
51 # 训练内容 训练速度0.001 让loss最小
52 train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
53
54
55 # 生成随机数作为训练数据
56 rdm = RandomState(1)
57 dataset_size = 128
58 X = rdm.rand(dataset_size, 2)
59 # 预测的正确至设置为两个特征加和 加上一个噪音
60 # 不设置噪音 预测的意义就不大了
61 # 噪音设置为均值为0的极小量
62 Y = [[x1 + x2 + rdm.rand()/10.0-0.05] for (x1, x2) in X]
63
64 # 开启会话训练
65 with tf.Session() as sess:
66 init_op = tf.initialize_all_variables()
67 sess.run(init_op)
68 STEPS = 5000
69 for i in range(STEPS):
70 start = (i * batch_size) % dataset_size
71 end = min(start + batch_size, dataset_size)
72 sess.run(
73 train_step,
74 feed_dict={
75 x: X[start: end],
76 y_: Y[start: end],
77 }
78 )
79 print(sess.run(w1))
80
81 ‘‘‘
82 [[1.019347 ]
83 [1.0428089]]
84 ‘‘‘
41 # 自定义损失函数部分 42 loss_less = 10 43 loss_more = 1 44 loss = tf.reduce_sum( 45 tf.where( 46 tf.greater(y, y_), 47 (y - y_) * loss_more, 48 (y_ - y) * loss_less 49 ))
这里自定义损失的时候,如果结果少了损失权重为10, 多了损失权重为1
预测结果 w1 为 [[1.02],[1.04]] , 所以结果预测偏向多于x1+x2, 因为多的话,损失少
标签:orm 会话 loss star pre ast on() red code
原文地址:https://www.cnblogs.com/Lin-Yi/p/9147674.html