5.7 Local Surrogate (LIME)
Local surrogate models are interpretable models that are used to explain individual predictions of black box machine learning models. Local interpretable model-agnostic explanations (LIME)37 is a paper in which the authors propose a concrete implementation of local surrogate models. Surrogate models are trained to approximate the predictions of the underlying black box model. Instead of training a global surrogate model, LIME focuses on training local surrogate models to explain individual predictions.
The idea is quite intuitive. First, forget about the training data and imagine you only have the black box model where you can input data points and get the predictions of the model. You can probe the box as often as you want. Your goal is to understand why the machine learning model made a certain prediction. LIME tests what happens to the predictions when you give variations of your data into the machine learning model. LIME generates a new dataset consisting of permuted samples and the corresponding predictions of the black box model. On this new dataset LIME then trains an interpretable model, which is weighted by the proximity of the sampled instances to the instance of interest. The interpretable model can be anything from the interpretable models chapter, for example Lasso or a decision tree. The learned model should be a good approximation of the machine learning model predictions locally, but it does not have to be a good global approximation. This kind of accuracy is also called local fidelity.
Mathematically, local surrogate models with interpretability constraint can be expressed as follows:
\[\text{explanation}(x)=\arg\min_{g\in{}G}L(f,g,\pi_x)+\Omega(g)\]
The explanation model for instance x is the model g (e.g. linear regression model) that minimizes loss L (e.g. mean squared error), which measures how close the explanation is to the prediction of the original model f (e.g. an xgboost model), while the model complexity \(\Omega(g)\) is kept low (e.g. prefer fewer features). G is the family of possible explanations, for example all possible linear regression models. The proximity measure \(\pi_x\) defines how large the neighborhood around instance x is that we consider for the explanation. In practice, LIME only optimizes the loss part. The user has to determine the complexity, e.g. by selecting the maximum number of features that the linear regression model may use.
The recipe for training local surrogate models:
- Select your instance of interest for which you want to have an explanation of its black box prediction.
- Perturb your dataset and get the black box predictions for these new points.
- Weight the new samples according to their proximity to the instance of interest.
- Train a weighted, interpretable model on the dataset with the variations.
- Explain the prediction by interpreting the local model.
In the current implementations in R and Python, for example, linear regression can be chosen as interpretable surrogate model. In advance, you have to select K, the number of features you want to have in your interpretable model. The lower K, the easier it is to interpret the model. A higher K potentially produces models with higher fidelity. There are several methods for training models with exactly K features. A good choice is Lasso. A Lasso model with a high regularization parameter \(\lambda\) yields a model without any feature. By retraining the Lasso models with slowly decreasing \(\lambda\), one after the other, the features get weight estimates that differ from zero. If there are K features in the model, you have reached the desired number of features. Other strategies are forward or backward selection of features. This means you either start with the full model (= containing all features) or with a model with only the intercept and then test which feature would bring the biggest improvement when added or removed, until a model with K features is reached.
How do you get the variations of the data? This depends on the type of data, which can be either text, image or tabular data. For text and images, the solution is to turn single words or super-pixels on or off. In the case of tabular data, LIME creates new samples by perturbing each feature individually, drawing from a normal distribution with mean and standard deviation taken from the feature.
5.7.1 LIME for Tabular Data
Tabular data is data that comes in tables, with each row representing an instance and each column a feature. LIME samples are not taken around the instance of interest, but from the training data’s mass center, which is problematic. But it increases the probability that the result for some of the sample points predictions differ from the data point of interest and that LIME can learn at least some explanation.
It is best to visually explain how sampling and local model training works:
## Creating dataset ###########################################################
library("dplyr")
library("ggplot2")
# Define range of set
lower_x1 = -2
upper_x1 = 2
lower_x2 = -2
upper_x2 = 1
# Size of the training set for the black box classifier
n_training = 20000
# Size for the grid to plot the decision boundaries
n_grid = 100
# Number of samples for LIME explanations
n_sample = 500
# Simulate y ~ x1 + x2
set.seed(1)
x1 = runif(n_training, min = lower_x1, max = upper_x1)
x2 = runif(n_training, min = lower_x2, max = upper_x2)
y = get_y(x1, x2)
# Add noise
y_noisy = get_y(x1, x2, noise_prob = 0.01)
lime_training_df = data.frame(x1=x1, x2=x2, y=as.factor(y), y_noisy=as.factor(y_noisy))
# For scaling later on
x_means = c(mean(x1), mean(x2))
x_sd = c(sd(x1), sd(x2))
# Learn model
rf = randomForest::randomForest(y_noisy ~ x1 + x2, data = lime_training_df, ntree=100)
lime_training_df$predicted = predict(rf, lime_training_df)
# The decision boundaries
grid_x1 = seq(from=lower_x1, to=upper_x1, length.out=n_grid)
grid_x2 = seq(from=lower_x2, to=upper_x2, length.out=n_grid)
grid_df = expand.grid(x1 = grid_x1, x2 = grid_x2)
grid_df$predicted = as.numeric(as.character(predict(rf, newdata = grid_df)))
# The observation to be explained
explain_x1 = 1
explain_x2 = -0.5
explain_y_model = predict(rf, newdata = data.frame(x1=explain_x1, x2=explain_x2))
df_explain = data.frame(x1=explain_x1, x2=explain_x2, y_predicted=explain_y_model)
point_explain = c(explain_x1, explain_x2)
point_explain_scaled = (point_explain - x_means) / x_sd
# Drawing the samples for the LIME explanations
x1_sample = rnorm(n_sample, x_means[1], x_sd[1])
x2_sample = rnorm(n_sample, x_means[2], x_sd[2])
df_sample = data.frame(x1 = x1_sample, x2 = x2_sample)
# Scale the samples
points_sample = apply(df_sample, 1, function(x){
(x - x_means) / x_sd
}) %>% t
# Add weights to the samples
kernel_width = sqrt(dim(df_sample)[2]) * 0.15
distances = get_distances(point_explain_scaled,
points_sample = points_sample)
df_sample$weights = kernel(distances, kernel_width=kernel_width)
df_sample$predicted = predict(rf, newdata = df_sample)
# Trees
# mod = rpart(predicted ~ x1 + x2, data = df_sample, weights = df_sample$weights)
# grid_df$explained = predict(mod, newdata = grid_df, type='prob')[,2]
# Logistic regression model
mod = glm(predicted ~ x1 + x2, data = df_sample, weights = df_sample$weights, family='binomial')
## Warning in eval(family$initialize): non-integer #successes in a binomial
## glm!
## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred
grid_df$explained = predict(mod, newdata = grid_df, type='response')
# logistic decision boundary
coefs = coefficients(mod)
logistic_boundary_x1 = grid_x1
logistic_boundary_x2 = - (1/coefs['x2']) * (coefs['(Intercept)'] + coefs['x1'] * grid_x1)
logistic_boundary_df = data.frame(x1 = logistic_boundary_x1, x2 = logistic_boundary_x2)
logistic_boundary_df = filter(logistic_boundary_df, x2 <= upper_x2, x2 >= lower_x2)
# Create a smaller grid for visualization of local model boundaries
x1_steps = unique(grid_df$x1)[seq(from=1, to=n_grid, length.out = 20)]
x2_steps = unique(grid_df$x2)[seq(from=1, to=n_grid, length.out = 20)]
grid_df_small = grid_df[grid_df$x1 %in% x1_steps & grid_df$x2 %in% x2_steps,]
grid_df_small$explained_class = round(grid_df_small$explained)
colors = c('#132B43', '#56B1F7')
# Data with some noise
p_data = ggplot(lime_training_df) +
geom_point(aes(x=x1,y=x2,fill=y_noisy, color=y_noisy), alpha =0.3, shape=21) +
scale_fill_manual(values = colors) +
scale_color_manual(values = colors) +
my_theme(legend.position = 'none')
# The decision boundaries of the learned black box classifier
p_boundaries = ggplot(grid_df) +
geom_raster(aes(x=x1,y=x2,fill=predicted), alpha = 0.3, interpolate=TRUE) +
my_theme(legend.position='none') +
ggtitle('A')
# Drawing some samples
p_samples = p_boundaries +
geom_point(data = df_sample, aes(x=x1, y=x2)) +
scale_x_continuous(limits = c(-2, 2)) +
scale_y_continuous(limits = c(-2, 1))
# The point to be explained
p_explain = p_samples +
geom_point(data = df_explain, aes(x=x1,y=x2), fill = 'yellow', shape = 21, size=4) +
ggtitle('B')
p_weighted = p_boundaries +
geom_point(data = df_sample, aes(x=x1, y=x2, size=weights)) +
scale_x_continuous(limits = c(-2, 2)) +
scale_y_continuous(limits = c(-2, 1)) +
geom_point(data = df_explain, aes(x=x1,y=x2), fill = 'yellow', shape = 21, size=4) +
ggtitle('C')
p_boundaries_lime = ggplot(grid_df) +
geom_raster(aes(x=x1,y=x2,fill=predicted), alpha = 0.3, interpolate=TRUE) +
geom_point(aes(x=x1, y=x2, color=explained), size = 2, data = grid_df_small[grid_df_small$explained_class==1,], shape=3) +
geom_point(aes(x=x1, y=x2, color=explained), size = 2, data = grid_df_small[grid_df_small$explained_class==0,], shape=95) +
geom_point(data = df_explain, aes(x=x1,y=x2), fill = 'yellow', shape = 21, size=4) +
geom_line(aes(x=x1, y=x2), data =logistic_boundary_df, color = 'white') +
my_theme(legend.position='none') + ggtitle('D')
gridExtra::grid.arrange(p_boundaries, p_explain, p_weighted, p_boundaries_lime, ncol=2)
## Warning: Removed 83 rows containing missing values (geom_point).
## Warning: Removed 83 rows containing missing values (geom_point).
As always, the devil is in the detail. Defining a meaningful neighborhood around a point is difficult. LIME currently uses an exponential smoothing kernel to define the neighborhood. A smoothing kernel is a function that takes two data instances and returns a proximity measure. The kernel width determines how large the neighborhood is: A small kernel width means that an instance must be very close to influence the local model, a larger kernel width means that instances that are farther away also influence the model. If you look at LIME’s Python implementation (file lime/lime_tabular.py) you will see that it uses an exponential smoothing kernel (on the normalized data) and the kernel width is 0.75 times the square root of the number of columns of the training data. It looks like an innocent line of code, but it is like an elephant sitting in your living room next to the good porcelain you got from your grandparents. The big problem is that we do not have a good way to find the best kernel or width. And where does the 0.75 even come from? In certain scenarios, you can easily turn your explanation around by changing the kernel width, as shown in the following figure:
set.seed(42)
df = data.frame(x = rnorm(200, mean = 0, sd = 3))
df$x[df$x < -5] = -5
df$y = (df$x + 2)^2
df$y[df$x > 1] = -df$x[df$x > 1] + 10 + - 0.05 * df$x[df$x > 1]^2
#df$y = df$y + rnorm(nrow(df), sd = 0.05)
explain.p = data.frame(x = 1.6, y = 8.5)
w1 = kernel(get_distances(data.frame(x = explain.p$x), df), 0.1)
w2 = kernel(get_distances(data.frame(x = explain.p$x), df), 0.75)
w3 = kernel(get_distances(data.frame(x = explain.p$x), df), 2)
lm.1 = lm(y ~ x, data = df, weights = w1)
lm.2 = lm(y ~ x, data = df, weights = w2)
lm.3 = lm(y ~ x, data = df, weights = w3)
df.all = rbind(df, df, df)
df.all$lime = c(predict(lm.1), predict(lm.2), predict(lm.3))
df.all$width = factor(c(rep(c(0.1, 0.75, 2), each = nrow(df))))
ggplot(df.all, aes(x = x, y = y)) +
geom_line(size = 2.5) +
geom_rug(sides = "b") +
geom_line(aes(x = x, y = lime, group = width, color = width, linetype = width)) +
geom_point(data = explain.p, aes(x = x, y = y), size = 12, shape = "x") +
viridis::scale_color_viridis("Kernel width", discrete = TRUE) +
scale_linetype("Kernel width") +
scale_y_continuous("Black Box prediction")
The example shows only one feature. It gets worse in high-dimensional feature spaces. It is also very unclear whether the distance measure should treat all features equally. Is a distance unit for feature x1 identical to one unit for feature x2? Distance measures are quite arbitrary and distances in different dimensions (aka features) might not be comparable at all.
5.7.1.1 Example
Let us look at a concrete example. We go back to the bike rental data and turn the prediction problem into a classification: After taking into account the trend that the bicycle rental has become more popular over time, we want to know on a certain day whether the number of bicycles rented will be above or below the trend line. You can also interpret “above” as being above the average number of bicycles, but adjusted for the trend.
data("bike")
ntree = 100
bike.train.resid = factor(resid(lm(cnt ~ days_since_2011, data = bike)) > 0, levels = c(FALSE, TRUE), labels = c('below', 'above'))
bike.train.x = bike[names(bike) != 'cnt']
model <- caret::train(bike.train.x,
bike.train.resid,
method = 'rf', ntree=ntree, maximise = FALSE)
## Loading required package: lattice
##
## Attaching package: 'caret'
## The following object is masked from 'package:mlr':
##
## train
n_features_lime = 2
First we train a random forest with 100 trees on the classification task. On what day will the number of rental bikes be above the trend-free average, based on weather and calendar information?
The explanations are created with 2 features. The results of the sparse local linear models trained for two instances with different predicted classes:
library("iml")
library("gridExtra")
##
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
##
## combine
instance_indices = c(295, 8)
set.seed(44)
bike.train.x$temp = round(bike.train.x$temp, 2)
pred = Predictor$new(model, data = bike.train.x, class = "above", type = "prob")
lim1 = LocalModel$new(pred, x.interest = bike.train.x[instance_indices[1],], k = n_features_lime)
## Loading required package: gower
lim2= LocalModel$new(pred, x.interest = bike.train.x[instance_indices[2],], k = n_features_lime)
wlim = c(min(c(lim1$results$effect, lim2$results$effect)), max(c(lim1$results$effect, lim2$results$effect)))
a = plot(lim1) +
scale_y_continuous(limit = wlim) +
geom_hline(aes(yintercept=0)) +
theme(axis.title.y=element_blank(),
axis.ticks.y=element_blank())
b = plot(lim2) +
scale_y_continuous(limit = wlim) +
geom_hline(aes(yintercept=0)) +
theme(axis.title.y=element_blank(),
axis.ticks.y=element_blank())
grid.arrange(a, b, ncol = 1)
From the figure it becomes clear that it is easier to interpret categorical features than numerical features. One solution is to categorize the numerical features into bins.
5.7.2 LIME for Text
LIME for text differs from LIME for tabular data. Variations of the data are generated differently: Starting from the original text, new texts are created by randomly removing words from the original text. The dataset is represented with binary features for each word. A feature is 1 if the corresponding word is included and 0 if it has been removed.
5.7.2.1 Example
In this example we classify YouTube comments as spam or normal.
The black box model is a deep decision tree trained on the document word matrix. Each comment is one document (= one row) and each column is the number of occurrences of a given word. Short decision trees are easy to understand, but in this case the tree is very deep. Also in place of this tree there could have been a recurrent neural network or a support vector machine trained on word embeddings (abstract vectors). Let us look at the two comments of this dataset and the corresponding classes (1 for spam, 0 for normal comment):
data("ycomments")
example_indices = c(267, 173)
texts = ycomments$CONTENT[example_indices]
kable(ycomments[example_indices, c('CONTENT', 'CLASS')])
CONTENT | CLASS | |
---|---|---|
267 | PSY is a good guy | 0 |
173 | For Christmas Song visit my channel! ;) | 1 |
The next step is to create some variations of the datasets used in a local model. For example, some variations of one of the comments:
library("tm")
## Loading required package: NLP
##
## Attaching package: 'NLP'
## The following object is masked from 'package:ggplot2':
##
## annotate
##
## Attaching package: 'tm'
## The following object is masked from 'package:arules':
##
## inspect
labeledTerms = prepare_data(ycomments$CONTENT)
labeledTerms$class = factor(ycomments$CLASS, levels = c(0,1), labels = c('no spam', 'spam'))
labeledTerms2 = prepare_data(ycomments, trained_corpus = labeledTerms)
rp = rpart::rpart(class ~ ., data = labeledTerms)
predict_fun = get_predict_fun(rp, labeledTerms)
tokenized = tokenize(texts[2])
set.seed(2)
variations = create_variations(texts[2], predict_fun, prob=0.7, n_variations = 5, class='spam')
colnames(variations) = c(tokenized, 'prob', 'weight')
example_sentence = paste(colnames(variations)[variations[2, ] == 1], collapse = ' ')
kable(variations)
For | Christmas | Song | visit | my | channel! | ;) | prob | weight | |
---|---|---|---|---|---|---|---|---|---|
2 | 1 | 0 | 1 | 1 | 0 | 0 | 1 | 0.17 | 0.57 |
3 | 0 | 1 | 1 | 1 | 1 | 0 | 1 | 0.17 | 0.71 |
4 | 1 | 0 | 0 | 1 | 1 | 1 | 1 | 0.99 | 0.71 |
5 | 1 | 0 | 1 | 1 | 1 | 1 | 1 | 0.99 | 0.86 |
6 | 0 | 1 | 1 | 1 | 0 | 0 | 1 | 0.17 | 0.57 |
Each column corresponds to one word in the sentence.
Each row is a variation, 1 means that the word is part of this variation and 0 means that the word has been removed.
The corresponding sentence for one of the variations is “Christmas Song visit my ;)
”.
The “prob” column shows the predicted probability of spam for each of the sentence variations.
The “weight” column shows the proximity of the variation to the original sentence, calculated as 1 minus the proportion of words that were removed, for example if 1 out of 7 words was removed, the proximity is 1 - 1/7 = 0.86.
Here are the two sentences (one spam, one no spam) with their estimated local weights found by the LIME algorithm:
set.seed(42)
ycomments.predict = get.ycomments.classifier(ycomments)
explanations = data.table::rbindlist(lapply(seq_along(texts), function(i) {
explain_text(texts[i], ycomments.predict, class='spam', case=i, prob = 0.5)
})
)
## Warning in eval(family$initialize): non-integer #successes in a binomial
## glm!
## Warning in eval(family$initialize): non-integer #successes in a binomial
## glm!
explanations = data.frame(explanations)
kable(explanations[c("case", "label_prob", "feature", "feature_weight")])
case | label_prob | feature | feature_weight |
---|---|---|---|
1 | 0.1701170 | good | 0.000000 |
1 | 0.1701170 | a | 0.000000 |
1 | 0.1701170 | is | 0.000000 |
2 | 0.9939024 | channel! | 6.180747 |
2 | 0.9939024 | For | 0.000000 |
2 | 0.9939024 | ;) | 0.000000 |
The word “channel” indicates a high probability of spam. For the non-spam comment no non-zero weight was estimated, because no matter which word is removed, the predicted class remains the same.
5.7.3 LIME for Images
This section was written by Verena Haunschmid.
LIME for images works differently than LIME for tabular data and text. Intuitively, it would not make much sense to perturb individual pixels, since many more than one pixel contribute to one class. Randomly changing individual pixels would probably not change the predictions by much. Therefore, variations of the images are created by segmenting the image into “superpixels” and turning superpixels off or on. Superpixels are interconnected pixels with similar colors and can be turned off by replacing each pixel with a user-defined color such as gray. The user can also specify a probability for turning off a superpixel in each permutation.
5.7.3.1 Example
Since the computation of image explanations is rather slow, the lime R package contains a precomputed example which we will also use to show the output of the method.
The explanations can be displayed directly on the image samples.
Since we can have several predicted labels per image (sorted by probability), we can explain the top n_labels
.
For the following image the top 3 predictions were electric guitar; acoustic guitar; and Labrador.
# Having trouble to install imagemick in version 6.8.8 or higher on TravisCI,
# which would be required for this code. So running only locally and added the
# image manually.
# For running locally, set eval = TRUE and make sure lime is installed.
library("lime")
explanation <- .load_image_example()
plot_image_explanation(explanation)
knitr::include_graphics("images/lime-images-package-example-1.png")
The prediction and explanation in the first case are very reasonable. The first prediction of electric guitar is of course wrong, but the explanation shows us that the neural network still behaved reasonably because the image part identified suggests that this could be an electric guitar.
5.7.4 Advantages
Even if you replace the underlying machine learning model, you can still use the same local, interpretable model for explanation. Suppose the people looking at the explanations understand decision trees best. Because you use local surrogate models, you use decision trees as explanations without actually having to use a decision tree to make the predictions. For example, you can use a SVM. And if it turns out that an xgboost model works better, you can replace the SVM and still use as decision tree to explain the predictions.
Local surrogate models benefit from the literature and experience of training and interpreting interpretable models.
When using Lasso or short trees, the resulting explanations are short (= selective) and possibly contrastive. Therefore, they make human-friendly explanations. This is why I see LIME more in applications where the recipient of the explanation is a lay person or someone with very little time. It is not sufficient for complete attributions, so I do not see LIME in compliance scenarios where you might be legally required to fully explain a prediction. Also for debugging machine learning models, it is useful to have all the reasons instead of a few.
LIME is one of the few methods that works for tabular data, text and images.
The fidelity measure (how well the interpretable model approximates the black box predictions) gives us a good idea of how reliable the interpretable model is in explaining the black box predictions in the neighborhood of the data instance of interest.
LIME is implemented in Python (lime and Skater) and R (lime package and iml package) and is very easy to use.
The explanations created with local surrogate models can use other features than the original model. This can be a big advantage over other methods, especially if the original features cannot bet interpreted. A text classifier can rely on abstract word embeddings as features, but the explanation can be based on the presence or absence of words in a sentence. A regression model can rely on a non-interpretable transformation of some attributes, but the explanations can be created with the original attributes.
5.7.5 Disadvantages
The correct definition of the neighborhood is a very big, unsolved problem when using LIME with tabular data. In my opinion it is the biggest problem with LIME and the reason why I would recommend to use LIME only with great care. For each application you have to try different kernel settings and see for yourself if the explanations make sense. Unfortunately, this is the best advice I can give to find good kernel widths.
Sampling could be improved in the current implementation of LIME. Data points are sampled from a Gaussian distribution, ignoring the correlation between features. This can lead to unlikely data points which can then be used to learn local explanation models.
The complexity of the explanation model has to be defined in advance. This is just a small complaint, because in the end the user always has to define the compromise between fidelity and sparsity.
Another really big problem is the instability of the explanations. In an article 38 the authors showed that the explanations of two very close points varied greatly in a simulated setting. Also, in my experience, if you repeat the sampling process, then the explantions that come out can be different. Instability means that it is difficult to trust the explanations, and you should be very critical.
Conclusion: Local surrogate models, with LIME as a concrete implementation, are very promising. But the method is still in development phase and many problems need to be solved before it can be safely applied.
Ribeiro, Marco Tulio, Sameer Singh, and Carlos Guestrin. “Why should I trust you?: Explaining the predictions of any classifier.” Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining. ACM (2016).↩
Alvarez-Melis, David, and Tommi S. Jaakkola. “On the robustness of interpretability methods.” arXiv preprint arXiv:1806.08049 (2018).↩