码迷,mamicode.com
首页 > 其他好文 > 详细

caffe源码修改:抽取任意一张图片的特征

时间:2014-09-19 17:38:25      阅读:333      评论:0      收藏:0      [点我收藏+]

标签:deep learning   caffe   机器学习   源码修改   


caffe源码修改:抽取任意一张图片的特征

目前caffe不是很完善,输入的图片数据需要在prototxt指定路径。但是我们往往有这么一个需求:训练后得到一个模型文件,我们想拿这个模型文件来对一张图片抽取特征或者预测分类等。如果非得在prototxt指定路径,就很不方便。因此,这样的工具才是我们需要的:给一个可执行文件通过命令行来传递图片路径,然后caffe读入图片数据,进行一次正向传播。


因此我做了这么一个工具,用来抽取任意一张图片的特征。

这工具的使用方法如下:


extract_one_feature.bin ./model/caffe_reference_imagenet_model ./examples/_temp/imagenet_val.prototxt fc7 ./examples/_temp/features /media/G/imageset/clothing/针织衫/针织衫_426.jpg CPU

参数1:./model/caffe_reference_imagenet_model是训练后的模型文件

参数2:./examples/_temp/imagenet_val.prototxt 网络配置文件

参数3:fc7是blob的名字

参数4:./examples/_temp/features 将该图片的特征保存在该文件

参数5:图片路径

参数6:GPU或者CPU模式


(其实我还想到更好的工具,如果该可执行文件是监听模式的,就是通过一定的方式,给该进程传递 图片路径,进程接到任务就执行。

这样子的话,就不需要每次抽一张图片都要申请内存空间。(*^__^*) 嘻嘻……)


下面给出初步修改方法,大家可以根据自己需求再修改。


extract_one_feature.cpp(该文件参考过源码中extract_features.cpp修改)

#include <stdio.h>  // for snprintf
#include <string>
#include <vector>
#include <iostream>
#include <fstream>

#include "boost/algorithm/string.hpp"
#include "google/protobuf/text_format.h"
#include "leveldb/db.h"
#include "leveldb/write_batch.h"

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/net.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp"

using namespace caffe;  // NOLINT(build/namespaces)

template<typename Dtype>
int feature_extraction_pipeline(int argc, char** argv);

int main(int argc, char** argv) {
  return feature_extraction_pipeline<float>(argc, argv);
//  return feature_extraction_pipeline<double>(argc, argv);
}

template<typename Dtype>
class writeDb
{
public:
	void open(string dbName)
	{
		db.open(dbName.c_str());
	}
	void write(const Dtype &data)
	{
		db<<data;
	}
	void write(const string &str)
	{
		db<<str;
	}
	virtual ~writeDb()
	{
		db.close();
	}
private:
	std::ofstream db;
};

