码迷,mamicode.com
首页 > 其他好文 > 详细

Caffe源码-Solver类

时间:2019-12-22 16:16:17      阅读:97      评论:0      收藏:0      [点我收藏+]

标签:mes   写入文件   end   blob   art   otn   method   函数指针   来源   

Solver类简介

Net类中实现了网络的前向/反向计算和参数更新,而Solver类中则是对此进行进一步封装,包含可用于逐次训练网络的Step()函数,和用于求解网络的优化解的Solve()函数,同时还实现了一些存储、读取网络模型快照的接口函数。

solver.cpp源码

template<typename Dtype>
void Solver<Dtype>::SetActionFunction(ActionCallback func) {
  action_request_function_ = func;    //设置回调函数,该函数会返回求解器的动作类型
}

template<typename Dtype>
SolverAction::Enum Solver<Dtype>::GetRequestedAction() {  //返回求解器的动作类型
  if (action_request_function_) {
    // If the external request function has been set, call it.
    return action_request_function_();    //运行回调函数,该函数会返回求解器的动作类型
  }
  return SolverAction::NONE;
}

template <typename Dtype>
Solver<Dtype>::Solver(const SolverParameter& param)   //构造函数,使用param消息初始化求解器
    : net_(), callbacks_(), requested_early_exit_(false) {
  Init(param);    //使用param消息初始化当前求解器
}

template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file)
    : net_(), callbacks_(), requested_early_exit_(false) {  //构造函数,从文本类型的proto文件中读取求解器参数
  SolverParameter param;
  ReadSolverParamsFromTextFileOrDie(param_file, &param);    //从param_file中读取消息数据到param中
  Init(param);    //初始化求解器
}

template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {    //Solver类初始化
  LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
    << std::endl << param.DebugString();    //主线程中打印信息
  param_ = param;
  //loss的滑动平均窗的长度,每次计算最近average_loss_次的平均loss
  CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
  CheckSnapshotWritePermissions();          //检查是否能够打开快照文件
  if (param_.random_seed() >= 0) {          //SolverParameter消息中设置了随机种子
    Caffe::set_random_seed(param_.random_seed() + Caffe::solver_rank());    //设置
  }
  // Scaffolding code
  InitTrainNet();   //初始化训练网络
  InitTestNets();   //初始化所有测试网络    //训练网络只有一个,但是测试网络可以有多个
  if (Caffe::root_solver()) {
    LOG(INFO) << "Solver scaffolding done.";    //只在主线程中打印
  }
  iter_ = 0;      //初始化参数
  current_step_ = 0;
}

// Load weights from the caffemodel(s) specified in "weights" solver parameter
// into the train and test nets.
template <typename Dtype>
void LoadNetWeights(shared_ptr<Net<Dtype> > net, const std::string& model_list) {   //加载权重文件
  std::vector<std::string> model_names;
  boost::split(model_names, model_list, boost::is_any_of(",")); //拆分文件名,权重文件名在model_list中以","中分隔开
  for (int i = 0; i < model_names.size(); ++i) {
    boost::trim(model_names[i]);    //删除首位空格
    LOG(INFO) << "Finetuning from " << model_names[i];  //打印权重文件名
    net->CopyTrainedLayersFrom(model_names[i]);   //从文件中拷贝blob数据到网络的同名参数中
  }
}

