Transformers learn to implement preconditioned gradient descent for in-context learning

Motivated by the striking ability of transformers for in-context learning, several works demonstrate that transformers can implement algorithms like gradient descent. By a careful construction of weights, these works show that multiple layers of transformers are expressive enough to simulate gradient descent iterations. Going beyond the question of expressivity, we ask: Can transformers learn to implement such algorithms by training over random problem instances? To our knowledge, we make the first theoretical progress toward this question via analysis of the loss landscape for linear transformers trained over random instances of linear regression. For a single attention layer, we prove the global minimum of the training objective implements a single iteration of preconditioned gradient descent. Notably, the preconditioning matrix not only adapts to the input distribution but also to the variance induced by data inadequacy. For a transformer with $k$ attention layers, we prove certain critical points of the training objective implement $k$ iterations of preconditioned gradient descent. Our results call for future theoretical studies on learning algorithms by training transformers.

[1]  Jason D. Lee,et al.  Looped Transformers as Programmable Computers , 2023, ICML.

[2]  A. Zhmoginov,et al.  Transformers learn in-context by gradient descent , 2022, ICML.

[3]  D. Schuurmans,et al.  What learning algorithm is in-context learning? Investigations with linear models , 2022, ICLR.

[4]  Tom B. Brown,et al.  In-context Learning and Induction Heads , 2022, ArXiv.

[5]  Percy Liang,et al.  What Can Transformers Learn In-Context? A Case Study of Simple Function Classes , 2022, NeurIPS.

[6]  Stella Rose Biderman,et al.  GPT-NeoX-20B: An Open-Source Autoregressive Language Model , 2022, BIGSCIENCE.

[7]  Po-Sen Huang,et al.  Scaling Language Models: Methods, Analysis & Insights from Training Gopher , 2021, ArXiv.

[8]  Sang Michael Xie,et al.  An Explanation of In-context Learning as Implicit Bayesian Inference , 2021, ICLR.

[9]  M. Lewis,et al.  MetaICL: Learning to Learn In Context , 2021, NAACL.

[10]  Benjamin L. Edelman,et al.  Inductive Biases and Variable Creation in Self-Attention Mechanisms , 2021, ICML.

[11]  Colin Wei,et al.  Statistically Meaningful Approximation: a Case Study on Approximating Turing Machines with Transformers , 2021, NeurIPS.

[12]  Kazuki Irie,et al.  Linear Transformers Are Secretly Fast Weight Programmers , 2021, ICML.

[13]  Mark Chen,et al.  Language Models are Few-Shot Learners , 2020, NeurIPS.

[14]  Lukasz Kaiser,et al.  Attention is All you Need , 2017, NIPS.

[15]  Alex Graves,et al.  Neural Turing Machines , 2014, ArXiv.

[16]  John C. Duchi,et al.  Adaptive Subgradient Methods for Online Learning and Stochastic Optimization , 2011 .

[17]  S. Hochreiter,et al.  Long Short-Term Memory , 1997, Neural Computation.

[18]  Hava T. Siegelmann,et al.  On the computational power of neural nets , 1992, COLT '92.

[19]  P. Barceló,et al.  Attention is Turing-Complete , 2021, J. Mach. Learn. Res..

[20]  Ming-Wei Chang,et al.  BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding , 2019, NAACL.

[21]  Ilya Sutskever,et al.  Language Models are Unsupervised Multitask Learners , 2019 .

[22]  G. Evans,et al.  Learning to Optimize , 2008 .

[23]  Vladimir N. Vapnik,et al.  The Nature of Statistical Learning Theory , 2000, Statistics for Engineering and Information Science.

[24]  Noga Alon,et al.  The Probabilistic Method , 2015, Fundamentals of Ramsey Theory.