template<typename Dtype>
int feature_extraction_pipeline(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);
  const int num_required_args = 6;
  if (argc < num_required_args) {
    LOG(ERROR)<<
    "This program takes in a trained network and an input data layer, and then"
    " extract features of the input data produced by the net.\n"
    "Usage: extract_features  pretrained_net_param"
    "  feature_extraction_proto_file  extract_feature_blob_name1[,name2,...]"
    "  save_feature_leveldb_name1[,name2,...]  img_path  [CPU/GPU]"
    "  [DEVICE_ID=0]\n"
    "Note: you can extract multiple features in one pass by specifying"
    " multiple feature blob names and leveldb names seperated by ','."
    " The names cannot contain white space characters and the number of blobs"
    " and leveldbs must be equal.";
    return 1;
  }
  int arg_pos = num_required_args;

  arg_pos = num_required_args;
  if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {
    LOG(ERROR)<< "Using GPU";
    uint device_id = 0;
    if (argc > arg_pos + 1) {
      device_id = atoi(argv[arg_pos + 1]);
      CHECK_GE(device_id, 0);
    }
    LOG(ERROR) << "Using Device_id=" << device_id;
    Caffe::SetDevice(device_id);
    Caffe::set_mode(Caffe::GPU);
  } else {
    LOG(ERROR) << "Using CPU";
    Caffe::set_mode(Caffe::CPU);
  }
  Caffe::set_phase(Caffe::TEST);

  arg_pos = 0;  // the name of the executable
  string pretrained_binary_proto(argv[++arg_pos]);//网络模型参数文件

  string feature_extraction_proto(argv[++arg_pos]);

  shared_ptr<Net<Dtype> > feature_extraction_net(
      new Net<Dtype>(feature_extraction_proto));

  feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto);//将网络参数load进内存


  string extract_feature_blob_names(argv[++arg_pos]);
  vector<string> blob_names;//要抽取特征的layer的名字,可以是多个
  boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(","));

  string save_feature_leveldb_names(argv[++arg_pos]);
  vector<string> leveldb_names;// 这里我改写成一个levedb为一个文件,数据格式不使用真正的levedb,而是自定义
  boost::split(leveldb_names, save_feature_leveldb_names,
               boost::is_any_of(","));
  CHECK_EQ(blob_names.size(), leveldb_names.size()) <<
      " the number of blob names and leveldb names must be equal";
  size_t num_features = blob_names.size();

  for (size_t i = 0; i < num_features; i++) {
    CHECK(feature_extraction_net->has_blob(blob_names[i]))  //检测blob的名字在网络中是否存在
        << "Unknown feature blob name " << blob_names[i]
        << " in the network " << feature_extraction_proto;
  }


  vector<shared_ptr<writeDb<Dtype> > > feature_dbs;
  for (size_t i = 0; i < num_features; ++i) //打开db,准备写入数据
  {
    LOG(INFO)<< "Opening db " << leveldb_names[i];
    writeDb<Dtype>* db = new writeDb<Dtype>();
    db->open(leveldb_names[i]);
    feature_dbs.push_back(shared_ptr<writeDb<Dtype> >(db));
  }



  LOG(ERROR)<< "Extacting Features";

  const shared_ptr<Layer<Dtype> > layer = feature_extraction_net->layer_by_name("data");//获取第一层
  MyImageDataLayer<Dtype>* my_layer = (MyImageDataLayer<Dtype>*)layer.get();
  my_layer->setImgPath(argv[++arg_pos],1);//"/media/G/imageset/clothing/针织衫/针织衫_1.jpg"
  //设置图片路径

  vector<Blob<float>*> input_vec;
  vector<int> image_indices(num_features, 0);
  int num_mini_batches = 1;//atoi(argv[++arg_pos]);//共多少次迭代。  每次迭代的数量在prototxt用batchsize指定
  for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) //共num_mini_batches次迭代
  {
    feature_extraction_net->Forward(input_vec);//一次正向传播
    for (int i = 0; i < num_features; ++i) //多层特征
    {
      const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
          ->blob_by_name(blob_names[i]);
      int batch_size = feature_blob->num();
      int dim_features = feature_blob->count() / batch_size;

      Dtype* feature_blob_data;

      for (int n = 0; n < batch_size; ++n)
      {
        feature_blob_data = feature_blob->mutable_cpu_data() +
            feature_blob->offset(n);
        feature_dbs[i]->write("3 ");
        for (int d = 0; d < dim_features; ++d)
        {
          feature_dbs[i]->write((Dtype)(d+1));
          feature_dbs[i]->write(":");
          feature_dbs[i]->write(feature_blob_data[d]);
          feature_dbs[i]->write(" ");
        }
        feature_dbs[i]->write("\n");

      }  // for (int n = 0; n < batch_size; ++n)
    }  // for (int i = 0; i < num_features; ++i)
  }  // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)


  LOG(ERROR)<< "Successfully extracted the features!";
  return 0;
}

my_data_layer.cpp(参考image_data_layer修改)

#include <fstream>  // NOLINT(readability/streams)
#include <iostream>  // NOLINT(readability/streams)
#include <string>
#include <utility>
#include <vector>

#include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/rng.hpp"
#include "caffe/vision_layers.hpp"

