码迷,mamicode.com
首页 > 其他好文 > 详细

Keras用动态数据生成器(DataGenerator)和fitgenerator动态训练模型

时间:2019-10-04 09:20:58      阅读:453      评论:0      收藏:0      [点我收藏+]

标签:ESS   方法   类的继承   sequence   自动   ret   callbacks   生成   执行   

 

 最近做Kaggle的图像分类比赛:RSNA Intracranial Hemorrhage Detection (https://www.kaggle.com/c/rsna-intracranial-hemorrhage-detection/overview)以及阅读Yolov3

源码的时候接触到深度学习训练时一个有趣的技巧,那就是构造生成器generator 并且用Keras 的fit_generator来批量生成数据,释放内存,该方法适合于大规模数据集的训练。一个DataGenerator是keras的Sequence类的继承类,一般要包含__len__,__getitem__, on_epoch_end等方法,例如下面的批量图片数据生成器:

class DataGenerator(keras.utils.Sequence):
      
      
      def __init__(self, list_IDs, labels, batch_size=1, img_size=(512, 512), 
                   img_dir, *args, **kwargs):

         """
            self.list_IDs:存放所有需要训练的图片文件名的列表。
            self.labels:记录图片标注的分类信息的pandas.DataFrame数据类型,已经预先给定。
            self.batch_size:每次批量生成,训练的样本大小。
            self.img_size:训练的图片尺寸。
            self.img_dir:图片在电脑中存放的路径。
      
      
         """

          
          self.list_IDs = list_IDs
          self.labels = labels
          self.batch_size = batch_size
          self.img_size = img_size
          self.img_dir = img_dir
          self.on_epoch_end()

      def __len__(self):
          
          """
             返回生成器的长度,也就是总共分批生成数据的次数。
             
          """
          return int(ceil(len(self.list_IDs) / self.batch_size))

     def __getitem__(self, index):
         
         """
            该函数返回每次我们需要的经过处理的数据。
         """
         
         indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]
         list_IDs_temp = [self.list_IDs[k] for k in indices]
         X, Y = self.__data_generation(list_IDs_temp)
         return X, Y

     def on_epoch_end(self):
         
         """
            该函数将在训练时每一个epoch结束的时候自动执行,在这里是随机打乱索引次序以方便下一batch运行。

         """
         self.indices = np.arange(len(self.list_IDs))
         np.random.shuffle(self.indices)

     def __data_generation(self, list_IDs_temp):

        """
           给定文件名,生成数据。
        """
        X = np.empty((self.batch_size, *self.img_size, 1))
        Y = np.empty((self.batch_size, 6), dtype=np.float32)

       for i, ID in enumerate(list_IDs_temp):
       X[i,] = mpimg.imread(self.img_dir+ID+".png")
       Y[i,] = self.labels.loc[ID].values

       return X, Y

 

有了这个生成器,我们就可以用fit_generator 方法进行训练,格式套路如下:

model.fit_generator(generator,

steps_per_epoch=...,

epochs=...,

verbose=...,

callbacks=...,

validation_data=...,

validation_steps=...,

validation_freq=...,

class_weight=None=...,

max_queue_size=...

workers=...,

use_multiprocessing=...,

)

除此以外我们还可以搞批量预测:

model.predict_generator()

Keras用动态数据生成器(DataGenerator)和fitgenerator动态训练模型

标签:ESS   方法   类的继承   sequence   自动   ret   callbacks   生成   执行   

原文地址:https://www.cnblogs.com/szqfreiburger/p/11621261.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!