查看原文
其他

NeurIPS 2022 | 稀疏且鲁棒的预训练语言模型

刘源鑫 PaperWeekly 2022-12-14
©PaperWeekly 原创 · 作者 | 刘源鑫
单位 | 北京大学
研究方向 | 自然语言处理

论文标题:
A Win-win Deal: Towards Sparse and Robust Pre-trained Language Models

收录会议:

NeurIPS 2022

论文链接:

https://arxiv.org/abs/2210.05211

代码链接:

https://github.com/llyx97/sparse-and-robust-PLM




背景及动机

尽管预训练语言模型(PLM)在自然语言处理(NLP)领域取得了很大的成功,它们仍然面临两个主要问题。一方面,PLM 的参数量通常很大,会消化大量的计算资源。另一方面,虽然 PLM 在大规模语料上进行了预训练,但是它们仍然容易受下游数据集偏见(dataset bias)的影响,因而在 out-of-distribution(OOD)测试集上的性能不理想。

对于模型参数量问题,一些工作尝试用稀疏子网络代替 PLM。[1,2,3] 将微调后的 PLM 剪枝为稀疏子网络。[4,5,6,7] 采用了彩票假设 [8] 的设定,直接剪枝未经微调的 PLM,并把得到的子网络在下游任务进行微调。更进一步,[9] 发现 PLM 中本身就包括一些子网络,它们可以直接用于下游任务测试,而无须对模型权重进行任何微调。图 1 展示了这三种微调-剪枝流程。


▲ 图1 通过不同微调-剪枝流程得到的PLM子网络,在in-distribution和out-of-distribution两种场景下测试。

同时,很多去偏方法 [13-15] 也被提出用于缓解 dataset bias 问题。一个主流的思想是根据训练样本的偏见程度调整对应损失函数的权重,使模型不过多关注偏见样本(可以通过表层特征就分类对的样本),从而提升其在 OOD 测试集上的泛化能力。

虽然近期工作在上述两个问题上都取得了不错的进展,但是还很少有工作对 PLM 的高效性和鲁棒性同时进行探究。然而,为了促进 PLM 在真实场景中的应用,这两个问题是需要被同时解决的。因此,本文将 PLM 剪枝研究扩展到了 OOD 场景,在上述三种微调-剪枝流程下,探究是否存在既稀疏又对 dataset bias 鲁棒的 PLM 子网络Sparse and Robust SubNetworks, SRNets)?




BERT剪枝及去偏

2.1 BERT子网络

本文中,我们以预训练语言模型 BERT 为研究对象。BERT 由一个词嵌入矩阵,若干 Transformer 层,以及一个任务相关的分类器 组成。其中,每个 Transformer 层都由多头自注意力模块(MHAtt)和前馈网络(FFN)组成。MHAtt 中包括三个计算注意力的矩阵 和一个输出矩阵 。FFN 中包括两个线性层
假设 BERT 模型为 ,我们通过在其参数 上加二元掩码 ,获得子网络 。本文中,我们对 MHAtt、FFN 模块,以及任务相关分类器进行剪枝,即需要剪枝的参数矩阵集合为

2.2 剪枝方法

2.2.1 迭代权重剪枝

基于权重的剪枝 [8,10] 移除绝对值最小的模型参数。通常,剪枝和训练是交替进行的,这整个流程也叫做迭代权重剪枝(Iterative Magnitude Pruning, IMP):

  1. 将完整模型训练至收敛。
  2. 将一部分绝对值最小的模型参数移除。
  3. 训练剪枝后的子网络。
  4. 重复 1-3,直到子网络达到目标稀疏程度。
以上流程一般用于获得微调后 PLM 的子网络(即图 1 中的(a)),如果要获得未经微调的子网络(即图 1 中的(b)和(c)),在第三步之后还需要将子网络参数重置回预训练的值,并且第一步可以被跳过。

2.2.2 掩码训练