namespace caffe {


template <typename Dtype>
MyImageDataLayer<Dtype>::~MyImageDataLayer<Dtype>() {
}


template <typename Dtype>
void MyImageDataLayer<Dtype>::setImgPath(string path,int label)
{
	lines_.clear();
	lines_.push_back(std::make_pair(path, label));
}


template <typename Dtype>
void MyImageDataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
      vector<Blob<Dtype>*>* top) {
  Layer<Dtype>::SetUp(bottom, top);
  const int new_height  = this->layer_param_.image_data_param().new_height();
  const int new_width  = this->layer_param_.image_data_param().new_width();
  CHECK((new_height == 0 && new_width == 0) ||
      (new_height > 0 && new_width > 0)) << "Current implementation requires "
      "new_height and new_width to be set at the same time.";

  /*
   * 因为下面需要随便拿一张图片来初始化blob。
   * 因此需要硬盘上有一张图片。
   * 1 从prototxt读取一张图片的路径,
   * 2 其实也可以在这里将用于初始化的图片路径写死
  */

  /*1*/
  /*
  const string& source = this->layer_param_.image_data_param().source();
  LOG(INFO) << "Opening file " << source;
  std::ifstream infile(source.c_str());
  string filename;
  int label;
  while (infile >> filename >> label) {
    lines_.push_back(std::make_pair(filename, label));
  }
  */

  /*2*/
  lines_.push_back(std::make_pair("/home/linger/init.jpg",1));

  //上面1和2代码可以任意用一段

  lines_id_ = 0;
  // Read a data point, and use it to initialize the top blob. (随便)读取一张图片,来初始化blob
  Datum datum;
  CHECK(ReadImageToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
                         new_height, new_width, &datum));
  // image
  const int crop_size = this->layer_param_.image_data_param().crop_size();
  const int batch_size = 1;//this->layer_param_.image_data_param().batch_size();
  const string& mean_file = this->layer_param_.image_data_param().mean_file();
  if (crop_size > 0) {
    (*top)[0]->Reshape(batch_size, datum.channels(), crop_size, crop_size);
    prefetch_data_.Reshape(batch_size, datum.channels(), crop_size, crop_size);
  } else {
    (*top)[0]->Reshape(batch_size, datum.channels(), datum.height(),
                       datum.width());
    prefetch_data_.Reshape(batch_size, datum.channels(), datum.height(),
        datum.width());
  }
  LOG(INFO) << "output data size: " << (*top)[0]->num() << ","
      << (*top)[0]->channels() << "," << (*top)[0]->height() << ","
      << (*top)[0]->width();
  // label
  (*top)[1]->Reshape(batch_size, 1, 1, 1);
  prefetch_label_.Reshape(batch_size, 1, 1, 1);
  // datum size
  datum_channels_ = datum.channels();
  datum_height_ = datum.height();
  datum_width_ = datum.width();
  datum_size_ = datum.channels() * datum.height() * datum.width();
  CHECK_GT(datum_height_, crop_size);
  CHECK_GT(datum_width_, crop_size);
  // check if we want to have mean
  if (this->layer_param_.image_data_param().has_mean_file()) {
    BlobProto blob_proto;
    LOG(INFO) << "Loading mean file from" << mean_file;
    ReadProtoFromBinaryFile(mean_file.c_str(), &blob_proto);
    data_mean_.FromProto(blob_proto);
    CHECK_EQ(data_mean_.num(), 1);
    CHECK_EQ(data_mean_.channels(), datum_channels_);
    CHECK_EQ(data_mean_.height(), datum_height_);
    CHECK_EQ(data_mean_.width(), datum_width_);
  } else {
    // Simply initialize an all-empty mean.
    data_mean_.Reshape(1, datum_channels_, datum_height_, datum_width_);
  }
  // Now, start the prefetch thread. Before calling prefetch, we make two
  // cpu_data calls so that the prefetch thread does not accidentally make
  // simultaneous cudaMalloc calls when the main thread is running. In some
  // GPUs this seems to cause failures if we do not so.
  prefetch_data_.mutable_cpu_data();
  prefetch_label_.mutable_cpu_data();
  data_mean_.cpu_data();


}

//--------------------------------下面是读取一张图片数据-----------------------------------------------
template <typename Dtype>
void MyImageDataLayer<Dtype>::fetchData() {
	  Datum datum;
	  CHECK(prefetch_data_.count());
	  Dtype* top_data = prefetch_data_.mutable_cpu_data();
	  Dtype* top_label = prefetch_label_.mutable_cpu_data();
	  ImageDataParameter image_data_param = this->layer_param_.image_data_param();
	  const Dtype scale = image_data_param.scale();//image_data_layer相关参数
	  const int batch_size = 1;//image_data_param.batch_size(); 这里我们只需要一张图片

	  const int crop_size = image_data_param.crop_size();
	  const bool mirror = image_data_param.mirror();
	  const int new_height = image_data_param.new_height();
	  const int new_width = image_data_param.new_width();

	  if (mirror && crop_size == 0) {
	    LOG(FATAL) << "Current implementation requires mirror and crop_size to be "
	        << "set at the same time.";
	  }
	  // datum scales
	  const int channels = datum_channels_;
	  const int height = datum_height_;
	  const int width = datum_width_;
	  const int size = datum_size_;
	  const int lines_size = lines_.size();
	  const Dtype* mean = data_mean_.cpu_data();

	  for (int item_id = 0; item_id < batch_size; ++item_id) {//读取一图片
	    // get a blob
	    CHECK_GT(lines_size, lines_id_);
	    if (!ReadImageToDatum(lines_[lines_id_].first,
	          lines_[lines_id_].second,
	          new_height, new_width, &datum)) {
	      continue;
	    }
	    const string& data = datum.data();
	    if (crop_size) {
	      CHECK(data.size()) << "Image cropping only support uint8 data";
	      int h_off, w_off;
	      // We only do random crop when we do training.
	        h_off = (height - crop_size) / 2;
	        w_off = (width - crop_size) / 2;

	        // Normal copy 正常读取,把裁剪后的图片数据读给top_data
	        for (int c = 0; c < channels; ++c) {
	          for (int h = 0; h < crop_size; ++h) {
	            for (int w = 0; w < crop_size; ++w) {
	              int top_index = ((item_id * channels + c) * crop_size + h)
	                              * crop_size + w;
	              int data_index = (c * height + h + h_off) * width + w + w_off;
	              Dtype datum_element =
	                  static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
	              top_data[top_index] = (datum_element - mean[data_index]) * scale;
	            }
	          }
	        }

	    } else {
	      // Just copy the whole data 正常读取,把图片数据读给top_data
	      if (data.size()) {
	        for (int j = 0; j < size; ++j) {
	          Dtype datum_element =
	              static_cast<Dtype>(static_cast<uint8_t>(data[j]));
	          top_data[item_id * size + j] = (datum_element - mean[j]) * scale;
	        }
	      } else {
	        for (int j = 0; j < size; ++j) {
	          top_data[item_id * size + j] =
	              (datum.float_data(j) - mean[j]) * scale;
	        }
	      }
	    }
	    top_label[item_id] = datum.label();//读取该图片的标签

	  }
}

