感知机介绍
感知机(perceptron)是二分类的线性模型,其输入为实例的特征向量, 输出为实例的类别,取+1或-1两个值。感知机对于输入空间(特征空间)中将实例划分为正负两类的分类超平面,属于判别模型。感知机学习旨在求出将训练数据进行线性划分的超平面。
感知机具有简单易于实现的优点,它是神经网络和支持向量机的基础。
假设输入空间(特征空间)是$x\subseteq \mathcal{R}_n$,输出空间是$\gamma=\{+1,-1\}$。输入$x\in \chi$表示实例的特征向量,对应于输入空间(特征空间)点;输出$y\in \gamma$表示实例的类别。由输入空间到输出空间的如下函数
$$f(x) = sign(w \cdot x + b)\tag{1}$$
称为感知机。其中,w和b为感知机模型参数,$w\in R^n$叫作权值或权值向量,$b\in R$叫作偏置,w·x表示w和x的内积。sign是符号函数,即
$$\begin{align}sign(x)=\begin{cases}+1,&x \geq 0 \\\ -1, &x < 0\end{cases}\end{align}$$
感知机是一种线性分类模型,属于判别模型。感知机模型的假设空间是定义在特征空间中的所有线性分类模型或线性分类器,即函数集合 $$\{f|f(x)=w·x+b\}$$
感知机的学习策略
给定一个数据集$T=\{(x_1,y_1),(x_2, y_2),...,(x_n,y_n)\}$,其中$x_i\in \chi =R^n, y_i\in \gamma=\{+1, -1\},i=1, 2, ...,N$,如果存在某个超平面S
$$w\cdot x+b=0\tag{3}$$
能将数据集的真实例点和负实例点完全正确的分到超平面的两侧,即对所有$y_i=+1$的实例i,有$w\cdot x_i+b>0$,对于所有$y_i=-1$的实例i,有$w\cdot x_i+b<0$,则称数据集T为线性可分数据集(Linear separable dataset);否则,称数据集T不可分。
假设数据集T是线性可分的,感知机学习的目的是,学习一个能完全正确分离正负样本的超平面S,如下图所示。确定该超平面,即确定模型参数w和b。需要确定模型的损失函数,并将其最小化。
损失函数的一个最直接的选择是误分类样本的数目(类似于L0正则化),但这样的损失函数并不容易求解。另一个选择是将误分类点到超平面的总距离作为损失函数,可以得到输入空间$R^n$任意一个点$x_0$到超平面的距离。
$$\frac{1}{||w||}|w\cdot x_0 +b|\tag{4}$$
$||w||$为权重w的L2范数。
对于误分类的实例而言,有下式成立:
$$-y\cdot (w\cdot x + b)>0\tag{5}$$
当$w\cdot x+b>0$时,$y_i$=-1;而当$w\cdot x+b<0$时,$y_i$=+1。因此,误分类点$x_i$到超平面S的距离为
$$-\frac{1}{||w||}y_i\cdot (w\cdot x_i +b)\tag{6}$$
这样可以得到所有误分类实例到超平面的总距离
$$-\frac{1}{||w||}\sum_{x_i\in M}y_i\cdot (w\cdot x_i +b)\tag{7}$$
M为误分类实例的集合,不考虑$\frac{1}{||w||}$,就得到感知机$sign(w\cdot x+b)$的损失函数。
$$L(w, b)=-\sum_{x_i\in M}y_i\cdot (w\cdot x_i +b)\tag{8}$$
感知机学习的策略是在假设空间中选择使损失函数(8)最小的模型参数w、b,即感知机模型。
感知机算法步骤
感知机算法是由误分类驱动的,具体采用随机梯度下降法。即首先随机选择一个超平面(w,b),然后使用梯度下降最小化目标函数
$$ min_{w, b}L(w, b)=-\sum_{x_i\in M}y_i\cdot (w\cdot x_i +b) \tag{9}$$
极小化过程不是一次使用全部误分类点的梯度下降,而是一次随机选取一个误分类点使其梯度下降。
假设误分类点集合M是固定的,那么损失函数L(w,b)的梯度由
$$\begin{equation}\tag{10}\begin{cases}\bigtriangledown_wL(w, b)=-\sum_{x_i\in M}y_ix_i \\ \bigtriangledown_bL(w, b)=-\sum_{x_i\in M}y_i\end{cases}\end{equation}$$
给出。
随机选取一个误分类点$(x_i, y_i)$,对参数w、b进行更新:
$$w\leftarrow w+\eta y_ix_i \tag{11}$$
$$b\leftarrow b+\eta y_i \tag{12}$$
$\eta$表示步长,也叫学习率。
对于数据集$T=\{(x_1,y_1),(x_2, y_2),...,(x_n,y_n)\}$,其中$x_i\in x=R^n, y_i\in y={+1, -1},i=1, 2, ...,N,0<\eta<1$。
由此,我们可以写出感知机算法的过程
- 初始化模型参数w、b。
- 计算误分类点的集合。
- 利用误分类点数据,使用式(11)和式(12)对参数w、b进行更新。
- 转至步骤2,直至误分类点的集合为空。
感知机的代码实现
加载数据并预处理。这里使用sklearn自带的鸢尾花数据集,我们只是用数据集中花萼长度和花萼宽度两个特征来构建感知机模型。
# 加载数据
data = load_iris()
x, y = data['data'], data['target']
# 只选择标签为0和1的样本
mask = y < 2
x, y = x[mask], y[mask]
idx0 = (y == 0)
idx1 = (y == 1)
# 只选择花萼长度和花萼宽度两个特征
x = x[:, :2]
# 将标签0和1替换为-1和1
y = np.where(y == 1, 1, -1)
构建感知机模型
class Perceptron:
def __init__(self, x, y, lr=0.1):
self.x = x
self.y = y
self.w = np.ones(x.shape[1], dtype=np.float32)
self.b = 0
self.lr=lr
# 一次训练整个误分类样本集
def train1(self):
while True:
y_pred = self.y * (self.x @ self.w)
mask = (y_pred <= 0).reshape((-1, ))
error_num = np.count_nonzero(mask)
if error_num == 0:
break
else:
self._update(self.x[mask], self.y[mask])
# 单个样本训练
def train(self):
while True:
error_num = 0
for i in range(self.x.shape[0]):
X = self.x[i]
Y = self.y[i]
if self.y[i] * (X @ self.w + self.b) <= 0:
self.w = self.w + self.lr * Y * X
self.b = self.b + self.lr * Y
error_num += 1
if error_num == 0:
break
# 参数更新
def _update(self, x, y):
self.w += self.lr * x.T @ y / len(y)
可视化感知机分类结果
net = Perceptron(x,y)
net.train()
fig, ax = plt.subplots()
ax.scatter(x[idx0, 0], x[idx0, 1], label='0')
ax.scatter(x[idx1, 0], x[idx1, 1], label='1')
ax.plot(x[:,0], -(x[:, 0] * net.w[0] + net.b) / net.w[1], c='red', label='fit')
ax.legend()
plt.show()
从图中可以看出,根据误分类集拟合得到的超平面并非最优超平面,这是因为恰好完全正确分类作为一个临界条件,所以模型只是刚好能满足这个条件而已。另外需要注意的是,在大部分情况下,梯度下降都只是获得一个去不可行解,而非最优解。
总结
1、感知机是根据输入实例的特征向量x对其进行二分类的线性模型,$f(x)=sign(w\cdot x+b)$,感知机对应于输入空间中的超平面$w\cdot x+b=0$。
2、感知机的策略是极小化损失函数:
$$min_{w, b}L(w, b)=-\sum_{x_i\in M}y_i\cdot (w\cdot x_i +b) $$
损失函数对应于误分类点到超平面的距离。
3、感知机学习算法是基于随机梯度下降的对损失函数的最优化算法,有原始形式和对偶形式(这里只介绍了原始形式)。算法简单且易于实现。原始形式中,首先任意选取一个超平面,然后用梯度下降法不断极小化目标函数。在这个过程中一次随机选取一个误分类点使其梯度下降。
4、当训练数据集线性可分时,感知机学习算法是收敛的。* 当训练数据集线性可分时,感知机学习算法存在无穷多个解,其解由于不同的初值或不同的迭代顺序而可能有所不同。
Reference
《李航统计学习方法》
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。