Article / 文章中心

ICLR2022顶会论文分享-PoNet:使用多粒度Pooling结构替代attention的网络

发布时间:2022-02-18 点击数:659
    
简介: 近年来,在机器学习范畴Transformer模型已成为最先进的(SOTA) 序列建模模型,包含在自然言语处理 (NLP)、核算机视觉、语音处理、基因组数据等都有着广泛的运用。


image.png


近年来,在机器学习范畴Transformer模型已成为最先进的(SOTA) 序列建模模型,包含在自然言语处理 (NLP)、核算机视觉、语音处理、基因组数据等都有着广泛的运用。

Transformer 成功的关键原因在于它的自我留意(self-attention)机制,核算输入表征的每个方位之间的点积。Transformer被证明在学习上下文表征方面非常有用,它成为最主要的骨干模型,例如 BERT和 RoBERTa。这些预练习言语模型展现了强壮的搬迁学习才能,并在广泛NLP使命中完成了 SOTA。

然而,因为Transformer模型中的self-attention机制相对于语句长度的复杂度是二次的(O(N^2)),因而在核算速度和显存空间方面都限制了它在长序列中的运用。咱们提出了一种具有线性复杂度 (O(N)) 的核算模型 PoNet ,运用 pooling 网络代替 self-attention 机制对语句词汇进行混合,从而捕捉上下文信息。

试验表明,PoNet 在长文本测验 Long Range Arena (LRA) 榜[1] 上在准确率上比 Transformer 高 2.28 个点,在GPU上运转速度是Transformer的 9 倍,显存占用只有 1/10。此外,试验也展现了 PoNet 的搬迁学习才能,PoNet-Base 在 GLUE 基准上达到了 BERT-Base 的 95.7% 的准确性。

|| 模型

受到用于视觉使命的外部留意EA[2]的启示,咱们将其简化为 多层感知器  softmax,并观察到 softmax 经过分母项将序列信息融入到 token 中供给了上下文建模才能。然而,softmax 涉及到指数的核算,这仍然是很慢的。因而,咱们考虑运用池化法作为代替办法,以明显下降的复杂度来捕捉语境信息。

模型主要由三个不同粒度的 pooling 组成,一个大局的pooling模块(GA),分段的segment max-pooling模块(SMP),和局部的max-pooling模块(LMP),对应捕捉不同粒度的序列信息:


 在第一阶段,GA沿着序列长度进行均匀得到语句的大局表征g。为了加强对大局信息的捕捉,GA在第二阶段对g和输入练习核算cross-attention。因为g的长度为1,因而总的核算复杂度仍为O(N)。

 SMP按每个分段求取最大值,以捕获中等颗粒度的信息。

 LMP沿着序列长度的方向核算滑动窗口max-pooling。

 然后经过池化交融(PF)将这些池化特征聚合起来。因为GA的特征在整个token序列是同享的,SMP的特征在segment内部也是同享的,直接将这些特征加到原始token上会使得token趋同(向量加法),而这种token表征同质化的影响将会下降比如语句对分类使命的性能。因而,咱们在PF层将原始的token于对应的GA,SMP特征核算元素乘法得到新的特征,使得不同的token对应了不同的特征。

image.png


|| 试验结果

长序列使命

Long Range Arena(LRA) 是用来评价捕捉长距离依赖关系的基准测验。在LRA上,PoNet取得了比Transformer更好的分数。

image.png

在速度和显存方面,仅次于FNet[3],明显优于Transformer。

image.png

搬迁学习

咱们用大规模语料库对PoNet进行预练习,然后测验它在下流使命上的性能。下图是预练习的  MLM[4] 和 SSO[5] 两个子使命的练习曲线,能够看到,咱们的模型在 MLM 上略弱小于 BERT ,在 SSO 上与 BERT 还有必定的差距,两个使命上都明显要优于 FNet 。

image.png

GLUE

PoNet取得了76.80的AVG分数,达到了 BERT 在 GLUE 上的准确率(80.21)的95.7%,相对来说比 FNet 要好4.5%。这些性能比较与图2中显示的预练习准确率一致。

image.png

长文本使命

咱们还评价了预练习的 PoNet 在四个长文本分类数据集上的性能。从表4能够看出,PoNet-Base 在 HND 和 Arxiv 上优于 BERT-Base,在 IMDb 和 Yelp-5 上的F1分数达到了 BERT-Base 的99%。

image.png

融化分析

下面的融化试验也证明晰每个组件的重要性。一起与 L_MN(MLM+NSP),L_OM(MLM) 也说明晰预练习使命运用 MLM+SSO 的必要性。

image.png

|| 总结

咱们提出了一个运用多粒度的 Pooling 结构来代替 attention 的网络(PoNet),它能够捕捉到不同层次的上下文信息,让序列的 token 之间能够得到有用的交互。试验表明,PoNet 既完成了有竞争力的长距离依赖性建模才能,又完成了强壮的搬迁学习才能,而且具有线性的时刻和显存复杂度。


|| Future Work

未来的作业包含进一步优化模型结构和预练习,以及将 PoNet 运用于包含生成使命在内的更广泛的使命。咱们希望PoNet模型能够对探究更高效的序列建模模型供给一些启示。