码迷,mamicode.com
首页 > 其他好文 > 详细

Caffe源码-DataTransformer类

时间:2019-12-26 00:07:35      阅读:110      评论:0      收藏:0      [点我收藏+]

标签:into   erro   状态   initial   format   char   tran   等于   参数   

DataTransformer类简介

DataTransformer类中主要用于图像预处理操作,layer中可设置TransformationParameter类型的消息来对输入图像进行减均值、随机镜像、随机裁剪或缩放。DataTransformer类中主要包含重载函数Transform(),可以对各种类型的图像数据进行预处理,并存入到Blob类型的数据中。类中还包含了以下变量。

TransformationParameter param_; //预处理参数
shared_ptr<Caffe::RNG> rng_;    //随机数生成器
Phase phase_;                   //网络状态,TRAIN还是TEST
Blob<Dtype> data_mean_;         //数据的均值,从mean_file中读取到的均值数据
vector<Dtype> mean_values_;     //均值数值,以mean_value形式设置一系列数据

其中TransformationParameter消息中包含的内容如下。

// Message that stores parameters used to apply transformation to the data layer's data
message TransformationParameter {
  // For data pre-processing, we can do simple scaling and subtracting the data mean,
  // if provided. Note that the mean subtraction is always carried out before scaling.
  optional float scale = 1 [default = 1];       //数值缩放系数    //缩放操作总是在减均值之后进行
  // Specify if we want to randomly mirror data.
  optional bool mirror = 2 [default = false];   //预处理时是否需要随机镜像
  // Specify if we would like to randomly crop an image.
  optional uint32 crop_size = 3 [default = 0];  //裁剪后的图像尺寸,非0值表示预处理时需要裁剪图像
  // mean_file and mean_value cannot be specified at the same time
  optional string mean_file = 4;                //均值文件的路径,均值文件为二进制proto类型文件
  // if specified can be repeated once (would subtract it from all the channels)
  // or can be repeated the same number of times as channels
  // (would subtract them from the corresponding channel)
  // mean_file与mean_value不能同时设置
  repeated float mean_value = 5;                //均值数值,mean_value的个数等于1或图像通道数
  // Force the decoded image to have 3 color channels.
  optional bool force_color = 6 [default = false];  //编码数据解码时强制转化为3通道彩色图
  // Force the decoded image to have 1 color channels.
  optional bool force_gray = 7 [default = false];   //编码数据解码时强制转化为单通道灰度图
}

data_transformer.cpp源码

template<typename Dtype>
DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param, Phase phase)
    : param_(param), phase_(phase) {    //构造函数,读取均值文件中的数据或者均值数值
  // check if we want to use mean_file
  if (param_.has_mean_file()) {         //设置了均值文件
    //TransformationParameter消息中不能同时设置mean_file和mean_value参数
    CHECK_EQ(param_.mean_value_size(), 0) << "Cannot specify mean_file and mean_value at the same time";
    const string& mean_file = param.mean_file();    //均值文件名
    if (Caffe::root_solver()) {
      LOG(INFO) << "Loading mean file from: " << mean_file;   //主线程中打印文件名
    }
    BlobProto blob_proto;
    ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); //从该二进制proto文件中读取数据到blob_proto消息中
    data_mean_.FromProto(blob_proto);   //将BlobProto类型的消息中的数据拷贝到Blob类型的变量中
  }
  // check if we want to use mean_value
  if (param_.mean_value_size() > 0) {   //如果设置了均值数值
    CHECK(param_.has_mean_file() == false) <<
      "Cannot specify mean_file and mean_value at the same time"; //同样先检查不能同时设置
    for (int c = 0; c < param_.mean_value_size(); ++c) {
      mean_values_.push_back(param_.mean_value(c));   //将设置的值全部保存到类中
    }
  }
}

