码迷,mamicode.com
首页 > 其他好文 > 详细

转:谱聚类

时间:2018-07-15 19:41:15      阅读:201      评论:0      收藏:0      [点我收藏+]

标签:trim   lap   ber   error   ted   i+1   title   组成   doc   

广义上来说,任何在算法中用到SVD/特征值分解的,都叫Spectral Algorithm。顺便说一下,对于任意矩阵只存在奇异值分解,不存在特征值分解。对于正定的对称矩阵,奇异值就是特征值,奇异向量就是特征向量。

传统的聚类算法,如K-Means、EM算法都是建立在凸球形样本空间上,当样本空间不为凸时,算法会陷入局部最优,最终结果受初始参数的选择影响比较大。而谱聚类可以在任意形状的样本空间上聚类,且收敛于全局最优解。

谱聚类和CHAMELEON聚类很像,都是把样本点的相似度放到一个带权无向图中,采用“图划分”的方法进行聚类。只是谱聚类算法在进行图划分的时候发现计算量很大,转而求特征值去了,而且最后还在几个小特征向量组成的矩阵上进行了K-Means聚类

Simply speaking,谱聚类算法分为3步:

  1. 构造一个N×N的权值矩阵W,Wij表示样本i和样本j的相似度,显然W是个对称矩阵。相似度的计算方法很多了,你可以用欧拉距离、街区距离、向量夹角、皮尔森相关系数等。并不是任意两个点间的相似度都要表示在图上,我们希望的权值图是比较稀疏的,有2种方法:权值小于阈值的认为是0;K最邻近方法,即每个点只和跟它最近的k个点连起来,CHAMELEON算法的第1阶段就是这么干的。再构造一个对角矩阵D,Dii为W第i列元素之和。最后构造矩阵L=D-W。可以证明L是个半正定和对称矩阵。
  2. 求L的前K小特征值对应的特征向量(这要用到奇异值分解了)。把K个特征向量放在一起构造一个N×K的矩阵M。
  3. 把M的每一行当成一个新的样本点,对这N个新的样本点进行K-Means聚类。

从文件读入样本点,最终算得矩阵L

#include<math.h>
#include<string.h>
#include"matrix.h"
#include"svd.h"
 
#define N 19        //样本点个数
#define K 4         //K-Means算法中的K
#define T 0.1       //样本点之间相似度的阈值
 
double sample[N][2];    //存放所有样本点的坐标(2维的)
 
void readSample(char *filename){
    FILE *fp;
    if((fp=fopen(filename,"r"))==NULL){
        perror("fopen");
        exit(0);
    }
    char buf[50]={0};
    int i=0;
    while(fgets(buf,sizeof(buf),fp)!=NULL){
        char *w=strtok(buf," \t");
        double x=atof(w);
        w=strtok(NULL," \t");
        double y=atof(w);
        sample[i][0]=x;
        sample[i][1]=y;
        i++;
        memset(buf,0x00,sizeof(buf));
    }
    assert(i==N);
    fclose(fp);
}
 
double** getSimMatrix(){
    //为二维矩阵申请空间
    double **matrix=getMatrix(N,N);
    //计算样本点两两之间的相似度,得到矩阵W
    int i,j;
    for(i=0;i<N;i++){
        matrix[i][i]=1;
        for(j=i+1;j<N;j++){
            double dist=sqrt(pow(sample[i][0]-sample[j][0],2)+pow(sample[i][1]-sample[j][1],2));
            double sim=1.0/(1+dist);
            if(sim>T){
                matrix[j][i]=sim;
                matrix[i][j]=sim;
            }
        }
    }
    //计算L=D-W
    for(j=0;j<N;j++){
        double sum=0;
        for(i=0;i<N;i++){
            sum+=matrix[i][j];
            if(i!=j)
                matrix[i][j]=0-matrix[i][j];
        }
        matrix[j][j]=matrix[j][j]-sum;
    }
    return matrix;
}
 
int main(){
    char *file="/home/orisun/data";
    readSample(file);
    double **L=getSimMatrix();
    printMatrix(L,N,N);
     
    double **M=singleVector(L,N,N,5);
    printMatrix(M,N,5);
     
    freeMatrix(L,N);
 
    return 0;
}

L已是对称矩阵,直接奇异值分解的得到的就是特征向量

最后是运行KMeans的Java代码

package ai;
 