template <typename Dtype>
void Solver<Dtype>::InitTrainNet() {    //初始化训练网络,配置网络参数,加载预训练模型
  //训练网络的proto文件名可通过SolverParameter消息中的train_net_param, train_net, net_param, net四个中的任意一个指定
  const int num_train_nets = param_.has_net() + param_.has_net_param() +
      param_.has_train_net() + param_.has_train_net_param();    //这四个参数总共设置的训练网络个数
  const string field_names = "net, net_param, train_net, train_net_param";
  CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
      << "using one of these fields: " << field_names;          //检查是否大于等于1
  CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
      << "one of these fields specifying a train_net: " << field_names; //检查是否小于等于1 //四个中只能有一个设置了true
  NetParameter net_param;
  if (param_.has_train_net_param()) {   //训练网络的名称在train_net_param中设置了
    LOG_IF(INFO, Caffe::root_solver()) << "Creating training net specified in train_net_param.";  //主线程中打印
    net_param.CopyFrom(param_.train_net_param());   //从NetParameter消息中拷贝网络参数至net_param
  } else if (param_.has_train_net()) {  //训练网络的名称在train_net中设置了
    LOG_IF(INFO, Caffe::root_solver()) << "Creating training net from train_net file: " << param_.train_net();
    ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param); //从proto文件中读取网络参数
  }
  if (param_.has_net_param()) {         //训练网络的名称在net_param中设置了
    LOG_IF(INFO, Caffe::root_solver()) << "Creating training net specified in net_param.";
    net_param.CopyFrom(param_.net_param()); //从NetParameter类型的消息中拷贝网络参数
  }
  if (param_.has_net()) {               //训练网络的名称在net中设置了
    LOG_IF(INFO, Caffe::root_solver()) << "Creating training net from net file: " << param_.net();
    ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);   //从proto文件中读取网络参数
  }
  // Set the correct NetState.  We start with the solver defaults (lowest
  // precedence); then, merge in any NetState specified by the net_param itself;
  // finally, merge in any NetState specified by the train_state (highest
  // precedence).
  //Message::MergeFrom()的机制,单字段的值会被覆盖,嵌套消息的值会被融合在一起,重复字段的值会被拼接在一起
  //Message::CopyFrom()的机制,清空当前的消息,然后将指定消息MergeFrom()到当前消息中
  //net_param中的状态值先是设置为默认值,然后使用从上面四个设置中读取到的网络参数net_param中的网络状态覆盖其中相同的,
  //再用当前求解器中设置的SolverParameter消息中的train_state覆盖其中相同的.
  //在网络中设置的网络状态优先级低,会被求解器中设置的网络状态覆盖
  NetState net_state;
  net_state.set_phase(TRAIN);   //设置网络的状态,训练模式
  net_state.MergeFrom(net_param.state());     //先使用上面的从文件或者消息中读取的网络参数中的网络状态
  net_state.MergeFrom(param_.train_state());  //再使用当前求解器中设置的训练网络状态
  net_param.mutable_state()->CopyFrom(net_state);   //将最终的到的网络状态存入网络参数中
  net_.reset(new Net<Dtype>(net_param));            //使用该网络参数初始化网络,存入net_中
  for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) { //weights参数的个数
    LoadNetWeights(net_, param_.weights(w_idx));    //加载每个参数中的一个或者多个预训练模型到net_中
  }
}

