码迷,mamicode.com
首页 > 编程语言 > 详细

K-Means 算法的 Hadoop 实现

时间:2016-12-03 02:22:17      阅读:468      评论:0      收藏:0      [点我收藏+]

标签:创建   数据存储   cli   hand   output   turn   leo   mpi   成员   

K-Means 算法的 Hadoop 实现

K-Means 算法简介

k-Means是一种聚类分析算法,它是一种无监督学习算法。它主要用来计算数据的聚集,将数据相近的点归到同一数据蔟。学习聚类时我们需要了解聚类与分类的区别,分类的类别是我们实现设定好的,而聚类的类别是通过计算得到的。

算法原理

维基百科的算法描述如下:

已知观测集 (x1,x2,x3,...,xn) ,其中每个观测都是一个d-维实向量,k-平均聚类要把这n个观测划分到k个集合中 (k≤n) ,使得组内平方和(WCSS within-cluster sum of squares)最小。换句话说,它的目标是找到使得下式满足的聚类 Si

argminS=i=1kxSi||x?μi||2

其中 μiSi 中所有点的均值。

简单描述就是:不断迭代计算各个数据簇的中心点,直到该中心点趋于稳定。
该算法的优点是实现非常简单,主要缺点有如下:

  • 对异常数据敏感。当单独几个数据远离数据簇时会影响聚类效果。
  • 由于 K 值是事先给定的,所以 K 值的选择难以估计。也就是我们事先并不知道需要分多少个类别。
    • ISODATA 算法可用于解决此问题,得到较为合理的类型数目K
  • 初始的数据簇的中心点需要事先给定,初始种子点很大程度上会影响聚类的结果。
    • K-Means++ 算法可以用来解决这个问题,其可以有效地选择初始点

步骤

  1. 创建k个数据簇的中心点。
  2. 计算所有数据点到这 k 个中心点的距离,将其划归到距离自己最近的中心点。
  3. 根据上次聚类结果,计算各个数据簇的算数平均值作为新的数据簇中心点。
  4. 将所有数据在新的中心点上重新聚类。
  5. 重复第4步,直到中心点趋于稳定。

中心点距离算法

求某一数据点到中心点的距离可以采用欧几里得距离公式:

distance=k=1n(xik?xjk)2?

可以参考 K-Means 算法(CoolShell) 里面的 求点群中心的算法 这一节,有三种距离公式。

Hadoop 环境简介

参考该篇文章(基于Docker搭建Hadoop集群之升级版)搭建了基于 Docker 的 Hadoop 环境。

Hadoop Version

root@hadoop-master:/myjob/kmeans# hadoop version
Hadoop 2.7.2
Subversion Unknown -r Unknown
Compiled by root on 2016-05-27T18:05Z
Compiled with protoc 2.5.0
From source with checksum d0fda26633fa762bff87ec759ebe689c
This command was run using /usr/local/hadoop/share/hadoop/common/hadoop-common-2.7.2.jar

Hadoop 参数设置

core-site.xml

<?xml version="1.0"?>
<configuration>
    <property>
        <name>fs.defaultFS</name>
        <value>hdfs://hadoop-master:9000/</value>
    </property>
</configuration>

hdfs-site.xml

<?xml version="1.0"?>
<configuration>
    <property>
        <name>dfs.namenode.name.dir</name>
        <value>file:///root/hdfs/namenode</value>
        <description>NameNode directory for namespace and transaction logs storage.</description>
    </property>
    <property>
        <name>dfs.datanode.data.dir</name>
        <value>file:///root/hdfs/datanode</value>
        <description>DataNode directory</description>
    </property>
    <property>
        <name>dfs.replication</name>
        <value>3</value>
    </property>
</configuration>

mapred-site.xml

<?xml version="1.0"?>
<configuration>
    <property>
        <name>mapreduce.framework.name</name>
        <value>yarn</value>
    </property>
</configuration>

yarn-site.xml

