码迷,mamicode.com
首页 > 数据库 > 详细

caffe神经网络框架的辅助工具(将图片转换为leveldb格式)

时间:2015-05-05 15:52:42      阅读:595      评论:0      收藏:0      [点我收藏+]

标签:

caffe神经网络框架的辅助工具(将图片转换为leveldb格式)

这应该是 比较老的版本的caffe了,直接拿来应该不能用了,但是可以参考下

caffe中负责整个网络输入的datalayer是从leveldb里读取数据的,是一个google实现的非常高效的kv数据库。

因此我们训练网络必须先把数据转成leveldb的格式。

这里我实现的是把一个文件夹的所有图片转成leveldb的格式。

 

工具使用命令格格式:convert_imagedata src_dir dst_dir attach_dir channel width height

样例:./convert_imagedata.bin /home/linger/imdata/collar_train/ /home/linger/linger/testfile/crop_train_db/ /home/linger/linger/testfile/crop_train_attachment/ 3 50 50

 

源代码:

 

[cpp] view plaincopy技术分享技术分享
 
  1. #include <google/protobuf/text_format.h>  
  2. #include <glog/logging.h>  
  3. #include <leveldb/db.h>  
  4.   
  5. #include <stdint.h>  
  6. #include <fstream>  // NOLINT(readability/streams)  
  7. #include <string>  
  8. #include <set>  
  9. #include <stdio.h>  
  10. #include <string.h>  
  11. #include <stdlib.h>  
  12. #include <dirent.h>  
  13. #include <sys/stat.h>  
  14. #include <unistd.h>  
  15. #include <sys/types.h>  
  16. #include "caffe/proto/caffe.pb.h"  
  17. #include <opencv2/highgui/highgui.hpp>  
  18. #include <opencv2/highgui/highgui_c.h>  
  19. #include <opencv2/imgproc/imgproc.hpp>  
  20.   
  21. using std::string;  
  22. using namespace std;  
  23.   
  24.   
  25. set<string> all_class_name;  
  26. map<string,int> class2id;  
  27.   
  28.   
  29. /** 
  30.  * path:目录 
  31.  * files:用于保存文件名的vector 
  32.  * r:是否需要遍历子目录 
  33.  * return:文件名,不包含路径 
  34.  */  
  35. void list_dir(const char *path,vector<string>& files,bool r = false)  
  36. {  
  37.     DIR *pDir;  
  38.     struct dirent *ent;  
  39.     char childpath[512];  
  40.     pDir = opendir(path);  
  41.     memset(childpath, 0, sizeof(childpath));  
  42.     while ((ent = readdir(pDir)) != NULL)  
  43.     {  
  44.         if (ent->d_type & DT_DIR)  
  45.         {  
  46.   
  47.             if (strcmp(ent->d_name, ".") == 0 || strcmp(ent->d_name, "..") == 0)  
  48.             {  
  49.                 continue;  
  50.             }  
  51.             if(r) //如果需要遍历子目录  
  52.             {  
  53.                 sprintf(childpath, "%s/%s", path, ent->d_name);  
  54.                 list_dir(childpath,files);  
  55.             }  
  56.         }  
  57.         else  
  58.         {  
  59.             files.push_back(ent->d_name);  
  60.         }  
  61.     }  
  62.     sort(files.begin(),files.end());//排序  
  63.   
  64. }  
  65.   
  66. string get_classname(string path)  
  67. {  
  68.     int index = path.find_last_of(‘_‘);  
  69.     return path.substr(0, index);  
  70. }  
  71.   
  72.   
  73. int get_labelid(string fileName)  
  74. {  
  75.     string class_name_tmp = get_classname(fileName);  
  76.     all_class_name.insert(class_name_tmp);  
  77.     map<string,int>::iterator name_iter_tmp = class2id.find(class_name_tmp);  
  78.     if (name_iter_tmp == class2id.end())  
  79.     {  
  80.         int id = class2id.size();  
  81.         class2id.insert(name_iter_tmp, std::make_pair(class_name_tmp, id));  
  82.         return id;  
  83.     }  
  84.     else  
  85.     {  
  86.         return name_iter_tmp->second;  
  87.     }  
  88. }  
  89.   
  90. void loadimg(string path,char* buffer)  
  91. {  
  92.     cv::Mat img = cv::imread(path, CV_LOAD_IMAGE_COLOR);  
  93.     string val;  
  94.     int rows = img.rows;  
  95.     int cols = img.cols;  
  96.     int pos=0;  
  97.     for (int c = 0; c < 3; c++)  
  98.     {  
  99.         for (int row = 0; row < rows; row++)  
  100.         {  
  101.             for (int col = 0; col < cols; col++)  
  102.             {  
  103.                 buffer[pos++]=img.at<cv::Vec3b>(row,col)[c];  
  104.             }  
  105.         }  
  106.     }  
  107.   
  108. }  
  109. void convert(string imgdir,string outputdb,string attachdir,int channel,int width,int height)  
  110. {  
  111.     leveldb::DB* db;  
  112.     leveldb::Options options;  
  113.     options.create_if_missing = true;  
  114.     options.error_if_exists = true;  
  115.     caffe::Datum datum;  
  116.     datum.set_channels(channel);  
  117.     datum.set_height(height);  
  118.     datum.set_width(width);  
  119.     int image_size = channel*width*height;  
  120.     char buffer[image_size];  
  121.   
  122.     string value;  
  123.     CHECK(leveldb::DB::Open(options, outputdb, &db).ok());  
  124.     vector<string> filenames;  
  125.     list_dir(imgdir.c_str(),filenames);  
  126.     string img_log = attachdir+"image_filename";  
  127.     ofstream writefile(img_log.c_str());  
  128.     for(int i=0;i<filenames.size();i++)  
  129.     {  
  130.         string path= imgdir;  
  131.         path.append(filenames[i]);//算出绝对路径  
  132.   
  133.         loadimg(path,buffer);  
  134.   
  135.         int labelid = get_labelid(filenames[i]);  
  136.   
  137.         datum.add_label(labelid);  
  138.         datum.set_data(buffer,image_size);  
  139.         datum.SerializeToString(&value);  
  140.         snprintf(buffer, image_size, "%05d", i);  
  141.         printf("\nclassid:%d classname:%s abspath:%s",labelid,get_classname(filenames[i]).c_str(),path.c_str());  
  142.         db->Put(leveldb::WriteOptions(),string(buffer),value);  
  143.         //printf("%d %s\n",i,fileNames[i].c_str());  
  144.   
  145.         assert(writefile.is_open());  
  146.         writefile<<i<<" "<<filenames[i]<<"\n";  
  147.   
  148.     }  
  149.     delete db;  
  150.     writefile.close();  
  151.   
  152.     img_log = attachdir+"image_classname";  
  153.     writefile.open(img_log.c_str());  
  154.     set<string>::iterator iter = all_class_name.begin();  
  155.     while(iter != all_class_name.end())  
  156.     {  
  157.         assert(writefile.is_open());  
  158.         writefile<<(*iter)<<"\n";  
  159.         //printf("%s\n",(*iter).c_str());  
  160.         iter++;  
  161.     }  
  162.     writefile.close();  
  163.   
  164. }  
  165.   
  166. int main(int argc, char** argv)  
  167. {  
  168.     if (argc < 6)  
  169.     {  
  170.         LOG(ERROR) << "convert_imagedata src_dir dst_dir attach_dir channel width height";  
  171.         return 0;  
  172.     }  
  173. //./convert_imagedata.bin  /home/linger/imdata/collarTest/ /home/linger/linger/testfile/dbtest/  /home/linger/linger/testfile/test_attachment/ 3 250 250  
  174.     //   ./convert_imagedata.bin /home/linger/imdata/collar_train/ /home/linger/linger/testfile/crop_train_db/ /home/linger/linger/testfile/crop_train_attachment/ 3 50 50  
  175.     google::InitGoogleLogging(argv[0]);  
  176.     string src_dir = argv[1];  
  177.     string src_dst = argv[2];  
  178.     string attach_dir = argv[3];  
  179.     int channel = atoi(argv[4]);  
  180.     int width = atoi(argv[5]);  
  181.     int height = atoi(argv[6]);  
  182.   
  183.     //for test  
  184.     /* 
  185.     src_dir = "/home/linger/imdata/collarTest/"; 
  186.     src_dst = "/home/linger/linger/testfile/dbtest/"; 
  187.     attach_dir = "/home/linger/linger/testfile/"; 
  188.     channel = 3; 
  189.     width = 250; 
  190.     height = 250; 
  191.      */  
  192.   
  193.     convert(src_dir,src_dst,attach_dir,channel,width,height);  
  194.   
  195.   
  196.   
  197. }  

caffe神经网络框架的辅助工具(将图片转换为leveldb格式)

标签:

原文地址:http://www.cnblogs.com/yymn/p/4479124.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!