标签:tsp one com 计算 counter blob ons ble 需要
1 int main(int argc, char** argv) { 2 ..... 3 return GetBrewFunction(caffe::string(argv[1]))(); 4 .... 5 }
g_brew_map实现过程,首先通过 typedef定义函数指针 typedef int (*BrewFunction)(); 这个是用typedef定义函数指针方法。这个程序定义一个BrewFunction函数指针类型,在caffe.cpp 中 BrewFunction 作为GetBrewFunction()函数的返回类型,可以是 train(),test(),device_query(),time() 这四个函数指针的其中一个。在train(),test(),中可以调用solver类的函数,从而进入到net,进入到每一层,运行整个caffe程序。然后对每个函数注册。
1 RegisterBrewFunction(train) 2 RegisterBrewFunction(test) 3 RegisterBrewFunction(device_query) 4 RegisterBrewFunction(time)
如果需要,可以增加其他的方式,然后通过RegisterBrewFunction()函数注册一下即可。
接着调用train()函数,train函数中主要有三个方法ReadSolverParamsFromTextFileOrDie、CreateSolver、Solve。
1 // Train / Finetune a model. 2 int train() { 3 ...... 4 caffe::SolverParameter solver_param; 5 caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);//从-solver参数读取solver_param 6 ...... 7 shared_ptr<caffe::Solver<float> > 8 solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));//从参数创建solver,同样采用string到函数指针的映射实现,用到了工厂模式 9 10 if (FLAGS_snapshot.size()) {//迭代snapshot次后保存模型一次 11 LOG(INFO) << "Resuming from " << FLAGS_snapshot; 12 solver->Restore(FLAGS_snapshot.c_str()); 13 } else if (FLAGS_weights.size()) {//若采用finetuning,则拷贝weight到指定模型 14 CopyLayers(solver.get(), FLAGS_weights); 15 } 16 17 if (gpus.size() > 1) { 18 caffe::P2PSync<float> sync(solver, NULL, solver->param()); 19 sync.Run(gpus); 20 } else { 21 LOG(INFO) << "Starting Optimization"; 22 solver->Solve();//开始训练网络 23 } 24 LOG(INFO) << "Optimization Done."; 25 return 0; 26 }
caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param)解析-solver指定的solver.prototxt的文件内容到solver_param中
CreateSolver函数构建solver和net,该函数是初始化的入口,会通过执行Solver的构造函数,调用 void Solver<Dtype>::Init(const SolverParameter& param),该函数内有InitTrainNet()、InitTestNets()。对于InitTrainNet函数:
...... net_.reset(new Net<Dtype>(net_param));
调用Net类的构造函数,然后执行Init()操作,该函数具体的内容如下图和源码所示:
1 template <typename Dtype> 2 void Net<Dtype>::Init(const NetParameter& in_param) { 3 ........//过滤校验参数FilterNet 4 FilterNet(in_param, &filtered_param); 5 .........//插入Splits层 6 InsertSplits(filtered_param, ¶m); 7 .......// 构建网络中输入输出存储结构 8 bottom_vecs_.resize(param.layer_size()); 9 top_vecs_.resize(param.layer_size()); 10 bottom_id_vecs_.resize(param.layer_size()); 11 param_id_vecs_.resize(param.layer_size()); 12 top_id_vecs_.resize(param.layer_size()); 13 bottom_need_backward_.resize(param.layer_size()); 14 15 for (int layer_id = 0; layer_id < param.layer_size(); ++layer_id) { 16 ...//创建层 17 layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param)); 18 layer_names_.push_back(layer_param.name()); 19 LOG_IF(INFO, Caffe::root_solver()) 20 << "Creating Layer " << layer_param.name(); 21 bool need_backward = false; 22 23 // Figure out this layer‘s input and output 24 for (int bottom_id = 0; bottom_id < layer_param.bottom_size(); 25 ++bottom_id) { 26 const int blob_id = AppendBottom(param, layer_id, bottom_id, 27 &available_blobs, &blob_name_to_idx); 28 29 30 ........//创建相关blob 31 // If the layer specifies that AutoTopBlobs() -> true and the LayerParameter 32 // specified fewer than the required number (as specified by 33 // ExactNumTopBlobs() or MinTopBlobs()), allocate them here. 34 Layer<Dtype>* layer = layers_[layer_id].get(); 35 if (layer->AutoTopBlobs()) { 36 const int needed_num_top = 37 std::max(layer->MinTopBlobs(), layer->ExactNumTopBlobs()); 38 for (; num_top < needed_num_top; ++num_top) { 39 // Add "anonymous" top blobs -- do not modify available_blobs or 40 // blob_name_to_idx as we don‘t want these blobs to be usable as input 41 // to other layers. 42 AppendTop(param, layer_id, num_top, NULL, NULL); 43 } 44 } 45 46 47 .....//执行SetUp() 48 // After this layer is connected, set it up. 49 layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]); 50 LOG_IF(INFO, Caffe::root_solver()) 51 << "Setting up " << layer_names_[layer_id]; 52 for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) { 53 if (blob_loss_weights_.size() <= top_id_vecs_[layer_id][top_id]) { 54 blob_loss_weights_.resize(top_id_vecs_[layer_id][top_id] + 1, Dtype(0)); 55 } 56 blob_loss_weights_[top_id_vecs_[layer_id][top_id]] = layer->loss(top_id); 57 LOG_IF(INFO, Caffe::root_solver()) 58 << "Top shape: " << top_vecs_[layer_id][top_id]->shape_string(); 59 if (layer->loss(top_id)) { 60 LOG_IF(INFO, Caffe::root_solver()) 61 << " with loss weight " << layer->loss(top_id); 62 } 63 memory_used_ += top_vecs_[layer_id][top_id]->count(); 64 } 65 LOG_IF(INFO, Caffe::root_solver()) 66 << "Memory required for data: " << memory_used_ * sizeof(Dtype); 67 const int param_size = layer_param.param_size(); 68 const int num_param_blobs = layers_[layer_id]->blobs().size(); 69 CHECK_LE(param_size, num_param_blobs) 70 << "Too many params specified for layer " <<
SetUp是怎么构建的呢?
1 virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, 2 const vector<Blob<Dtype>*>& top) {} 3 4 void SetUp(const vector<Blob<Dtype>*>& bottom, 5 const vector<Blob<Dtype>*>& top) { 6 InitMutex(); 7 CheckBlobCounts(bottom, top); 8 LayerSetUp(bottom, top); 9 Reshape(bottom, top); 10 SetLossWeights(top); 11 }
初始化的总体流程大概就是新建一个Solver对象,然后调用Solver类的构造函数,然后在Solver的构造函数中又会新建Net类实例,在Net类的构造函数中又会新建各个layer的实例,一直具体到设置每个Blob,大概就完成了网络初始化的工作了。
train函数中CreateSolver()执行完成后,接下来是具体训练过程,执行Solve()函数---->Step()--->结束
Solve的具体内容和代码:
1 template <typename Dtype> 2 void Solver<Dtype>::Solve(const char* resume_file) { 3 CHECK(Caffe::root_solver()); 4 LOG(INFO) << "Solving " << net_->name(); 5 LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); 6 7 // For a network that is trained by the solver, no bottom or top vecs 8 // should be given, and we will just provide dummy vecs. 9 int start_iter = iter_; 10 Step(param_.max_iter() - iter_); 11 12 // overridden by setting snapshot_after_train := false 13 if (param_.snapshot_after_train() 14 && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) { 15 Snapshot(); 16 } 17 18 // display loss 19 if (param_.display() && iter_ % param_.display() == 0) { 20 int average_loss = this->param_.average_loss(); 21 Dtype loss; 22 net_->Forward(&loss); 23 24 UpdateSmoothedLoss(loss, start_iter, average_loss); 25 26 27 if (param_.test_interval() && iter_ % param_.test_interval() == 0) { 28 TestAll(); 29 } 30 }
然后开始执行Step函数,具体内容和代码:
1 template <typename Dtype> 2 void Solver<Dtype>::Step(int iters) 3 { 4 // 起始迭代步数 5 const int start_iter = iter_; 6 // 终止迭代步数 7 const int stop_iter = iter_ + iters; 8 9 // 判断是否已经完成设定步数 10 while (iter_ < stop_iter) 11 { 12 // 将net_中的Bolb梯度参数置为零 13 net_->ClearParamDiffs(); 14 15 ... 16 17 // accumulate the loss and gradient 18 Dtype loss = 0; 19 for (int i = 0; i < param_.iter_size(); ++i) 20 { 21 // 正向传导和反向传导,并计算loss 22 loss += net_->ForwardBackward(); 23 } 24 loss /= param_.iter_size(); 25 26 // 为了输出结果平滑,将临近的average_loss个loss数值进行平均,存储在成员变量smoothed_loss_中 27 UpdateSmoothedLoss(loss, start_iter, average_loss); 28 29 // BP算法更新权重 30 ApplyUpdate(); 31 32 // Increment the internal iter_ counter -- its value should always indicate 33 // the number of times the weights have been updated. 34 ++iter_; 35 } 36 }
while循环中先调用了网络类Net::ForwardBackward()成员函数进行正向传导和反向传导,并计算loss
1 Dtype ForwardBackward() { 2 Dtype loss; 3 //正向传导 4 Forward(&loss); 5 //反向传导 6 Backward(); 7 return loss; 8 }
而Fordward函数中调用了ForwardFromTo,而FordwardFromTo又调用了每个layer的Fordward。反向传导函数Backward()调用了BackwardFromTo(int start, int end)函数。正向传导和反向传导结束后,再调用SGDSolver::ApplyUpdate()成员函数进行权重更新。
1 template <typename Dtype> 2 void SGDSolver<Dtype>::ApplyUpdate() 3 { 4 // 获取当前学习速率 5 Dtype rate = GetLearningRate(); 6 if (this->param_.display() && this->iter_ % this->param_.display() == 0) 7 { 8 LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; 9 } 10 11 // 在计算当前梯度的时候,如果该值超过了阈值clip_gradients,则将梯度直接设置为该阈值 12 // 此处阈值设为-1,即不起作用 13 ClipGradients(); 14 15 // 逐层更新网络中的可学习层 16 for (int param_id = 0; param_id < this->net_->learnable_params().size(); 17 ++param_id) 18 { 19 // 归一化 20 Normalize(param_id); 21 // L2范数正则化添加衰减权重 22 Regularize(param_id); 23 // 随机梯度下降法计算更新值 24 ComputeUpdateValue(param_id, rate); 25 } 26 // 更新权重 27 this->net_->Update(); 28 }
最后将迭代次数++iter_,继续while循环,直到迭代次数完成。 这就是整个网络的训练过程。
标签:tsp one com 计算 counter blob ons ble 需要
原文地址:http://www.cnblogs.com/liuzhongfeng/p/7289956.html