标签:
本文是Ranklib部分源码的分析,参考了RankLib源码分析——guoguo881218的专栏以及Learning to Rank——wowarsenal,在此对原博主表示感谢
在The Lemur Project可以下载到Ranklib程序,Ranklib2.1和Ranklib2.3有源码可以下载,Ranklib2.4和Ranklib2.5只有jar文件可以下载。通过jad反编译后可以看到源码,整体结构差别不大。本文以Ranklib2.3为标准进行说明。
Ranklib程序主入口为ciir.umass.edu.eval.Evaluator
类中main函数.
其中public static void main(String[] args)
函数接收命令行传入参数。
for(int i=0;i<args.length;i++)
{
if(args[i].compareTo("-train")==0)
trainFile = args[++i]; //训练集
else if(args[i].compareTo("-ranker")==0)
rankerType = Integer.parseInt(args[++i]); //Rank类型
...
else if(args[i].compareTo("-metric2t")==0)
trainMetric = args[++i]; //训练集Metric
else if(args[i].compareTo("-metric2T")==0)
testMetric = args[++i]; //测试集Metric
...
else if(args[i].compareTo("-validate")==0)
validationFile = args[++i]; //验证集
else if(args[i].compareTo("-test")==0)
{
testFile = args[++i];
testFiles.add(testFile);
} //测试集
...
else if(args[i].compareTo("-save")==0)
Evaluator.modelFile = args[++i]; //模型保存位置
...
else if(args[i].compareTo("-load")==0)
{
savedModelFile = args[++i];
savedModelFiles.add(args[i]);
} //导入模型
...
else if(args[i].compareTo("-rank")==0)
rankFile = args[++i]; //待排序数据
... ... ...
//MART / LambdaMART / Random forest
else if(args[i].compareTo("-tree")==0)
{
LambdaMART.nTrees = Integer.parseInt(args[++i]);
RFRanker.nTrees = Integer.parseInt(args[i]);
} //树的棵树
else if(args[i].compareTo("-leaf")==0)
{
LambdaMART.nTreeLeaves = Integer.parseInt(args[++i]);
RFRanker.nTreeLeaves = Integer.parseInt(args[i]);
} //每棵树叶子结点数
else if(args[i].compareTo("-shrinkage")==0)
{
LambdaMART.learningRate = Float.parseFloat(args[++i]);
RFRanker.learningRate = Float.parseFloat(args[i]);
} //收缩系数
...
//Random forest
else if(args[i].compareTo("-bag")==0)
RFRanker.nBag = Integer.parseInt(args[++i]); //bags数目
if(nThread == -1)
nThread = Runtime.getRuntime().availableProcessors();
MyThreadPool.init(nThread); //线程池初始化
...
Evaluator e = new Evaluator(rType2[rankerType], trainMetric, testMetric); //根据Rank类型以及训练集、测试集上的评价函数生成Evaluator对象
... ...
RankerFactory rf = new RankerFactory();
rf.createRanker(rType2[rankerType]).printParameters();//根据参数创建Rank对象
...
e.evaluate() //多个实现,针对不同情况进行evaluate
...
if(testFiles.size() > 1)
e.test(savedModelFiles, testFiles, prpFile);
else
e.test(savedModelFiles, testFile, prpFile) //利用已有模型在测试集上进行预测s
public Evaluator(RANKER_TYPE rType, String trainMetric, String testMetric)
{
this.type = rType; //Ranke类型
trainScorer = mFact.createScorer(trainMetric); //训练集上得分
testScorer = mFact.createScorer(testMetric); //测试集上得分
...
}
//根据训练集验证集和测试集进行训练的evaluate()函数调用
public void evaluate(String trainFile, String validationFile, String testFile, String featureDefFile)
{
List<RankList> train = readInput(trainFile);
...
test = readInput(testFile);//读取训练、验证、测试文件
... ...
RankerTrainer trainer = new RankerTrainer();
Ranker ranker = trainer.train(type, train, validation, features, trainScorer);//利用训练集和验证集训练模型
... ...
double rankScore = evaluate(ranker, test); //计算测试集得分
... ...
ranker.save(modelFile); //保存模型
test函数调用
//根据已有模型以及测试集进行训练的test()函数调用
public void test(String modelFile, String testFile, String prpFile)
{
Ranker ranker = rFact.loadRanker(modelFile); //导入模型
List<RankList> test = readInput(testFile); //读取测试数据
RankList l = ranker.rank(test.get(i)); //排序评分
double score = testScorer.score(l); //取得评分
数据格式与SVM-Rank、libSVM、LETOR格式均相同。格式如下
<line> .=. <target> qid:<qid> <feature>:<value> <feature>:<value> ... <feature>:<value> # <info>
<target> .=. <positive integer> //正整数型评分
<qid> .=. <positive integer> //正整数型查询
<feature> .=. <positive integer> //正整数型特征序号
<value> .=. <float> //浮点型特征值
<info> .=. <string> //注释
ciir.umass.edu.learning.DataPoint
实现了需要评分的对象的数据结构。每个对象是一个待评分文档。
ciir.umass.edu.learning.RankList
实现了需要评分的对象组成的列表的数据结构。每个对象是一个包含对应于同一查询的不同文献的集合。
ciir.umass.edu.learning.RANKER_TYPE
枚举类型,包含各种Rank类型
ciir.umass.edu.learning.RankerFactory
实现了RankerFactory,所有Rank方法都需要在此类中注册。
public Ranker createRanker(RANKER_TYPE type)
创建某种类型的Rank对象。
public Ranker loadRanker(String modelFile)
导入已有模型。
ciir.umass.edu.learning.Ranker
Ranker类实现了一般的Rank接口,所有Rank类型都需要集成Ranker。
通用方法有:
public void setTrainingSet(List<RankList> samples) //设置训练集
public void setValidationSet(List<RankList> samples) //设置验证集
public double getScoreOnTrainingData() //训练集得分
...
public void save(String modelFile) //保存模型
public RankList rank(RankList rl) //给出评分后的排序
public List<RankList> rank(List<RankList> l) //给出评分后的排序
必须在子类中实现的方法有:
public void init() //初始化
public void learn() //学习
public double eval(DataPoint p) //评价
public String toString() //模型转为字符串
public void load(String fn) //导入模型
ciir.umass.edu.learning.RankerTrainer
实现了对模型进行训练的函数:
public Ranker train(RANKER_TYPE type, List<RankList> train, List<RankList> validation, int[] features, MetricScorer scorer)
ciir.umass.edu.learning.tree.FeatureHistogram
特征直方图类,对RankList对象进行特征的直方图统计,选择每次split时的最优feature和最优划分点。
public void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, int[] features, float[][] thresholds)
{
...
sum = new double[features.length][];
count = new int[features.length][];
...
}
* sum[i][j]:指定feature i 的所有值(训练数据中出现的值),每个j代表一个训练数据中出现的一个值,sum[i][j]的值为feature i的所有小于某个指定值(该值由threshold[j]提供)的训练数据datapoint的label(该算法里为lambda)之和。
* count[i][j]:指定feature i 的所有值(训练数据中出现的值),每个j代表一个训练数据中出现的一个值,sum[i][j]的值为feature i的所有小于某个指定值(该值由threshold[j]提供)的训练数据datapoint的总数。
protected void update(double[] labels)
用新的label更新sum[i][j]。
//findBestSplit方法:
protected Config findBestSplit(int[] usedFeatures, int minLeafSupport, int start, int end)
{
...
int countLeft = count[i][t]; //countLeft是该节点下某个feature的值小于指定值(备选s)的所有训练数据的总数
int countRight = totalCount - countLeft; //countRight是该节点下某个feature的值大于等于指定值(备选s)的所有训练数据的总数
...
double sumLeft = sum[i][t]; //sumLeft是该节点下某个feature的值小于指定值(备选s)的所有训练数据的lambad之和
double sumRight = sumResponse - sumLeft; //sumRight 是该节点下某个feature的值大于等于指定值(备选s)的所有训练数据的lambad之和
...
}
构建树的时候,输入为(xi,lambdai),其中lambdai代表着对xi的评分(影响排序结果,是增大还是减少)。最好的划分点,就是把增大的划分到一起(全部为正值,相加结果为sumA),减少的划分到一起(全部为负值,相加结果为sumb).此时的sumA*sumA/countA+sumB*sumB/countB为最大。
因此,这里的S的含义为:该划分点尽量把正值和负值区分开。正值表示:后续评分调大;负值表示:后续评分调小;lambdai就是si从newTree中获取的值,表示si的值如何调整才能满足C最大(类似梯度)。C表示的是排序后的NDCG,求其最大值。
ciir.umass.edu.learning.tree.RegressionTree
回归树实现。
protected int nodes = 10; //控制分裂的次数,这个次数是按照节点来算的,而不是按照层数来计算的
protected int minLeafSupport = 1;//控制分裂的次数,如果某个节点所包含的训练数据小于2*minLeafSupport ,则该节点不再分裂。
...
protected DataPoint[] trainingSamples = null; //训练的数据
protected double[] trainingLabels = null; //这里的lables就是y值,在lambdaMART里为lambda值
...
public void fit() //根据输入的数据以及lable值,生成回归树
ciir.umass.edu.learning.tree.Ensemble
ciir.umass.edu.learning.tree.LambdaMART
init()
public void init()
{
...
//将样本根绝特征排序,方便做树的分列时快速找出最优分列点
sortedIdx = new int[features.length][];
MyThreadPool p = MyThreadPool.getInstance();
if(p.size() == 1)//single-thread
sortSamplesByFeature(0, features.length-1);
...
//创建存放候选阈值(分列点)的表
thresholds = new float[features.length][];
for(int f=0;f<features.length;f++)
{...}
//计算特征直方图,加速寻找分列点
hist = new FeatureHistogram();
hist.construct(martSamples, pseudoResponses, sortedIdx, features, thresholds);
...
}
learn()
public void learn()
{
//开始梯度提升训练过程
for(int m=0; m<nTrees; m++)
{
PRINT(new int[]{7}, new String[]{(m+1)+""});
//计算lambdas (pseudo responses)
computePseudoResponses();
//根据新的label更新特征直方图
hist.update(pseudoResponses);
//回归决策树
RegressionTree rt = new RegressionTree(nTreeLeaves, martSamples, pseudoResponses, hist, minLeafSupport);
rt.fit();
//将新生成的树加入模型
ensemble.add(rt, learningRate);
//更新树的输出(同时计算利用Newton-Raphson方法计算gamma)
updateTreeOutput(rt);
//更新所有训练样本的模型输出
List<Split> leaves = rt.leaves();
for(int i=0;i<leaves.size();i++)
{
Split s = leaves.get(i);
int[] idx = s.getSamples();
for(int j=0;j<idx.length;j++)
modelScores[idx[j]] += learningRate * s.getOutput();
}
//评价模型
scoreOnTrainingData = computeModelScoreOnTraining();
//检验是否应该提前结束
if(m - bestModelOnValidation > nRoundToStopEarly)
break;
...
//回滚到在验证集上最优的模型
ensemble.remove(ensemble.treeCount()-1);
...
}
标签:
原文地址:http://blog.csdn.net/clheang/article/details/51685265