template <typename Dtype>
Dtype MyImageDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
      vector<Blob<Dtype>*>* top) {

  //更新input
	fetchData();

  // Copy the data
  caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),
             (*top)[0]->mutable_cpu_data());
  caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),
             (*top)[1]->mutable_cpu_data());

  return Dtype(0.);
}

#ifdef CPU_ONLY
STUB_GPU_FORWARD(ImageDataLayer, Forward);
#endif

INSTANTIATE_CLASS(MyImageDataLayer);

}  // namespace caffe


在data_layers.hpp添加一下代码,参考ImageDataLayer写的。

template <typename Dtype>
class MyImageDataLayer : public Layer<Dtype>  {
 public:
  explicit MyImageDataLayer(const LayerParameter& param)
      : Layer<Dtype>(param) {}
  virtual ~MyImageDataLayer();
  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
      vector<Blob<Dtype>*>* top);

  virtual inline LayerParameter_LayerType type() const {
    return LayerParameter_LayerType_MY_IMAGE_DATA;
  }
  virtual inline int ExactNumBottomBlobs() const { return 0; }
  virtual inline int ExactNumTopBlobs() const { return 2; }
  void fetchData();
  void setImgPath(string path,int label);
 protected:
  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
      vector<Blob<Dtype>*>* top);

  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}


  vector<std::pair<std::string, int> > lines_;
  int lines_id_;
  int datum_channels_;
  int datum_height_;
  int datum_width_;
  int datum_size_;
  Blob<Dtype> prefetch_data_;
  Blob<Dtype> prefetch_label_;
  Blob<Dtype> data_mean_;
  Caffe::Phase phase_;
};


修改caffe.proto,在适当的位置添加下面信息,也是参考image_data写的。


MY_IMAGE_DATA = 36;


optional MyImageDataParameter my_image_data_param = 36;


// Message that stores parameters used by MyImageDataLayer
message MyImageDataParameter {
  // Specify the data source.
  optional string source = 1;
  // For data pre-processing, we can do simple scaling and subtracting the
  // data mean, if provided. Note that the mean subtraction is always carried
  // out before scaling.
  optional float scale = 2 [default = 1];
  optional string mean_file = 3;
  // Specify the batch size.
  optional uint32 batch_size = 4;
  // Specify if we would like to randomly crop an image.
  optional uint32 crop_size = 5 [default = 0];
  // Specify if we want to randomly mirror data.
  optional bool mirror = 6 [default = false];
  // The rand_skip variable is for the data layer to skip a few data points
  // to avoid all asynchronous sgd clients to start at the same point. The skip
  // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
  // be larger than the number of keys in the leveldb.
  optional uint32 rand_skip = 7 [default = 0];
  // Whether or not ImageLayer should shuffle the list of files at every epoch.
  optional bool shuffle = 8 [default = false];
  // It will also resize images if new_height or new_width are not zero.
  optional uint32 new_height = 9 [default = 0];
  optional uint32 new_width = 10 [default = 0];
}


以上每行位置不在一起,可以参考读一个image_data对应的位置。



本文作者:linger

本文链接:http://blog.csdn.net/lingerlanlan/article/details/39400375



caffe源码修改:抽取任意一张图片的特征

标签:deep learning   caffe   机器学习   源码修改   

原文地址:http://blog.csdn.net/lingerlanlan/article/details/39400375

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!