The `marginaleffects`

package includes three flexible
functions to plot estimates and display interactions.

Those functions can be used to plot two kinds of quantities:

- Conditional estimates:
- Estimates computed on a substantively meaningful grid of predictor values.
- This is analogous to using the
`newdata`

argument with the`datagrid()`

function in a`predictions()`

,`comparisons()`

, or`slopes()`

call.

- Marginal estimates:
- Estimates computed on the original data, but averaged by subgroup.
- This is analogous to using the
`newdata`

argument with the`datagrid()`

function in a`predictions()`

,`comparisons()`

, or`slopes()`

call.

To begin, let’s download data and fit a model:

```
# libraries
library(ggplot2)
library(patchwork) # combine plots with the + and / signs
library(marginaleffects)
# visual theme
theme_set(theme_minimal())
okabeito <- c('#E69F00', '#56B4E9', '#009E73', '#F0E442', '#0072B2', '#D55E00', '#CC79A7', '#999999', '#000000')
options(ggplot2.discrete.fill = okabeito)
options(ggplot2.discrete.colour = okabeito)
options(width = 1000)
# download data
dat <- read.csv("https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv")
mod <- lm(body_mass_g ~ flipper_length_mm * species * bill_length_mm + island, data = dat)
```

## Predictions

### Conditional predictions

We call a prediction “conditional” when it is made on a grid of user-specified values. For example, we predict penguins’ body mass for different values of flipper length and species:

```
pre <- predictions(mod, newdata = datagrid(flipper_length_mm = c(172, 231), species = unique))
pre
#>
#> Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 % bill_length_mm island flipper_length_mm species
#> 3859 204 18.9 <0.001 3460 4259 43.9 Biscoe 172 Adelie
#> 2545 369 6.9 <0.001 1822 3268 43.9 Biscoe 172 Gentoo
#> 3146 234 13.5 <0.001 2688 3604 43.9 Biscoe 172 Chinstrap
#> 4764 362 13.2 <0.001 4054 5474 43.9 Biscoe 231 Adelie
#> 5597 155 36.0 <0.001 5292 5901 43.9 Biscoe 231 Gentoo
#> 4086 469 8.7 <0.001 3166 5006 43.9 Biscoe 231 Chinstrap
#>
#> Columns: rowid, estimate, std.error, statistic, p.value, conf.low, conf.high, body_mass_g, bill_length_mm, island, flipper_length_mm, species
```

The `condition`

argument of the
`plot_predictions()`

function can be used to build meaningful
grids of predictor values somewhat more easily:

`plot_predictions(mod, condition = c("flipper_length_mm", "species"))`

Note that the values at each end of the x-axis correspond to the numerical results produced above. For example, the predicted outcome for a Gentoo with 231mm flippers is 5597.

We can include a 3rd conditioning variable, specify what values we
want to consider, supply `R`

functions to compute summaries,
and use one of several string shortcuts for common reference values
(“threenum”, “minmax”, “quartile”, etc.):

```
plot_predictions(
mod,
condition = list(
"flipper_length_mm" = 180:220,
"bill_length_mm" = "threenum",
"species" = unique))
```

See `?plot_predictions`

for more information.

### Marginal predictions

We call a prediction “marginal” when it is the result of a two step process: (1) make predictions for each observed unit in the original dataset, and (2) average predictions across one or more categorical predictors. For example:

```
predictions(mod, by = "species")
#>
#> species Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> Adelie 3701 27.2 136.1 <0.001 3647 3754
#> Gentoo 5076 30.1 168.5 <0.001 5017 5135
#> Chinstrap 3733 40.5 92.2 <0.001 3654 3812
#>
#> Columns: species, estimate, std.error, statistic, p.value, conf.low, conf.high
```

We can plot those predictions by using the analogous command:

`plot_predictions(mod, by = "species")`

We can also make predictions at the intersections of different variables:

```
predictions(mod, by = c("species", "island"))
#>
#> species island Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> Adelie Torgersen 3706 46.8 79.2 <0.001 3615 3798
#> Adelie Biscoe 3710 50.4 73.7 <0.001 3611 3808
#> Adelie Dream 3688 44.6 82.6 <0.001 3601 3776
#> Gentoo Biscoe 5076 30.1 168.5 <0.001 5017 5135
#> Chinstrap Dream 3733 40.5 92.2 <0.001 3654 3812
#>
#> Columns: species, island, estimate, std.error, statistic, p.value, conf.low, conf.high
```

