Scan and Snap: Understanding Training Dynamics and Token Composition in 1-layer Transformer

Transformer architecture has shown impressive performance in multiple research domains and has become the backbone of many neural network models. However, there is limited understanding on how it works. In particular, with a simple predictive loss, how the representation emerges from the gradient \emph{training dynamics} remains a mystery. In this paper, for 1-layer transformer with one self-attention layer plus one decoder layer, we analyze its SGD training dynamics for the task of next token prediction in a mathematically rigorous manner. We open the black box of the dynamic process of how the self-attention layer combines input tokens, and reveal the nature of underlying inductive bias. More specifically, with the assumption (a) no positional encoding, (b) long input sequence, and (c) the decoder layer learns faster than the self-attention layer, we prove that self-attention acts as a \emph{discriminative scanning algorithm}: starting from uniform attention, it gradually attends more to distinct key tokens for a specific next token to be predicted, and pays less attention to common key tokens that occur across different next tokens. Among distinct tokens, it progressively drops attention weights, following the order of low to high co-occurrence between the key and the query token in the training set. Interestingly, this procedure does not lead to winner-takes-all, but decelerates due to a \emph{phase transition} that is controllable by the learning rates of the two layers, leaving (almost) fixed token combination. We verify this \textbf{\emph{scan and snap}} dynamics on synthetic and real-world data (WikiText).

[1]  Song Mei,et al.  Transformers as Statisticians: Provable In-Context Learning with In-Context Algorithm Selection , 2023, ArXiv.

[2]  Siva Reddy,et al.  The Impact of Positional Encoding on Length Generalization in Transformers , 2023, NeurIPS.

[3]  Shuai Li,et al.  The Closeness of In-Context Learning and Weight Shifting for Softmax Regression , 2023, ArXiv.

[4]  Sanjeev Arora,et al.  Do Transformers Parse while Predicting the Masked Word? , 2023, ArXiv.

[5]  Andrej Risteski,et al.  How Do Transformers Learn Topic Structure: Towards a Mechanistic Understanding , 2023, ICML.

[6]  Naman Goyal,et al.  LLaMA: Open and Efficient Foundation Language Models , 2023, ArXiv.

[7]  S. Du,et al.  Over-Parameterization Exponentially Slows Down Gradient Descent for Learning a Single Neuron , 2023, COLT.

[8]  A. Zhmoginov,et al.  Transformers learn in-context by gradient descent , 2022, ICML.

[9]  D. Schuurmans,et al.  What learning algorithm is in-context learning? Investigations with linear models , 2022, ICLR.

[10]  Andrew M. Dai,et al.  Scaling Instruction-Finetuned Language Models , 2022, ArXiv.

[11]  Michael E. Sander,et al.  Vision Transformers provably learn spatial structure , 2022, Neural Information Processing Systems.

[12]  Tom B. Brown,et al.  In-context Learning and Induction Heads , 2022, ArXiv.

[13]  Percy Liang,et al.  What Can Transformers Learn In-Context? A Case Study of Simple Function Classes , 2022, NeurIPS.

[14]  S. Kakade,et al.  Hidden Progress in Deep Learning: SGD Learns Parities Near the Computational Limit , 2022, NeurIPS.

[15]  Yuhuai Wu,et al.  Exploring Length Generalization in Large Language Models , 2022, NeurIPS.

[16]  Yuandong Tian Understanding the Role of Nonlinearity in Training Dynamics of Contrastive Learning , 2022, ICLR.

[17]  Hyung Won Chung,et al.  UL2: Unifying Language Learning Paradigms , 2022, ICLR.

[18]  Xi Victoria Lin,et al.  OPT: Open Pre-trained Transformer Language Models , 2022, ArXiv.

[19]  Andrew M. Dai,et al.  PaLM: Scaling Language Modeling with Pathways , 2022, J. Mach. Learn. Res..

[20]  Edward J. Hu,et al.  Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer , 2022, ArXiv.

