码迷,mamicode.com
首页 > 其他好文 > 详细

KNN分类器

时间:2015-03-08 17:22:38      阅读:208      评论:0      收藏:0      [点我收藏+]

标签:knn   数据挖掘   分类器   

  转载请声明出处:http://blog.csdn.net/zhongkejingwang/article/details/44132771

   KNN即K-Nearest Neighbor,是数据挖掘中一种最简单的分类方法,即要判断某一个样本属于已知样本种类中的哪一类时,通过计算找出所有样本中与测试样本最近或者最相似的K个样本,统计这K个样本中哪一种类最多则把测试样本归位该类。如何衡量两个样本的相似度?可以用向量的p-范数来定义。

假设有两个样本X=(x1, x2, ..., xn),Y=(y1, y2, ..., yn),则他们之间的相似度可以用以下向量p-范数定义:

                技术分享

当p=2时即为计算X、Y的欧几里得距离。

  本文将介绍用Java实现KNN分类器对Iris数据进行分类。Iris数据如下:

技术分享技术分享

  前面四个item是属性,最后一个是类别名,总共有三类。完整的数据集可点击这里下载

拿到原始数据后为了测试KNN分类效果,需要在原始数据中随机抽取一部分作为测试集,另一部分作为训练集。随机抽取的方法可以用下面代码实现:

/**
	 * 将数据集划分为训练集和测试集,随机划分
	 * 
	 * @param filePath
	 *            数据集文件路径
	 * @param testCount
	 *            测试集个数
	 * @param outputPath
	 *            输出路径
	 * @throws Exception
	 */
	public static void splitDataSet(String filePath, int testCount,
			String outputPath) throws Exception
	{
		BufferedWriter trainFile = new BufferedWriter(new FileWriter(new File(
				outputPath + "/train.txt")));
		BufferedWriter testFile = new BufferedWriter(new FileWriter(new File(
				outputPath + "/test.txt")));
		BufferedReader input = new BufferedReader(new FileReader(new File(
				filePath)));
		List<String> lines = new ArrayList<String>();
		String line = null;
		//将所有数据读取到一个List里
		while ((line = input.readLine()) != null)
			lines.add(line);
		//遍历一次List,每次产生一个随机序号,将该随机序号和当前序号内容进行交换
		for (int i = 0; i < lines.size(); i++)
		{
			int ran = (int) (Math.random() * lines.size());
			String temp = lines.get(i);
			lines.set(i, lines.get(ran));
			lines.set(ran, temp);
		}
		int i = 0;
		//将指定数目的测试集写进test.txt中
		for (; i < testCount; i++)
		{
			testFile.write(lines.get(i) + "\n");
			testFile.flush();
		}
		//剩余的写进train.txt中
		for (; i < lines.size(); i++)
		{
			trainFile.write(lines.get(i) + "\n");
			trainFile.flush();
		}
		testFile.close();
		trainFile.close();
	}

  调用这个方法后就可以得到train.txt和test.txt两份数据了。

  接下来将数据读入:

/**
	 * 根据文件生成训练集,注意:程序将以第一个出现的非数字的属性作为类别名称
	 * 
	 * @param fileName
	 *            文件名
	 * @param sep
	 *            分隔符
	 * @return
	 * @throws Exception
	 */
	public List<DataNode> getDataList(String fileName, String sep)
			throws Exception
	{
		List<DataNode> list = new ArrayList<DataNode>();
		BufferedReader br = new BufferedReader(new FileReader(
				new File(fileName)));
		String line = null;
		while ((line = br.readLine()) != null)
		{
			String splits[] = line.split(sep);
			//DataNode类用于保存数据属性和数据类别
			DataNode node = new DataNode();
			int i = 0;
			for (; i < splits.length; i++)
			{
				try
				{
					node.addAttrib(Float.valueOf(splits[i]));
				} catch (NumberFormatException e)
				{
					// 非数字,则为类别名称,将类别映射为数字
					if (!mTypes.containsKey(splits[i]))
					{
						mTypes.put(splits[i], mTypeCount);
						mTypeCount++;
					}
					node.setType(mTypes.get(splits[i]));
					list.add(node);
				}
			}
		}
		return list;
	}
对于testList中的每一个样本,均与所有trainList中的样本进行计算,取出最接近的K个样本并返回:

KnnClassifier.java

package com.jingchen.knn;

