Skip to contents

The marginaleffects package offers convenience functions to compute and display predictions, contrasts, and marginal effects from bayesian models estimated by the brms package. To compute these quantities, marginaleffects relies on workhorse functions from the brms package to draw from the posterior distribution. The type of draws used is controlled by using the type argument of the predictions or marginaleffects functions:

  • type = "response": Compute posterior draws of the expected value using the brms::posterior_epred function.
  • type = "link": Compute posterior draws of the linear predictor using the brms::posterior_linpred function.
  • type = "prediction": Compute posterior draws of the posterior predictive distribution using the brms::posterior_predict function.

The predictions and marginaleffects functions can also pass additional arguments to the brms prediction functions via the ... ellipsis. For example, if mod is a mixed-effects model, then this command will compute 10 draws from the posterior predictive distribution, while ignoring all group-level effects:

predictions(mod, type = "prediction", ndraws = 10, re_formula = NA)

See the brms documentation for a list of available arguments:

Logistic regression with multiplicative interactions

Load libraries and download data on passengers of the Titanic from the Rdatasets archive:

library(marginaleffects)
library(brms)
library(ggplot2)
library(ggdist)
library(magrittr)

dat <- read.csv("https://vincentarelbundock.github.io/Rdatasets/csv/carData/TitanicSurvival.csv")
dat$survived <- ifelse(dat$survived == "yes", 1, 0)
dat$woman <- ifelse(dat$sex == "female", 1, 0)

Fit a logit model with a multiplicative interaction:

mod <- brm(survived ~ woman * age + passengerClass,
           family = bernoulli(link = "logit"),
           data = dat)

Adjusted predictions

We can compute adjusted predicted values of the outcome variable (i.e., probability of survival aboard the Titanic) using the predictions function. By default, this function calculates predictions for each row of the dataset:

pred <- predictions(mod)
head(pred)
#>   rowid     type predicted  conf.low conf.high survived woman     age
#> 1     1 response 0.9366604 0.9069674 0.9590097        1     1 29.0000
#> 2     2 response 0.8493050 0.7453010 0.9186720        1     0  0.9167
#> 3     3 response 0.9433293 0.8948592 0.9704210        0     1  2.0000
#> 4     4 response 0.5131011 0.4302430 0.5999582        0     0 30.0000
#> 5     5 response 0.9374937 0.9080051 0.9600572        0     1 25.0000
#> 6     6 response 0.2730542 0.2028999 0.3517513        1     0 48.0000
#>   passengerClass
#> 1            1st
#> 2            1st
#> 3            1st
#> 4            1st
#> 5            1st
#> 6            1st

To visualize the relationship between the outcome and one of the regressors, we can plot conditional adjusted predictions with the plot_cap function:

plot_cap(mod, condition = "age")

Compute adjusted predictions for some user-specified values of the regressors, using the newdata argument and the datagrid function:

pred <- predictions(mod,
                    newdata = datagrid(woman = 0:1,
                                       passengerClass = c("1st", "2nd", "3rd")))
pred
#>   rowid     type  predicted   conf.low conf.high      age woman passengerClass
#> 1     1 response 0.51492993 0.43192231 0.6018749 29.88113     0            1st
#> 2     2 response 0.20128833 0.15362308 0.2613351 29.88113     0            2nd
#> 3     3 response 0.08750369 0.06555724 0.1141134 29.88113     0            3rd
#> 4     4 response 0.93641346 0.90660921 0.9587589 29.88113     1            1st
#> 5     5 response 0.77829290 0.70896643 0.8346419 29.88113     1            2nd
#> 6     6 response 0.57010265 0.49377997 0.6441967 29.88113     1            3rd
#>   survived
#> 1        1
#> 2        1
#> 3        1
#> 4        1
#> 5        1
#> 6        1

The posteriordraws function samples from the posterior distribution of the model, and produces a data frame with drawid and draw columns.

