Parallelizing Linear Recurrent Neural Nets Over Sequence Length

Recurrent neural networks (RNNs) are widely used to model sequential data but their non-linear dependencies between sequence elements prevent parallelizing training over sequence length. We show the training of RNNs with only linear sequential dependencies can be parallelized over the sequence length using the parallel scan algorithm, leading to rapid training on long sequences even with small minibatch size. We develop a parallel linear recurrence CUDA kernel and show that it can be applied to immediately speed up training and inference of several state of the art RNN architectures by up to 9x. We abstract recent work on linear RNNs into a new framework of linear surrogate RNNs and develop a linear surrogate model for the long short-term memory unit, the GILR-LSTM, that utilizes parallel linear recurrence. We extend sequence learning to new extremely long sequence regimes that were previously out of reach by successfully training a GILR-LSTM on a synthetic sequence classification task with a one million timestep dependency.

[1]  Martín Abadi,et al.  TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed Systems , 2016, ArXiv.

[2]  Quoc V. Le,et al.  Sequence to Sequence Learning with Neural Networks , 2014, NIPS.

[3]  Jeffrey M. Hausdorff,et al.  Physionet: Components of a New Research Resource for Complex Physiologic Signals". Circu-lation Vol , 2000 .

[4]  Guy E. Blelloch,et al.  Prefix sums and their applications , 1990 .

[5]  Yoshua Bengio,et al.  Understanding the difficulty of training deep feedforward neural networks , 2010, AISTATS.

[6]  Sanjeev Saxena,et al.  On Parallel Prefix Computation , 1994, Parallel Process. Lett..

[7]  Jimmy Ba,et al.  Adam: A Method for Stochastic Optimization , 2014, ICLR.

[8]  Muhammad Ghifary,et al.  Strongly-Typed Recurrent Neural Networks , 2016, ICML.

[9]  Peter Stone,et al.  Deep Recurrent Q-Learning for Partially Observable MDPs , 2015, AAAI Fall Symposia.

[10]  Yann LeCun,et al.  The mnist database of handwritten digits , 2005 .

[11]  Richard Socher,et al.  Quasi-Recurrent Neural Networks , 2016, ICLR.

[12]  Yann Dauphin,et al.  Convolutional Sequence to Sequence Learning , 2017, ICML.

[13]  Heiga Zen,et al.  WaveNet: A Generative Model for Raw Audio , 2016, SSW.

[14]  Chong Wang,et al.  Deep Speech 2 : End-to-End Speech Recognition in English and Mandarin , 2015, ICML.

[15]  Alex Graves,et al.  Neural Machine Translation in Linear Time , 2016, ArXiv.

[16]  Jürgen Schmidhuber,et al.  Long Short-Term Memory , 1997, Neural Computation.

[17]  Jorge Nocedal,et al.  On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima , 2016, ICLR.

[18]  Yoshua Bengio,et al.  Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation , 2014, EMNLP.

[19]  Yu Zhang,et al.  Training RNNs as Fast as CNNs , 2017, EMNLP 2018.

[20]  Erich Elsen,et al.  Persistent RNNs: Stashing Recurrent Weights On-Chip , 2016, ICML.

[21]  Gregory Cohen,et al.  Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades , 2015, Front. Neurosci..