kd树就是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构,可以运用在k近邻法中,实现快速k近邻搜索。构造kd树相当于不断地用垂直于坐标轴的超平面将k维空间切分,依次选择坐标轴对空间进行切分,选择训练实例点在选定坐标轴上的中位数为切分点。具体kd树的原理可以参考kd树的原理。
代码是参考《统计学习方法》k近邻 kd树的python实现得到
首先创建一个类,用于表示树的节点,包括:该节点的值,该节点的切分轴,左子树,右子树
class decisionnode: def __init__(self,value=None,col=None,rb=None,lb=None): self.value=value self.col=col self.rb=rb self.lb=lb
切分点为坐标轴上的中值,下面代码求得一个序列的中值
def median(x): n=len(x) x=list(x) x_order=sorted(x) return x_order[n//2],x.index(x_order[n//2])
然后就可以构造一颗kd树,左子树小于切分点,右子树大于切分点
def buildtree(x,j=0): rb=[] lb=[] m,n=x.shape if m==0: return None edge,row=median(x[:,j].copy()) for i in range(m): if x[i][j]>edge: rb.append(i) if x[i][j]<edge: lb.append(i) rb_x=x[rb,:] lb_x=x[lb,:] rightBranch=buildtree(rb_x,(j+1)%n) leftBranch=buildtree(lb_x,(j+1)%n) return decisionnode(x[row,:],j,rightBranch,leftBranch)
接下来是树的搜索过程,可以用下图表示树的搜索过程,具体过程可以参考kd树的原理。
代码如下:
#搜索树:nearestPoint,nearestValue均为全局变量 def traveltree(node,point): global nearestPoint,nearestValue if node==None: return print(node.value) print(‘---‘) col=node.col if point[col]>node.value[col]: traveltree(node.rb,point) if point[col]<node.value[col]: traveltree(node.lb,point) dis=dist(node.value,point) print(dis) if dis<nearestValue: nearestPoint=node nearestValue=dis #print(‘nearestPoint,nearestValue‘ % (nearestPoint,nearestValue)) if node.rb!=None or node.lb!=None: if abs(point[node.col] - node.value[node.col]) < nearestValue: if point[node.col]<node.value[node.col]: traveltree(node.rb,point) if point[node.col]>node.value[node.col]: traveltree(node.lb,point) def searchtree(tree,aim): global nearestPoint,nearestValue #nearestPoint=None nearestValue=float(‘inf‘) traveltree(tree,aim) return nearestPoint def dist(x1, x2): #欧式距离的计算 return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5
完整代码在此处取
1 import numpy as np 2 from numpy import array 3 class decisionnode: 4 def __init__(self,value=None,col=None,rb=None,lb=None): 5 self.value=value 6 self.col=col 7 self.rb=rb 8 self.lb=lb 9 10 #读取数据并将数据转换为矩阵形式 11 def readdata(filename): 12 data=open(filename).readlines() 13 x=[] 14 for line in data: 15 line=line.strip().split(‘\t‘) 16 x_i=[] 17 for num in line: 18 num=float(num) 19 x_i.append(num) 20 x.append(x_i) 21 x=array(x) 22 return x 23 24 #求序列的中值 25 def median(x): 26 n=len(x) 27 x=list(x) 28 x_order=sorted(x) 29 return x_order[n//2],x.index(x_order[n//2]) 30 31 #以j列的中值划分数据,左小右大,j=节点深度%列数 32 def buildtree(x,j=0): 33 rb=[] 34 lb=[] 35 m,n=x.shape 36 if m==0: return None 37 edge,row=median(x[:,j].copy()) 38 for i in range(m): 39 if x[i][j]>edge: 40 rb.append(i) 41 if x[i][j]<edge: 42 lb.append(i) 43 rb_x=x[rb,:] 44 lb_x=x[lb,:] 45 rightBranch=buildtree(rb_x,(j+1)%n) 46 leftBranch=buildtree(lb_x,(j+1)%n) 47 return decisionnode(x[row,:],j,rightBranch,leftBranch) 48 49 #搜索树:nearestPoint,nearestValue均为全局变量 50 def traveltree(node,point): 51 global nearestPoint,nearestValue 52 if node==None: return 53 print(node.value) 54 print(‘---‘) 55 col=node.col 56 if point[col]>node.value[col]: 57 traveltree(node.rb,point) 58 if point[col]<node.value[col]: 59 traveltree(node.lb,point) 60 dis=dist(node.value,point) 61 print(dis) 62 if dis<nearestValue: 63 nearestPoint=node 64 nearestValue=dis 65 #print(‘nearestPoint,nearestValue‘ % (nearestPoint,nearestValue)) 66 if node.rb!=None or node.lb!=None: 67 if abs(point[node.col] - node.value[node.col]) < nearestValue: 68 if point[node.col]<node.value[node.col]: 69 traveltree(node.rb,point) 70 if point[node.col]>node.value[node.col]: 71 traveltree(node.lb,point) 72 73 def searchtree(tree,aim): 74 global nearestPoint,nearestValue 75 #nearestPoint=None 76 nearestValue=float(‘inf‘) 77 traveltree(tree,aim) 78 return nearestPoint 79 80 81 def dist(x1, x2): #欧式距离的计算 82 return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5