template <typename Dtype>
void Solver<Dtype>::InitTestNets() {    //初始化测试网络
  const bool has_net_param = param_.has_net_param();
  const bool has_net_file = param_.has_net();
  const int num_generic_nets = has_net_param + has_net_file;    //是否设置了模型参数,是否设置了模型文件名
  CHECK_LE(num_generic_nets, 1)
      << "Both net_param and net_file may not be specified.";   //检查是否小于等于1,这两个不能同时指定
  const int num_test_net_params = param_.test_net_param_size(); //设置的测试网络的参数的个数
  const int num_test_net_files = param_.test_net_size();        //设置的测试网络的个数
  const int num_test_nets = num_test_net_params + num_test_net_files;   //总个数
  if (num_generic_nets) {
      //test_iter_表示每个测试网络迭代的次数,test_iter_参数设置的个数必须与测试网络的个数相等
      //如果设置了模型参数或者模型文件名,那么这里面也可能设置了test net,所以test_iter_的个数必须大于等于num_test_nets
      CHECK_GE(param_.test_iter_size(), num_test_nets)
          << "test_iter must be specified for each test network.";
  } else {
      //没有设置net_parma或者net的话,test net全部在test_net_parma和test_net中指定,个数需相等
      CHECK_EQ(param_.test_iter_size(), num_test_nets)
          << "test_iter must be specified for each test network.";
  }
  // If we have a generic net (specified by net or net_param, rather than
  // test_net or test_net_param), we may have an unlimited number of actual
  // test networks -- the actual number is given by the number of remaining
  // test_iters after any test nets specified by test_net_param and/or test_net
  // are evaluated.
  const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;  //相减得到在net_parma或者net中定义的test net的个数
  const int num_test_net_instances = num_test_nets + num_generic_net_instances;   //总的test net的个数,即为param_.test_iter_size()
  if (param_.test_state_size()) {   //设置了test_state_,则个数必须与测试网络的个数相等
    CHECK_EQ(param_.test_state_size(), num_test_net_instances)
        << "test_state must be unspecified or specified once per test net.";      //检查个数是否相等
  }
  if (num_test_net_instances) {
    CHECK_GT(param_.test_interval(), 0);  //检查设置的测试的迭代间隔是否大于0
  }
  int test_net_id = 0;
  vector<string> sources(num_test_net_instances);
  vector<NetParameter> net_params(num_test_net_instances);
  //caffe.proto文件中注明了test net运行的优先级,(1) test_net_param, (2) test_net, (3) net_param/net.
  for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
      sources[test_net_id] = "test_net_param";    //保存定义该测试网络的来源
      net_params[test_net_id].CopyFrom(param_.test_net_param(i)); //从NetParameter类型的消息中拷贝网络参数
  }
  for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
      sources[test_net_id] = "test_net file: " + param_.test_net(i);  //保存来源,加上文件名
      ReadNetParamsFromTextFileOrDie(param_.test_net(i),
          &net_params[test_net_id]);    //从proto文件中读取网络参数,存入net_param中
  }
  const int remaining_test_nets = param_.test_iter_size() - test_net_id;  //net_param/net中定义的网络的个数
  if (has_net_param) {  //定义了net_param,则剩余的测试网络都定义在此处
    for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
      sources[test_net_id] = "net_param";
      net_params[test_net_id].CopyFrom(param_.net_param()); //拷贝网络参数
    }
  }
  if (has_net_file) {    //同样,从net文件中定义的测试网络文件名中读取网络参数
    for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
      sources[test_net_id] = "net file: " + param_.net();
      ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
    }
  }
  test_nets_.resize(num_test_net_instances);    //调整大小
  for (int i = 0; i < num_test_net_instances; ++i) {
    // Set the correct NetState.  We start with the solver defaults (lowest
    // precedence); then, merge in any NetState specified by the net_param
    // itself; finally, merge in any NetState specified by the test_state
    // (highest precedence).
    //与InitTrainNet()中的操作类似,先使用默认值,然后使用网络参数中的网络状态覆盖默认值,再使用
    //求解器中设置的测试网络状态覆盖之前的值,得到最终的测试网络状态
    NetState net_state;
    net_state.set_phase(TEST);    //设置模式为test
    net_state.MergeFrom(net_params[i].state()); //先使用网络参数中设置的网络状态覆盖
    if (param_.test_state_size()) {
      net_state.MergeFrom(param_.test_state(i));  //然后使用求解器中设置的测试网络状态覆盖
    }
    net_params[i].mutable_state()->CopyFrom(net_state); //将最终的测试网络状态存入net_params[i]中
    LOG(INFO) << "Creating test net (#" << i << ") specified by " << sources[i];  //打印之前保存的来源信息
    test_nets_[i].reset(new Net<Dtype>(net_params[i])); //使用net_params[i]创建网络,存入test_nets_中
    test_nets_[i]->set_debug_info(param_.debug_info()); //将求解器的是否打印信息的设置存入网络中
    for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
      LoadNetWeights(test_nets_[i], param_.weights(w_idx)); //加载预训练模型文件,每个测试网络都会尝试加载所有的预训练模型文件
    }
  }
}

