码迷,mamicode.com
首页 > 其他好文 > 详细

二月五号博客

时间:2020-02-06 01:06:09      阅读:77      评论:0      收藏:0      [点我收藏+]

标签:cer   string   read   efi   rod   runners   sed   队列   文件   

今天学了TensorFlow文件读取操作

一,读取图片文件

def read_picture():
    """
    读取狗图片案例
    :return:
    """
    # 1、构造文件名队列
    # 构造文件名列表
    filename_list = os.listdir("./dog")
    # 给文件名加上路径
    file_list = [os.path.join("./dog/", i) for i in filename_list]
    # print("file_list:\n", file_list)
    # print("filename_list:\n", filename_list)
    file_queue = tf.train.string_input_producer(file_list)

    # 2、读取与解码
    # 读取
    reader = tf.WholeFileReader()
    key, value = reader.read(file_queue)
    print("key:\n", key)
    print("value:\n", value)

    # 解码
    image_decoded = tf.image.decode_jpeg(value)
    print("image_decoded:\n", image_decoded)

    # 将图片缩放到同一个大小
    image_resized = tf.image.resize_images(image_decoded, [200, 200])
    print("image_resized_before:\n", image_resized)
    # 更新静态形状
    image_resized.set_shape([200, 200, 3])
    print("image_resized_after:\n", image_resized)


    # 3、批处理队列
    image_batch = tf.train.batch([image_resized], batch_size=100, num_threads=2, capacity=100)
    print("image_batch:\n", image_batch)

    # 开启会话
    with tf.Session() as sess:
        # 开启线程
        # 构造线程协调器
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # 运行
        filename, sample, image, n_image = sess.run([key, value, image_resized, image_batch])
        print("filename:\n", filename)
        print("sample:\n", sample)
        print("image:\n", image)
        print("n_image:\n", n_image)

        coord.request_stop()
        coord.join(threads)


    return None

二,读取二进制文件

class Cifar():

    def __init__(self):

        # 设置图像大小
        self.height = 32
        self.width = 32
        self.channel = 3

        # 设置图像字节数
        self.image = self.height * self.width * self.channel
        self.label = 1
        self.sample = self.image + self.label


    def read_binary(self):
        """
        读取二进制文件
        :return:
        """
        # 1、构造文件名队列
        filename_list = os.listdir("./cifar-10-batches-bin")
        # print("filename_list:\n", filename_list)
        file_list = [os.path.join("./cifar-10-batches-bin/", i) for i in filename_list if i[-3:]=="bin"]
        # print("file_list:\n", file_list)
        file_queue = tf.train.string_input_producer(file_list)

        # 2、读取与解码
        # 读取
        reader = tf.FixedLengthRecordReader(self.sample)
        # key文件名 value样本
        key, value = reader.read(file_queue)

        # 解码
        image_decoded = tf.decode_raw(value, tf.uint8)
        print("image_decoded:\n", image_decoded)

        # 切片操作
        label = tf.slice(image_decoded, [0], [self.label])
        image = tf.slice(image_decoded, [self.label], [self.image])
        print("label:\n", label)
        print("image:\n", image)

        # 调整图像的形状
        image_reshaped = tf.reshape(image, [self.channel, self.height, self.width])
        print("image_reshaped:\n", image_reshaped)

        # 三维数组的转置
        image_transposed = tf.transpose(image_reshaped, [1, 2, 0])
        print("image_transposed:\n", image_transposed)

        # 3、构造批处理队列
        image_batch, label_batch = tf.train.batch([image_transposed, label], batch_size=100, num_threads=2, capacity=100)

        # 开启会话
        with tf.Session() as sess:

            # 开启线程
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            label_value, image_value = sess.run([label_batch, image_batch])
            print("label_value:\n", label_value)
            print("image:\n", image_value)

            coord.request_stop()
            coord.join(threads)

        return image_value, label_value

三,读取TFRecords文件

    def read_tfrecords(self):
        """
        读取TFRecords文件
        :return:
        """
        # 1、构造文件名队列
        file_queue = tf.train.string_input_producer(["cifar10.tfrecords"])

        # 2、读取与解码
        # 读取
        reader = tf.TFRecordReader()
        key, value = reader.read(file_queue)

        # 解析example
        feature = tf.parse_single_example(value, features={
            "image": tf.FixedLenFeature([], tf.string),
            "label": tf.FixedLenFeature([], tf.int64)
        })
        image = feature["image"]
        label = feature["label"]
        print("read_tf_image:\n", image)
        print("read_tf_label:\n", label)

        # 解码
        image_decoded = tf.decode_raw(image, tf.uint8)
        print("image_decoded:\n", image_decoded)
        # 图像形状调整
        image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channel])
        print("image_reshaped:\n", image_reshaped)

        # 3、构造批处理队列
        image_batch, label_batch = tf.train.batch([image_reshaped, label], batch_size=100, num_threads=2, capacity=100)
        print("image_batch:\n", image_batch)
        print("label_batch:\n", label_batch)

        # 开启会话
        with tf.Session() as sess:

            # 开启线程
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            image_value, label_value = sess.run([image_batch, label_batch])
            print("image_value:\n", image_value)
            print("label_value:\n", label_value)

            # 回收资源
            coord.request_stop()
            coord.join(threads)

        return None

 

二月五号博客

标签:cer   string   read   efi   rod   runners   sed   队列   文件   

原文地址:https://www.cnblogs.com/goubb/p/12267256.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!