码迷,mamicode.com
首页 > 其他好文 > 详细

Ranklib部分源码分析

时间:2016-06-16 15:02:17      阅读:724      评论:0      收藏:0      [点我收藏+]

标签:

Ranklib部分源码分析(LambdaMART+RandomForest)

声明

本文是Ranklib部分源码的分析,参考了RankLib源码分析——guoguo881218的专栏以及Learning to Rank——wowarsenal,在此对原博主表示感谢

关于Ranklib

The Lemur Project可以下载到Ranklib程序,Ranklib2.1和Ranklib2.3有源码可以下载,Ranklib2.4和Ranklib2.5只有jar文件可以下载。通过jad反编译后可以看到源码,整体结构差别不大。本文以Ranklib2.3为标准进行说明。

主框架(Evlauator.java)

程序主入口main函数

Ranklib程序主入口为ciir.umass.edu.eval.Evaluator类中main函数.
其中public static void main(String[] args)函数接收命令行传入参数。

  1. 首先初始化一些变量并根据传入参数给变量赋值:
        for(int i=0;i<args.length;i++)
        {
            if(args[i].compareTo("-train")==0)
                trainFile = args[++i]; //训练集
            else if(args[i].compareTo("-ranker")==0)
                rankerType = Integer.parseInt(args[++i]); //Rank类型
            ...
            else if(args[i].compareTo("-metric2t")==0)
                trainMetric = args[++i]; //训练集Metric
            else if(args[i].compareTo("-metric2T")==0)
                testMetric = args[++i]; //测试集Metric
            ...
            else if(args[i].compareTo("-validate")==0)
                validationFile = args[++i]; //验证集
            else if(args[i].compareTo("-test")==0)
            {
                testFile = args[++i];
                testFiles.add(testFile);
            } //测试集
            ...
            else if(args[i].compareTo("-save")==0)
                Evaluator.modelFile = args[++i]; //模型保存位置
            ...
            else if(args[i].compareTo("-load")==0)
            {
                savedModelFile = args[++i];
                savedModelFiles.add(args[i]);
            } //导入模型
            ...
            else if(args[i].compareTo("-rank")==0)
                rankFile = args[++i]; //待排序数据
            ... ... ...
            //MART / LambdaMART / Random forest
            else if(args[i].compareTo("-tree")==0)
            {
                LambdaMART.nTrees = Integer.parseInt(args[++i]);
                RFRanker.nTrees = Integer.parseInt(args[i]);
            } //树的棵树
            else if(args[i].compareTo("-leaf")==0)
            {
                LambdaMART.nTreeLeaves = Integer.parseInt(args[++i]);
                RFRanker.nTreeLeaves = Integer.parseInt(args[i]);
            } //每棵树叶子结点数
            else if(args[i].compareTo("-shrinkage")==0)
            {
                LambdaMART.learningRate = Float.parseFloat(args[++i]);
                RFRanker.learningRate = Float.parseFloat(args[i]);
            } //收缩系数
             ...
            //Random forest
            else if(args[i].compareTo("-bag")==0)
                RFRanker.nBag = Integer.parseInt(args[++i]); //bags数目
  1. 根据参数变量进行训练
        if(nThread == -1)
            nThread = Runtime.getRuntime().availableProcessors();
        MyThreadPool.init(nThread); //线程池初始化
        ...
        Evaluator e = new Evaluator(rType2[rankerType], trainMetric, testMetric); //根据Rank类型以及训练集、测试集上的评价函数生成Evaluator对象
        ... ...
            RankerFactory rf = new RankerFactory();
            rf.createRanker(rType2[rankerType]).printParameters();//根据参数创建Rank对象
                ...
                e.evaluate() //多个实现,针对不同情况进行evaluate
                ...
                        if(testFiles.size() > 1)
                            e.test(savedModelFiles, testFiles, prpFile);
                        else
                            e.test(savedModelFiles, testFile, prpFile) //利用已有模型在测试集上进行预测s

训练器evaluate函数

  1. Evaluator初始化
    public Evaluator(RANKER_TYPE rType, String trainMetric, String testMetric)
    {
        this.type = rType; //Ranke类型
        trainScorer = mFact.createScorer(trainMetric); //训练集上得分
        testScorer = mFact.createScorer(testMetric); //测试集上得分
        ...
    }
  1. evaluate函数调用
//根据训练集验证集和测试集进行训练的evaluate()函数调用
    public void evaluate(String trainFile, String validationFile, String testFile, String featureDefFile)
    {
        List<RankList> train = readInput(trainFile);
        ...
            test = readInput(testFile);//读取训练、验证、测试文件
            ... ...

        RankerTrainer trainer = new RankerTrainer();
        Ranker ranker = trainer.train(type, train, validation, features, trainScorer);//利用训练集和验证集训练模型
        ... ...
            double rankScore = evaluate(ranker, test); //计算测试集得分
            ... ...
            ranker.save(modelFile); //保存模型

预测test函数

test函数调用

//根据已有模型以及测试集进行训练的test()函数调用
    public void test(String modelFile, String testFile, String prpFile)
    {
        Ranker ranker = rFact.loadRanker(modelFile); //导入模型

        List<RankList> test = readInput(testFile); //读取测试数据

            RankList l = ranker.rank(test.get(i)); //排序评分
            double score = testScorer.score(l); //取得评分

数据结构基础类(DataPoint、RankList)

数据格式

数据格式与SVM-Rank、libSVM、LETOR格式均相同。格式如下

<line> .=. <target> qid:<qid> <feature>:<value> <feature>:<value> ... <feature>:<value> # <info>
<target> .=. <positive integer> //正整数型评分
<qid> .=. <positive integer> //正整数型查询
<feature> .=. <positive integer> //正整数型特征序号
<value> .=. <float> //浮点型特征值
<info> .=. <string> //注释

DataPoint.java

ciir.umass.edu.learning.DataPoint

实现了需要评分的对象的数据结构。每个对象是一个待评分文档。

RankList.java

ciir.umass.edu.learning.RankList
实现了需要评分的对象组成的列表的数据结构。每个对象是一个包含对应于同一查询的不同文献的集合。

Rank基础类(RANKER_TYPE、RankerFactory、Ranker)

RANKER_TYPE.java

ciir.umass.edu.learning.RANKER_TYPE

枚举类型,包含各种Rank类型

RankerFactory.java

ciir.umass.edu.learning.RankerFactory

实现了RankerFactory,所有Rank方法都需要在此类中注册。

public Ranker createRanker(RANKER_TYPE type)创建某种类型的Rank对象。
public Ranker loadRanker(String modelFile)导入已有模型。

Ranker.java

ciir.umass.edu.learning.Ranker

Ranker类实现了一般的Rank接口,所有Rank类型都需要集成Ranker。

通用方法有:

public void setTrainingSet(List<RankList> samples) //设置训练集
public void setValidationSet(List<RankList> samples) //设置验证集
public double getScoreOnTrainingData() //训练集得分
 ...
public void save(String modelFile) //保存模型
public RankList rank(RankList rl) //给出评分后的排序
public List<RankList> rank(List<RankList> l) //给出评分后的排序

必须在子类中实现的方法有:

public void init() //初始化
public void learn() //学习
public double eval(DataPoint p) //评价
public String toString() //模型转为字符串
public void load(String fn) //导入模型

RankerTrainer

ciir.umass.edu.learning.RankerTrainer
实现了对模型进行训练的函数:
public Ranker train(RANKER_TYPE type, List<RankList> train, List<RankList> validation, int[] features, MetricScorer scorer)

LambdaMART基础类(FeatureHistorgram、RegressionTree、Split、Ensemble)

FeatureHistorgram.java

ciir.umass.edu.learning.tree.FeatureHistogram

特征直方图类,对RankList对象进行特征的直方图统计,选择每次split时的最优feature和最优划分点。

  • construct方法:
public void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, int[] features, float[][] thresholds)
    {
    ...
        sum = new double[features.length][];
        count = new int[features.length][];
    ...
    }
 * sum[i][j]:指定feature i 的所有值(训练数据中出现的值),每个j代表一个训练数据中出现的一个值,sum[i][j]的值为feature i的所有小于某个指定值(该值由threshold[j]提供)的训练数据datapoint的label(该算法里为lambda)之和。
 * count[i][j]:指定feature i 的所有值(训练数据中出现的值),每个j代表一个训练数据中出现的一个值,sum[i][j]的值为feature i的所有小于某个指定值(该值由threshold[j]提供)的训练数据datapoint的总数。
  • update方法:
protected void update(double[] labels)

用新的label更新sum[i][j]。

  • findBestSplit方法思路:
    • 选取feature作为划分的备选(可全选,可选部分)。
    • 选取最优feature和最优划分点
      • 计算每个feature的每个划分点,doubleS=sumLeft?sumLeft/countLeft+sumRight?sumRight/countRight,最小的S即为最优feature和最优划分点s(该s是feature的具体值)。
//findBestSplit方法:
    protected Config findBestSplit(int[] usedFeatures, int minLeafSupport, int start, int end)
    {
    ...
                int countLeft = count[i][t]; //countLeft是该节点下某个feature的值小于指定值(备选s)的所有训练数据的总数
                int countRight = totalCount - countLeft; //countRight是该节点下某个feature的值大于等于指定值(备选s)的所有训练数据的总数
                ...
                double sumLeft = sum[i][t]; //sumLeft是该节点下某个feature的值小于指定值(备选s)的所有训练数据的lambad之和
                double sumRight = sumResponse - sumLeft; //sumRight 是该节点下某个feature的值大于等于指定值(备选s)的所有训练数据的lambad之和
                ...
    }

构建树的时候,输入为(xi,lambdai),其中lambdai代表着对xi的评分(影响排序结果,是增大还是减少)。最好的划分点,就是把增大的划分到一起(全部为正值,相加结果为sumA),减少的划分到一起(全部为负值,相加结果为sumb).此时的sumA*sumA/countA+sumB*sumB/countB为最大。
因此,这里的S的含义为:该划分点尽量把正值和负值区分开。正值表示:后续评分调大;负值表示:后续评分调小;lambdai就是si从newTree中获取的值,表示si的值如何调整才能满足C最大(类似梯度)。C表示的是排序后的NDCG,求其最大值。

RegressionTree.java

ciir.umass.edu.learning.tree.RegressionTree
回归树实现。

    protected int nodes = 10; //控制分裂的次数,这个次数是按照节点来算的,而不是按照层数来计算的
    protected int minLeafSupport = 1;//控制分裂的次数,如果某个节点所包含的训练数据小于2*minLeafSupport ,则该节点不再分裂。
    ...
    protected DataPoint[] trainingSamples = null; //训练的数据
    protected double[] trainingLabels = null; //这里的lables就是y值,在lambdaMART里为lambda值
    ...
public void fit() //根据输入的数据以及lable值,生成回归树

Ensemble.java

ciir.umass.edu.learning.tree.Ensemble

LambdaMART流程

ciir.umass.edu.learning.tree.LambdaMART

  • LambdaMART初始化函数init()
    a. 设置训练数据,为每个训练数据i设置初值(0),为每个训练数据的y设置初值(0),为每个训练数据的w设置初值(0)
    b. 按照每个feature的大小重新排训练数据,为方便后面的计算。
    c. 每个feature都设置一批值以供后续做回归树split时的切分点。
    d. 初始化一个回归树(该树未进行分裂)
    public void init()
    {
        ...
        //将样本根绝特征排序,方便做树的分列时快速找出最优分列点
        sortedIdx = new int[features.length][];
        MyThreadPool p = MyThreadPool.getInstance();
        if(p.size() == 1)//single-thread
            sortSamplesByFeature(0, features.length-1);
        ...
        //创建存放候选阈值(分列点)的表
        thresholds = new float[features.length][];
        for(int f=0;f<features.length;f++)
        {...}

        //计算特征直方图,加速寻找分列点
        hist = new FeatureHistogram();
        hist.construct(martSamples, pseudoResponses, sortedIdx, features, thresholds);
        ...
    }
  • LambdaMART训练函数learn()
    生成指定数目的tree,以下为生成一个树的流程。
    1. 清空以前生成的pseudoResponses(yi),weights(wi)
    2. computePseudoResponses函数中生成新的pseudoResponses(yi),weights(wi)
    3. 用新生成的pseudoResponses(yi)来更新回归树
    4. 生成一棵新的回归树,并保存结果
    5. 求得γlk。
    6. 重新计算modelScores,即每个训练数据的评分
    7. 通过early stop的方式校验数据和退出
    public void learn()
    {
        //开始梯度提升训练过程
        for(int m=0; m<nTrees; m++)
        {
            PRINT(new int[]{7}, new String[]{(m+1)+""});

            //计算lambdas (pseudo responses)
            computePseudoResponses();

            //根据新的label更新特征直方图
            hist.update(pseudoResponses);

            //回归决策树
            RegressionTree rt = new RegressionTree(nTreeLeaves, martSamples, pseudoResponses, hist, minLeafSupport);
            rt.fit();

            //将新生成的树加入模型
            ensemble.add(rt, learningRate);

            //更新树的输出(同时计算利用Newton-Raphson方法计算gamma)
            updateTreeOutput(rt);

            //更新所有训练样本的模型输出
            List<Split> leaves = rt.leaves();
            for(int i=0;i<leaves.size();i++)
            {
                Split s = leaves.get(i);
                int[] idx = s.getSamples();
                for(int j=0;j<idx.length;j++)
                    modelScores[idx[j]] += learningRate * s.getOutput();
            }

            //评价模型
            scoreOnTrainingData = computeModelScoreOnTraining();

            //检验是否应该提前结束
            if(m - bestModelOnValidation > nRoundToStopEarly)
                break;
            ...
        //回滚到在验证集上最优的模型
            ensemble.remove(ensemble.treeCount()-1);
        ...
    }

Ranklib部分源码分析

标签:

原文地址:http://blog.csdn.net/clheang/article/details/51685265

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!