标签:
1. 准备一些必要模板
import numpy as np #导入numpy模块用作后面的数值计算 import matplotlib.pyplot as plt #导入matplotlib 主要用作画图
import sys #caffe module应该要被导入到python path caffe_root = '../' # this file should be run from {caffe_root}/examples (otherwise change this line) sys.path.insert(0, caffe_root + 'python') import caffe
3 准备好模型, 因为我们是利用已经训练好的模型
我们用这个模型:bvlc_reference_caffenet.caffemodel ,是ALEXNET的一个变种,并且放到‘models/bvlc_reference_caffenet/这个文件夹下。
下面验证有没有找到这个模型
import os if os.path.isfile(caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'): print 'CaffeNet found.' else: print 'Downloading pre-trained CaffeNet model...' !../scripts/download_model_binary.py ../models/bvlc_reference_caffenet
4. 初始化网络
caffe.set_mode_cpu() #设置成是cpu运行还是gpu,如果gpu改为caffe.set_mode_gpu() model_def = caffe_root + 'models/bvlc_reference_caffenet/deploy.prototxt' #这个是前向传播时候的网络 model_weights = caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel' #训练好的模型 net = caffe.Net(model_def, # defines the structure of the model model_weights, # contains the trained weights caffe.TEST) # use test mode (e.g., don't perform dropout) ,这些定义都在._caffe中
同时,要减去均值,
<span style="font-size: 14px; line-height: 17.0001px; background-color: rgb(247, 247, 247);">mu = np.load(caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy') #导入图片的均值,得到一个三维数组,3*256*256,mu[0]是B,mu[1] #是G,mu[2]是R mu = mu.mean(1).mean(1) # average over pixels to obtain the mean (BGR) pixel values,计算均值,输出是一个三维向量。 print 'mean-subtracted values:', zip('BGR', mu)# zip 接受多个序列作为参数,返回一个列表。</span><span style="font-size: 14px; line-height: 17.0001px; background-color: rgb(247, 247, 247);"> # create transformer for the input called 'data' ,在io.py中可以看到源码 transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})#定义一个实例,字典存储在self.inputs中 # def __init__(self, inputs): # self.inputs = inputs # self.transpose = {} # self.channel_swap = {} # self.raw_scale = {} # self.mean = {} # self.input_scale = {} transformer.set_transpose('data', (2,0,1)) # move image channels to outermost dimension (2,0,1)存储在self.transpose['data']=order transformer.set_mean('data', mu) # 将这个均值存储在self.mean['data']=mu transformer.set_raw_scale('data', 255) # rescale from [0, 1] to [0, 255],.原因是网络输入图片不是[0,1],所以需要做一个变换</span><span style="font-size: 14px; line-height: 17.0001px; background-color: rgb(247, 247, 247);">transformer.set_channel_swap('data', (2,1,0)) # swap channels from RGB to BGR</span>
5.测试图片
# set the size of the input (we can skip this if we're happy # with the default; we can also change it later, e.g., for different batch sizes) net.blobs['data'].reshape(50, # batch size 3, # 3-channel (BGR) images 227, 227) # image size is 227x227载入图片, 已经预处理图片
image = caffe.io.load_image(caffe_root + 'examples/images/cat.jpg') #图片大小为(360,480,3) transformed_image = transformer.preprocess('data', image) #这里的transformer 已经包含了很多属性了,这些都是在之前在第四步定义的,其实这个'data'参数对应<span style="font-family: Arial, Helvetica, sans-serif;"># 的是上面源码中__init__中许多字典的key</span>
# copy the image data into the memory allocated for the net net.blobs['data'].data[...] = transformed_image ### perform classification output = net.forward() #前向传播 output_prob = output['prob'][0] # the output probability vector for the first image in the batch,其实output['prob']输出的会是一个(50,1000)的array,因为#之前我们设置的batch size= 50,因为只有一张图片,所以其实得到的这50行都是一样的,去第一行,也就是第一张图片 print 'predicted class is:', output_prob.argmax() # 得到概率最大的那一类,结果是281
# load ImageNet labels labels_file = caffe_root + 'data/ilsvrc12/synset_words.txt' if not os.path.exists(labels_file): !../data/ilsvrc12/get_ilsvrc_aux.sh labels = np.loadtxt(labels_file, str, delimiter='\t') print 'output label:', labels[output_prob.argmax()]
输出output label: n02123045 tabby, tabby cat
# sort top five predictions from softmax output top_inds = output_prob.argsort()[::-1][:5] # reverse sort and take five largest items ,argsort()得到从低到高index排序结果,[::-1]倒个序。 print 'probabilities and labels:' #取排名最靠前的五名 zip(output_prob[top_inds], labels[top_inds])
[(0.31243637, ‘n02123045 tabby, tabby cat‘), (0.2379719, ‘n02123159 tiger cat‘), (0.12387239, ‘n02124075 Egyptian cat‘), (0.10075711, ‘n02119022 red fox, Vulpes vulpes‘), (0.070957087, ‘n02127052 lynx, catamount‘)]
标签:
原文地址:http://blog.csdn.net/ghlwxt/article/details/51329374