码迷,mamicode.com
首页 > 编程语言 > 详细

KNN算法

时间:2016-03-01 00:44:25      阅读:323      评论:0      收藏:0      [点我收藏+]

标签:

KNN算法的介绍请参考:

http://blog.csdn.net/zouxy09/article/details/16955347

统计学习方法里面给出了KD Tree的算法介绍,按照书上的进行了实现:

技术分享
# -*- coding: utf-8 -*-

from operator import itemgetter
from copy import deepcopy
import heapq


class Node(object):
    def __init__(self, dim, label=None, parent = None,
                 split = 0):
        """
        kd树的节点
        :param dim: 节点的向量
        :param label: 节点的label
        :param parent: 父节点
        :param split: 第 split 进行切分
        :return:
        """
        self.dim = deepcopy(dim)
        self.label = deepcopy(label)
        self.left_node = None
        self.right_node = None
        self.parent = parent
        self.split = split


class KdTree(object):

    def __init__(self):
        """
        j 主要是一个递增值,用来计算当前使用哪个维度进行切面
        :return:
        """
        self.j = 0

    def __get_kd_tree(self, samples, k, parent_node):
        """
        生成kd树,主要采用递归
        :param samples: 样本[[(1,2,3),‘A‘],[(2,3,4),‘b‘]]
        :param k: 样本的维度
        :param parent_node: 父节点
        :return:
        """

        if samples is None or len(samples) == 0:
            return None

        #计算切面
        l = self.j % k
        self.j = self.j + 1

        #对样本进行排序,并取中位数
        samples.sort(key=lambda s:s[0][l])
        len_sam = len(samples)
        mid_index = len_sam / 2
        mid_value = samples[mid_index][0][l]
        i = 0
        while i < len_sam and samples[i][0][l] < mid_value:
            i += 1

        #将中位数对应的样本设置为当前节点
        root_node = Node(samples[i][0], samples[i][1])
        root_node.parent = parent_node
        root_node.split = l
        if 0 == i:
            left_samples = []
        else:
            left_samples = samples[0:i]
        if i >= len_sam - 1:
            right_samples = []
        else:
            right_samples = samples[i+1:]

        #del(samples[mid_index])
        root_node.left_node = self.__get_kd_tree(left_samples, k, root_node)
        root_node.right_node = self.__get_kd_tree(right_samples, k, root_node)

        return root_node

    def get_kd_tree(self, samples):
        """
        :param samples: [[(1,2,3),‘A‘],[(2,3,4),‘b‘]]
        :return:
        """
        return  self.__get_kd_tree(samples, len(samples[0][0]), None)

    def cal_dist(self, target, sample):
        """
        欧拉距离
        :param target: 目标样本
        :param sample: 需要计算距离的样本
        :return:
        """

        dis = 0.

        for i in range(0, len(target.dim)):
           dis += (target.dim[i] - sample.dim[i]) ** 2
        dis = dis ** 0.5

        return dis


    def __insert_heap(self, k, dis, node, heap):
        """
        python 的 heap 是小顶堆, 将数值设置为负数,就变成了大顶堆
        [(-dis,node)]
        :param k:
        :param dis:
        :param node:
        :param heap:
        :return:
        """
        if len(heap) < k:
            heapq.heappush(heap, (-dis, node))
        else:
            d  = - heap[0][0]
            if dis < d:
                heapq.heapreplace(heap, (-dis, node))

    def get_k_neighbors(self, target, root, k, heap):
        """

        :param root: kd树的根节点
        :param k: 邻居个数
        :return:
        """
        if root is None:
            return

        s = root.split
        if target.dim[s] < root.dim[s]:
            self.get_k_neighbors(target, root.left_node, k, heap)
        else:
             self.get_k_neighbors(target, root.right_node, k, heap)

        dis = self.cal_dist(target, root)
        self.__insert_heap(k, dis, root, heap)

        if root.parent is not None:
            father_s = root.parent.split
            check_node = None
            if root.dim[father_s] < root.parent.dim[father_s]:
                check_node = root.parent.right_node
            else:
                check_node = root.parent.left_node

            smallest = heapq.nlargest(1, heap)
            if check_node is not None and self.cal_dist(target, check_node) < -smallest[0][0]:
                self.get_k_neighbors(target, check_node, k, heap)
            else:
                return

    def get_label_of_sample(self, heap):
        lable_dict = {}
        for i in range(0, len(heap)):
            node_label = heap[i][1].label
            if lable_dict.has_key(node_label):
                lable_dict[node_label] = lable_dict[node_label] + 1
            else:
                lable_dict[node_label] = 1

        max = 0
        max_label = ‘‘
        for key in lable_dict.keys():
            if lable_dict[key] > max:
                max = lable_dict[key]
                max_label = key
        return max_label


