This is Jessica, but today I’m blogging about a blog post on interpretable machine learning that co-blogger Keith wrote for his day job and shared with me. I agree with multiple observations he makes. Some highlights:
The often suggested simple remedy for this unmanageable complexity is just finding ways to explain these black box models; however, those explanations can sometimes miss key information. In turn, rather than being directly connected with what is going on in the black box model, they result in being “stories” for getting concordant predictions. Given that concordance is not perfect, they can result in very misleading outcomes for many situations.
I blogged a bit about this before, giving examples like inconsistency in explanations for the same inputs and outputs. Thinking about these complications, I’m reminded of a talk by Chris Olah that I saw back in 2018, where he talked about how feature visualizations of activiations that fire for different image inputs to a deep neural net allow us to seriously consider what’s going on inside, in a way that makes them analogous to the discovery of the microscope opening up a whole new world of microorganisms. I wonder if this idea has lost favor given that sometimes these explanations don’t behave the way we would hope.
I can also buy that not-quite-correct explanations can cause problems downstream since I see it in the human context. The other day I had to ask a collaborator to try to refrain from providing explanations instead of methodological details for unexpected analysis results, since when delivered passionately the explanations could seem good enough that I wouldn’t question them initially, only to later realize we wasted time when there was a better explanation. Plus when every unexpected result comes with explanation, skepticism with the explanation can make it feel like you’re undercutting everything a person says, even if you want to encourage discussion of these things.
While we need to accept what we cannot understand, we should never overlook the advantages of what we can understand. For example, we may never fully understand the physical world. Nor how people think, interact, create and or decide. In ML, Geoffrey Hinton’s 2018 YouTube drew attention to the fact that people are unable to explain exactly how they decide in general if something is the digit 2 or not. This fact was originally pointed out, a while ago, by Herbert Simon, and has not been seriously disputed (Erickson and Simon, 1980). However, prediction models are just abstractions and we can understand the abstractions created to represent that reality, which is complex and often beyond our direct access. So not being able to understand people is not a valid reason to dismiss desires to understand prediction models.
In essence, abstractions are diagrams or symbols that can be manipulated, in error-free ways, to discern their implications. Usually referred to as models or assumptions, they are deductive and hence can be understood in and of themselves for simply what they imply. That is, until they become too complex. For instance, triangles on the plane are understood by most, while triangles on the sphere are understood by less. Reality may always be too complex, but models that adequately represent reality for some purpose need not be. Triangles on the plane are for navigation of short distances while on the sphere, for long distances. Emphatically, it is the abstract model that is understood not necessarily the reality it attempts to represent.
I’m not sure I totally grasp the distinction Keith is trying to get at here. To me the above passage implies we should be careful about assuming that some aspects of reality are too complex to explain. But given the part about concordances being misleading above, it seems applying this recursively can lead to problems: when the deep model is the complex thing we want to explain, we have to be careful isolating what we think are simpler units of abstractions to capture what it’s doing. For instance, a node in a neural network is a relatively simple abstraction (i.e., a linear regression wrapped in a non-linear activation function), but is thinking at that level of abstraction as a means of trying to understand the much more complex behavior of the network as a whole useful? Maybe Keith is trying to motivate considering interpretability in your choice of model, which he talks about later.
Related to people not being able to say how they recognize a 2, one thing that people can potentially do is point to the processor they think is responsible; e.g., I can’t describe why it’s a 2 succintly based on low level properties like edge detection but maybe I could say something higher level like, ‘I would guess it’s something like visual word form memory.’ It’s not a complete explanation, but it seems that sort of meta statement could potentially be useful since the first step to debugging is to figure out where to start looking.
[A] persistent misconception has arisen in ML that models for accurate prediction usually need to be complex. To build upon previous examples, there remains some application areas where simple models have yet to achieve accuracy comparable to black box models. On the other hand, simple models continue to predict as accurately as any state of the art black box model and thus, the question, as noted in the 2019 article by Rudin and Radin, is: “Why Are We Using Black Box Models in AI When We Don’t Need To?”
The referenced paper describes how the authors entered a NeurIPS competition on explainability, but then realized they didn’t need a black box at all to do the task, they could just use one of many simpler, interpretable models. Oops. Some of the interpretability work coming out of ML does seem like what you get when complexity enthusiasts excitedly latch onto new problem that can motivate more of what they’re good at (e.g., optimization), without necessarily questioning the premise.
Interpretable models are far more trustworthy in that they can be more readily discerned where and when they should be trusted or not and in what ways. But, how can one do this without understanding how the model works, especially for a model that is patently not trustworthy? This is especially important in cases where the underlying distribution of data changes, where it is critical to trouble shoot and modify without delays, as noted in the 2020 article by Hamamoto et al. It is arguably much more difficult to remain successful in the ML full life cycle with black box models than with interpretable models.
Agreed; debugging calls for some degree of interpretability. And often the more people you can get helping debug something, the more likely you are to find the problem.
There is increasing understanding based on considering numerous possible prediction models in a given prediction task. The not-too-unusual observation of simple models performing well for tabular data (a collection of variables, each of which has meaning on its own) was noted over 20 years ago and was labeled the “Rashomon effect” (Breiman, 2001). Breiman posited the possibility of a large Rashomon set in many applications; that is, a multitude of models with approximately the same minimum error rate. A simple check for this is to fit a number of different ML models to the same data set. If many of these are as accurate as the most accurate (within the margin of error), then many other untried models might also be. A recent study (Semenova et al., 2019), now supports running a set of different (mostly black box) ML models to determine their relative accuracy on a given data set to predict the existence of a simple accurate interpretable model—that is, a way to quickly identify applications where it is a good bet that accurate interpretable prediction model can be developed.
I like the idea of trying to estimate how many different ways there are to achieve good accuracy on some inference problem. I’m reminded of a paper I recently read which does basically the inverse – generate a bunch of hypothetical datasets and see how well a model intended to explain human behavior does across them, to understand when you just have a very flexible model versus when it’s actually providing some insight into behavior.
The full data science and life-cycle process likely is different when using interpretable models. More input is needed from domain experts to produce an interpretable model that make sense to them. This should be seen as an advantage. For instance, it is not too unusual at a given stage to find numerous equally interpretable and accurate models. To the data scientist, there may seem little to guide the choice between these. But, when shown to domain experts, they may easily discern opportunities to improve constraints as well as indications of which ones are less likely to generalize well. All equally interpretable and accurate models are not equal in the eyes of domain experts.
I definitely agree with this and other comments Keith makes about the need to consider interpretability early in the process. I was involved in a paper a few years ago where my co-authors had interviewed a bunch of machine learning developers about interpretability. One of the more surprising things we found was that in contrast to ML lit implying that interpretability can be applied post model development, it was seen by many of the developers as a more holistic thing related to how much others in their organization trusted their work at all, and consequently many thought about from the beginning of model development.
There is now a vast and confusing literature, which conflates interpretability and explainability. In this brief blog, the degree of interpretability is taken simply as how easily the user can grasp the connection between input data and what the ML model would predict. Erasmus et al. (2020) provide a more general and philosophical view. Rudin et al. (2021) avoid trying to provide an exhaustive definition by instead providing general guiding principles to help readers avoid common, but problematic ways of thinking about interpretability. On the other hand, the term “explainability” often refers to post hoc attempts to explain a black box by using simpler ‘understudy’ models that predict the black box predictions.
I’ve always found the simple definition of interpretability as ability to simulate what a model will predict interesting. At one point I was thinking about how if interpretability is mainly aimed at building trust in model predictions, maybe a “deeper” proxy for trust could be called internalizability, which is where the person (after using the model) is simulating the model but they don’t know it.
I’m reminded of discussions about sports training. I can be hard for us to describe what we do when we swing a bat or throw a ball; indeed, our descriptions of our own actions may be highly inaccurate, as we can see from going back and watching a video. Usually we think that if we want to improve, it’s a good idea to figure out what we’re currently doing.
Relatedly, when I teach a skill (whether in statistics or sports or whatever), I become much more aware of my current practice.
This post reminds me of the AI startup where I used to work.
There was a cult of deep learning. Complex models exerted a powerful fascination.
The longer the name the better.
GAN based meta X learning of hyperbolic latent space would soon take over the world, then covid hit, then half of the company was laid off…
Ah yes, the religion of Complexify_it_(but_try_to_keep_the_math_discrete).
I’m currently working on evaluating the fire risk of various parts of electricity infrastructure — at the level of individual electric poles and stretches of wire — based on historical data on fires and electric outages. The culture and practices of the team are very much aligned with machine learning; the idea of hand-defining and hand-tuning a model is pretty much a non-starter. So I started with random forest models and then gradient-boosted forest models. But just for laffs I put the same variables into a simple logistic regression model…and it performs just as well. Not better, but just as well. Much easier to interpret and to explain, and equally good empirically. And with some careful selection of interaction terms it can probably improve a bit. I’m going to argue for making the logistic regression model the official one, but I’m expecting some pushback.
This is frankly terrifying to me. Fire risk is exactly where rigorous uncertainty quantification seems important—poorly thought out normality assumptions implicit in mean squared loss functions should be a non starter.
you have sufficient historical fire data at the level of individual poles? Do you incorporate information about the local topographic, forest and/or typical weather conditions (i.e., topographic features can create different wind conditions; S or SE facing slopes drier), or is that done by a different group or something? Seems like it would be relevant.
I’m reminded of the WWII warplane joke. The right place to put the additional armoring is the place where you don’t see any holes in returning planes, since planes that were hit in those places went down and didn’t return. (The holes you see in returning planes are holes that didn’t knock out the plane. So those holes aren’t representative of places that you don’t want to get shot in.)
The story here would be: places that already have had fires are places that don’t have a lot of stuff (e.g. dried dead underbrush) that causes fires, since the fires already burned that stuff. But the places where fires are likely to occur are places where fires haven’t occurred (and this still have a danger of fires).
How does ML do on this sort of problem?
> Perhaps and in the end it may boil down to the fact that if simple works, then why make things more complex.
I’d just add to this testing. Things are easier to understand if they’re tested well. If you can test well enough, then there’s less of a need to understand what’s going on (cuz you’ve defined what’s going on by the tests or whatever).
I don’t think testing these sorts of things would necessarily be easy, even if you can use off the shelf neural network components or somehow don’t have to deal with most of the training difficulties. It might be easier than explaining the neural network, and you’d need to do the testing either way.
As a child, movies led me to believe I would frequently run into dangerous quicksand and lava, so having a grappling hook handy was essential. In the same way, school led me to believe I would constantly be classifying lots of unstructured data like bitmaps and so I had to have a deep neural network handy.
Great post–reminds me of this blog by Wolfram (curse his name)
Hey! My 4yo is always on the lookout for lava! Are you saying she’s misinformed?
I get how you use a grappling hook to pull yourself out of quicksand. I don’t know if it will work with lava.
The cover of my book on information quality is “42”. I gave several talks titled “42” asking the audience what it means. Having 42 near 38 and 43 makes is an OK number. Having it next to 0.00000000000000000000000000000000000000007 makes is huge. However, near 2589000000000000000000, makes is small.
The trick is the context. If you add a story on where the number comes from, the number becomes data.
Hinton’s YouTube on “2” is similar. Computer scientists think differently from statisticians. He want a definition of “2” without considering context.
Another point is the possibility to explain something by demarcation of a Boundary of Meaning. You can explain what something is, and what it is not. I called this generalization of findings. https://papers.ssrn.com/sol3/papers.cfm?abstract_id=3035070
Sameness: two things cannot be completely the same in all respects; else they would be one; and not two. Difference: two things cannot be entirely different in every respect; for even then they would have that in common – at least; that relation of difference. What is a thing and what is a collection of things: is it not apparent that the distinction is – properly speaking – really a phenomenon of the attention?
Don’t think context was overlooked.
The inaccurate reports found by other research [e.g. asking subjects how they decided] are shown to result from requesting information that was never directly heeded, thus forcing S[ubject]s to _infer_ rather than _remember_ their mental processes. https://content.apa.org/record/1980-24435-001
Andrew provided a physical example of this https://statmodeling.stat.columbia.edu/2021/06/30/not-being-able-to-say-why-you-see-a-2-doesnt-excuse-your-uninterpretable-model/#comment-1891736
I don’t doubt that people sometimes make things up when asked to verbalize about information they never paid attention to. I guess that’s why people who use think-aloud protocol have to be so careful about not asking leading questions. I have found it useful though in experimental settings to ask people to describe their strategies as they do some information task. It’s always a risk that they aren’t accurate, so it’s never our main analysis, but often in experiments on decision making with some interface, given the amount and kind of data we can collect, it’s not possible to specify a model that can differentiate between multiple similar strategies a person might be using. At least not without designing some series of task with the express purpose of doing that.
I was browsing some intepretability research today and came across a paper that uses info from a deep neural net to improve the accuracy of simple models (e.g., because they are preferred by experts): https://arxiv.org/pdf/1807.07506.pdf Seems in line with the sentiment of Keith’s post.
Would not suggest this line of research won’t get anywhere, but given the intuitive explanation some concerns immediately arise.
“The primary intuition behind our approach is to identify examples that the simple model will most likely fail on, i.e. identify truly hard examples. We then want to inform the simple model to ignore these examples.”
Hopefully the easiness was determined on hold out samples (i.e. not just kicking out data points with largest observed prediction errors).
Simple is not necessarily interpretable except maybe in extremely simplicity and method preferences can largely be unconnected with interpretability (e.g. preference for Cox proportional regression in survival analysis).
The non-easy samples maybe the more critical ones to get even somewhat close prediction for (e.g. highly infectious variants).
Given the prediction is either a number or label “too hard”, the “too hard” predictions are left as even blacker box output than the deep neural net predictions.
However a prediction model taken as just a model is just math and in math, how it was happened upon, simply does not matter. If an accurate interpretable is happened upon – it is what it is.
tl;dr: For interpretability of a deep network’s predictions and its underlying training data, there are compelling practical reasons to structure the deep network in such a way as to be able to back-out predictions at a lower resolution than that of the training labels, and then approximate those predictions as a weighted vote over the training set. Dense matching constraints against the observed data (which we get from the model approximation) yield heuristics/signals for prediction reliability/uncertainty, and the approximations additionally provide a mechanism for updating the model in some settings without a full re-training of the underlying deep network. See “Detecting Local Insights from Global Labels: Supervised & Zero-Shot Sequence Labeling via a Convolutional Decomposition”: https://direct.mit.edu/coli/article/doi/10.1162/coli_a_00416/106772. Here’s a 12 minute video: https://youtu.be/iJ_udvksyqE
To get toward a practically viable approach for interpretability (and by extension, deployment) of the deep networks, we needed a few technical advancements and conceptual shifts. (The technical details, including relative to existing approaches, are important from a computer science perspective and for actually implementing for a give task, but here I will focus on the higher-level conceptual points, which I believe are rather unique themselves in totality with a single method/model.) In a sense, we are going to view interpretability as an interactive/human-in-the-loop prediction task at a lower resolution of the input than that for which we have labels (at least at scale, at least without further updating/labeling, as noted below), but to do so, we’ll ideally need some properties/behaviors that are not typically associated with the deep networks out-of-the-box. (Note that I’m not going to try to separate “interpretability” from “explainability”. In CS there does not seem to be a well-agreed upon distinction between those two terms, in part because it has been somewhat amorphous as to what our ideal target behavior even really was with the deep networks. See the point about asking the right questions, below.)
Typically when we use the large deep networks it is in the context of massive training datasets, which may have been curated by means/methods somewhat opaque to the model designers. This is particularly true for data used for pre-training the networks. To reach scale, such data may be derived via semi-automated means and only lightly curated by annotators. As such, a core theme of this notion of interpretability is that we will want methods for analyzing the data itself, and in particular, relative to the model’s predictions for a given instance. The other important point on the data front is that often when we think of deploying the deep networks, it is in the context of some type of interactive/dynamic end-user application where we have new data coming in which has the potential to differ rather significantly from what was seen in training. This is referred to as the out-of-domain (OOD) and/or domain/sub-population shift problem, and is an issue since the deep networks tend to be sensitive to relatively modest perturbations in the input data (i.e., they are “brittle”). So in summary, when it comes to data, we face challenges both in terms of knowing what exactly is in the tails of the training data, and handling input data at inference in the tails of the distribution that could throw the whole endeavor unexpectedly off the rails.
OK, so now we know the data itself is going to pose some serious challenges in this context. Unfortunately, it gets even worse from there, since the models themselves pose some intrinsic challenges, which will require some conceptual shifts in how we think about interpretability. In the context of document-level classification, one means of making sense of a document-level prediction is to use some type of learned mechanism of the network by which we can back out predictions/activations for lower-levels of the input (such as words) in an indirect manner. These mechanisms fall under the general heading of “attention mechanisms”. These approaches can be useful when the goal is to create a sequence labeler when we don’t have word/token-level labels for training, and provided the sequence labeling effectiveness is good, is already a means of analyzing the data (e.g., simply run it over your data and visualize the activations). However, there are some subtle caveats to remember when thinking about these mechanisms in terms of the document-level predictions. The parameters of the deep network are not identifiable, and as a reflection, it is possible to bias the attention mechanisms to have rather different word-level predictions while the document-level predictions nonetheless stay similar. That we now know how to structure the networks with suitable inductive biases/priors (and losses) to create such indirect sequence labelers is good news for practical applications, but the nature of the problem is such that we will have to think carefully about the inductive biases of both the architecture and the losses for each dataset/task (deciding how the local level relates to the global level, and vice-versa). However, even if our particular attention mechanism produces a strong predictor for our held-out data, there is a sense in which we have just kicked the can down to a lower level: We’ve turned our uninterpretable document-level prediction into an uninterpretable word-level prediction, and we haven’t done anything about the aforementioned domain-shift/reliability issue. As such, we ideally want some type of additional mechanism to relate the test predictions to those for which we have known labels (namely, those from the training set), and to have some means of constraining the predictions to at least be similar to what we have observed in training. Finally, if we’re analyzing the data, or constraining the predictions, and we find that something is indeed wrong in our training data (e.g., an incorrect label) or to a limited extent, we want to add new data to the training set (for new domains), is our only recourse to retrain the massive model we took weeks/large compute to train? What if an end user wants to make a local update (changing labels, etc.)?
In short, we have a number of hurdles to overcome in order to have a reasonable handle over interpreting the networks so that we can deploy them in practice: We need some means of analyzing the data under the model; we are going to have to address the OOD and uncertainty/reliability issues (to have some sense of not only when the predictions may go off the rails, but also importantly, when our interpretability methods themselves may be going off the rails, if interpretability is itself viewed as a prediction task); we need some means of relating the global instance-level predictions down to constituent parts (and vice-versa), with flexibility in the approach to be adaptable to various priors we may have; and ideally, we seek some degree of updatability when things (inevitably) go wrong with the model or data without having to re-train the full model.
Part of the challenge has been simply asking the right questions to determine what properties/characteristics we even need in practice. In my view, the whole endeavor is rather more shaky without any one of the aforementioned characteristics. For example, you could have a highly effective indirect sequence labeler (via an attention-style mechanism) on your in-domain data, but if you haven’t addressed the domain-shift problem, it could all unexpectedly fall apart at deployment on the new data that users submit. As another example, perhaps you have some reasonably effective intrinsic measure of reliability/uncertainty, but you lack an explicit connection to the training data, so you get blindsided (and have no direct means of analyzing and addressing) when it turns out that your massive training set has very poor quality labels for certain sub-populations, cutting the data space differently than that of your pre-deployment trials and analyses.
I’ve been working on this line of research for some time now, but I wanted to mention the initial paper in this series which focuses on the document-level classification case, since I recently presented it at EMNLP (the paper is to appear in the journal Computational Linguistics) and these ideas are perhaps not yet widely known. (It’s been in the TODO queue to reply to this post for a while.)
We obtain the aforementioned properties over the deep networks with a particular attention-style mechanism that gives us flexibility in producing word-level predictions from models trained with document-level labels. By structuring the network in this particular way, we will then derive dense representations for each word-level prediction that we use for matching against the training data, or a support set with known labels. We will then construct an approximation of those word-level predictions as a K-NN over the matches and their associated labels. So in summary, at a high level, we’re going to decompose the model first across the input, down to a lower resolution than that for which we had annotated labels. Next, we’re going to decompose each of those predictions as a weighted vote over the training set.
This will give us a tool for analyzing the data at a lower resolution than our available ground-truth labels, and this approach gives us an explicit connection to the training instances for each prediction. Importantly, it also turns out that the approximation gives us strong signals as to the reliability of the predictions: The predictions become more reliable as the output magnitude of the K-NN increases (recall it’s a weighted vote of labels and predictions over the training set) and as the distance to the first match decreases. We can then use that to screen instances unlike those seen in training. Interestingly, to a limited extent, we can then update the model provided the representations can match to a new domain (or alternatively, in the more perfunctory case, we can simply update the labels in the training/support set). Holding other things constant, we haven’t gained anything in terms of robustness over domain-shifted/OOD data (and in fact, there have not been to my knowledge any approaches that have consistently improved robustness, ceteris paribus holding the model and training data fixed), but importantly, at least we can screen such data, to prevent things from going off the rails, sending the input to humans for adjudication, and then in some cases, we can perform an update via the model approximation.
Paper title: “Detecting Local Insights from Global Labels: Supervised & Zero-Shot Sequence Labeling via a Convolutional Decomposition”
12 minute video: https://youtu.be/iJ_udvksyqE
(The most recent arXiv version has the same content if you want to read in the interim without the ‘uncorrected proof’ watermark: https://arxiv.org/pdf/1906.01154.pdf)
Online Appendix: https://raw.githubusercontent.com/allenschmaltz/exa/main/online_appendix/binary_appendix.pdf
What’s next? Well from a statistics perspective, it may seem interesting that the signals from the magnitude of the K-NN output and the distance to the first match encode information about prediction reliability of the deep network; that’s a rather surprising and useful side-effect of the particular approximation approach. Nonetheless, it may be somewhat unsatisfying as it’s still otherwise not obvious how one would then turn that into a prediction set/confidence band/etc. IF we want to take into account domain shifts (keeping with the theme that we really should be expecting domain shifts, as our default). If we’re setting a threshold on the held-out data, there’s obviously a dependence then with the distribution of the held-out data if we want to maintain coverage, as opposed to just knocking out the most unreliable predictions as a hard cut (without imposing the more challenging requirement of having some handle over the proportion of instances in the subsets/divisions). Just applying a conformal approach, even if we condition on distances, won’t necessarily be sufficient, since domain shifts could lead to arbitrarily large under coverage over the test distribution for some sub-populations. It turns out to be very non-trivial to resolve this issue; my colleague and I have a new paper coming soon that proposes an approach. I’ll try to remember to update this post with a comment when it’s available.
What else? It’s not yet clear how dependent this approach is on the particular neural architecture (but note that Transformers, the architecture I focused on, is a commonly used architecture in practice). It’s also an interesting question to ask whether there are other approximations that yield similar behavior (e.g., tangent kernel/path kernel).