pred <- posteriordraws(pred)
head(pred)
#>   drawid       draw rowid     type  predicted   conf.low conf.high      age
#> 1      1 0.46566713     1 response 0.51492993 0.43192231 0.6018749 29.88113
#> 2      1 0.16658900     2 response 0.20128833 0.15362308 0.2613351 29.88113
#> 3      1 0.08750961     3 response 0.08750369 0.06555724 0.1141134 29.88113
#> 4      1 0.93735755     4 response 0.93641346 0.90660921 0.9587589 29.88113
#> 5      1 0.77437334     5 response 0.77829290 0.70896643 0.8346419 29.88113
#> 6      1 0.62216334     6 response 0.57010265 0.49377997 0.6441967 29.88113
#>   woman passengerClass survived
#> 1     0            1st        1
#> 2     0            2nd        1
#> 3     0            3rd        1
#> 4     1            1st        1
#> 5     1            2nd        1
#> 6     1            3rd        1

This “long” format makes it easy to plots results:

ggplot(pred, aes(x = draw, fill = factor(woman))) +
    geom_density() +
    facet_grid(~ passengerClass, labeller = label_both) +
    labs(x = "Predicted probability of survival", y = "", fill = "Woman")

Marginal effects

Use marginaleffects() to compute marginal effects (slopes of the regression equation) for each row of the dataset, and use summary() to compute “Average Marginal Effects”, that is, the average of all observation-level marginal effects:

mfx <- marginaleffects(mod)
summary(mfx)
#>             Term  Contrast    Effect     2.5 %    97.5 %
#> 1            age     dY/dX -0.005265 -0.007104 -0.003462
#> 2 passengerClass 2nd - 1st -0.237269 -0.309664 -0.164460
#> 3 passengerClass 3rd - 1st -0.389119 -0.454761 -0.322277
#> 4          woman     dY/dX  0.366252  0.336326  0.392362
#> 
#> Model type:  brmsfit 
#> Prediction type:  response

Compute marginal effects with some regressors fixed at user-specified values, and other regressors held at their means:

marginaleffects(mod,
                newdata = datagrid(woman = 1,
                                   passengerClass = "1st"))
#>   rowid     type           term  contrast          dydx     conf.low
#> 1     1 response          woman     dY/dX  0.1562784430  0.111359140
#> 2     1 response            age     dY/dX -0.0002383118 -0.001355295
#> 3     1 response passengerClass 2nd - 1st -0.1574415687 -0.223274987
#> 4     1 response passengerClass 3rd - 1st -0.3653763624 -0.438319255
#>       conf.high predicted predicted_hi predicted_lo      age woman
#> 1  0.2087557243 0.9364135    0.9364295    0.9364135 29.88113     1
#> 2  0.0008710313 0.9364135    0.9364102    0.9364135 29.88113     1
#> 3 -0.1028896556 0.9364135    0.7782929    0.9364135 29.88113     1
#> 4 -0.2947694986 0.9364135    0.5701027    0.9364135 29.88113     1
#>   passengerClass survived        eps
#> 1            1st        1 0.00010000
#> 2            1st        1 0.00798333
#> 3            1st        1         NA
#> 4            1st        1         NA

Compute and plot conditional marginal effects:

plot_cme(mod, effect = "woman", condition = "age")

The posteriordraws produces a dataset with drawid and draw columns:

draws <- posteriordraws(mfx)

dim(draws)
#> [1] 16736000       17

head(draws)
#>   drawid      draw rowid     type  term contrast       dydx   conf.low
#> 1      1 0.1633609     1 response woman    dY/dX 0.15298096 0.10883007
#> 2      1 0.1359009     2 response woman    dY/dX 0.13546458 0.03607654
#> 3      1 0.0615405     3 response woman    dY/dX 0.05878909 0.02875607
#> 4      1 0.7088821     4 response woman    dY/dX 0.65453708 0.56699300
#> 5      1 0.1473759     5 response woman    dY/dX 0.13828737 0.09722433
#> 6      1 0.6635099     6 response woman    dY/dX 0.70740470 0.58935737
#>   conf.high predicted predicted_hi predicted_lo survived woman     age
#> 1 0.2051898 0.9366604    0.9366751    0.9366604        1     1 29.0000
#> 2 0.2970409 0.8493050    0.8493174    0.8493050        1     0  0.9167
#> 3 0.1009453 0.9433293    0.9433358    0.9433293        0     1  2.0000
#> 4 0.7442501 0.5131011    0.5131652    0.5131011        0     0 30.0000
#> 5 0.1863969 0.9374937    0.9375085    0.9374937        0     1 25.0000
#> 6 0.8406470 0.2730542    0.2731252    0.2730542        1     0 48.0000
#>   passengerClass   eps
#> 1            1st 1e-04
#> 2            1st 1e-04
#> 3            1st 1e-04
#> 4            1st 1e-04
#> 5            1st 1e-04
#> 6            1st 1e-04

