码迷,mamicode.com
首页 > 其他好文 > 详细

tensorflow-读写数据最佳代码组合方式

时间:2018-12-26 16:52:21      阅读:190      评论:0      收藏:0      [点我收藏+]

标签:hat   ade   pos   键值对   enqueue   ever   runners   session   not   

最佳组合代码模式为:

# Create the graph, etc.
init_op = tf.global_variables_initializer()

# Create a session for running operations in the Graph.
sess = tf.Session()

# Initialize the variables (like the epoch counter).
sess.run(init_op)

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
    while not coord.should_stop():
        # Run training steps or whatever
        sess.run(train_op)

except tf.errors.OutOfRangeError:
    print(‘Done training -- epoch limit reached‘)
finally:
    # When done, ask the threads to stop.
    coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Sat Sep 15 10:54:53 2018
@author: myhaspl
@email:myhaspl@myhaspl.com
读取文件
"""
import tensorflow as tf
import os

g=tf.Graph()
with g.as_default():
    #生成文件名队列
    fileName=os.getcwd()+"/1.csv"
    print fileName
    fileNameQueue=tf.train.string_input_producer([fileName])
    #生成记录键值对
    reader=tf.TextLineReader(skip_header_lines=1)
    key,value=reader.read(fileNameQueue)
    recordDefaults=[[""],[1],[1]]
    decoded=tf.decode_csv(value,record_defaults=recordDefaults)
    name,age,source=tf.train.shuffle_batch(decoded,batch_size=2,capacity=2,min_after_dequeue=1)    
    features=tf.transpose(tf.stack([age,source]))

with tf.Session(graph=g) as sess:
    # 开始产生文件名队列
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    print sess.run(features)

    coord.request_stop()
    coord.join(threads)

[[32 99]
[36 75]]

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Sat Sep 15 10:54:53 2018
@author: myhaspl
@email:myhaspl@myhaspl.com
读取文件
"""
import tensorflow as tf
import os

g=tf.Graph()
with g.as_default():
    #生成文件名队列
    fileName=os.getcwd()+"/1.csv"
    fileNameQueue=tf.train.string_input_producer([fileName])
    #生成记录键值对
    reader=tf.TextLineReader(skip_header_lines=1)
    key,value=reader.read(fileNameQueue)
    recordDefaults=[[""],[1],[1]]
    decoded=tf.decode_csv(value,record_defaults=recordDefaults)
    name,age,source=tf.train.shuffle_batch(decoded,batch_size=2,capacity=2,min_after_dequeue=1)    
    features=tf.stack([age,source])#此处不转置

with tf.Session(graph=g) as sess:
    # 开始产生文件名队列
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    print sess.run(features)
    print sess.run(key)#文件名
    print sess.run(value)#读取一行的内容

    coord.request_stop()
    coord.join(threads)

[[32 36]
[99 75]]
/Users/xxxxx/Documents/AIstudy/tf/1.csv:3
lisi,36,75
$ cat 1.csv

name,age,source

zhanghua,32,99

liuzhi,29,69

lisi,36,75

tensorflow-读写数据最佳代码组合方式

标签:hat   ade   pos   键值对   enqueue   ever   runners   session   not   

原文地址:http://blog.51cto.com/13959448/2335460

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!