Skip to contents

The marginaleffects package includes three flexible functions to plot estimates and display interactions.

Those functions can be used to plot two kinds of quantities:

  • Conditional estimates:
  • Marginal estimates:

To begin, let’s download data and fit a model:

# libraries
library(ggplot2)
library(patchwork) # combine plots with the + and / signs
library(marginaleffects)

# visual theme
theme_set(theme_minimal())
okabeito <- c('#E69F00', '#56B4E9', '#009E73', '#F0E442', '#0072B2', '#D55E00', '#CC79A7', '#999999', '#000000')
options(ggplot2.discrete.fill = okabeito)
options(ggplot2.discrete.colour = okabeito)
options(width = 1000)

# download data
dat <- read.csv("https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv")

mod <- lm(body_mass_g ~ flipper_length_mm * species * bill_length_mm + island, data = dat)

Predictions

Conditional predictions

We call a prediction “conditional” when it is made on a grid of user-specified values. For example, we predict penguins’ body mass for different values of flipper length and species:

pre <- predictions(mod, newdata = datagrid(flipper_length_mm = c(172, 231), species = unique))
pre
#> 
#>  Estimate Std. Error    z Pr(>|z|) 2.5 % 97.5 % bill_length_mm island flipper_length_mm   species
#>      3859        204 18.9   <0.001  3460   4259           43.9 Biscoe               172 Adelie   
#>      2545        369  6.9   <0.001  1822   3268           43.9 Biscoe               172 Gentoo   
#>      3146        234 13.5   <0.001  2688   3604           43.9 Biscoe               172 Chinstrap
#>      4764        362 13.2   <0.001  4054   5474           43.9 Biscoe               231 Adelie   
#>      5597        155 36.0   <0.001  5292   5901           43.9 Biscoe               231 Gentoo   
#>      4086        469  8.7   <0.001  3166   5006           43.9 Biscoe               231 Chinstrap
#> 
#> Columns: rowid, estimate, std.error, statistic, p.value, conf.low, conf.high, body_mass_g, bill_length_mm, island, flipper_length_mm, species

The condition argument of the plot_predictions() function can be used to build meaningful grids of predictor values somewhat more easily:

plot_predictions(mod, condition = c("flipper_length_mm", "species"))

Note that the values at each end of the x-axis correspond to the numerical results produced above. For example, the predicted outcome for a Gentoo with 231mm flippers is 5597.

We can include a 3rd conditioning variable, specify what values we want to consider, supply R functions to compute summaries, and use one of several string shortcuts for common reference values (“threenum”, “minmax”, “quartile”, etc.):

plot_predictions(
    mod,
    condition = list(
        "flipper_length_mm" = 180:220,
        "bill_length_mm" = "threenum",
        "species" = unique))

See ?plot_predictions for more information.

Marginal predictions

We call a prediction “marginal” when it is the result of a two step process: (1) make predictions for each observed unit in the original dataset, and (2) average predictions across one or more categorical predictors. For example:

predictions(mod, by = "species")
#> 
#>    species Estimate Std. Error     z Pr(>|z|) 2.5 % 97.5 %
#>  Adelie        3701       27.2 136.1   <0.001  3647   3754
#>  Gentoo        5076       30.1 168.5   <0.001  5017   5135
#>  Chinstrap     3733       40.5  92.2   <0.001  3654   3812
#> 
#> Columns: species, estimate, std.error, statistic, p.value, conf.low, conf.high

We can plot those predictions by using the analogous command:

plot_predictions(mod, by = "species")

We can also make predictions at the intersections of different variables:

predictions(mod, by = c("species", "island"))
#> 
#>    species    island Estimate Std. Error     z Pr(>|z|) 2.5 % 97.5 %
#>  Adelie    Torgersen     3706       46.8  79.2   <0.001  3615   3798
#>  Adelie    Biscoe        3710       50.4  73.7   <0.001  3611   3808
#>  Adelie    Dream         3688       44.6  82.6   <0.001  3601   3776
#>  Gentoo    Biscoe        5076       30.1 168.5   <0.001  5017   5135
#>  Chinstrap Dream         3733       40.5  92.2   <0.001  3654   3812
#> 
#> Columns: species, island, estimate, std.error, statistic, p.value, conf.low, conf.high

