码迷,mamicode.com
首页 > 其他好文 > 详细

tensorflow数据加载方式

时间:2018-02-24 13:13:37      阅读:166      评论:0      收藏:0      [点我收藏+]

标签:tensorflow 数据加载

tensorflow当前具有三种读取数据的方式:
1.预加载(preloaded):在构建tensorflow流图时直接定义常量数据,由于数据是直接镶嵌在流图中,所以当数据量很大时将占用大量内存

import tensorflow as tf
a = tf.constant([1,2,3],name=‘input_a‘)
b = tf.constant([4,5,6],name=‘input_b‘)
c = tf.add(a,b,name=‘sums‘)
sess = tf.Session()
x = sess.run(c)
print(x)

2.填充(feeding):将python产生的数据直接填充到后端,这种方式同样存在数据量大时消耗内存的问题,同时数据类型转换也会增加一些开销

import tensorflow as tf
a = tf.placeholder(tf.int16)
b = tf.placeholder(tf.int16)
c = tf.add(a,b)
p_a = [1,2,3]
p_b = [4,5,6]
with tf.Session() as sess:
    print(sess.run(c, feed_dict={a:p_a, b:p_b}))

3.从文件读取(reading from file):相较于上面两种,这种方式处理量大的数据具有很大优势。tensorflow在从文件中读取数据时主要分两步:
(1)将数据写入TFRecords二进制文件;

‘‘‘创建转换函数,将数据填入到tf.train.Example协议缓冲区中,同时将缓冲区序列化为字符串,
  再通过tf.python_io.TFRecordWriter写入TFRecords文件‘‘‘
import os
import tensorflow as tf
def int64_feature(data):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[data]))
def bytes_feature(data):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[data]))
def convert_tfrecords(data, name):
    images = data.images
    labels = data.labels
    num_examples = data.num_examples
    if images.shape[0] != num_examples:
        raise ValueError(u‘图片数量与标签数量不一致,分别为%d和%d‘ %(images.shape[0],num_examples))
    rows = images.shape[1]
    width = images.shape[2]
    depth = images.shape[3]
    filename = os.path.join(os.path.dirname(__file__), name + ‘.tfrecores‘)
    writer = tf.python_io.TFRecoredWriter(filename)
    for i in range(num_examples):
        image_raw = images[i].tostring()
        example = tf.train.Example(features = tf.train.Features(feature = {
                            ‘height‘: int64_feature(rows), ‘width‘:int64_feature(width),
                            ‘depth‘:int64_feature(depth),‘label‘:int64_feature(labels),
                            ‘image_raw‘:bytes_feature(image_raw)}))
        writer.write(example.SerializeToString())
    writer.close()

(2)使用队列从二进制文件中读取数据。

tensorflow数据加载方式

标签:tensorflow 数据加载

原文地址:http://blog.51cto.com/abezoo/2072567

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!