DisARM: An Antithetic Gradient Estimator for Binary Latent Variables

Training models with discrete latent variables is challenging due to the difficulty of estimating the gradients accurately. Much of the recent progress has been achieved by taking advantage of continuous relaxations of the system, which are not always available or even possible. The Augment-REINFORCE-Merge (ARM) estimator provides an alternative that, instead of relaxation, uses continuous augmentation. Applying antithetic sampling over the augmenting variables yields a relatively low-variance and unbiased estimator applicable to any model with binary latent variables. However, while antithetic sampling reduces variance, the augmentation process increases variance. We show that ARM can be improved by analytically integrating out the randomness introduced by the augmentation process, guaranteeing substantial variance reduction. Our estimator, \emph{DisARM}, is simple to implement and has the same computational cost as ARM. We evaluate DisARM on several generative modeling benchmarks and show that it consistently outperforms ARM and a strong independent sample baseline in terms of both variance and log-likelihood. Furthermore, we propose a local version of DisARM designed for optimizing the multi-sample variational bound, and show that it outperforms VIMCO, the current state-of-the-art method.

[1]  Max Welling,et al.  Buy 4 REINFORCE Samples, Get a Baseline for Free! , 2019, DeepRLStructPred@ICLR.

[2]  Mike Wu,et al.  Differentiable Antithetic Sampling for Variance Reduction in Stochastic Variational Inference , 2019, AISTATS.

[3]  Xiaoning Qian,et al.  Probabilistic Best Subset Selection via Gradient-Based Optimization , 2020, 2006.06448.

[4]  Andriy Mnih,et al.  Variational Inference for Monte Carlo Objectives , 2016, ICML.

[5]  Yoshua Bengio,et al.  Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation , 2013, ArXiv.

[6]  Jascha Sohl-Dickstein,et al.  REBAR: Low-variance, unbiased gradient estimates for discrete latent variable models , 2017, NIPS.

[7]  Michael I. Jordan,et al.  An Introduction to Variational Methods for Graphical Models , 1999, Machine Learning.

[8]  Andrew L. Maas Rectifier Nonlinearities Improve Neural Network Acoustic Models , 2013 .

[9]  Mingyuan Zhou,et al.  ARM: Augment-REINFORCE-Merge Gradient for Stochastic Binary Networks , 2018, ICLR.

[10]  Sergey Levine,et al.  MuProp: Unbiased Backpropagation for Stochastic Neural Networks , 2015, ICLR.

[11]  Karol Gregor,et al.  Neural Variational Inference and Learning in Belief Networks , 2014, ICML.

[12]  Ruslan Salakhutdinov,et al.  Importance Weighted Autoencoders , 2015, ICLR.

[13]  David Wingate,et al.  Automated Variational Inference in Probabilistic Programming , 2013, ArXiv.

[14]  Yee Whye Teh,et al.  The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables , 2016, ICLR.

[15]  Miguel Lázaro-Gredilla,et al.  Local Expectation Gradients for Black Box Variational Inference , 2015, NIPS.

[16]  David Duvenaud,et al.  Backpropagation through the Void: Optimizing control variates for black-box gradient estimation , 2017, ICLR.

[17]  Max Welling,et al.  Auto-Encoding Variational Bayes , 2013, ICLR.

[18]  Daan Wierstra,et al.  Stochastic Backpropagation and Approximate Inference in Deep Generative Models , 2014, ICML.

[19]  Stefano Ermon,et al.  Adaptive Antithetic Sampling for Variance Reduction , 2019, ICML.

[20]  Ben Poole,et al.  Categorical Reparameterization with Gumbel-Softmax , 2016, ICLR.

[21]  Sean Gerrish,et al.  Black Box Variational Inference , 2013, AISTATS.

[22]  Mingyuan Zhou,et al.  ARSM: Augment-REINFORCE-Swap-Merge Estimator for Gradient Backpropagation Through Categorical Variables , 2019, ICML.

[23]  Michael I. Jordan,et al.  Variational Bayesian Inference with Stochastic Search , 2012, ICML.

[24]  Peter W. Glynn,et al.  Likelihood ratio gradient estimation for stochastic systems , 1990, CACM.

[25]  Jimmy Ba,et al.  Adam: A Method for Stochastic Optimization , 2014, ICLR.

[26]  T. Weber,et al.  Stochastic Gradient Estimation With Finite Differences , 2016 .

[27]  Ronald J. Williams,et al.  Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning , 2004, Machine Learning.