Coupled Gradient Estimators for Discrete Latent Variables

Training models with discrete latent variables is challenging due to the high variance of unbiased gradient estimators. While low-variance reparameterization gradients of a continuous relaxation can provide an effective solution, a continuous relaxation is not always available or tractable. Dong et al. (2020) and Yin et al. (2020) introduced a performant estimator that does not rely on continuous relaxations; however, it is limited to binary random variables. We introduce a novel derivation of their estimator based on importance sampling and statistical couplings, which we extend to the categorical setting. Motivated by the construction of a stick-breaking coupling, we introduce gradient estimators based on reparameterizing categorical variables as sequences of binary variables and Rao-Blackwellization. In systematic experiments, we show that our proposed categorical gradient estimators provide state-of-the-art performance, whereas even with additional Rao-Blackwellization, previous estimators (Yin et al., 2019) underperform a simpler REINFORCE with a leave-one-out-baseline estimator (Kool et al., 2019).

[1]  Michael I. Jordan,et al.  Tree-Structured Stick Breaking for Hierarchical Data , 2010, NIPS.

[2]  Yoshua Bengio,et al.  Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation , 2013, ArXiv.

[3]  Sergey Levine,et al.  MuProp: Unbiased Backpropagation for Stochastic Neural Networks , 2015, ICLR.

[4]  Mingyuan Zhou,et al.  ARM: Augment-REINFORCE-Merge Gradient for Stochastic Binary Networks , 2018, ICLR.

[5]  Mingyuan Zhou,et al.  ARSM: Augment-REINFORCE-Swap-Merge Estimator for Gradient Backpropagation Through Categorical Variables , 2019, ICML.

[6]  David Duvenaud,et al.  Backpropagation through the Void: Optimizing control variates for black-box gradient estimation , 2017, ICLR.

[7]  Sean Gerrish,et al.  Black Box Variational Inference , 2013, AISTATS.

[8]  Yoshua Bengio,et al.  Reweighted Wake-Sleep , 2014, ICLR.

[9]  Ronald J. Williams,et al.  Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning , 2004, Machine Learning.

[10]  Joshua B. Tenenbaum,et al.  Human-level concept learning through probabilistic program induction , 2015, Science.

[11]  Peter W. Glynn,et al.  Likelihood ratio gradient estimation for stochastic systems , 1990, CACM.

[12]  Karol Gregor,et al.  Neural Variational Inference and Learning in Belief Networks , 2014, ICML.

[13]  Paul Glasserman,et al.  Monte Carlo Methods in Financial Engineering , 2003 .

[14]  Zhijian Ou,et al.  Joint Stochastic Approximation and Its Application to Learning Discrete Latent Variable Models , 2020, UAI.

[15]  Miguel Lázaro-Gredilla,et al.  Local Expectation Gradients for Black Box Variational Inference , 2015, NIPS.

[16]  Michael I. Jordan,et al.  Rao-Blackwellized Stochastic Gradients for Discrete Distributions , 2018, ICML.

[17]  Michael I. Jordan,et al.  An Introduction to Variational Methods for Graphical Models , 1999, Machine Learning.

[18]  Chen Liang,et al.  Carbon Emissions and Large Neural Network Training , 2021, ArXiv.

[19]  Emily M. Bender,et al.  On the Dangers of Stochastic Parrots: Can Language Models Be Too Big? 🦜 , 2021, FAccT.

[20]  Max Welling,et al.  Estimating Gradients for Discrete Random Variables by Sampling without Replacement , 2020, ICLR.

[21]  Jascha Sohl-Dickstein,et al.  REBAR: Low-variance, unbiased gradient estimates for discrete latent variable models , 2017, NIPS.

[22]  Ömer Deniz Akyildiz,et al.  VarGrad: A Low-Variance Gradient Estimator for Variational Inference , 2020, NeurIPS.

[23]  Ben Poole,et al.  Categorical Reparameterization with Gumbel-Softmax , 2016, ICLR.

[24]  Alek Dimitriev,et al.  ARMS: Antithetic-REINFORCE-Multi-Sample Gradient for Binary Variables , 2021, ICML.

[25]  Andreas Krause,et al.  Rao-Blackwellizing the Straight-Through Gumbel-Softmax Gradient Estimator , 2020, ICLR.

[26]  Yee Whye Teh,et al.  The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables , 2016, ICLR.

[27]  David M. Blei,et al.  Overdispersed Black-Box Variational Inference , 2016, UAI.

[28]  Roland Vollgraf,et al.  Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms , 2017, ArXiv.

[29]  Max Welling,et al.  Buy 4 REINFORCE Samples, Get a Baseline for Free! , 2019, DeepRLStructPred@ICLR.

[30]  R. Rubinstein,et al.  Optimization of static simulation models by the score function method , 1990 .

[31]  Daan Wierstra,et al.  Stochastic Backpropagation and Approximate Inference in Deep Generative Models , 2014, ICML.

[32]  George Tucker,et al.  DisARM: An Antithetic Gradient Estimator for Binary Latent Variables , 2020, NeurIPS.

[33]  Max Welling,et al.  Auto-Encoding Variational Bayes , 2013, ICLR.

[34]  Nhat Ho,et al.  Probabilistic Best Subset Selection by Gradient-Based Optimization , 2020, 2006.06448.

[35]  Mohammad Emtiyaz Khan,et al.  A Stick-Breaking Likelihood for Categorical Data Analysis with Latent Gaussian Models , 2012, AISTATS.

[36]  Tianqi Chen,et al.  Empirical Evaluation of Rectified Activations in Convolutional Network , 2015, ArXiv.