Signal Propagation in Transformers: Theoretical Perspectives and the Role of Rank Collapse

Transformers have achieved remarkable success in several domains, ranging from natural language processing to computer vision. Nevertheless, it has been recently shown that stacking self-attention layers — the distinctive architectural component of Transformers — can result in rank collapse of the tokens’ representations at initialization. The question of if and how rank collapse affects training is still largely unanswered, and its investigation is necessary for a more comprehensive understanding of this architecture. In this work, we shed new light on the causes and the effects of this phenomenon. First, we show that rank collapse of the tokens’ representations hinders training by causing the gradients of the queries and keys to vanish at initialization. Furthermore, we provide a thorough description of the origin of rank collapse and discuss how to prevent it via an appropriate depth-dependent scaling of the residual branches. Finally, our analysis unveils that specific architectural hyperparameters affect the gradients of queries and values differently, leading to disproportionate gradient norms. This suggests an explanation for the widespread use of adaptive methods for Transformers’ optimization.

[1]  Li Dong,et al.  DeepNet: Scaling Transformers to 1, 000 Layers , 2022, ArXiv.

[2]  Jesse Michael Han,et al.  Formal Mathematics Statement Curriculum Learning , 2022, ICLR.

[3]  Jan Hązła,et al.  A Johnson-Lindenstrauss Framework for Randomly Initialized CNNs , 2021, ICLR.

[4]  Thomas Hofmann,et al.  Analytic Insights into Structure and Rank of Neural Network Hessian Maps , 2021, NeurIPS.

[5]  Sebastian Nowozin,et al.  Precise characterization of the prior predictive distribution of deep ReLU networks , 2021, NeurIPS.

[6]  Giambattista Parascandolo,et al.  Neural Symbolic Regression that Scales , 2021, ICML.

[7]  Aurélien Lucchi,et al.  Vanishing Curvature and the Power of Adaptive Methods in Randomly Initialized Deep Networks , 2021, ArXiv.

[8]  Cengiz Pehlevan,et al.  Exact marginal prior distributions of finite Bayesian neural networks , 2021, NeurIPS.

[9]  Andreas Loukas,et al.  Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth , 2021, ICML.

[10]  A. Doucet,et al.  Stable ResNet , 2020, AISTATS.

[11]  S. Gelly,et al.  An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale , 2020, ICLR.

[12]  Garrison W. Cottrell,et al.  ReZero is All You Need: Fast Convergence at Large Depth , 2020, UAI.

[13]  Sashank J. Reddi,et al.  Why are Adaptive Methods Good for Attention Models? , 2020, NeurIPS.

[14]  Maksims Volkovs,et al.  Improving Transformer Optimization Through Better Initialization , 2020, ICML.

[15]  Abdel-rahman Mohamed,et al.  wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations , 2020, NeurIPS.

[16]  Mark Chen,et al.  Language Models are Few-Shot Learners , 2020, NeurIPS.

[17]  Jiawei Han,et al.  Understanding the Difficulty of Training Transformers , 2020, EMNLP.

[18]  T. Hofmann,et al.  Batch normalization provably avoids ranks collapse for randomly initialised deep networks , 2020, NeurIPS.

[19]  Tie-Yan Liu,et al.  On Layer Normalization in the Transformer Architecture , 2020, ICML.

[20]  Colin Raffel,et al.  Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer , 2019, J. Mach. Learn. Res..

[21]  Liyuan Liu,et al.  On the Variance of the Adaptive Learning Rate and Beyond , 2019, ICLR.

[22]  Yiming Yang,et al.  XLNet: Generalized Autoregressive Pretraining for Language Understanding , 2019, NeurIPS.

[23]  Jingbo Zhu,et al.  Learning Deep Transformer Models for Machine Translation , 2019, ACL.

[24]  Yoshua Bengio,et al.  How to Initialize your Network? Robust Initialization for WeightNorm & ResNets , 2019, NeurIPS.

