The Natural Neural Tangent Kernel: Neural Network Training Dynamics under Natural Gradient Descent

Gradient-based optimization methods have proven successful in learning complex, overparameterized neural networks from non-convex objectives. Yet, the precise theoretical relationship between gradient-based optimization methods, the induced training dynamics, and generalization in deep neural networks remains unclear. In this work, we investigate the training dynamics of overparameterized neural networks under natural gradient descent. Taking a function-space view of the training dynamics, we give an exact analytic solution to the training dynamics on training points. We derive a bound on the discrepancy between the distributions over functions at the global optimum of natural gradient descent and the analytic solution to the natural gradient descent training dynamics linearized around the parameters at initialization and validate our theoretical results empirically. In particular, we show that the discrepancy between the functions obtained from linearized and non-linearized natural gradient descent is provably smaller than under standard gradient descent, and we demonstrate empirically that the discrepancy is small for overparameterized neural networks without needing to make a limit argument about the width of the neural network layers, as was done in previous work. Finally, we show that our theoretical results are consistent with the empirical discrepancy between the functions obtained from linearized and non-linearized natural gradient descent and that the discrepancy is small on a set of regression benchmark problems.

[1]  Shun-ichi Amari,et al.  The Efficiency and the Robustness of Natural Gradient Descent Learning Rule , 1997, NIPS.

[2]  Shun-ichi Amari,et al.  Natural Gradient Works Efficiently in Learning , 1998, Neural Computation.

[3]  Razvan Pascanu,et al.  Revisiting Natural Gradient for Deep Networks , 2013, ICLR.

[4]  Charles Blundell,et al.  Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles , 2016, NIPS.

[5]  Guillaume Hennequin,et al.  Exact natural gradient in deep linear networks and its application to the nonlinear case , 2018, NeurIPS.

[6]  Francis Bach,et al.  On the Global Convergence of Gradient Descent for Over-parameterized Models using Optimal Transport , 2018, NeurIPS.

[7]  Arthur Jacot,et al.  Neural tangent kernel: convergence and generalization in neural networks (invited paper) , 2018, NeurIPS.

[8]  Ruosong Wang,et al.  On Exact Computation with an Infinitely Wide Neural Net , 2019, NeurIPS.

[9]  Marius Kloft,et al.  Efficient Gaussian Process Classification Using Polya-Gamma Data Augmentation , 2018, AAAI.

[10]  Shun-ichi Amari,et al.  Fisher Information and Natural Gradient Learning of Random Deep Networks , 2018, AISTATS.

[11]  Shun-ichi Amari,et al.  Universal statistics of Fisher information in deep neural networks: mean field approach , 2018, AISTATS.

[12]  Jaehoon Lee,et al.  Wide neural networks of any depth evolve as linear models under gradient descent , 2019, NeurIPS.

[13]  Florian Wenzel,et al.  Multi-Class Gaussian Process Classification Made Conjugate: Efficient Inference via Data Augmentation , 2019, UAI.

[14]  Samet Oymak,et al.  Overparameterized Nonlinear Learning: Gradient Descent Takes the Shortest Path? , 2018, ICML.

[15]  Barnabás Póczos,et al.  Gradient Descent Provably Optimizes Over-parameterized Neural Networks , 2018, ICLR.

[16]  James Martens,et al.  New Insights and Perspectives on the Natural Gradient Method , 2014, J. Mach. Learn. Res..