public class Global {
    //计算两个向量的欧氏距离
    public static double calEuraDist(double[] arr1,double[] arr2,int len){
        double result=0.0;
        for(int i=0;i<len;i++){
            result+=Math.pow(arr1[i]-arr2[i],2.0);
        }
        return Math.sqrt(result);
    }
}
package ai;
 
public class DataObject {
 
    String docname;
    double[] vector;
    int cid;   
    boolean visited;
     
    public DataObject(int len){
        vector=new double[len];
    }
 
    public String getName() {
        return docname;
    }
 
    public void setName(String docname) {
        this.docname = docname;
    }
 
    public double[] getVector() {
        return vector;
    }
 
    public void setVector(double[] vector) {
        this.vector = vector;
    }
 
    public int getCid() {
        return cid;
    }
 
    public void setCid(int cid) {
        this.cid = cid;
    }
 
    public boolean isVisited() {
        return visited;
    }
 
    public void setVisited(boolean visited) {
        this.visited = visited;
    }
 
}
package ai;
 
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
public class DataSource {
 
    ArrayList<DataObject> objects;
    int row;
    int col;
 
    public void readMatrix(File dataFile) {
        try {
            FileReader fr = new FileReader(dataFile);
            BufferedReader br = new BufferedReader(fr);
            String line = br.readLine();
            String[] words = line.split("\\s+");
            row = Integer.parseInt(words[0]);
            // row=1000;
            col = Integer.parseInt(words[1]);
            objects = new ArrayList<DataObject>(row);
            for (int i = 0; i < row; i++) {
                DataObject object = new DataObject(col);
                line = br.readLine();
                words = line.split("\\s+");
                for (int j = 0; j < col; j++) {
                    object.getVector()[j] = Double.parseDouble(words[j]);
                }
                objects.add(object);
            }
            br.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
 
    public void readRLabel(File file) {
        try {
            FileReader fr = new FileReader(file);
            BufferedReader br = new BufferedReader(fr);
            String line = null;
            for (int i = 0; i < row; i++) {
                line = br.readLine();
                objects.get(i).setName(line.trim());
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
 
    public void printResult(ArrayList<DataObject> objects, int n) {
        //DBScan是从第1类开始,K-Means是从第0类开始
//      for (int i =0; i <n; i++) {
        for(int i=1;i<=n;i++){
            System.out.println("=============属于第"+i+"类的有:===========================");
            Iterator<DataObject> iter = objects.iterator();
            while (iter.hasNext()) {
                DataObject object = iter.next();
                int cid=object.getCid();
                if(cid==i){
                    System.out.println(object.getName());
//                  switch(Integer.parseInt(object.getName())/1000){
//                  case 0:
//                      System.out.println(0);
//                      break;
//                  case 1:
//                      System.out.println(1);
//                      break;
//                  case 2:
//                      System.out.println(2);
//                      break;
//                  case 3:
//                      System.out.println(3);
//                      break;
//                  case 4:
//                      System.out.println(4);
//                      break;
//                  case 5:
//                      System.out.println(5);
//                      break;
//                  default:
//                      System.out.println("Go Out");
//                      break;
//                  }              
                }
            }
        }
    }
}
package ai;
 
import java.io.File;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
  
public class KMeans {
  
    int k; // 指定划分的簇数
    double mu; // 迭代终止条件,当各个新质心相对于老质心偏移量小于mu时终止迭代
    double[][] center; // 上一次各簇质心的位置
    int repeat; // 重复运行次数
    double[] crita; // 存放每次运行的满意度
  
    public KMeans(int k, double mu, int repeat, int len) {
        this.k = k;
        this.mu = mu;
        this.repeat = repeat;
        center = new double[k][];
        for (int i = 0; i < k; i++)
            center[i] = new double[len];
        crita = new double[repeat];
    }
  
    // 初始化k个质心,每个质心是len维的向量,每维均在left--right之间
    public void initCenter(int len, ArrayList<DataObject> objects) {
        Random random = new Random(System.currentTimeMillis());
        int[] count = new int[k]; // 记录每个簇有多少个元素
        Iterator<DataObject> iter = objects.iterator();
        while (iter.hasNext()) {
            DataObject object = iter.next();
            int id = random.nextInt(10000)%k;
            count[id]++;
            for (int i = 0; i < len; i++)
                center[id][i] += object.getVector()[i];
        }
        for (int i = 0; i < k; i++) {
            for (int j = 0; j < len; j++) {
                center[i][j] /= count[i];
            }
        }
    }
  
    // 把数据集中的每个点归到离它最近的那个质心
    public void classify(ArrayList<DataObject> objects) {
        Iterator<DataObject> iter = objects.iterator();
        while (iter.hasNext()) {
            DataObject object = iter.next();
            double[] vector = object.getVector();
            int len = vector.length;
            int index = 0;
            double neardist = Double.MAX_VALUE;
            for (int i = 0; i < k; i++) {
                double dist = Global.calEuraDist(vector, center[i], len); // 使用欧氏距离
                if (dist < neardist) {
                    neardist = dist;
                    index = i;
                }
            }
            object.setCid(index);
        }
    }
  
    // 重新计算每个簇的质心,并判断终止条件是否满足,如果不满足更新各簇的质心,如果满足就返回true.len是数据的维数
    public boolean calNewCenter(ArrayList<DataObject> objects, int len) {
        boolean end = true;
        int[] count = new int[k]; // 记录每个簇有多少个元素
        double[][] sum = new double[k][];
        for (int i = 0; i < k; i++)
            sum[i] = new double[len];
        Iterator<DataObject> iter = objects.iterator();
        while (iter.hasNext()) {
            DataObject object = iter.next();
            int id = object.getCid();
            count[id]++;
            for (int i = 0; i < len; i++)
                sum[id][i] += object.getVector()[i];
        }
        for (int i = 0; i < k; i++) {
            if (count[i] != 0) {
                for (int j = 0; j < len; j++) {
                    sum[i][j] /= count[i];
                }
            }
            // 簇中不包含任何点,及时调整质心
            else {
                int a=(i+1)%k;
                int b=(i+3)%k;
                int c=(i+5)%k;
                for (int j = 0; j < len; j++) {
                    center[i][j] = (center[a][j]+center[b][j]+center[c][j])/3;
                }
            }
        }
        for (int i = 0; i < k; i++) {
            // 只要有一个质心需要移动的距离超过了mu,就返回false
            if (Global.calEuraDist(sum[i], center[i], len) >= mu) {
                end = false;
                break;
            }
        }
        if (!end) {
            for (int i = 0; i < k; i++) {
                for (int j = 0; j < len; j++)
                    center[i][j] = sum[i][j];
            }
        }
        return end;
    }
  
    // 计算各簇内数据和方差的加权平均,得出本次聚类的满意度.len是数据的维数
    public double getSati(ArrayList<DataObject> objects, int len) {
        double satisfy = 0.0;
        int[] count = new int[k];
        double[] ss = new double[k];
        Iterator<DataObject> iter = objects.iterator();
        while (iter.hasNext()) {
            DataObject object = iter.next();
            int id = object.getCid();
            count[id]++;
            for (int i = 0; i < len; i++)
                ss[id] += Math.pow(object.getVector()[i] - center[id][i], 2.0);
        }
        for (int i = 0; i < k; i++) {
            satisfy += count[i] * ss[i];
        }
        return satisfy;
    }
  
    public double run(int round, DataSource datasource, int len) {
        System.out.println("第" + round + "次运行");
        initCenter(len,datasource.objects);
        classify(datasource.objects);
        while (!calNewCenter(datasource.objects, len)) {
            classify(datasource.objects);
        }
        datasource.printResult(datasource.objects, k);
        double ss = getSati(datasource.objects, len);
        System.out.println("加权方差:" + ss);
        return ss;
    }
  
    public static void main(String[] args) {
        DataSource datasource = new DataSource();
        datasource.readMatrix(new File("/home/orisun/test/dot.mat"));
        datasource.readRLabel(new File("/home/orisun/test/dot.rlabel"));
        int len = datasource.col;
        // 划分为4个簇,质心移动小于1E-8时终止迭代,重复运行7次
        KMeans km = new KMeans(4, 1E-10, 7, len);
        int index = 0;
        double minsa = Double.MAX_VALUE;
        for (int i = 0; i < km.repeat; i++) {
            double ss = km.run(i, datasource, len);
            if (ss < minsa) {
                minsa = ss;
                index = i;
            }
        }
        System.out.println("最好的结果是第" + index + "次。");
    }
}

转:谱聚类

标签:trim   lap   ber   error   ted   i+1   title   组成   doc   

原文地址:https://www.cnblogs.com/lm3306/p/9314032.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!