在cs229-notes1的第9页给出了以下等式:
$$ \nabla_A trABA^TC = CAB + C^TAB^T $$
后面推导线性回归的正规方程(Normal Equation)的时候,需要用到这个等式。
我在它的证明上纠结了很久,因此在这里写下我的证明,希望对正在纠结这个证明的同学有所帮助。
以下标题是为了(希望)搜索引擎能将一些模糊的公式输入索引到这个页面来,从而能帮助到正在使用搜索引擎查找这个问题的同学。读者可以忽略。
∇ TrABATC=CAB+CTABT 证明
准备过程
multivariable chain rule (多元链式求导法则):
$$\frac{d}{dt} f(x(t),y(t)) = \frac{\partial f}{\partial x} \cdot \frac{dx}{dt} + \frac{\partial f}{\partial y} \cdot \frac{dy}{dt}$$
多元链式求导法则是一个很有用的东西,其实乘法求导的法则是可以通过它来推导出来的:
将$ f(x(t),y(t)) = x(t)\cdot y(t) $带入上面的多元链式求导法则:
$$\frac{d x\cdot y}{dt} = \frac{\partial x\cdot y}{\partial x} \cdot \frac{dx}{dt} + \frac{\partial x\cdot y}{\partial y} \cdot \frac{dy}{dt} = y\cdot\frac{dx}{dt} + x\cdot\frac{dy}{dt} $$
证明过程
参考资料1和2的证明过程需要“对输出矩阵的函数求自变量的梯度”(比如对$f = AB$求 $\nabla_A$),我认为这样证明是错误的。因为只有在$f:R^{m \times n } \rightarrow R$(也就是说$f$输出实数)的时候,矩阵的梯度才有定义;如果输出的是一个矩阵,则不能求输入矩阵的梯度。(cs229线性代数补充笔记第20页)
我的证明参考了参考资料3,这种方法将问题转化为了标量计算,从而可以用我们更加习惯的方式进行求导。
为了方便分析,我们设A的尺寸是u*v。
为了使矩阵乘法有效,必须满足$AB \rightarrow (A的列数 = B的行数) $。
很容易可以证明,要使$ABA^TC$有定义且结果是方阵(tr运算符中的矩阵必须是方阵),需要满足:
矩阵 | 行数 | 列数 |
---|---|---|
A | u | v |
B | v | v |
C | u | u |
首先,对矩阵$ABA^TC$,写出它的任意元素的表达式。
$$(ABA^TC)_{ij}\\
= \sum_{k = 1}^{v}A_{ik}(BA^TC)_{kj}(利用(AB)_{ij}=\sum_k A_{ik}B_{kj})\\
= \sum_{k = 1}^{v}A_{ik} \sum_{l = 1}^{v} B_{kl}(A^TC)_{lj}\\
= \sum_{k = 1}^{v}A_{ik} \sum_{l = 1}^{v} B_{kl}\sum_{m = 1}^{u}A^T_{lm}C_{mj}\\
= \sum_{k = 1}^{v}A_{ik} \sum_{l = 1}^{v} B_{kl}\sum_{m = 1}^{u}A_{ml}C_{mj}\\
= \sum_{k = 1}^{v}\sum_{l = 1}^{v}\sum_{m = 1}^{u}A_{ik}B_{kl}A_{ml}C_{mj}(求和符号外的系数可以移入求和符号内)$$
因此,
$$trABA^TC = \sum_{n = 1}^{u}(ABA^TC)_{nn}\\
=\sum_{k = 1}^{v}\sum_{l = 1}^{v}\sum_{m = 1}^{u} \sum_{n = 1}^{u} A_{nk}B_{kl}A_{ml}C_{mn}$$
根据梯度的定义:
我们要求$\nabla_A trABA^TC$,需要对每个 $A_{pq}$ 求以下偏导数:
$$ \frac{\partial trABA^TC}{\partial A_{pq}} $$
我们再看一下之前的结论:
$$ trABA^TC = \sum_{k = 1}^{v}\sum_{l = 1}^{v}\sum_{m = 1}^{u} \sum_{n = 1}^{u} A_{nk}B_{kl}A_{ml}C_{mn} $$
这是很多项求和的结果,但是只有当【n=p&&k=q】或【m=p&&l=q】时,这一项对$A_{pq}$的偏导数才不为0(不含$A_{pq}$的项求导以后为0)。
每一项必定是以下4种情况中的一种:
- 【n=p&&k=q】满足,【m=p&&l=q】不满足。满足这种情况的项对$A_{pq}$的偏导数加起来成为:$\sum_{m = 1}^{u}\sum_{l = 1}^{v} B_{ql}A_{ml}C_{mp}$,注意这里遍历的时候要排除【m=p&&l=q】的项。
- 【n=p&&k=q】不满足,【m=p&&l=q】满足。满足这种情况的项对$A_{pq}$的偏导数加起来成为:$\sum_{n = 1}^{u}\sum_{k = 1}^{v} A_{nk}B_{kq}C_{pn}$,注意这里遍历的时候要排除【n=p&&k=q】的项。
- 【n=p&&k=q】不满足,【m=p&&l=q】不满足。满足这种情况的项对$A_{pq}$的偏导都是0,可以忽略。
- 【n=p&&k=q】满足,【m=p&&l=q】满足。满足这种情况的只有这一项:$A_{pq}B_{qq}A_{pq}C_{pp}$,它对$A_{pq}$求导时要使用多元链式法则,得到$B_{qq}A_{pq}C_{pp}+A_{pq}B_{qq}C_{pp}$。恰好补全了第一第二种情况缺少的项!
综上,$ \frac{\partial trABA^TC}{\partial A_{pq}} $ 的结果:
$$\sum_{m = 1}^{u}\sum_{l = 1}^{v} B_{ql}A_{ml}C_{mp} + \sum_{n = 1}^{u}\sum_{k = 1}^{v} A_{nk}B_{kq}C_{pn}$$
稍微调整一下,让结果更加明显:
$$\sum_{m = 1}^{u}\sum_{l = 1}^{v} (C^T)_{pm}A_{ml}(B^T)_{lq} + \sum_{n = 1}^{u}\sum_{k = 1}^{v} C_{pn}A_{nk}B_{kq}$$
这恰好就是
$$ (C^TAB^T)_{pq} + (CAB)_{pq} $$
因此,$$\frac{\partial trABA^TC}{\partial A_{pq}} = (C^TAB^T)_{pq} + (CAB)_{pq} $$
也就是说:
$$ \nabla_A trABA^TC = CAB + C^TAB^T $$
目标得证。
参考资料:
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。