码迷,mamicode.com
首页 > 其他好文 > 详细

mahout贝叶斯算法拓展篇3---分类无标签数据

时间:2014-07-20 23:19:06      阅读:463      评论:0      收藏:0      [点我收藏+]

标签:des   blog   http   java   使用   os   

代码测试环境:Hadoop2.4+Mahout1.0

前面博客:mahout贝叶斯算法开发思路(拓展篇)1mahout贝叶斯算法开发思路(拓展篇)2 分析了Mahout中贝叶斯算法针对数值型数据的处理。在前面这两篇博客中并没有关于如何分类不带标签的原始数据的处理。下面这篇博客就针对这样的数据进行处理。

最新版(适合Hadoop2.4+mahout1.0环境)源码以及jar包可以在这里下载Mahout贝叶斯分类不含标签数据

下载后参考使用里面的jar包中的fz.bayes.model.BayesRunner 调用贝叶斯模型建立算法,这里不多介绍,下面是分类无标签数据思路。


输入数据:

0.2,0.3,0.4
0.32,0.43,0.45
0.23,0.33,0.54
2.4,2.5,2.6
2.3,2.2,2.1
5.4,7.2,7.2
5.6,7,6
5.8,7.1,6.3
6,6,5.4
11,12,13
这个数据和原始数据相比就是少了最后一列label而已。

分类主程序:

package fz.bayes;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Map;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.BayesUtils;
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.training.WeightsMapper;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.math.Vector;
/**
 * 用于分类的Job
 * 针对
 * [
 *   2.1,3.2,1.2
	 2.1,3.2,1.3
   ]
   的数据,进行分类(即不含标签的数据)
 * @author fansy
 *
 */
public class BayesClassifiedJob extends AbstractJob {
	/**
	 * @param args
	 * @throws Exception 
	 */
	public static void main(String[] args) throws Exception {
		ToolRunner.run(new Configuration(), new BayesClassifiedJob(),args);
	}
	
	@Override
	public int run(String[] args) throws Exception {
		addInputOption();
	    addOutputOption();
	    addOption("model","m", "The file where bayesian model store ");
	    addOption("labelIndex","labelIndex", "The file where the index store ");
	    addOption("labelNumber","ln", "The labels number ");
	    addOption("mapreduce","mr", "Whether use mapreduce, true use ,else not use ");
	    addOption("SV","SV","The input vector splitter ,default is comma",",");
	    
	    if (parseArguments(args) == null) {
		      return -1;
		}
	    Configuration conf=getConf();
	    Path input = getInputPath();
	    Path output = getOutputPath();
	    String labelNumber=getOption("labelNumber");
	    String modelPath=getOption("model");
	    String useMR = getOption("mapreduce");
	    String SV = getOption("SV");
	    String labelIndex = getOption("labelIndex");
	    int returnCode=-1;
	    if("true".endsWith(useMR)){
	    	returnCode = useMRToClassify(conf,labelNumber,modelPath,input,output,SV,labelIndex);
	    }else{
	    	returnCode = classify(conf,input, output, labelNumber, modelPath, SV, labelIndex);
	    }
	    return returnCode;
	}
	/**
	 * 单机版
	 * @param conf
	 * @param input
	 * @param output
	 * @param labelNumber
	 * @param modelPath
	 * @param sv
	 * @param labelIndex
	 * @return
	 * @throws IOException 
	 * @throws IllegalArgumentException 
	 */
	private int classify(Configuration conf, Path input ,Path output ,String labelNumber,String modelPath,
			String sv,String labelIndex)  {
		// 读取模型参数 
		try{
		NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), conf);
		 AbstractNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model);
		 Map<Integer, String> labelMap = BayesUtils.readLabelIndex(conf, new Path(labelIndex));
		 Path outputPath =new Path(output,"result");
		 // 按行读取文件,并把分类的结果写入另外的文件
		 FileSystem fs =FileSystem.get(input.toUri(),conf);
		 FSDataInputStream in=fs.open(input);  
		 
	     InputStreamReader istr=new InputStreamReader(in);  
	     BufferedReader br=new BufferedReader(istr);  
	     if(fs.exists(outputPath)){
	    	 fs.delete(outputPath, true);
	     }
	     FSDataOutputStream out = fs.create(outputPath);
	     
	     String lines;
	     StringBuffer buff = new StringBuffer();
	     while((lines=br.readLine())!=null&&!"".equals(lines)){  
	    	 String[] line = lines.toString().split(sv);
	    	 if(line.length<1){
	    		 break;
	    	 }
			  Vector original =BayesUtil.transformToVector(line);
        	  Vector result = classifier.classifyFull(original);
        	  String label = BayesUtil.classifyVector(result, labelMap);
        	  buff.append(lines+sv+label+"\n");
//        	 out.writeUTF(lines+sv+label);
//        	 out.
	     }
	     out.writeUTF(buff.substring(0, buff.length()-1));
	     out.flush();
	     out.close();
	     br.close();
	     istr.close();
	     in.close();
//	     fs.close();
		}catch(Exception e){
			e.printStackTrace();
			return -1;
		}
		return 0;
	}
