论文标题: 残差持续学习 (Residual Continual Learning)

所属机构: KAIST (韩国科学技术院) 电气工程学院 & 人工智能研究生院

论文地址:https://ojs.aaai.org/index.php/AAAI/article/view/5884

视频讲解:https://www.bilibili.com/video/BV1QDCHY8EC2/?spm_id_from=333....


0. 摘要

  • 提出方法: 提出了一种新的持续学习方法,称为残差持续学习(ResCL)。
  • 解决问题: 该方法旨在解决多任务序列学习中的灾难性遗忘现象。
  • 核心特点:

    • 无需源任务信息(仅需原始网络)。
    • 网络大小完全不增加。
    • 通过线性组合原始网络和微调后网络的每一层来重新参数化网络参数。
    • 考虑了批量归一化(BN)层的影响。
    • 利用类残差学习的重参数化和特殊的权重衰减损失来有效控制源任务和目标任务性能之间的权衡。
  • 实验结果: 在多种持续学习场景中展现了最先进的性能。

1. 引言

  • 背景: 深度学习性能强大,但训练需要大量数据和时间。迁移学习(如微调)可利用源任务知识加速目标任务训练,但会导致灾难性遗忘(源任务性能损失)。
  • 目标: 在保持源任务性能的同时,实现良好的目标任务性能,特别是在基于CNN的图像分类任务上。
  • 实际约束条件:

    1. 无源信息: 目标任务训练期间,无法获取任何形式的源任务信息(包括数据、生成模型、Fisher信息矩阵等),仅有原始的源网络模型。
    2. 网络大小不变: 网络规模不应随任务增加而增长,这与网络扩展方法不同。
  • 主要贡献/特征:

    • 类残差重参数化: 允许持续学习,并通过简单的衰减损失控制性能权衡。
    • 无需源信息: 除原始源网络外,不需要其他源任务信息。
    • 推理时网络大小不变: 除了最后的任务特定线性分类器外,网络结构大小保持不变。
    • 通用性: 可自然地应用于包含批量归一化(BN)层的通用CNN。
    • 公平评估指标: 提出了两种公平的度量标准来比较不同的持续学习方法:理想情况下的“最大可达平均准确率”,以及更实用的“满足要求的目标准确率时的源准确率”。这两种指标的设计思路可见图4。

图 4.由于每种方法具有不同的权衡超参数定义,仅使用特定的权衡超参数设置进行比较是不公平的。(a) 公平度量之一是最大可达平均准确率(圆圈点)。(b1) 实践中,使用目标验证集调整权衡超参数,直到达到所需的目标准确率(虚线圆圈点)。(b2) 具有这些超参数的模型在源测试集上测试一次。

2. 相关工作

  • LwF (Learning without Forgetting): 使用知识蒸馏约束新网络输出接近源网络输出。ResCL将LwF视为一种特殊情况,但训练对象和方式不同。
  • IMM (Incremental Moment Matching): 同样使用网络组合。IMM基于贝叶斯框架,无需额外训练确定组合系数;ResCL通过额外训练学习系数和权重,更适用于标准神经网络。
  • 网络扩展方法 (Terekhov et al., Rusu et al.): 能完美防止遗忘,但会增加网络大小,与ResCL目标不同。
  • 基于正则化的方法 (EWC, SI, MAS, RWalk等): 通过惩罚损失保护重要权重。通常需要源数据或复杂计算,ResCL避免了这些。
  • 剪枝与重训练 (PackNet): 需要源数据。
  • 滤波器组合方法 (Rebuffi et al., Rosenfeld & Tsotsos): 主要用于多任务或迁移学习,通常需为任务保留特定参数或使用受限系数。ResCL专注于在单一固定结构内进行序列学习,学习每层/滤波器的实数组合系数,且不依赖特定结构。

3. 方法

3.1 核心思想与流程

  • 目标: 在源任务和目标任务之间找到一个良好的中间点,平衡性能。
  • 基本流程概述见图1:

    1. 从在源数据训练好的原始网络 net_s (参数 θ_s) 开始。
    2. net_s 在目标数据上微调得到 net_t (参数 θ_t)。
    3. 线性组合 net_s (参数 θ_s 固定) 和 net_t (参数 θ_t 可学习) 的对应层,得到组合网络 net_c。组合系数 α = (α_s, α_t) 也是可学习的。
    4. 使用 LwF 损失 L_s (保留源任务知识)、蒸馏损失 L_t (学习目标任务知识) 和特殊的衰减损失 L_decay (控制权衡并防止遗忘) 来训练 net_c,优化 αθ_t

