码迷,mamicode.com
首页 > 其他好文 > 详细

TensorFlow keras 迁移学习

时间:2020-02-06 01:43:15      阅读:86      评论:0      收藏:0      [点我收藏+]

标签:rom   app   kernel   列表   获取   image   port   ima   cto   

技术图片

 

 技术图片

 

 技术图片

 

 技术图片

 

 技术图片

 

 技术图片

 

数据的读取

 

import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator

class TransferModel(object):

    def __init__(self):
        #标准化和数据增强
        self.train_generator = ImageDataGenerator(rescale=1.0/255.0)
        self.test_generator = ImageDataGenerator(rescale=1.0/255.0)
        #指定训练集数据和测试集数据目录
        self.train_dir = "./data/train"
        self.test_dir = "./data/test"
        self.image_size = (224,224)
        self.batch_size = 32

    def get_loacl_data(self):
        ‘‘‘
        读取本地的图片数据以及类别
        :return: 
        ‘‘‘
        train_gen = self.train_generator.flow_from_directory(self.train_dir,
                                                 target_size=self.image_size,
                                                 batch_size=self.batch_size,
                                                 class_mode=‘binary‘,
                                                 shuffle=True)
        test_gen = self.test_generator.flow_from_directory(self.test_dir,
                                                           target_size=self.image_size,
                                                           batch_size=self.batch_size,
                                                           class_mode=‘binary‘,
                                                           shuffle=True)


        return train_gen,test_gen

if __name__ == ‘__main__‘:
    tm = TransferModel()
    train_gen,test_gen = tm.get_loacl_data()
    print(train_gen)

 

  迁移学习完整代码

 

import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.python.keras.applications.vgg16 import VGG16, preprocess_input
import numpy as np


class TransferModel(object):

    def __init__(self):

        # 定义训练和测试图片的变化方法,标准化以及数据增强
        self.train_generator = ImageDataGenerator(rescale=1.0 / 255.0)
        self.test_generator = ImageDataGenerator(rescale=1.0 / 255.0)

        # 指定训练数据和测试数据的目录
        self.train_dir = "./data/train"
        self.test_dir = "./data/test"

        # 定义图片训练相关网络参数
        self.image_size = (224, 224)
        self.batch_size = 32

        # 定义迁移学习的基类模型
        # 不包含VGG当中3个全连接层的模型加载并且加载了参数
        # vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
        self.base_model = VGG16(weights=‘imagenet‘, include_top=False)

        self.label_dict = {
            ‘0‘: ‘汽车‘,
            ‘1‘: ‘恐龙‘,
            ‘2‘: ‘大象‘,
            ‘3‘: ‘花‘,
            ‘4‘: ‘马‘
        }

    def get_local_data(self):
        """
        读取本地的图片数据以及类别
        :return: 训练数据和测试数据迭代器
        """
        # 使用flow_from_derectory
        train_gen = self.train_generator.flow_from_directory(self.train_dir,
                                                             target_size=self.image_size,
                                                             batch_size=self.batch_size,
                                                             class_mode=‘binary‘,
                                                             shuffle=True)
        test_gen = self.test_generator.flow_from_directory(self.test_dir,
                                                           target_size=self.image_size,
                                                           batch_size=self.batch_size,
                                                           class_mode=‘binary‘,
                                                           shuffle=True)
        return train_gen, test_gen

    def refine_base_model(self):
        """
        微调VGG结构,5blocks后面+全局平均池化(减少迁移学习的参数数量)+两个全连接层
        :return:
        """
        # 1、获取原notop模型得出
        # [?, ?, ?, 512]
        x = self.base_model.outputs[0]

        # 2、在输出后面增加我们结构
        # [?, ?, ?, 512]---->[?, 1 * 1 * 512]
        x = keras.layers.GlobalAveragePooling2D()(x)

        # 3、定义新的迁移模型
        x = keras.layers.Dense(1024, activation=tf.nn.relu)(x)
        y_predict = keras.layers.Dense(5, activation=tf.nn.softmax)(x)

        # model定义新模型
        # VGG 模型的输入, 输出:y_predict
        transfer_model = keras.models.Model(inputs=self.base_model.inputs, outputs=y_predict)

        return transfer_model

    def freeze_model(self):
        """
        冻结VGG模型(5blocks)
        冻结VGG的多少,根据你的数据量
        :return:
        """
        # self.base_model.layers 获取所有层,返回层的列表
        for layer in self.base_model.layers:
            layer.trainable = False

    def compile(self, model):
        """
        编译模型
        :return:
        """
        model.compile(optimizer=keras.optimizers.Adam(),
                      loss=keras.losses.sparse_categorical_crossentropy,
                      metrics=[‘accuracy‘])
        return None

    def fit_generator(self, model, train_gen, test_gen):
        """
        训练模型,model.fit_generator()不是选择model.fit()
        :return:
        """
        # 每一次迭代准确率记录的h5文件
        modelckpt = keras.callbacks.ModelCheckpoint(‘./ckpt/transfer_{epoch:02d}-{val_acc:.2f}.h5‘,
                                                     monitor=‘val_acc‘,
                                                     save_weights_only=True,
                                                     save_best_only=True,
                                                     mode=‘auto‘,
                                                     period=1)

        model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[modelckpt])

        return None

    def predict(self, model):
        """
        预测类别
        :return:
        """

        # 加载模型,transfer_model
        model.load_weights("./ckpt/transfer_02-0.93.h5")

        # 读取图片,处理
        image = load_img("./1.jpg", target_size=(224, 224))
        image.show()
        image = img_to_array(image)
        # print(image.shape)
        # 四维(224, 224, 3)---》(1, 224, 224, 3)
        img = image.reshape([1, image.shape[0], image.shape[1], image.shape[2]])
        # print(img)
        # model.predict()

        # 预测结果进行处理
        image = preprocess_input(img)
        predictions = model.predict(image)
        print(predictions)
        res = np.argmax(predictions, axis=1)
        print("所预测的类别是:",self.label_dict[str(res[0])])


if __name__ == ‘__main__‘:
    tm = TransferModel()
    # 训练
    # train_gen, test_gen = tm.get_local_data()
    # # print(train_gen)
    # # for data in train_gen:
    # #     print(data[0].shape, data[1].shape)
    # # print(tm.base_model.summary())
    # model = tm.refine_base_model()
    # # print(model)
    # tm.freeze_model()
    # tm.compile(model)
    #
    # tm.fit_generator(model, train_gen, test_gen)

    # 测试
    model = tm.refine_base_model()

    tm.predict(model)

 

  

 

TensorFlow keras 迁移学习

标签:rom   app   kernel   列表   获取   image   port   ima   cto   

原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12267402.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!