DiT:用于文档图像Transformer的自监督预训练
📖阅读时长:19分钟
🕙发布时间:2025-02-13
近日热文:全网最全的神经网络数学原理(代码和公式)直观解释
欢迎关注知乎和公众号的专栏内容
LLM架构专栏
知乎LLM专栏
知乎【柏企】
公众号【柏企科技说】【柏企阅文】
DiT是一种自监督预训练的文档图像Transformer模型,它使用大规模无标签文本图像来处理文档人工智能(Document AI)任务。我们将DiT用作各种基于视觉的文档AI任务的骨干网络,这些任务包括文档图像分类、文档布局分析、表格检测以及用于光学字符识别(OCR)的文本检测。
一、文档AI模型预训练的现状
通常,预训练文档AI模型的流程是从基于视觉的理解开始的,比如光学字符识别(OCR)或文档布局分析,而这在很大程度上仍然依赖于有人类标记训练样本的有监督计算机视觉骨干模型。尽管在基准数据集上取得了不错的成果,但由于领域差异以及与训练数据的模板/格式不匹配,这些视觉模型在实际应用中往往存在性能差距。
目前并没有像ImageNet那样常用的大规模人工标注基准数据集,这使得大规模有监督预训练难以实现。即使弱监督方法被用于创建文档AI基准,但这些数据集的领域通常来自学术论文,它们具有相似的模板和格式,与现实世界中的文档(如表格、发票/收据、报告等)有所不同。这可能导致在处理一般的文档AI问题时,结果不尽人意。因此,使用来自通用领域的大规模无标签数据对文档图像骨干模型进行预训练至关重要,因为这能够支持多种文档AI任务。
二、模型架构
参照ViT,DiT采用普通的Transformer架构作为骨干。我们将文档图像划分为不重叠的图像块,并获取一系列图像块嵌入。在添加1维位置嵌入后,这些图像块会被输入到带有多头注意力机制的Transformer块堆栈中。最后,Transformer编码器的输出将作为图像块的表示。
三、预训练
受BEiT启发,DiT采用掩码图像建模(MIM)作为预训练目标。在这个过程中,图像分别从两个视角被表示为图像块和视觉标记。在预训练期间,DiT将图像块作为输入,并利用输出表示预测视觉标记。
就像自然语言中的文本标记一样,图像可以通过图像标记器表示为一系列离散标记。BEiT使用来自DALLE的离散变分自编码器(dVAE)作为图像标记器,该标记器在包含4亿张图像的大型数据集上进行训练。然而,自然图像和文档图像之间存在领域不匹配的问题,这使得DALL-E标记器不适用于文档图像。因此,为了在文档图像领域获得更好的离散视觉标记,我们在包含4200万张文档图像的IIT-CDIP数据集上训练了一个dVAE。
新的dVAE标记器通过均方误差(MSE)损失和困惑度损失共同训练。MSE损失用于重建输入图像,困惑度损失则用于增强量化码本表示的使用效果。
为了有效地预训练DiT模型,我们在给定一系列图像块时,用特殊标记[MASK]随机掩码一部分输入。DiT编码器通过线性投影和添加位置嵌入来嵌入掩码后的图像块序列,然后通过Transformer块堆栈对其进行上下文编码。模型需要根据掩码位置的输出预测视觉标记的索引。掩码图像建模任务要求模型预测由图像标记器获得的离散视觉标记,而不是预测原始像素。
四、微调
(一)图像分类
在图像分类任务中,我们使用平均池化来聚合图像块的表示。接着,将全局表示输入到一个简单的线性分类器中。
(二)目标检测
在目标检测任务中,我们采用Mask R-CNN和Cascade R-CNN作为检测框架,并使用基于ViT的模型作为骨干网络。在四个不同的Transformer块中使用分辨率调整模块,以使单尺度的ViT适应多尺度特征金字塔网络(FPN)。
设总块数为$N$,对第$\frac{N}{3}$个块,使用一个带有2个步长为2的$2\times2$转置卷积的模块将其进行4倍上采样;对于第$\frac{N}{2}$个块的输出,使用一个步长为2的$2\times2$转置卷积进行2倍上采样;第$\frac{2N}{3}$个块的输出则直接使用,无需额外操作;最后,对第$\frac{3N}{3}$个块的输出,使用步长为2的$2\times2$最大池化进行2倍下采样。
五、评估
预训练的DiT模型在四个公开的文档AI基准数据集上进行评估:
- RVL-CDIP数据集,用于文档图像分类;
- PubLayNet数据集,用于文档布局分析;
- ICDAR 2019 cTDaR数据集,用于表格检测;
- FUNSD数据集,用于OCR文本检测。
六、论文
DiT: Self-supervised Pre-training for Document Image Transformer 2203.02378
## 推荐阅读
1. DeepSeek-R1的顿悟时刻是如何出现的? 背后的数学原理
2. 微调 DeepSeek LLM:使用监督微调(SFT)与 Hugging Face 数据
3. 使用 DeepSeek-R1 等推理模型将 RAG 转换为 RAT
4. DeepSeek R1:了解GRPO和多阶段训练
5. 深度探索:DeepSeek-R1 如何从零开始训练
6. DeepSeek 发布 Janus Pro 7B 多模态模型,免费又强大!
本文由mdnice多平台发布
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。