掩码训练将  视为可以训练的变量。参考 [9,11] 中的做法,我们在前向传播中进行二元化,在反向传播中进行梯度估计,以实现对 的训练。
具体地,PLM 中每个需要剪枝的参数矩阵 都会和两个掩码矩阵, 关联。前向传播时, 将由 替代,即只用到了 定义的子网络。 根据  计算得到:


其中 为阈值。在反向传播时,我们用 straight-through estimator [12] 进行梯度估计,以解决二元化操作不可导的问题。如果 是微调后的参数,那么掩码训练搜索到的是微调后 BERT 的子网络。同理,如果 是未经微调的预训练参数,那么掩码训练得到的是预训练 BERT 中的子网络。
不论是 IMP 还是掩码训练,都对训练中使用的损失函数没有限制。因此,这两种剪枝方法都可以和下面要介绍的去偏方法结合
2.3 去偏方法
在去偏方法中,通常用一个偏见模型估计样本的偏见程度。偏见模型在训练时,输入是人工构造的表层特征。这样偏见模型就会学习到表层特征和类别之间的联系。如果偏见模型对某训练样本的预测类别分布是 ,那么正确类别 的概率 就是该样本的偏见程度。
有了 后,可以用不同的去偏方法对训练损失函数进行调整。本文中考虑以下几种损失函数:
  • 标准交叉熵损失:计算主模型预测概率 和正确类别独热分布(one-hot) 之间的交叉熵
  • Product-of-Experts(PoE)[13]:先结合主模型和偏见模型的预测概率 ,再计算交叉熵
  • 样本重权重 [14]:直接用偏见程度调整每个样本损失函数值的权重,给予偏见程度高的样本低权重

  • 置信度正则化 [15]:这是一种基于知识蒸馏的方法,需要一个用标准交叉熵训练好的教师模型。教师模型的预测分布 给学生模型(主模型)提供监督信号,同时用偏见程度调整每个样本损失值的权重:




搜索稀疏且鲁棒的BERT子网络

3.1 实验设置

模型我们主要以 BERT-base 为研究对象。同时,为了验证结论的普适性,我们还在 RoBERTa-base 和 BERT-large 上进行了部分实验,相关结果请参见论文。

任务及数据集本文在三个自然语言理解任务上进行了实验:自然语言推断(Natural Language Inference, NLI),释义识别(Paraphrase Identification)和事实验证(Fact Verification)。每个任务都有一个 in-distribution(ID)数据集和一个 out-of-distribution(OOD)数据集。其中 ID 数据集中存在 dataset bias,而 OOD 数据集在构建时去除了这些 bias。

  • NLI:MNLI 作为 ID 数据集,HANS 作为 OOD 数据集。
  • 释义识别:QQP 作为 ID 数据集,PAWS-qqp, PAWS-wiki 作为 OOD 数据集。
  • 事实验证:FEVER 为 ID 数据集,FEVER-symmetric(v1,v2)为 OOD 数据集。
对每个数据集统计、bias 类型以及评测方法的介绍,请参见我们的论文。

3.2 微调后搜索BERT子网络

我们首先探究微调后的 BERT 中是否存在鲁棒的子网络(对应图 1, (a))。我们考虑两种微调后的 BERT,一种是用标准交叉熵训练的(3.2.1小节),一种是用 PoE 去偏方法训练的(3.2.2小节),其中前者本身对 dataset bias 是不鲁棒的。
3.2.1 标准交叉熵微调BERT

▲ 图2 标准交叉熵微调后搜索BERT子网络的效果

我们用四种不同的剪枝-损失函数组合对微调后的 BERT 进行压缩。这里我们只展示标准交叉熵和 PoE 两种损失函数,关于样本重权重和置信度正则化请参见我们的论文。

