标签:on() argmax decode 没有 records oss 序号 fine har
1,分割
2,整体识别 一张图片n个字母 即不再是一个目标值,是n个
? 例:NZPP
? N ------>[0.01,0.02,0.03.......] 概率 N------->[0,0,0,0,1.......] one-hot编码
? Z-------->[0.01,0.02,0.03.......] Z------->[0,1,0,0,0......]
? 最后得出n*26个概率
? 交叉熵计算
? 目标值:[0,0,0,1......] [0,0,0,1......] [0,0,0,1......] [0,0,0,1......]
? 经过网络输出:[1.2,2.3,5.6......]
? 经过softmax得出概率:[0.01,0.02,.......]
? 损失:104 个目标值m与概率n值相乘 mlogn,one-hot由于很多为0 最可得出一个样本的损失值为:1log()+1log()+1log()+1log()=损失值
1,处理数据 图片-----标签 图片路径-->序号-->验证码值----->转化为数字
使用tfrecords
2,识别验证码
x = [100,20*80*3]
? w=[20*80*3,4*26] bias=[4*26]
? y_predict = [100,4*26](这得出概率值)
先把目标值[100,4] 转化为one_hot编码 ---->[100,4,26]
? Api:tf.ont_hot([[13,25,11,24],[....],[...]],depth=26,axis=2,on_value=1.0)
? 将目标值转化为[100,4*26] 与预测值交叉熵计算
梯度下降优化
精确值
tf.equal(tf.argmax(y_true,2),tf.argmax(y_predict,2)) [100,4,26] 2 ---->26
再求平均值
import tensorflow as tf
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("captcha_dir", "dataPath/.tfrecords", "验证码文件路径")
tf.flags.DEFINE_string("batch_size", 100, "每批次训练的样本数")
tf.flags.DEFINE_string("letter_num", 26, "目标值的分类")
tf.flags.DEFINE_string("label_num", 4, "目标值个数")
tf.flags._FlagValuesWrapper
# 定义一个权重初始化函数
def weight_variables(shape):
weight = tf.Variable(tf.random_normal(mean=0.0, stddev=1.0, shape=shape), name="w")
return weight
# 定义一个权重初始化函数
def bias_Variables(shape):
bias = tf.Variable(tf.constant(0.0, shape=shape), name="b")
return bias
def read_and_decode():
"""
读取数据API
:return:image_batch,label_batch
"""
# 1,构建文件队列
file_queue = tf.train.start_queue_runners([FLAGS.captcha_dir])
# 2,构建阅读器,读取文件内容,默认一个样本
reader = tf.TFRecordReader()
# 3,读取文件内容
key, value = reader.read(file_queue)
# tfrecords 格式需要解析
features = tf.parse_single_example(value, features={
"image": tf.FixedLengthFeature([], tf.string),
"label": tf.FixedLengthFeature([], tf.string),
})
# 解析内容,字符串内容
image = tf.decode_rwa(features["image"], tf.uint8)
label = tf.decode_rwa(features["label"], tf.uint8)
# 改变形状
image_reshape = tf.reshape(image, [20, 80, 3])
label_reshape = tf.reshape(label, [FLAGS.label_num])
# 进行批处理,每次读取的样本数,就是每次训练的样本数
image_batch, label_batch = tf.train.batch([image_reshape, label_reshape], batch_size=FLAGS.batch_size,num_threads=1, capacity=FLAGS.batch_size)
return image_batch, label_batch
def fc_model(image):
"""
进行预测结果
:param image:特征值
:return:y_predict 【100,104】
"""
with tf.variable_scope("model"):
# 转化图片为二维
image_reshape = tf.reshape(image, [-1, 20 * 80 * 3])
# 1,随机初始化权重,偏值
weight = weight_variables([20 * 80 * 3, FLAGS.label_num * FLAGS.letter_num])
bias = bias_Variables([FLAGS.label_num * FLAGS.letter_num])
# 2,进行全连接层计算 没有使用卷积神经网络 [100,4*26]
y_predict = tf.matmul(tf.cast(image_reshape, tf.float32), weight) + bias
return y_predict
def transform_to_onthot(label):
"""
将读取文件的目标值转化为ont-hot编码
:param label:【100,4】 【【1,2,3,4】,【5,6,7,8】.....】
:return:[100,4,26]
"""
label_onehot = tf.one_hot(label, depth=FLAGS.letter_num, on_value=1.0, asix=2)
return label_onehot
def captcharec():
"""验证码识别程序"""
# 1,读取验证码的数据文件
image_batch, label_batch = read_and_decode()
# 2,输入图片特征数据建立模型,得出预测结果
# 一层全连接神经网络进行预测
# matrix【100,20,80,3】---》转化为二维*【20*80*3,4*26】+【4*26】=【100,104】
y_predict = fc_model(image_batch)
# 3,转换目标值为ont-hot编码 [100,4,26]
y_true = transform_to_onthot(label_batch)
# 4,交叉熵损失计算
with tf.variable_scope("sotf_cross"):
# y_true[100,4,26] ---> [100,104]
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
labels=tf.reshape(y_true, [FLAGS.batch_size, FLAGS.label_num * FLAGS.letter_num]), logits=y_predict))
# 5,优化损失
with tf.variable_scope("opt"):
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
# 6,准确率计算
with tf.variables_scope("acc"):
# 比较目标值与与与测试 4个位置是否全部相同, 三维比较 需要将预测值转换为三维
equal_list = tf.equal(tf.argmax(y_true, 2),
tf.argmax(tf.reshape(y_predict, [FLAGS.batch_size, FLAGS.label_num, FLAGS.letter_num]),2))
accuracy = tf.reduce_mean(tf.cast(equal_list.tf.float32))
# 7,开启会话训练
init_op = tf.golbal_variables_initilizer()
with tf.Session() as sess:
sess.run(init_op)
# 定义线程协调器,开启线程
coord = tf.train.Coordinator()
# 开启线程取读取文件
threads = tf.train.start_queue_runners(sess, coord=coord)
# 训练
for i in range(5000):
sess.run(train_op)
print("第%d次准确率是:%f" % (i, accuracy.eval()))
# 回收
coord.request_stop()
coord.join(threads)
标签:on() argmax decode 没有 records oss 序号 fine har
原文地址:https://www.cnblogs.com/Dean0731/p/11788568.html