码迷,mamicode.com
首页 > 编程语言 > 详细

Mahout贝叶斯算法拓展篇3---分类无标签数据

时间:2016-03-05 20:13:09      阅读:298      评论:0      收藏:0      [点我收藏+]

标签:

代码測试环境:Hadoop2.4+Mahout1.0

前面博客:mahout贝叶斯算法开发思路(拓展篇)1mahout贝叶斯算法开发思路(拓展篇)2 分析了Mahout中贝叶斯算法针对数值型数据的处理。在前面这两篇博客中并没有关于怎样分类不带标签的原始数据的处理。

以下这篇博客就针对这种数据进行处理。

最新版(适合Hadoop2.4+mahout1.0环境)源代码以及jar包能够在这里下载Mahout贝叶斯分类不含标签数据

下载后參考使用里面的jar包中的fz.bayes.model.BayesRunner 调用贝叶斯模型建立算法,这里不多介绍,以下是分类无标签数据思路。


输入数据:

0.2,0.3,0.4
0.32,0.43,0.45
0.23,0.33,0.54
2.4,2.5,2.6
2.3,2.2,2.1
5.4,7.2,7.2
5.6,7,6
5.8,7.1,6.3
6,6,5.4
11,12,13
这个数据和原始数据相比就是少了最后一列label而已。

分类主程序:

package fz.bayes;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Map;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.BayesUtils;
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.training.WeightsMapper;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.math.Vector;
/**
 * 用于分类的Job
 * 针对
 * [
 *   2.1,3.2,1.2
	 2.1,3.2,1.3
   ]
   的数据,进行分类(即不含标签的数据)
 * @author fansy
 *
 */
public class BayesClassifiedJob extends AbstractJob {
	/**
	 * @param args
	 * @throws Exception 
	 */
	public static void main(String[] args) throws Exception {
		ToolRunner.run(new Configuration(), new BayesClassifiedJob(),args);
	}
	
	@Override
	public int run(String[] args) throws Exception {
		addInputOption();
	    addOutputOption();
	    addOption("model","m", "The file where bayesian model store ");
	    addOption("labelIndex","labelIndex", "The file where the index store ");
	    addOption("labelNumber","ln", "The labels number ");
	    addOption("mapreduce","mr", "Whether use mapreduce, true use ,else not use ");
	    addOption("SV","SV","The input vector splitter ,default is comma",",");
	    
	    if (parseArguments(args) == null) {
		      return -1;
		}
	    Configuration conf=getConf();
	    Path input = getInputPath();
	    Path output = getOutputPath();
	    String labelNumber=getOption("labelNumber");
	    String modelPath=getOption("model");
	    String useMR = getOption("mapreduce");
	    String SV = getOption("SV");
	    String labelIndex = getOption("labelIndex");
	    int returnCode=-1;
	    if("true".endsWith(useMR)){
	    	returnCode = useMRToClassify(conf,labelNumber,modelPath,input,output,SV,labelIndex);
	    }else{
	    	returnCode = classify(conf,input, output, labelNumber, modelPath, SV, labelIndex);
	    }
	    return returnCode;
	}
	/**
	 * 单机版
	 * @param conf
	 * @param input
	 * @param output
	 * @param labelNumber
	 * @param modelPath
	 * @param sv
	 * @param labelIndex
	 * @return
	 * @throws IOException 
	 * @throws IllegalArgumentException 
	 */
	private int classify(Configuration conf, Path input ,Path output ,String labelNumber,String modelPath,
			String sv,String labelIndex)  {
		// 读取模型參数 
		try{
		NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), conf);
		 AbstractNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model);
		 Map<Integer, String> labelMap = BayesUtils.readLabelIndex(conf, new Path(labelIndex));
		 Path outputPath =new Path(output,"result");
		 // 按行读取文件。并把分类的结果写入另外的文件
		 FileSystem fs =FileSystem.get(input.toUri(),conf);
		 FSDataInputStream in=fs.open(input);  
		 
	     InputStreamReader istr=new InputStreamReader(in);  
	     BufferedReader br=new BufferedReader(istr);  
	     if(fs.exists(outputPath)){
	    	 fs.delete(outputPath, true);
	     }
	     FSDataOutputStream out = fs.create(outputPath);
	     
	     String lines;
	     StringBuffer buff = new StringBuffer();
	     while((lines=br.readLine())!=null&&!"".equals(lines)){  
	    	 String[] line = lines.toString().split(sv);
	    	 if(line.length<1){
	    		 break;
	    	 }
			  Vector original =BayesUtil.transformToVector(line);
        	  Vector result = classifier.classifyFull(original);
        	  String label = BayesUtil.classifyVector(result, labelMap);
        	  buff.append(lines+sv+label+"\n");
//        	 out.writeUTF(lines+sv+label);
//        	 out.
	     }
	     out.writeUTF(buff.substring(0, buff.length()-1));
	     out.flush();
	     out.close();
	     br.close();
	     istr.close();
	     in.close();
//	     fs.close();
		}catch(Exception e){
			e.printStackTrace();
			return -1;
		}
		return 0;
	}
