码迷,mamicode.com
首页 > Web开发 > 详细

caffe mnist实例 --lenet_train_test.prototxt 网络配置详解

时间:2017-10-28 20:26:35      阅读:520      评论:0      收藏:0      [点我收藏+]

标签:inter   number   解压缩   protoc   als   efi   脚本   dia   cti   

 

1.mnist实例

##1.数据下载 获得mnist的数据包,在caffe根目录下执行./data/mnist/get_mnist.sh脚本。 get_mnist.sh脚本先下载样本库并进行解压缩,得到四个文件。 技术分享

2.生成LMDB

成功解压缩下载的样本库后,然后执行./examples/mnist/create_mnist.sh。 create_mnist.sh脚本先利用caffe-master/build/examples/mnist/目录下的convert_mnist_data.bin工具,将mnist data转化为caffe可用的lmdb格式文件,然后将生成的mnist-train-lmdb和mnist-test-lmdb两个文件放在caffe-master/example/mnist目录下面。

3.网络配置

LeNet网络定义在./examples/mnist/lenet_train_test.prototxt 文件中。

name: "LeNet"
layer {
  name: "mnist"    //输入层的名称mnist
  type: "Data"     //输入层的类型为Data  top: "data"      //本层下一场连接data层和label blob空间
  top: "label"
  include {
    phase: TRAIN   //训练阶段
  }
  transform_param {
    scale: 0.00390625  //输入图片像素归一到[0,1].1除以256为0.00390625
  }
  data_param {
    source: "examples/mnist/mnist_train_lmdb"  //从mnist_train_lmdb中读入数据
    batch_size: 64    //batch大小为64,一次训练64条数据
    backend: LMDB
  }
}
layer {
  name: "mnist"    //输入层的名称mnist
  type: "Data"     //输入层的类型为Data  top: "data"      //本层下一场连接data层和label blob空间
  top: "label"
  include {
    phase: TEST   //测试阶段
  }
  transform_param {
    scale: 0.00390625  //输入图片像素归一到[0,1].1除以256为0.00390625
  }
  data_param {
    source: "examples/mnist/mnist_test_lmdb"  //从mnist_test_lmdb中读入数据
    batch_size: 100    //batch大小为100,一次训练100条数据
    backend: LMDB
  }
}
layer {
  name: "conv1"    //卷积层名称conv1
  type: "Convolution"    //层类型为卷积层
  bottom: "data"    //本层使用上一层的data,生成下一层conv1的blob
  top: "conv1"
  param {
    lr_mult: 1    //权重参数w的学习率倍数
  }
  param {
    lr_mult: 2    //偏置参数b的学习率倍数
  }
  convolution_param {
    num_output: 20    //输出单元数20
    kernel_size: 5    //卷积核大小为5*5
    stride: 1         //步长为1
    weight_filler {   //允许用随机值初始化权重和偏置值
      type: "xavier"  //使用xavier算法自动确定基于输入—输出神经元数量的初始规模
    }
    bias_filler {
      type: "constant"    //偏置值初始化为常数,默认为0
    }
  }
}
layer {
  name: "pool1"      //层名称为pool1
  type: "Pooling"    //层类型为pooling
  bottom: "conv1"    //本层的上一层是conv1,生成下一层pool1的blob
  top: "pool1"
  pooling_param {    //pooling层的参数
    pool: MAX        //pooling的方式是MAX
    kernel_size: 2   //pooling核是2*2
    stride: 2        //pooling步长是2
  }
}
layer {
  name: "conv2"    //第二个卷积层,同第一个卷积层相同,只是卷积核为50
  type: "Convolution"
  bottom: "pool1"
  top: "conv2"
  param {
    lr_mult: 1
  }
  param {
    lr_mult: 2
  }
  convolution_param {
    num_output: 50
    kernel_size: 5
    stride: 1
    weight_filler {
      type: "xavier"
    }
    bias_filler {
      type: "constant"
    }
  }
}
layer {
  name: "pool2"     //第二个pooling层,与第一个pooling层相同
  type: "Pooling"
  bottom: "conv2"
  top: "pool2"
  pooling_param {
    pool: MAX
    kernel_size: 2
    stride: 2
  }
}
layer {            //全连接层
  name: "ip1"      //全连接层名称ip1
  type: "InnerProduct"    //层类型为全连接层
  bottom: "pool2"
  top: "ip1"
  param {
    lr_mult: 1
  }
  param {
    lr_mult: 2
  }
  inner_product_param {     //全连接层的参数
    num_output: 500         //输出500个节点
    weight_filler {
      type: "xavier"
    }
    bias_filler {
      type: "constant"
    }
  }
}
layer {
  name: "relu1"       //ReLU  type: "ReLU"        //层名称为relu1
  bottom: "ip1"       //层类型为ReLU
  top: "ip1"
}
layer {
  name: "ip2"         //第二个全连接层
  type: "InnerProduct"
  bottom: "ip1"
  top: "ip2"
  param {
    lr_mult: 1
  }
  param {
    lr_mult: 2
  }
  inner_product_param {
    num_output: 10     //输出10个单元
    weight_filler {
      type: "xavier"
    }
    bias_filler {
      type: "constant"
    }
  }
}
layer {
  name: "accuracy"
  type: "Accuracy"
  bottom: "ip2"
  bottom: "label"
  top: "accuracy"
  include {
    phase: TEST
  }
}
layer {        //loss层,softmax_loss层实现softmax和多项Logistic损失
  name: "loss"
  type: "SoftmaxWithLoss"
  bottom: "ip2"
  bottom: "label"
  top: "loss"
}

