标签:with open pytorch https 参数 str __init__ lse else import
image_test.py
import argparse import numpy as np import sys import os import csv from imagenet_test_base import TestKit import torch class TestTorch(TestKit): def __init__(self): super(TestTorch, self).__init__() self.truth[‘tensorflow‘][‘inception_v3‘] = [(22, 9.6691055), (24, 4.3524747), (25, 3.5957973), (132, 3.5657473), (23, 3.346283)] self.truth[‘keras‘][‘inception_v3‘] = [(21, 0.93430489), (23, 0.002883445), (131, 0.0014781791), (24, 0.0014518998), (22, 0.0014435351)] self.model = self.MainModel.KitModel(self.args.w) self.model.eval() def preprocess(self, image_path): x = super(TestTorch, self).preprocess(image_path) x = np.transpose(x, (2, 0, 1)) x = np.expand_dims(x, 0).copy() self.data = torch.from_numpy(x) self.data = torch.autograd.Variable(self.data, requires_grad = False) def print_result(self, image_name, top1, top5): predict = self.model(self.data) predict = predict.data.numpy() return super(TestTorch, self).print_result(predict, image_name, top1, top5) def print_intermediate_result(self, layer_name, if_transpose=False): intermediate_output = self.model.test.data.numpy() super(TestTorch, self).print_intermediate_result(intermediate_output, if_transpose) def inference(self, images): with open(images) as fp_images: images_file = csv.reader(fp_images, delimiter=‘\n‘) top1 = 0.0 top5 = 0.0 image_count = 0 for image_name in images_file: image_count += 1 image_path = "../data/imagenet/small_imagenet/"+image_name[0] self.preprocess(image_path) temp1, temp5 = self.print_result(image_name[0], top1, top5) top1 = temp1 top5 = temp5 print("top1‘s accuracy : %f"%(top1/image_count)) print("top5‘s accuracy : %f"%(top5/image_count)) # self.print_intermediate_result(None, False) # self.test_truth() def dump(self, path=None): if path is None: path = self.args.dump torch.save(self.model, path) print(‘PyTorch model file is saved as [{}], generated by [{}.py] and [{}].‘.format( path, self.args.n, self.args.w)) if __name__==‘__main__‘: tester = TestTorch() if tester.args.dump: tester.dump() else: tester.inference(tester.args.image)
image_test_base.py:
请见上传的代码。 下载地址:https://files.cnblogs.com/files/jzcbest1016/imagenet_test_base.py.tar.gz
执行py文件时,需要终端输入参数:
parser = argparse.ArgumentParser() parser.add_argument(‘-p‘, ‘--preprocess‘, type=_text_type, help=‘Model Preprocess Type‘) # pytorch的测试程序, 这里为image_test.py parser.add_argument(‘-n‘, type=_text_type, default=‘kit_imagenet‘, help=‘Network structure file name.‘) # 模型结构测试文件 parser.add_argument(‘-s‘, type=_text_type, help=‘Source Framework Type‘, choices=self.truth.keys()) # 框架类型:pytorch,tensorflow... parser.add_argument(‘-w‘, type=_text_type, required=True, help=‘Network weights file name‘) #模型结构文件 parser.add_argument(‘--image‘, ‘-i‘, type=_text_type, help=‘Test image path.‘, default="../data/file_list.txt" #图像路径 ) parser.add_argument(‘-l‘, ‘--label‘, type=_text_type, default=‘../data/val.txt‘, help=‘Path of label.‘) #测试集类别 parser.add_argument(‘--dump‘, type=_text_type, default=None, help=‘Target model path.‘) # 转化的目标模型文件的保存路径 parser.add_argument(‘--detect‘, type=_text_type, default=None, help=‘Model detection result path.‘) # tensorflow dump tag parser.add_argument(‘--dump_tag‘, type=_text_type, default=None, help=‘Tensorflow model dump type‘, choices=[‘SERVING‘, ‘TRAINING‘])
标签:with open pytorch https 参数 str __init__ lse else import
原文地址:https://www.cnblogs.com/jzcbest1016/p/9780356.html