k近邻法的思想
你打算预测我会在大选中投票给谁。假设你对我一无所知,一个明智的方法是看看我的邻居们都投票给谁。当然,你可能还知道我的年龄、收入、有几个孩子,等等,根据我的行为受这些维度影响的程度,你可以观察在这些维度上最接近我的邻居们而不是我所有的邻居会得到更好的预测结果。这就是最近邻分类(nearest neighbors classification)方法背后的思想。
k近邻法的优缺点
k近邻法是最简单的预测模型之一,它没有多少数学上的假设,也不要求任何复杂的数学处理,它所要求的仅仅是:
某种距离的概念
彼此接近的点具有相似性质的假设(否则用近邻来预测结果就是不合理的)
k近邻法有意忽略了大量信息,对每个新的数据点的预测只依赖少量最接近它的点。
此外,它不能解释为什么。例如,基于我邻居的投票行为来预测我的投票并不能告诉你我为什么要这样投票,而基于我的收入、婚姻等因素来预测我的投票行为的模型则能揭示我投票的原因。
案例:最喜欢的编程语言
假设我们有一份数据,这份数据是各个城市经纬度及该城市最受欢迎的编程语言的集合。数据以列表存储。
cities = [(-86.75,33.5666666666667,'Python'),(-88.25,30.6833333333333,'Python'),(-112.016666666667,33.4333333333333,'Java')......]
数据可视化
我们将每种语言及其经度(x)、纬度(y)按如下的格式存储到字典中:键是语言,值是成对的数据。
plots={"Java":([],[]),"python":([],[]),"R":([],[])}
每种语言用不同的符号和颜色标记:
markers = { "Java" : "o", "Python" : "s", "R" : "^" }
colors = { "Java" : "r", "Python" : "b", "R" : "g" }
将cities列表中的数据存放到plots字典中:
for (longitude, latitude), language in cities:
plots[language][0].append(longitude)
plots[language][1].append(latitude)
我们可以用items方法来返回字典元素的列表:
In [1]: plots.items()
Out[1]:
[('Python',([-86.75, -88.25, -118.15......],[33.5666666666667, 30.6833333333333, 33.8166666666667.....])),('R'......),('Java'.....)]
这样我们就可以用for语句来循环遍历字典元素,为每种语言创建一个散点序列:
import matplotlib.pyplot as plt
for language, (x, y) in plots.items():
plt.scatter(x, y, color=colors[language], marker=markers[language],
label=language, zorder=10)
plt.legend(loc=0) #让matplotlib选择一个位置
plt.axis([-130,-60,20,55]) #设置坐标轴
plt.title("most popular language") #设置图标标题
put.show()
最终效果如下:
k近邻法的python实现
下面的点是cities列表中的第一个点,这个城市最受欢迎的编程语言是python:
In [2]: cities[0]
Out[2]: ([-86.75, 33.5666666666667], 'Python')
假设我们不知道这个城市最受欢迎的语言是什么。根据k近邻法的思想,为了预测结果,(1)我们首先需要知道这个点即这个城市与其他所有点的距离,(2)然后找到离这个点最近的某个点或几个点最受欢迎的编程语言是什么,以此作为预测结果,(3)如果是几个点,我们需要计算哪种语言出现的次数最多,以此作为预测结果。
思路清楚了,让我们来一步步实现吧。
先将除其他城市存放在other_cities列表中(这里用列表解析式遍历所有城市,找到与该城市不同的所有城市):
other_cities=[other_city for other_city in cities if other_city != (cities[0][0],cities[0][1])]
按其他城市与待预测城市之间的距离从近到远对other_cities列表进行排序:
from linear_algebra import distance
by_distance=sorted(other_cities, key= lambda point_label:
distance(point_label[0], cities[0][0]))
找到最近的一个城市:
In [3]: k_nearest_labels=[label for _, label in by_distance[:1]]
In [4]: k_nearest_labels
Out[4]: ['Python']
当然,我们也可以找到最近的3个城市、5个城市或7个城市。
In [5]: k_nearest_labels=[label for _, label in by_distance[:3]]
In [6]: k_nearest_labels
Out[6]: ['Python', 'R', 'Python']
In [7]: k_nearest_labels=[label for _, label in by_distance[:5]]
In [8]: k_nearest_labels
Out[8]: ['Python', 'R', 'Python', 'Java', 'R']
In [9]: k_nearest_labels=[label for _, label in by_distance[:7]]
In [10]: k_nearest_labels
Out[10]: ['Python', 'R', 'Python', 'Java', 'R', 'Python', 'Java']
这时,我们需要一个计数器找到出现次数最多的语言:
def majority_vote(labels):
"""假设labels已经从最近到最远排序"""
vote_counts = Counter(labels) #Counter返回的是字典,以label为键,出现次数为值
winner, winner_count = vote_counts.most_common(1)[0] #most_common方法可以找出vote_counts中出现次数前1(前几由括号内参数指定)的键和值,以元组组织
num_winners = len([count
for count in vote_counts.values()
if count == winner_count]) #计算vote_counts中前1的出现次数出现了几次,有几个胜出者
if num_winners == 1:
return winner # 如果只有一个胜出者,直接返回
else:
return majority_vote(labels[:-1]) # 如果有几个胜出者,排除lavels中最远的点,再试一次
在计算距离时,我们仅仅计算了一个点与其他点的距离。由于所有点的逻辑都是一样的,这种情况下,我们可以构造函数来执行同类操作。
def knn_classify(k, labeled_points, new_point):
"""k决定取最近的几个点;labeled_points指带标签的点,即(point, label)的数据对,是除待预测的点之外所有的点;new_ponit为待预测的点"""
# 将带标签的点,即其他点从近到远排序
by_distance = sorted(labeled_points,
key=lambda point_label: distance(point_label[0], new_point))
# 找到最近的k个点
k_nearest_labels = [label for _, label in by_distance[:k]]
# 对每个点进行计数
return majority_vote(k_nearest_labels)
现在,让我们看看尝试利用近邻城市来预测每个城市的偏爱语言会得到什么结果:
for k in [1, 3, 5, 7]:
num_correct = 0
for location, actual_language in cities:
other_cities = [other_city
for other_city in cities
if other_city != (location, actual_language)]
predicted_language = knn_classify(k, other_cities, location)
if predicted_language == actual_language:
num_correct += 1
print(k, "neighbor[s]:", num_correct, "correct out of", len(cities))
可以看到,k=3时的预测准确率最高,59%的时间能给出正确答案:
(1, 'neighbor[s]:', 40, 'correct out of', 75)
(3, 'neighbor[s]:', 44, 'correct out of', 75)
(5, 'neighbor[s]:', 41, 'correct out of', 75)
(7, 'neighbor[s]:', 35, 'correct out of', 75)
参考资料:
Joel Grus《数据科学入门》第12章
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。