if __name__ == __main__:
    samples = [[(2,3),"A"],[(5,4),"B"],[(9,6),"C"],[(4,7),"D"],[(8,1),"E"],[(7,2),"F"]]
    kd = KdTree()
    kd_root = kd.get_kd_tree(samples)
    print kd_root.dim
    heap = []
    target_node = Node((2.1, 3.1), "P")
    kd.get_k_neighbors(target_node, kd_root, 2, heap)
    print heap[0][1].dim
    print heap[1][1].dim
    print kd.get_label_of_sample(heap)
    print samples
kd tree

实现了后,利用上面博客给的手写数据集进行了下测试

技术分享
# -*- coding: utf-8 -*-

import os

import numpy as np

import kd_tree

class KnnDigits(object):

    def __init__(self):
        pass

    def img2array(self, filename):
        """

        :return:
        """
        rows = 32
        cols = 32

        img_array = np.zeros(rows * cols)

        with open(filename) as read_fp:
            for row in xrange(0, rows):
                line_str = read_fp.readline()
                for col in xrange(0, cols):
                    img_array[row * rows + col] = int(line_str[col])
                    #img_array[row] += int(line_str[col])
        return img_array

    def load_data(self, data_dir):
        """

        :param data_dir:
        :return:
        """
        samples = []
        files_list = os.listdir(data_dir)
        num_samples = len(files_list)
        for i in xrange(0, num_samples):
            file_name = os.path.join(data_dir, files_list[i])
            img_array = self.img2array(file_name)
            img_label = int(files_list[i].split(_)[0])
            samples.append([img_array, img_label])

        return samples

    def run(self, train_dir, test_dir):
        """

        :param train_dir:
        :param test_dir:
        :return:
        """
        train_samples = self.load_data(train_dir)
        test_samples = self.load_data(test_dir)
        kd = kd_tree.KdTree()
        kd_root = kd.get_kd_tree(train_samples)
        nums_test_samples = len(test_samples)
        match_count = 0
        for i in range(0, nums_test_samples):
            heap = []
            target_node = kd_tree.Node(test_samples[i][0], test_samples[i][1])
            kd.get_k_neighbors(target_node, kd_root, 3, heap)
            pridict_label = kd.get_label_of_sample(heap)
            if pridict_label == test_samples[i][1]:
                match_count += 1
            #print "pridict label is %s and test lable is %s" %(pridict_label, test_samples[i][1])
        accur = float(match_count) / nums_test_samples

        return accur

if __name__ == __main__:
    train_dir = "/Users/baidu/PycharmProjects/statistics_learning_method/digits/trainingDigits"
    test_dir = "/Users/baidu/PycharmProjects/statistics_learning_method/digits/testDigits"

    knn = KnnDigits()
    print knn.run(train_dir, test_dir)
View Code

我的娘亲哟,结果只有0.879492600423

这说明kd 树实现的不好,并且生成的树不平衡,并且很可能有BUG。

算法改进:

http://my.oschina.net/keyven/blog/221792

 

KNN算法

标签:

原文地址:http://www.cnblogs.com/SpeakSoftlyLove/p/5229346.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!