标签:
1 void FakeSolver::testTrainedClassifiers() { 2 3 map<int, PersonClassifier>::iterator model_it; 4 map<int, vector<Photo> >::iterator vali_it; 5 map<int, vector<float> > model_posi_dis; 6 map<int, vector<float> >::iterator posi_it; 7 8 int classifier_num = 0; 9 int validation_set_size = 0; 10 float max_dis; 11 int max_label; 12 float tmp; 13 float accuracy = 0; 14 int spl_posi_num = 0; 15 16 vector<int> classifiers_keys; 17 for (model_it = this->classifiers_map.begin(); model_it != this->classifiers_map.end(); ++model_it) { 18 model_it->second.loadXMl(); 19 classifiers_keys.push_back(model_it->first); //分类器的id ---> classifiers_keys; 20 21 // // Model_it->first: 1 ---> 第一个分类器; 22 // // Model_it->first: 2 ---> 第二个分类器; 23 // cout << "Model_it->first: " << model_it->first << end; 24 25 spl_posi_num += model_it->second.getSplPositiveNum(); //spl正样本的数量; 26 classifier_num++; //分类器的数目; 27 } 28 29 vector<int> validation_set_keys; 30 for (vali_it = this->album->validation_set.begin(); vali_it != this->album->validation_set.end(); ++vali_it) { 31 validation_set_keys.push_back(vali_it->first); //测试集键值; 32 33 // // 测试集的索引; 34 // // 输出 vali_it->first 1 35 // // 输出 vali_it->first 2 36 // cout << "vali_it->first" << vali_it->first << end; 37 } 38 39 int validation_keys_size = validation_set_keys.size(); 40 int photo_size = 0; 41 42 // // 输出检查 测试集的大小 ---> 大小为2; 43 // cout << "validation_keys_size: " << validation_keys_size << end; 44 45 //////////////////////////////////////////////////////////////////////////////////////////////////////////////// 46 47 for (int j = 0; j < validation_keys_size; ++j) { 48 vali_it = this->album->validation_set.find(validation_set_keys[j]); 49 photo_size = vali_it->second.size(); //测试集每一类的照片数量-->为1; 50 51 // cout << "photo_size: " << photo_size << end; 52 53 float acc = 0; 54 55 56 double TP=1; 57 double TN=1; 58 double FP=1; 59 double FN=1; 60 61 62 //////////////////////////////////////////////////////////////////////////////////////////////////////////// 63 model_it = this->classifiers_map.find(vali_it->first); // 第一个分类器;第二个分类器; 64 if (model_it != this->classifiers_map.end() && model_it->second.isTrained()) 65 { 66 67 for (int i = 0; i < photo_size; ++i) { // 对于每一个图像; 68 max_dis = -1000; 69 max_label = -1; 70 int predict_label; 71 int test_image_id; 72 73 // 将该图像输入每一个分类器,观察其离每一个分界面的距离; 74 for (model_it = this->classifiers_map.begin(); model_it != this->classifiers_map.end(); ++model_it) { 75 if (model_it->second.isTrained()) 76 77 // // 测试的是哪一张图像? 78 // int test_image_id; 79 // test_image_id = vali_it->first; 80 // cout << "test_image_id: " << test_image_id << end; 81 82 83 tmp = model_it->second.getDis(vali_it->second[i]); // tmp 是到分界面的距离 84 85 if (tmp > 1) { 86 predict_label = 1; //到分界面的距离大于某一阈值,则认为存在该属性; 预测标签predict_label为1; 87 } 88 else 89 predict_label = 0; 90 91 92 93 //------------------------------------------------------------------ 94 //------------------------------------------------------------------ 95 // 找到分类器的id 和 属性 之间的对应关系; 96 // 该分类器对应的id是 validation_set_keys[j] ; 97 ifstream relations("/home/wangxiao/Downloads/imageID_Attribute.txt"); 98 if (!relations) { 99 cerr << "fail to open the input file !" << end; 100 } 101 102 map<int, string> ID_Attributes; 103 104 // int id = validation_set_keys[j]; 105 int id; 106 std::string label; 107 std::string Attribute; 108 // 109 while(relations >> id >> label){ 110 ID_Attributes.insert(make_pair<int, string>(id, label)); 111 } 112 // 113 // map<int, string>::iterator la_it = ID_Attributes.find(validation_set_keys[j]); 114 115 map<int, string>::iterator la_it = ID_Attributes.find(vali_it->first); 116 Attribute = la_it->second; //该id对应的属性Attribute; 117 // //cout<<"Attribute: "<<Attribute<<end; 118 relations.close(); 119 120 ifstream file("/home/wangxiao/Downloads/Label.txt"); 121 if (!file) { 122 cerr << "fail to open the input file !" << endl; 123 } 124 125 std::string line; 126 // map<int, string> id_label; 127 // // int id2 = vali_it->first; 128 int id2; 129 std::string label2; 130 int xxx; 131 // std::string Attribute2; 132 // std::string line; 133 // int id2; 134 135 while(getline(file, line)){ 136 label2 = line.substr(5, line.size()); 137 id2 = atoi((line.substr(1,4)).c_str()); 138 139 if (id2 == xxx) 140 break; 141 } 142 // 143 // while (in >> id2 >> label2) { 144 // id_label.insert(make_pair<int, string>(id2, label2)); 145 // } 146 // // test_image_id = 0; 147 // map<int, string>::iterator search_it = id_label.find(id2); 148 // //std::string label2; 149 // //cout<<"id2: "<<id2<<end; 150 // Attribute2 = search_it->second; 151 152 // search_it=>first ---> 1 有问题!!! 153 // cout << "search_it->first: " << search_it->first << end; 154 155 file.close(); 156 157 // id_label: search_it->first ---> 1 158 // cout << "id_label: " << search_it->first <<end; 159 // cout << "label2:" << search_it->second << end; 160 // cout << "test_image_id: " << test_image_id << end; 161 162 163 164 165 166 int true_label; 167 bool condition = label2.find(Attribute); 168 169 // cout << "condition: " << condition << end; 170 171 if (condition) 172 true_label = 1; 173 else 174 true_label = 0; 175 176 if ((true_label = 1) && (predict_label = 1)) 177 TP++; 178 else if ((true_label = 1) && (predict_label = 0)) 179 TN++; 180 else if ((true_label = 0) && (predict_label = 0)) 181 FP++; 182 else if ((true_label = 0) && (predict_label = 1)) 183 FN++; 184 185 } 186 187 188 } 189 190 } 191 192 193 194 } 195 196 int TP, TN, FP, FN; 197 int all_positive = TP+FP; 198 int all_negative = FN+TN; 199 200 if ((all_positive==0) || (all_negative==0)){ 201 all_positive = 1; 202 all_negative = 1; 203 } 204 205 accuracy = 0.5*(TP/(all_positive) + TN/(all_negative)); 206 207 cout << "accuracy:" << accuracy << end; 208 cout << "TP:" << TP << end; 209 cout << "TN:" << TN << end; 210 cout << "FP:" << FP << end; 211 cout << "FN:" << FN << end; 212 cout << "TP+TN+FP+FN = " << TP+TN+FP+FN << end; 213 214 215 216 }
C++ Multilabel compare function
标签:
原文地址:http://www.cnblogs.com/wangxiaocvpr/p/4937530.html