论文标题: PyCIL: 一个用于类增量学习的Python工具箱 (PyCIL: A Python Toolbox for Class-Incremental Learning)
所属机构: 南京大学计算机软件新技术国家重点实验室,中国南京 210023
论文:https://link.springer.com/article/10.1007/s11432-022-3600-y
代码:https://github.com/G-U-N/PyCIL
1. 摘要 (Abstract - 隐式,基于前两段):
- 问题: 标准的深度学习模型通常假设类别集合是固定的,难以适应现实世界中数据以流式格式出现或因隐私问题临时可用的情况。直接用新数据微调模型会导致“灾难性遗忘”现象,即模型在旧类别上的性能急剧下降。
- 类增量学习 (CIL) 的目标: 使模型能够逐步学习新的类别,同时不丢失对旧类别的识别能力。
- 需求: 随着机器学习社区对类增量学习兴趣的增长,迫切需要一个简单、高效且包含多种标准算法的工具箱。
- 提出的解决方案: 介绍了 PyCIL,一个基于Python的类增量学习工具箱。
PyCIL 特点:
- 使用Python开发,便于机器学习社区广泛使用。
- 实现了多个CIL领域的基础性和当前先进的算法。
- 设计易于使用、易于获取,所有功能具有一致的语法和约定。
- 仅依赖标准的开源库。
- 跨平台兼容 (Linux, macOS, Windows)。
- 源代码已在GitHub上公开 (https://github.com/G-U-N/PyCIL)。
2. 引言 (Introduction - 第1-3段):
- 背景: 深度模型在固定数据集上表现优异,但在动态开放环境中面临挑战。
- 问题阐述: 现实世界数据常以流式或临时形式出现,要求模型能增量学习新类别。直接微调会因灾难性遗忘导致性能下降。
- CIL 定义与示例: CIL旨在扩展模型知识以包含新类别,同时保留旧知识。举例说明:模型先学分类鸟和狗,然后增量学习虎和鱼,之后再学习猴和羊,每一步都需要能够区分所有见过的类别。
- PyCIL 动机: CIL社区发展迅速,需要一个简单高效的工具箱。Python因其在机器学习领域的广泛应用而被选用。PyCIL旨在为学术研究和工业应用提供便利工具。
3. 定义 1 (类增量学习 - Class-Incremental Learning):
- 形式化设定: 学习过程是一系列(B个)不重叠类别的训练任务 {D1, D2, ..., DB}。
- 任务数据: 每个任务 Db 包含 nb 个实例 (xi, yi),其中 xi ∈ RD 是数据实例,yi ∈ Yb 是任务 b 的标签空间。
- 类别不相交: 不同任务的标签空间互不重叠 (Yb ∩ Yb' = Ø 对于 b ≠ b')。
- 数据访问限制: 在训练任务 b 时,只能访问当前任务的数据 Db。
- 目标: 不仅要从当前任务 Db 获取知识,还要保留从先前任务中学到的知识。
- 评估: 每个任务训练完成后,模型需在所有已见类别 (Yb = Y1 ∪ ... ∪ Yb) 上进行评估。
4. 定义 2 (样本集 - Exemplar Set):
- 解决的问题: 仅用当前任务数据 Db 更新模型会导致严重的灾难性遗忘。
- 解决方案: 维护一个额外的“样本集” E,其中存储了来自先前见过的类别的有限数量 (M) 的实例。
- 作用机制: 通过回顾 (revisiting) 这些样本,可以帮助模型克服灾难性遗忘。
- 选择方法: 通常使用如herding等算法来选择样本,以确保其代表性。
5. 实现的算法 (Implemented Algorithms):
- 概述: PyCIL实现了11种典型的CIL算法。
列表与简述:
- Finetune: 基线方法,直接用新任务数据微调,遗忘严重。
- Replay: 基线方法,用新任务数据和样本集中的旧数据一起训练。
- EWC [1]: 利用费雪信息矩阵评估参数重要性,并施加正则化以减缓遗忘。
- LwF [2]: 利用知识蒸馏,使新模型在旧任务上的输出与旧模型对齐。
- iCaRL [3]: 基于LwF,引入样本集进行排练,并使用最近中心均值分类器。
- GEM [4]: 使用样本集作为约束,确保新任务的梯度更新不会增加在旧任务上的损失。
- BiC [5]: 在iCaRL基础上,训练一个额外的适配层来修正新类的预测偏差。
- WA [6]: 在iCaRL基础上,每次学习后根据权重范数归一化分类器权重。
- PODNet [7]: 引入空间层面的池化特征蒸馏来约束网络表示。
- DER [8]: 采用两阶段学习方法和动态可扩展的表示来更有效地建模。
- Coil [9]: 利用最优传输理论构建双向知识迁移。
6. 依赖项 (Dependencies):
- 核心库: NumPy 和 SciPy (用于线性代数和优化)。
- 深度学习框架: PyTorch (用于网络构建和训练)。
7. 基本用法 (Basic Usage):
- 功能: 提供了上述11种方法的实现,以及CIFAR100和ImageNet100/1000等基准数据集的设置。
- 配置: 用户可以编辑全局参数 (如内存大小
Memory-Size
, 初始类别数Init-Cls
, 增量步长Increment
, 网络类型Convnet-type
, 随机种子Seed
) 和算法特定超参数。 - 运行: 配置完成后运行主函数即可开始训练和评估。
8. 评估 (Evaluation):
- 指标: CIL常用的评估指标是每个增量阶段结束后的测试准确率 Ab,以及所有阶段的平均准确率 Ā = (1/B) Σ Ab。
- 数据集: 使用了CIFAR100和ImageNet100基准数据集。
- 实验设置: 将100个类别划分成多个增量阶段进行测试(例如,10个阶段,每阶段10类;或初始50类,后续4个阶段每阶段10类)。
- 结果 (图1): 展示了不同算法在CIFAR100和ImageNet100上的增量准确率曲线。作者表示,经过超参数搜索,复现的算法性能与原始论文报告相当或更好。
9. 结论 (Conclusion):
- 总结: 介绍了PyCIL,一个用Python编写的类增量学习工具箱。
- 贡献: 它实现了CIL领域的一系列开创性工作和当前先进的算法。
- 目标: 提供了一个代码一致、易于使用的工具,适用于研究、教学和工业应用。
本文由mdnice多平台发布
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。