码迷,mamicode.com
首页 > 编程语言 > 详细

K-近邻算法入门

时间:2018-12-19 19:36:26      阅读:240      评论:0      收藏:0      [点我收藏+]

标签:sqrt   https   验证   class   items   欧几里得   lines   nes   euc   

  K-近邻算法的直观理解就是:给定一个训练集合,对于新的实例,在训练集合中找到k个与该实例最近的邻居,然后根据“少数服从多数”原则判断该实例归属于哪一类,又称“随大流”

K-近邻算法的三大要素:K值得选取,邻居距离度量,分类决策的制定。

(1)K值选取:通常采用交叉验证选取最优的K值(自己了解)

(2)邻居距离度量:根据不同的应用场景选取相应的距离度量。常见的距离度量有欧几里得距离、曼哈顿距离、马氏距离。同时要注意的是归一化机制。

(3)分类决策制定:一般分为平等投票表决原则和加权投票原则。

import operator
import csv
import math
import random

def loadDataSet(filename,split,trainingSet=[],testSet=[]):
    #读取本地数据#
    with open(filename,r) as csvfile:
        lines=csv.reader(csvfile)
        dataset=list(lines)
        for x in range(len(dataset)-1):
            for y in range (4):
                dataset[x][y]=float(dataset[x][y])
            if random.random()<split:
                trainingSet.append(dataset[x])
            else:
                testSet.append(dataset[x])

def EuclidDist(instance1,instance2,len):
    #求欧几里得距离#
    distance=0.0
    for x in range(len):
        distance+=pow((instance1[x]-instance2[x]),2)
    return math.sqrt(distance)


def getNeighbors(trainSet,testInstance,k):
    #获取最近邻居#
    distance=[]
    length=len(testInstance)-1
    for x in range(len(trainSet)):
        dist=EuclidDist(testInstance,trainSet[x],length)
        distance.append((trainSet[x],dist))
    distance.sort(key=operator.itemgetter(1))
    #列表的sort(key)方法用来根据关键字排序
    neighbors=[]
    for x in range(k):
        neighbors.append(distance[x][0])
    return neighbors

def getClass(neighbors):
    #分类与评估函数#
     classVotes={}
     for x in range(len(neighbors)):
         instance_class=neighbors[x][-1]
         if instance_class in classVotes:
             classVotes[instance_class]+=1
         else:
             classVotes[instance_class]=1
         sortedVotes=sorted(classVotes.items(),key=operator.itemgetter(1),reverse=True)
     return sortedVotes[0][0]

def getAccuracy(testSet,predictions):
    #预测正确率计算#
    correct=0
    for x in range(len(testSet)):
        if testSet[x][-1]==predictions[x]:
            correct+=1
    return (correct/float(len(testSet)))*100.0

def main():
    trainingSet=[]
    testSet=[]
    split=0.7
    loadDataSet(iris.data.csv,split,trainingSet,testSet)
    print(训练集合:+repr(len(trainingSet)))
    print(测试集合:+repr(len(testSet)))
    predictions=[]
    k=3
    for x in range(len(testSet)):
        neighbors=getNeighbors(trainingSet,testSet[x],k)
        result=getClass(neighbors)
        predictions.append(result)
        print(>预测=+repr(result)+,实际=+repr(testSet[x][-1]))
    accuracy=getAccuracy(testSet,predictions)
    print(精确度为:+repr(accuracy)+%)

main()

针对此代码中的数据来源为UCI机器学习库中的鸢尾花卉数据集,可以直接获取(https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data),也可以下载我转换好的CSV文件(链接:https://pan.baidu.com/s/1YSLhrPMn3RflGE8VDGGbHQ 提取码:42se )

本次范例属于“自己动手丰衣足食”,每个函数都自己实现,可以在入门阶段对K-近邻算法流程有个初步认识,在有了一定基础之后,我们就没有必要重造轮子,可以使用常见的机器学习算法,毕竟其专业性远远目前超过我们自己的程序。例如scikit-learn模块。

K-近邻算法入门

标签:sqrt   https   验证   class   items   欧几里得   lines   nes   euc   

原文地址:https://www.cnblogs.com/zjq-115/p/10145198.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!