[21]  Michael Auli,et al.  data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language , 2022, ICML.

[22]  Yuandong Tian Understanding Deep Contrastive Learning via Coordinate-wise Optimization , 2022, NeurIPS.

[23]  Ross B. Girshick,et al.  Masked Autoencoders Are Scalable Vision Learners , 2021, 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR).

[24]  Benjamin L. Edelman,et al.  Inductive Biases and Variable Creation in Self-Attention Mechanisms , 2021, ICML.

[25]  Jason Weston,et al.  NormFormer: Improved Transformer Pretraining with Extra Normalization , 2021, ArXiv.

[26]  Noah A. Smith,et al.  Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation , 2021, ICLR.

[27]  Colin Wei,et al.  Statistically Meaningful Approximation: a Case Study on Approximating Turing Machines with Transformers , 2021, NeurIPS.

[28]  Krzysztof Choromanski,et al.  On the Expressive Power of Self-Attention Matrices , 2021, ArXiv.

[29]  C. Papadimitriou,et al.  Self-Attention Networks Can Process Bounded Hierarchical Languages , 2021, ACL.

[30]  Dan Klein,et al.  Approximating How Single Head Attention Learns , 2021, ArXiv.

[31]  Ilya Sutskever,et al.  Learning Transferable Visual Models From Natural Language Supervision , 2021, ICML.

[32]  Sashank J. Reddi,et al.  Why are Adaptive Methods Good for Attention Models? , 2020, NeurIPS.

[33]  S. Gelly,et al.  An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale , 2020, ICLR.

[34]  Navin Goyal,et al.  On the Ability and Limitations of Transformers to Recognize Formal Languages , 2020, EMNLP.

[35]  Ryan J. Lowe,et al.  Learning to summarize from human feedback , 2020, NeurIPS 2020.

[36]  Cong Fang,et al.  Modeling from Features: a Mean-field Framework for Over-parameterized Deep Neural Networks , 2020, COLT.

[37]  Abdel-rahman Mohamed,et al.  wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations , 2020, NeurIPS.

[38]  Jascha Sohl-Dickstein,et al.  Infinite attention: NNGP and NTK for deep attention networks , 2020, ICML.

[39]  Navin Goyal,et al.  On the Computational Power of Transformers and Its Implications in Sequence Modeling , 2020, CONLL.

[40]  Mark Chen,et al.  Language Models are Few-Shot Learners , 2020, NeurIPS.

[41]  俊一 甘利 5分で分かる!? 有名論文ナナメ読み:Jacot, Arthor, Gabriel, Franck and Hongler, Clement : Neural Tangent Kernel : Convergence and Generalization in Neural Networks , 2020 .

[42]  Chao Ma,et al.  A Mean-field Analysis of Deep ResNet and Beyond: Towards Provable Optimization Via Overparameterization From Depth , 2020, ICML.

[43]  知秀 柴田 5分で分かる!? 有名論文ナナメ読み:Jacob Devlin et al. : BERT : Pre-training of Deep Bidirectional Transformers for Language Understanding , 2020 .

[44]  Phan-Minh Nguyen,et al.  A Rigorous Framework for the Mean Field Limit of Multilayer Neural Networks , 2020, Mathematical Statistics and Learning.

[45]  Lukasz Kaiser,et al.  Reformer: The Efficient Transformer , 2020, ICLR.

[46]  Sashank J. Reddi,et al.  Are Transformers universal approximators of sequence-to-sequence functions? , 2019, ICLR.

[47]  Kjell Schubert,et al.  Transformer-Transducer: End-to-End Speech Recognition with Self-Attention , 2019, ArXiv.

[48]  Colin Raffel,et al.  Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer , 2019, J. Mach. Learn. Res..

[49]  Rico Sennrich,et al.  Root Mean Square Layer Normalization , 2019, NeurIPS.

