Distributionally Robust Neural Networks

Overparameterized neural networks trained to minimize average loss can be highly accurate on average on an i.i.d. test set, yet consistently fail on atypical groups of the data (e.g., by learning spurious correlations that do not hold at test time). Distributionally robust optimization (DRO) provides an approach for learning models that instead minimize worst-case training loss over a set of pre-defined groups. We find, however, that naively applying DRO to overparameterized neural networks fails: these models can perfectly fit the training data, and any model with vanishing average training loss will also already have vanishing worst-case training loss. Instead, the poor worst-case performance of these models arises from poor generalization on some groups. As a solution, we show that increased regularization---e.g., stronger-than-typical weight decay or early stopping---allows DRO models to achieve substantially higher worst-group accuracies, with 10% to 40% improvements over standard models on a natural language inference task and two image tasks, while maintaining high average accuracies. Our results suggest that regularization is critical for worst-group performance in the overparameterized regime, even if it is not needed for average performance. Finally, we introduce and provide convergence guarantees for a stochastic optimizer for this group DRO setting, underpinning the empirical study above.

[1]  Dimitri P. Bertsekas,et al.  Convex Optimization Theory , 2009 .

[2]  Gang Niu,et al.  Does Distributionally Robust Supervised Learning Give Robust Classifiers? , 2016, ICML.

[3]  Yoram Singer,et al.  Train faster, generalize better: Stability of stochastic gradient descent , 2015, ICML.

[4]  Alexander Shapiro,et al.  Stochastic Approximation approach to Stochastic Programming , 2013 .

[5]  John C. Duchi,et al.  Learning Models with Uniform Performance via Distributionally Robust Optimization , 2018, ArXiv.

[6]  Jonas Peters,et al.  Causal inference by using invariant prediction: identification and confidence intervals , 2015, 1501.01332.

[7]  R. Thomas McCoy,et al.  Right for the Wrong Reasons: Diagnosing Syntactic Heuristics in Natural Language Inference , 2019, ACL.

[8]  Qingming Huang,et al.  Relay Backpropagation for Effective Learning of Deep Convolutional Neural Networks , 2015, ECCV.

[9]  Samuel R. Bowman,et al.  A Broad-Coverage Challenge Corpus for Sentence Understanding through Inference , 2017, NAACL.

[10]  Atsuto Maki,et al.  A systematic study of the class imbalance problem in convolutional neural networks , 2017, Neural Networks.

[11]  Jian Sun,et al.  Deep Residual Learning for Image Recognition , 2015, 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

[12]  John C. Duchi,et al.  Variance-based Regularization with Convex Objectives , 2016, NIPS.

[13]  Massimiliano Pontil,et al.  Empirical Bernstein Bounds and Sample-Variance Penalization , 2009, COLT.

[14]  Daniel Kuhn,et al.  Distributionally Robust Logistic Regression , 2015, NIPS.

[15]  Peter Bühlmann,et al.  Magging: Maximin Aggregation for Inhomogeneous Large-Scale Data , 2014, Proceedings of the IEEE.

[16]  Victor S. Lempitsky,et al.  Unsupervised Domain Adaptation by Backpropagation , 2014, ICML.

[17]  Russell Greiner,et al.  Robust Learning under Uncertain Test Distributions: Relating Covariate Shift to Model Misspecification , 2014, ICML.

[18]  Percy Liang,et al.  Distributionally Robust Language Modeling , 2019, EMNLP.

[19]  Koby Crammer,et al.  Analysis of Representations for Domain Adaptation , 2006, NIPS.

[20]  Nathan Srebro,et al.  Equality of Opportunity in Supervised Learning , 2016, NIPS.

[21]  Vladimir Vapnik,et al.  Principles of Risk Minimization for Learning Theory , 1991, NIPS.

[22]  Timnit Gebru,et al.  Gender Shades: Intersectional Accuracy Disparities in Commercial Gender Classification , 2018, FAT.

[23]  Percy Liang,et al.  Fairness Without Demographics in Repeated Loss Minimization , 2018, ICML.

[24]  Xiaogang Wang,et al.  Deep Learning Face Attributes in the Wild , 2014, 2015 IEEE International Conference on Computer Vision (ICCV).