//对Datum类中的图像进行预处理操作(减均值/裁剪/镜像/数值缩放),将处理后的图像数据存入缓冲区中
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum, Dtype* transformed_data) {
  const string& data = datum.data();            //图像原始数据
  const int datum_channels = datum.channels();  //原始图像的通道数
  const int datum_height = datum.height();      //原始图像高度
  const int datum_width = datum.width();        //原始图像宽度

  const int crop_size = param_.crop_size();     //裁剪后的尺寸,非0为有效值
  const Dtype scale = param_.scale();           //数值缩放系数
  const bool do_mirror = param_.mirror() && Rand(2);    //是否需要镜像, mirror()为是否需要随机镜像,Rand(2)会返回0或1的值
  const bool has_mean_file = param_.has_mean_file();    //是否设置了均值文件
  const bool has_uint8 = data.size() > 0;               //datum中uint8数据的个数是否不为空
  const bool has_mean_values = mean_values_.size() > 0; //是否设置了均值数值

  CHECK_GT(datum_channels, 0);        //有效性检查,图像通道数是否大于0
  CHECK_GE(datum_height, crop_size);  //原始图像高度大于等于裁剪后的尺寸
  CHECK_GE(datum_width, crop_size);   //原始图像宽度大于等于裁剪后的尺寸

  Dtype* mean = NULL;
  if (has_mean_file) {
    //设置了均值文件,则检查均值文件中数据的channel/height/width与原始图像的是否匹配
    CHECK_EQ(datum_channels, data_mean_.channels());
    CHECK_EQ(datum_height, data_mean_.height());
    CHECK_EQ(datum_width, data_mean_.width());
    mean = data_mean_.mutable_cpu_data();   //最后返回均值文件的数据指针
  }
  if (has_mean_values) {
    //设置了均值数值,则设置的数值的个数要么为1(图像的所有通道都减去相同的值),要么设置的个数与图像的通道数相等
    CHECK(mean_values_.size() == 1 || mean_values_.size() == datum_channels) <<
     "Specify either 1 mean_value or as many as channels: " << datum_channels;
    if (datum_channels > 1 && mean_values_.size() == 1) {
      // Replicate the mean_value for simplicity
      for (int c = 1; c < datum_channels; ++c) {  //设置的数值的个数为1,但是图像通道数个数不为1
        mean_values_.push_back(mean_values_[0]);  //将每个通道对应的均值均设置为该值mean_values_[0]
      }
    }
  }

  int height = datum_height;    //height/width为预处理后图像的长宽,初始时为原图尺寸
  int width = datum_width;

  int h_off = 0;  //裁剪时的h/w方向的偏移量
  int w_off = 0;
  if (crop_size) {
    height = crop_size; //如果设置了裁剪的尺寸,则更新
    width = crop_size;
    // We only do random crop when we do training.
    if (phase_ == TRAIN) {    //训练模式下,随机得到裁剪的h和w方向的偏移
      h_off = Rand(datum_height - crop_size + 1); //返回一个 0 ~ datum_height - crop_size 之间的随机数
      w_off = Rand(datum_width - crop_size + 1);
    } else {                  //测试模式下,固定为中心裁剪
      h_off = (datum_height - crop_size) / 2;     //中心裁剪的h/w的偏移
      w_off = (datum_width - crop_size) / 2;
    }
  }

  //datum内只存有一张图像,num=1,n=0
  //top_index为输出图像的某个点的在输出图像中的索引,data_index为该点在原始图像中的索引
  //datum_element为该点在原始图像中的值
  Dtype datum_element;
  int top_index, data_index;
  for (int c = 0; c < datum_channels; ++c) {
    for (int h = 0; h < height; ++h) {
      for (int w = 0; w < width; ++w) {
        //原始图像中的(0, c, h_off + h, w_off + w)点
        data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;
        if (do_mirror) {                                          //此处可以看出镜像为width方向的镜像
          top_index = (c * height + h) * width + (width - 1 - w); //镜像,则对应输出图像的(0,c,h,width - 1 - w)点
        } else {
          top_index = (c * height + h) * width + w;       //无需镜像,则对应输出的(0,c,h,w)点
        }
        if (has_uint8) {    //如果datum中存在uint8数据
          datum_element = static_cast<Dtype>(static_cast<uint8_t>(data[data_index])); //原始图像上该点的值
        } else {            //如果datum中不存在uint8数据,则从float_data中读取float类型的数据
          datum_element = datum.float_data(data_index);   //同样,该点的值
        }
        if (has_mean_file) {
          //设置了均值文件,则每个数据都有个对应的均值mean[data_index],减去均值后乘上数值缩放系数,得到输出的值
          transformed_data[top_index] = (datum_element - mean[data_index]) * scale;
        } else {
          if (has_mean_values) {
            //设置了均值数值,则图像每个通道上的数据都存在一个均值,减均值乘上缩放系数
            transformed_data[top_index] = (datum_element - mean_values_[c]) * scale;
          } else {
            transformed_data[top_index] = datum_element * scale;    //未设置均值,直接缩放
          }
        }
      }
    }
  }
}