Note that certain species only live on certain islands. Visually:

plot_predictions(mod, by = c("species", "island"))

Comparisons

Conditional comparisons

The syntax for conditional comparisons is the same as the syntax for conditional predictions, except that we now need to specify the variable(s) of interest using an additional argument:

comparisons(mod,
  variables = "flipper_length_mm",
  newdata = datagrid(flipper_length_mm = c(172, 231), species = unique))
#> 
#>               Term Contrast Estimate Std. Error    z Pr(>|z|) 2.5 % 97.5 % bill_length_mm island   species
#>  flipper_length_mm       +1     15.3       9.25 1.66   0.0976 -2.81   33.5           43.9 Biscoe Adelie   
#>  flipper_length_mm       +1     51.7       8.70 5.95   <0.001 34.68   68.8           43.9 Biscoe Gentoo   
#>  flipper_length_mm       +1     15.9      11.37 1.40   0.1610 -6.34   38.2           43.9 Biscoe Chinstrap
#>  flipper_length_mm       +1     15.3       9.25 1.66   0.0976 -2.81   33.5           43.9 Biscoe Adelie   
#>  flipper_length_mm       +1     51.7       8.70 5.95   <0.001 34.68   68.8           43.9 Biscoe Gentoo   
#>  flipper_length_mm       +1     15.9      11.37 1.40   0.1610 -6.34   38.2           43.9 Biscoe Chinstrap
#> 
#> Columns: rowid, term, contrast, estimate, std.error, statistic, p.value, conf.low, conf.high, predicted, predicted_hi, predicted_lo, body_mass_g, bill_length_mm, island, flipper_length_mm, species

plot_comparisons(mod,
  variables = "flipper_length_mm",
  condition = c("bill_length_mm", "species"))

We can specify custom comparisons, as we would using the variables argument of the comparisons() function. For example, see what happens to the predicted outcome when flipper_length_mm increases by 1 standard deviation or by 10mm:

plot_comparisons(mod,
  variables = list("flipper_length_mm" = "sd"),
  condition = c("bill_length_mm", "species")) +

plot_comparisons(mod,
  variables = list("flipper_length_mm" = 10),
  condition = c("bill_length_mm", "species"))

Notice that the vertical scale is different in the plots above, reflecting the fact that we are plotting the effect of a change of 1 standard deviation on the left vs 10 units on the right.

Like the comparisons() function, plot_comparisons() is a very powerful tool because it allows us to compute and display custom comparisons such as differences, ratios, odds, and arbitrary functions of predicted outcomes. For example, if we want to plot the ratio of predicted body mass for different species of penguins, we could do:

plot_comparisons(mod,
  variables = "species",
  condition = "bill_length_mm",
  comparison = "ratio")

The left panel shows that the ratio of Chinstrap body mass to Adelie body mass is approximately constant, at slightly above 0.8. The right panel shows that the ratio of Gentoo to Adelie body mass is depends on their bill length. For birds with short bills, Gentoos seem to have smaller body mass than Adelies. For birds with long bills, Gentoos seem heavier than Adelies, although the null ratio (1) is not outside the confidence interval.

Marginal comparisons

As above, we can also display marginal comparisons, by subgroups:

plot_comparisons(mod,
  variables = "flipper_length_mm",
  by = "species") +

plot_comparisons(mod,
  variables = "flipper_length_mm",
  by = c("species", "island"))

Multiple contrasts at once:

plot_comparisons(mod,
  variables = c("flipper_length_mm", "bill_length_mm"),
  by = c("species", "island"))

Slopes

If you have read the sections above, the behavior of the plot_slopes() function should not surprise. Here we give two examples in which we compute display the elasticity of body mass with respect to bill length:

# conditional
plot_slopes(mod,
  variables = "bill_length_mm",
  slope = "eyex",
  condition = c("species", "island"))


# marginal
plot_slopes(mod,
  variables = "bill_length_mm",
  slope = "eyex",
  by = c("species", "island"))

And here is an example of a marginal effects (aka “slopes” or “partial derivatives”) plot for a model with multiplicative interactions between continuous variables:

