标签:bat 参考 tps taf epo range shape cpu 不同的
虽然已经走在 torch boy 的路上了, 还是把碰到的这个坑给记录一下
model.fit_generator 一般参数的配置参考官方文档就好,其中 generator, workers, use_multiprocessing 的使用有一些坑存在。
此时 generator 用一个普通的 generator去提供数据即可,类似官方提供的这种
def generate_arrays_from_file(path):
while True:
with open(path) as f:
for line in f:
# create numpy arrays of input data
# and labels, from each line in the file
x1, x2, y = process_line(line)
yield ({'input_1': x1, 'input_2': x2}, {'output': y})
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
steps_per_epoch=10000, epochs=10)
这时依然用一个 generator function 来做 generator在拟合的时候便会报错如下:
PicklingError: Can't pickle <function generator_queue.<locals>.data_generator_task at
且当 use_multiprocessing=True 时,如果你使用的是 generator function, 代码会把你的数据copy几份分给不同的worker去处理,但我们希望的是把一份数据平均分拆成几份给多个worker去处理。
怎么解决上面两个问题? keras.utils.Sequence 可以做到
很简单,继承 keras.utils.Sequence 这个类,重写自己的 len(), getitem 即可。
class SequenceData(Sequence):
def __init__(self, filePaths, batch_size):
self.filePaths = filePaths[:100].copy()
self.batch_size = batch_size
self.Y = self.getY()
def __len__(self):
return len(self.Y) // self.batch_size
def __getitem__(self, index):
batch_X = np.zeros((self.batch_size,) + IMG_DIMS, dtype='float32')
batch_Y_ = self.Y[index*self.batch_size: (index+1)*self.batch_size].copy()
batch_Y_.reset_index(drop=True, inplace=True)
assert batch_Y_.shape[0] == self.batch_size
for index, rows in batch_Y_.iterrows():
try:
img = _load_img(rows['path'])
batch_X[index, :, :, :] = img.copy()
batch_Y_.loc[index, 'valid'] = 1
except:
batch_Y_.loc[index, 'valid'] = 0
traceback.print_exc()
batch_Y = to_categorical(batch_Y_['label'], classes_num)
return batch_X, batch_Y
def __iter__(self):
for item in (self[i] for i in range(len(self))):
yield item
def getY(self):
Y = pd.DataFrame(self.filePaths, columns=['path'])
Y['class'] = Y['path'].apply(lambda x: path2class(x))
Y['label'] = Y['class'].apply(lambda x: class2label[x])
Y_biaoge = Y[Y['class']=='biaoge'].copy()
Y = Y.append(Y_biaoge)
Y = Y.append(Y_biaoge)
Y = Y.append(Y_biaoge)
Y = Y.sample(frac=1).reset_index(drop=True)
return Y
可能数据量过小,并行的效果不是太明显。
数据读取方式 | workers | use_multiprocessing | 耗时/s |
---|---|---|---|
内存读取 | 0 | True | 1797 |
keras.utils.Sequence | 0 | False | 1475 |
keras.utils.Sequence | 4 | True |
参考:
标签:bat 参考 tps taf epo range shape cpu 不同的
原文地址:https://www.cnblogs.com/Fosen/p/11953468.html