Reducing BERT Pre-Training Time from 3 Days to 76 Minutes

Large-batch training is key to speeding up deep neural network training in large distributed systems. However, large-batch training is difficult because it produces a generalization gap. Straightforward optimization often leads to accuracy loss on the test set. BERT \cite{devlin2018bert} is a state-of-the-art deep learning model that builds on top of deep bidirectional transformers for language understanding. Previous large-batch training techniques do not perform well for BERT when we scale the batch size (e.g. beyond 8192). BERT pre-training also takes a long time to finish (around three days on 16 TPUv3 chips). To solve this problem, we propose the LAMB optimizer, which helps us to scale the batch size to 65536 without losing accuracy. LAMB is a general optimizer that works for both small and large batch sizes and does not need hyper-parameter tuning besides the learning rate. The baseline BERT-Large model needs 1 million iterations to finish pre-training, while LAMB with batch size 65536/32768 only needs 8599 iterations. We push the batch size to the memory limit of a TPUv3 pod and can finish BERT training in 76 minutes.

[1]  Stephen J. Wright,et al.  Hogwild: A Lock-Free Approach to Parallelizing Stochastic Gradient Descent , 2011, NIPS.

[2]  P. Cochat,et al.  Et al , 2008, Archives de pediatrie : organe officiel de la Societe francaise de pediatrie.

[3]  Marc'Aurelio Ranzato,et al.  Large Scale Distributed Deep Networks , 2012, NIPS.

[4]  Yoshua Bengio,et al.  Practical Recommendations for Gradient-Based Training of Deep Architectures , 2012, Neural Networks: Tricks of the Trade.

[5]  Geoffrey E. Hinton,et al.  ImageNet classification with deep convolutional neural networks , 2012, Commun. ACM.

[6]  Saeed Ghadimi,et al.  Stochastic First- and Zeroth-Order Methods for Nonconvex Stochastic Programming , 2013, SIAM J. Optim..

[7]  Alex Krizhevsky,et al.  One weird trick for parallelizing convolutional neural networks , 2014, ArXiv.

[8]  Sanja Fidler,et al.  Aligning Books and Movies: Towards Story-Like Visual Explanations by Watching Movies and Reading Books , 2015, 2015 IEEE International Conference on Computer Vision (ICCV).

[9]  Jimmy Ba,et al.  Adam: A Method for Stochastic Optimization , 2014, ICLR.

[10]  Roger B. Grosse,et al.  Optimizing Neural Networks with Kronecker-factored Approximate Curvature , 2015, ICML.

[11]  Mu Li Proposal Scaling Distributed Machine Learning with System and Algorithm Co-design , 2016 .

[12]  Jian Sun,et al.  Deep Residual Learning for Image Recognition , 2015, 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

[13]  Saeed Ghadimi,et al.  Mini-batch stochastic approximation methods for nonconvex stochastic composite optimization , 2013, Mathematical Programming.

[14]  Forrest N. Iandola,et al.  FireCaffe: Near-Linear Acceleration of Deep Neural Network Training on Compute Clusters , 2015, 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

[15]  Kaiming He,et al.  Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour , 2017, ArXiv.

[16]  Yang You,et al.  Large Batch Training of Convolutional Networks , 2017, 1708.03888.

[17]  David A. Patterson,et al.  In-datacenter performance analysis of a tensor processing unit , 2017, 2017 ACM/IEEE 44th Annual International Symposium on Computer Architecture (ISCA).

[18]  Frank Hutter,et al.  Fixing Weight Decay Regularization in Adam , 2017, ArXiv.

[19]  Takuya Akiba,et al.  Extremely Large Minibatch SGD: Training ResNet-50 on ImageNet in 15 Minutes , 2017, ArXiv.

[20]  Elad Hoffer,et al.  Train longer, generalize better: closing the generalization gap in large batch training of neural networks , 2017, NIPS.

[21]  Jorge Nocedal,et al.  On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima , 2016, ICLR.

[22]  Vikram A. Saletore,et al.  Scale out for large minibatch SGD: Residual network training on ImageNet-1K with improved accuracy and reduced time to train , 2017, ArXiv.

[23]  Yang You,et al.  Scaling SGD Batch Size to 32K for ImageNet Training , 2017, ArXiv.

[24]  Michael Garland,et al.  AdaBatch: Adaptive Batch Sizes for Training Deep Neural Networks , 2017, ArXiv.

[25]  Satoshi Matsuoka,et al.  Second-order Optimization Method for Large Mini-batch: Training ResNet-50 on ImageNet in 35 Epochs , 2018, ArXiv.

[26]  James Demmel,et al.  ImageNet Training in Minutes , 2017, ICPP.

[27]  Pongsakorn U.-Chupala,et al.  ImageNet/ResNet-50 Training in 224 Seconds , 2018, ArXiv.

[28]  Hiroaki Mikami,et al.  Massively Distributed SGD: ImageNet/ResNet-50 Training in a Flash , 2018 .

[29]  Quoc V. Le,et al.  Don't Decay the Learning Rate, Increase the Batch Size , 2017, ICLR.

[30]  Kamyar Azizzadenesheli,et al.  signSGD: compressed optimisation for non-convex problems , 2018, ICML.

[31]  Yuanzhou Yang,et al.  Highly Scalable Deep Learning Training System with Mixed-Precision: Training ImageNet in Four Minutes , 2018, ArXiv.

[32]  Tao Wang,et al.  Image Classification at Supercomputer Scale , 2018, ArXiv.

[33]  James Demmel,et al.  Large-batch training for LSTM and beyond , 2019, SC.

[34]  Masafumi Yamazaki,et al.  Yet Another Accelerated SGD: ResNet-50 Training on ImageNet in 74.7 seconds , 2019, ArXiv.

[35]  Jascha Sohl-Dickstein,et al.  Measuring the Effects of Data Parallelism on Neural Network Training , 2018, J. Mach. Learn. Res..

[36]  Frank Hutter,et al.  Decoupled Weight Decay Regularization , 2017, ICLR.

[37]  Ming-Wei Chang,et al.  BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding , 2019, NAACL.