转载请声明出处:http://blog.csdn.net/zhongkejingwang/article/details/44132771
KNN即K-Nearest Neighbor,是数据挖掘中一种最简单的分类方法,即要判断某一个样本属于已知样本种类中的哪一类时,通过计算找出所有样本中与测试样本最近或者最相似的K个样本,统计这K个样本中哪一种类最多则把测试样本归位该类。如何衡量两个样本的相似度?可以用向量的p-范数来定义。
假设有两个样本X=(x1, x2, ..., xn),Y=(y1, y2, ..., yn),则他们之间的相似度可以用以下向量p-范数定义:
当p=2时即为计算X、Y的欧几里得距离。
本文将介绍用Java实现KNN分类器对Iris数据进行分类。Iris数据如下:
前面四个item是属性,最后一个是类别名,总共有三类。完整的数据集可点击这里下载。
拿到原始数据后为了测试KNN分类效果,需要在原始数据中随机抽取一部分作为测试集,另一部分作为训练集。随机抽取的方法可以用下面代码实现:
/** * 将数据集划分为训练集和测试集,随机划分 * * @param filePath * 数据集文件路径 * @param testCount * 测试集个数 * @param outputPath * 输出路径 * @throws Exception */ public static void splitDataSet(String filePath, int testCount, String outputPath) throws Exception { BufferedWriter trainFile = new BufferedWriter(new FileWriter(new File( outputPath + "/train.txt"))); BufferedWriter testFile = new BufferedWriter(new FileWriter(new File( outputPath + "/test.txt"))); BufferedReader input = new BufferedReader(new FileReader(new File( filePath))); List<String> lines = new ArrayList<String>(); String line = null; //将所有数据读取到一个List里 while ((line = input.readLine()) != null) lines.add(line); //遍历一次List,每次产生一个随机序号,将该随机序号和当前序号内容进行交换 for (int i = 0; i < lines.size(); i++) { int ran = (int) (Math.random() * lines.size()); String temp = lines.get(i); lines.set(i, lines.get(ran)); lines.set(ran, temp); } int i = 0; //将指定数目的测试集写进test.txt中 for (; i < testCount; i++) { testFile.write(lines.get(i) + "\n"); testFile.flush(); } //剩余的写进train.txt中 for (; i < lines.size(); i++) { trainFile.write(lines.get(i) + "\n"); trainFile.flush(); } testFile.close(); trainFile.close(); }
接下来将数据读入:
/** * 根据文件生成训练集,注意:程序将以第一个出现的非数字的属性作为类别名称 * * @param fileName * 文件名 * @param sep * 分隔符 * @return * @throws Exception */ public List<DataNode> getDataList(String fileName, String sep) throws Exception { List<DataNode> list = new ArrayList<DataNode>(); BufferedReader br = new BufferedReader(new FileReader( new File(fileName))); String line = null; while ((line = br.readLine()) != null) { String splits[] = line.split(sep); //DataNode类用于保存数据属性和数据类别 DataNode node = new DataNode(); int i = 0; for (; i < splits.length; i++) { try { node.addAttrib(Float.valueOf(splits[i])); } catch (NumberFormatException e) { // 非数字,则为类别名称,将类别映射为数字 if (!mTypes.containsKey(splits[i])) { mTypes.put(splits[i], mTypeCount); mTypeCount++; } node.setType(mTypes.get(splits[i])); list.add(node); } } } return list; }对于testList中的每一个样本,均与所有trainList中的样本进行计算,取出最接近的K个样本并返回:
KnnClassifier.java
package com.jingchen.knn; import java.util.List; /** * @author chenjing * */ public class KnnClassifier { //k个近邻节点 private int k; private KNode[] mNearestK; private List<DataNode> mTrainData; public KnnClassifier(int k, List<DataNode> trainList) { mTrainData = trainList; this.k = k; mNearestK = new KNode[k]; for (int i = 0; i < k; i++) mNearestK[i] = new KNode(); } public void setK(int k){ this.k = k; mNearestK = new KNode[k]; for (int i = 0; i < k; i++) mNearestK[i] = new KNode(); } private void train(DataNode test, float p) { for (int i = 0; i < mTrainData.size(); i++) { putNode(getSim(test, mTrainData.get(i), p)); } } /** * 将新计算出来的节点与k个近邻节点比较,如果比其中之一小则插入 * @param node */ private void putNode(KNode node) { for (int i = 0; i < k; i++) { if (node.getD() < mNearestK[i].getD()) { for (int j = k - 1; j > i; j--) mNearestK[j] = mNearestK[j - 1]; mNearestK[i] = node; break; } } } /** * 获取相似度并封装成一个KNode类型返回 * @param test * @param trainNode * @param p * @return */ private KNode getSim(DataNode test, DataNode trainNode, float p) { List<Float> list1 = test.getAttribs(); List<Float> list2 = trainNode.getAttribs(); float d = 0; for (int i = 0; i < list1.size(); i++) d += Math.pow( Math.abs(list1.get(i).floatValue() - list2.get(i).floatValue()), p); d = (float) Math.pow(d, 1/p); KNode node = new KNode(d, trainNode.getType()); return node; } private void reset() { for (int i = 0; i < k; i++) mNearestK[i].reset(); } /** * 返回K个近邻节点 * @param test * @param p * @return */ public KNode[] getKNN(DataNode test, float p) { reset(); train(test, p); return mNearestK; } }
public static void main(String[] args) throws Exception { DataUtil util = DataUtil.getInstance(); //获得训练集和测试集 List<DataNode> trainList = util.getDataList("E:/train.txt", ","); List<DataNode> testList = util.getDataList("E:/test.txt", ","); int K = BASE_K; KnnClassifier classifier = new KnnClassifier(K, trainList); BufferedWriter output = new BufferedWriter(new FileWriter(new File( "E:/output.txt"))); int typeCount = util.getTypeCount(); int[] count = new int[typeCount]; for (int i = 0; i < testList.size();) { for (int m = 0; m < typeCount; m++) count[m] = 0; DataNode test = testList.get(i); classifier.setK(K); KNode[] nodes = classifier.getKNN(test, 2); for (int j = 0; j < nodes.length; j++) count[nodes[j].getType()]++; int type = -1; int max = -1; for (int j = 0; j < typeCount; j++) { if (count[j] > max) { max = count[j]; type = j; } else if (count[j] == max) { // 存在两个类型分个数相同,无法判断属于哪个类型,增加K的值继续从该节点开始 type = -1; K++; break; } } if (type == -1) continue; else { i++; K = BASE_K; } //将分类结果写入文件 List<Float> attribs = test.getAttribs(); for (int n = 0; n < attribs.size(); n++) { output.write(attribs.get(n) + ","); output.flush(); } output.write(util.getTypeName(type) + "\n"); output.flush(); } output.close(); }
原文地址:http://blog.csdn.net/zhongkejingwang/article/details/44132771