//对Datum类中的图像进行预处理操作(减均值/裁剪/镜像/数值缩放),将处理后的图像数据存入Blob类型的数据中
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum, Blob<Dtype>* transformed_blob) {
  // If datum is encoded, decode and transform the cv::image.
  if (datum.encoded()) {    //如果数据为编码过的数据,则需要使用opencv进行解码
#ifdef USE_OPENCV
    //force_color表示解码后的数据为3通道彩色图,force_gray表示解码后的图像为单通道的灰度图,两者不能同时设置
    CHECK(!(param_.force_color() && param_.force_gray())) << "cannot set both force_color and force_gray";
    cv::Mat cv_img;
    if (param_.force_color() || param_.force_gray()) {
    // If force_color then decode in color otherwise decode in gray.
      cv_img = DecodeDatumToCVMat(datum, param_.force_color()); //从内存缓冲区中读取一张图像
    } else {
      cv_img = DecodeDatumToCVMatNative(datum); //未设置force_color/force_gray,则按原始格式读取图像
    }
    // Transform the cv::image into blob.
    return Transform(cv_img, transformed_blob); //将读取的图像进行预处理,然后存入transformed_blob中
#else
    LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";
#endif  // USE_OPENCV
  } else {
    //未编码数据,不能设置force_color或force_gray,否则报错
    if (param_.force_color() || param_.force_gray()) {
      LOG(ERROR) << "force_color and force_gray only for encoded datum";
    }
  }

  const int crop_size = param_.crop_size();     //裁剪后的尺寸
  const int datum_channels = datum.channels();  //原始图像的通道数/高度/宽度
  const int datum_height = datum.height();
  const int datum_width = datum.width();

  // Check dimensions.
  const int channels = transformed_blob->channels();  //输出blob的通道数/高度/宽度/个数
  const int height = transformed_blob->height();
  const int width = transformed_blob->width();
  const int num = transformed_blob->num();

  CHECK_EQ(channels, datum_channels); //检查原始图像与输出图像的尺寸是否匹配
  CHECK_LE(height, datum_height);
  CHECK_LE(width, datum_width);
  CHECK_GE(num, 1);

  if (crop_size) {
    CHECK_EQ(crop_size, height);    //需要裁剪,则原始图像的宽高大于等于输出的图像的宽高
    CHECK_EQ(crop_size, width);
  } else {
    CHECK_EQ(datum_height, height); //无需裁剪,则两者相等
    CHECK_EQ(datum_width, width);
  }

  Dtype* transformed_data = transformed_blob->mutable_cpu_data(); //输出blob的数据指针
  Transform(datum, transformed_data); //预处理图像,并将数据存入transformed_data中
}