Note that certain species only live on certain islands. Visually:

`plot_predictions(mod, by = c("species", "island"))`

## Comparisons

### Conditional comparisons

The syntax for conditional comparisons is the same as the syntax for conditional predictions, except that we now need to specify the variable(s) of interest using an additional argument:

```
comparisons(mod,
variables = "flipper_length_mm",
newdata = datagrid(flipper_length_mm = c(172, 231), species = unique))
#>
#> Term Contrast Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 % bill_length_mm island species
#> flipper_length_mm +1 15.3 9.25 1.66 0.0976 -2.81 33.5 43.9 Biscoe Adelie
#> flipper_length_mm +1 51.7 8.70 5.95 <0.001 34.68 68.8 43.9 Biscoe Gentoo
#> flipper_length_mm +1 15.9 11.37 1.40 0.1610 -6.34 38.2 43.9 Biscoe Chinstrap
#> flipper_length_mm +1 15.3 9.25 1.66 0.0976 -2.81 33.5 43.9 Biscoe Adelie
#> flipper_length_mm +1 51.7 8.70 5.95 <0.001 34.68 68.8 43.9 Biscoe Gentoo
#> flipper_length_mm +1 15.9 11.37 1.40 0.1610 -6.34 38.2 43.9 Biscoe Chinstrap
#>
#> Columns: rowid, term, contrast, estimate, std.error, statistic, p.value, conf.low, conf.high, predicted, predicted_hi, predicted_lo, body_mass_g, bill_length_mm, island, flipper_length_mm, species
plot_comparisons(mod,
variables = "flipper_length_mm",
condition = c("bill_length_mm", "species"))
```

We can specify custom comparisons, as we would using the
`variables`

argument of the `comparisons()`

function. For example, see what happens to the predicted outcome when
`flipper_length_mm`

increases by 1 standard deviation or by
10mm:

```
plot_comparisons(mod,
variables = list("flipper_length_mm" = "sd"),
condition = c("bill_length_mm", "species")) +
plot_comparisons(mod,
variables = list("flipper_length_mm" = 10),
condition = c("bill_length_mm", "species"))
```

Notice that the vertical scale is different in the plots above, reflecting the fact that we are plotting the effect of a change of 1 standard deviation on the left vs 10 units on the right.

Like the `comparisons()`

function,
`plot_comparisons()`

is a very powerful tool because it
allows us to compute and display custom comparisons such as differences,
ratios, odds, and arbitrary functions of predicted outcomes. For
example, if we want to plot the ratio of predicted body mass for
different species of penguins, we could do:

```
plot_comparisons(mod,
variables = "species",
condition = "bill_length_mm",
comparison = "ratio")
```

The left panel shows that the ratio of Chinstrap body mass to Adelie body mass is approximately constant, at slightly above 0.8. The right panel shows that the ratio of Gentoo to Adelie body mass is depends on their bill length. For birds with short bills, Gentoos seem to have smaller body mass than Adelies. For birds with long bills, Gentoos seem heavier than Adelies, although the null ratio (1) is not outside the confidence interval.

### Marginal comparisons

As above, we can also display marginal comparisons, by subgroups:

```
plot_comparisons(mod,
variables = "flipper_length_mm",
by = "species") +
plot_comparisons(mod,
variables = "flipper_length_mm",
by = c("species", "island"))
```

Multiple contrasts at once:

```
plot_comparisons(mod,
variables = c("flipper_length_mm", "bill_length_mm"),
by = c("species", "island"))
```

## Slopes

If you have read the sections above, the behavior of the
`plot_slopes()`

function should not surprise. Here we give
two examples in which we compute display the elasticity of body mass
with respect to bill length:

```
# conditional
plot_slopes(mod,
variables = "bill_length_mm",
slope = "eyex",
condition = c("species", "island"))
```

```
# marginal
plot_slopes(mod,
variables = "bill_length_mm",
slope = "eyex",
by = c("species", "island"))
```

And here is an example of a marginal effects (aka “slopes” or “partial derivatives”) plot for a model with multiplicative interactions between continuous variables:

```
mod2 <- lm(mpg ~ wt * qsec * factor(gear), data = mtcars)
plot_slopes(mod2, variables = "qsec", condition = c("wt", "gear"))
```

## Uncertainty estimates

As with all the other functions in the package, the
`plot_*()`

functions have a `conf_level`

argument
and a `vcov`

argument which can be used to control the size
of confidence intervals and the types of standard errors used:

```
plot_slopes(mod,
variables = "bill_length_mm", condition = "flipper_length_mm") +
ylim(c(-150, 200)) +
# clustered standard errors
plot_slopes(mod,
vcov = ~island,
variables = "bill_length_mm", condition = "flipper_length_mm") +
ylim(c(-150, 200)) +
# alpha level
plot_slopes(mod,
conf_level = .8,
variables = "bill_length_mm", condition = "flipper_length_mm") +
ylim(c(-150, 200))
```

## Customization

A very useful feature of the plotting functions in this package is
that they produce normal `ggplot2`

objects. So we can
customize them to our heart’s content, using `ggplot2`

itself, or one of the many packages designed to augment its
functionalities:

```
library(ggrepel)
mt <- mtcars
mt$label <- row.names(mt)
mod <- lm(mpg ~ hp * factor(cyl), data = mt)
plot_predictions(mod, condition = c("hp", "cyl"), points = .5, rug = TRUE, vcov = FALSE) +
geom_text_repel(aes(x = hp, y = mpg, label = label),
data = subset(mt, hp > 250),
nudge_y = 2) +
theme_classic()
```

All the plotting functions work with all the model supported by the
`marginaleffects`

package, so we can plot the output of a
logistic regression model. This plot shows the probability of survival
aboard the Titanic, for different ages and different ticket classes:

```
library(ggdist)
library(ggplot2)
dat <- "https://vincentarelbundock.github.io/Rdatasets/csv/Stat2Data/Titanic.csv"
dat <- read.csv(dat)
mod <- glm(Survived ~ Age * SexCode * PClass, data = dat, family = binomial)
plot_predictions(mod, condition = c("Age", "PClass")) +
geom_dots(
alpha = .8,
scale = .3,
pch = 18,
data = dat, aes(
x = Age,
y = Survived,
side = ifelse(Survived == 1, "bottom", "top")))
```

Thanks to Andrew Heiss who inspired this plot.

Designing effective data visualizations requires a lot of
customization to the specific context and data. The plotting functions
in `marginaleffects`

offer a powerful way to iterate quickly
between plots and models, but they obviously cannot support all the
features that users may want. Thankfully, it is very easy to use the
`slopes`

functions to generate datasets that can then be used
in `ggplot2`

or any other data visualization tool. Just use
the `draw`

argument:

```
p <- plot_predictions(mod, condition = c("Age", "PClass"), draw = FALSE)
head(p)
#> rowid estimate p.value conf.low conf.high Survived SexCode Age PClass
#> 1 1 0.8679723 0.0013307148 0.6754794 0.9540527 0.4140212 0.3809524 0.17000 1st
#> 2 2 0.8956789 0.0001333343 0.7401973 0.9627887 0.4140212 0.3809524 0.17000 2nd
#> 3 3 0.4044513 0.2667759617 0.2554245 0.5734603 0.4140212 0.3809524 0.17000 3rd
#> 4 4 0.8631027 0.0011563593 0.6749549 0.9503543 0.4140212 0.3809524 1.61551 1st
#> 5 5 0.8813224 0.0001728862 0.7228529 0.9548415 0.4140212 0.3809524 1.61551 2nd
#> 6 6 0.3934924 0.1899483112 0.2535791 0.5533716 0.4140212 0.3809524 1.61551 3rd
```

This allows us to feed the data easily to other functions, such as
those in the useful `ggdist`

and `distributional`

packages:

```
library(ggdist)
library(distributional)
plot_slopes(mod, variables = "SexCode", condition = c("Age", "PClass"), type = "link", draw = FALSE) |>
ggplot() +
stat_lineribbon(aes(
x = Age,
ydist = dist_normal(mu = estimate, sigma = std.error),
fill = PClass),
alpha = 1 / 4)
```

