码迷,mamicode.com
首页 > 其他好文 > 详细

KNN 分类程序

时间:2019-07-24 22:37:11      阅读:104      评论:0      收藏:0      [点我收藏+]

标签:打开   sort   pyplot   count   百分比   with   group   map排序   测试数据   

# coding: utf-8
import numpy as np
import operator
import matplotlib
from numpy import *
import matplotlib.pyplot as plt
import os


def CreateDataSet():
  group = np.array([
    [1.0, 1.1],
    [1.0, 1.0],
    [0.0, 0.0],
    [0.0, 0.1]])
  label = [a, a, b, b]
  return group, label


def Classify(intx, datax, label, k):
  datasize = datax.shape[0]
  diffmat = np.tile(intx, (datasize, 1)) - datax #每一位相减
  sqdiffmat = diffmat ** 2 #每一位平方
  sqdistence = sqdiffmat.sum(axis=1) #axis=1按照行求和 axix=0按照列求和
  distence = sqdistence ** 0.5
  sorteddistenceindicies = distence.argsort()
  classcount = {}
  for i in range(k):
    voteilabel = label[sorteddistenceindicies[i]]
    classcount[voteilabel] = classcount.get(voteilabel, 0) + 1 #map标记
  sortedclasscount = sorted(classcount.items(), key=operator.itemgetter(1), reverse = True) #map排序
  return sortedclasscount[0][0]
def file2matrix(filename) :
  with open(filename, mode = "r") as fr : #表示打开文件,使用这一句会系统自动调用 fr.close关闭文件,无论文件是否打开都会调用
    arrayolines = fr.readlines() #https://blog.csdn.net/liuyhoo/article/details/80756812
    numberoflines = len(arrayolines)
    returnmat = np.zeros((numberoflines, 3)) #生成一个 num * 3 d的全0矩阵
    labels = []
    index = 0
    for line in arrayolines :
      listfromline = line.split("\t") #数据中间是\t 结尾是\n
      returnmat[index, :] = listfromline[0: 3]
      labels.append(int(listfromline[-1])) # 处理结尾\n
      index = index + 1
    return returnmat, labels

def autonorm(datax) :
  minval = datax.min(0) #min() 表示矩阵中最小是 min(0)表示每列中最小值 min(1)表示每行中最小值
  maxval = datax.max(0)
  ranges = maxval - minval
  rows = datax.shape[0] #查看矩阵的维数
  newval = datax - tile(minval, (rows, 1)) #minval是三维,后面的是生成的矩阵为 rows * 1 倍
  newval = newval / tile(ranges, (rows, 1)) # 矩阵除法相当于c中每一位直接整除
  return newval, ranges, minval

def datingClassTest():
  hoRatio = 0.1  # 设置测试集百分比
  filename = "datingTestSet2.txt"
  dataX, labels = file2matrix(filename) #读数据
  normMat, ranges, minVals = autonorm(dataX)  # 归一化
  m = dataX.shape[0]  #numbers of rows
  numTestVecs = int(m * hoRatio)
  errorcount = 0  # 错误数
  for i in range(numTestVecs):
    classifierResult = Classify(normMat[i, :], normMat[numTestVecs:m, :], labels[numTestVecs:m], 5)  # 前10%作为测试数据
    #   print("the classifier predict %d, the real answer is :%d" %((classifierResult),labels[i]))
    if (classifierResult != labels[i]):
      errorcount = errorcount + 1.0
  print("error rate :%f" % ((errorcount) / (numTestVecs)))

def plot():  # 画datingTestSet2.txt这个数据的图像
  k = 3
  filename = "datingTestSet2.txt"
  dataX, labels = file2matrix(filename)
  fig = plt.figure() #创建一个图
  ax = fig.add_subplot(111)
  ax.scatter(dataX[:, 0], dataX[:, 1], c=15 * np.array(labels), s=15 * np.array(labels))
  ax = fig.add_subplot(121)
  ax.scatter(dataX[:, 0], dataX[:, 2], c=15 * np.array(labels), s=15 * np.array(labels))
  ax = fig.add_subplot(131)
  ax.scatter(dataX[:, 1], dataX[:, 2], c=15 * np.array(labels), s=15 * np.array(labels))
  plt.show()

if __name__ == __main__:
    # plot()
    datingClassTest()

 搬运门

KNN 分类程序

标签:打开   sort   pyplot   count   百分比   with   group   map排序   测试数据   

原文地址:https://www.cnblogs.com/lalalatianlalu/p/11241152.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!