起步
这次使用的训练集由 sklearn
模块提供,关于虹膜(一种鸢尾属植物)的数据。
数据载入
from sklearn import datasets
iris = datasets.load_iris()
数据存储在 .data
成员中,它是一个 (n_samples, n_features) numpy
数组:
print(iris.data)
# [[ 5.1 3.5 1.4 0.2]
# [ 4.9 3. 1.4 0.2]
# ...
它有四个特征,萼片长度,萼片宽度,花瓣长度,花瓣宽度 (sepal length, sepal width, petal length and petal width)。
它的品种分类有山鸢尾,变色鸢尾,菖蒲锦葵(Iris setosa, Iris versicolor, Iris virginica.)三种。
print iris.data.shape
# output:(150L, 4L)
这是一个含有 150 个数据的训练集。
构造 KNN 分类器
from sklearn import neighbors
knn = neighbors.KNeighborsClassifier(n_neighbors=5)
n_neighbors
参数级是指定获取 K 个邻近点。
训练
训练的函数一般就是 fit
:
knn.fit(iris.data, iris.target)
测试
模拟一些测试数据,使用刚刚的模型进行预测:
predict = knn.predict([[0.1, 0.2, 0.3, 0.4]])
print(predict) # output: [0]
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。