码迷,mamicode.com
首页 > 其他好文 > 详细

TFRecords文件的生成和读取(样例实现)

时间:2018-04-01 11:55:27      阅读:5858      评论:0      收藏:0      [点我收藏+]

标签:session   []   读取   rds   需要   信息   string   特征   max   

参考:https://blog.csdn.net/u012222949/article/details/72875281

参考:https://blog.csdn.net/chengshuhao1991/article/details/78656724

tfrecords文件的存储:

将其他数据存储为tfrecord文件的时候,需要进行两个步骤:

建立tfrecord存储器

构造每个样本的Example模块

1、构建tfrecord存储器

实现建立存储器的函数为:

tf.python_io.TFRecordWriter(path)  
#写入tfrecord文件
#path为tfrecord的存储路径

2、构造每个样本的example模块

Example协议块的规则如下:

message Example {
  Features features = 1;
};
message Features {
  map<string, Feature> feature = 1;
};
message Feature {
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

其中实现的几个函数如下所示:

tf.train.Example(features = None)  
#用于写入tfrecords文件
#features : tf.train.Features类型的特征实例
#返回example协议格式块
tf.train.Features(feature = None)
#用于构造每个样本的信息键值对
#feature : 字典数据,key为要保存的名字,value为tf.train.Feature实例
#返回Features类型
tf.train.Feature(**options) 
#options可选的三种数据格式:
bytes_list = tf.train.BytesList(value = [Bytes])
int64_list = tf.train.Int64List(value = [Value])
float_list = tf.trian.FloatList(value = [Value])

最终将图片数据转换成tfrecords的例子,即对每个样本都作如下处理:

example = tf.train.Example(feature = tf.train.Features(feature= {"image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image(bytes)]))
,"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label(int)]))}))

例1、将图片文件转换成tfrecord文件(具体代码实现):

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import tensorflow as tf
import pandas as pd
def get_label_from_filename(filename):
    return 1
filenames = tf.train.match_filenames_once(C:/Users/1/Desktop/3/*.jpg)
writer = tf.python_io.TFRecordWriter(C:/Users/1/Desktop/png_train.tfrecords)
with tf.Session() as sess:        #使用match_filenames_once函数需要用tf.local_variables_initializer()函数来实现变量的初始化
    sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])
    filenames=(sess.run(filenames))      
print(filenames)        
#获取的字符串为前面带b:bytes的字符串,类似于字符串前带u:unicode的字符串
#其中从字符串转化成unicode编码的过称为:str.decode(‘utf-8‘),从unicode转化成字符串为:str.encode(‘utf-8‘),因此对如下做同样操作
for filename in filenames: img=mpimg.imread(filename.decode(utf-8)) print("{} shape is {}".format(filename, img.shape)) img_raw = img.tostring() label = get_label_from_filename(filename) example = tf.train.Example( features=tf.train.Features( feature={ "image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])), "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])) } ) ) writer.write(record=example.SerializeToString()) writer.close()

 

 

 

 

 

 

glob包的介绍:

用于获取所有匹配的文件路径列表

import glob
glob.glob("/home/zikong/doc/*.doc")
#返回结果如下:
/home/zikong/doc/file1.doc     /home/zikong/doc/file2.doc

例2、tfrecord文件的生成:

from random import shuffle  
import numpy as np  
import glob  
import tensorflow as tf  
import cv2  
import sys  
import os  
os.environ[
TF_CPP_MIN_LOG_LEVEL] = 2 shuffle_data = True image_path = /path/to/image/*.jpg # 取得该路径下所有图片的路径,type(addrs)= list addrs = glob.glob(image_path) # 标签数据的获得具体情况具体分析,type(labels)= list labels = ... # 这里是打乱数据的顺序 if shuffle_data: c = list(zip(addrs, labels)) #将两列元素进行组合 shuffle(c) #random包的shuffle函数进行打乱处理 addrs, labels = zip(*c) #将组合后的元素再进行拆分 # 按需分割数据集 train_addrs = addrs[0:int(0.7*len(addrs))] train_labels = labels[0:int(0.7*len(labels))] val_addrs = addrs[int(0.7*len(addrs)):int(0.9*len(addrs))] val_labels = labels[int(0.7*len(labels)):int(0.9*len(labels))] test_addrs = addrs[int(0.9*len(addrs)):] test_labels = labels[int(0.9*len(labels)):] # 上面不是获得了image的地址么,下面这个函数就是根据地址获取图片 def load_image(addr): # A function to Load image img = cv2.imread(addr) img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 这里/255是为了将像素值归一化到[0,1] img = img / 255. img = img.astype(np.float32) return img # 将数据转化成对应的属性 def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _float_feature(value): return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) # 下面这段就开始把数据写入TFRecods文件 train_filename = /path/to/train.tfrecords # 输出文件地址 # 创建一个writer来写 TFRecords 文件 writer = tf.python_io.TFRecordWriter(train_filename) for i in range(len(train_addrs)): # 这是写入操作可视化处理 if not i % 1000: print(Train data: {}/{}.format(i, len(train_addrs))) sys.stdout.flush() # 加载图片 img = load_image(train_addrs[i]) label = train_labels[i] # 创建一个属性(feature) feature = {train/label: _int64_feature(label), train/image: _bytes_feature(tf.compat.as_bytes(img.tostring()))} # 创建一个 example protocol buffer example = tf.train.Example(features=tf.train.Features(feature=feature)) # 将上面的example protocol buffer写入文件 writer.write(example.SerializeToString()) writer.close() sys.stdout.flush()

 例3、从MNIST输入数据转化为TFRecord的格式,以及将如何读取TFRecords文件中的数据

从MNIST输入数据转化为TFRecord格式:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
mnist=input_data.read_data_sets(C:/Users/1/Desktop/data,dtype=tf.uint8,one_hot=True)
images=mnist.train.images
labels=mnist.train.labels
pixels=images.shape[1]
num_examples=mnist.train.num_examples
#输出TFRecord文件地址
filename=C:/Users/1/Desktop/data/output.tfrecords
writer=tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
    image_raw=images[index].tostring()
    example=tf.train.Example(features=tf.train.Features(feature={pixels:_int64_feature(pixels),label:_int64_feature(np.argmax(labels[index])),image_raw:_bytes_feature(image_raw)}))
    writer.write(example.SerializeToString())
writer.close()

以上程序部分将MNIST数据集中所有的训练数据存储到TFRecord文件中,当数据量较大时,也可以将数据写入多个TFRecord文件

以下程序给出了如何读取TFRecord文件中的数据:

import tensorflow as tf
reader=tf.TFRecordReader()
filename_queue=tf.train.string_input_producer([C:/Users/1/Desktop/data/output.tfrecords])  
_,serialized_example=reader.read(filename_queue)  #从文件中读取一个样例
features=tf.parse_single_example(serialized_example,features={image_raw:tf.FixedLenFeature([],tf.string),pixels:tf.FixedLenFeature([],tf.int64),label:tf.FixedLenFeature([],tf.int64)})
#tf.FixedLenFeature()函数解析得到的结果是一个Tensor
images=tf.decode_raw(features[image_raw],tf.uint8)
labels=tf.cast(features[label],tf.int32)   #将目标变量转换成tf.int32格式
pixels=tf.cast(features[pixels],tf.int32)
#tf.decode_raw可以将字符串解析成图像对应的像素数组
sess=tf.Session()
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)
for i in range(10):
    image,label,pixel=sess.run([images,labels,pixels])

 

TFRecords文件的生成和读取(样例实现)

标签:session   []   读取   rds   需要   信息   string   特征   max   

原文地址:https://www.cnblogs.com/xiaochouk/p/8685909.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!