//对datum_vector中的多张图像进行预处理,并将结果存入Blob类型的数据中
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector,
                                       Blob<Dtype>* transformed_blob) {
  const int datum_num = datum_vector.size();    //原始图像数据的个数
  const int num = transformed_blob->num();      //输出blob的各个维度的值
  const int channels = transformed_blob->channels();
  const int height = transformed_blob->height();
  const int width = transformed_blob->width();

  //检查输入的原始图像的个数,大于0,不超过blob的num维度的值
  CHECK_GT(datum_num, 0) << "There is no datum to add";
  CHECK_LE(datum_num, num) << "The size of datum_vector must be no greater than transformed_blob->num()";
  Blob<Dtype> uni_blob(1, channels, height, width);   //用于存放单个图像数据
  for (int item_id = 0; item_id < datum_num; ++item_id) {
    int offset = transformed_blob->offset(item_id);   //(n=item_id, c=0, h=0, w=0)点的偏移量,用于存放一张新的图像
    uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset); //将uni_blob的数据指针指向transformed_blob的缓冲区
    Transform(datum_vector[item_id], &uni_blob);      //预处理,并将预处理后的图像保存在uni_blob中
  }
}

//对mat_vector中的多张图像进行预处理,并将结果存入Blob类型的数据中
#ifdef USE_OPENCV
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const vector<cv::Mat> & mat_vector,
                                       Blob<Dtype>* transformed_blob) {
  const int mat_num = mat_vector.size();            //输入图像的个数
  const int num = transformed_blob->num();          //输出blob的个数/通道数/高度/宽度
  const int channels = transformed_blob->channels();
  const int height = transformed_blob->height();
  const int width = transformed_blob->width();

  //同样检查输入图像的个数大于0,小于输出blob的num维度的值
  CHECK_GT(mat_num, 0) << "There is no MAT to add";
  CHECK_EQ(mat_num, num) << "The size of mat_vector must be equals to transformed_blob->num()";
  Blob<Dtype> uni_blob(1, channels, height, width);
  for (int item_id = 0; item_id < mat_num; ++item_id) {
    int offset = transformed_blob->offset(item_id);   //(n=item_id, c=0, h=0, w=0)的偏移
    uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset); //将uni_blob的数据指针指向transformed_blob的缓冲区
    Transform(mat_vector[item_id], &uni_blob);    //预处理图像,结果存入uni_blob中
  }
}