We can use this dataset to plot our results. For example, to plot the posterior density of the marginal effect of age when the woman variable is equal to 0 or 1:

mfx <- marginaleffects(mod,
                       variables = "age",
                       newdata = datagrid(woman = 0:1)) |>
       posteriordraws()

ggplot(mfx, aes(x = draw, fill = factor(woman))) +
    stat_halfeye(slab_alpha = .5) +
    labs(x = "Marginal Effect of Age on Survival",
         y = "Posterior density",
         fill = "Woman")

Random effects model

This section replicates some of the analyses of a random effects model published in Andrew Heiss’ blog post: “A guide to correctly calculating posterior predictions and average marginal effects with multilevel Bayesian models.” The objective is mainly to illustrate the use of marginaleffects. Please refer to the original post for a detailed discussion of the quantities computed below.

Load libraries and download data:

library(brms)
library(ggdist)
library(patchwork)
library(marginaleffects)

vdem_2015 <- read.csv("https://github.com/vincentarelbundock/marginaleffects/raw/main/data-raw/vdem_2015.csv")

head(vdem_2015)
#>   country_name country_text_id year                           region
#> 1       Mexico             MEX 2015  Latin America and the Caribbean
#> 2     Suriname             SUR 2015  Latin America and the Caribbean
#> 3       Sweden             SWE 2015 Western Europe and North America
#> 4  Switzerland             CHE 2015 Western Europe and North America
#> 5        Ghana             GHA 2015               Sub-Saharan Africa
#> 6 South Africa             ZAF 2015               Sub-Saharan Africa
#>   media_index party_autonomy_ord polyarchy civil_liberties party_autonomy
#> 1       0.837                  3     0.631           0.704           TRUE
#> 2       0.883                  4     0.777           0.887           TRUE
#> 3       0.956                  4     0.915           0.968           TRUE
#> 4       0.939                  4     0.901           0.960           TRUE
#> 5       0.858                  4     0.724           0.921           TRUE
#> 6       0.898                  4     0.752           0.869           TRUE

Fit a basic model:

mod <- brm(
  bf(media_index ~ party_autonomy + civil_liberties + (1 | region),
     phi ~ (1 | region)),
  data = vdem_2015,
  family = Beta(),
  control = list(adapt_delta = 0.9))

Posterior predictions

To compute posterior predictions for specific values of the regressors, we use the newdata argument and the datagrid function. We also use the type argument to compute two types of predictions: accounting for residual (observation-level) residual variance (prediction) or ignoring it (response).

nd = datagrid(model = mod,
              party_autonomy = c(TRUE, FALSE),
              civil_liberties = .5,
              region = "Middle East and North Africa")
p1 <- predictions(mod, type = "response", newdata = nd) %>%
    posteriordraws()
p2 <- predictions(mod, type = "prediction", newdata = nd) %>%
    posteriordraws()
pred <- rbind(p1, p2)

Extract posterior draws and plot them:

ggplot(pred, aes(x = draw, fill = party_autonomy)) +
    stat_halfeye(alpha = .5) +
    facet_wrap(~ type) +
    labs(x = "Media index (predicted)", 
         y = "Posterior density",
         fill = "Party autonomy")

Marginal effects and contrasts

As noted in the Marginal Effects vignette, there should be one distinct marginal effect for each combination of regressor values. Here, we consider only one combination of regressor values, where region is “Middle East and North Africa”, and civil_liberties is 0.5. Then, we calculate the mean of the posterior distribution of marginal effects:

mfx <- marginaleffects(mod,
                       newdata = datagrid(civil_liberties = .5,
                                          region = "Middle East and North Africa"))
mfx
#>   rowid     type            term     contrast      dydx  conf.low conf.high
#> 1     1 response  party_autonomy TRUE - FALSE 0.2517104 0.1663314 0.3355927
#> 2     1 response civil_liberties        dY/dX 0.8160498 0.6213970 1.0066362
#>   predicted predicted_hi predicted_lo party_autonomy civil_liberties
#> 1 0.6213442    0.6213442    0.3684876           TRUE             0.5
#> 2 0.6213442    0.6214182    0.6213442           TRUE             0.5
#>                         region media_index      eps
#> 1 Middle East and North Africa       0.837       NA
#> 2 Middle East and North Africa       0.837 9.56e-05

Use the posteriordraws() to extract draws from the posterior distribution of marginal effects, and plot them:

mfx <- posteriordraws(mfx)

ggplot(mfx, aes(x = draw, y = term)) +
  stat_halfeye() +
  labs(x = "Marginal effect", y = "")

Plot marginal effects, conditional on a regressor:

plot_cme(mod,
         effect = "civil_liberties",
         condition = "party_autonomy")

Continuous predictors

pred <- predictions(mod,
                    newdata = datagrid(party_autonomy = FALSE,
                                       region = "Middle East and North Africa",
                                       civil_liberties = seq(0, 1, by = 0.05))) |>
        posteriordraws()

ggplot(pred, aes(x = civil_liberties, y = draw)) +
    stat_lineribbon() +
    scale_fill_brewer(palette = "Reds") +
    labs(x = "Civil liberties",
         y = "Media index (predicted)",
         fill = "")

The slope of this line for different values of civil liberties can be obtained with:

mfx <- marginaleffects(mod,
                       newdata = datagrid(civil_liberties = c(.2, .5, .8),
                                          party_autonomy = FALSE,
                                          region = "Middle East and North Africa"),
                       variables = "civil_liberties")
mfx
#>   rowid     type            term      dydx  conf.low conf.high predicted
#> 1     1 response civil_liberties 0.4900170 0.3609330 0.6388463 0.1700110
#> 2     2 response civil_liberties 0.8071389 0.6121827 0.9926679 0.3684876
#> 3     3 response civil_liberties 0.8069026 0.6744996 0.9336997 0.6244447
#>   predicted_hi predicted_lo civil_liberties party_autonomy
#> 1    0.1700591    0.1700110             0.2          FALSE
#> 2    0.3685630    0.3684876             0.5          FALSE
#> 3    0.6245229    0.6244447             0.8          FALSE
#>                         region media_index      eps
#> 1 Middle East and North Africa       0.837 9.56e-05
#> 2 Middle East and North Africa       0.837 9.56e-05
#> 3 Middle East and North Africa       0.837 9.56e-05

And plotted:

mfx <- posteriordraws(mfx)

ggplot(mfx, aes(x = draw, fill = factor(civil_liberties))) +
    stat_halfeye(slab_alpha = .5) +
    labs(x = "Marginal effect of Civil Liberties on Media Index",
         y = "Posterior density",
         fill = "Civil liberties")

The marginaleffects function can use the ellipsis (...) to push any argument forward to the posterior_predict function. This can alter the types of predictions returned. For example, the re_formula=NA argument of the posterior_predict.brmsfit method will compute marginaleffects without including any group-level effects:

mfx <- marginaleffects(mod,
                       newdata = datagrid(civil_liberties = c(.2, .5, .8),
                                          party_autonomy = FALSE,
                                          region = "Middle East and North Africa"),
                       variables = "civil_liberties",
                       re_formula = NA) |>
       posteriordraws()

ggplot(mfx, aes(x = draw, fill = factor(civil_liberties))) +
    stat_halfeye(slab_alpha = .5) +
    labs(x = "Marginal effect of Civil Liberties on Media Index",
         y = "Posterior density",
         fill = "Civil liberties")

Global grand mean