/**
 * MR 版
 * @param conf
 * @param labelNumber
 * @param modelPath
 * @param input
 * @param output
 * @param SV
 * @param labelIndex
 * @return
 * @throws IOException
 * @throws ClassNotFoundException
 * @throws InterruptedException
 */
	private int useMRToClassify(Configuration conf, String labelNumber, String modelPath, Path input, Path output, 
			String SV, String labelIndex) throws IOException, ClassNotFoundException, InterruptedException {
		
	    conf.set(WeightsMapper.class.getName() + ".numLabels",labelNumber);
	    conf.set("SV", SV);
	    conf.set("labelIndex", labelIndex);
	    HadoopUtil.cacheFiles(new Path(modelPath), conf);
	    HadoopUtil.delete(conf, output);
	    Job job=Job.getInstance(conf, "");
	    job.setJobName("Use bayesian model to classify the  input:"+input.getName());
	    job.setJarByClass(BayesClassifiedJob.class); 
	    
	    job.setInputFormatClass(TextInputFormat.class);
	    job.setOutputFormatClass(TextOutputFormat.class);
	    
	    job.setMapperClass(BayesClassifyMapper.class);
	    job.setMapOutputKeyClass(Text.class);
	    job.setMapOutputValueClass(Text.class);
	    job.setNumReduceTasks(0);
	    job.setOutputKeyClass(Text.class);
	    job.setOutputValueClass(Text.class);
	    FileInputFormat.setInputPaths(job, input);
	    FileOutputFormat.setOutputPath(job, output);
	    
	    if(job.waitForCompletion(true)){
	    	return 0;
	    }
		return -1;
	}
	
	
	
}
假设使用MR,则Mapper例如以下:

package fz.bayes;

import java.io.IOException;
import java.util.Map;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.BayesUtils;
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
import org.apache.mahout.math.Vector;

/**
 *  自己定义Mapper。输出当前值和分类的结果
 * @author Administrator
 *
 */
@SuppressWarnings("deprecation")
public  class BayesClassifyMapper extends Mapper<LongWritable, Text, Text, Text>{
	private AbstractNaiveBayesClassifier classifier;
	private String SV;
	private Map<Integer, String> labelMap;
	private String labelIndex;
		@Override
	  public void setup(Context context) throws IOException, InterruptedException {
			
	    Configuration conf = context.getConfiguration();
	    Path modelPath = new Path(DistributedCache.getCacheFiles(conf)[0].getPath());
	    NaiveBayesModel model = NaiveBayesModel.materialize(modelPath, conf);
	    classifier = new StandardNaiveBayesClassifier(model);
	    SV = conf.get("SV");
	    labelIndex=conf.get("labelIndex");
		labelMap = BayesUtils.readLabelIndex(conf, new Path(labelIndex));
	  }

	  @Override
	  public void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
		  String values =value.toString();
		  if("".equals(values)){
			  context.getCounter("Records", "Bad Record").increment(1);
			  return; 
		  }
		  String[] line = values.split(SV);
		  
		  Vector original =BayesUtil.transformToVector(line);
     	  Vector result = classifier.classifyFull(original);
     	  String label = BayesUtil.classifyVector(result, labelMap);
	    
	    //the key is the vector 
	    context.write(value, new Text(label));
	  }
}


用到的工具类:

package fz.bayes;

import java.util.Map;

import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;

public class BayesUtil {

	/**
	 * 把输入字符串转换为Vector
	 * @param lines
	 * @return
	 */
	public static Vector transformToVector(String[] line){
		Vector v=new RandomAccessSparseVector(line.length);
		for(int i=0;i<line.length;i++){
			double item=0;
			try{
				item=Double.parseDouble(line[i]);
			}catch(Exception e){
				return null; // 假设不能够转换,说明输入数据有问题
			}
			v.setQuick(i, item);
		}
		return v;
	}
	/**
	 * 依据得分值分类
	 * @param v
	 * @param labelMap
	 * @return
	 */
	public static String classifyVector(Vector v,Map<Integer, String> labelMap){
		int bestIdx = Integer.MIN_VALUE;
		double bestScore = Long.MIN_VALUE;
		for (Vector.Element element : v.all()) {
			if (element.get() > bestScore) {
				bestScore = element.get();
				bestIdx = element.index();
			}
		}
		if (bestIdx != Integer.MIN_VALUE) {
			ClassifierResult classifierResult = new ClassifierResult(labelMap.get(bestIdx), bestScore);
			return classifierResult.getLabel();
		}
		
		return null;
	}
}
这里略微分析下思路(參考单机版代码或者Mapper代码):

1. 读取模型。參数模型路径、标签的编码文件(labelIndex.bin)。标签的个数(labelNumber),依据相关路径,初始化模型相关变量;

2. 针对每条记录 。比方 0.2,0.3,0.4 。依据SV(输入路径向量的分隔符)把这条记录向量化,得到Vector(0=0.2,1=0.3,2=0.4);

3. 使用模型计算每一个标签的得分,得到的也是一个向量,记录了每一个标签的分数Vector result = classifier.classifyFull(original); 即result 向量;

4. 依据标签的得分,得出该条记录属于哪个标签,最后反编码(因为标签是经过编码得到的,所以这里须要经过反编码)。

这里看下输出结果:

MR版:

技术分享

单机版:

技术分享

能够看到单机版。第一行输出有一个乱码,这个事实上是没有影响的。使用hadoop fs -cat 读取是没有问题的。


分享,成长,快乐

转载请注明blog地址:http://blog.csdn.net/fansy1990



Mahout贝叶斯算法拓展篇3---分类无标签数据

标签:

原文地址:http://www.cnblogs.com/gcczhongduan/p/5245586.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!