一、定义
url:https://en.wikipedia.org/wiki...
In pattern recognition, the k-nearest neighbors algorithm (k-NN) is a non-parametric method used for classification and regression.[1] In both cases, the input consists of the k closest training examples in the feature space. The output depends on whether k-NN is used for classification or regression:In k-NN classification, the output is a class membership. An object is classified by a majority vote of its neighbors, with the object being assigned to the class most common among its k nearest neighbors (k is a positive integer, typically small). If k = 1, then the object is simply assigned to the class of that single nearest neighbor.
In k-NN regression, the output is the property value for the object. This value is the average of the values of its k nearest neighbors.
二、个人理解
其实简单理解就是:通过计算新加入点与附近K个点的距离,然后寻找到距离最近的K个点,进行占比统计,找到k个点中数量占比最高的target,那么新加入的样本,它的target就是频数最高的target
三、实践
语言:python3
欧拉距离:
# -*- coding: utf-8 -*-
"""
Created on Sat Mar 17 11:17:18 2018
@author: yangzinan
"""
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from math import sqrt
from collections import Counter
# 样本
x= [
[3.393533211,2.331273381],
[3.110073483,1.781539638],
[1.343808831,3.368360954],
[3.582294042,4.679179110],
[2.280362439,2.866990263],
[7.423436942,4.696522875],
[5.745051997,3.533989803],
[9.172168622,2.511101045],
[7.792783481,3.424088941],
[7.939820817,0.791637231]
]
y= [0,0,0,0,0,1,1,1,1,1]
x_train = np.array(x)
y_train = np.array(y)
# 绘图
plt.scatter(x_train[y_train==0,0],x_train[y_train==0,1],color="red")
plt.scatter(x_train[y_train==1,0],x_train[y_train==1,1],color="green")
x_point = np.array([8.093607318,3.365731514])
plt.scatter(x_point[0],x_point[1],color="blue")
plt.show()
#计算距离 欧拉距离
distances = []
for d in x_train:
# 求出和x相差的距离
d_sum = sqrt(np.sum(((d-x)**2)))
distances.append(d_sum)
print(distances)
#求出最近的点
#按照从小到大的顺序,得到下标
nearest = np.argsort(distances)
#指定应该求出的个数
k = 3
topK_y = []
#求出前K个target
for i in nearest[:k]:
topK_y.append(y_train[i])
#得到频数最高的target,那么新加入点target 就是频数最高的
predict_y = Counter(topK_y).most_common(1)[0][0]
print(predict_y)
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。