标签:epo put reset tdd htm http eset equal hold
使用tf实现LR.import tensorflow as tf
import numpy as np
tf.reset_default_graph() # 清空Graph
FEATURE_NUM = 8 # 特征数量
with tf.name_scope("input"):
x = tf.placeholder(tf.float32, shape=[None, FEATURE_NUM])
y = tf.placeholder(tf.int32, shape=[None])
with tf.name_scope("lr"):
weight_init = tf.truncated_normal(shape=[FEATURE_NUM, 1], mean=0.0, stddev=1.0)
weight = tf.Variable(weight_init)
bais = tf.Variable([0.0])
y_expand = tf.reshape(y,shape=[-1,1])
hypothesis = tf.sigmoid(tf.matmul(x, weight) + bais)
with tf.name_scope("loss"):
y_float = tf.to_float(y_expand)
likelyhood = -(y_float tf.log(hypothesis) + (1.0 - y_float) (tf.log(1.0 - hypothesis)))
loss = tf.reduce_mean(likelyhood, axis=0)
LEARNING_RATE = 0.02 # 学习速率
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(LEARNING_RATE)
training_op = optimizer.minimize(loss)
THRESHOLD = 0.5 # 判断门限
with tf.name_scope("eval"):
predictions = tf.sign(hypothesis - THRESHOLD)
labels = tf.sign(y_float - THRESHOLD)
corrections = tf.equal(predictions, labels)
accuracy = tf.reduce_mean(tf.cast(corrections, tf.float32))
init = tf.global_variables_initializer() # 初始化所有变量
EPOCH = 10 # 迭代次数
with tf.Session() as sess:
sess.run(init)
for i in range(EPOCH):
_training_op, _loss = sess.run([training_op, loss],
feed_dict={x: np.random.rand(10, 8), y: np.random.randint(2, size=10)})
_accuracy = sess.run([accuracy], feed_dict={x: np.random.rand(5, 8), y: np.random.randint(2, size=5)})
print("epoch:", i, _loss, _accuracy)
参考文章
https://www.cnblogs.com/jhc888007/p/10390282.html
标签:epo put reset tdd htm http eset equal hold
原文地址:https://blog.51cto.com/12597095/2507443