title: "复习: 反向传播与参数初始化" categories:
以最简单的前馈网络为例, 向量默认为列向量.
记号 | 含义 |
---|---|
$M_l$ | 第 $l$ 层神经元个数 (向量维数) |
$W^{(l)}\in \mathbb R^{Ml \times M{l-1}}$ | 第 $l-1$ 层到第 $l$ 层的权重矩阵 |
$b^{(l)}\in \mathbb R^{M_l}$ | 第 $l-1$ 层到第 $l$ 层的偏置 (列向量) |
$z^{(l)}\in \mathbb R^{M_l}$ | 第 $l$ 层净输入 (logits) |
$f_l\colon \mathbb R \to \mathbb R$ | 第 $l$ 层激活函数 |
$a^{(l)}\in \mathbb R^{M_l}$ | 第 $l$ 层输出 |
输入向量时约定 $f_l((x_1, x_2)') = (f_l(x_1), f_l(x_2))'$. 网络如下
$$ \begin{align} z^{(l)} &= W^{(l)}a^{(l-1)} + b^{(l)},\ a^{(l)} &= f_l(z^{(l)}). \end{align} $$
损失函数为 $L(y, \hat y)$, 其中 $y$ 为真实值, $\hat y = \operatorname{softmax}(z^{(L)})$ (以多分类为例), $L$ 表示总层数 (最后一层).
下面计算 $L$ 分别对 $W$ 和 $b$ 的偏导数, 把对矩阵 $W$ 的微分拆解为对每个元素的微分. 我不喜欢分母布局 (邱锡鹏以及很多文章中采用的), 所以下面用 分子布局.
$$ \begin{align*} \frac{\partial L}{\partial w^{(l)}{i, j}} &= \frac{\partial L}{\partial z^{(l)}} \frac{\partial z^{(l)}}{\partial w^{(l)}{i, j}},\
\frac{\partial L}{\partial b^{(l)}} &= \frac{\partial L}{\partial z^{(l)}} \frac{\partial z^{(l)}}{\partial b^{(l)}}. \end{align*} $$
$$ \begin{align} \frac{\partial z^{(l)}}{\partial w^{(l)}{i, j}} &= \left( 0, \dots, \frac{\partial(w{i,:}^{(l)}a^{(l-1)} + b^{(l)})}{\partial w^{(l)}_{i, j}}, \dots, 0 \right)'\ &= a^{(l-1)}_j e_i \in \mathbb R^{M_l}, \end{align} $$
其中 $e_i$ 表示第 $i$ 个元素为 $1$, 其余为 0 的列向量.
$$ \begin{align} \frac{\partial z^{(l)}}{\partial b^{(l)}} = I \in \mathbb R^{M_l\times M_l}, \end{align} $$
为单位阵.
$$ \begin{align} (\delta^{(l)})' :=& \frac{\partial L}{\partial z^{(l)}}\ =& \frac{\partial L}{\partial z^{(l+1)}} \frac{\partial z^{(l+1)}}{\partial a^{(l)}} \frac{\partial a^{(l)}}{\partial z^{(l)}}\ =& (\delta^{(l+1)})' W^{(l+1)} \operatorname{diag}(f_l'(z^{(l)})) \in \mathbb R^{1\times M_l}. \end{align} $$
反向传播的含义是, 第 $l$ 层的一个神经元的 误差项 $\delta^{(l)}$ 是所有与该神经元相连的第 $l+1$ 层的神经元的误差项的权重和, 再乘上该神经元激活函数的梯度.
$$ \begin{align} \frac{\partial L}{\partial w^{(l)}_{i, j}} = \delta^{(l)}_i a^{(l-1)}_j, \end{align} $$
因此
$$ \begin{align} \frac{\partial L}{\partial W^{(l)}} = (\delta^{(l)})' a^{(l-1)}. \end{align} $$
另外
$$ \begin{align} \frac{\partial L}{\partial b^{(l)}} &= (\delta^{(l)})'. \end{align} $$
In hindsight, 其他人采用分母布局主要是为了结果稍微好看一点.
对权重的每个元素, 通常用正态分布 $\operatorname{N}(0, \sigma^2)$ 或者均匀分布 $U(-r, r)$ 初始化参数, 注意均匀分布的方差为 $r^2/3$.
考虑第 $i$ 个神经元,
$$ a^{(l)}_i = fl \left( \sum{j=1}^{M{l-1}}w^{(l)}{i, j}a^{(l-1)}_j + b^{(l)}_i \right) =: f_l (z^{(l)}_i), $$
简单起见先令激活函数 $fl$ 为 恒等函数. 假设 $b{i}^{(l)} = 0$, 以及 $w{i, j}^{(l)}$ 与 $a{j}^{(l-1)}$ 互相独立, 期望为 0, 且 $w{j}^{(l)}$ 方差相等. 则可知 $\mathbb E a{i}^{(l)} = 0$, 以及
$$ \operatorname{Var}(a^{(l)}_i) = \operatorname{Var}(z^{(l)}i) = M{l-1}\operatorname{Var}(w^{(l)}_{1, 1})\operatorname{Var}(a^{(l-1)}_i). $$
为了使 $a^{(l)}$ 在传播过程中方差不变, 应当使 $\operatorname{Var}(w{1,1}^{(l)}) = 1/M{l-1}$. 类似地为了使 $\delta^{(l)}$ 方差不变, 应当使 $\operatorname{Var}(w{1,1}^{(l)}) = 1/M{l}$. 取调和平均为
$$ \operatorname{Var}(w^{(l)}{1,1}) = \frac{2}{M{l-1} + M_{l}}, $$
称为 Xavier 初始化或者 Glorot 初始化.
"保持方差不变" 的动机来自, 网络层数加深, 方差递推式中方差的变化因子 (假设为常数) 如果不为 1 则方差要么爆炸要么收敛到 0. 更进一步希望权重矩阵正交, 好在一个简单的结果是高维空间任意 (球面平均) 两个向量近似是正交的 (比如参考 这里). 当然也有更直接的正交初始化.
虽然在 Xavier 初始化中我们假设激活函数为恒等函数, 但是 Xavier 初始化也适用于 Logistic 函数和 Tanh 函数. 这是因为神经元的参数和输入的绝对值通常比较小, 处于激活函数的线性区间. 这时 Logistic 函数和 Tanh 函数可以近似为线性函数. 由于 Logistic 函数在线性区间的斜率约为 0.25, 所以初始化方差要乘 16, 或者其他缩放因子.
以严格的数学来说有些问题, 但实际使用效果不错.
如果激活函数为 ReLU, 则期望不为 0, 其他假设同上, 另外假设 $w_{i, j}$ 的分布关于 0 对称. 由于
$$ \operatorname{Var}(z^{(l)}i) = M{l-1}\operatorname{Var}(w^{(l)}_{1, 1})\mathbb E[(a^{(l-1)}i)^2] = M{l-1}\operatorname{Var}(w^{(l)}_{1, 1})\operatorname{Var}(z^{(l-1)}_i) / 2, $$
因此
$$ \operatorname{Var}(w^{(l)}{1,1}) = \frac{2}{M{l-1}}, $$
称为 He 初始化或者 Kaiming 初始化.
另外为了适配 parametric (or leaky) ReLU, 于是有
$$ \operatorname{Var}(w^{(l)}{1,1}) = \frac{2}{(1 + a^2) M{l-1}}. $$
比如 PyTorch 源码就采用了这个形式, 全连接层的默认初始化用的并非是上述两种初始化, 而是方差为 $1/(3M{l-1})$ 的均匀分布, 即 $U(-1/\sqrt{M{l-1}}, 1/\sqrt{M_{l-1}})$.
# torch/nn/modules/linear.py#L105
class Linear(Module):
def reset_parameters(self) -> None:
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
新文章