Skip to contents

Outcome predicted by a fitted model on a specified scale for a given combination of values of the predictor variables, such as their observed values, their means, or factor levels (a.k.a. "reference grid"). The tidy() and summary() functions can be used to aggregate the output of predictions(). To learn more, read the predictions vignette, visit the package website, or scroll down this page for a full list of vignettes:

Usage

predictions(
  model,
  newdata = NULL,
  variables = NULL,
  vcov = TRUE,
  conf_level = 0.95,
  type = NULL,
  by = NULL,
  byfun = NULL,
  wts = NULL,
  transform_post = NULL,
  hypothesis = NULL,
  ...
)

Arguments

model

Model object

newdata

NULL, data frame, string, or datagrid() call. Determines the grid of predictors on which we make predictions.

  • NULL (default): Predictions for each observed value in the original dataset.

  • data frame: Predictions for each row of the newdata data frame.

  • string:

    • "mean": Predictions at the Mean. Predictions when each predictor is held at its mean or mode.

    • "median": Predictions at the Median. Predictions when each predictor is held at its median or mode.

    • "marginalmeans": Predictions at Marginal Means. See Details section below.

    • "tukey": Predictions at Tukey's 5 numbers.

    • "grid": Predictions on a grid of representative numbers (Tukey's 5 numbers and unique values of categorical predictors).

  • datagrid() call to specify a custom grid of regressors. For example:

    • newdata = datagrid(cyl = c(4, 6)): cyl variable equal to 4 and 6 and other regressors fixed at their means or modes.

    • See the Examples section and the datagrid() documentation.

variables

Named list of variables with values to create a counterfactual grid of predictions. The entire dataset replicated for each unique combination of the variables in this list. See the Examples section below. Warning: This can use a lot of memory if there are many variables and values, and when the dataset is large.

vcov

Type of uncertainty estimates to report (e.g., for robust standard errors). Acceptable values:

  • FALSE: Do not compute standard errors. This can speed up computation considerably.

  • TRUE: Unit-level standard errors using the default vcov(model) variance-covariance matrix.

  • String which indicates the kind of uncertainty estimates to return.

    • Heteroskedasticity-consistent: "HC", "HC0", "HC1", "HC2", "HC3", "HC4", "HC4m", "HC5". See ?sandwich::vcovHC

    • Heteroskedasticity and autocorrelation consistent: "HAC"

    • Mixed-Models degrees of freedom: "satterthwaite", "kenward-roger"

    • Other: "NeweyWest", "KernHAC", "OPG". See the sandwich package documentation.

  • One-sided formula which indicates the name of cluster variables (e.g., ~unit_id). This formula is passed to the cluster argument of the sandwich::vcovCL function.

  • Square covariance matrix

  • Function which returns a covariance matrix (e.g., stats::vcov(model))

conf_level

numeric value between 0 and 1. Confidence level to use to build a confidence interval.

type

string indicates the type (scale) of the predictions used to compute marginal effects or contrasts. This can differ based on the model type, but will typically be a string such as: "response", "link", "probs", or "zero". When an unsupported string is entered, the model-specific list of acceptable values is returned in an error message. When type is NULL, the default value is used. This default is the first model-related row in the marginaleffects:::type_dictionary dataframe.

by

Character vector of variable names over which to compute group-wise estimates.

byfun

A function such as mean() or sum() used to aggregate estimates within the subgroups defined by the by argument. NULL uses the mean() function. Must accept a numeric vector and return a single numeric value. This is sometimes used to take the sum or mean of predicted probabilities across outcome or predictor levels. See examples section.

wts

string or numeric: weights to use when computing average contrasts or marginaleffects. These weights only affect the averaging in tidy() or summary(), and not the unit-level estimates themselves.

  • string: column name of the weights variable in newdata. When supplying a column name to wts, it is recommended to supply the original data (including the weights variable) explicitly to newdata.

  • numeric: vector of length equal to the number of rows in the original data or in newdata (if supplied).

transform_post

(experimental) A function applied to unit-level adjusted predictions and confidence intervals just before the function returns results. For bayesian models, this function is applied to individual draws from the posterior distribution, before computing summaries.

hypothesis

specify a hypothesis test or custom contrast using a vector, matrix, string, or string formula.

  • String:

    • "pairwise": pairwise differences between estimates in each row.

    • "reference": differences between the estimates in each row and the estimate in the first row.

    • "sequential": difference between an estimate and the estimate in the next row.

    • "revpairwise", "revreference", "revsequential": inverse of the corresponding hypotheses, as described above.

  • String formula to specify linear or non-linear hypothesis tests. If the term column uniquely identifies rows, terms can be used in the formula. Otherwise, use b1, b2, etc. to identify the position of each parameter. Examples:

    • hp = drat

    • hp + drat = 12

    • b1 + b2 + b3 = 0

  • Numeric vector: Weights to compute a linear combination of (custom contrast between) estimates. Length equal to the number of rows generated by the same function call, but without the hypothesis argument.

  • Numeric matrix: Each column is a vector of weights, as describe above, used to compute a distinct linear combination of (contrast between) estimates. The column names of the matrix are used as labels in the output.

  • See the Examples section below and the vignette: https://vincentarelbundock.github.io/marginaleffects/articles/hypothesis.html

...

Additional arguments are passed to the predict() method supplied by the modeling package.These arguments are particularly useful for mixed-effects or bayesian models (see the online vignettes on the marginaleffects website). Available arguments can vary from model to model, depending on the range of supported arguments by each modeling package. See the "Model-Specific Arguments" section of the ?marginaleffects documentation for a non-exhaustive list of available arguments.

Value

A data.frame with one row per observation and several columns:

  • rowid: row number of the newdata data frame

  • type: prediction type, as defined by the type argument

  • group: (optional) value of the grouped outcome (e.g., categorical outcome models)

  • predicted: predicted outcome

  • std.error: standard errors computed by the insight::get_predicted function or, if unavailable, via marginaleffects delta method functionality.

  • conf.low: lower bound of the confidence interval (or equal-tailed interval for bayesian models)

  • conf.high: upper bound of the confidence interval (or equal-tailed interval for bayesian models)

Details

The newdata argument, the tidy() function, and datagrid() function can be used to control the kind of predictions to report:

  • Average Predictions

  • Predictions at the Mean

  • Predictions at User-Specified values (aka Predictions at Representative values).

When possible, predictions() delegates the computation of confidence intervals to the insight::get_predicted() function, which uses back transformation to produce adequate confidence intervals on the scale specified by the type argument. When this is not possible, predictions() uses the Delta Method to compute standard errors around adjusted predictions, and builds symmetric confidence intervals. These naive symmetric intervals may not always be appropriate. For instance, they may stretch beyond the bounds of a binary response variables.

Model-Specific Arguments

Some model types allow model-specific arguments to modify the nature of marginal effects, predictions, marginal means, and contrasts.

PackageClassArgumentDocumentation
brmsbrmsfitndrawsbrms::posterior_predict
re_formula
lme4merModinclude_randominsight::get_predicted
re.formlme4::predict.merMod
allow.new.levelslme4::predict.merMod
glmmTMBglmmTMBre.formglmmTMB::predict.glmmTMB
allow.new.levelsglmmTMB::predict.glmmTMB
zitypeglmmTMB::predict.glmmTMB
mgcvbamexcludemgcv::predict.bam
robustlmmrlmerModre.formrobustlmm::predict.rlmerMod
allow.new.levelsrobustlmm::predict.rlmerMod

Examples

# Adjusted Prediction for every row of the original dataset
mod <- lm(mpg ~ hp + factor(cyl), data = mtcars)
pred <- predictions(mod)
head(pred)
#>   rowid     type predicted std.error statistic       p.value conf.low conf.high
#> 1     1 response  20.03819 1.2041405  16.64107  3.512623e-62 17.57162  22.50476
#> 2     2 response  20.03819 1.2041405  16.64107  3.512623e-62 17.57162  22.50476
#> 3     3 response  26.41451 0.9619738  27.45866 5.476301e-166 24.44399  28.38502
#> 4     4 response  20.03819 1.2041405  16.64107  3.512623e-62 17.57162  22.50476
#> 5     5 response  15.92247 0.9924560  16.04350  6.347069e-58 13.88952  17.95543
#> 6     6 response  20.15839 1.2186288  16.54186  1.832792e-61 17.66214  22.65463
#>    mpg  hp cyl
#> 1 21.0 110   6
#> 2 21.0 110   6
#> 3 22.8  93   4
#> 4 21.4 110   6
#> 5 18.7 175   8
#> 6 18.1 105   6

# Adjusted Predictions at User-Specified Values of the Regressors
predictions(mod, newdata = datagrid(hp = c(100, 120), cyl = 4))
#>   rowid     type predicted std.error statistic       p.value conf.low conf.high
#> 1     1 response  26.24623 0.9856325  26.62883 3.148430e-156 24.22726  28.26521
#> 2     2 response  25.76546 1.1096486  23.21947 2.895018e-119 23.49245  28.03847
#>    hp cyl mpg
#> 1 100   4  21
#> 2 120   4  21

# Average Adjusted Predictions (AAP)
library(dplyr)
mod <- lm(mpg ~ hp * am * vs, mtcars)

pred <- predictions(mod)
summary(pred)
#>   Predicted Std. Error z value   Pr(>|z|) CI low CI high
#> 1     20.09     0.4844   41.47 < 2.22e-16  19.14   21.04
#> 
#> Model type:  lm 
#> Prediction type:  response 

predictions(mod, by = "am")
#>       type am predicted std.error statistic       p.value conf.low conf.high
#> 1 response  1  24.39231 0.7600565  32.09275 5.564392e-226 22.90262  25.88199
#> 2 response  0  17.14737 0.6286961  27.27449 8.515145e-164 15.91515  18.37959

# Conditional Adjusted Predictions
plot_cap(mod, condition = "hp")


# Counterfactual predictions with the `variables` argument
# the `mtcars` dataset has 32 rows

mod <- lm(mpg ~ hp + am, data = mtcars)
p <- predictions(mod)
head(p)
#>   rowid     type predicted std.error statistic       p.value conf.low conf.high
#> 1     1 response  25.38434 0.8176495  31.04550 1.311940e-211 23.71206  27.05662
#> 2     2 response  25.38434 0.8176495  31.04550 1.311940e-211 23.71206  27.05662
#> 3     3 response  26.38543 0.8495566  31.05789 8.927432e-212 24.64790  28.12297
#> 4     4 response  20.10726 0.7754954  25.92827 3.197502e-148 18.52119  21.69332
#> 5     5 response  16.27955 0.6773841  24.03297 1.258166e-127 14.89414  17.66495
#> 6     6 response  20.40169 0.7962179  25.62325 8.401973e-145 18.77325  22.03014
#>    mpg  hp am
#> 1 21.0 110  1
#> 2 21.0 110  1
#> 3 22.8  93  1
#> 4 21.4 110  0
#> 5 18.7 175  0
#> 6 18.1 105  0
nrow(p)
#> [1] 32

# counterfactual predictions obtained by replicating the entire for different
# values of the predictors
p <- predictions(mod, variables = list(hp = c(90, 110)))
nrow(p)
#> [1] 64


# hypothesis test: is the prediction in the 1st row equal to the prediction in the 2nd row
mod <- lm(mpg ~ wt + drat, data = mtcars)

predictions(
    mod,
    newdata = datagrid(wt = 2:3),
    hypothesis = "b1 = b2")
#>       type  term predicted std.error statistic      p.value conf.low conf.high
#> 1 response b1=b2   4.78289 0.7970353  6.000851 1.962855e-09  3.22073  6.345051

# same hypothesis test using row indices
predictions(
    mod,
    newdata = datagrid(wt = 2:3),
    hypothesis = "b1 - b2 = 0")
#>       type    term predicted std.error statistic      p.value conf.low
#> 1 response b1-b2=0   4.78289 0.7970353  6.000851 1.962855e-09  3.22073
#>   conf.high
#> 1  6.345051

# same hypothesis test using numeric vector of weights
predictions(
    mod,
    newdata = datagrid(wt = 2:3),
    hypothesis = c(1, -1))
#>       type   term predicted std.error statistic      p.value conf.low conf.high
#> 1 response custom   4.78289 0.7970353  6.000851 1.962855e-09  3.22073  6.345051

# two custom contrasts using a matrix of weights
lc <- matrix(c(
    1, -1,
    2, 3),
    ncol = 2)
predictions(
    mod,
    newdata = datagrid(wt = 2:3),
    hypothesis = lc)
#>   rowid     type   term predicted std.error statistic       p.value  conf.low
#> 1     1 response custom   4.78289 0.7970353  6.000851  1.962855e-09   3.22073
#> 2     2 response custom 115.21432 3.6474827 31.587352 5.507900e-219 108.06539
#>    conf.high
#> 1   6.345051
#> 2 122.363255


# `by` argument
mod <- lm(mpg ~ hp * am * vs, data = mtcars)
predictions(mod, by = c("am", "vs")) 
#>       type am vs predicted std.error statistic       p.value conf.low conf.high
#> 1 response  1  0  19.75000 1.1187729  17.65327  9.603187e-70 17.55725  21.94275
#> 2 response  1  1  28.37143 1.0357825  27.39130 3.481693e-165 26.34133  30.40152
#> 3 response  0  1  20.74286 1.0357825  20.02627  3.251286e-89 18.71276  22.77295
#> 4 response  0  0  15.05000 0.7910919  19.02434  1.072330e-80 13.49949  16.60051

library(nnet)
nom <- multinom(factor(gear) ~ mpg + am * vs, data = mtcars, trace = FALSE)

# first 5 raw predictions
predictions(nom, type = "probs") |> head()
#>   rowid  type group    predicted    std.error    statistic    p.value
#> 1     1 probs     3 3.623918e-05 2.002490e-03   0.01809706 0.98556142
#> 2     2 probs     3 3.623918e-05 2.002490e-03   0.01809706 0.98556142
#> 3     3 probs     3 9.347603e-08 6.911938e-06   0.01352385 0.98920986
#> 4     4 probs     3 4.044657e-01 1.965452e-01   2.05787667 0.03960197
#> 5     5 probs     3 9.999714e-01 1.246217e-03 802.40562752 0.00000000
#> 6     6 probs     3 5.183336e-01 2.898025e-01   1.78857550 0.07368321
#>        conf.low    conf.high gear  mpg am vs
#> 1 -3.888569e-03 3.961047e-03    4 21.0  1  0
#> 2 -3.888569e-03 3.961047e-03    4 21.0  1  0
#> 3 -1.345367e-05 1.364063e-05    4 22.8  1  1
#> 4  1.924426e-02 7.896871e-01    3 21.4  0  1
#> 5  9.975289e-01 1.002414e+00    3 18.7  0  0
#> 6 -4.966881e-02 1.086336e+00    3 18.1  0  1

# average predictions
predictions(nom, type = "probs", by = "group") |> summary()
#>   Group Predicted Std. Error z value   Pr(>|z|)  CI low CI high
#> 1     3    0.4688    0.04043  11.595 < 2.22e-16 0.38952  0.5480
#> 2     4    0.3750    0.06142   6.106 1.0231e-09 0.25462  0.4954
#> 3     5    0.1562    0.04624   3.379  0.0007279 0.06561  0.2469
#> 
#> Model type:  multinom 
#> Prediction type:  probs 

by <- data.frame(
    group = c("3", "4", "5"),
    by = c("3,4", "3,4", "5"))

predictions(nom, type = "probs", by = by)
#>    type predicted  std.error statistic      p.value   conf.low conf.high  by
#> 1 probs 0.4218766 0.02312133 18.246210 2.217708e-74 0.37655960 0.4671935 3,4
#> 2 probs 0.1562469 0.04624265  3.378848 7.279037e-04 0.06561294 0.2468808   5

# sum of predicted probabilities for combined response levels
mod <- multinom(factor(cyl) ~ mpg + am, data = mtcars, trace = FALSE)
by <- data.frame(
    by = c("4,6", "4,6", "8"),
    group = as.character(c(4, 6, 8)))
predictions(mod, newdata = "mean", byfun = sum, by = by)
#> Warning: Some of the variable names are missing from the model data: group
#>    type  predicted std.error statistic      p.value   conf.low conf.high  by
#> 1 probs 0.91584335 0.1218077 7.5187647 5.529622e-14  0.6771047 1.1545820 4,6
#> 2 probs 0.08415665 0.1218077 0.6908977 4.896298e-01 -0.1545820 0.3228953   8