mod2 <- lm(mpg ~ wt * qsec * factor(gear), data = mtcars)

plot_slopes(mod2, variables = "qsec", condition = c("wt", "gear"))

Uncertainty estimates

As with all the other functions in the package, the plot_*() functions have a conf_level argument and a vcov argument which can be used to control the size of confidence intervals and the types of standard errors used:

plot_slopes(mod,
  variables = "bill_length_mm", condition = "flipper_length_mm") +
  ylim(c(-150, 200)) +

# clustered standard errors
plot_slopes(mod,
  vcov = ~island,
  variables = "bill_length_mm", condition = "flipper_length_mm") +
  ylim(c(-150, 200)) +

# alpha level
plot_slopes(mod,
  conf_level = .8,
  variables = "bill_length_mm", condition = "flipper_length_mm") +
  ylim(c(-150, 200))

Customization

A very useful feature of the plotting functions in this package is that they produce normal ggplot2 objects. So we can customize them to our heart’s content, using ggplot2 itself, or one of the many packages designed to augment its functionalities:

library(ggrepel)

mt <- mtcars
mt$label <- row.names(mt)

mod <- lm(mpg ~ hp * factor(cyl), data = mt)

plot_predictions(mod, condition = c("hp", "cyl"), points = .5, rug = TRUE, vcov = FALSE) +
    geom_text_repel(aes(x = hp, y = mpg, label = label),
                    data = subset(mt, hp > 250),
                    nudge_y = 2) +
    theme_classic()

All the plotting functions work with all the model supported by the marginaleffects package, so we can plot the output of a logistic regression model. This plot shows the probability of survival aboard the Titanic, for different ages and different ticket classes:

library(ggdist)
library(ggplot2)

dat <- "https://vincentarelbundock.github.io/Rdatasets/csv/Stat2Data/Titanic.csv"
dat <- read.csv(dat)

mod <- glm(Survived ~ Age * SexCode * PClass, data = dat, family = binomial)

plot_predictions(mod, condition = c("Age", "PClass")) +
    geom_dots(
        alpha = .8,
        scale = .3,
        pch = 18,
        data = dat, aes(
        x = Age,
        y = Survived,
        side = ifelse(Survived == 1, "bottom", "top")))

Thanks to Andrew Heiss who inspired this plot.

Designing effective data visualizations requires a lot of customization to the specific context and data. The plotting functions in marginaleffects offer a powerful way to iterate quickly between plots and models, but they obviously cannot support all the features that users may want. Thankfully, it is very easy to use the slopes functions to generate datasets that can then be used in ggplot2 or any other data visualization tool. Just use the draw argument:

p <- plot_predictions(mod, condition = c("Age", "PClass"), draw = FALSE)
head(p)
#>   rowid  estimate      p.value  conf.low conf.high  Survived   SexCode     Age PClass
#> 1     1 0.8679723 0.0013307148 0.6754794 0.9540527 0.4140212 0.3809524 0.17000    1st
#> 2     2 0.8956789 0.0001333343 0.7401973 0.9627887 0.4140212 0.3809524 0.17000    2nd
#> 3     3 0.4044513 0.2667759617 0.2554245 0.5734603 0.4140212 0.3809524 0.17000    3rd
#> 4     4 0.8631027 0.0011563593 0.6749549 0.9503543 0.4140212 0.3809524 1.61551    1st
#> 5     5 0.8813224 0.0001728862 0.7228529 0.9548415 0.4140212 0.3809524 1.61551    2nd
#> 6     6 0.3934924 0.1899483112 0.2535791 0.5533716 0.4140212 0.3809524 1.61551    3rd

This allows us to feed the data easily to other functions, such as those in the useful ggdist and distributional packages:

library(ggdist)
library(distributional)
plot_slopes(mod, variables = "SexCode", condition = c("Age", "PClass"), type = "link", draw = FALSE) |>
  ggplot() +
  stat_lineribbon(aes(
    x = Age,
    ydist = dist_normal(mu = estimate, sigma = std.error),
    fill = PClass),
    alpha = 1 / 4)

Fits and smooths