//求解器单步迭代iters次
template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
  const int start_iter = iter_;         //当前已迭代的次数
  const int stop_iter = iter_ + iters;  //终止迭代时的次数
  int average_loss = this->param_.average_loss();   //loss的滑动平均窗的长度
  losses_.clear();          //清空历史loss值
  smoothed_loss_ = 0;       //清空
  iteration_timer_.Start(); //打开计时器

  while (iter_ < stop_iter) {
    // zero-init the params
    net_->ClearParamDiffs();    //清空网络中所有可学习参数的梯度数据
    if (param_.test_interval() && iter_ % param_.test_interval() == 0   //两次测试之间的迭代间隔不为0,且当前轮到测试
        && (iter_ > 0 || param_.test_initialization())) {   //初始时可以进入测试模式
      //test_initialization()仅仅用于表示初始(iter_==0)时是否运行一次测试网络
      //该值为true时,(iter_ % test_interval == 0)总是成立,每次开始迭代时都会先进入测试模式.该值为false时只在iter_ > 0时进入测试
      if (Caffe::root_solver()) {   //测试网络只在主线程中运行
        TestAll();    //运行所有测试网络,并打印输出信息
      }
      if (requested_early_exit_) {  //测试过程中出现提前退出动作,退出循环
        // Break out of the while loop because stop was requested while testing.
        break;
      }
    }

    for (int i = 0; i < callbacks_.size(); ++i) {   //solver的回调函数,在多gpu训练时用于同步求解器
      callbacks_[i]->on_start();
    }
    const bool display = param_.display() && iter_ % param_.display() == 0; //设置了打印间隔并且当前迭代轮到打印了
    net_->set_debug_info(display && param_.debug_info());   //设置是否打印调试信息
    // accumulate the loss and gradient
    Dtype loss = 0;
    for (int i = 0; i < param_.iter_size(); ++i) {  //单次迭代会执行iter_size次前向反向过程
      loss += net_->ForwardBackward();  //执行一次前向计算和反向传播,并累加iter_size次计算得到的loss
    }
    loss /= param_.iter_size();   //每次迭代的平均loss
    // average the loss across iterations for smoothed reporting
    UpdateSmoothedLoss(loss, start_iter, average_loss); //将loss保存在losses_中,并计算新的均值smoothed_loss_
    if (display) {                //需要打印此次迭代的信息
      float lapse = iteration_timer_.Seconds();   //关闭计时器,返回已运行的时间,单位s
      float per_s = (iter_ - iterations_last_) / (lapse ? lapse : 1);   //iterations_last_为上次开启计时器时的迭代次数,得到每秒可迭代的次数
      LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
          << " (" << per_s << " iter/s, " << lapse << "s/"
          << param_.display() << " iters), loss = " << smoothed_loss_;  //打印迭代次数,迭代速度,运行时间等信息
      iteration_timer_.Start();     //重新打开计时器
      iterations_last_ = iter_;     //保存当前的迭代次数
      const vector<Blob<Dtype>*>& result = net_->output_blobs();    //训练网络的所有输出blob
      int score_index = 0;
      for (int j = 0; j < result.size(); ++j) {
        const Dtype* result_vec = result[j]->cpu_data();    //第j个输出blob的data_数据
        const string& output_name = net_->blob_names()[net_->output_blob_indices()[j]]; //该输出blob的名称
        const Dtype loss_weight = net_->blob_loss_weights()[net_->output_blob_indices()[j]];  //该输出blob的loss权重
        for (int k = 0; k < result[j]->count(); ++k) {
          ostringstream loss_msg_stream;
          if (loss_weight) {    //权重不为0时,保存权重和加权后的输出值
            loss_msg_stream << " (* " << loss_weight
                            << " = " << loss_weight * result_vec[k] << " loss)";
          }
          LOG_IF(INFO, Caffe::root_solver()) << "    Train net output #"
              << score_index++ << ": " << output_name << " = "
              << result_vec[k] << loss_msg_stream.str();    //打印信息
        }
      }
    }
    //求解器的回调函数,在梯度计算完毕之后调用.同样也是用于多gpu计算时梯度数据的同步
    for (int i = 0; i < callbacks_.size(); ++i) {
      callbacks_[i]->on_gradients_ready();
    }
    ApplyUpdate();    //根据学习率,冲量,权重衰减值等参数计算参数更新时使用的梯度,并更新网络中的参数,在SGDSolver类中实现

    SolverAction::Enum request = GetRequestedAction();    //获取当前求解器的动作

    // Save a snapshot if needed.
    if ((param_.snapshot()
         && iter_ % param_.snapshot() == 0
         && Caffe::root_solver()) ||
         (request == SolverAction::SNAPSHOT)) {   //当前迭代次数轮到存储快照,或者当前的解器动作为存快照
      Snapshot();   //生成快照文件
    }
    if (SolverAction::STOP == request) {    //当前动作为退出,则提前退出
      requested_early_exit_ = true;
      // Break out of training loop.
      break;
    }
  }
}

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {  //从resume_file文件中恢复网络和求解器状态,并训练网络
  CHECK(Caffe::root_solver());    //在主线程中进行该操作
  LOG(INFO) << "Solving " << net_->name();
  LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();  //打印网络名称和学习率更新策略

  // Initialize to false every time we start solving.
  requested_early_exit_ = false;    //每次求解时初始化下状态

  if (resume_file) {        //文件名不为空
    LOG(INFO) << "Restoring previous solver status from " << resume_file;
    Restore(resume_file);   //从文件中还原网络参数和求解器的状态
  }

  // For a network that is trained by the solver, no bottom or top vecs
  // should be given, and we will just provide dummy vecs.
  int start_iter = iter_;   //当前已迭代的次数
  Step(param_.max_iter() - iter_);    //max_iter_为最大迭代次数,计算当前需要迭代的次数
  // If we haven't already, save a snapshot after optimization, unless
  // overridden by setting snapshot_after_train := false
  if (param_.snapshot_after_train()
      && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
    //如果设置了训练结束后保存快照,并且当前迭代次数在并未轮到保存快照
    //满足 param_.snapshot() && iter_ % param_.snapshot() == 0 的话会在Step()函数中保存当前iter_的快照,此处自然无需再保存
    Snapshot();
  }
  if (requested_early_exit_) {  //同样判断下求解器的动作
    LOG(INFO) << "Optimization stopped early.";
    return;
  }
  // After the optimization is done, run an additional train and test pass to
  // display the train and test loss/outputs if appropriate (based on the
  // display and test_interval settings, respectively).  Unlike in the rest of
  // training, for the train net we only run a forward pass as we've already
  // updated the parameters "max_iter" times -- this final pass is only done to
  // display the loss, which is computed in the forward pass.
  //如果需要显示,会额外进行一次前向计算.这与Step()中的最后一次计算不同,Step()中的最后一次计算包括前向和反向计算,
  //还包括参数的更新,此时参数更新之后网络的loss并不知道,所以此处会使用更新后的参数再计算一次前向过程,得到对应的loss
  if (param_.display() && iter_ % param_.display() == 0) {  //设置了打印求解器的信息并且当前迭代轮到打印了
    int average_loss = this->param_.average_loss();   //loss的滑动平均窗的长度
    Dtype loss;
    net_->Forward(&loss);   //一次前向计算

    UpdateSmoothedLoss(loss, start_iter, average_loss); //更新losses_,并计算平均loss
    LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;  //打印信息
  }
  if (param_.test_interval() && iter_ % param_.test_interval() == 0) {    //设置了测试网络的运行间隔,并且当前轮到测试网络
    TestAll();    //运行所有测试网络
  }
  LOG(INFO) << "Optimization Done.";    //求解器优化完成
}

