PALM: Machine Learning Explanations For Iterative Debugging

When a Deep Neural Network makes a misprediction, it can be challenging for a developer to understand why. While there are many models for interpretability in terms of predictive features, it may be more natural to isolate a small set of training examples that have the greatest influence on the prediction. However, it is often the case that every training example contributes to a prediction in some way but with varying degrees of responsibility. We present Partition Aware Local Model (PALM), which is a tool that learns and summarizes this responsibility structure to aide machine learning debugging. PALM approximates a complex model (e.g., a deep neural network) using a two-part surrogate model: a meta-model that partitions the training data, and a set of sub-models that approximate the patterns within each partition. These sub-models can be arbitrarily complex to capture intricate local patterns. However, the meta-model is constrained to be a decision tree. This way the user can examine the structure of the meta-model, determine whether the rules match intuition, and link problematic test examples to responsible training data efficiently. Queries to PALM are nearly 30x faster than nearest neighbor queries for identifying relevant data, which is a key property for interactive applications.