码迷,mamicode.com
首页 > 其他好文 > 详细

kd树的创建和求最近邻

时间:2019-10-22 23:45:04      阅读:110      评论:0      收藏:0      [点我收藏+]

标签:nump   orm   距离   turn   __init__   style   shape   深度   取绝对值   

 1 import numpy as np
 2 arr = np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]])
 3 arr.shape
 4 
 5 class KDTree():
 6     def __init__(self):
 7         self.value = None
 8         self.left = None
 9         self.right = None
10         self.axis = None
11 
12 def create(arr, k, h=0):
13     if arr.shape[0] == 0:
14         return None
15     tree = KDTree()
16     axis = h % k
17     
18     if arr.shape[0] == 1:
19         tree.value = arr[0]
20         tree.left = None
21         tree.right = None
22         tree.axis = axis
23     else:
24         arr = sorted(arr, key = lambda x:x[axis])
25         arr = np.array(arr)
26         i = arr.shape[0]//2
27         
28         tree.value =  arr[i]
29         tree.left = create(arr[0:i], k, h+1)
30         tree.right = create(arr[i+1:], k, h+1)
31         tree.axis = axis
32     return tree
33 
34 k = KDTree()
35 
36 k = create(arr, arr.shape[1])
37 
38 def preOrder(k):
39     print(当前节点: + str(k.value))
40     
41     if k.left:
42         preOrder(k.left)
43     if k.right:
44         preOrder(k.right)
45 
46 preOrder(k)
47 
48 def dis(a, b):
49     return np.linalg.norm(a-b)
50 def search(kd, goal, k, h=0):
51     ‘‘‘输入:kd树,目标点、特征维度k以及当前深度h‘‘‘
52     ‘‘‘输出:在kd树上的与目标点距离(欧氏距离)最近的距离‘‘‘
53     if kd.left == None and kd.right == None:
54         return dis(goal, kd.value)
55     if kd.left == None:
56         return min(search(kd.right, goal, k, h+1), dis(kd.value, goal))
57     if kd.right == None:
58         return min(search(kd.left, goal, k, h+1), dis(kd.value, goal))
59     axis = h%k
60     
61     if goal[axis] < kd.value[axis]:
62         cur_dis = search(kd.left, goal, k, h+1)
63     else:
64         cur_dis = search(kd.right, goal, k, h+1)
65     
66     
67     if cur_dis < kd.value[axis]-goal[axis]:////cut  取绝对值
68         return cur_dis;
69     else:
70         if goal[axis] < kd.value[axis]:
71             cur_dis = min(search(kd.right, goal, k, h+1), cur_dis, dis(kd.value, goal))
72         else:
73             cur_dis = min(search(kd.left, goal, k, h+1), cur_dis, dis(kd.value, goal))
74     return cur_dis
75 
76 search(k, np.array([9, 6]), 2)

 

kd树的创建和求最近邻

标签:nump   orm   距离   turn   __init__   style   shape   深度   取绝对值   

原文地址:https://www.cnblogs.com/liuwenhan/p/11723354.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!