码迷,mamicode.com
首页 > 编程语言 > 详细

隐马尔可夫训练参数,BaumWelch算法,java实现【参考52nlp的博客算法原理实现】

时间:2016-01-05 22:44:08      阅读:1030      评论:0      收藏:0      [点我收藏+]

标签:

package jxutcm.edu.cn.hmm.model;

import jxutcm.edu.cn.hmm.bean.HMMHelper;
import jxutcm.edu.cn.util.TCMMath;

/**
 *  Baum-Welch算法也叫前向-后向算法:
 *  目的:
 *  1、在给定多个观测状态序列(多个观测序列,维数可不同)的条件下,训练和学习HMM模型的参数A,B,pi
 *  2、该算法得出的是一个局部最优解,较依赖于初始值
 * 
@author aool
 
*/
public class BaumWelch extends HMM{
    public int[][] O;
    public double threshold=0;
    public double MSP=0.001;
    public double MOP=0.001;
    public int Iteration=0;
    
    private BaumWelch(int[][] O, int N, int M){
        this.O=O;
        this.M=M;
        this.N=N;
    }
    
    public BaumWelch( int[][] O, int N, int M, double threshold){
        this(O, N, M);
        this.threshold=threshold;
    }
    
    public BaumWelch(int[][] O, int N, int M, int Iteration){
        this(O, N, M);
        this.Iteration=Iteration;
    }
    
    public BaumWelch( int[][] O, int N, int M, double threshold, double MSP, double MOP){
        this(O, N, M);
        this.threshold=threshold;
    }
    
    public BaumWelch( int[][] O, int N, int M,  int Iteration, double MSP, double MOP){
        this(O, N, M);
        this.Iteration=Iteration;
    }
    
    /**
     * 训练HMM模型的A,B,PI
     
*/
    public void buldModel() {
        int mTrain = O.length;

        this.logPI=HMMHelper.randomdiscrete(N);
        this.logA =new double[ N ][ ];
        this.logB =new double[ N ][ ];
        for (int i = 0; i < N; i++) {
            this.logA[ i ] = HMMHelper.randomdiscrete(N);
            this.logB[ i ] = HMMHelper.randomdiscrete(M);
        }
        
        double OldLogProb= Math.log(1.0E-10), NewLogProb=Math.log(1.0E-10);
        int T=0;
        int[] o;
        
        Forward fore;
        Backward back;
        
        double[] piSum=new double[ N ];
        double piSumSum= Double.NEGATIVE_INFINITY;
        
        double[][] gamma;
        double[][][] kexi;
        
        double[][] kexiSumA=new double[ N ][ N ];
        double[][]     gammaSumBB=new double[ N ][ M ];
        
        double[] gammaSumA=new double[ N ];
        double[] gammaSumB=new double[ N ];
        
        double gammaSumAi, gammaSumBi;
        
        for(int iteration=1; ; iteration++){
            piSumSum = Double.NEGATIVE_INFINITY;
            this.Init(piSum, kexiSumA, gammaSumA, gammaSumBB, gammaSumB);
            
            for(int sample = 0; sample< mTrain; sample++){
                o=O[ sample ];
                T=o.length;
                
                fore=new Forward(this, o);
                fore.CalculateForeMatrix();
                back=new Backward(this, o);
                back.CalculateBackMatrix();
                
                gamma=this.CalculateGamma(fore.alpha, back.beta);
                kexi=this.ComputeKeXi(fore.alpha, back.beta, o);
                for(int i=0; i< N; i++){
                    piSum[ i ] = TCMMath.logplus( piSum[ i ], gamma[ 0 ][ i ] );
                    piSumSum =TCMMath.logplus( piSumSum, gamma[ 0 ][ i ] );
                    
                    gammaSumAi=Double.NEGATIVE_INFINITY;
                    for(int t=0; t< T - 1 ; t++){//注意这里是T-1
                        gammaSumAi = TCMMath.logplus(gammaSumAi, gamma[ t ] [ i ]);
                    }
                    gammaSumA[ i ] =TCMMath.logplus(gammaSumA[ i ], gammaSumAi);//之前样本再累加上
                    
                    for(int j=0; j< N; j++){
                        for(int t=0; t< T - 1; t++){//注意这里是T-1
                            kexiSumA[ i ][ j ] =TCMMath.logplus(kexiSumA[ i ][ j ], kexi[ t ] [ i ] [ j ]);
                        }
                    }
                    
                    //【重新估计混淆矩阵的分子/分母】
                    gammaSumBi =  TCMMath.logplus(gammaSumAi, gamma[ T-1 ] [ i ]);
                    gammaSumB[ i ] = TCMMath.logplus(gammaSumB[ i ], gammaSumBi);//之前样本类加上
                    for(int k=0; k< M; k++){
                        for(int t=0; t< T; t++){//注意这里是T
                            if ( o[ t ] == k ){
                                gammaSumBB[ i ][ k ]= TCMMath.logplus(gammaSumBB[ i ][ k ], gamma[ t ] [ i ] );
                            }
                        }
                    }
                }
            }
            this.EstimateParameter(piSum, piSumSum, kexiSumA, gammaSumA, gammaSumBB, gammaSumB);
            
            for(int sample = 0; sample< mTrain; sample++){
                o=O[ sample ];
                fore=new Forward(this, o);
                fore.CalculateForeMatrix();
                NewLogProb =  fore.logProb();
            }
            
            this.CheckParameter(OldLogProb, NewLogProb, iteration);
            
            if( threshold !=0 && 
                    ( Math.abs(NewLogProb - OldLogProb) <= threshold || NewLogProb >=1)//浮点运算问题
                    ) break;//临界值约束
            if(Iteration !=0 && iteration >= Iteration) break;//迭代次数约束
            OldLogProb=NewLogProb;//保留上次新值
        }
    }
    
