Differentiable Top-k Operator with Optimal Transport

The top-k operation, i.e., finding the k largest or smallest elements from a collection of scores, is an important model component, which is widely used in information retrieval, machine learning, and data mining. However, if the top-k operation is implemented in an algorithmic way, e.g., using bubble algorithm, the resulting model cannot be trained in an end-to-end way using prevalent gradient descent algorithms. This is because these implementations typically involve swapping indices, whose gradient cannot be computed. Moreover, the corresponding mapping from the input scores to the indicator vector of whether this element belongs to the top-k set is essentially discontinuous. To address the issue, we propose a smoothed approximation, namely the SOFT (Scalable Optimal transport-based diFferenTiable) top-k operator. Specifically, our SOFT top-k operator approximates the output of the top-k operation as the solution of an Entropic Optimal Transport (EOT) problem. The gradient of the SOFT operator can then be efficiently approximated based on the optimality conditions of EOT problem. We apply the proposed operator to the k-nearest neighbors and beam search algorithms, and demonstrate improved performance.

[1]  J. Zico Kolter,et al.  OptNet: Differentiable Optimization as a Layer in Neural Networks , 2017, ICML.

[2]  Yoshua Bengio,et al.  Neural Machine Translation by Jointly Learning to Align and Translate , 2014, ICLR.

[3]  Stefano Ermon,et al.  Stochastic Optimization of Sorting Networks via Continuous Relaxations , 2019, ICLR.

[4]  Max Welling,et al.  Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement , 2019, ICML.

[5]  Richard Sinkhorn,et al.  Concerning nonnegative matrices and doubly stochastic matrices , 1967 .

[6]  Jean-Philippe Vert,et al.  Differentiable Ranking and Sorting using Optimal Transport , 2019, NeurIPS.

[7]  Samy Bengio,et al.  Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks , 2015, NIPS.

[8]  Nadir Durrani,et al.  Edinburgh’s Phrase-based Machine Translation Systems for WMT-14 , 2014, WMT@ACL.

[9]  Alex Krizhevsky,et al.  Learning Multiple Layers of Features from Tiny Images , 2009 .

[10]  Gabriel Peyré,et al.  Iterative Bregman Projections for Regularized Transportation Problems , 2014, SIAM J. Sci. Comput..

[11]  Jian Sun,et al.  Deep Residual Learning for Image Recognition , 2015, 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

[12]  Rico Sennrich,et al.  Neural Machine Translation of Rare Words with Subword Units , 2015, ACL.

[13]  Andreas Griewank,et al.  Evaluating derivatives - principles and techniques of algorithmic differentiation, Second Edition , 2000, Frontiers in applied mathematics.

[14]  C. A. R. Hoare,et al.  Algorithm 65: find , 1961, Commun. ACM.

[15]  Stefano Ermon,et al.  Reparameterizable Subset Sampling via Continuous Relaxations , 2019, IJCAI.

[16]  Ben Poole,et al.  Categorical Reparameterization with Gumbel-Softmax , 2016, ICLR.

[17]  Yoshua Bengio,et al.  On Using Very Large Target Vocabulary for Neural Machine Translation , 2014, ACL.

[18]  Alexander M. Rush,et al.  OpenNMT: Open-Source Toolkit for Neural Machine Translation , 2017, ACL.

[19]  Marco Cuturi,et al.  Sinkhorn Distances: Lightspeed Computation of Optimal Transport , 2013, NIPS.

[20]  Patrick D. McDaniel,et al.  Deep k-Nearest Neighbors: Towards Confident, Interpretable and Robust Deep Learning , 2018, ArXiv.

[21]  L. V. Kantorovich,et al.  Mathematical Methods of Organizing and Planning Production , 1960 .

[22]  Ondrej Chum,et al.  CNN Image Retrieval Learns from BoW: Unsupervised Fine-Tuning with Hard Examples , 2016, ECCV.

[23]  Ben Glocker,et al.  Attention Gated Networks: Learning to Leverage Salient Regions in Medical Images , 2018, Medical Image Anal..

[24]  Alexander M. Rush,et al.  Sequence-to-Sequence Learning as Beam-Search Optimization , 2016, EMNLP.

[25]  Quoc V. Le,et al.  Addressing the Rare Word Problem in Neural Machine Translation , 2014, ACL.

[26]  Albert Gordo,et al.  Deep Image Retrieval: Learning Global Representations for Image Search , 2016, ECCV.

[27]  Yoshua Bengio,et al.  Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation , 2014, EMNLP.

[28]  Sunita Sarawagi,et al.  Surprisingly Easy Hard-Attention for Sequence to Sequence Learning , 2018, EMNLP.

[29]  Stefan Roth,et al.  Neural Nearest Neighbors Networks , 2018, NeurIPS.

[30]  Yoshua Bengio,et al.  Gradient-based learning applied to document recognition , 1998, Proc. IEEE.

[31]  Graham Neubig,et al.  A Continuous Relaxation of Beam Search for End-to-end Training of Neural Sequence Models , 2017, AAAI.

[32]  Luca Antiga,et al.  Automatic differentiation in PyTorch , 2017 .

[33]  Yoram Singer,et al.  Efficient projections onto the l1-ball for learning in high dimensions , 2008, ICML '08.

[34]  Xiao Liu,et al.  Fine-Grained Video Categorization with Redundancy Reduction Attention , 2018, ECCV.

[35]  Ashutosh Kumar Singh,et al.  The Elements of Statistical Learning: Data Mining, Inference, and Prediction , 2010 .

[36]  Quoc V. Le,et al.  Sequence to Sequence Learning with Neural Networks , 2014, NIPS.

[37]  Alessandro Rudi,et al.  Differential Properties of Sinkhorn Approximation for Learning with Wasserstein Distance , 2018, NeurIPS.

[38]  David Stutz,et al.  Neural Codes for Image Retrieval , 2015 .