template <typename Dtype>
void Solver<Dtype>::TestAll() {   //运行全部测试网络
  for (int test_net_id = 0;
       test_net_id < test_nets_.size() && !requested_early_exit_; //没有要求提前退出
       ++test_net_id) {
    Test(test_net_id);    //执行第test_net_id个测试网络
  }
}

template <typename Dtype>
void Solver<Dtype>::Test(const int test_net_id) {     //执行第test_net_id个测试网络
  CHECK(Caffe::root_solver());    //测试网络只在主线程中运行
  LOG(INFO) << "Iteration " << iter_
            << ", Testing net (#" << test_net_id << ")";  //打印迭代信息,测试网络的id
  //共享网络,将训练网络net_中的参数blob的数据指针赋给当前的测试网络,只修改测试网络的指针指向位置,不会拷贝数据
  CHECK_NOTNULL(test_nets_[test_net_id].get())->ShareTrainedLayersWith(net_.get());
  vector<Dtype> test_score;
  vector<int> test_score_output_id;
  const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];  //当前的测试网络
  Dtype loss = 0;
  //test_iter(test_net_id)为第test_net_id个测试网络在测试时需要迭代的次数
  for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
    SolverAction::Enum request = GetRequestedAction();    //获取当前的求解器动作
    // Check to see if stoppage of testing/training has been requested.
    while (request != SolverAction::NONE) {       //非NONE类型的话,则会执行相应的动作
        if (SolverAction::SNAPSHOT == request) {  //拍摄快照,并继续训练
          Snapshot();   //生成快照文件,并继续当前操作
        } else if (SolverAction::STOP == request) { //提前退出
          requested_early_exit_ = true;
        }
        request = GetRequestedAction();
    }
    if (requested_early_exit_) {    //退出,不进行后续的操作
      // break out of test loop.
      break;
    }

    Dtype iter_loss;
    //执行test_net的一次前向计算过程,loss存入iter_loss中,result为网络的输出blob(net_output_blobs_)
    const vector<Blob<Dtype>*>& result = test_net->Forward(&iter_loss);
    if (param_.test_compute_loss()) { //是否计算测试网络的平均loss
      loss += iter_loss;              //累加每次计算出的loss
    }
    if (i == 0) {           //初次计算时,先确定好test_score和test_score_output_id的大小
      for (int j = 0; j < result.size(); ++j) {
        const Dtype* result_vec = result[j]->cpu_data();  //网络输出的第j个blob的data_
        for (int k = 0; k < result[j]->count(); ++k) {
          test_score.push_back(result_vec[k]);    //将输出blob的data中的数据全部存入test_score中
          test_score_output_id.push_back(j);      //将数据在输出blob中的来源存入test_score_output_id中
        }
      }
    } else {
      int idx = 0;
      for (int j = 0; j < result.size(); ++j) {   //每个输出blob
        const Dtype* result_vec = result[j]->cpu_data();    //输出blob的data_数据
        for (int k = 0; k < result[j]->count(); ++k) {
          test_score[idx++] += result_vec[k];     //累加测试网络每次迭代时得到的输出blob数据
        }
      }
    }
  }
  if (requested_early_exit_) {        //提前退出?
    LOG(INFO)     << "Test interrupted.";
    return;
  }
  if (param_.test_compute_loss()) {   //是否计算测试网络的平均loss
    loss /= param_.test_iter(test_net_id);    //计算该测试网络test_iter(test_net_id)次迭代的loss均值
    LOG(INFO) << "Test loss: " << loss;
  }
  for (int i = 0; i < test_score.size(); ++i) {
    //数据test_score[i]来源于blob类型的net_output_blobs_[test_score_output_id[i]]中,output_blob_index为该blob在blobs_的索引
    const int output_blob_index = test_net->output_blob_indices()[test_score_output_id[i]];
    const string& output_name = test_net->blob_names()[output_blob_index];        //该blob的名称
    const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];   //该blob的loss权重
    ostringstream loss_msg_stream;
    const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id); //除以迭代次数,得到输出blob的均值
    if (loss_weight) {    //权重非0时,权重和加权值
      loss_msg_stream << " (* " << loss_weight << " = " << loss_weight * mean_score << " loss)";
    }
    LOG(INFO) << "    Test net output #" << i << ": " << output_name << " = "
              << mean_score << loss_msg_stream.str(); //打印测试网络的每个输出blob中的每个数据的均值
  }
}