    private double[][] CalculateGamma(double[][] alpha, double[][] beta){
        int T=alpha.length;
        double[][] gamma=new double[ T ][ N ];
        double sum=Double.NEGATIVE_INFINITY;
        for(int t=0; t< T; t++){
            sum=Double.NEGATIVE_INFINITY;
            for(int i=0; i< N; i++){
                gamma[ t ][ i ] = alpha[ t ][ i ] + beta [ t ][ i ];
                sum = TCMMath.logplus(sum, gamma[ t ][ i ]);
            }
            for(int i=0; i< N; i++){
                gamma[ t ][ i ] = gamma[ t ][ i ] - sum;//归一化,保证各时刻的概率总和等于1 
            }
        }
        return gamma;
    }
    
    private double[][][] ComputeKeXi(double[][] alpha, double[][] beta, int[] O){
        int T=alpha.length;
        double[][][] kexi=new double[ T-1 ][ N ] [ N ];//注意这里是T-1,从i到j,最后时刻没有
        double sum=Double.NEGATIVE_INFINITY;
        forint t=0; t< T-1; t++){//最后时刻不用算
            sum=Double.NEGATIVE_INFINITY;
            forint i=0; i< N; i++){
                for(int j=0; j< N; j++){
                    kexi[ t ][ i ][ j ] = alpha[ t ][ i ] + logA[ i ][ j ] + logB[ j ][ O[t+1] ] + beta[ t+1 ][ j ];
                    sum = TCMMath.logplus( sum, kexi[ t ][ i ][ j ]);
                }
            }
            for(int i=0; i< N; i++){
                for(int j=0; j< N; j++){
                    kexi[ t ][ i ][ j ] = kexi[ t ][ i ][ j ] -sum;//归一化,保证各时刻的概率总和等于1 
                }
            }
        }
        return kexi;
    }
    
    private void Init(double[] piSum, double[][] kexiSumA, double[] gammaSumA, double[][] gammaSumBB, double[] gammaSumB){
        for(int i=0; i<N; i++){
            piSum[ i ] =Double.NEGATIVE_INFINITY;
            for(int j=0; j<N; j++){
                kexiSumA[ i ][ j ]= Double.NEGATIVE_INFINITY;
            }
            for(int j=0; j<M; j++){
                gammaSumBB[ i ][ j ]= Double.NEGATIVE_INFINITY;
            }
            gammaSumA[ i ]=Double.NEGATIVE_INFINITY;
            gammaSumB[ i ]= Double.NEGATIVE_INFINITY;
        }
    }
    
    private void EstimateParameter(double[] piSum, double piSumSum, double[][] kexiSumA, double[] gammaSumA, double[][] gammaSumBB, double[] gammaSumB){
        for(int i=0; i< N; i++){
            //初始概率向量
            
//pi[i] = MSP + (1 - MSP * N) * piSum[ i ] / piSumSum;//【修正训练计算错误和保证概率总和为1】
            logPI[ i ]= TCMMath.logplus( Math.log( MSP ),  Math.log( 1 - MSP * N ) +piSum[ i ] - piSumSum );//当前状态累加;
            
//转移矩阵
            for(int j=0; j< N; j++){
                //logA[ i ] [ j ]=MSP + (1 - MSP * N) * kexiSumA[ i ][ j ] / gammaSumA[ i ];//【修正训练计算错误和保证概率总和为1】
                logA[ i ][ j ] = TCMMath.logplus( Math.log( MSP ),  Math.log( 1 - MSP * N ) + kexiSumA[ i ][ j ] - gammaSumA[ i ]);
            }
            //发射矩阵
            for(int k=0; k< M; k++){
                //logB[ i ][ k ] = MOP + (1 - MOP * M) * gammaSumC /gammaSumB[ i ];//【修正训练计算错误和归一化保证概率总和为1】
                logB[ i ][ k ] = TCMMath.logplus( Math.log( MOP ),  Math.log( 1 - MOP * M ) + gammaSumBB[ i ][ k ] - gammaSumB[ i ]);
            }
        }
    }
    
    private void CheckParameter(double OldLogProb, double NewLogProb, int iteration){
        System.out.println("------------------------------------------------------------------------------------------------第"+iteration+"次迭代------------------------------------------------------------------------------------------------");
        double piSumSum=Double.NEGATIVE_INFINITY;
        System.out.println("初始矩阵横向之和:");
        for(int i=0; i< N; i++){
            piSumSum = TCMMath.logplus(piSumSum, logPI[ i ] );
        }
        System.out.print("和="+HMMHelper.fmtlog( piSumSum )+ "初始时刻的隐状态出现概率:" );
        for(int i=0; i< N; i++){
            System.out.print( HMMHelper.fmtlog( logPI[ i ]  ) );
        }
        System.out.println();
        
        System.out.println("转移矩阵横向之和:");
        for(int i=0; i< N; i++){
            piSumSum=Double.NEGATIVE_INFINITY;
            for(int j=0; j< N; j++){
                piSumSum = TCMMath.logplus(piSumSum, logA[ i ][ j ] );
            }
            System.out.print( "和="+HMMHelper.fmtlog( piSumSum ) + "【第"+i+"种隐状态】转移至【各隐状态】概率:" );
            for(int j=0; j< N; j++){
                System.out.print( HMMHelper.fmtlog( logA[ i ][ j ] ) );
            }
            System.out.println();
        }
        
        System.out.println("发射矩阵横向之和:");
        for(int i=0; i< N; i++){
            piSumSum=Double.NEGATIVE_INFINITY;
            for(int j=0; j< M; j++){
                piSumSum = TCMMath.logplus(piSumSum, logB[ i ][ j ] );
            }
            System.out.print(  "和="+HMMHelper.fmtlog( piSumSum ) +"【第"+i+"种隐状态】发射出【各显状态】概率:" );
            for(int j=0; j< M; j++){
                System.out.print( HMMHelper.fmtlog( logB[ i ][ j ] ) );
            }
            System.out.println();
        }
        
        System.out.println( "OldLogProb = " +HMMHelper.fmt( OldLogProb )  +"对应概率值:"+HMMHelper.fmtlog( OldLogProb ) 
        +"  NewLogProb =" +HMMHelper.fmt( NewLogProb )+"对应概率值:"+HMMHelper.fmtlog( NewLogProb ) 
        +"  |NewLogProb - OldLogProb| = "+Math.abs(NewLogProb - OldLogProb));
    }
}

隐马尔可夫训练参数,BaumWelch算法,java实现【参考52nlp的博客算法原理实现】

标签:

原文地址:http://www.cnblogs.com/whaozl/p/5103799.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!