码迷,mamicode.com
首页 > 其他好文 > 详细

tensorflow2.0 新特性小练习

时间:2019-06-04 22:21:53      阅读:352      评论:0      收藏:0      [点我收藏+]

标签:ted   src   增强   python   value   小练习   idt   path   from   

基于tf2.0 对Kaggel Google street view characters classify 项目练手, 熟悉一下tf2.0的新特性

下载下来kaggle的数据集如下:

技术图片技术图片

 

所有训练数据在train文件夹中, labels在trainLabels.cvs文件中, label文件格式如下:

技术图片

分别每个label对应其图片的名称

首先对数据进行预处理 总共有61个类别从a-z, A-Z, 0-9等,代码如下:

from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import tensorflow as tf
from tensorflow.python import keras
import csv
import pathlib
keras.backend.clear_session()
csv_filepath = E:\\work\\Kaggle\\street-view-getting-started-with-julia\\trainLabels.csv
data_root_path = E:\\work\\Kaggle\\street-view-getting-started-with-julia\\train
csv_file = csv.reader(open(csv_filepath, r))
label_container = []
labels = []
all_image_labels = []
AUTOTUNE = tf.data.experimental.AUTOTUNE

for cnt in csv_file:
    if cnt[1] not in labels:
        labels.append(cnt[1])
    label_container.append(cnt)

labels = labels[1:]
label_container = label_container[1:]
labels = np.sort(labels)
labels_to_index = dict((name, index) for index,name in enumerate(labels))

data_root = pathlib.Path(data_root_path)
all_image_paths = list(data_root.glob(*))
all_image_paths = [str(path) for path in all_image_paths]


for item in data_root.iterdir():
    # all_img_path.append(item)
    name = item.name[:-4]
    for match in label_container:
        if name == match[0]:
   all_image_labels.append(key if value == match[1]
for key, value in enumerate(labels_to_index))

生成的all_image_paths 和 all_image_labels分别包含如下:

技术图片

因为测试集没有label,所以把训练集分三份,并用tf.data.Data去映射数据空间

train_img_path = all_image_paths[:4000]
val_img_path = all_image_paths[4000:5000]
test_img_path = all_image_paths[5000:]

train_img_labels = all_image_labels[:4000]
val_img_labels = all_image_labels[4000:5000]
test_img_labels = all_image_labels[5000:]


raw_train_ds = tf.data.Dataset.from_tensor_slices((train_img_path, train_img_labels))
raw_val_ds = tf.data.Dataset.from_tensor_slices((val_img_path, val_img_labels))
raw_test_ds = tf.data.Dataset.from_tensor_slices((test_img_path, test_img_labels))

Scale 图片,并对其做数据增强,来满足translation invarience

 

tensorflow2.0 新特性小练习

标签:ted   src   增强   python   value   小练习   idt   path   from   

原文地址:https://www.cnblogs.com/ChrisInsistPy/p/10976509.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!