码迷,mamicode.com
首页 > Web开发 > 详细

pytorch imagenet测试代码

时间:2018-10-12 23:43:16      阅读:608      评论:0      收藏:0      [点我收藏+]

标签:with open   pytorch   https   参数   str   __init__   lse   else   import   

image_test.py

import argparse
import numpy as np
import sys
import os
import csv
from imagenet_test_base import TestKit
import torch


class TestTorch(TestKit):

    def __init__(self):
        super(TestTorch, self).__init__()

        self.truth[‘tensorflow‘][‘inception_v3‘] = [(22, 9.6691055), (24, 4.3524747), (25, 3.5957973), (132, 3.5657473), (23, 3.346283)]
        self.truth[‘keras‘][‘inception_v3‘] = [(21, 0.93430489), (23, 0.002883445), (131, 0.0014781791), (24, 0.0014518998), (22, 0.0014435351)]

        self.model = self.MainModel.KitModel(self.args.w)
        self.model.eval()

    def preprocess(self, image_path):
        x = super(TestTorch, self).preprocess(image_path)
        x = np.transpose(x, (2, 0, 1))
        x = np.expand_dims(x, 0).copy()
        self.data = torch.from_numpy(x)
        self.data = torch.autograd.Variable(self.data, requires_grad = False)


    def print_result(self, image_name, top1, top5):
        predict = self.model(self.data)
        predict = predict.data.numpy()
        return super(TestTorch, self).print_result(predict, image_name, top1, top5)


    def print_intermediate_result(self, layer_name, if_transpose=False):
        intermediate_output = self.model.test.data.numpy()
        super(TestTorch, self).print_intermediate_result(intermediate_output, if_transpose)


    def inference(self, images):

        with open(images) as fp_images:
            images_file = csv.reader(fp_images, delimiter=‘\n‘)
            top1 = 0.0
            top5 = 0.0
            image_count = 0
            for image_name in images_file:
                image_count += 1
                image_path = "../data/imagenet/small_imagenet/"+image_name[0]
                self.preprocess(image_path)
                temp1, temp5 = self.print_result(image_name[0], top1, top5)
                top1 = temp1
                top5 = temp5
        print("top1‘s accuracy : %f"%(top1/image_count))
        print("top5‘s accuracy : %f"%(top5/image_count))
        # self.print_intermediate_result(None, False)
 # self.test_truth()


    def dump(self, path=None):
        if path is None: path = self.args.dump
        torch.save(self.model, path)
        print(‘PyTorch model file is saved as [{}], generated by [{}.py] and [{}].‘.format(
              path, self.args.n, self.args.w))


if __name__==‘__main__‘:
    tester = TestTorch()
    if tester.args.dump:
        tester.dump()
    else:
        tester.inference(tester.args.image)

 

image_test_base.py:

  请见上传的代码。 下载地址:https://files.cnblogs.com/files/jzcbest1016/imagenet_test_base.py.tar.gz

执行py文件时,需要终端输入参数:

 parser = argparse.ArgumentParser()

        parser.add_argument(‘-p‘, ‘--preprocess‘, type=_text_type, help=‘Model Preprocess Type‘)   # pytorch的测试程序, 这里为image_test.py

        parser.add_argument(‘-n‘, type=_text_type, default=‘kit_imagenet‘,  
                            help=‘Network structure file name.‘)   # 模型结构测试文件

        parser.add_argument(‘-s‘, type=_text_type, help=‘Source Framework Type‘,
                            choices=self.truth.keys())           # 框架类型:pytorch,tensorflow...

        parser.add_argument(‘-w‘, type=_text_type, required=True,
                            help=‘Network weights file name‘)   #模型结构文件

        parser.add_argument(‘--image‘, ‘-i‘,
                            type=_text_type, help=‘Test image path.‘,
                            default="../data/file_list.txt"     #图像路径
        )

        parser.add_argument(‘-l‘, ‘--label‘,
                            type=_text_type,
                            default=‘../data/val.txt‘,
                            help=‘Path of label.‘)   #测试集类别

        parser.add_argument(‘--dump‘,
            type=_text_type,
            default=None,
            help=‘Target model path.‘)  # 转化的目标模型文件的保存路径

        parser.add_argument(‘--detect‘,
            type=_text_type,
            default=None,
            help=‘Model detection result path.‘)

        # tensorflow dump tag
        parser.add_argument(‘--dump_tag‘,
            type=_text_type,
            default=None,
            help=‘Tensorflow model dump type‘,
            choices=[‘SERVING‘, ‘TRAINING‘])

 

pytorch imagenet测试代码

标签:with open   pytorch   https   参数   str   __init__   lse   else   import   

原文地址:https://www.cnblogs.com/jzcbest1016/p/9780356.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!