码迷,mamicode.com
首页 > 其他好文 > 详细

机器学习:weka中Evaluation类源码解析及输出AUC及交叉验证介绍

时间:2016-04-13 11:19:01      阅读:192      评论:0      收藏:0      [点我收藏+]

标签:

  在机器学习分类结果的评估中,ROC曲线下的面积AOC是一个非常重要的指标。下面是调用weka类,输出AOC的源码:

try {
// 1.读入数据集

                Instances data = new Instances(
                                      new BufferedReader(
                                        new FileReader("E:\\Develop/Weka-3-6/data/contact-lenses.arff")));

                data.setClassIndex(data.numAttributes() - 1);

// 2.训练分类器并用十字交叉验证法来获得Evaluation对象
// 注意这里的方法与我们在上几节中使用的验证法是不同。
                Classifier cl = new NaiveBayes();
                Evaluation eval = new Evaluation(data);
                eval.crossValidateModel(cl, data, 10, new Random(1));

         
// 3.生成用于得到ROC曲面和AUC值的Instances对象int classIndex = 0;
                System.out.println("The area under the ROC curve: " + eval.areaUnderROC(classIndex));
                
       System.out.println(eval.toClassDetailsString());
            System.out.println(eval.toSummaryString());
            System.out.println(eval.toMatrixString()); }
catch (Exception e) { e.printStackTrace(); }

 

  接着说一下交叉验证;

  如果没有分开训练集和测试集,可以使用Cross Validation方法,Evaluation中crossValidateModel方法的四个参数分别为,第一个是分类器,第二个是在某个数据集上评价的数据集,第三个参数是交叉检验的次数(10是比较常见的),第四个是一个随机数对象。

  注意:使用crossValidateModel时,分类器不需要先训练。

  类crossValidateModel的源码如下:

 public void crossValidateModel(Classifier classifier, Instances data,
    int numFolds, Random random, Object... forPredictionsPrinting)
    throws Exception {

    // Make a copy of the data we can reorder
    data = new Instances(data);
    data.randomize(random);
    if (data.classAttribute().isNominal()) {
      data.stratify(numFolds);
    }

    // We assume that the first element is a StringBuffer, the second a Range
    // (attributes
    // to output) and the third a Boolean (whether or not to output a
    // distribution instead
    // of just a classification)
    if (forPredictionsPrinting.length > 0) {
      // print the header first
      StringBuffer buff = (StringBuffer) forPredictionsPrinting[0];
      Range attsToOutput = (Range) forPredictionsPrinting[1];
      boolean printDist = ((Boolean) forPredictionsPrinting[2]).booleanValue();
      printClassificationsHeader(data, attsToOutput, printDist, buff);
    }

    // Do the folds
    for (int i = 0; i < numFolds; i++) {
      Instances train = data.trainCV(numFolds, i, random);
      setPriors(train);
      Classifier copiedClassifier = Classifier.makeCopy(classifier);
      copiedClassifier.buildClassifier(train);
      Instances test = data.testCV(numFolds, i);
      evaluateModel(copiedClassifier, test, forPredictionsPrinting);
    }
    m_NumFolds = numFolds;
  }

 

输出结果截图:

更新中。。。

 

机器学习:weka中Evaluation类源码解析及输出AUC及交叉验证介绍

标签:

原文地址:http://www.cnblogs.com/rongyux/p/5386120.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!