import java.util.List;

/**
 * @author chenjing
 * 
 */
public class KnnClassifier
{
	//k个近邻节点
	private int k;
	private KNode[] mNearestK;
	private List<DataNode> mTrainData;

	public KnnClassifier(int k, List<DataNode> trainList)
	{
		mTrainData = trainList;
		this.k = k;
		mNearestK = new KNode[k];
		for (int i = 0; i < k; i++)
			mNearestK[i] = new KNode();
	}
	public void setK(int k){
		this.k = k;
		mNearestK = new KNode[k];
		for (int i = 0; i < k; i++)
			mNearestK[i] = new KNode();
	}
	private void train(DataNode test, float p)
	{
		for (int i = 0; i < mTrainData.size(); i++)
		{
			putNode(getSim(test, mTrainData.get(i), p));
		}
	}

	/**
	 * 将新计算出来的节点与k个近邻节点比较,如果比其中之一小则插入
	 * @param node
	 */
	private void putNode(KNode node)
	{
		for (int i = 0; i < k; i++)
		{
			if (node.getD() < mNearestK[i].getD())
			{
				for (int j = k - 1; j > i; j--)
					mNearestK[j] = mNearestK[j - 1];
				mNearestK[i] = node;
				break;
			}
		}
	}

	/**
	 * 获取相似度并封装成一个KNode类型返回
	 * @param test
	 * @param trainNode
	 * @param p
	 * @return
	 */
	private KNode getSim(DataNode test, DataNode trainNode, float p)
	{
		List<Float> list1 = test.getAttribs();
		List<Float> list2 = trainNode.getAttribs();
		float d = 0;
		for (int i = 0; i < list1.size(); i++)
			d += Math.pow(
					Math.abs(list1.get(i).floatValue() - list2.get(i).floatValue()), p);
		d = (float) Math.pow(d, 1/p);
		KNode node = new KNode(d, trainNode.getType());
		return node;
	}

	private void reset()
	{
		for (int i = 0; i < k; i++)
			mNearestK[i].reset();
	}

	/**
	 * 返回K个近邻节点
	 * @param test
	 * @param p
	 * @return
	 */
	public KNode[] getKNN(DataNode test, float p)
	{
		reset();
		train(test, p);
		return mNearestK;
	}
}

  main方法:

public static void main(String[] args) throws Exception
	{
		DataUtil util = DataUtil.getInstance();
		//获得训练集和测试集
		List<DataNode> trainList = util.getDataList("E:/train.txt", ",");
		List<DataNode> testList = util.getDataList("E:/test.txt", ",");
		int K = BASE_K;
		KnnClassifier classifier = new KnnClassifier(K, trainList);
		BufferedWriter output = new BufferedWriter(new FileWriter(new File(
				"E:/output.txt")));
		int typeCount = util.getTypeCount();
		int[] count = new int[typeCount];
		for (int i = 0; i < testList.size();)
		{
			for (int m = 0; m < typeCount; m++)
				count[m] = 0;
			DataNode test = testList.get(i);
			classifier.setK(K);
			KNode[] nodes = classifier.getKNN(test, 2);
			for (int j = 0; j < nodes.length; j++)
				count[nodes[j].getType()]++;
			int type = -1;
			int max = -1;
			for (int j = 0; j < typeCount; j++)
			{
				if (count[j] > max)
				{
					max = count[j];
					type = j;
				} else if (count[j] == max)
				{
					// 存在两个类型分个数相同,无法判断属于哪个类型,增加K的值继续从该节点开始
					type = -1;
					K++;
					break;
				}
			}
			if (type == -1)
				continue;
			else
			{
				i++;
				K = BASE_K;
			}
			//将分类结果写入文件
			List<Float> attribs = test.getAttribs();
			for (int n = 0; n < attribs.size(); n++)
			{
				output.write(attribs.get(n) + ",");
				output.flush();
			}
			output.write(util.getTypeName(type) + "\n");
			output.flush();
		}
		output.close();

	}

  经测试,KNN对Iris数据集分类准确率基本都在90+%以上,此分类方法也比较直观。数据集及完整的项目代码可以从这里下载:点击下载。后面介绍的BP神经网络将会把准确率提高到接近100%。




KNN分类器

标签:knn   数据挖掘   分类器   

原文地址:http://blog.csdn.net/zhongkejingwang/article/details/44132771

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!