template <typename Dtype>
void Solver<Dtype>::Snapshot() {    //生成两个快照文件,分别保存网络参数(NetParameter类型)和求解器的状态(SolverState类型)
  CHECK(Caffe::root_solver());      //同样,存快照只在主线程中操作
  string model_filename;
  switch (param_.snapshot_format()) {   //设置的快照文件格式
  case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:   //二进制proto类型
    model_filename = SnapshotToBinaryProto(); //将训练网络的网络参数存为".caffemodel"后缀的文件,返回其文件名
    break;
  case caffe::SolverParameter_SnapshotFormat_HDF5:          //hdf5类型
    model_filename = SnapshotToHDF5();        //将训练网络的网络参数写入文件中,返回其文件名
    break;
  default:
    LOG(FATAL) << "Unsupported snapshot format.";
  }

  SnapshotSolverState(model_filename);      //将求解器的状态(SolverState类型)保存为文件
}

template <typename Dtype>
void Solver<Dtype>::CheckSnapshotWritePermissions() { //检查是否能够创建快照文件(只检查是否能够以写方式创建文件,不会存数据进去)
  if (Caffe::root_solver() && param_.snapshot()) {    //只在主线程中操作
    CHECK(param_.has_snapshot_prefix())
        << "In solver params, snapshot is specified but snapshot_prefix is not";  //检查是否设置了快照文件名的前缀
    string probe_filename = SnapshotFilename(".tempfile");    //生成快照的文件名,".tempfile"为后缀
    std::ofstream probe_ofs(probe_filename.c_str());    //创建临时文件文件
    if (probe_ofs.good()) {   //判断是否发生错误
      probe_ofs.close();      //关闭
      std::remove(probe_filename.c_str());    //删除文件
    } else {
      LOG(FATAL) << "Cannot write to snapshot prefix '"
          << param_.snapshot_prefix() << "'.  Make sure "
          << "that the directory exists and is writable.";    //创建失败,报错
    }
  }
}