//对cv_img(单张图像)进行预处理,并将结果存入Blob类型的数据中
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
                                       Blob<Dtype>* transformed_blob) {
  const int crop_size = param_.crop_size();     //裁剪后的图像尺寸
  const int img_channels = cv_img.channels();   //原始图像的通道数/高度/宽度
  const int img_height = cv_img.rows;
  const int img_width = cv_img.cols;

  // Check dimensions.
  const int channels = transformed_blob->channels();  //输出blob的通道数/高度/宽度/个数
  const int height = transformed_blob->height();
  const int width = transformed_blob->width();
  const int num = transformed_blob->num();

  CHECK_EQ(channels, img_channels); //检查输入图像与输出blob的各个维度是否匹配
  CHECK_LE(height, img_height);
  CHECK_LE(width, img_width);
  CHECK_GE(num, 1);

  //cv_img中的图像数据必须为uint8类型
  CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";

  const Dtype scale = param_.scale();                   //设置的数值缩放系数
  const bool do_mirror = param_.mirror() && Rand(2);    //是否镜像
  const bool has_mean_file = param_.has_mean_file();    //是否设置了均值文件
  const bool has_mean_values = mean_values_.size() > 0; //是否设置了均值数值

  CHECK_GT(img_channels, 0);        //检查输入图像的维度/高度/宽度是否有效
  CHECK_GE(img_height, crop_size);
  CHECK_GE(img_width, crop_size);

  Dtype* mean = NULL;
  if (has_mean_file) {
    //存在均值文件,则还会检查均值blob的数据的形状与输入图像的形状是否匹配
    CHECK_EQ(img_channels, data_mean_.channels());
    CHECK_EQ(img_height, data_mean_.height());
    CHECK_EQ(img_width, data_mean_.width());
    mean = data_mean_.mutable_cpu_data();   //均值数据指针
  }
  if (has_mean_values) {
    //如果设置了均值数值,则会检查均值数值的个数是否为1或者等于输入图像的通道数
    CHECK(mean_values_.size() == 1 || mean_values_.size() == img_channels) <<
     "Specify either 1 mean_value or as many as channels: " << img_channels;
    if (img_channels > 1 && mean_values_.size() == 1) {
      // Replicate the mean_value for simplicity
      for (int c = 1; c < img_channels; ++c) {
        //均值数值的个数为1,图像通道数不为0,则将该均值mean_values_[0]作为每个通道的均值
        mean_values_.push_back(mean_values_[0]);
      }
    }
  }

  int h_off = 0;    //裁剪的h/w方向的偏移
  int w_off = 0;
  cv::Mat cv_cropped_img = cv_img;  //裁剪后的图像,初始设置为原始图像
  if (crop_size) {
    CHECK_EQ(crop_size, height);    //检查输出blob的尺寸是否等于裁剪后的图像
    CHECK_EQ(crop_size, width);
    // We only do random crop when we do training.
    if (phase_ == TRAIN) {          //同样,训练模式下会随机得到裁剪时h/w方向的偏移值
      h_off = Rand(img_height - crop_size + 1);
      w_off = Rand(img_width - crop_size + 1);
    } else {                        //测试模式下会使用中心裁剪方式得到h/w方向的偏移
      h_off = (img_height - crop_size) / 2;
      w_off = (img_width - crop_size) / 2;
    }
    cv::Rect roi(w_off, h_off, crop_size, crop_size); //设置图像兴趣区域的位置
    cv_cropped_img = cv_img(roi);   //得到裁剪后的图像
  } else {
    CHECK_EQ(img_height, height);   //非裁剪模式,检查输入图像的尺寸与输入blob的形状是否一致
    CHECK_EQ(img_width, width);
  }

  CHECK(cv_cropped_img.data);   //裁剪后的图像数据不为空

  //此处注意opencv中图像是以(h,w,c)形式存放的
  Dtype* transformed_data = transformed_blob->mutable_cpu_data(); //输出blob的数据指针
  int top_index;
  for (int h = 0; h < height; ++h) {
    const uchar* ptr = cv_cropped_img.ptr<uchar>(h);  //裁剪图像的第h行数据的指针
    int img_index = 0;
    for (int w = 0; w < width; ++w) {
      for (int c = 0; c < img_channels; ++c) {
        if (do_mirror) {    //镜像模式下
          top_index = (c * height + h) * width + (width - 1 - w); //得到裁剪图像上(h,w,c)点在输出blob上的索引(c,h,width - 1 - w)
        } else {
          top_index = (c * height + h) * width + w;         //裁剪图像上(h,w,c)点对应输出blob上的(c,h,w)点
        }
        // int top_index = (c * height + h) * width + w;
        Dtype pixel = static_cast<Dtype>(ptr[img_index++]); //裁剪图像上(h,w,c)点的值
        if (has_mean_file) {
          //裁剪图像上(h,w,c)点对应均值文件上的(c, h_off + h, w_off + w)点
          int mean_index = (c * img_height + h_off + h) * img_width + w_off + w;
          transformed_data[top_index] = (pixel - mean[mean_index]) * scale;   //减均值,缩放
        } else {
          if (has_mean_values) {
            //裁剪图像上(h,w,c)点对应均值数值的mean_values_[c]
            transformed_data[top_index] = (pixel - mean_values_[c]) * scale;  //减均值,缩放
          } else {
            transformed_data[top_index] = pixel * scale;  //未设置均值,直接缩放
          }
        }
      }
    }
  }
}
#endif  // USE_OPENCV