[50]  S. Du,et al.  Towards Understanding the Importance of Shortcut Connections in Residual Networks , 2019, NeurIPS.

[51]  Samet Oymak,et al.  Toward Moderate Overparameterization: Global Convergence Guarantees for Training Shallow Neural Networks , 2019, IEEE Journal on Selected Areas in Information Theory.

[52]  Ruosong Wang,et al.  Fine-Grained Analysis of Optimization and Generalization for Overparameterized Two-Layer Neural Networks , 2019, ICML.

[53]  Francis Bach,et al.  On Lazy Training in Differentiable Programming , 2018, NeurIPS.

[54]  Yuan Cao,et al.  Stochastic Gradient Descent Optimizes Over-parameterized Deep ReLU Networks , 2018, ArXiv.

[55]  Liwei Wang,et al.  Gradient Descent Finds Global Minima of Deep Neural Networks , 2018, ICML.

[56]  Barnabás Póczos,et al.  Gradient Descent Provably Optimizes Over-parameterized Neural Networks , 2018, ICLR.

[57]  Wei Hu,et al.  A Convergence Analysis of Gradient Descent for Deep Linear Neural Networks , 2018, ICLR.

[58]  Yuanzhi Li,et al.  A Convergence Theory for Deep Learning via Over-Parameterization , 2018, ICML.

[59]  Yuanzhi Li,et al.  Learning Overparameterized Neural Networks via Stochastic Gradient Descent on Structured Data , 2018, NeurIPS.

[60]  Francis Bach,et al.  On the Global Convergence of Gradient Descent for Over-parameterized Models using Optimal Transport , 2018, NeurIPS.

[61]  Andrea Montanari,et al.  A mean field view of the landscape of two-layer neural networks , 2018, Proceedings of the National Academy of Sciences.

[62]  Noam Shazeer,et al.  Adafactor: Adaptive Learning Rates with Sublinear Memory Cost , 2018, ICML.

[63]  Philip M. Long,et al.  Gradient Descent with Identity Initialization Efficiently Learns Positive-Definite Linear Transformations by Deep Residual Networks , 2018, Neural Computation.

[64]  Raghu Meka,et al.  Learning One Convolutional Layer with Overlapping Patches , 2018, ICML.

[65]  Yuandong Tian,et al.  Gradient Descent Learns One-hidden-layer CNN: Don't be Afraid of Spurious Local Minima , 2017, ICML.

[66]  Frank Hutter,et al.  Decoupled Weight Decay Regularization , 2017, ICLR.

[67]  Yuandong Tian,et al.  When is a Convolutional Filter Easy To Learn? , 2017, ICLR.

[68]  Mahdi Soltanolkotabi,et al.  Learning ReLUs via Gradient Descent , 2017, NIPS.

[69]  Yuandong Tian,et al.  An Analytical Formula of Population Gradient for two-layered ReLU network and its Applications in Convergence and Critical Point Analysis , 2017, ICML.

[70]  Amir Globerson,et al.  Globally Optimal Gradient Descent for a ConvNet with Gaussian Inputs , 2017, ICML.

[71]  Richard Socher,et al.  Pointer Sentinel Mixture Models , 2016, ICLR.

[72]  Jimmy Ba,et al.  Adam: A Method for Stochastic Optimization , 2014, ICLR.

[73]  P. Barceló,et al.  Attention is Turing-Complete , 2021, J. Mach. Learn. Res..

[74]  Stephen Lin,et al.  Swin Transformer: Hierarchical Vision Transformer using Shifted Windows , 2021, 2021 IEEE/CVF International Conference on Computer Vision (ICCV).

[75]  Ilya Sutskever,et al.  Language Models are Unsupervised Multitask Learners , 2019 .

[76]  Tuo Zhao,et al.  Toward Understanding the Importance of Noise in Training Neural Networks , 2019, ICML.

[77]  Aidong Zhang,et al.  A Survey on Context Learning , 2017, IEEE Transactions on Knowledge and Data Engineering.