[25]  Myle Ott,et al.  fairseq: A Fast, Extensible Toolkit for Sequence Modeling , 2019, NAACL.

[26]  Tengyu Ma,et al.  Fixup Initialization: Residual Learning Without Normalization , 2019, ICLR.

[27]  Guillaume Lample,et al.  Cross-lingual Language Model Pretraining , 2019, NeurIPS.

[28]  Boris Flach,et al.  Feed-forward Propagation in Probabilistic Neural Networks with Categorical and Max Layers , 2018, ICLR.

[29]  Yuanzhi Li,et al.  A Convergence Theory for Deep Learning via Over-Parameterization , 2018, ICML.

[30]  Frank Hutter,et al.  Decoupled Weight Decay Regularization , 2017, ICLR.

[31]  Andrew M. Dai,et al.  Music Transformer: Generating Music with Long-Term Structure , 2018, ICLR.

[32]  Jascha Sohl-Dickstein,et al.  Dynamical Isometry and a Mean Field Theory of CNNs: How to Train 10, 000-Layer Vanilla Convolutional Neural Networks , 2018, ICML.

[33]  Ankur Bapna,et al.  The Best of Both Worlds: Combining Recent Advances in Neural Machine Translation , 2018, ACL.

[34]  David Rolnick,et al.  How to Start Training: The Effect of Initialization and Architecture , 2018, NeurIPS.

[35]  Boris Hanin,et al.  Which Neural Net Architectures Give Rise To Exploding and Vanishing Gradients? , 2018, NeurIPS.

[36]  Surya Ganguli,et al.  Resurrecting the sigmoid in deep learning through dynamical isometry: theory and practice , 2017, NIPS.

[37]  Lukasz Kaiser,et al.  Attention is All you Need , 2017, NIPS.

[38]  Jean Daunizeau,et al.  Semi-analytical approximations to statistical moments of sigmoid and softmax mappings of normal variables , 2017, 1703.00091.

[39]  Surya Ganguli,et al.  Deep Information Propagation , 2016, ICLR.

[40]  Surya Ganguli,et al.  Exponential expressivity in deep neural networks through transient chaos , 2016, NIPS.

[41]  Jian Sun,et al.  Deep Residual Learning for Image Recognition , 2015, 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

[42]  Jian Sun,et al.  Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification , 2015, 2015 IEEE International Conference on Computer Vision (ICCV).

[43]  Jimmy Ba,et al.  Adam: A Method for Stochastic Optimization , 2014, ICLR.

[44]  Surya Ganguli,et al.  Exact solutions to the nonlinear dynamics of learning in deep linear neural networks , 2013, ICLR.

[45]  Yoshua Bengio,et al.  Understanding the difficulty of training deep feedforward neural networks , 2010, AISTATS.

[46]  Martin J. Wainwright,et al.  Estimating Divergence Functionals and the Likelihood Ratio by Convex Risk Minimization , 2008, IEEE Transactions on Information Theory.

[47]  Lawrence K. Saul,et al.  Kernel Methods for Deep Learning , 2009, NIPS.

[48]  Salim Roukos,et al.  Bleu: a Method for Automatic Evaluation of Machine Translation , 2002, ACL.

[49]  Yoshua Bengio,et al.  Learning long-term dependencies with gradient descent is difficult , 1994, IEEE Trans. Neural Networks.

[50]  J. Magnus,et al.  Matrix Differential Calculus with Applications in Statistics and Econometrics , 1991 .

[51]  Sepp Hochreiter,et al.  Untersuchungen zu dynamischen neuronalen Netzen , 1991 .

[52]  L. Isserlis ON A FORMULA FOR THE PRODUCT-MOMENT COEFFICIENT OF ANY ORDER OF A NORMAL FREQUENCY DISTRIBUTION IN ANY NUMBER OF VARIABLES , 1918 .