We can compare the model predictors with fits and smoothers using the geom_smooth() function from the ggplot2 package:

dat <- "https://vincentarelbundock.github.io/Rdatasets/csv/Stat2Data/Titanic.csv"
dat <- read.csv(dat)
mod <- glm(Survived ~ Age * PClass, data = dat, family = binomial)

plot_predictions(mod, condition = c("Age", "PClass")) +
    geom_smooth(data = dat, aes(Age, Survived), method = "lm", se = FALSE, color = "black") +
    geom_smooth(data = dat, aes(Age, Survived), se = FALSE, color = "black")

Groups and categorical outcomes

In some models, marginaleffects functions generate different estimates for different groups, such as categorical outcomes. For example,

library(MASS)
mod <- polr(factor(gear) ~ mpg + hp, data = mtcars)

predictions(mod)
#> 
#>  Group Estimate Std. Error    z Pr(>|z|)   2.5 % 97.5 %
#>      3   0.5316     0.1127 4.72   <0.001  0.3107  0.753
#>      3   0.5316     0.1127 4.72   <0.001  0.3107  0.753
#>      3   0.4492     0.1200 3.74   <0.001  0.2140  0.684
#>      3   0.4944     0.1111 4.45   <0.001  0.2765  0.712
#>      3   0.4213     0.1142 3.69   <0.001  0.1974  0.645
#> --- 86 rows omitted. See ?avg_predictions and ?print.marginaleffects --- 
#>      5   0.6894     0.1956 3.52   <0.001  0.3059  1.073
#>      5   0.1650     0.1290 1.28   0.2010 -0.0879  0.418
#>      5   0.1245     0.0698 1.78   0.0744 -0.0123  0.261
#>      5   0.3779     0.3244 1.17   0.2440 -0.2579  1.014
#>      5   0.0667     0.0458 1.46   0.1455 -0.0231  0.157
#> Columns: rowid, group, estimate, std.error, statistic, p.value, conf.low, conf.high, gear, mpg, hp

We can plot those estimates in the same way as before, by specifying group as one of the conditional variable, or by adding that column to a facet_wrap() call:

plot_predictions(mod, condition = c("mpg", "group"), type = "probs", vcov = FALSE)


plot_predictions(mod, condition = "mpg", type = "probs", vcov = FALSE) +
  facet_wrap(~ group)

plot() and marginaleffects objects

Some users may feel inclined to call plot() on a object produced by marginaleffects object. Doing so will generate an informative error like this one:

mod <- lm(mpg ~ hp * wt * factor(cyl), data = mtcars)
p <- predictions(mod)
plot(p)
#> Error: Please use the `plot_predictions()` function.

The reason for this error is that the user query is underspecified. marginaleffects allows users to compute so many quantities of interest that it is not clear what the user wants when they simply call plot(). Adding several new arguments would compete with the main plotting functions, and risk sowing confusion. The marginaleffects developers thus decided to support one main path to plotting: plot_predictions(), plot_comparisons(), and plot_slopes().

That said, it may be useful to remind users that all marginaleffects output are standard “tidy” data frames. Although they get pretty-printed to the console, all the listed columns are accessible via standard R operators. For example:

p <- avg_predictions(mod, by = "cyl")
p
#> 
#>  cyl Estimate Std. Error    z Pr(>|z|) 2.5 % 97.5 %
#>    6     19.7      0.871 22.7   <0.001  18.0   21.5
#>    4     26.7      0.695 38.4   <0.001  25.3   28.0
#>    8     15.1      0.616 24.5   <0.001  13.9   16.3
#> 
#> Columns: cyl, estimate, std.error, statistic, p.value, conf.low, conf.high

p$estimate
#> [1] 19.74286 26.66364 15.10000

p$std.error
#> [1] 0.8713835 0.6951236 0.6161612

p$conf.low
#> [1] 18.03498 25.30122 13.89235

This allows us to plot all results very easily with standard plotting functions:

plot_predictions(mod, by = "cyl")


plot(p$cyl, p$estimate)


ggplot(p, aes(x = cyl, y = estimate, ymin = conf.low, ymax = conf.high)) +
  geom_pointrange()