码迷,mamicode.com
首页 > 编程语言 > 详细

C++ Multilabel compare function

时间:2015-11-04 22:47:39      阅读:297      评论:0      收藏:0      [点我收藏+]

标签:

  1 void FakeSolver::testTrainedClassifiers() {
  2 
  3     map<int, PersonClassifier>::iterator model_it;
  4     map<int, vector<Photo> >::iterator vali_it;
  5     map<int, vector<float> > model_posi_dis;
  6     map<int, vector<float> >::iterator posi_it;
  7 
  8     int classifier_num = 0;
  9     int validation_set_size = 0;
 10     float max_dis;
 11     int max_label;
 12     float tmp;
 13     float accuracy = 0;
 14     int spl_posi_num = 0;
 15 
 16     vector<int> classifiers_keys;
 17     for (model_it = this->classifiers_map.begin(); model_it != this->classifiers_map.end(); ++model_it) {
 18         model_it->second.loadXMl();
 19         classifiers_keys.push_back(model_it->first); //分类器的id ---> classifiers_keys;
 20 
 21 //        // Model_it->first: 1  ---> 第一个分类器;
 22 //        // Model_it->first: 2  ---> 第二个分类器;
 23 //        cout << "Model_it->first: " << model_it->first << end;
 24 
 25         spl_posi_num += model_it->second.getSplPositiveNum(); //spl正样本的数量;
 26         classifier_num++; //分类器的数目;
 27     }
 28 
 29     vector<int> validation_set_keys;
 30     for (vali_it = this->album->validation_set.begin(); vali_it != this->album->validation_set.end(); ++vali_it) {
 31         validation_set_keys.push_back(vali_it->first);  //测试集键值;
 32 
 33 //        // 测试集的索引;
 34 //        // 输出 vali_it->first    1
 35 //        // 输出 vali_it->first    2
 36 //        cout << "vali_it->first" << vali_it->first << end;
 37     }
 38 
 39     int validation_keys_size = validation_set_keys.size();
 40     int photo_size = 0;
 41 
 42 //    // 输出检查 测试集的大小 ---> 大小为2;
 43 //    cout << "validation_keys_size: " << validation_keys_size << end;
 44 
 45     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////
 46 
 47     for (int j = 0; j < validation_keys_size; ++j) {
 48         vali_it = this->album->validation_set.find(validation_set_keys[j]);
 49         photo_size = vali_it->second.size(); //测试集每一类的照片数量-->为1;
 50 
 51 //        cout << "photo_size: " << photo_size << end;
 52 
 53         float acc = 0;
 54 
 55 
 56         double TP=1;
 57         double TN=1;
 58         double FP=1;
 59         double FN=1;
 60 
 61 
 62         ////////////////////////////////////////////////////////////////////////////////////////////////////////////
 63         model_it = this->classifiers_map.find(vali_it->first); // 第一个分类器;第二个分类器;
 64         if (model_it != this->classifiers_map.end() && model_it->second.isTrained())
 65         {
 66 
 67             for (int i = 0; i < photo_size; ++i) {  // 对于每一个图像;
 68                 max_dis = -1000;
 69                 max_label = -1;
 70                 int predict_label;
 71                 int test_image_id;
 72 
 73                 // 将该图像输入每一个分类器,观察其离每一个分界面的距离;
 74                 for (model_it = this->classifiers_map.begin(); model_it != this->classifiers_map.end(); ++model_it) {
 75                     if (model_it->second.isTrained())
 76 
 77 //                        // 测试的是哪一张图像?
 78 //                        int test_image_id;
 79 //                        test_image_id = vali_it->first;
 80 //                        cout << "test_image_id: " << test_image_id << end;
 81 
 82 
 83                         tmp = model_it->second.getDis(vali_it->second[i]); // tmp 是到分界面的距离
 84 
 85                         if (tmp > 1) {
 86                             predict_label = 1; //到分界面的距离大于某一阈值,则认为存在该属性; 预测标签predict_label为1;
 87                         }
 88                         else
 89                             predict_label = 0;
 90 
 91 
 92 
 93                         //------------------------------------------------------------------
 94                         //------------------------------------------------------------------
 95                         // 找到分类器的id 和 属性 之间的对应关系;
 96                         // 该分类器对应的id是 validation_set_keys[j] ;
 97                         ifstream relations("/home/wangxiao/Downloads/imageID_Attribute.txt");
 98                         if (!relations) {
 99                             cerr << "fail to open the input file !" << end;
100                         }
101 
102                         map<int, string> ID_Attributes;
103 
104                         // int id = validation_set_keys[j];
105                         int id;
106                         std::string label;
107                         std::string Attribute;
108 //
109                         while(relations >> id >> label){
110                             ID_Attributes.insert(make_pair<int, string>(id, label));
111                         }
112 //
113 //                         map<int, string>::iterator la_it = ID_Attributes.find(validation_set_keys[j]);
114 
115                         map<int, string>::iterator la_it = ID_Attributes.find(vali_it->first);
116                         Attribute = la_it->second;   //该id对应的属性Attribute;
117 //                        //cout<<"Attribute: "<<Attribute<<end;
118                         relations.close();
119 
120                         ifstream file("/home/wangxiao/Downloads/Label.txt");
121                         if (!file) {
122                             cerr << "fail to open the input file !" << endl;
123                         }
124 
125                         std::string line;
126 //                        map<int, string> id_label;
127 //                        // int id2 = vali_it->first;
128                         int id2;
129                         std::string label2;
130                         int xxx;
131 //                        std::string Attribute2;
132 //                        std::string line;
133 //                        int id2;
134 
135                         while(getline(file, line)){
136                             label2 = line.substr(5, line.size());
137                             id2 = atoi((line.substr(1,4)).c_str());
138 
139                             if (id2 == xxx)
140                                 break;
141                         }
142 //
143 //                        while (in >> id2 >> label2) {
144 //                             id_label.insert(make_pair<int, string>(id2, label2));
145 //                        }
146 //                        // test_image_id = 0;
147 //                        map<int, string>::iterator search_it = id_label.find(id2);
148 //                        //std::string label2;
149 //                        //cout<<"id2: "<<id2<<end;
150 //                        Attribute2 = search_it->second;
151 
152                         // search_it=>first ---> 1  有问题!!!
153                 //                    cout << "search_it->first: " << search_it->first << end;
154 
155                         file.close();
156 
157                         // id_label: search_it->first ---> 1
158                 //                    cout << "id_label: " << search_it->first <<end;
159                 //                    cout << "label2:" << search_it->second << end;
160                 //                    cout << "test_image_id: " << test_image_id << end;
161 
162 
163 
164 
165 
166                     int true_label;
167                     bool condition = label2.find(Attribute);
168 
169 //                    cout << "condition: " << condition << end;
170 
171                     if (condition)
172                         true_label = 1;
173                     else
174                         true_label = 0;
175 
176                     if ((true_label = 1) && (predict_label = 1))
177                         TP++;
178                     else if ((true_label = 1) && (predict_label = 0))
179                         TN++;
180                     else if ((true_label = 0) && (predict_label = 0))
181                         FP++;
182                     else if ((true_label = 0) && (predict_label = 1))
183                         FN++;
184 
185                 }
186 
187 
188             }
189 
190         }
191 
192 
193 
194     }
195 
196     int TP, TN, FP, FN;
197     int all_positive = TP+FP;
198     int all_negative = FN+TN;
199 
200     if ((all_positive==0) || (all_negative==0)){
201         all_positive = 1;
202         all_negative = 1;
203     }
204 
205     accuracy = 0.5*(TP/(all_positive) + TN/(all_negative));
206 
207     cout << "accuracy:" << accuracy << end;
208     cout << "TP:" << TP << end;
209     cout << "TN:" << TN << end;
210     cout << "FP:" << FP << end;
211     cout << "FN:" << FN << end;
212     cout << "TP+TN+FP+FN = " << TP+TN+FP+FN << end;
213 
214 
215 
216 }

 

C++ Multilabel compare function

标签:

原文地址:http://www.cnblogs.com/wangxiaocvpr/p/4937530.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!