//生成快照的文件名,前缀字符串 + "_iter_" + 迭代次数转字符串 + 扩展名extension
template <typename Dtype>
string Solver<Dtype>::SnapshotFilename(const string& extension) {
  return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
    + extension;
}

template <typename Dtype>
string Solver<Dtype>::SnapshotToBinaryProto() {   //将训练网络的网络参数保存为二进制proto文件,并返回文件名
  string model_filename = SnapshotFilename(".caffemodel");  //生成文件名,扩展名为".caffemodel"
  LOG(INFO) << "Snapshotting to binary proto file " << model_filename;    //打印信息
  NetParameter net_param;
  //将训练网络net_中的所有layer的参数写入到net_param中,snapshot_diff()表示是否需要保存梯度信息到快照中
  net_->ToProto(&net_param, param_.snapshot_diff());
  WriteProtoToBinaryFile(net_param, model_filename);    //将NetParameter类型的消息写入到文件中
  return model_filename;    //返回快照文件名
}

template <typename Dtype>
string Solver<Dtype>::SnapshotToHDF5() {    //将训练网络的参数存为hdf5文件中,返回文件名
  string model_filename = SnapshotFilename(".caffemodel.h5");   //快照的文件名
  LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;  //打印
  net_->ToHDF5(model_filename, param_.snapshot_diff());         //将net_的各layer的参数写入hdf5文件中
  return model_filename;    //返回文件名
}

//还原网络参数和训练状态,从state_file文件中读取求解器的状态,如果里面还设置了网络参数的模型文件,则还会加载网络参数
template <typename Dtype>
void Solver<Dtype>::Restore(const char* state_file) {
  string state_filename(state_file);
  if (state_filename.size() >= 3 &&
      state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) { //根据文件名判断hdf5还是proto类型,稍微粗糙了点
    RestoreSolverStateFromHDF5(state_filename);   //从hdf5文件中读取
  } else {
    RestoreSolverStateFromBinaryProto(state_filename);  //从二进制proto文件中读取
  }
}

//start_iter为初始迭代的次数
//losses_中存放loss值,初始时(iter_ < start_iter + average_loss)存放的loss的个数逐渐增加,个数达到average_loss时不再增加.
//之后新的loss值都是从前往后依次覆盖之前的保存的值,不断循环.
template <typename Dtype>
void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss) {
  if (losses_.size() < average_loss) {    //个数还不到滑动平均窗的大小,会逐渐增加losses_的大小
    losses_.push_back(loss);              //将loss存入
    int size = losses_.size();
    //smoothed_loss_为当前loss存入之前losses_的均值,存入后更新下均值
    smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
  } else {    
    int idx = (iter_ - start_iter) % average_loss;  //将iter_对应的loss存入losses_中的对应位置
    smoothed_loss_ += (loss - losses_[idx]) / average_loss; //先计算平均loss,再将值存入
    losses_[idx] = loss;
  }
}

solver.hpp源码

/**
  * @brief Enumeration of actions that a client of the Solver may request by
  * implementing the Solver's action request function, which a
  * client may optionally provide in order to request early termination
  * or saving a snapshot without exiting. In the executable caffe, this
  * mechanism is used to allow the snapshot to be saved when stopping
  * execution with a SIGINT (Ctrl-C).
  */
  namespace SolverAction {
    enum Enum {
      NONE = 0,  // Take no special action.
      STOP = 1,  // Stop training. snapshot_after_train controls whether a
                 // snapshot is created.    //停止,提前退出
      SNAPSHOT = 2  // Take a snapshot, and keep training.  //将当前的训练网络的参数存为快照文件,并继续后续操作
    };
  }

/**
 * @brief Type of a function that returns a Solver Action enumeration.
 */
typedef boost::function<SolverAction::Enum()> ActionCallback;

/**
 * @brief An interface for classes that perform optimization on Net%s.
 *
 * Requires implementation of ApplyUpdate to compute a parameter update
 * given the current state of the Net parameters.
 */
template <typename Dtype>
class Solver {
 public:
  explicit Solver(const SolverParameter& param);
  explicit Solver(const string& param_file);
  void Init(const SolverParameter& param);
  void InitTrainNet();
  void InitTestNets();

