码迷,mamicode.com
首页 > 其他好文 > 详细

tf.train.string_input_producer()

时间:2018-11-30 16:41:52      阅读:1631      评论:0      收藏:0      [点我收藏+]

标签:margin   pre   training   process   turn   技术   done   结合   inf   

处理从文件中读数据

官方说明

技术分享图片

简单使用

示例中读取的是csv文件,如果要读tfrecord的文件,需要换成 tf.TFRecordReader

import tensorflow as tf
filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

# Default values, in case of empty columns. Also specifies the type of the decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)
features = tf.stack([col1, col2, col3, col4])

with tf.Session() as sess:
    # Start populating the filename queue.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(12):
        # Retrieve a single instance:
        example, label = sess.run([features, col5])
        print(example, label)

    coord.request_stop()
    coord.join(threads)

运行结果:

技术分享图片

结合批处理

import tensorflow as tf
def read_my_file_format(filename_queue):
#     reader = tf.SomeReader()
    reader = tf.TextLineReader()
    key, record_string = reader.read(filename_queue)
#     example, label = tf.some_decoder(record_string)
    record_defaults = [[1], [1], [1], [1], [1]]
    col1, col2, col3, col4, col5 = tf.decode_csv(record_string, record_defaults=record_defaults)
#     processed_example = some_processing(example)
    features = tf.stack([col1, col2, col3, col4])
    return features, col5

def input_pipeline(filenames, batch_size, num_epochs=None):
    filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)
    example, label = read_my_file_format(filename_queue)
    #   min_after_dequeue + (num_threads + a small safety margin) * batch_size
    min_after_dequeue = 100
    capacity = min_after_dequeue + 3 * batch_size
    example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity,
                              min_after_dequeue=min_after_dequeue)
    return example_batch, label_batch

x,y = input_pipeline(["file0.csv", "file1.csv"],5,4)

sess = tf.Session()
sess.run([tf.global_variables_initializer(),tf.initialize_local_variables()])

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
    print("in try")
    while not coord.should_stop():
        # Run training steps or whatever
        example, label = sess.run([x,y])
        print(example, label)
        print("ssss")
        
except tf.errors.OutOfRangeError:
    print (Done training -- epoch limit reached)
finally:
    # When done, ask the threads to stop.
    coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()

运行结果:

技术分享图片

tf.train.string_input_producer()

标签:margin   pre   training   process   turn   技术   done   结合   inf   

原文地址:https://www.cnblogs.com/helloworld0604/p/10044748.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!