Variance-reduced Language Pretraining via a Mask Proposal Network

Self-supervised learning, a.k.a., pretraining, is important in natural language processing. Most of the pretraining methods first randomly mask some positions in a sentence and then train a model to recover the tokens at the masked positions. In such a way, the model can be trained without human labeling, and the massive data can be used with billion parameters. Therefore, the optimization efficiency becomes critical. In this paper, we tackle the problem from the view of gradient variance reduction. In particular, we first propose a principled gradient variance decomposition theorem, which shows that the variance of the stochastic gradient of the language pretraining can be naturally decomposed into two terms: the variance that arises from the sample of data in a batch, and the variance that arises from the sampling of the mask. The second term is the key difference between selfsupervised learning and supervised learning, which makes the pretraining slower. In order to reduce the variance of the second part, we leverage the importance sampling strategy, which aims at sampling the masks according to a proposal distribution instead of the uniform distribution. It can be shown that if the proposal distribution is proportional to the gradient norm, the variance of the sampling is reduced. To improve efficiency, we introduced a MAsk Proposal Network (MAPNet), which approximates the optimal mask proposal distribution and is trained end-to-end along with the model. According to the experimental result, our model converges much faster and achieves higher performance than the baseline BERT model.

[1]  Yoshua Bengio,et al.  Variance Reduction in SGD by Distributed Importance Sampling , 2015, ArXiv.

[2]  Quoc V. Le,et al.  ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators , 2020, ICLR.

[3]  Rico Sennrich,et al.  Neural Machine Translation of Rare Words with Subword Units , 2015, ACL.

[4]  Philipp Koehn,et al.  Moses: Open Source Toolkit for Statistical Machine Translation , 2007, ACL.

[5]  Andrew McCallum,et al.  Energy and Policy Considerations for Deep Learning in NLP , 2019, ACL.

[6]  Yiming Yang,et al.  XLNet: Generalized Autoregressive Pretraining for Language Understanding , 2019, NeurIPS.

[7]  Jeffrey Dean,et al.  Distributed Representations of Words and Phrases and their Compositionality , 2013, NIPS.

[8]  Ming-Wei Chang,et al.  BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding , 2019, NAACL.

[9]  Jeffrey Pennington,et al.  Dynamic Pooling and Unfolding Recursive Autoencoders for Paraphrase Detection , 2011, NIPS.

[10]  Jürgen Schmidhuber,et al.  Long Short-Term Memory , 1997, Neural Computation.

[11]  Jeffrey Pennington,et al.  GloVe: Global Vectors for Word Representation , 2014, EMNLP.

[12]  Luke S. Zettlemoyer,et al.  Deep Contextualized Word Representations , 2018, NAACL.

[13]  Sanja Fidler,et al.  Aligning Books and Movies: Towards Story-Like Visual Explanations by Watching Movies and Reading Books , 2015, 2015 IEEE International Conference on Computer Vision (ICCV).

[14]  Phil Blunsom,et al.  A Convolutional Neural Network for Modelling Sentences , 2014, ACL.

[15]  Omer Levy,et al.  RoBERTa: A Robustly Optimized BERT Pretraining Approach , 2019, ArXiv.

[16]  Kevin Gimpel,et al.  Gaussian Error Linear Units (GELUs) , 2016 .

[17]  Omer Levy,et al.  GLUE: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding , 2018, BlackboxNLP@EMNLP.

[18]  Xiaodong Liu,et al.  Unified Language Model Pre-training for Natural Language Understanding and Generation , 2019, NeurIPS.

[19]  Lukasz Kaiser,et al.  Attention is All you Need , 2017, NIPS.

[20]  James Demmel,et al.  Large Batch Optimization for Deep Learning: Training BERT in 76 minutes , 2019, ICLR.

[21]  James Demmel,et al.  Reducing BERT Pre-Training Time from 3 Days to 76 Minutes , 2019, ArXiv.

[22]  R. Durrett Probability: Theory and Examples , 1993 .

[23]  Christopher D. Manning,et al.  Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks , 2015, ACL.

[24]  Alec Radford,et al.  Improving Language Understanding by Generative Pre-Training , 2018 .

[25]  Ilya Sutskever,et al.  Language Models are Unsupervised Multitask Learners , 2019 .

[26]  Di He,et al.  Efficient Training of BERT by Progressively Stacking , 2019, ICML.

[27]  Chris Callison-Burch,et al.  Open Source Toolkit for Statistical Machine Translation: Factored Translation Models and Lattice Decoding , 2006 .

[28]  Jimmy Ba,et al.  Adam: A Method for Stochastic Optimization , 2014, ICLR.