图 2 展示了不同子网络的效果,我们可以得出以下主要发现:
  • 如果在剪枝过程中采用标准交叉熵损失,即 mask train(std)和 imp(std),得到的子网络相比完整 BERT 在 HANS 和 PAWS 上总体有略微的提升。这可能是因为部分和表层特征相关的参数被剪枝了。

  • 如果在剪枝过程中采用 PoE 去偏方法,即 mask train(poe)和 imp(poe),我们可以获得 70% 稀疏的子网络(保留 30% 参数),它们在 OOD 数据集上比完整 BERT 有显著的提升,并且在 ID 数据集上保持了95%以上的性能。这说明在剪枝过程中采用去偏损失函数对于同时实现压缩和去偏的目标是很有效的。另外,由于掩码训练没有改变模型参数值,mask train(poe)的效果意味着带有偏见的 BERT 模型中本身就存在鲁棒子网络。
  • 在两种训练损失下,掩码训练的总体效果都优于 IMP。

3.2.2 PoE去偏方法微调BERT


▲ 图3 PoE微调后搜索BERT子网络的效果

现在我们对 PoE 微调的 BERT 进行剪枝,看看子网络的效果如何。因为前一小节已经发现在剪枝过程中使用 PoE 比标准交叉熵效果好,此处我们仅考虑 PoE。

从图 3 的结果中,我们可以看出:

  • 和带有偏见的 full bert(std)不同,在已经较为鲁棒的 full bert(poe)中,子网络的 OOD 效果没有显著超越完整 BERT。但是在较高的 70% 稀疏程度下,子网络的 ID 和 OOD 效果仍然没有很多下降,保持了完整 BERT 95% 以上的性能。

  • 从带有偏见的 BERT 中搜索到的子网络(图 3 中橙色曲线)和较为鲁棒 BERT 中搜索到的子网络(图 3 中蓝色曲线)效果没有很大差距。这意味着只要在剪枝过程中引入了去偏方法,对完整BERT的去偏就不是必须的。

3.3 单独微调的BERT子网络


▲ 图4 单独微调的BERT子网络效果
在本小节设置下(对应图1,(b)),我们首先通过 PoE 剪枝得到子网络,这些子网络参数值和原始 BERT 相同,但是在结构上(即二元掩码 )已经学习到了一些去偏信息。然后我们分别用标准交叉熵和 PoE 微调这些子网络的参数。

根据图 4 中的结果,我们发现:

  • 如果采用标准交叉熵微调,在 20%~50% 稀疏程度内,用掩码训练搜索到的子网络 mask train(poe)的 OOD 效果要优于完整 BERT。这说明拥有较为鲁棒结构的子网络在训练过程中,相比完整 BERT 更不容易受到 dataset bias 的影响
  • 如果采用 PoE 微调,在 70% 稀疏程度内,掩码训练和 IMP 剪枝得到的子网络和完整 BERT 的效果相当。
  • 结合以上两点,我们可以吧 BERT 中的彩票假设 [4,5] 推广到 OOD 场景:预训练 BERT 中包含了一些子网络,它们可以在下游任务上用标准交叉熵或 PoE 去偏方法单独微调,并且在 ID 和 OOD 场景下取得和微调完整 BERT 相当的效果
  • 对比标准交叉熵微调和 PoE 微调,在所有情况下后者的 OOD 性能都有明显优势。这说明即使子网络已经学习到了较为鲁棒的结构,对于参数值的去偏训练仍然是重要的

3.4 不微调参数的BERT子网络


▲ 图5 不微调参数的BERT子网络效果

在本小节中,我们直接在预训练 BERT 参数(包括随机初始化的下游分类器参数)上进行掩码训练而不对参数值进行微调(对应图1,(c))。在掩码训练中,我们采用了标准交叉熵和 PoE 两种损失函数。

从图 5 中,我们发现:

  • 在掩码训练过程中采用交叉熵得到的子网络(50% 稀疏程度以下)和同样用交叉熵训练出的完整 BERT 效果相当。
  • 对于 PoE 损失函数也是同样的结论。这说明未经微调的预训练 BERT 中本身已经存在适用于特定下游任务的 SRNets
  • 对比预训练 BERT 中的子网络(图 5 中绿色曲线)和微调后 BERT 中的子网络(图 5 中橙色曲线),我们发现后者总体效果略优于前者,但是二者差距较小。这引出了一个比较有意思的问题:是否有必要首先把完整 BERT 微调至收敛,再开始掩码训练?在 3.6.1 小节中将对这个问题进行进一步探究。