//对input_blob中的所有图像进行预处理,并将结果存入transformed_blob中
template<typename Dtype>
void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
                                       Blob<Dtype>* transformed_blob) {
  const int crop_size = param_.crop_size();           //裁剪后的尺寸
  const int input_num = input_blob->num();            //输入图像的个数
  const int input_channels = input_blob->channels();  //输入图像的通道数/高度/宽度
  const int input_height = input_blob->height();
  const int input_width = input_blob->width();

  if (transformed_blob->count() == 0) {   //如果输出blob为空,则先按照输出图像的尺寸调整blob的形状
    // Initialize transformed_blob with the right shape.
    if (crop_size) {    //设置了裁剪尺寸    //调整形状,在实际访问内部数据的之后便会为其分配相应的空间
      transformed_blob->Reshape(input_num, input_channels, crop_size, crop_size);
    } else {
      transformed_blob->Reshape(input_num, input_channels, input_height, input_width);
    }
  }

  const int num = transformed_blob->num();    //输出图像的个数/通道数/高度/宽度
  const int channels = transformed_blob->channels();
  const int height = transformed_blob->height();
  const int width = transformed_blob->width();
  const int size = transformed_blob->count(); //输出blob的大小

  CHECK_LE(input_num, num);           //输入图像个数不超过输出图像个数
  CHECK_EQ(input_channels, channels); //输入输出图像通道数相同
  CHECK_GE(input_height, height);     //输入图像尺寸不小于输出图像尺寸
  CHECK_GE(input_width, width);

  const Dtype scale = param_.scale();                   //设置的数值缩放系数
  const bool do_mirror = param_.mirror() && Rand(2);    //是否镜像
  const bool has_mean_file = param_.has_mean_file();    //是否设置了均值文件
  const bool has_mean_values = mean_values_.size() > 0; //是否设置了均值数值

  int h_off = 0;    //裁剪图像时h/w方向的偏移量
  int w_off = 0;
  if (crop_size) {  //需要裁剪
    CHECK_EQ(crop_size, height);  //输出图像与裁剪尺寸一致
    CHECK_EQ(crop_size, width);
    // We only do random crop when we do training.
    if (phase_ == TRAIN) {        //训练模式,随机获取裁剪时h/w方向的偏移量
      h_off = Rand(input_height - crop_size + 1);
      w_off = Rand(input_width - crop_size + 1);
    } else {                      //测试模式,获取中心裁剪时h/w方向的偏移量
      h_off = (input_height - crop_size) / 2;
      w_off = (input_width - crop_size) / 2;
    }
  } else {
    CHECK_EQ(input_height, height); //非裁剪模式,检查输入图像与输出图像尺寸是否一致
    CHECK_EQ(input_width, width);
  }

  Dtype* input_data = input_blob->mutable_cpu_data();   //输入blob的数据指针
  if (has_mean_file) {
    CHECK_EQ(input_channels, data_mean_.channels());  //设置了均值文件,则检查均值文件中的blob与输入blob的c/h/w是否一致
    CHECK_EQ(input_height, data_mean_.height());
    CHECK_EQ(input_width, data_mean_.width());
    for (int n = 0; n < input_num; ++n) {
      int offset = input_blob->offset(n);   //输入blob中第n张图像数据的起始偏移
      caffe_sub(data_mean_.count(), input_data + offset,
            data_mean_.cpu_data(), input_data + offset);  //相减,(input_data + offset)[] -= data_mean_cpp_data[]
    }
  }

  if (has_mean_values) {    //设置了均值数值
    //同样,检查均值数值的个数等于1或等于通道数
    CHECK(mean_values_.size() == 1 || mean_values_.size() == input_channels) <<
     "Specify either 1 mean_value or as many as channels: " << input_channels;
    if (mean_values_.size() == 1) {
      caffe_add_scalar(input_blob->count(), -(mean_values_[0]), input_data);  //input_data[i] += -(mean_values_[0])
    } else {
      for (int n = 0; n < input_num; ++n) {
        for (int c = 0; c < input_channels; ++c) {
          int offset = input_blob->offset(n, c);    //输入blob的第n张图的第c通道的起始偏移,同一通道需减去相同的均值数值
          // (input_data + offset)[i] += -(mean_values_[c])
          caffe_add_scalar(input_height * input_width, -(mean_values_[c]), input_data + offset);
        }
      }
    }
  }

  Dtype* transformed_data = transformed_blob->mutable_cpu_data(); //输出blob的数据指针

  for (int n = 0; n < input_num; ++n) {
    int top_index_n = n * channels;     //计算输出偏移的中间量,不好描述,大致可理解为输出blob的(n, ?, ?, ?)点的偏移
    int data_index_n = n * channels;    //输入blob的的(n, ?, ?, ?)点的偏移
    for (int c = 0; c < channels; ++c) {
      int top_index_c = (top_index_n + c) * height;                   //输出blob的(n, c, ?, ?)点的偏移
      int data_index_c = (data_index_n + c) * input_height + h_off;   //输入blob的(n, c, h_off, ?)点的偏移
      for (int h = 0; h < height; ++h) {
        int top_index_h = (top_index_c + h) * width;                  //输出blob的(n, c, h, ?)点的偏移
        int data_index_h = (data_index_c + h) * input_width + w_off;  //输入blob的(n, c, h_off + h, w_off)点的偏移
        if (do_mirror) {  //需要镜像
          int top_index_w = top_index_h + width - 1;                  //输出blob的(n, c, h, width - 1)点的偏移
          for (int w = 0; w < width; ++w) {
            //输出blob的(n, c, h, width - 1 - w)点对应输入blob的(n, c, h_off + h, w_off + w)点
            transformed_data[top_index_w-w] = input_data[data_index_h + w];
          }
        } else {
          for (int w = 0; w < width; ++w) {
            //输出blob的(n, c, h, w)点对应输入blob的(n, c, h_off + h, w_off + w)点
            transformed_data[top_index_h + w] = input_data[data_index_h + w];
          }
        }
      }
    }
  }
  if (scale != Dtype(1)) {    //非1,则还需缩放数据
    DLOG(INFO) << "Scale: " << scale;
    caffe_scal(size, scale, transformed_data);  //transformed_data[] *= scale
  }
}