图 1:我们方法的图示。可学习参数以红色显示。我们从一个在源数据上训练的原始网络 net_s 开始。首先,使用目标数据对 net_s进行微调以获得 net_t。net_s 和 net_t 中的每个线性块都通过 net_c 中的组合层进行组合。在 net_c 上执行持续学习,使用 LwF 损失 L_s 来保持源任务性能,并使用蒸馏损失 L_t 来适应目标任务。还有一个特殊的衰减损失 L_decay,这是防止遗忘的最重要的损失。DKL(·||·) 指的是 Kullback–Leibler 散度,softmax 温度为 2。请注意,由于源任务和目标任务通常具有不同的类别,因此每个任务都有其自己最后的任务特定全连接层。因此,net_c 有两个不同的输出:net_c(·; task_s) 用于源任务,net_c(·; task_t) 用于目标任务。

3.2 层的线性组合 (Linear Combination of Two Layers)

  • 组合方式 (公式 1): (1_Co + α_s) ∘ (W_s x) + α_t ∘ (W_t x)

    • W_s 固定,W_t 可学习(从 W_s 初始化)。
    • α_s, α_t 是逐通道(特征)的可学习向量。
    • 1 + α_s 的参数化形式是关键设计,使得衰减损失倾向于将网络推向源网络状态,类似于残差学习。
  • 推理时等效性 (公式 2): ((1_Co + α_s) ⊗ 1_Ci^T) ∘ W_s + (α_t ⊗ 1_Ci^T) ∘ W_t (原文公式2的简化理解,表示权重合并)

    • 由于线性,组合操作在推理时可以合并为一个单一层,因此网络大小不增加。
  • 非线性层: 组合应在非线性激活函数(如ReLU)之前进行。

3.3 训练 (Training)

  • 损失函数:

    1. 源任务损失 L_s: DKL(net_s(x) || net_c(x; task_s)) (LwF损失)。
    2. 目标任务损失 L_t: DKL(net_t(x) || net_c(x; task_t)) (蒸馏损失)。
    3. 衰减损失 L_decay: λ||(α_s, α_t)||_1 + (λ_dec/2)||θ_t||_2^2

      • λ||(α_s, α_t)||_1:L1 范数促使 α 稀疏并趋于零,结合 1 + α_s 参数化,保护源网络权重。实验发现 L1 比 L2 效果稍好。
      • (λ_dec/2)||θ_t||_2^2:对目标路径权重 θ_t 进行标准 L2 衰减。
  • 算法流程 (Algorithm 1): 详细描述了从微调、获取蒸馏目标、初始化α、组合网络到最终训练优化的步骤。

Algorithm 1:残差持续学习算法。

3.4 卷积层和批量归一化 (Convolution and Batch Normalization)

  • 卷积层: 组合系数 α 在空间维度共享。
  • BN层处理 (图 2):

    • 在组合网络中,源路径和目标路径各自保留独立的BN层 (BN(s)BN(t)),如图2(b)所示。
    • 关键: 训练 net_c 时,BN(s) 必须使用其在源任务上计算好的、固定的总体统计量(population statistics),以保留源任务的分布信息。BN(t) 正常使用其批次统计量或总体统计量。
    • 推理合并 (图 2(c)): 推理时,两个卷积层、两个BN层(此时为固定线性变换)和组合层可以合并成一个等效的卷积层 Conv(c),网络结构与原始单元 (a1)(a2) 相同,不增加参数或计算量。

图 2 :源和目标预激活残差单元 (He et al. 2016b) 的组合。“Comb”代表组合层。两条路径在每个非线性层之前通过组合层进行组合。可学习的层显示为红色。在推理阶段,使用组合网络 (c),它等效于 (b) 并且具有与 (a1) 和 (a2) 相同的网络大小。

4. 实验

4.1 评估指标:最大可达平均准确率

  • 目的: 衡量方法的理论最优性能,消除特定超参数选择的影响。
  • 方法: 搜索一系列权衡超参数 (λ 或 α1/α2),找到使 (源准确率 + 目标准确率) / 2 最大的点。
  • 结果 (表1, 表2):

    • ResCL 在所有三个场景 (CIFAR-10→100, CIFAR-100→10, CIFAR-10→SVHN) 的最大可达平均准确率上均显著优于其他无需源信息的方法(Fine-tuning, LwF, Mean-IMM)。
    • 尤其在任务差异大的 CIFAR-10→SVHN 场景,ResCL (89.49%) 大幅领先 LwF (68.90%) 和 Mean-IMM (79.91%)。