<?xml version="1.0"?>
<configuration>
    <property>
        <name>yarn.nodemanager.aux-services</name>
        <value>mapreduce_shuffle</value>
    </property>
    <property>
        <name>yarn.nodemanager.aux-services.mapreduce_shuffle.class</name>
        <value>org.apache.hadoop.mapred.ShuffleHandler</value>
    </property>
    <property>
        <name>yarn.resourcemanager.hostname</name>
        <value>hadoop-master</value>
    </property>
</configuration>

NameNode Information

在本机通过 Docker 启动 Hadoop 集群后共有一个 master 节点和四个 slave 节点,如下图:

技术分享

从管理后台截得下图:

技术分享

Datanode Information

共有四个计算节点,在 hdfs-site.xml 文件中设定每个数据都备份三份,如下图:

技术分享

MapReduce 编程

整体思路

整体设计思路如下图:

技术分享

每次迭代过程都是一个 Hadoop Job ,通过不断迭代计算得到新的中心点文件,然后跟旧的中心点文件进行比较,直到新的中心点与旧的中心点误差小于给定的阙值,此时迭代结束,最后一次得到的中心点为计算结果。

项目的 GitHub 地址: https://github.com/CHAAAAA/hadoop-kmeans

KMeansData (辅助 K-Means 算法的单个数据点类)

为辅助计算,设计了类 KMeansData ,其用于保存 K-Means 计算过程中的数据,并实现了 WritableComparable 接口使其支持在 Hadoop 计算过程中向下传递。

成员变量

  • Text kMeansData
    • 格式: 1 2 ... 6 7 ,多维数据用空格隔开, 使用 Double 解析
    • 含义: 一行多维数据;或者在 CombinerReducer 过程中累加的多维数据
  • IntWritable dataSize
    • 含义:该对象的 kMeansData 字段是有几行数据累加的

主要成员函数

  • public void add(KMeansData data, int dimension)

    在当前 KMeansData 对象上累加一个 KMeansData 对象,更新 kMeansDatadataSize 的值。该函数主要在 CombinerReducer 计算过程中用到。

    @param data 一个 KMeansData 数据对象

    @param dimension 数据对象的维度

    @throws KMeansCentroidFormatException

  • public String getNewCentroids()

    根据当前 KMeansData 对象生成新的中心点,也就是将 kMeansData 中各维的数据除以 dataSize。该函数主要用于 Reducer 过程中的最后一步,生成新中心点。

    @return 新中心点,数据以空格隔开

代码

public class KMeansData implements WritableComparable<KMeansData> {

    private Text kMeansData;
    private IntWritable dataSize;

    public KMeansData() {
        set(new Text(), new IntWritable());
    }

    public KMeansData(Text text, IntWritable intWritable) {
        this.kMeansData = text;
        this.dataSize = intWritable;
    }

    private void set(Text textWritable, IntWritable intWritable) {
        this.kMeansData = textWritable;
        this.dataSize = intWritable;
    }

    public Text getkMeansData() {
        return kMeansData;
    }

    public IntWritable getDataSize() {
        return dataSize;
    }

    /**
     * isEqual
     *
     * @param o
     * @return
     */
    public int compareTo(KMeansData o) {

        int flag = 0;
        if (kMeansData.compareTo(o.kMeansData) != 0 || dataSize.compareTo(o.dataSize) != 0) {
            flag = 1;
        }
        return flag;
    }

    public void write(DataOutput dataOutput) throws IOException {
        kMeansData.write(dataOutput);
        dataSize.write(dataOutput);
    }

    public void readFields(DataInput dataInput) throws IOException {
        kMeansData.readFields(dataInput);
        dataSize.readFields(dataInput);
    }

    /**
     * add a KMeansData to this
     * @param data a KMeansData
     * @param dimension dimension
     * @throws KMeansCentroidFormatException
     */
    public void add(KMeansData data, int dimension) throws KMeansCentroidFormatException {
        Text newData = data.kMeansData;
        String[] newStrings = newData.toString().trim().split(" ");
        String[] strings = kMeansData.toString().trim().split(" ");
        if (newStrings.length != dimension || strings.length != dimension) {
            throw new KMeansCentroidFormatException("Dimension Error");
        }

        StringBuffer result = new StringBuffer();

        for (int i = 0; i < dimension; i++) {
            double a = Double.parseDouble(newStrings[i]) + Double.parseDouble(strings[i]);
            DecimalFormat df = new DecimalFormat("0.0");
            result.append(df.format(a)).append(" ");
        }

        String r = result.toString().trim();
        this.kMeansData.set(r.substring(0, r.length() - 1));
        this.dataSize.set(this.dataSize.get() + data.dataSize.get());
    }

