标签:
libsvm很早之前就用了,现在封装一下方便自己使用,也方便大家更快的使用这个库,这个库一个挺有用的特性就是对测试样本的概率估计。源码在随笔的最后。liblinear的版本也是类似移植,主要是处理好数据的传入即可。
1.封装的类CxLibSVM
基于libsvm封装的类,如下:
#pragma once #include <string> #include <vector> #include <iostream> #include "./libsvm/svm.h" using namespace std; //内存分配 #define Malloc(type,n) (type *)malloc((n)*sizeof(type)) /************************************************************************/ /* 封装svm */ /************************************************************************/ class CxLibSVM { private: struct svm_model* model_; struct svm_parameter param; struct svm_problem prob; struct svm_node * x_space; public: //************************************ // 描 述: 构造函数 // 方 法: CxLibSVM // 文 件 名: CxLibSVM::CxLibSVM // 访问权限: public // 返 回 值: // 限 定 符: //************************************ CxLibSVM() { model_ = NULL; } //************************************ // 描 述: 析构函数 // 方 法: ~CxLibSVM // 文 件 名: CxLibSVM::~CxLibSVM // 访问权限: public // 返 回 值: // 限 定 符: //************************************ ~CxLibSVM() { free_model(); } //************************************ // 描 述: 训练模型 // 方 法: train // 文 件 名: CxLibSVM::train // 访问权限: public // 参 数: const vector<vector<double>> & x // 参 数: const vector<double> & y // 参 数: const int & alg_type // 返 回 值: void // 限 定 符: //************************************ void train(const vector<vector<double>>& x, const vector<double>& y) { if (x.size() == 0)return; //释放先前的模型 free_model(); /*初始化*/ long len = x.size(); long dim = x[0].size(); long elements = len*dim; //参数初始化,参数调整部分在这里修改即可 // 默认参数 param.svm_type = C_SVC; //算法类型 param.kernel_type = LINEAR; //核函数类型 param.degree = 3; //多项式核函数的参数degree param.coef0 = 0; //多项式核函数的参数coef0 param.gamma = 0.5; //1/num_features,rbf核函数参数 param.nu = 0.5; //nu-svc的参数 param.C = 10; //正则项的惩罚系数 param.eps = 1e-3; //收敛精度 param.cache_size = 100; //求解的内存缓冲 100MB param.p = 0.1; param.shrinking = 1; param.probability = 1; //1表示训练时生成概率模型,0表示训练时不生成概率模型,用于预测样本的所属类别的概率 param.nr_weight = 0; //类别权重 param.weight = NULL; //样本权重 param.weight_label = NULL; //类别权重 //转换数据为libsvm格式 prob.l = len; prob.y = Malloc(double, prob.l); prob.x = Malloc(struct svm_node *, prob.l); x_space = Malloc(struct svm_node, elements+len); int j = 0; for (int l = 0; l < len; l++) { prob.x[l] = &x_space[j]; for (int d = 0; d < dim; d++) { x_space[j].index = d+1; x_space[j].value = x[l][d]; j++; } x_space[j++].index = -1; prob.y[l] = y[l]; } /*训练*/ model_ = svm_train(&prob, ¶m); } //************************************ // 描 述: 预测测试样本所属类别和概率 // 方 法: predict // 文 件 名: CxLibSVM::predict // 访问权限: public // 参 数: const vector<double> & x 样本 // 参 数: double & prob_est 类别估计的概率 // 返 回 值: double 预测的类别 // 限 定 符: //************************************ int predict(const vector<double>& x,double& prob_est) { //数据转换 svm_node* x_test = Malloc(struct svm_node, x.size()+1); for (unsigned int i=0; i<x.size(); i++) { x_test[i].index = i; x_test[i].value = x[i]; } x_test[x.size()].index = -1; double *probs = new double[model_->nr_class];//存储了所有类别的概率 //预测类别和概率 int value = (int)svm_predict_probability(model_, x_test, probs); for (int k = 0; k < model_->nr_class;k++) {//查找类别相对应的概率 if (model_->label[k] == value) { prob_est = probs[k]; break; } } delete[] probs; return value; } //************************************ // 描 述: 导入svm模型 // 方 法: load_model // 文 件 名: CxLibSVM::load_model // 访问权限: public // 参 数: string model_path 模型路径 // 返 回 值: int 0表示成功;-1表示失败 // 限 定 符: //************************************ int load_model(string model_path) { //释放原来的模型 free_model(); //导入模型 model_ = svm_load_model(model_path.c_str()); if (model_ == NULL)return -1; return 0; } //************************************ // 描 述: 保存模型 // 方 法: save_model // 文 件 名: CxLibSVM::save_model // 访问权限: public // 参 数: string model_path 模型路径 // 返 回 值: int 0表示成功,-1表示失败 // 限 定 符: //************************************ int save_model(string model_path) { int flag = svm_save_model(model_path.c_str(), model_); return flag; } private: //************************************ // 描 述: 释放svm模型内存 // 方 法: free_model // 文 件 名: CxLibSVM::free_model // 访问权限: private // 返 回 值: void // 限 定 符: //************************************ void free_model() { if (model_ != NULL) { svm_free_and_destroy_model(&model_); svm_destroy_param(¶m); free(prob.y); free(prob.x); free(x_space); } } };
2.调用封装的类CxLibSVM
如何调用该类请看如下代码:
#include "cxlibsvm.hpp" #include <time.h> #include <iostream> using namespace std; void main() { //初始化libsvm CxLibSVM svm; /*1、准备训练数据*/ vector<vector<double>> x; //样本集 vector<double> y; //样本类别集 long sample_num = 200; //样本数 long dim = 10000; //样本类别 double scale = 1; //数据缩放尺度 srand((unsigned)time(NULL));//随机数 //生成随机的正类样本 for (int i = 0; i < sample_num; i++) { vector<double> rx; for (int j = 0; j < dim; j++) { rx.push_back(scale*(rand() % 10) ); } x.push_back(rx); y.push_back(1); } //生成随机的负类样本 for (int i = 0; i < sample_num; i++) { vector<double> rx; for (int j = 0; j < dim; j++) { rx.push_back(-scale*(rand() % 10)); } x.push_back(rx); y.push_back(2); } /*2、训练*/ svm.train(x, y); /*3、保存模型*/ string model_path = ".\\svm_model.txt"; svm.save_model(model_path); /*4、导入模型*/ string model_path_p = ".\\svm_model.txt"; svm.load_model(model_path_p); /*5、预测*/ //生成随机测试数据 vector<double> x_test; for (int j = 0; j < dim; j++) { x_test.push_back(scale*(rand() % 10)); } double prob_est; //预测 double value = svm.predict(x_test, prob_est); //打印预测类别和概率 printf("label:%f,prob:%f", value, prob_est); }
3.测试模型
模型如下:
svm_type c_svc kernel_type linear nr_class 2 total_sv 8 rho -0.0379061 label 1 2 probA -3.05015 probB 0.103192 nr_sv 2 6 SV 0.002455897026356498 1:5 2:0 3:0 4:2 5:4 6:0 7:1 8:4 9:5 10:3 0.007680247728335155 1:2 2:3 3:6 4:0 5:3 6:0 7:1 8:0 9:2 10:1 -0.000110773050020484 1:-1 2:-3 3:-1 4:-2 5:-5 6:-1 7:-2 8:-4 9:-0 10:-6 -0.002310331085133643 1:-1 2:-0 3:-1 4:-2 5:-3 6:-5 7:-5 8:-5 9:-0 10:-4 -0.001462570160622233 1:-5 2:-6 3:-2 4:-9 5:-0 6:-0 7:-2 8:-0 9:-1 10:-0 -0.002824751492599935 1:-0 2:-2 3:-5 4:-0 5:-0 6:-0 7:-3 8:-8 9:-1 10:-0 -0.003207598246179264 1:-2 2:-4 3:-3 4:-1 5:-1 6:-7 7:-1 8:-2 9:-1 10:-1 -0.0002201207201360932 1:-4 2:-2 3:-1 4:-7 5:-0 6:-2 7:-0 8:-4 9:-1 10:-6
测试样本的类别如下:
label:1.000000,prob:0.994105
4.源码
码农最喜欢的稻草了,封装的项目源码,请看附件:CxLibSVM.zip
libsvm源码:libsvm
标签:
原文地址:http://www.cnblogs.com/cv-pr/p/5646434.html