//推断图像在预处理之后的形状
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const Datum& datum) {
  if (datum.encoded()) {    //编码过的数据
#ifdef USE_OPENCV
    CHECK(!(param_.force_color() && param_.force_gray()))
        << "cannot set both force_color and force_gray";  //同样,force_color/force_gray不能同时设置
    cv::Mat cv_img;
    if (param_.force_color() || param_.force_gray()) {
    // If force_color then decode in color otherwise decode in gray.
      cv_img = DecodeDatumToCVMat(datum, param_.force_color()); //读取数据,返回图像
    } else {
      cv_img = DecodeDatumToCVMatNative(datum);
    }
    // InferBlobShape using the cv::image.
    return InferBlobShape(cv_img);    //判断图像在预处理后的形状,返回
#else
    LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";
#endif  // USE_OPENCV
  }
  //非编码数据,直接判断
  const int crop_size = param_.crop_size();     //裁剪后的尺寸
  const int datum_channels = datum.channels();  //输入数据的通道数/高度/宽度
  const int datum_height = datum.height();
  const int datum_width = datum.width();
  // Check dimensions.
  CHECK_GT(datum_channels, 0);  //有效性检查,输入数据的通道数大于0,宽高不小于裁剪后的尺寸
  CHECK_GE(datum_height, crop_size);
  CHECK_GE(datum_width, crop_size);
  // Build BlobShape.
  vector<int> shape(4);     //图像形状
  shape[0] = 1;             //单张图像,num固定为1
  shape[1] = datum_channels;
  shape[2] = (crop_size)? crop_size: datum_height;  //需要裁剪则为裁剪的尺寸,否则为原始尺寸
  shape[3] = (crop_size)? crop_size: datum_width;
  return shape;
}

