Understanding Straight-Through Estimator in Training Activation Quantized Neural Nets

Training activation quantized neural networks involves minimizing a piecewise constant function whose gradient vanishes almost everywhere, which is undesirable for the standard back-propagation or chain rule. An empirical way around this issue is to use a straight-through estimator (STE) (Bengio et al., 2013) in the backward pass only, so that the "gradient" through the modified chain rule becomes non-trivial. Since this unusual "gradient" is certainly not the gradient of loss function, the following question arises: why searching in its negative direction minimizes the training loss? In this paper, we provide the theoretical justification of the concept of STE by answering this question. We consider the problem of learning a two-linear-layer network with binarized ReLU activation and Gaussian input data. We shall refer to the unusual "gradient" given by the STE-modifed chain rule as coarse gradient. The choice of STE is not unique. We prove that if the STE is properly chosen, the expected coarse gradient correlates positively with the population gradient (not available for the training), and its negation is a descent direction for minimizing the population loss. We further show the associated coarse gradient descent algorithm converges to a critical point of the population loss minimization problem. Moreover, we show that a poor choice of STE leads to instability of the training algorithm near certain local minima, which is verified with CIFAR-10 experiments.

[1]  A. A. Mullin,et al.  Principles of neurodynamics , 1962 .

[2]  Jack Xin,et al.  Blended coarse gradient descent for full quantization of deep neural networks , 2018, Research in the Mathematical Sciences.

[3]  Zhen Li,et al.  Deep Neural Nets with Interpolating Function as Output Activation , 2018, NeurIPS.

[4]  James T. Kwok,et al.  Loss-aware Weight Quantization of Deep Networks , 2018, ICLR.

[5]  S. P. Lloyd,et al.  Least squares quantization in PCM , 1982, IEEE Trans. Inf. Theory.

[6]  Yuandong Tian,et al.  An Analytical Formula of Population Gradient for two-layered ReLU network and its Applications in Convergence and Critical Point Analysis , 2017, ICML.

[7]  Yoshua Bengio,et al.  Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation , 2013, ArXiv.

[8]  Ran El-Yaniv,et al.  Quantized Neural Networks: Training Neural Networks with Low Precision Weights and Activations , 2016, J. Mach. Learn. Res..

[9]  Shane Legg,et al.  Human-level control through deep reinforcement learning , 2015, Nature.

[10]  Jack Xin,et al.  Quantization and Training of Low Bit-Width Convolutional Neural Networks for Object Detection , 2016, Journal of Computational Mathematics.

[11]  Jian Sun,et al.  Deep Residual Learning for Image Recognition , 2015, 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

[12]  Mahdi Soltanolkotabi,et al.  Learning ReLUs via Gradient Descent , 2017, NIPS.

[13]  Raghu Meka,et al.  Learning One Convolutional Layer with Overlapping Patches , 2018, ICML.

[14]  Jason Weston,et al.  A unified architecture for natural language processing: deep neural networks with multitask learning , 2008, ICML '08.

[15]  Qianxiao Li,et al.  An Optimal Control Approach to Deep Learning and Applications to Discrete-Weight Neural Networks , 2018, ICML.

[16]  Bin Liu,et al.  Ternary Weight Networks , 2016, ICASSP 2023 - 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP).

[17]  Geoffrey E. Hinton,et al.  ImageNet classification with deep convolutional neural networks , 2012, Commun. ACM.

[18]  Yoshua Bengio,et al.  Gradient-based learning applied to document recognition , 1998, Proc. IEEE.

[19]  Yoshua Bengio,et al.  BinaryConnect: Training Deep Neural Networks with binary weights during propagations , 2015, NIPS.

[20]  Inderjit S. Dhillon,et al.  Recovery Guarantees for One-hidden-layer Neural Networks , 2017, ICML.

[21]  Shuchang Zhou,et al.  DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients , 2016, ArXiv.

[22]  Song Han,et al.  Trained Ternary Quantization , 2016, ICLR.

[23]  Yuanzhi Li,et al.  Convergence Analysis of Two-layer Neural Networks with ReLU Activation , 2017, NIPS.

[24]  Pedro M. Domingos,et al.  Deep Learning as a Mixed Convex-Combinatorial Optimization Problem , 2017, ICLR.

[25]  Ran El-Yaniv,et al.  Binarized Neural Networks , 2016, NIPS.

[26]  Yoav Freund,et al.  Large Margin Classification Using the Perceptron Algorithm , 1998, COLT.

[27]  Juncai He sci Relu Deep Neural Networks and Linear Finite Elements , 2020 .

[28]  Jian Sun,et al.  Deep Learning with Low Precision by Half-Wave Gaussian Quantization , 2017, 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

[29]  David A. Wagner,et al.  Obfuscated Gradients Give a False Sense of Security: Circumventing Defenses to Adversarial Examples , 2018, ICML.

[30]  Andrew Zisserman,et al.  Very Deep Convolutional Networks for Large-Scale Image Recognition , 2014, ICLR.

[31]  Yuandong Tian,et al.  Gradient Descent Learns One-hidden-layer CNN: Don't be Afraid of Spurious Local Minima , 2017, ICML.

[32]  Hanan Samet,et al.  Training Quantized Nets: A Deeper Understanding , 2017, NIPS.

[33]  Joan Bruna,et al.  Intriguing properties of neural networks , 2013, ICLR.

[34]  Lin Xu,et al.  Incremental Network Quantization: Towards Lossless CNNs with Low-Precision Weights , 2017, ICLR.

[35]  Alex Krizhevsky,et al.  Learning Multiple Layers of Features from Tiny Images , 2009 .

[36]  Bernard Widrow,et al.  30 years of adaptive neural networks: perceptron, Madaline, and backpropagation , 1990, Proc. IEEE.

[37]  Yoshua Bengio,et al.  Difference Target Propagation , 2014, ECML/PKDD.

[38]  Demis Hassabis,et al.  Mastering the game of Go with deep neural networks and tree search , 2016, Nature.

[39]  Swagath Venkataramani,et al.  PACT: Parameterized Clipping Activation for Quantized Neural Networks , 2018, ArXiv.

[40]  Ali Farhadi,et al.  XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks , 2016, ECCV.

[41]  Jack Xin,et al.  BinaryRelax: A Relaxation Approach For Training Deep Neural Networks With Quantized Weights , 2018, SIAM J. Imaging Sci..

[42]  Sergey Ioffe,et al.  Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift , 2015, ICML.

[43]  Kaiming He,et al.  Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks , 2015, IEEE Transactions on Pattern Analysis and Machine Intelligence.

[44]  Amir Globerson,et al.  Globally Optimal Gradient Descent for a ConvNet with Gaussian Inputs , 2017, ICML.