Joint Stochastic Approximation and Its Application to Learning Discrete Latent Variable Models

Although with progress in introducing auxiliary amortized inference models, learning discrete latent variable models is still challenging. In this paper, we show that the annoying difficulty of obtaining reliable stochastic gradients for the inference model and the drawback of indirectly optimizing the target log-likelihood can be gracefully addressed in a new method based on stochastic approximation (SA) theory of the Robbins-Monro type. Specifically, we propose to directly maximize the target log-likelihood and simultaneously minimize the inclusive divergence between the posterior and the inference model. The resulting learning algorithm is called joint SA (JSA). To the best of our knowledge, JSA represents the first method that couples an SA version of the EM (expectation-maximization) algorithm (SAEM) with an adaptive MCMC procedure. Experiments on several benchmark generative modeling and structured prediction tasks show that JSA consistently outperforms recent competitive algorithms, with faster convergence, better final likelihoods, and lower variance of gradient estimates.

[1]  Yoshua Bengio,et al.  Reweighted Wake-Sleep , 2014, ICLR.

[2]  Alexander M. Rush,et al.  Latent Normalizing Flows for Discrete Sequences , 2019, ICML.

[3]  Mingyuan Zhou,et al.  ARSM: Augment-REINFORCE-Swap-Merge Estimator for Gradient Backpropagation Through Categorical Variables , 2019, ICML.

[4]  G. Fort,et al.  On the geometric ergodicity of hybrid samplers , 2003, Journal of Applied Probability.

[5]  C. Andrieu,et al.  On the ergodicity properties of some adaptive MCMC algorithms , 2006, math/0610317.

[6]  Yee Whye Teh,et al.  The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables , 2016, ICLR.

[7]  Honglak Lee,et al.  Learning Structured Output Representation using Deep Conditional Generative Models , 2015, NIPS.

[8]  Max Welling,et al.  Improved Variational Inference with Inverse Autoregressive Flow , 2016, NIPS 2016.

[9]  Ben Poole,et al.  Categorical Reparametrization with Gumble-Softmax , 2017, ICLR 2017.

[10]  M. Gu,et al.  Maximum likelihood estimation for spatial models by Markov chain Monte Carlo stochastic approximation , 2001 .

[11]  Karol Gregor,et al.  Neural Variational Inference and Learning in Belief Networks , 2014, ICML.

[12]  Max Welling,et al.  Auto-Encoding Variational Bayes , 2013, ICLR.

[13]  Max Welling,et al.  Markov Chain Monte Carlo and Variational Inference: Bridging the Gap , 2014, ICML.

[14]  Bin Wang,et al.  Learning Trans-Dimensional Random Fields with Applications to Language Modeling , 2018, IEEE Transactions on Pattern Analysis and Machine Intelligence.

[15]  Ronald J. Williams,et al.  Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning , 2004, Machine Learning.

[16]  Ole Winther,et al.  Auxiliary Deep Generative Models , 2016, ICML.

[17]  Sergey Levine,et al.  MuProp: Unbiased Backpropagation for Stochastic Neural Networks , 2015, ICLR.

[18]  Tommi S. Jaakkola,et al.  Direct Optimization through arg max for Discrete Variational Auto-Encoder , 2018, NeurIPS.

[19]  E. Kuhn,et al.  Coupling a stochastic approximation version of EM with an MCMC procedure , 2004 .

[20]  Eric Moulines,et al.  Stability of Stochastic Approximation under Verifiable Conditions , 2005, Proceedings of the 44th IEEE Conference on Decision and Control.

[21]  Fredrik Lindsten,et al.  Markovian Score Climbing: Variational Inference with KL(p||q) , 2020, NeurIPS.

[22]  Tapani Raiko,et al.  Techniques for Learning Binary Stochastic Feedforward Neural Networks , 2014, ICLR.

[23]  Z. Tan Optimally Adjusted Mixture Sampling and Locally Weighted Histogram Analysis , 2017 .

[24]  Max Welling,et al.  Estimating Gradients for Discrete Random Variables by Sampling without Replacement , 2020, ICLR.

[25]  É. Moulines,et al.  Convergence of a stochastic approximation version of the EM algorithm , 1999 .

[26]  F. Liang,et al.  Weak Convergence Rates of Population Versus Single-Chain Stochastic Approximation MCMC Algorithms , 2013, Advances in Applied Probability.

[27]  Ruslan Salakhutdinov,et al.  Importance Weighted Autoencoders , 2015, ICLR.

[28]  Jascha Sohl-Dickstein,et al.  Generalizing Hamiltonian Monte Carlo with Neural Networks , 2017, ICLR.

[29]  Han-Fu Chen Stochastic approximation and its applications , 2002 .

[30]  H. Robbins A Stochastic Approximation Method , 1951 .

[31]  Andriy Mnih,et al.  Variational Inference for Monte Carlo Objectives , 2016, ICML.

[32]  Mingyuan Zhou,et al.  ARM: Augment-REINFORCE-Merge Gradient for Stochastic Binary Networks , 2018, ICLR.

[33]  Ruslan Salakhutdinov,et al.  On the quantitative analysis of deep belief networks , 2008, ICML '08.

[34]  Pierre Priouret,et al.  Adaptive Algorithms and Stochastic Approximations , 1990, Applications of Mathematics.

[35]  Michael I. Jordan,et al.  Variational Bayesian Inference with Stochastic Search , 2012, ICML.

[36]  Gareth O. Roberts,et al.  Examples of Adaptive MCMC , 2009 .

[37]  Yee Whye Teh,et al.  Tighter Variational Bounds are Not Necessarily Better , 2018, ICML.

[38]  Jascha Sohl-Dickstein,et al.  REBAR: Low-variance, unbiased gradient estimates for discrete latent variable models , 2017, NIPS.

[39]  Michael I. Jordan,et al.  Rao-Blackwellized Stochastic Gradients for Discrete Distributions , 2018, ICML.

[40]  Zhijian Ou,et al.  Joint Stochastic Approximation learning of Helmholtz Machines , 2016, ArXiv.

[41]  Tim Hesterberg,et al.  Monte Carlo Strategies in Scientific Computing , 2002, Technometrics.

[42]  David Duvenaud,et al.  Backpropagation through the Void: Optimizing control variates for black-box gradient estimation , 2017, ICLR.

[43]  Radford M. Neal Connectionist Learning of Belief Networks , 1992, Artif. Intell..

[44]  Daan Wierstra,et al.  Stochastic Backpropagation and Approximate Inference in Deep Generative Models , 2014, ICML.

[45]  Shakir Mohamed,et al.  Variational Inference with Normalizing Flows , 2015, ICML.

[46]  Matthew D. Hoffman,et al.  Learning Deep Latent Gaussian Models with Markov Chain Monte Carlo , 2017, ICML.

[47]  D. Rubin,et al.  Maximum likelihood from incomplete data via the EM - algorithm plus discussions on the paper , 1977 .

[48]  Michael I. Jordan,et al.  An Introduction to Variational Methods for Graphical Models , 1999, Machine Learning.

[49]  David Barber,et al.  Auxiliary Variational MCMC , 2018, International Conference on Learning Representations.