![表 1:每种方法的最大可达平均准确率[%]。显示了四次运行的均值和标准差。最优权衡超参数在括号中给出。](https://files.mdnice.com/user/91772/a5abf711-32f2-4953-9a00-5...)

![表 2:对应表1中最优平均准确率时,各方法分别在源任务和目标任务上的准确率[%]。](https://files.mdnice.com/user/91772/27bebfa4-7bb7-4cd9-a3c7-0...)

4.2 评估指标:满足要求的目标准确率时的源准确率

  • 目的: 模拟实际应用,评估在保证基本目标性能下的遗忘程度。
  • 方法: 设置目标准确率要求(如 Fine-tuning 的 95%),仅用目标验证集调整超参数,找到满足该要求的模型,然后评估其源任务准确率。
  • 结果 (表4, 表5):

    • 在所有场景下,ResCL 在满足目标准确率要求时,其源任务准确率显著高于 LwF。例如,在 CIFAR-10→SVHN 中,ResCL 源准确率为 76.83%,而 LwF 仅为 38.08%。

![表 4:在要求的目标准确率下,各种方法的源任务准确率[%]。目标准确率要求设定为微调模型性能的95%。](https://files.mdnice.com/user/91772/b8dcb743-3925-4d8e-9818-b...)

![表 5:在要求的目标准确率下,各种方法的源任务准确率[%]。第二列表示在三个任务上的序列学习。](https://files.mdnice.com/user/91772/2ed2e3f6-d1a2-417d-8a6e-1...)

4.3 实验设置

  • 数据集与网络: CIFAR-10/100, SVHN 使用 PreResNet-32;ImageNet→CUB 使用 AlexNet 和 VGG。
  • 训练细节: 标准 CIFAR 训练流程(数据增强、SGD、学习率衰减等)、He 初始化、α 初始化为 (-0.51, 0.51) 以初始平衡。SVHN 不使用数据增强。

4.4 结果分析

  • BN 统计量 (图 3): LwF 训练后,BN 层的均值和标准差分布与原始网络差异巨大,说明源分布信息丢失。ResCL 通过保留独立的源 BN 层及其固定统计量,避免了此问题。

图 3:特定网络中所有BN层的统计参数(均值μ和标准差σ)的分布。黑线是原始网络的分布,红线是使用LwF在目标任务上训练原始网络后的分布。

  • 多任务扩展 (表 3): ResCL 同样适用于三个任务的序列学习以及大规模数据集 (ImageNet→CUB),表现出良好的可扩展性和性能。

![表 3:每种方法的最大可达平均准确率[%]。第二列代表三个任务的序列学习。](https://files.mdnice.com/user/91772/8a0d84e1-b329-4dea-9a44-0...)

  • 超参数λ的影响 (图 5): 在 CIFAR-10→SVHN 场景中,λ 确实有效地控制了源/目标准确率的权衡。源准确率随 λ 增大而饱和,目标准确率随 λ 增大而下降,平均准确率呈凹形曲线。

图 5:在 CIFAR-10 → SVHN 场景下,源准确率、目标准确率和平均准确率随权衡超参数 λ 的变化情况。

  • 组合系数α的分析 (图 6): 组合系数 α 的平均绝对值(表示与源特征的偏离程度)通常随网络层加深而增大。这符合认知:浅层学习通用特征(任务间变化小),深层学习任务特定特征(任务间变化大)。

图 6:组合参数 α 元素的平均绝对值随其所在层深度的变化情况。

5. 结论

  • 总结: 提出了 ResCL,一种新颖的持续学习方法,在图像分类任务上展现了最先进的性能。
  • 优点:

    • 有效防止灾难性遗忘,即使源/目标任务差异很大。
    • 实用性强:无需额外源信息,推理时网络大小不增加。
    • 通用性好:可应用于包含BN层的通用CNN架构,处理BN层的方式自然有效。
  • 未来工作: 将 ResCL 方法扩展到其他类型的神经网络(如RNN)和其他机器学习领域。

本文由mdnice多平台发布


一只云卷云舒
1 声望0 粉丝