[25]  Toniann Pitassi,et al.  Fairness through awareness , 2011, ITCS '12.

[26]  Sergey Ioffe,et al.  Rethinking the Inception Architecture for Computer Vision , 2015, 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

[27]  David M. Simcha,et al.  Tackling the widespread and critical impact of batch effects in high-throughput data , 2010, Nature Reviews Genetics.

[28]  Carlos Guestrin,et al.  "Why Should I Trust You?": Explaining the Predictions of Any Classifier , 2016, ArXiv.

[29]  Samy Bengio,et al.  Understanding deep learning requires rethinking generalization , 2016, ICLR.

[30]  Vishal Gupta,et al.  Data-driven robust optimization , 2013, Math. Program..

[31]  Ming-Wei Chang,et al.  BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding , 2019, NAACL.

[32]  John Duchi,et al.  Statistics of Robust Optimization: A Generalized Empirical Likelihood Approach , 2016, Math. Oper. Res..

[33]  John C. Duchi,et al.  Certifying Some Distributional Robustness with Principled Adversarial Training , 2017, ICLR.

[34]  Stephen P. Boyd,et al.  Convex Optimization , 2004, Algorithms and Theory of Computation Handbook.

[35]  Anja De Waegenaere,et al.  Robust Solutions of Optimization Problems Affected by Uncertain Probabilities , 2011, Manag. Sci..

[36]  Elad Hoffer,et al.  Train longer, generalize better: closing the generalization gap in large batch training of neural networks , 2017, NIPS.

[37]  Sergey Ioffe,et al.  Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift , 2015, ICML.

[38]  Brendan T. O'Connor,et al.  Demographic Dialectal Variation in Social Media: A Case Study of African-American English , 2016, EMNLP.

[39]  Pietro Perona,et al.  The Caltech-UCSD Birds-200-2011 Dataset , 2011 .

[40]  Jon M. Kleinberg,et al.  Inherent Trade-Offs in the Fair Determination of Risk Scores , 2016, ITCS.

[41]  Omer Levy,et al.  Annotation Artifacts in Natural Language Inference Data , 2018, NAACL.

[42]  Colin Wei,et al.  Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss , 2019, NeurIPS.

[43]  Shin Ishii,et al.  Distributional Smoothing with Virtual Adversarial Training , 2015, ICLR 2016.

[44]  Yulia Tsvetkov,et al.  Incorporating Dialectal Variability for Socially Equitable Language Identification , 2017, ACL.

[45]  H. Shimodaira,et al.  Improving predictive inference under covariate shift by weighting the log-likelihood function , 2000 .

[46]  Zachary C. Lipton,et al.  What is the Effect of Importance Weighting in Deep Learning? , 2018, ICML.

[47]  N. Meinshausen,et al.  Maximin effects in inhomogeneous large-scale data , 2014, 1406.0596.

[48]  Gustavo Carneiro,et al.  Hidden stratification causes clinically meaningful failures in machine learning for medical imaging , 2019, CHIL.

[49]  Christina Heinze-Deml,et al.  Grouping-By-ID: Guarding Against Adversarial Domain Shifts , 2017, 1710.11469.

[50]  Nitish Srivastava,et al.  Dropout: a simple way to prevent neural networks from overfitting , 2014, J. Mach. Learn. Res..

[51]  Daniel Kuhn,et al.  Data-driven distributionally robust optimization using the Wasserstein metric: performance guarantees and tractable reformulations , 2015, Mathematical Programming.

[52]  John C. Duchi,et al.  Stochastic Gradient Methods for Distributionally Robust Optimization with f-divergences , 2016, NIPS.

[53]  Yang Song,et al.  Class-Balanced Loss Based on Effective Number of Samples , 2019, 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR).

[54]  N. Meinshausen,et al.  Anchor regression: Heterogeneous data meet causality , 2018, Journal of the Royal Statistical Society: Series B (Statistical Methodology).

[55]  Bolei Zhou,et al.  Places: A 10 Million Image Database for Scene Recognition , 2018, IEEE Transactions on Pattern Analysis and Machine Intelligence.

[56]  Karthyek R. A. Murthy,et al.  Quantifying Distributional Model Risk Via Optimal Transport , 2016, Math. Oper. Res..