pred <- predictions(mod,
                    re_formula = NA,
                    newdata = datagrid(party_autonomy = c(TRUE, FALSE))) |>
        posteriordraws()

mfx <- marginaleffects(mod,
                       re_formula = NA,
                       variables = "party_autonomy") |>
       posteriordraws()

plot1 <- ggplot(pred, aes(x = draw, fill = party_autonomy)) +
         stat_halfeye(slab_alpha = .5) +
         labs(x = "Media index (Predicted)",
              y = "Posterior density",
              fill = "Party autonomy")

plot2 <- ggplot(mfx, aes(x = draw)) +
         stat_halfeye(slab_alpha = .5)  +
         labs(x = "Contrast: Party autonomy TRUE - FALSE",
              y = "",
              fill = "Party autonomy")

# combine plots using the `patchwork` package
plot1 + plot2

Region-specific predictions and contrasts

Predicted media index by region and level of civil liberties:

pred <- predictions(mod,
                    newdata = datagrid(region = vdem_2015$region,
                                       party_autonomy = FALSE, 
                                       civil_liberties = seq(0, 1, length.out = 100))) |> 
        posteriordraws()

ggplot(pred, aes(x = civil_liberties, y = draw)) +
    stat_lineribbon() +
    scale_fill_brewer(palette = "Reds") +
    facet_wrap(~ region) +
    labs(x = "Civil liberties",
         y = "Media index (predicted)",
         fill = "")

Predicted media index by region and level of civil liberties:

pred <- predictions(mod,
                    newdata = datagrid(region = vdem_2015$region,
                                       civil_liberties = c(.2, .8),
                                      party_autonomy = FALSE)) |>
        posteriordraws()

ggplot(pred, aes(x = draw, fill = factor(civil_liberties))) +
    stat_halfeye(slab_alpha = .5) +
    facet_wrap(~ region) +
    labs(x = "Media index (predicted)",
         y = "Posterior density",
         fill = "Civil liberties")

Predicted media index by region and party autonomy:

pred <- predictions(mod,
                    newdata = datagrid(region = vdem_2015$region,
                                       party_autonomy = c(TRUE, FALSE),
                                       civil_liberties = .5)) |>
        posteriordraws()

ggplot(pred, aes(x = draw, y = region , fill = party_autonomy)) +
    stat_halfeye(slab_alpha = .5) +
    labs(x = "Media index (predicted)",
         y = "",
         fill = "Party autonomy")

TRUE/FALSE contrasts (marginal effects) of party autonomy by region:

mfx <- marginaleffects(mod,
                       variables = "party_autonomy",
                       newdata = datagrid(region = vdem_2015$region,
                                          civil_liberties = .5)) |>
        posteriordraws()

ggplot(mfx, aes(x = draw, y = region , fill = party_autonomy)) +
    stat_halfeye(slab_alpha = .5) +
    labs(x = "Media index (predicted)",
         y = "",
         fill = "Party autonomy")

Hypothetical groups

We can also obtain predictions or marginal effects for a hypothetical group instead of one of the observed regions. To achieve this, we create a dataset with NA in the region column. Then we call the marginaleffects or predictions functions with the allow_new_levels argument. This argument is pushed through via the ellipsis (...) to the posterior_epred function of the brms package:

dat <- data.frame(civil_liberties = .5,
                  party_autonomy = FALSE,
                  region = "New Region")

mfx <- marginaleffects(
    mod,
    variables = "party_autonomy",
    allow_new_levels = TRUE,
    newdata = dat)

draws <- posteriordraws(mfx)

ggplot(draws, aes(x = draw)) +
     stat_halfeye() +
     labs(x = "Marginal effect of party autonomy in a generic world region", y = "")

Multinomial logit

Fit a model with categorical outcome (heating system choice in California houses) and logit link:

dat <- "https://vincentarelbundock.github.io/Rdatasets/csv/Ecdat/Heating.csv"
dat <- read.csv(dat)
mod <- brm(depvar ~ ic.gc + oc.gc,
           data = dat,
           family = categorical(link = "logit"))

Adjusted predictions