    /**
     * get the new Centroids
     * callback by Reducer
     * @return data
     */
    public String getNewCentroids() {
        StringBuffer r = new StringBuffer();
        String[] strings = kMeansData.toString().trim().split(" ");
        for (String s : strings) {
            double d = Double.parseDouble(s) / dataSize.get();
            DecimalFormat df = new DecimalFormat("0.0");
            r.append(df.format(d)).append(" ");
        }

        return r.toString().trim();
    }

    /**
     * return an ArrayList<Double> about this data
     * @return arrayList
     */
    public ArrayList<Double> getArray() {
        ArrayList<Double> arrayList = new ArrayList<Double>();
        String[] data = this.kMeansData.toString().trim().split(" ");
        for (String s : data) {
            arrayList.add(Double.parseDouble(s));
        }
        return arrayList;
    }
}

KMeansCentroids (辅助 K-Means 算法的中心点类)

该类用于实例化中心点文件,将中心点文件中的所有中心点保存在该类中用于计算和比较。

成员变量

  • Map<Integer, ArrayList<Double>> centroids
    • 格式: centroids 为一个 MapMapKey 为一个中心点的行号;Value 为一个 ArrayList ,存储了该中心点的数据。
  • int centroidDimension
    • 含义:该字段保存了中心点文件的数据维度

主要成员函数

  • private void initCentroid(String centroidPath)

    在初始化一个 KMeansCentroids 对象时会调用此函数,函数根据传入的文件路径,读取该文件并将数据存储在 centroidscentroidDimension 中。

    @param centroidPath 中心点文件的路径

    @throws KMeansCentroidFormatException

  • public int getCentroid(KMeansData point)

    根据传入的一个 KMeansData 对象,找到距离该数据最近的中心点并返回。使用欧几里得距离公式求距离。

    @param point 一个 KMeansData 数据对象

    @return 距离该数据最近的中心点的行号

    @throws KMeansCentroidFormatException

  • public boolean isEquals(KMeansCentroids o, Double error)

    计算该对象是否与传入的 KMeansCentroids 对象相等,当另个对象之间的数据差小于 error 时,我们也认为其相等

    @param o 一个 KMeansCentroids 数据对象

    @param error 两个 KMeansCentroids 对象之间允许的误差

    @return 距离该数据最近的中心点的行号

    @return 另个中心点文件是否相等

代码

public class KMeansCentroids {
    private static final Logger LOG = LogManager.getLogger(KMeansCentroids.class);

    private Map<Integer, ArrayList<Double>> centroids = new HashMap<Integer, ArrayList<Double>>();
    // Dimension
    private int centroidDimension = 3;


    public KMeansCentroids(String centroidPath, int centroidDimension) {
        this.centroidDimension = centroidDimension;
        initCentroid(centroidPath);
    }

    public KMeansCentroids(String centroidPath) {
        initCentroid(centroidPath);
    }

    /**
     * init centroids
     *
     * @param centroidPath centroids file uri
     */
    private void initCentroid(String centroidPath) {

        try {
            Configuration configuration = new Configuration();
            FileSystem fileSystem = FileSystem.get(URI.create(centroidPath), configuration);
            FSDataInputStream inputStream = null;
            try {
                LOG.debug("Start read centroids file,URI: " + centroidPath);
                inputStream = fileSystem.open(new Path(centroidPath));

                BufferedReader d = new BufferedReader(new InputStreamReader(inputStream));
                String line;
                while ((line = d.readLine()) != null) {
                    line = line.replace("\t", " ").trim();
                    String[] points = line.split(" ");
                    if (points.length != centroidDimension + 1) {
                        throw new KMeansCentroidFormatException("Centroid Dimension Error");
                    }
                    int index = Integer.valueOf(points[0]);

                    ArrayList<Double> oneCentroid = new ArrayList<Double>();

                    for (int i = 1; i <= centroidDimension; i++) {
                        oneCentroid.add(Double.valueOf(points[i]));
                    }
                    centroids.put(index, oneCentroid);
                }

                LOG.debug("Read centroids file success. Centroids: \n" + readCentroids());
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                IOUtils.closeStream(inputStream);
            }
        } catch (Exception e1) {
            e1.printStackTrace();
        }
    }

