Unbiased Gradient Estimation with Balanced Assignments for Mixtures of Experts

Training large-scale mixture of experts models efficiently on modern hardware requires assigning datapoints in a batch to different experts, each with a limited capacity. Recently proposed assignment procedures lack a probabilistic interpretation and use biased estimators for training. As an alternative, we propose two unbiased estimators based on principled stochastic assignment procedures: one that skips datapoints which exceed expert capacity, and one that samples perfectly balanced assignments using an extension of the Gumbel-Matching distribution [29]. Both estimators are unbiased, as they correct for the used sampling procedure. On a toy experiment, we find the ‘skip’-estimator is more effective than the balanced sampling one, and both are more robust in solving the task than biased alternatives.

[1]  Peter W. Glynn,et al.  Likelihood ratio gradient estimation for stochastic systems , 1990, CACM.

[2]  Naman Goyal,et al.  BASE Layers: Simplifying Training of Large, Sparse Models , 2021, ICML.

[3]  Chang Zhou,et al.  Exploring Sparse Expert Models and Beyond , 2021, ArXiv.

[4]  Marco Cuturi,et al.  Sinkhorn Distances: Lightspeed Computation of Optimal Transport , 2013, NIPS.

[5]  Carlos Riquelme,et al.  Scaling Vision with Sparse Mixture of Experts , 2021, NeurIPS.

[6]  Philip A. Knight,et al.  The Sinkhorn-Knopp Algorithm: Convergence and Applications , 2008, SIAM J. Matrix Anal. Appl..

[7]  Tommi S. Jaakkola,et al.  Approximate inference using conditional entropy decompositions , 2007, AISTATS.

[8]  Harold W. Kuhn,et al.  The Hungarian method for the assignment problem , 1955, 50 Years of Integer Programming.

[9]  Quoc V. Le,et al.  Diversity and Depth in Per-Example Routing Models , 2018, ICLR.

[10]  Yoshua Bengio,et al.  Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation , 2013, ArXiv.

[11]  E. Gumbel Statistical Theory of Extreme Values and Some Practical Applications : A Series of Lectures , 1954 .

[12]  Scott W. Linderman,et al.  Learning Latent Permutations with Gumbel-Sinkhorn Networks , 2018, ICLR.

[13]  Mark Huber,et al.  Exact Sampling from Perfect Matchings of Dense Regular Bipartite Graphs , 2006, Algorithmica.

[14]  Richard Zemel,et al.  Efficient Feature Learning Using Perturb-and-MAP , 2013 .

[15]  J. M. Bilbao,et al.  Contributions to the Theory of Games , 2005 .

[16]  Ignacio Cases,et al.  Routing Networks and the Challenges of Modular and Compositional Computation , 2019, ArXiv.

[17]  Richard Sinkhorn A Relationship Between Arbitrary Positive Matrices and Doubly Stochastic Matrices , 1964 .

[18]  Eric Vigoda,et al.  A polynomial-time approximation algorithm for the permanent of a matrix with nonnegative entries , 2004, JACM.

[19]  Noam Shazeer,et al.  Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity , 2021, ArXiv.

[20]  Ronald J. Williams,et al.  Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning , 2004, Machine Learning.

[21]  Richard Sinkhorn,et al.  Concerning nonnegative matrices and doubly stochastic matrices , 1967 .

[22]  Alex Graves,et al.  Adaptive Computation Time for Recurrent Neural Networks , 2016, ArXiv.

[23]  Frank Nielsen,et al.  Sinkhorn AutoEncoders , 2018, UAI.

[24]  David Barber,et al.  Modular Networks: Learning to Decompose Neural Computation , 2018, NeurIPS.

[25]  Stefano Ermon,et al.  Stochastic Optimization of Sorting Networks via Continuous Relaxations , 2019, ICLR.

[26]  Tapani Raiko,et al.  Techniques for Learning Binary Stochastic Feedforward Neural Networks , 2014, ICLR.

[27]  Stephen J. Garland,et al.  Algorithm 97: Shortest path , 1962, Commun. ACM.

[28]  Jasper Snoek,et al.  Sinkhorn Networks: Using Optimal Transport Techniques to Learn Permutations , 2017 .

[29]  Jason Weston,et al.  Hash Layers For Large Sparse Models , 2021, NeurIPS.

[30]  Jimmy Ba,et al.  Adam: A Method for Stochastic Optimization , 2014, ICLR.

[31]  Ryan P. Adams,et al.  Ranking via Sinkhorn Propagation , 2011, ArXiv.

[32]  Marc'Aurelio Ranzato,et al.  Learning Factored Representations in a Deep Mixture of Experts , 2013, ICLR.

[33]  George Papandreou,et al.  Perturb-and-MAP random fields: Using discrete optimization to learn and sample from energy models , 2011, 2011 International Conference on Computer Vision.

[34]  Tri Dao,et al.  Approximating the Permanent by Sampling from Adaptive Partitions , 2019, NeurIPS.

[35]  Subhransu Maji,et al.  On Sampling from the Gibbs Distribution with Random Maximum A-Posteriori Perturbations , 2013, NIPS.

[36]  Geoffrey E. Hinton,et al.  Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer , 2017, ICLR.

[37]  Orhan Firat,et al.  GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding , 2020, ICLR.

[38]  Alexander J. Smola,et al.  Direct Optimization of Ranking Measures , 2007, ArXiv.

[39]  Paul R. Milgrom,et al.  Designing Random Allocation Mechanisms: Theory and Applications , 2013 .

[40]  Richard S. Sutton,et al.  Reinforcement Learning: An Introduction , 1998, IEEE Trans. Neural Networks.

[41]  Jakub M. Tomczak On some properties of the low-dimensional Gumbel perturbations in the Perturb-and-MAP model , 2016 .

[42]  Tom Minka,et al.  A* Sampling , 2014, NIPS.

[43]  Joelle Pineau,et al.  Conditional Computation in Neural Networks for faster models , 2015, ArXiv.

[44]  M. Klein A Primal Method for Minimal Cost Flows with Applications to the Assignment and Transportation Problems , 1966 .