## Fits and smooths

We can compare the model predictors with fits and smoothers using the
`geom_smooth()`

function from the `ggplot2`

package:

```
dat <- "https://vincentarelbundock.github.io/Rdatasets/csv/Stat2Data/Titanic.csv"
dat <- read.csv(dat)
mod <- glm(Survived ~ Age * PClass, data = dat, family = binomial)
plot_predictions(mod, condition = c("Age", "PClass")) +
geom_smooth(data = dat, aes(Age, Survived), method = "lm", se = FALSE, color = "black") +
geom_smooth(data = dat, aes(Age, Survived), se = FALSE, color = "black")
```

## Groups and categorical outcomes

In some models, `marginaleffects`

functions generate
different estimates for different groups, such as categorical outcomes.
For example,

```
library(MASS)
mod <- polr(factor(gear) ~ mpg + hp, data = mtcars)
predictions(mod)
#>
#> Group Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> 3 0.5316 0.1127 4.72 <0.001 0.3107 0.753
#> 3 0.5316 0.1127 4.72 <0.001 0.3107 0.753
#> 3 0.4492 0.1200 3.74 <0.001 0.2140 0.684
#> 3 0.4944 0.1111 4.45 <0.001 0.2765 0.712
#> 3 0.4213 0.1142 3.69 <0.001 0.1974 0.645
#> --- 86 rows omitted. See ?avg_predictions and ?print.marginaleffects ---
#> 5 0.6894 0.1956 3.52 <0.001 0.3059 1.073
#> 5 0.1650 0.1290 1.28 0.2010 -0.0879 0.418
#> 5 0.1245 0.0698 1.78 0.0744 -0.0123 0.261
#> 5 0.3779 0.3244 1.17 0.2440 -0.2579 1.014
#> 5 0.0667 0.0458 1.46 0.1455 -0.0231 0.157
#> Columns: rowid, group, estimate, std.error, statistic, p.value, conf.low, conf.high, gear, mpg, hp
```

We can plot those estimates in the same way as before, by specifying
`group`

as one of the conditional variable, or by adding that
column to a `facet_wrap()`

call:

`plot_predictions(mod, condition = c("mpg", "group"), type = "probs", vcov = FALSE)`

```
plot_predictions(mod, condition = "mpg", type = "probs", vcov = FALSE) +
facet_wrap(~ group)
```

##
`plot()`

and `marginaleffects`

objects

Some users may feel inclined to call `plot()`

on a object
produced by `marginaleffects`

object. Doing so will generate
an informative error like this one:

```
mod <- lm(mpg ~ hp * wt * factor(cyl), data = mtcars)
p <- predictions(mod)
plot(p)
#> Error: Please use the `plot_predictions()` function.
```

The reason for this error is that the user query is underspecified.
`marginaleffects`

allows users to compute so many quantities
of interest that it is not clear what the user wants when they simply
call `plot()`

. Adding several new arguments would compete
with the main plotting functions, and risk sowing confusion. The
`marginaleffects`

developers thus decided to support one main
path to plotting: `plot_predictions()`

,
`plot_comparisons()`

, and `plot_slopes()`

.

That said, it may be useful to remind users that all
`marginaleffects`

output are standard “tidy” data frames.
Although they get pretty-printed to the console, all the listed columns
are accessible via standard `R`

operators. For example:

```
p <- avg_predictions(mod, by = "cyl")
p
#>
#> cyl Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> 6 19.7 0.871 22.7 <0.001 18.0 21.5
#> 4 26.7 0.695 38.4 <0.001 25.3 28.0
#> 8 15.1 0.616 24.5 <0.001 13.9 16.3
#>
#> Columns: cyl, estimate, std.error, statistic, p.value, conf.low, conf.high
p$estimate
#> [1] 19.74286 26.66364 15.10000
p$std.error
#> [1] 0.8713835 0.6951236 0.6161612
p$conf.low
#> [1] 18.03498 25.30122 13.89235
```

This allows us to plot all results very easily with standard plotting functions:

`plot_predictions(mod, by = "cyl")`

```
plot(p$cyl, p$estimate)
```

```
ggplot(p, aes(x = cyl, y = estimate, ymin = conf.low, ymax = conf.high)) +
geom_pointrange()
```