    public Map<Integer, ArrayList<Double>> getCentroids() {
        return centroids;
    }

    public int getCentroidDimension() {
        return centroidDimension;
    }

    /**
     * Get the Shortest Path Centroid index
     *
     * @param point data point
     * @return centroid
     * @throws KMeansCentroidFormatException
     */
    public int getCentroid(KMeansData point) throws KMeansCentroidFormatException {
        double distance = Double.MAX_VALUE;
        int r = 0;
        ArrayList<Double> pointData = point.getArray();

        for (Integer i : centroids.keySet()) {
            double temp = getEnumDistance(centroids.get(i), pointData);
            if (temp < distance) {
                distance = temp;
                r = i;
            }
        }
        return r;
    }

    /**
     * Get the Enum Distance
     *
     * @param centroid centroid
     * @param point    data point
     * @return Enum Distance
     */
    private double getEnumDistance(ArrayList<Double> centroid, ArrayList<Double> point) {
        double distance = 0.0;
        for (int i = 0; i < centroidDimension; i++) {
            distance += ((centroid.get(i) - point.get(i)) * (centroid.get(i) - point.get(i)));
        }
        distance = Math.sqrt(distance);
        return distance;
    }


    /**
     * Show the Centroids
     */
    public String readCentroids() {
        StringBuffer sb = new StringBuffer();

        for (Integer i : centroids.keySet()) {
            sb.append(i + "");
            for (Double j : centroids.get(i)) {
                sb.append(j + " ");
            }
            sb.append("\n");
        }

        return sb.toString();
    }


    /**
     * two KMeansCentroids is equal
     *
     * @param o     compare object
     * @param error allowable error
     * @return flag
     */
    public boolean isEquals(KMeansCentroids o, Double error) {
        boolean flag = true;
        for (Integer i : centroids.keySet()) {
            if (!arrayEquals(centroids.get(i), o.getCentroids().get(i), error)) {
                flag = false;
                break;
            }
        }
        return flag;
    }

    private boolean arrayEquals(ArrayList<Double> a, ArrayList<Double> b, Double error) {
        boolean flag = true;
        for (int i = 0; i < centroidDimension; i++) {
            if (Math.abs(a.get(i) - b.get(i)) > error) {
                flag = false;
                break;
            }
        }
        return flag;
    }
}

KMeansMapper

Mapper 计算过程:

<Object key, Text value> -> <IntWritable index, KMeansData data>

代码

    protected void map(Object key, Text value, Context context) throws IOException, InterruptedException {
        //get data
        String s = value.toString().trim();
        String[] fields = s.split(" ");
        if (fields.length == dimension) {
            KMeansData kMeansData = new KMeansData(new Text(s), new IntWritable(1));
            try {
                int index = centroids.getCentroid(kMeansData);
                context.write(new IntWritable(index), kMeansData);
            } catch (KMeansCentroidFormatException e) {
                throw new IOException();
            }
        }
    }

KMeansCombiner

Combiner 计算过程:

<IntWritable index, KMeansData data> -> <IntWritable index, KMeansData data>

代码

    protected void reduce(IntWritable key, Iterable<KMeansData> values, Context context) throws IOException, InterruptedException {

        StringBuilder stringBuilder = new StringBuilder();
        for (int i = 0; i < dimension; i++) {
            stringBuilder.append(0).append(" ");
        }

        KMeansData data = new KMeansData(new Text(stringBuilder.toString().trim()), new IntWritable(dimension));
        for (KMeansData val : values) {
            try {
                data.add(val, dimension);
            } catch (KMeansCentroidFormatException e) {
                throw new IOException();
            }
        }
        context.write(key, data);
    }

