标签:dict 像素 tde feed from 输入数据 平均值 优化器 min
最近在学习黄文坚的TensorFlow书籍,希望对学习做一个总结。
softmax regression算法原理:当我们对一张图片进行预测时,会计算每一个数字的可能性,如3的概率是3%,5的概率是6%,1的概率是80%,则返回1.
TensorFlow版本:0.8.0
# 导入手写识别数据,TensorFlow提供了手写识别库
from tensorflow.examples.tutorials.mnist import input_data
# 读取手写识别数据 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# 训练集数据的维度是(55000,784),训练集标签的维度是(55000,10)
# 测试集数据的维度是(10000,784),测试集标签的维度是(10000,10)
# 验证集数据的维度是(5000,784),验证集标签的维度是(5000,10)
# 为什么训练数据的维度是784?因为tensorflow提供的数据集的图片像素是28*28=784
# 为什么标签的维度是10,标签做了处理,每个预期结果变成了只包含0和1的10维数据。
# 例如标签5就表示为[0,0,0,0,0,1,0,0,0,0],这种方法叫one-hot编码
print(mnist.train.images.shape,mnist.train.labels.shape) print(mnist.test.images.shape,mnist.test.labels.shape) print(mnist.validation.images.shape,mnist.validation.labels.shape)
# 导入TensorFlow库 import tensorflow as tf
# 将session注册为默认的session,运算都在session里跑。placeholder为输入数据的地方
# placeholder的第一个参数表示数据类型,第2个参数表示数据的维度,None表示任意长度的数据 sess=tf.InteractiveSession() x = tf.placeholder(tf.float32,[None,784])
# Variable用于存储参数,它是持久化的,可以长期存在,每次迭代都会更新 # 数据的维度是784,类别的维度经过one-hot编码后变成了10维,所以W的参数为[784,10]
# b为[10]维,W和b全部初始化为0,简单模型的初始值不重要
W = tf.Variable(tf.zeros([784,10])) b=tf.Variable(tf.zeros([10])) # softmax函数用于定义softmax regression算法
# matmul用于向量乘法
y=tf.nn.softmax(tf.matmul(x,W)+b)
# 求损失函数cross-entropy,先定义一个placeholder,输入的真实label
# cross_entropy定义了损失函数的计算方法,通过reduce_sum求熵的和,reduce_mean求每个batch的熵的平均值 y_=tf.placeholder(tf.float32,[None,10]) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))
# 定义一个优化器,GradientDescentOptimizer为优化器,学习率为0.5,优化目标设定为cross_entropy train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# 全局参数初始化并执行run tf.initialize_all_variables().run()
# 每次取100个样本,并feed给placeholder,执行1000次,train_step对数据进行训练 for i in range(1000): batch_xs,batch_ys = mnist.train.next_batch(100) train_step.run({x:batch_xs,y_:batch_ys})
# 求出概率最大的数字,判断是否与实际标签相符合,y是预测数据,y_是实际数据 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
# 求计算精度 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels})
总的来说,TensorFlow感觉还是比较简单的,也许这只是个最简单的模型吧。
涉及的概念也只有session,variable,placeholder,GradientDescentOptimizer。
梯度下降等复杂的方法都进行了封装,用python不到30行的代码就实现了手写识别,虽然识别正确率只有92%左右。
tensorflow使用softmax regression算法实现手写识别
标签:dict 像素 tde feed from 输入数据 平均值 优化器 min
原文地址:http://www.cnblogs.com/eagle-1024/p/7739711.html