Compute predicted probabilities for each level of the outcome variable:

pred <- predictions(mod)

head(pred)
#>   rowid     type group  predicted   conf.low  conf.high depvar  ic.gc  oc.gc
#> 1     1 response    ec 0.06626689 0.04471591 0.09304536     gc 866.00 199.69
#> 2     2 response    ec 0.07681658 0.05896082 0.09740047     gc 727.93 168.66
#> 3     3 response    ec 0.10300017 0.06181137 0.15849874     gc 599.48 165.58
#> 4     4 response    ec 0.06335247 0.04590258 0.08378809     er 835.17 180.88
#> 5     5 response    ec 0.07452660 0.05739728 0.09467869     er 755.59 174.91
#> 6     6 response    ec 0.07086098 0.04545918 0.10359585     gc 666.11 135.67

Extract posterior draws and plot them:

draws <- posteriordraws(pred)

ggplot(draws, aes(x = draw, fill = group)) +
    geom_density(alpha = .2, color = "white") +
    labs(x = "Predicted probability",
         y = "Density",
         fill = "Heating system")

Use the plot_cap function to plot conditional adjusted predictions for each level of the outcome variable gear, conditional on the value of the mpg regressor:

plot_cap(mod, condition = "oc.gc") +
    facet_wrap(~ group) +
    labs(y = "Predicted probability")

Marginal effects

mfx <- marginaleffects(mod)
summary(mfx)
#>    Group  Term     Effect      2.5 %    97.5 %
#> 1     ec ic.gc -1.773e-04 -3.961e-04 2.368e-05
#> 2     ec oc.gc  4.877e-04 -4.035e-04 1.447e-03
#> 3     er ic.gc  1.653e-05 -2.256e-04 2.506e-04
#> 4     er oc.gc -1.020e-03 -2.069e-03 2.983e-05
#> 5     gc ic.gc  1.380e-05 -3.724e-04 3.998e-04
#> 6     gc oc.gc  1.044e-03 -7.393e-04 2.784e-03
#> 7     gr ic.gc  4.243e-05 -2.365e-04 3.295e-04
#> 8     gr oc.gc  9.458e-05 -1.186e-03 1.342e-03
#> 9     hp ic.gc  1.073e-04 -7.726e-05 2.967e-04
#> 10    hp oc.gc -5.853e-04 -1.453e-03 2.296e-04
#> 
#> Model type:  brmsfit 
#> Prediction type:  response

Hurdle models

This section replicates some analyses from yet another amazing blog post by Andrew Heiss.

To begin, we estimate a hurdle model in brms with random effects, using data from the gapminder package:

library(gapminder)
library(brms)
library(dplyr)
library(ggplot2)
library(ggdist)
library(cmdstanr)
library(patchwork)
library(marginaleffects)

set.seed(1024)

CHAINS <- 4
ITER <- 2000
WARMUP <- 1000
BAYES_SEED <- 1234

gapminder <- gapminder::gapminder |> 
  filter(continent != "Oceania") |> 
  # Make a bunch of GDP values 0
  mutate(prob_zero = ifelse(lifeExp < 50, 0.3, 0.02),
         will_be_zero = rbinom(n(), 1, prob = prob_zero),
         gdpPercap = ifelse(will_be_zero, 0, gdpPercap)) |> 
  select(-prob_zero, -will_be_zero) |> 
  # Make a logged version of GDP per capita
  mutate(log_gdpPercap = log1p(gdpPercap)) |> 
  mutate(is_zero = gdpPercap == 0)

mod <- brm(
  bf(gdpPercap ~ lifeExp + year + (1 + lifeExp + year | continent),
     hu ~ lifeExp),
  data = gapminder,
  backend = "cmdstanr",
  family = hurdle_lognormal(),
  cores = 2,
  chains = CHAINS, iter = ITER, warmup = WARMUP, seed = BAYES_SEED,
  silent = 2)

Adjusted predictions

Adjusted predictions for every observation in the original data:

predictions(mod) %>% head()
#>   rowid     type predicted conf.low conf.high gdpPercap lifeExp year continent
#> 1     1 response  142.5456 103.1327  218.8240  779.4453  28.801 1952      Asia
#> 2     2 response  168.2657 124.8506  255.8359  820.8530  30.332 1957      Asia
#> 3     3 response  201.7414 152.9039  303.6894  853.1007  31.997 1962      Asia
#> 4     4 response  251.4996 196.5698  372.8495  836.1971  34.020 1967      Asia
#> 5     5 response  312.2482 249.5193  454.2419    0.0000  36.088 1972      Asia
#> 6     6 response  397.5390 324.5934  566.6799  786.1134  38.438 1977      Asia

Adjusted predictions for the hu parameter:

predictions(mod, dpar = "hu") %>% head()
#>   rowid     type predicted  conf.low conf.high gdpPercap lifeExp year continent
#> 1     1 response 0.5739495 0.4746831 0.6515670  779.4453  28.801 1952      Asia
#> 2     2 response 0.5365013 0.4416000 0.6112908  820.8530  30.332 1957      Asia
#> 3     3 response 0.4955596 0.4069162 0.5664250  853.1007  31.997 1962      Asia
#> 4     4 response 0.4455042 0.3656259 0.5106086  836.1971  34.020 1967      Asia
#> 5     5 response 0.3957129 0.3251678 0.4537019    0.0000  36.088 1972      Asia
#> 6     6 response 0.3413458 0.2823879 0.3907218  786.1134  38.438 1977      Asia

Predictions on a different scale:

predictions(mod, type = "link", dpar = "hu") %>% head()
#>   rowid type   predicted   conf.low   conf.high gdpPercap lifeExp year
#> 1     1 link  0.29798370 -0.1013542  0.62593415  779.4453  28.801 1952
#> 2     2 link  0.14626541 -0.2346709  0.45274118  820.8530  30.332 1957
#> 3     3 link -0.01776198 -0.3767284  0.26727993  853.1007  31.997 1962
#> 4     4 link -0.21885246 -0.5510282  0.04244063  836.1971  34.020 1967
#> 5     5 link -0.42336036 -0.7301227 -0.18572429    0.0000  36.088 1972
#> 6     6 link -0.65730248 -0.9326474 -0.44427921  786.1134  38.438 1977
#>   continent
#> 1      Asia
#> 2      Asia
#> 3      Asia
#> 4      Asia
#> 5      Asia
#> 6      Asia

Plot adjusted predictions as a function of lifeExp:

plot_cap(
    mod,
    condition = "lifeExp") +
    labs(y = "mu") +
plot_cap(
    mod,
    dpar = "hu",
    condition = "lifeExp") +
    labs(y = "hu")

Predictions with more than one condition and the re_formula argument from brms:

plot_cap(
    mod,
    re_formula = NULL,
    condition = c("lifeExp", "continent"))

Extract draws with posteriordraws()

The posteriordraws() function extract raw samples from the posterior from objects produced by marginaleffects. This allows us to use richer geoms and summaries, such as those in the ggdist package:

predictions(
    mod,
    re_formula = NULL,
    newdata = datagrid(model = mod,
                       continent = gapminder$continent,
                       year = c(1952, 2007),
                       lifeExp = seq(30, 80, 1))) %>%
    posteriordraws() %>%
    ggplot(aes(lifeExp, draw, fill = continent, color = continent)) +
    stat_lineribbon(alpha = .25) +
    facet_grid(year ~ continent)

Average Contrasts

What happens to gdpPercap when lifeExp increases by one?

comparisons(mod) %>% summary()
#>      Term Contrast Effect  2.5 % 97.5 %
#> 1 lifeExp       +1 718.87 515.58 811.96
#> 2    year       +1 -63.82 -84.44 -41.05
#> 
#> Model type:  brmsfit 
#> Prediction type:  response

What happens to gdpPercap when lifeExp increases by one standard deviation?

comparisons(mod, variables = list(lifeExp = "sd")) %>% summary()
#>      Term                Contrast Effect 2.5 % 97.5 %
#> 1 lifeExp (x + sd/2) - (x - sd/2)   4050  3718   4741
#> 
#> Model type:  brmsfit 
#> Prediction type:  response