//推断datum_vector中的图像在预处理之后的形状
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const vector<Datum> & datum_vector) {
  const int num = datum_vector.size();
  CHECK_GT(num, 0) << "There is no datum to in the vector"; //图像个数需大于0
  // Use first datum in the vector to InferBlobShape.
  vector<int> shape = InferBlobShape(datum_vector[0]);  //得到形状,(1, channel, height, width)
  // Adjust num to the size of the vector.
  shape[0] = num;   //以图像个数设置num维度的值
  return shape;
}

#ifdef USE_OPENCV
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const cv::Mat& cv_img) { //推断cv_img在预处理之后的图像尺寸
  const int crop_size = param_.crop_size();     //裁剪尺寸
  const int img_channels = cv_img.channels();   //输入图像的通道数/高度/宽度
  const int img_height = cv_img.rows;
  const int img_width = cv_img.cols;
  // Check dimensions.
  CHECK_GT(img_channels, 0);        //同理,有效性检查
  CHECK_GE(img_height, crop_size);
  CHECK_GE(img_width, crop_size);
  // Build BlobShape.
  vector<int> shape(4);
  shape[0] = 1;
  shape[1] = img_channels;
  shape[2] = (crop_size)? crop_size: img_height;  //输出尺寸为裁剪后的尺寸或者原始尺寸
  shape[3] = (crop_size)? crop_size: img_width;
  return shape;
}

template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(
    const vector<cv::Mat> & mat_vector) {       //推断mat_vector中的图像在预处理之后的形状
  const int num = mat_vector.size();
  CHECK_GT(num, 0) << "There is no cv_img to in the vector";  //图像个数大于0
  // Use first cv_img in the vector to InferBlobShape.
  vector<int> shape = InferBlobShape(mat_vector[0]);  //得到单张图像预处理后的尺寸
  // Adjust num to the size of the vector.
  shape[0] = num;   //以图像个数设置num维度的值
  return shape;
}
#endif  // USE_OPENCV

template <typename Dtype>
void DataTransformer<Dtype>::InitRand() {       //初始化随机数生成器
  //是否需要随机数生成器,只有设置了随机镜像或训练模式下设置了随机裁剪才需要随即操作
  const bool needs_rand = param_.mirror() || (phase_ == TRAIN && param_.crop_size());
  if (needs_rand) {
    const unsigned int rng_seed = caffe_rng_rand(); //随机得到一个随机种子
    rng_.reset(new Caffe::RNG(rng_seed));       //使用该种子创建一个随机数生成器
  } else {
    rng_.reset(); //不需要随机,释放
  }
}

template <typename Dtype>
int DataTransformer<Dtype>::Rand(int n) {   //返回一个0 ~ n-1 之间的随机数
  CHECK(rng_);
  CHECK_GT(n, 0);
  caffe::rng_t* rng = static_cast<caffe::rng_t*>(rng_->generator());  //随机数生成器
  return ((*rng)() % n);  //随机数,取余
}

小结

  1. 注意opencv中图像是以(height, width, channel)形式存放的,与caffe中的(num, channel, height,width)形式不同。
  2. caffe::RNG类中封装了boost库和CUDA的CURAND库的随机数函数,实现了跨平台编译。CURAND库的函数可参考官方提供的文档。

参考

https://docs.nvidia.com/cuda/pdf/CURAND_Library.pdf

Caffe的源码笔者是第一次阅读,一边阅读一边记录,对代码的理解和分析可能会存在错误或遗漏,希望各位读者批评指正,谢谢支持!

Caffe源码-DataTransformer类

标签:into   erro   状态   initial   format   char   tran   等于   参数   

原文地址:https://www.cnblogs.com/Relu110/p/12099629.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!