KMeansReducer

Reducer 计算过程:

<IntWritable index, KMeansData data> -> <IntWritable index, Text centroid>

代码

    protected void reduce(IntWritable key, Iterable<KMeansData> values, Context context) throws IOException, InterruptedException {
        StringBuilder stringBuilder = new StringBuilder();
        for (int i = 0; i < dimension; i++) {
            stringBuilder.append(0).append(" ");
        }
        KMeansData data = new KMeansData(new Text(stringBuilder.toString().trim()), new IntWritable(dimension));
        for (KMeansData val : values) {
            try {
                data.add(val, dimension);
            } catch (KMeansCentroidFormatException e) {
                throw new IOException();
            }
        }

KMeansRun

这是 Hadoop 作业的执行入口,每次迭代计算后都判断一下是否进行下次迭代。程序启动时需要 1 个或者 3 个参数,说明如下:

  • 1 个参数:参数需要表示本次计算的数据维度
  • 3 个参数:数据维度 输入文件夹 初始数据中心点路径

其中初始数据中心点的文件名称必须为 centroid

代码

public class KMeansRun {

    private static final Logger LOG = LogManager.getLogger(KMeansRun.class);

    //default centroid path /centroid$times
    private static String centroidPath = "kmeans/";
    private static String inputPath = "input/kmeans/";
    private static Integer dimension = 3;

    public static void main(String[] args) throws Exception {
        //iteration times
        int iterations = 0;

        Configuration conf = new Configuration();

        GenericOptionsParser optionParser = new GenericOptionsParser(conf, args);
        String[] remainingArgs = optionParser.getRemainingArgs();
        if (remainingArgs.length != 1 && remainingArgs.length != 3) {
            System.err.println("Usage: K-Means <dimension> [in] [centroidRootPath(without filename , default filename is centroid)]");
            System.exit(2);
        }

        dimension = Integer.valueOf(args[0]);
        if (remainingArgs.length == 3) {
            inputPath = args[1];
            centroidPath = args[2];
            if (!args[2].endsWith(File.separator)) {
                centroidPath = args[2] + File.separator;
            }
        }

        //set dimension
        conf.set("dimension", dimension.toString());

        String oldCentroidPath = centroidPath + "centroid";
        String currentCentroidPath = centroidPath + "centroid";
        do {
            conf.set("centroid.path", currentCentroidPath);

            Job job = Job.getInstance(conf, "K-Means" + iterations);
            job.setJarByClass(KMeansRun.class);
            job.setMapperClass(KMeansMapper.class);
            job.setCombinerClass(KMeansCombiner.class);
            job.setReducerClass(KMeansReducer.class);

            job.setMapOutputKeyClass(IntWritable.class);
            job.setMapOutputValueClass(KMeansData.class);
            job.setOutputKeyClass(IntWritable.class);
            job.setOutputValueClass(Text.class);
            FileInputFormat.addInputPath(job, new Path(inputPath));
            FileOutputFormat.setOutputPath(job, new Path(centroidPath + iterations + "/"));
            if (!job.waitForCompletion(true)) {
                System.exit(1);
            }

            oldCentroidPath = currentCentroidPath;
            currentCentroidPath = centroidPath + iterations + "/part-r-00000";
            iterations++;
        } while (isContinue(oldCentroidPath, currentCentroidPath));


    }

    private static boolean isContinue(String oldPath, String newPath) {
        boolean flag = false;

        KMeansCentroids oldCentroids = new KMeansCentroids(oldPath, dimension);
        KMeansCentroids newCentroids = new KMeansCentroids(newPath, dimension);
        if (!oldCentroids.isEquals(newCentroids, 0.1)) {
            flag = true;
        }
        return flag;
    }
}

参考

K-Means 算法(CoolShell)

机器学习算法-K-means聚类

k-平均算法(wikipedia)

基于Docker搭建Hadoop集群之升级版

K-Means 算法的 Hadoop 实现

标签:创建   数据存储   cli   hand   output   turn   leo   mpi   成员   

原文地址:http://blog.csdn.net/chaaaa_wangyc/article/details/53426612

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!