码迷,mamicode.com
首页 > 其他好文 > 详细

TensorFlow读取CSV数据(批量)

时间:2017-05-12 11:35:49      阅读:514      评论:0      收藏:0      [点我收藏+]

标签:exce   global   value   下载   ade   erro   file   code   roc   

直接上代码:

# -*- coding:utf-8 -*-
import tensorflow as tf

def read_data(file_queue):
    reader = tf.TextLineReader(skip_header_lines=1)
    key, value = reader.read(file_queue)
    defaults = [[0], [0.], [0.], [0.], [0.], [‘‘]]
    Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species = tf.decode_csv(value, defaults)

    #因为使用的是鸢尾花数据集,这里需要对y值做转换
    preprocess_op = tf.case({
        tf.equal(Species, tf.constant(Iris-setosa)): lambda: tf.constant(0),
        tf.equal(Species, tf.constant(Iris-versicolor)): lambda: tf.constant(1),
        tf.equal(Species, tf.constant(Iris-virginica)): lambda: tf.constant(2),
    }, lambda: tf.constant(-1), exclusive=True)

    return tf.stack([SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm]), preprocess_op

def create_pipeline(filename, batch_size, num_epochs=None):
    file_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs)
    example, label = read_data(file_queue)

    min_after_dequeue = 1000
    capacity = min_after_dequeue + batch_size
    example_batch, label_batch = tf.train.shuffle_batch(
        [example, label], batch_size=batch_size, capacity=capacity,
        min_after_dequeue=min_after_dequeue
    )

    return example_batch, label_batch

x_train_batch, y_train_batch = create_pipeline(Iris-train.csv, 50, num_epochs=1000)
x_test, y_test = create_pipeline(Iris-test.csv, 60)

init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer()  # local variables like epoch_num, batch_size
with tf.Session() as sess:
    sess.run(init_op)
    sess.run(local_init_op)

    # Start populating the filename queue.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    # Retrieve a single instance:
    try:
        #while not coord.should_stop():
        while True:
            example, label = sess.run([x_train_batch, y_train_batch])
            print (example)
            print (label)
    except tf.errors.OutOfRangeError:
        print (Done reading)
    finally:
        coord.request_stop()

    coord.join(threads)
    sess.close()

 

数据集是鸢尾花数据集,大家自行下载吧,下面给个示例:

Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species
21,5.4,3.4,1.7,0.2,Iris-setosa
22,5.1,3.7,1.5,0.4,Iris-setosa
23,4.6,3.6,1.0,0.2,Iris-setosa
24,5.1,3.3,1.7,0.5,Iris-setosa
25,4.8,3.4,1.9,0.2,Iris-setosa
26,5.0,3.0,1.6,0.2,Iris-setosa
27,5.0,3.4,1.6,0.4,Iris-setosa
28,5.2,3.5,1.5,0.2,Iris-setosa
29,5.2,3.4,1.4,0.2,Iris-setosa
30,4.7,3.2,1.6,0.2,Iris-setosa
31,4.8,3.1,1.6,0.2,Iris-setosa
32,5.4,3.4,1.5,0.4,Iris-setosa
33,5.2,4.1,1.5,0.1,Iris-setosa
34,5.5,4.2,1.4,0.2,Iris-setosa
35,4.9,3.1,1.5,0.1,Iris-setosa
36,5.0,3.2,1.2,0.2,Iris-setosa
37,5.5,3.5,1.3,0.2,Iris-setosa

 

TensorFlow读取CSV数据(批量)

标签:exce   global   value   下载   ade   erro   file   code   roc   

原文地址:http://www.cnblogs.com/hunttown/p/6844477.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!