3.5 利用OOD数据搜索无偏的oracle子网络


▲ 图6 用OOD训练集搜索出的oracle BERT子网络

通过 3.2-3.4 小节的实验分析,我们已经发现在不同的微调-剪枝流程下都存在稀疏且鲁棒的 BERT 子网络。本小节希望探究这些子网络 OOD 性能的上界。为此我们利用 OOD 训练集进行掩码训练搜索 oracle 子网络,而用和训练集没有重合数据的 OOD 测试集进行测试。和之前的实验一样,我们探究了三种微调-剪枝流程下的效果。为了反映子网络结构本身的去偏能力,在对子网络参数进行微调时(对应图1,(b)的设置),我们采用的是对 dataset bias 较为敏感的标准交叉熵损失。

根据图 6 的结果,我们发现:
  • 在 BERT 微调前后搜索到的 oracle 子网络(分别对应图 6 中 bert-pt subnet 和 bert-ft subnet)在一定稀疏范围内可以取得很高的 OOD 性能。特别地,20%~70% 稀疏的 bert-ft subnet 在 HANS 上都取得了 100% 的准确率。
  • 如果用标准交叉熵对 oracle 子网络参数进行微调,它们的 OOD 效果会有一定的下降。然而相比完整 BERT,这些 oracle 子网络在训练过程中对 dataset bias 明显更加鲁棒。
  • 以上发现说明,BERT 子网络对 dataset bias 鲁棒性的上界很高,在不同微调-剪枝流程下理论上都存在几乎无偏的 BERT 子网络

3.6 改进掩码训练方法

虽然我们已经证明了 SRNets 的存在,但是前几小节的实验结果显示整个微调-剪枝流程还有改进空间。在本小节中,我们探究如何对基于掩码训练的 SRNet 搜索方法做进一步的改进。
3.6.1 开始掩码训练的时刻

▲ 图7 NLI任务上70%稀疏程度的掩码训练曲线。不同曲线代表从不同的完整BERT微调checkpoint开始掩码训练。

相比于在微调后的 BERT 中搜索子网络,在微调前进行搜索的总体训练开销更小。在 3.4 小节中我们发现,二者的最终效果存在一些差距。那么,我们是否可以在训练开销和子网络效果之间找到一个更合适的 trade-off?为此我们进行了一系列的实验,从不同微调程度的完整 BERT checkpoints 开始掩码训练。

图7展示了掩码训练的准确率变化曲线。我们发现在预训练 BERT 上进行掩码训练(ft step=0)的收敛速度比在微调后的 BERT 上慢(ft to end),特别是在 HANS OOD 数据集上。然而,如果以经过 20,000 步(ft to end 的约 55%)微调的 BERT 为起点,掩码训练的最终效果和 ft to end 就没有显著差距了。这说明整个微调-剪枝流程的开销可以在完整 BERT 微调阶段被节省很大一部分而不影响子网络最终效果
不过值得注意的是,在以上分析中我们只是证明了减少 BERT 微调步数的可行性。要真正减少训练开销,我们需要事先对开始掩码训练的时刻进行预测,对这个问题的进一步探究也是一个值得研究的未来方向。

3.6.2 逐渐提升稀疏程度


▲ 图8 掩码训练中固定子网络稀疏程度和逐渐提升稀疏程度对比。

在之前掩码训练的实验中,我们都是固定子网络的稀疏程度,并且发现高于 70% 稀疏程度后效果就急剧下降。我们猜想这是因为在掩码训练的一开始就将稀疏程度设置过高,对于 的优化比较困难。因此,我们改进了掩码训练过程,在训练过程中采用 cubic sparsity schedule [16],将 的稀疏程度逐渐从 70% 提升到 90%。
图 8 在不同的训练开销下,对比了固定稀疏程度和逐渐提升稀疏程度(gradual sparsity increase, GSI)。我们发现,直接增加掩码训练轮次也能提升高稀疏程度下的效果。但是 GSI 只需要和原本相同的训练开销,就取得了更好的 ID 和 OOD 效果。



