L ARGE B ATCH O PTIMIZATION FOR D EEP L EARNING : T RAINING BERT IN 76 MINUTES

Training large deep neural networks on massive datasets is computationally very challenging. There has been recent surge in interest in using large batch stochastic optimization methods to tackle this issue. The most prominent algorithm in this line of research is LARS, which by employing layerwise adaptive learning rates trains RESNET on ImageNet in a few minutes. However, LARS performs poorly for attention models like BERT, indicating that its performance gains are not consistent across tasks. In this paper, we first study a principled layerwise adaptation strategy to accelerate training of deep neural networks using large mini-batches. Using this strategy, we develop a new layerwise adaptive large batch optimization technique called LAMB; we then provide convergence analysis of LAMB as well as LARS, showing convergence to a stationary point in general nonconvex settings. Our empirical results demonstrate the superior performance of LAMB across various tasks such as BERT and RESNET-50 training with very little hyperparameter tuning. In particular, for BERT training, our optimizer enables use of very large batch sizes of 32868 without any degradation of performance. By increasing the batch size to the memory limit of a TPUv3 Pod, BERT training time can be reduced from 3 days to just 76 minutes (Table 1). The LAMB implementation is available online1.

[1]  Y. Nesterov A method for solving the convex programming problem with convergence rate O(1/k^2) , 1983 .

[2]  Stephen J. Wright,et al.  Hogwild: A Lock-Free Approach to Parallelizing Stochastic Gradient Descent , 2011, NIPS.

[3]  Marc'Aurelio Ranzato,et al.  Large Scale Distributed Deep Networks , 2012, NIPS.

[4]  Yoshua Bengio,et al.  Practical Recommendations for Gradient-Based Training of Deep Architectures , 2012, Neural Networks: Tricks of the Trade.

[5]  Saeed Ghadimi,et al.  Stochastic First- and Zeroth-Order Methods for Nonconvex Stochastic Programming , 2013, SIAM J. Optim..

[6]  Geoffrey E. Hinton,et al.  On the importance of initialization and momentum in deep learning , 2013, ICML.

[7]  Alex Krizhevsky,et al.  One weird trick for parallelizing convolutional neural networks , 2014, ArXiv.

[8]  Roger B. Grosse,et al.  Optimizing Neural Networks with Kronecker-factored Approximate Curvature , 2015, ICML.

[9]  Mu Li Proposal Scaling Distributed Machine Learning with System and Algorithm Co-design , 2016 .

[10]  Jian Sun,et al.  Deep Residual Learning for Image Recognition , 2015, 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

[11]  Saeed Ghadimi,et al.  Mini-batch stochastic approximation methods for nonconvex stochastic composite optimization , 2013, Mathematical Programming.

[12]  Forrest N. Iandola,et al.  FireCaffe: Near-Linear Acceleration of Deep Neural Network Training on Compute Clusters , 2015, 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

[13]  Timothy Dozat,et al.  Incorporating Nesterov Momentum into Adam , 2016 .

[14]  Kaiming He,et al.  Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour , 2017, ArXiv.

[15]  Takuya Akiba,et al.  Extremely Large Minibatch SGD: Training ResNet-50 on ImageNet in 15 Minutes , 2017, ArXiv.

[16]  Elad Hoffer,et al.  Train longer, generalize better: closing the generalization gap in large batch training of neural networks , 2017, NIPS.

[17]  Jorge Nocedal,et al.  On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima , 2016, ICLR.

[18]  Kunle Olukotun,et al.  DAWNBench : An End-to-End Deep Learning Benchmark and Competition , 2017 .

[19]  Vikram A. Saletore,et al.  Scale out for large minibatch SGD: Residual network training on ImageNet-1K with improved accuracy and reduced time to train , 2017, ArXiv.

[20]  Yang You,et al.  Scaling SGD Batch Size to 32K for ImageNet Training , 2017, ArXiv.

[21]  Michael Garland,et al.  AdaBatch: Adaptive Batch Sizes for Training Deep Neural Networks , 2017, ArXiv.

[22]  Satoshi Matsuoka,et al.  Second-order Optimization Method for Large Mini-batch: Training ResNet-50 on ImageNet in 35 Epochs , 2018, ArXiv.

[23]  James Demmel,et al.  ImageNet Training in Minutes , 2017, ICPP.

[24]  Pongsakorn U.-Chupala,et al.  ImageNet/ResNet-50 Training in 224 Seconds , 2018, ArXiv.

[25]  Quoc V. Le,et al.  Don't Decay the Learning Rate, Increase the Batch Size , 2017, ICLR.

[26]  Kamyar Azizzadenesheli,et al.  signSGD: compressed optimisation for non-convex problems , 2018, ICML.

[27]  Yuanzhou Yang,et al.  Highly Scalable Deep Learning Training System with Mixed-Precision: Training ImageNet in Four Minutes , 2018, ArXiv.

[28]  Tao Wang,et al.  Image Classification at Supercomputer Scale , 2018, ArXiv.

[29]  James Demmel,et al.  Large-batch training for LSTM and beyond , 2019, SC.

[30]  Masafumi Yamazaki,et al.  Yet Another Accelerated SGD: ResNet-50 Training on ImageNet in 74.7 seconds , 2019, ArXiv.

[31]  Jascha Sohl-Dickstein,et al.  Measuring the Effects of Data Parallelism on Neural Network Training , 2018, J. Mach. Learn. Res..

[32]  Ming-Wei Chang,et al.  BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding , 2019, NAACL.