package jxutcm.edu.cn.hmm.model;
import jxutcm.edu.cn.hmm.bean.HMMHelper;
import jxutcm.edu.cn.util.TCMMath;
/**
* Baum-Welch算法也叫前向-后向算法:
* 目的:
* 1、在给定多个观测状态序列(多个观测序列,维数可不同)的条件下,训练和学习HMM模型的参数A,B,pi
* 2、该算法得出的是一个局部最优解,较依赖于初始值
* @author aool
*/
public class BaumWelch extends HMM{
public int[][] O;
public double threshold=0;
public double MSP=0.001;
public double MOP=0.001;
public int Iteration=0;
private BaumWelch(int[][] O, int N, int M){
this.O=O;
this.M=M;
this.N=N;
}
public BaumWelch( int[][] O, int N, int M, double threshold){
this(O, N, M);
this.threshold=threshold;
}
public BaumWelch(int[][] O, int N, int M, int Iteration){
this(O, N, M);
this.Iteration=Iteration;
}
public BaumWelch( int[][] O, int N, int M, double threshold, double MSP, double MOP){
this(O, N, M);
this.threshold=threshold;
}
public BaumWelch( int[][] O, int N, int M, int Iteration, double MSP, double MOP){
this(O, N, M);
this.Iteration=Iteration;
}
/**
* 训练HMM模型的A,B,PI
*/
public void buldModel() {
int mTrain = O.length;
this.logPI=HMMHelper.randomdiscrete(N);
this.logA =new double[ N ][ ];
this.logB =new double[ N ][ ];
for (int i = 0; i < N; i++) {
this.logA[ i ] = HMMHelper.randomdiscrete(N);
this.logB[ i ] = HMMHelper.randomdiscrete(M);
}
double OldLogProb= Math.log(1.0E-10), NewLogProb=Math.log(1.0E-10);
int T=0;
int[] o;
Forward fore;
Backward back;
double[] piSum=new double[ N ];
double piSumSum= Double.NEGATIVE_INFINITY;
double[][] gamma;
double[][][] kexi;
double[][] kexiSumA=new double[ N ][ N ];
double[][] gammaSumBB=new double[ N ][ M ];
double[] gammaSumA=new double[ N ];
double[] gammaSumB=new double[ N ];
double gammaSumAi, gammaSumBi;
for(int iteration=1; ; iteration++){
piSumSum = Double.NEGATIVE_INFINITY;
this.Init(piSum, kexiSumA, gammaSumA, gammaSumBB, gammaSumB);
for(int sample = 0; sample< mTrain; sample++){
o=O[ sample ];
T=o.length;
fore=new Forward(this, o);
fore.CalculateForeMatrix();
back=new Backward(this, o);
back.CalculateBackMatrix();
gamma=this.CalculateGamma(fore.alpha, back.beta);
kexi=this.ComputeKeXi(fore.alpha, back.beta, o);
for(int i=0; i< N; i++){
piSum[ i ] = TCMMath.logplus( piSum[ i ], gamma[ 0 ][ i ] );
piSumSum =TCMMath.logplus( piSumSum, gamma[ 0 ][ i ] );
gammaSumAi=Double.NEGATIVE_INFINITY;
for(int t=0; t< T - 1 ; t++){//注意这里是T-1
gammaSumAi = TCMMath.logplus(gammaSumAi, gamma[ t ] [ i ]);
}
gammaSumA[ i ] =TCMMath.logplus(gammaSumA[ i ], gammaSumAi);//之前样本再累加上
for(int j=0; j< N; j++){
for(int t=0; t< T - 1; t++){//注意这里是T-1
kexiSumA[ i ][ j ] =TCMMath.logplus(kexiSumA[ i ][ j ], kexi[ t ] [ i ] [ j ]);
}
}
//【重新估计混淆矩阵的分子/分母】
gammaSumBi = TCMMath.logplus(gammaSumAi, gamma[ T-1 ] [ i ]);
gammaSumB[ i ] = TCMMath.logplus(gammaSumB[ i ], gammaSumBi);//之前样本类加上
for(int k=0; k< M; k++){
for(int t=0; t< T; t++){//注意这里是T
if ( o[ t ] == k ){
gammaSumBB[ i ][ k ]= TCMMath.logplus(gammaSumBB[ i ][ k ], gamma[ t ] [ i ] );
}
}
}
}
}
this.EstimateParameter(piSum, piSumSum, kexiSumA, gammaSumA, gammaSumBB, gammaSumB);
for(int sample = 0; sample< mTrain; sample++){
o=O[ sample ];
fore=new Forward(this, o);
fore.CalculateForeMatrix();
NewLogProb = fore.logProb();
}
this.CheckParameter(OldLogProb, NewLogProb, iteration);
if( threshold !=0 &&
( Math.abs(NewLogProb - OldLogProb) <= threshold || NewLogProb >=1)//浮点运算问题
) break;//临界值约束
if(Iteration !=0 && iteration >= Iteration) break;//迭代次数约束
OldLogProb=NewLogProb;//保留上次新值
}
}
private double[][] CalculateGamma(double[][] alpha, double[][] beta){
int T=alpha.length;
double[][] gamma=new double[ T ][ N ];
double sum=Double.NEGATIVE_INFINITY;
for(int t=0; t< T; t++){
sum=Double.NEGATIVE_INFINITY;
for(int i=0; i< N; i++){
gamma[ t ][ i ] = alpha[ t ][ i ] + beta [ t ][ i ];
sum = TCMMath.logplus(sum, gamma[ t ][ i ]);
}
for(int i=0; i< N; i++){
gamma[ t ][ i ] = gamma[ t ][ i ] - sum;//归一化,保证各时刻的概率总和等于1
}
}
return gamma;
}
private double[][][] ComputeKeXi(double[][] alpha, double[][] beta, int[] O){
int T=alpha.length;
double[][][] kexi=new double[ T-1 ][ N ] [ N ];//注意这里是T-1,从i到j,最后时刻没有
double sum=Double.NEGATIVE_INFINITY;
for( int t=0; t< T-1; t++){//最后时刻不用算
sum=Double.NEGATIVE_INFINITY;
for( int i=0; i< N; i++){
for(int j=0; j< N; j++){
kexi[ t ][ i ][ j ] = alpha[ t ][ i ] + logA[ i ][ j ] + logB[ j ][ O[t+1] ] + beta[ t+1 ][ j ];
sum = TCMMath.logplus( sum, kexi[ t ][ i ][ j ]);
}
}
for(int i=0; i< N; i++){
for(int j=0; j< N; j++){
kexi[ t ][ i ][ j ] = kexi[ t ][ i ][ j ] -sum;//归一化,保证各时刻的概率总和等于1
}
}
}
return kexi;
}
private void Init(double[] piSum, double[][] kexiSumA, double[] gammaSumA, double[][] gammaSumBB, double[] gammaSumB){
for(int i=0; i<N; i++){
piSum[ i ] =Double.NEGATIVE_INFINITY;
for(int j=0; j<N; j++){
kexiSumA[ i ][ j ]= Double.NEGATIVE_INFINITY;
}
for(int j=0; j<M; j++){
gammaSumBB[ i ][ j ]= Double.NEGATIVE_INFINITY;
}
gammaSumA[ i ]=Double.NEGATIVE_INFINITY;
gammaSumB[ i ]= Double.NEGATIVE_INFINITY;
}
}
private void EstimateParameter(double[] piSum, double piSumSum, double[][] kexiSumA, double[] gammaSumA, double[][] gammaSumBB, double[] gammaSumB){
for(int i=0; i< N; i++){
//初始概率向量
//pi[i] = MSP + (1 - MSP * N) * piSum[ i ] / piSumSum;//【修正训练计算错误和保证概率总和为1】
logPI[ i ]= TCMMath.logplus( Math.log( MSP ), Math.log( 1 - MSP * N ) +piSum[ i ] - piSumSum );//当前状态累加;
//转移矩阵
for(int j=0; j< N; j++){
//logA[ i ] [ j ]=MSP + (1 - MSP * N) * kexiSumA[ i ][ j ] / gammaSumA[ i ];//【修正训练计算错误和保证概率总和为1】
logA[ i ][ j ] = TCMMath.logplus( Math.log( MSP ), Math.log( 1 - MSP * N ) + kexiSumA[ i ][ j ] - gammaSumA[ i ]);
}
//发射矩阵
for(int k=0; k< M; k++){
//logB[ i ][ k ] = MOP + (1 - MOP * M) * gammaSumC /gammaSumB[ i ];//【修正训练计算错误和归一化保证概率总和为1】
logB[ i ][ k ] = TCMMath.logplus( Math.log( MOP ), Math.log( 1 - MOP * M ) + gammaSumBB[ i ][ k ] - gammaSumB[ i ]);
}
}
}
private void CheckParameter(double OldLogProb, double NewLogProb, int iteration){
System.out.println("------------------------------------------------------------------------------------------------第"+iteration+"次迭代------------------------------------------------------------------------------------------------");
double piSumSum=Double.NEGATIVE_INFINITY;
System.out.println("初始矩阵横向之和:");
for(int i=0; i< N; i++){
piSumSum = TCMMath.logplus(piSumSum, logPI[ i ] );
}
System.out.print("和="+HMMHelper.fmtlog( piSumSum )+ "初始时刻的隐状态出现概率:" );
for(int i=0; i< N; i++){
System.out.print( HMMHelper.fmtlog( logPI[ i ] ) );
}
System.out.println();
System.out.println("转移矩阵横向之和:");
for(int i=0; i< N; i++){
piSumSum=Double.NEGATIVE_INFINITY;
for(int j=0; j< N; j++){
piSumSum = TCMMath.logplus(piSumSum, logA[ i ][ j ] );
}
System.out.print( "和="+HMMHelper.fmtlog( piSumSum ) + "【第"+i+"种隐状态】转移至【各隐状态】概率:" );
for(int j=0; j< N; j++){
System.out.print( HMMHelper.fmtlog( logA[ i ][ j ] ) );
}
System.out.println();
}
System.out.println("发射矩阵横向之和:");
for(int i=0; i< N; i++){
piSumSum=Double.NEGATIVE_INFINITY;
for(int j=0; j< M; j++){
piSumSum = TCMMath.logplus(piSumSum, logB[ i ][ j ] );
}
System.out.print( "和="+HMMHelper.fmtlog( piSumSum ) +"【第"+i+"种隐状态】发射出【各显状态】概率:" );
for(int j=0; j< M; j++){
System.out.print( HMMHelper.fmtlog( logB[ i ][ j ] ) );
}
System.out.println();
}
System.out.println( "OldLogProb = " +HMMHelper.fmt( OldLogProb ) +"对应概率值:"+HMMHelper.fmtlog( OldLogProb )
+" NewLogProb =" +HMMHelper.fmt( NewLogProb )+"对应概率值:"+HMMHelper.fmtlog( NewLogProb )
+" |NewLogProb - OldLogProb| = "+Math.abs(NewLogProb - OldLogProb));
}
}