总结及未来方向

在本文中,对于预训练语言模型 BERT,我们探究能否同时实现其子网络的稀疏性和鲁棒性。通过在三种自然语言处理任务上进行大量的实验,我们发现在三种常见的微调-剪枝流程下,的确存在稀疏且鲁棒的 BERT 子网络(SRNets)。进一步利用 OOD 训练集,我们发现 BERT 中存在对特定 dataset bias 几乎无偏的子网络。最后,针对掩码训练剪枝方法,我们从开始剪枝的时刻和掩码训练过程中子网络稀疏程度的控制两个角度,对子网络搜索的效率和效果提出了改进的思路。

在我们工作的基础上,仍然有几个方向值得继续改进和探究:

  • 本文只探究了 BERT 类型的 PLM 和自然语言理解任务。在其他类型的 PLM(例如 GPT)和 NLP 任务(例如自然语言生成)中也可能存在 dataset bias 问题。在这些场景下实现 PLM 的压缩和去偏也是很重要的。
  • 本文采用的整个微调-剪枝流程仍然有很大的优化空间。例如 3.6.1 小节中提到的,事先对开始掩码训练的时刻进行精确预测也是一个有意思的研究方向。


参考文献

[1] Z. Li, E. Wallace, S. Shen, K. Lin, K. Keutzer, D. Klein, and J. E. Gonzalez. Train large, then compress: Rethinking model size for efficient training and inference of transformers. In ICML 2020.
[2] P. Michel, O. Levy, and G. Neubig. Are sixteen heads really better than one? In NeurIPS 2019.
[3] Y. Liu, Z. Lin, and F. Yuan. ROSITA: refined BERT compression with integrated techniques. In AAAI 2021.
[4] T. Chen, J. Frankle, S. Chang, S. Liu, Y. Zhang, Z. Wang, and M. Carbin. The lottery ticket hypothesis for pre-trained BERT networks. In NeurIPS 2020.
[5] S. Prasanna, A. Rogers, and A. Rumshisky. When BERT plays the lottery, all tickets are winning. In EMNLP 2020.
[6] Y. Liu, F. Meng, Z. Lin, P. Fu, Y. Cao, W. Wang, and J. Zhou. Learning to win lottery tickets in BERT transfer via task-agnostic mask training. In NAACL 2022.
[7] C. Liang, S. Zuo, M. Chen, H. Jiang, X. Liu, P. He, T. Zhao, and W. Chen. Super tickets in pre-trained language models: From model compression to improving generalization. In ACL 2021.
[8] Jonathan Frankle and Michael Carbin. The lottery ticket hypothesis: Finding sparse, trainable neural networks. In ICLR 2019.
[9] M. Zhao, T. Lin, F. Mi, M. Jaggi, and H. Schütze. Masking as an efficient alternative to finetuning for pretrained language models. In EMNLP 2020.
[10] S. Han, J. Pool, J. Tran, and W. Dally. Learning both weights and connections for efficient neural network. In NIPS 2015.
[11] A. Mallya, D. Davis, and S. Lazebnik. Piggyback: Adapting a single network to multiple tasks by learning to mask weights. In ECCV 2018.
[12] Y. Bengio, N. Léonard, and A. C. Courville. Estimating or propagating gradients through stochastic neurons for conditional computation. CoRR, abs/1308.3432.
[13] C. Clark, M. Yatskar, and L. Zettlemoyer. Don’t take the easy way out: Ensemble based methods for avoiding known dataset biases. In EMNLP 2019.
[14] T. Schuster, D. J. Shah, Y. J. S. Yeo, D. Filizzola, E. Santus, and R. Barzilay. Towards debiasing fact verification models. In EMNLP 2019.
[15] P. A. Utama, N. S. Moosavi, and I. Gurevych. Mind the trade-off: Debiasing NLU models without degrading the in-distribution performance. In ACL 2020.
[16] M. Zhu and S. Gupta. To prune, or not to prune: Exploring the efficacy of pruning for model compression. In ICLR (Workshop) 2018.


更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存