多元线性回归
多特征
多个特征变量也称为多元线性回归(multivariate linear regression)
。先解释一些符号含义:
- $ x^{(i)} $ 表示训练集中的第
i
组用例 - $ x^{(i)}_j $ 表示第i组用例中的第
j
个特征变量 - m表示训练用例的总数
- n表示每组用例的特征数
多个特征变量有如下假设函数:
$$ h_θ(x) = θ_0 + θ_1x_1+ θ_2x_2+ θ_3x_3 + ... + θ_nx_n $$
比如,我们可以认为$ θ_0 $是房子的基础价格,$ θ_1 $是每平米价格,$ x_1 $是面积;$ θ_2 $是每层价格,$ x_2 $是层数;等等。
使用矩阵乘法来表示上面这个函数:
$$ h_θ(x)=\begin{bmatrix} θ_0 & θ_1 & \cdots & θ_n \\ \end{bmatrix} \begin{bmatrix}x_0 \newline x_1 \newline \vdots \newline x_n\end{bmatrix} = θ^T x $$
上一篇,曾给出过关于两个特征$ (θ_0,θ_1) $时候的算法推导式
$$ θ_0 := θ_0 - α{1 \over m} {\sum_{i=1}^m(h_θ(x_i)-y_i)} $$
$$ θ_1 := θ_1 - α{1 \over m} {\sum_{i=1}^m((h_θ(x_i)-y_i)x_i)} $$
应用到多个特征时,不难想到推导式应该如下:
$$ \begin{align*} & \text{repeat until convergence:} \; \lbrace \newline \; & \theta_0 := \theta_0 - \alpha \frac{1}{m} \sum\limits_{i=1}^{m} (h_\theta(x^{(i)}) - y^{(i)}) \cdot x_0^{(i)}\newline \; & \theta_1 := \theta_1 - \alpha \frac{1}{m} \sum\limits_{i=1}^{m} (h_\theta(x^{(i)}) - y^{(i)}) \cdot x_1^{(i)} \newline \; & \theta_2 := \theta_2 - \alpha \frac{1}{m} \sum\limits_{i=1}^{m} (h_\theta(x^{(i)}) - y^{(i)}) \cdot x_2^{(i)} \newline & \cdots \newline \rbrace \end{align*} $$
即:
$$ \begin{align*}& \text{repeat until convergence:} \; \lbrace \newline \; & \theta_j := \theta_j - \alpha \frac{1}{m} \sum\limits_{i=1}^{m} (h_\theta(x^{(i)}) - y^{(i)}) \cdot x_j^{(i)} \; & \text{for j := 0...n}\newline \rbrace\end{align*} $$
多元线性回归实践
通过将训练样本中不同的特征的取值范围限制在大致相同的范围内,可以加快梯度下降的收敛速度。这是因为,当输入的范围比较小的时候,$ θ $的递减速度比较快,而输入的范围比较大的时候,递减速度会变慢。理想情况下,可以通过对输入变量进行处理,将其限制在一个范围内,这个范围可能是−1 ≤ x ≤ 1
或−0.5 ≤ x ≤ 0.5
。并不需要严格在这个范围内,因为我们的目的仅仅是让算法执行速度更快一些。
有两种技术做到这点,分别是特征缩放(feature scaling)
和均值归一化(mean normalization)
。特征缩放是将输入值除以输入变量的范围(例如最大值减最小值),得到一个新的仅为1的范围。均值归一化是将输入值减去平均值后,再除以输入变量的范围,得到一个0左右的范围:
$$ x_i := \dfrac{x_i - \mu_i}{s_i} $$
$ u_i $表示特征的平均值,$ s_i $可以是(max-min)表示的范围,或者是标准差。例如,如果$ x_i $表示房价,范围是100到2000,平均为1000,那么:
$$ x_i := \dfrac{price-1000}{2000-100} $$
我们需要保证我们的梯度下降算法工作正常。绘制一张图,x轴是迭代次数,y轴是$ J(θ) $,随着迭代次数的增加,$ J(θ) $应该呈下降趋势,否则,应当减小α
。另外,为了测试收敛,可以设定一个E值,当前后两次$ J(θ) $的下降小于这个值时,认为收敛,比如可以让E取值为$ 10^{-3} $,在实践中这个阈值不容易确定。
我们可以通过多种方式改善我们的假设函数。我们可以将多个特征合并成一个,比如将$ x_1 $和$ x_2 $合并成$ x_3 $,使$ x_3 = x_1 * x_2 $。
有时,我们无法使用一条直线来定义假设函数,因为那样可能并不合适。这个时候,可以把假设函数设计成二次函数、三次函数、平方根函数等。例如,可以把$ h_θ(x) = θ_0 + θ_1 x_1 $设计成二次函数$ h_θ(x) = θ_0 + θ_1 x_1 + θ_2 x_1^2 $或三次函数$ h_θ(x) = θ_0 + θ_1 x_1 + θ_2 x_1^2 + θ_3 x_1^3 $,这样的拟合效果可能更好。对于上面的三次函数,我们可以引入两个新的特征$ x_2 $和$ x_3 $,使得$ x_2 = x_1^2 $,$ x_3 = x_1^3 $。还可以考虑平方根函数$ h_θ(x) = θ_0 + θ_1 x_1 + θ_2 \sqrt{x_1} $。需要记得,这个时候特征缩放
就显得格外重要了。
正规方程
梯度下降给出了最小化代价函数的算法,本节我们要讨论的另一种方式,是一种不基于迭代的算法,而是通过一个直接的计算公式,称为正规方程(Normal Equation)
:
$$ θ = (X^T X)^{-1}X^T y $$
下图是一个例子,以及上述公式中X
和y
的定义:
X
是指一个m x n+1
的矩阵,其中m是指样本数量,n是特征个数,之所以是n+1
,是因为第一列用全1填充。y
是一个m x 1
的向量,表示样本的结果。可以从数学上证明正规方程得到的θ
能使代价函数最小化。对于使用正规方程计算时,我们无需考虑
上面提到的特征缩放和归一化。
对比一下梯度下降和正规方程解法的优劣:
梯度下降 | 正规方程 |
---|---|
需要调整α | 无需调整α |
需要多次迭代 | 无需迭代 |
复杂度$ O(kn^2) $ | 计算$ X^TX^{-1} $时的复杂度为$ O(n^3) $ |
适用于n很大的时候 | n很大时很慢(>10000) |
正规方程还有一个问题是,$ X^TX $未必是可逆的(不可逆的矩阵也称为奇异阵)。在使用octave
计算正规方程是,通常使用pinv
函数而不是inv
函数,前者在$ X^TX $不可逆的情况下也能工作。
导致$ X^TX $不可逆的因素可能包括:
- 特性冗余,即两个特性之间联系比较紧密,比如存在线性依赖关系
- 特性比样本多,即(m ≤ n)
解决办法通常就是删除一些冗余的特性,或者简化特性。
线性回归代码总结
在整个线性回归问题中,主要有如下几个算法需要实现:
- 代价函数
- 梯度下降算法
- 特征缩放
- 正规方程
使用octave
和matlab
利于快速验证算法和模型。在使用这两种编程语言和平台时,要始终以向量和矩阵的思维方式去思考,这样才能更好的利用两种语言的优势,将很多看似复杂的公式用几行代码实现。
前面定义过代价函数:
$$ J(θ) = {1 \over 2m} {\sum_{i=1}^m(h_θ(x_i)-y_i)^2} $$
这里的$ θ $就是向量的概念,而不是单个变量,可以用如下代码实现它,X
是输入矩阵:
y = data(:, 2);
m = length(y);
X = [ones(m, 1), data(:,1)];
J = sum((X * theta - y).^2) / (2*m);
前面归纳过低度下降函数的算法:
$$ \begin{align*}& \text{repeat until convergence:} \; \lbrace \newline \; & \theta_j := \theta_j - \alpha \frac{1}{m} \sum\limits_{i=1}^{m} (h_\theta(x^{(i)}) - y^{(i)}) \cdot x_j^{(i)} \; & \text{for j := 0...n}\newline \rbrace\end{align*} $$
用代码实现:
for iter = 1:num_iters
theta = theta - (sum((X * theta - y) .* X) * alpha / m)';
end
对于特性缩放,可以这么实现:
$$ x_i := \dfrac{x_i - \mu_i}{s_i} $$
mu = mean(X);
sigma = std(X);
X_norm = (X - mu) ./ sigma;
正规方程就比较简单了:
$$ θ = (X^T X)^{-1}X^T y $$
theta = pinv(X' * X) * X' * y;
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。