一、算法
关于REPTree我实在是没找到什么相关其算法的资料,或许是Weka自创的一个关于决策树的改进,也许是其它某种决策树方法的别名,根据类的注释:Fast decision tree learner. Builds a decision/regression tree using information gain/variance and prunes it using reduced-error pruning (with backfitting). Only sorts values for numeric attributes once. Missing values are dealt with by splitting the corresponding instances into pieces (i.e. as in C4.5).
我们大概知道和C4.5相比,大概多了backfitting过程,并且数值型排序只进行一次(回想一下J48也就是C4.5算法是每个数据子集都要进行排序),并且缺失值的处理方式和C4.5一样,走不同的path再把结果进行加权。
具体和C4.5的比较将在代码分析之后给出一个总结。
二、buildClassifier
“大名鼎鼎”的分类器训练主入口,几乎每篇分析分类器源码都从这个方法入手。
public void buildClassifier(Instances data) throws Exception {
// 首先例行公事看一下给定数据集是否能使用REPTree进行分类,REPTREE基本能支持所有类型
getCapabilities().testWithFail(data);
// 把classIndex上没有数据的instance干掉,这些数据既不能用于训练也不能用于backfit
data = new Instances(data);
data.deleteWithMissingClass();
Random random = new Random(m_Seed);
m_zeroR = null;
if (data.numAttributes() == 1) {
m_zeroR = new ZeroR();//如果只有一列的话,就是用m_ZerO作为分类器,很直观只有一列的话肯定就是结果列了,只有结果列无法训练分类器,只能使用最基本的米ZerO作为分类器,mZerO的分类方法再上篇日志有说到。
m_zeroR.buildClassifier(data);
return;
}
// Randomize and stratify
data.randomize(random);//进行随机排列
if (data.classAttribute().isNominal()) {
data.stratify(m_NumFolds);//如果枚举型还要进行一下分层,目的是
}
// 如果需要剪枝,则分为train集合和prune集合,否则只要train集合就行了
Instances train = null;
Instances prune = null;
if (!m_NoPruning) {
train = data.trainCV(m_NumFolds, 0, random);//这里是用了多折交叉验证的方法取得train和test
prune = data.testCV(m_NumFolds, 0);
} else {
train = data;
}
// 建立了两个数组,第一维数据无意义,只是把三维数组当二维数组用而已,第二维代表各属性,第三维代表排序的index(顺序统计量)
int[][][] sortedIndices = new int[1][train.numAttributes()][0];//这个里面存放的是各instance的下标
double[][][] weights = new double[1][train.numAttributes()][0];//这个里面存放的是下标对应的instance的weight
double[] vals = new double[train.numInstances()];//这个是临时数组,用于排序用的
for (int j = 0; j < train.numAttributes(); j++) {
if (j != train.classIndex()) {
weights[0][j] = new double[train.numInstances()];
if (train.attribute(j).isNominal()) {
//如果是枚举类型,所做的排序工作就是简单的把Missing放到最后面
sortedIndices[0][j] = new int[train.numInstances()];
int count = 0;
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
if (!inst.isMissing(j)) {
sortedIndices[0][j][count] = i;
weights[0][j][count] = inst.weight();
count++;
}
}
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
if (inst.isMissing(j)) {
sortedIndices[0][j][count] = i;
weights[0][j][count] = inst.weight();
count++;
}
}
} else {
// 如果是数值类型,则进行排序
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
vals[i] = inst.value(j);
}
sortedIndices[0][j] = Utils.sort(vals);
for (int i = 0; i < train.numInstances(); i++) {
weights[0][j][i] = train.instance(sortedIndices[0][j][i]).weight();
}
}
}
}
// 这里建立数组存放训练集中每个类的分布
double[] classProbs = new double[train.numClasses()];
double totalWeight = 0, totalSumSquared = 0;
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
if (data.classAttribute().isNominal()) { classProbs[(int)inst.classValue()] += inst.weight();//如果是枚举类型,就进行简单的统计
totalWeight += inst.weight();
} else {
classProbs[0] += inst.classValue() * inst.weight();//如果是数值型,就相加,到后面进行取平均的操作
totalSumSquared += inst.classValue() * inst.classValue() * inst.weight();
totalWeight += inst.weight();
}
}
m_Tree = new Tree();//建立决策树节点
double trainVariance = 0;//训练集的方差
if (data.classAttribute().isNumeric()) {
trainVariance = m_Tree.
singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight;
classProbs[0] /= totalWeight;//这里取平均操作
}
// Build tree
m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs,
new Instances(train, 0), m_MinNum, m_MinVarianceProp *
trainVariance, 0, m_MaxDepth);//执行具体树上的构建操作,这参数还真多
// Insert pruning data and perform reduced error pruning
if (!m_NoPruning) {
m_Tree.insertHoldOutSet(prune);//传入剪枝数据
m_Tree.reducedErrorPrune();//进行剪枝
m_Tree.backfitHoldOutSet();//backfit
}
}
(2)Tree.buildTree
Tree是REPTree的一个子对象,训练用参数较多。
protected void buildTree(int[][][] sortedIndices, double[][][] weights,
Instances data, double totalWeight,
double[] classProbs, Instances header,
double minNum, double minVariance,
int depth, int maxDepth)
throws Exception {
//第一个参数是按属性排好序的下标,第二个是这些下标对应的weight,第三个是训练数据<span style="white-space:pre"> </span>//第四个是总权重,第五个是各类的分布,第六个是表头,第七个是每个节点最小instance数量
<span style="white-space:pre"> </span>//第八个是最小的方差 ,第九个是当前深度(0 base),第十个是最大深度
m_Info = header;//首先存下表头
if (data.classAttribute().isNumeric()) {
m_HoldOutDist = new double[2];//这个数组用于存放分布
} else {
m_HoldOutDist = new double[data.numClasses()];
}
// 看看是否有有效数据
int helpIndex = 0;
if (data.classIndex() == 0) {
helpIndex = 1;//传入的数据至少两列,因为一列的话上层就用m_zerO模型了,这个if是为了保证helpIndex对应的肯定是训练数据
}
if (sortedIndices[0][helpIndex].length == 0) {//如果没数据,就直接反悔了
if (data.classAttribute().isNumeric()) {
m_Distribution = new double[2];//为什么是二维的?第一维存放方差,第二维存放weight,基于约定的编程方式
} else {
m_Distribution = new double[data.numClasses()];
}
m_ClassProbs = null;
sortedIndices[0] = null;
weights[0] = null;
return;
}
double priorVar = 0;//存放class的方差(其实是方差*num),只有class是数值才有意义,下面就是计算方差的过程。
if (data.classAttribute().isNumeric()) {
// 每个sortedIndices[0][i]里面的都是一个Instances的index不同排列而已,使用helpIndex只是为了保证别对应到classIndex上
double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0;
for (int i = 0; i < sortedIndices[0][helpIndex].length; i++) {
Instance inst = data.instance(sortedIndices[0][helpIndex][i]);
totalSum += inst.classValue() * weights[0][helpIndex][i];
totalSumSquared +=
inst.classValue() * inst.classValue() * weights[0][helpIndex][i];
totalSumOfWeights += weights[0][helpIndex][i];
}
priorVar = singleVariance(totalSum, totalSumSquared,
totalSumOfWeights);
}
//把分布拷贝一下
m_ClassProbs = new double[classProbs.length];
System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
if ((//退出条件有4个<span style="white-space:pre"> </span>//第一个是instances里面的totalweight总量(可以理解成里面的instance数量,因为weight默认都是1)小于两倍的minNum,minNum默认是2.
<span style="white-space:pre"> </span>totalWeight < (2 * minNum)) ||
// 如果是枚举类型,并且都在一类中
(data.classAttribute().isNominal() &&
Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)],
Utils.sum(m_ClassProbs))) ||
// 数值型则比较方差是否小于minVariance,这个minVariance默认是原始方差的0.001,从上层代码可以得知
(data.classAttribute().isNumeric() &&
((priorVar / totalWeight) < minVariance)) ||
// 达到最大深度
((m_MaxDepth >= 0) && (depth >= maxDepth))) {
// 设置成叶子
m_Attribute = -1;
if (data.classAttribute().isNominal()) {
// 设置枚举类型的分布
m_Distribution = new double[m_ClassProbs.length];
for (int i = 0; i < m_ClassProbs.length; i++) {
m_Distribution[i] = m_ClassProbs[i];
}
Utils.normalize(m_ClassProbs);
} else {
// 设置数值类型的“分布”
m_Distribution = new double[2];
m_Distribution[0] = priorVar;
m_Distribution[1] = totalWeight;
}
sortedIndices[0] = null;
weights[0] = null;
return;
}
// 下面是寻找分裂点的过程
double[] vals = new double[data.numAttributes()];//每个属性产生的信息增益
double[][][] dists = new double[data.numAttributes()][0][0];//每个属性下每个类的分布
double[][] props = new double[data.numAttributes()][0];//每个属性下class的概率,也就是根据上面这个数组的分布求概率
double[][] totalSubsetWeights = new double[data.numAttributes()][0];//每个属性下每个subset的数量
double[] splits = new double[data.numAttributes()];//每个属性的分裂点,如果是枚举型则为NaN
if (data.classAttribute().isNominal()) {
// 首先来看classAttribute是枚举类型的情况
for (int i = 0; i < data.numAttributes(); i++) {
if (i != data.classIndex()) {
splits[i] = distribution(props, dists, i, sortedIndices[0][i],
weights[0][i], totalSubsetWeights, data);//得到分裂点、概率和分布
vals[i] = gain(dists[i], priorVal(dists[i]));//得到信息增益
}
}
} else {
// 如果是数值类型则不算信息增益(为什么数值类型不算增益?只有因为枚举型才算的出信息熵)(吐个槽:话说这个if-else为啥不放在循环里面??)
for (int i = 0; i < data.numAttributes(); i++) {
if (i != data.classIndex()) {
splits[i] =
numericDistribution(props, dists, i, sortedIndices[0][i],
weights[0][i], totalSubsetWeights, data,
vals);
}
}
}
// 选出信息增益最大的作为分裂属性
m_Attribute = Utils.maxIndex(vals);
int numAttVals = dists[m_Attribute].length;
// 每个subset都要多于minNum,这样才算一个有效subset
int count = 0;
for (int i = 0; i < numAttVals; i++) {
if (totalSubsetWeights[m_Attribute][i] >= minNum) {
count++;
}
if (count > 1) {
break;
}
}
// 至少存在2个有效subset,才算是一个有效的split
if (Utils.gr(vals[m_Attribute], 0) && (count > 1)) {
// Set split point, proportions, and temp arrays
m_SplitPoint = splits[m_Attribute];
m_Prop = props[m_Attribute];
double[][] attSubsetDists = dists[m_Attribute];
double[] attTotalSubsetWeights = totalSubsetWeights[m_Attribute];
// 释放内存
vals = null;
dists = null;
props = null;
totalSubsetWeights = null;
splits = null;
// 得到subSet的有序index
int[][][][] subsetIndices =
new int[numAttVals][1][data.numAttributes()][0];
double[][][][] subsetWeights =
new double[numAttVals][1][data.numAttributes()][0];
splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint,
sortedIndices[0], weights[0], data);
// 释放内存
sortedIndices[0] = null;
weights[0] = null;
//释放内存
m_Successors = new Tree[numAttVals];
for (int i = 0; i < numAttVals; i++) {
m_Successors[i] = new Tree();//构建孩子节点
m_Successors[i].
buildTree(subsetIndices[i], subsetWeights[i],
data, attTotalSubsetWeights[i],
attSubsetDists[i], header, minNum,
minVariance, depth + 1, maxDepth);
// 还是释放内存
attSubsetDists[i] = null;
}
} else {
// 如果不存在2个有效的subset,就直接当叶子节点了
m_Attribute = -1;
sortedIndices[0] = null;
weights[0] = null;
}
// 构建attribute用于之后的分类过程(当然这是在没有prune和backfit情况下用的)
if (data.classAttribute().isNominal()) {
m_Distribution = new double[m_ClassProbs.length];
for (int i = 0; i < m_ClassProbs.length; i++) {
m_Distribution[i] = m_ClassProbs[i];
}
Utils.normalize(m_ClassProbs);
} else {
m_Distribution = new double[2];
m_Distribution[0] = priorVar;
m_Distribution[1] = totalWeight;
}
}Weka算法Classifier-trees-REPTree源码分析(一)
原文地址:http://blog.csdn.net/roger__wong/article/details/39453865