  // Client of the Solver optionally may call this in order to set the function
  // that the solver uses to see what action it should take (e.g. snapshot or
  // exit training early).
  void SetActionFunction(ActionCallback func);    //设置求解器动作的回调函数
  SolverAction::Enum GetRequestedAction();
  // The main entry of the solver function. In default, iter will be zero. Pass
  // in a non-zero iter number to resume training for a pre-trained net.
  virtual void Solve(const char* resume_file = NULL);
  inline void Solve(const string& resume_file) { Solve(resume_file.c_str()); }
  void Step(int iters);
  // The Restore method simply dispatches to one of the
  // RestoreSolverStateFrom___ protected methods. You should implement these
  // methods to restore the state from the appropriate snapshot type.
  void Restore(const char* resume_file);
  // The Solver::Snapshot function implements the basic snapshotting utility
  // that stores the learned net. You should implement the SnapshotSolverState()
  // function that produces a SolverState protocol buffer that needs to be
  // written to disk together with the learned net.
  void Snapshot();
  virtual ~Solver() {}
  inline const SolverParameter& param() const { return param_; }
  inline shared_ptr<Net<Dtype> > net() { return net_; }
  inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
    return test_nets_;
  }
  int iter() const { return iter_; }

  // Invoked at specific points during an iteration
  //迭代过程中调用的回调类,里面实现了两个函数,用于多gpu训练中的同步
  class Callback {
   protected:
    virtual void on_start() = 0;
    virtual void on_gradients_ready() = 0;

    template <typename T>
    friend class Solver;
  };
  const vector<Callback*>& callbacks() const { return callbacks_; }
  void add_callback(Callback* value) {
    callbacks_.push_back(value);    //加入
  }

  void CheckSnapshotWritePermissions();
  /**
   * @brief Returns the solver type.
   */
  virtual inline const char* type() const { return ""; }

  // Make and apply the update value for the current iteration.
  virtual void ApplyUpdate() = 0;

 protected:
  string SnapshotFilename(const string& extension);
  string SnapshotToBinaryProto();
  string SnapshotToHDF5();
  // The test routine
  void TestAll();
  void Test(const int test_net_id = 0);
  virtual void SnapshotSolverState(const string& model_filename) = 0;
  virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
  virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
  void DisplayOutputBlobs(const int net_id);
  void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);

  SolverParameter param_;
  int iter_;                      //当前的迭代次数
  int current_step_;              //当前迭代的阶段,在学习率更新策略为step和multistep中使用
  shared_ptr<Net<Dtype> > net_;   //训练网络
  vector<shared_ptr<Net<Dtype> > > test_nets_;    //所有的测试网络
  vector<Callback*> callbacks_;   //回调函数
  vector<Dtype> losses_;          //保存最近average_loss_次迭代的loss值
  Dtype smoothed_loss_;           //losses_的均值

  // A function that can be set by a client of the Solver to provide indication
  // that it wants a snapshot saved and/or to exit early.
  ActionCallback action_request_function_;  //返回值为求解器动作的回调函数

  // True iff a request to stop early was received.
  bool requested_early_exit_;   //是否需要提前退出

  // Timing information, handy to tune e.g. nbr of GPUs
  Timer iteration_timer_;       //计时器
  float iterations_last_;       //上一次开启计时器的iter_的值

  DISABLE_COPY_AND_ASSIGN(Solver);
};

小结

  1. 求解器的动作回调函数在caffe.cpp文件中设置,为SignalHandler::CheckForSignals()的函数指针。当Unix系统中出现SIGINT或SIGHUP信号时,GotSIGINT()GotSIGHUP()函数会返回相应标志,并清空信号。而SignalHandler::CheckForSignals()函数则会根据标志返回对应的求解器动作类型(NONE/STOP/SNAPSHOT),具体可参考signal_handler.cpp文件。
  2. Step()函数中每次迭代计算前向/反向过程时,都使用了ClearParamDiffs()函数清空梯度。这是因为caffe中每次反向传播时的梯度数据都是累加在原数据上的,所以每次迭代时都需要手动清空,这与PyTorch中需要手动将梯度清零一致。

Caffe的源码笔者是第一次阅读,一边阅读一边记录,对代码的理解和分析可能会存在错误或遗漏,希望各位读者批评指正,谢谢支持!

Caffe源码-Solver类

标签:mes   写入文件   end   blob   art   otn   method   函数指针   来源   

原文地址:https://www.cnblogs.com/Relu110/p/12079937.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!