What happens to gdpPercap when lifeExp increases from 50 to 60 and year simultaneously increases its min to its max?

comparisons(
    mod,
    variables = list(lifeExp = c(50, 60), year = "minmax"),
    interaction = TRUE) %>%
    summary()
#>   lifeExp      year Effect 2.5 % 97.5 %
#> 1 60 - 50 Max - Min  834.7 523.1   1404
#> 
#> Model type:  brmsfit 
#> Prediction type:  response

Plot draws from the posterior distribution of average contrasts (not the same thing as draws from the posterior distribution of contrasts):

comparisons(mod) %>%
    summary() %>%
    posteriordraws() %>%
    ggplot(aes(estimate, term)) +
    stat_dotsinterval() +
    labs(x = "Posterior distribution of average contrasts", y = "")

Marginal effects (slopes)

Average Marginal Effect of lifeExp on different scales and for different parameters:

marginaleffects(mod) |> summary()
#>      Term Effect  2.5 % 97.5 %
#> 1 lifeExp 718.71 515.56 811.70
#> 2    year -63.82 -84.44 -41.05
#> 
#> Model type:  brmsfit 
#> Prediction type:  response

marginaleffects(mod, type = "link") |> summary()
#>      Term   Effect    2.5 %    97.5 %
#> 1 lifeExp  0.08249  0.07419  0.088564
#> 2    year -0.00937 -0.01201 -0.006316
#> 
#> Model type:  brmsfit 
#> Prediction type:  link

marginaleffects(mod, dpar = "hu") |> summary()
#>      Term    Effect     2.5 %    97.5 %
#> 1 lifeExp -0.008171 -0.009367 -0.006687
#> 
#> Model type:  brmsfit 
#> Prediction type:  response

marginaleffects(mod, dpar = "hu", type = "link") |> summary()
#>      Term   Effect   2.5 %   97.5 %
#> 1 lifeExp -0.09934 -0.1132 -0.08382
#> 
#> Model type:  brmsfit 
#> Prediction type:  link

Plot Conditional Marginal Effects

plot_cme(
    mod,
    effect = "lifeExp",
    condition = "lifeExp") +
    labs(y = "mu") +

plot_cme(
    mod,
    dpar = "hu",
    effect = "lifeExp",
    condition = "lifeExp") +
    labs(y = "hu")

Or we can call marginaleffects() or comparisons() with posteriordraws() function to have even more control:

comparisons(
    mod,
    type = "link",
    variables = "lifeExp",
    newdata = datagrid(lifeExp = c(40, 70), continent = gapminder$continent)) %>%
    posteriordraws() %>%
    ggplot(aes(draw, continent, fill = continent)) +
    stat_dotsinterval() +
    facet_grid(lifeExp ~ .) +
    labs(x = "Effect of a 1 unit change in Life Expectancy")

Bayesian estimates and credible intervals

For bayesian models like those produced by the brms or rstanarm packages, the marginaleffects package functions report the median of the posterior distribution as their main estimates.

The default credible intervals are equal-tailed intervals (quantiles). Users can customize the type of intervals reported by setting a global option. Currently, the only alternative to ETI is the Highest Density Interval (HDI). Note that, in these two commands, the confidence intervals change slightly, but not the median estimates:

library(insight)
library(marginaleffects)

mod <- insight::download_model("brms_1")

options(marginaleffects_credible_interval = "hdi")
comparisons(mod) |> summary()
#>   Term Contrast Effect  2.5 %  97.5 %
#> 1  cyl       +1 -1.494 -2.385 -0.6771
#> 2   wt       +1 -3.195 -4.704 -1.5704
#> 
#> Model type:  brmsfit 
#> Prediction type:  response

options(marginaleffects_credible_interval = "eti")
comparisons(mod) |> summary()
#>   Term Contrast Effect  2.5 %  97.5 %
#> 1  cyl       +1 -1.494 -2.361 -0.6361
#> 2   wt       +1 -3.195 -4.792 -1.6450
#> 
#> Model type:  brmsfit 
#> Prediction type:  response