4.训练网络

运行./examples/mnist/train_lenet.sh。 执行此脚本是,实际运行的是lenet_solver.prototxt中的定义。

# The train/test net protocol buffer definition
net: "examples/mnist/lenet_train_test.prototxt"    //网络具体定义
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100    //test迭代次数,若batch_size=100,则100张图一批,训练100次,可覆盖1000张图
# Carry out testing every 500 training iterations.
test_interval: 500    //训练迭代500次,测试一次
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01    //网络参数:学习率,动量,权重的衰减
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy    //学习策略:有固定学习率和每步递减学习率
lr_policy: "inv"    //当前使用递减学习率
gamma: 0.0001
power: 0.75
# Display every 100 iterations    //每迭代100次显示一次
display: 100
# The maximum number of iterations   //最大迭代数
max_iter: 10000
# snapshot intermediate results    //每5000次迭代存储一次数据
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
# solver mode: CPU or GPU
solver_mode: CPU    //本例用CPU训练

数据训练结束后,会生成以下四个文件: 技术分享

5.测试网络

运行./build/tools/caffe.bin test -model=examples/mnist/lenet_train_test.prototxt -weights=examples/mnist/lenet_iter_10000.caffemodel

test:表示对训练好的模型进行Testing,而不是training。其他参数包括train, time, device_query。

-model=XXX:指定模型prototxt文件,这是一个文本文件,详细描述了网络结构和数据集信息。 技术分享

从上面的打印输出可看出,测试数据中的accruacy平均成功率为98%。

mnist手写测试

手写数字的图片必须满足以下条件:

  • 必须是256位黑白色
  • 必须是黑底白字
  • 像素大小必须是28*28
  • 数字在图片中间,上下左右没有过多的空白。

测试图片

技术分享 技术分享 技术分享 技术分享 技术分享

手写数字识别脚本

import os
import sys
import numpy as np
import matplotlib.pyplot as plt

caffe_root = ‘/home/lynn/caffe/‘
sys.path.insert(0, caffe_root + ‘python‘)
import caffe

MODEL_FILE = ‘/home/lynn/caffe/examples/mnist/lenet.prototxt‘
PRETRAINED = ‘/home/lynn/caffe/examples/mnist/lenet_iter_10000.caffemodel‘

IMAGE_FILE = ‘/home/lynn/test.bmp‘
input_image = caffe.io.load_image(IMAGE_FILE, color=False)

#print input_image
net = caffe.Classifier(MODEL_FILE, PRETRAINED)
prediction = net.predict([input_image], oversample = False)
caffe.set_mode_cpu()
print ‘predicted class: ‘, prediction[0].argmax()

测试结果

技术分享 技术分享
技术分享 技术分享
技术分享

caffe mnist实例 --lenet_train_test.prototxt 网络配置详解

标签:inter   number   解压缩   protoc   als   efi   脚本   dia   cti   

原文地址:http://www.cnblogs.com/is-Tina/p/7747844.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!