标签:type style poc 文件 system plt 使用 append sage
1. 大幅度提升 Pytorch 的训练速度
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
2. 把原有的记录文件加个后缀变为 .bak 文件,避免直接覆盖
# from co-teaching train code
txtfile = save_dir + "/" + model_str + "_%s.txt"%str(args.optimizer) ## good job! nowTime=datetime.datetime.now().strftime(‘%Y-%m-%d-%H:%M:%S‘) if os.path.exists(txtfile): os.system(‘mv %s %s‘ % (txtfile, txtfile+".bak-%s" % nowTime)) # bakeup 备份文件
3. 计算 Accuracy 返回list, 调用函数时,直接提取值,而非提取list
# from co-teaching code but MixMatch_pytorch code also has it
def accuracy(logit, target, topk=(1,)): """Computes the precision@k for the specified values of k""" output = F.softmax(logit, dim=1) maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res prec1, = accuracy(logit, labels, topk=(1,)) # , indicate tuple unpackage prec1, prec5 = accuracy(logits, labels, topk=(1, 5))
4. 善于利用 logger 文件来记录每一个 epoch 的实验值
# from Pytorch_MixMatch code
class Logger(object):
‘‘‘Save training process to log file with simple plot function.‘‘‘
def __init__(self, fpath, title=None, resume=False):
self.file = None
self.resume = resume
self.title = ‘‘ if title == None else title
if fpath is not None:
if resume:
self.file = open(fpath, ‘r‘)
name = self.file.readline()
self.names = name.rstrip().split(‘\t‘)
self.numbers = {}
for _, name in enumerate(self.names):
self.numbers[name] = []
for numbers in self.file:
numbers = numbers.rstrip().split(‘\t‘)
for i in range(0, len(numbers)):
self.numbers[self.names[i]].append(numbers[i])
self.file.close()
self.file = open(fpath, ‘a‘)
else:
self.file = open(fpath, ‘w‘)
def set_names(self, names):
if self.resume:
pass
# initialize numbers as empty list
self.numbers = {}
self.names = names
for _, name in enumerate(self.names):
self.file.write(name)
self.file.write(‘\t‘)
self.numbers[name] = []
self.file.write(‘\n‘)
self.file.flush()
def append(self, numbers):
assert len(self.names) == len(numbers), ‘Numbers do not match names‘
for index, num in enumerate(numbers):
self.file.write("{0:.4f}".format(num))
self.file.write(‘\t‘)
self.numbers[self.names[index]].append(num)
self.file.write(‘\n‘)
self.file.flush()
def plot(self, names=None):
names = self.names if names == None else names
numbers = self.numbers
for _, name in enumerate(names):
x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name]))
plt.legend([self.title + ‘(‘ + name + ‘)‘ for name in names])
plt.grid(True)
def close(self):
if self.file is not None:
self.file.close()
# usage
logger = Logger(new_folder+‘/log_for_%s_WebVision1M.txt‘%data_type, title=title)
logger.set_names([‘epoch‘, ‘val_acc‘, ‘val_acc_ImageNet‘])
for epoch in range(100):
logger.append([epoch, val_acc, val_acc_ImageNet])
logger.close()
5. 利用 argparser 命令行工具来进行代码重构,使用不同参数适配不同数据集,不同优化方式,不同setting, 避免多个高度冗余的重复代码
# 待续
标签:type style poc 文件 system plt 使用 append sage
原文地址:https://www.cnblogs.com/Gelthin2017/p/12148011.html