On the Promise of the Stochastic Generalized Gauss-Newton Method for Training DNNs

Following early work on Hessian-free methods for deep learning, we study a stochastic generalized Gauss-Newton method (SGN) for training DNNs. SGN is a second-order optimization method, with efficient iterations, that we demonstrate to often require substantially fewer iterations than standard SGD to converge. As the name suggests, SGN uses a Gauss-Newton approximation for the Hessian matrix, and, in order to compute an approximate search direction, relies on the conjugate gradient method combined with forward and reverse automatic differentiation. Despite the success of SGD and its first-order variants, and despite Hessian-free methods based on the Gauss-Newton Hessian approximation having been already theoretically proposed as practical methods for training DNNs, we believe that SGN has a lot of undiscovered and yet not fully displayed potential in big mini-batch scenarios. For this setting, we demonstrate that SGN does not only substantially improve over SGD in terms of the number of iterations, but also in terms of runtime. This is made possible by an efficient, easy-to-use and flexible implementation of SGN we propose in the Theano deep learning platform, which, unlike Tensorflow and Pytorch, supports forward automatic differentiation. This enables researchers to further study and improve this promising optimization technique and hopefully reconsider stochastic second-order methods as competitive optimization techniques for training DNNs; we also hope that the promise of SGN may lead to forward automatic differentiation being added to Tensorflow or Pytorch. Our results also show that in big mini-batch scenarios SGN is more robust than SGD with respect to its hyperparameters (we never had to tune its step-size for our benchmarks!), which eases the expensive process of hyperparameter tuning that is instead crucial for the performance of first-order methods.

[1]  Nicol N. Schraudolph,et al.  Fast Curvature Matrix-Vector Products for Second-Order Gradient Descent , 2002, Neural Computation.

[2]  Roland Vollgraf,et al.  Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms , 2017, ArXiv.

[3]  Yuan Yu,et al.  TensorFlow: A system for large-scale machine learning , 2016, OSDI.

[4]  Guigang Zhang,et al.  Deep Learning , 2016, Int. J. Semantic Comput..

[5]  Jimmy Ba,et al.  Adam: A Method for Stochastic Optimization , 2014, ICLR.

[6]  Alex Krizhevsky,et al.  Learning Multiple Layers of Features from Tiny Images , 2009 .

[7]  D K Smith,et al.  Numerical Optimization , 2001, J. Oper. Res. Soc..

[8]  Barak A. Pearlmutter Fast Exact Multiplication by the Hessian , 1994, Neural Computation.

[9]  Roger B. Grosse,et al.  Optimizing Neural Networks with Kronecker-factored Approximate Curvature , 2015, ICML.

[10]  Moritz Diehl,et al.  Transferring Optimality Across Data Distributions via Homotopy Methods , 2020, ICLR.

[11]  Y. Nesterov A method for solving the convex programming problem with convergence rate O(1/k^2) , 1983 .

[12]  Barak A. Pearlmutter,et al.  Automatic differentiation in machine learning: a survey , 2015, J. Mach. Learn. Res..

[13]  Surya Ganguli,et al.  Identifying and attacking the saddle point problem in high-dimensional non-convex optimization , 2014, NIPS.

[14]  James Martens,et al.  New Insights and Perspectives on the Natural Gradient Method , 2014, J. Mach. Learn. Res..

[15]  John Salvatier,et al.  Theano: A Python framework for fast computation of mathematical expressions , 2016, ArXiv.

[16]  Pierre Priouret,et al.  Adaptive Algorithms and Stochastic Approximations , 1990, Applications of Mathematics.

[17]  Boris Polyak Some methods of speeding up the convergence of iteration methods , 1964 .

[18]  Natalia Gimelshein,et al.  PyTorch: An Imperative Style, High-Performance Deep Learning Library , 2019, NeurIPS.

[19]  James Martens,et al.  Deep learning via Hessian-free optimization , 2010, ICML.

[20]  Yann Dauphin,et al.  Empirical Analysis of the Hessian of Over-Parametrized Neural Networks , 2017, ICLR.