码迷,mamicode.com
首页 > 其他好文 > 详细

使用tensorflow中的Dataset来读取制作好的tfrecords文件

时间:2019-09-02 09:24:22      阅读:80      评论:0      收藏:0      [点我收藏+]

标签:charm   序列   queue   序列化   example   制作   shape   atp   创建   

上一篇我写了如何给自己的图像集制作tfrecords文件,现在我们就来讲讲如何读取已经创建好的文件,我们使用的是Tensorflow中的Dataset来读取我们的tfrecords,网上很多帖子应该是很久之前的了,绝大多数的做法是,先将tfrecords序列化成一个队列,然后使用TFRecordReader这个函数进行解析,解析出来的每一行都是一个record,然后再将每一个record进行还原,但是这个函数你在使用的时候会报出异常,原因就是它已经被dataset中新的读取方式所替代,下个版本中可能就无法使用了,因此不建议大家使用这个函数,好了,下面就来看看是如何进行读取的吧。

 1 import tensorflow as tf
 2 import matplotlib.pyplot as plt
 3 
 4 #定义可以一次获得多张图像的函数
 5 def show_image(image_dir):
 6     plt.imshow(image_dir)
 7     plt.axis(on)
 8     plt.show()
 9 
10 #单个record的解析函数
11 def decode_example(example):#,resize_height,resize_width,labels_nums):
12     features=tf.io.parse_single_example(example,features={
13         image_raw:tf.io.FixedLenFeature([],tf.string),
14         label:tf.io.FixedLenFeature([],tf.int64)
15     })
16     tf_image=tf.decode_raw(features[image_raw],tf.uint8)#这个其实就是图像的像素模式,之前我们使用矩阵来表示图像
17     tf_image=tf.reshape(tf_image,shape=[224,224,3])#对图像的尺寸进行调整,调整成三通道图像
18     tf_image=tf.cast(tf_image,tf.float32)*(1./255)#对图像进行归一化以便保持和原图像有相同的精度
19     tf_label=tf.cast(features[label],tf.int32)
20     tf_label=tf.one_hot(tf_label,5,on_value=1,off_value=0)#将label转化成用one_hot编码的格式
21     return tf_image,tf_label
22 
23 def batch_test(tfrecords_file):
24     dataset=tf.data.TFRecordDataset(tfrecords_file)
25     dataset=dataset.map(decode_example)
26     dataset=dataset.shuffle(100).batch(4)
27     iterator=tf.compat.v1.data.make_one_shot_iterator(dataset)
28     batch_images,batch_labels=iterator.get_next()
29 
30     init_op=tf.compat.v1.global_variables_initializer()
31     with tf.compat.v1.Session() as sess:
32         sess.run(init_op)
33         coord=tf.train.Coordinator()
34         threads=tf.train.start_queue_runners(coord=coord)
35         for i in range(4):
36             images,labels=sess.run([batch_images,batch_labels])
37             show_image(images[1,:,:,:])
38             print(shape:{},tpye:{},labels:{}.format(images.shape, images.dtype, labels))
39 
40         coord.request_stop()
41         coord.join(threads)
42 
43 if __name__==__main__:
44     tfrecords_file=D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/record/train.tfrecords
45     resize_height=224
46     resize_width=224
47     batch_test(tfrecords_file)

我为了测试,写了batch_test这个函数,因为我想试一试看我做的tfrecords能不能被解析成功,如果你不想测试只想训练,那你直接把images_batch,和labels_batch放到网络中进行训练就可以了,还有一点要注意的,tf.global_variables_initializer()已经被tf.compat.v1.global_variables_initializer()所取代了,我做的时候不知道所以报了一个warning提示,同时tf.Sesssion()已经被tf.compat.v1.Session() 所替代,iterator=dataset.make_one_shot_iterator()已经被tf.compat.v1.data.make_one_shot_iterator(dataset)  所代替,这些异常要注意,然后我只是将每个batch的第二张图片显示出来了,你也可以显示其他的,但是意义不大,反正只是测试一下解析成功与否,成功了我们就不需要纠结别的了。好啦,就是这样,接下来我会把这些东西放到网络中进行训练,再更新我的学习,就酱。

使用tensorflow中的Dataset来读取制作好的tfrecords文件

标签:charm   序列   queue   序列化   example   制作   shape   atp   创建   

原文地址:https://www.cnblogs.com/daremosiranaihana/p/11444705.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!