/**
 * MR 版
 * @param conf
 * @param labelNumber
 * @param modelPath
 * @param input
 * @param output
 * @param SV
 * @param labelIndex
 * @return
 * @throws IOException
 * @throws ClassNotFoundException
 * @throws InterruptedException
 */
	private int useMRToClassify(Configuration conf, String labelNumber, String modelPath, Path input, Path output, 
			String SV, String labelIndex) throws IOException, ClassNotFoundException, InterruptedException {
		
	    conf.set(WeightsMapper.class.getName() + ".numLabels",labelNumber);
	    conf.set("SV", SV);
	    conf.set("labelIndex", labelIndex);
	    HadoopUtil.cacheFiles(new Path(modelPath), conf);
	    HadoopUtil.delete(conf, output);
	    Job job=Job.getInstance(conf, "");
	    job.setJobName("Use bayesian model to classify the  input:"+input.getName());
	    job.setJarByClass(BayesClassifiedJob.class); 
	    
	    job.setInputFormatClass(TextInputFormat.class);
	    job.setOutputFormatClass(TextOutputFormat.class);
	    
	    job.setMapperClass(BayesClassifyMapper.class);
	    job.setMapOutputKeyClass(Text.class);
	    job.setMapOutputValueClass(Text.class);
	    job.setNumReduceTasks(0);
	    job.setOutputKeyClass(Text.class);
	    job.setOutputValueClass(Text.class);
	    FileInputFormat.setInputPaths(job, input);
	    FileOutputFormat.setOutputPath(job, output);
	    
	    if(job.waitForCompletion(true)){
	    	return 0;
	    }
		return -1;
	}
	
	
	
}
如果使用MR,则Mapper如下:

package fz.bayes;

import java.io.IOException;
import java.util.Map;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.BayesUtils;
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
import org.apache.mahout.math.Vector;

/**
 *  自定义Mapper,输出当前值和分类的结果
 * @author Administrator
 *
 */
@SuppressWarnings("deprecation")
public  class BayesClassifyMapper extends Mapper<LongWritable, Text, Text, Text>{
	private AbstractNaiveBayesClassifier classifier;
	private String SV;
	private Map<Integer, String> labelMap;
	private String labelIndex;
		@Override
	  public void setup(Context context) throws IOException, InterruptedException {
			
	    Configuration conf = context.getConfiguration();
	    Path modelPath = new Path(DistributedCache.getCacheFiles(conf)[0].getPath());
	    NaiveBayesModel model = NaiveBayesModel.materialize(modelPath, conf);
	    classifier = new StandardNaiveBayesClassifier(model);
	    SV = conf.get("SV");
	    labelIndex=conf.get("labelIndex");
		labelMap = BayesUtils.readLabelIndex(conf, new Path(labelIndex));
	  }

	  @Override
	  public void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
		  String values =value.toString();
		  if("".equals(values)){
			  context.getCounter("Records", "Bad Record").increment(1);
			  return; 
		  }
		  String[] line = values.split(SV);
		  
		  Vector original =BayesUtil.transformToVector(line);
     	  Vector result = classifier.classifyFull(original);
     	  String label = BayesUtil.classifyVector(result, labelMap);
	    
	    //the key is the vector 
	    context.write(value, new Text(label));
	  }
}


用到的工具类:

package fz.bayes;

import java.util.Map;

import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;

public class BayesUtil {

	/**
	 * 把输入字符串转换为Vector
	 * @param lines
	 * @return
	 */
	public static Vector transformToVector(String[] line){
		Vector v=new RandomAccessSparseVector(line.length);
		for(int i=0;i<line.length;i++){
			double item=0;
			try{
				item=Double.parseDouble(line[i]);
			}catch(Exception e){
				return null; // 如果不可以转换,说明输入数据有问题
			}
			v.setQuick(i, item);
		}
		return v;
	}
	/**
	 * 根据得分值分类
	 * @param v
	 * @param labelMap
	 * @return
	 */
	public static String classifyVector(Vector v,Map<Integer, String> labelMap){
		int bestIdx = Integer.MIN_VALUE;
		double bestScore = Long.MIN_VALUE;
		for (Vector.Element element : v.all()) {
			if (element.get() > bestScore) {
				bestScore = element.get();
				bestIdx = element.index();
			}
		}
		if (bestIdx != Integer.MIN_VALUE) {
			ClassifierResult classifierResult = new ClassifierResult(labelMap.get(bestIdx), bestScore);
			return classifierResult.getLabel();
		}
		
		return null;
	}
}
这里稍微分析下思路(参考单机版代码或者Mapper代码):

1. 读取模型,参数模型路径、标签的编码文件(labelIndex.bin),标签的个数(labelNumber),根据相关路径,初始化模型相关变量;

2. 针对每条记录 ,比如 0.2,0.3,0.4 ,根据SV(输入路径向量的分隔符)把这条记录向量化,得到Vector(0=0.2,1=0.3,2=0.4);

3. 使用模型计算每个标签的得分,得到的也是一个向量,记录了每个标签的分数Vector result = classifier.classifyFull(original); 即result 向量;

4. 根据标签的得分,得出该条记录属于哪个标签,最后反编码(由于标签是经过编码得到的,所以这里需要经过反编码)。

这里看下输出结果:

MR版:

bubuko.com,布布扣

单机版:

bubuko.com,布布扣

可以看到单机版,第一行输出有一个乱码,这个其实是没有影响的,使用hadoop fs -cat 读取是没有问题的。


分享,成长,快乐

转载请注明blog地址:http://blog.csdn.net/fansy1990



mahout贝叶斯算法拓展篇3---分类无标签数据,布布扣,bubuko.com

mahout贝叶斯算法拓展篇3---分类无标签数据

标签:des   blog   http   java   使用   os   

原文地址:http://blog.csdn.net/fansy1990/article/details/37991447

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!