标签:不同的 term 分类 image ros 图片 next idt k-means算法
数据挖掘中一个重要算法是K-means,我这里就不做详细介绍。如果感兴趣的话可以移步陈皓的博客:
总的来讲,k-means聚类需要以下几个步骤:
①.初始化数据
②.计算初始的中心点,可以随机选择
③.计算每个点到每个聚类中心的距离,并且划分到距离最短的聚类中心簇中
④.计算每个聚类簇的平均值,这个均值作为新的聚类中心,重复步骤3
⑤.如果达到最大循环或者是聚类中心不再变化或者聚类中心变化幅度小于一定范围时,停止循环。
恩,原理就是这样,超级简单。但是Java算法实现起来代码量并不小。这个代码也不算是完全自己写的啦,也有些借鉴。我把k-means实现封装在了一个类里面,这样就可以随时调用了呢。
import java.util.ArrayList; import java.util.Random; public class kmeans { private int k;//簇数 private int m;//迭代次数 private int dataSetLength;//数据集长度 private ArrayList<double[]> dataSet;//数据集合 private ArrayList<double[]> center;//中心链表 private ArrayList<ArrayList<double[]>> cluster;//簇 private ArrayList<Float> jc;//误差平方和,这个是用来计算中心聚点的移动哦 private Random random; //设置原始数据集合 public void setDataSet(ArrayList<double[]> dataSet){ this.dataSet=dataSet; } //获得簇分组 public ArrayList<ArrayList<double[]>> getCluster(){ return this.cluster; } //构造函数,传入要分的簇的数量 public kmeans(int k){ if(k<=0) k=1; this.k=k; } //初始化 private void init(){ m=0; random=new Random(); if(dataSet==null||dataSet.size()==0) initDataSet(); dataSetLength=dataSet.size(); if(k>dataSetLength) k=dataSetLength; center=initCenters(); cluster=initCluster(); jc=new ArrayList<Float>(); } //初始化数据集合 private void initDataSet(){ dataSet=new ArrayList<double[]>(); double[][] dataSetArray=new double[][]{{8,2},{3,4},{2,5},{4,2}, {7,3},{6,2},{4,7},{6,3},{5,3},{6,3},{6,9}, {1,6},{3,9},{4,1},{8,6}}; for(int i=0;i<dataSetArray.length;i++) dataSet.add(dataSetArray[i]); } //初始化中心链表,分成几簇就有几个中心 private ArrayList<double[]> initCenters(){ ArrayList<double[]> center= new ArrayList<double[]>(); //生成一个随机数列, int[] randoms=new int[k]; boolean flag; int temp=random.nextInt(dataSetLength); randoms[0]=temp; for(int i=1;i<k;i++){ flag=true; while(flag){ temp=random.nextInt(dataSetLength); int j=0; while(j<i){ if(temp==randoms[j]) break; j++; } if(j==i) flag=false; } randoms[i]=temp; } for(int i=0;i<k;i++) center.add(dataSet.get(randoms[i])); return center; } //初始化簇集合 private ArrayList<ArrayList<double[]>> initCluster(){ ArrayList<ArrayList<double[]>> cluster= new ArrayList<ArrayList<double[]>>(); for(int i=0;i<k;i++) cluster.add(new ArrayList<double[]>()); return cluster; } //计算距离 private double distance(double[] element,double[] center){ double distance=0.0f; double x=element[0]-center[0]; double y=element[1]-center[1]; double z=element[2]-center[2]; double sum=x*x+y*y+z*z; distance=(double)Math.sqrt(sum); return distance; } //计算最短的距离 private int minDistance(double[] distance){ double minDistance=distance[0]; int minLocation=0; for(int i=0;i<distance.length;i++){ if(distance[i]<minDistance){ minDistance=distance[i]; minLocation=i; }else if(distance[i]==minDistance){ if(random.nextInt(10)<5){ minLocation=i; } } } return minLocation; } //每个点分类 private void clusterSet(){ double[] distance=new double[k]; for(int i=0;i<dataSetLength;i++){ //计算到每个中心店的距离 for(int j=0;j<k;j++) distance[j]=distance(dataSet.get(i),center.get(j)); //计算最短的距离 int minLocation=minDistance(distance); //把他加到聚类里 cluster.get(minLocation).add(dataSet.get(i)); } } //计算新的中心 private void setNewCenter(){ for(int i=0;i<k;i++){ int n=cluster.get(i).size(); if(n!=0){ double[] newcenter={0,0}; for(int j=0;j<n;j++){ newcenter[0]+=cluster.get(i).get(j)[0]; newcenter[1]+=cluster.get(i).get(j)[1]; } newcenter[0]=newcenter[0]/n; newcenter[1]=newcenter[1]/n; center.set(i, newcenter); } } } //求2点的误差平方 private double errosquare(double[] element,double[] center){ double x=element[0]-center[0]; double y=element[1]-center[1]; double errosquare=x*x+y*y; return errosquare; } //计算误差平方和准则函数 private void countRule(){ float jcf=0; for(int i=0;i<cluster.size();i++){ for(int j=0;j<cluster.get(i).size();j++) jcf+=errosquare(cluster.get(i).get(j),center.get(i)); jc.add(jcf); } } //核心算法 private void Kmeans(){ //初始化各种变量,随机选定中心,初始化聚类 init(); //开始循环 while(true){ //把每个点分到聚类中去 clusterSet(); //计算目标函数 countRule(); //检查误差变化,因为我规定的计算循环次数为50次,所以就不用计算这个啦,你要愿意用也可以,就是慢一点 /* if(m!=0){ if(jc.get(m)-jc.get(m-1)==0) break; }*/ if(m>=50) break; //否则继续生成新的中心 setNewCenter(); m++; cluster.clear(); cluster=initCluster(); } }
//只暴露一个接口给外部类 public void execute(){ System.out.print("start kmeans\n"); Kmeans(); System.out.print("kmeans end\n"); }
//用来在外面打印出来已经分好的聚类 public void printDataArray(ArrayList<double[]> data,String dataArrayName){ for(int i=0;i<data.size();i++){ System.out.print("print:"+dataArrayName+"["+i+"]={"+data.get(i)[0]+","+data.get(i)[1]+"}\n"); } System.out.print("=========================="); } }嗯,代码就是这样。注释写的很详细,也都能看得懂。下面我给一个测试例子。
import java.util.ArrayList; public class Test { public static void main(String[] args){ kmeans k=new kmeans(2); ArrayList<double[]> dataSet=new ArrayList<double[]>(); dataSet.add(new double[]{2,2,2}); dataSet.add(new double[]{1,2,2}); dataSet.add(new double[]{2,1,2}); dataSet.add(new double[]{1,3,2}); dataSet.add(new double[]{3,1,2}); dataSet.add(new double[]{-2,-2,-2}); dataSet.add(new double[]{-1,-2,-2}); dataSet.add(new double[]{-2,-1,-2}); dataSet.add(new double[]{-3,-1,-2}); dataSet.add(new double[]{-1,-3,-2}); k.setDataSet(dataSet); k.execute(); ArrayList<ArrayList<double[]>> cluster=k.getCluster(); for(int i=0;i<cluster.size();i++){ k.printDataArray(cluster.get(i), "cluster["+i+"]"); } } }没啥难度,也就是输入写初始数据,然后执行k-means在进行分类,最后打印一下。这个原型代码很粗糙,没有添加聚类个数以及循环次数的变量,这些需要自己动手啦。
//读取指定目录的图片数据,并且写入数组,这个数据要继续处理 private int[][] getImageData(String path){ BufferedImage bi=null; try{ bi=ImageIO.read(new File(path)); }catch (IOException e){ e.printStackTrace(); } int width=bi.getWidth(); int height=bi.getHeight(); int [][] data=new int[width][height]; for(int i=0;i<width;i++) for(int j=0;j<height;j++) data[i][j]=bi.getRGB(i, j); /*测试输出 for(int i=0;i<data.length;i++) for(int j=0;j<data[0].length;j++) System.out.println(data[i][j]);*/ return data; } //用来处理获取的像素数据,提取我们需要的写入dataItem数组 private dataItem[][] InitData(int [][] data){ dataItem[][] dataitems=new dataItem[data.length][data[0].length]; for(int i=0;i<data.length;i++){ for(int j=0;j<data[0].length;j++){ dataItem di=new dataItem(); Color c=new Color(data[i][j]); di.r=(double)c.getRed(); di.g=(double)c.getGreen(); di.b=(double)c.getBlue(); di.group=1; dataitems[i][j]=di; } } return dataitems; }
//介货是用来输出图像的 <pre name="code" class="java"> private void ImagedataOut(String path){ Color c0=new Color(255,0,0); Color c1=new Color(0,255,0); Color c2=new Color(0,0,255); Color c3=new Color(128,128,128); BufferedImage nbi=new BufferedImage(source.length,source[0].length,BufferedImage.TYPE_INT_RGB); for(int i=0;i<source.length;i++){ for(int j=0;j<source[0].length;j++){ if(source[i][j].group==0) nbi.setRGB(i, j, c0.getRGB()); else if(source[i][j].group==1) nbi.setRGB(i, j, c1.getRGB()); else if(source[i][j].group==2) nbi.setRGB(i, j, c2.getRGB()); else if (source[i][j].group==3) nbi.setRGB(i, j, c3.getRGB()); //Color c=new Color((int)center[source[i][j].group].r, // (int)center[source[i][j].group].g,(int)center[source[i][j].group].b); //nbi.setRGB(i, j, c.getRGB()); } } try{ ImageIO.write(nbi, "jpg", new File(path)); }catch(IOException e){ e.printStackTrace(); } }
标签:不同的 term 分类 image ros 图片 next idt k-means算法
原文地址:http://blog.csdn.net/zhuzhuzhu22/article/details/52944979