码迷,mamicode.com
首页 > 编程语言 > 详细

基于spark的kmeans算法

时间:2018-10-30 21:21:34      阅读:247      评论:0      收藏:0      [点我收藏+]

标签:read   lin   seve   中心   tor   turn   array   rdd   bsp   

from __future__ import print_function

import sys

import numpy as np
from pyspark.sql import SparkSession


def parseVector(line):
    return np.array([float(x) for x in line.split(‘ ‘)])


def closestPoint(p, centers):
    bestIndex = 0
    closest = float("+inf")
    for i in range(len(centers)):
        tempDist = np.sum((p - centers[i]) ** 2)
        if tempDist < closest:
            closest = tempDist
            bestIndex = i
    return bestIndex


if __name__ == "__main__":

    if len(sys.argv) != 4:
        print("Usage: kmeans <file> <k> <convergeDist>", file=sys.stderr)
        sys.exit(-1)

    spark = SparkSession        .builder        .appName("PythonKMeans")        .getOrCreate()

    lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
    data = lines.map(parseVector).cache()
//聚类超参数K K
= int(sys.argv[2])
//收敛阈值 convergeDist
= float(sys.argv[3]) //初始化K个中心点 kPoints = data.takeSample(False, K, 1) tempDist = 1.0 while tempDist > convergeDist:
// map Key: 聚类中心点 Value: (当前点, 数量1) closest
= data.map( lambda p: (closestPoint(p, kPoints), (p, 1)))
// reduce Key:聚类中心点, 计算每个聚类中心点下的分布 pointStats
= closest.reduceByKey( lambda p1_c1, p2_c2: (p1_c1[0] + p2_c2[0], p1_c1[1] + p2_c2[1]))
//map 计算新的中心点 newPoints
= pointStats.map( lambda st: (st[0], st[1][0] / st[1][1])).collect() tempDist = sum(np.sum((kPoints[iK] - p) ** 2) for (iK, p) in newPoints) for (iK, p) in newPoints: kPoints[iK] = p print("Final centers: " + str(kPoints)) spark.stop()

 

基于spark的kmeans算法

标签:read   lin   seve   中心   tor   turn   array   rdd   bsp   

原文地址:https://www.cnblogs.com/energy1010/p/9879043.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!