Outcome predicted by a fitted model on a specified scale for a given combination of values of the predictor variables, such as their observed values, their means, or factor levels (a.k.a. "reference grid").
predictions()
: unit-level (conditional) estimates.avg_predictions()
: average (marginal) estimates.
The newdata
argument and the datagrid()
function can be used to control where statistics are evaluated in the predictor space: "at observed values", "at the mean", "at representative values", etc.
See the predictions vignette and package website for worked examples and case studies:
Usage
predictions(
model,
newdata = NULL,
variables = NULL,
vcov = TRUE,
conf_level = 0.95,
type = NULL,
by = FALSE,
byfun = NULL,
wts = NULL,
transform_post = NULL,
hypothesis = NULL,
df = Inf,
...
)
avg_predictions(
model,
newdata = NULL,
variables = NULL,
vcov = TRUE,
conf_level = 0.95,
type = NULL,
by = TRUE,
byfun = NULL,
wts = NULL,
transform_post = NULL,
hypothesis = NULL,
df = Inf,
...
)
Arguments
- model
Model object
- newdata
NULL
, data frame, string, ordatagrid()
call. Determines the predictor values for which to compute slopes.NULL
(default): Unit-level slopes each observed value in the original dataset.data frame: Unit-level slopes for each row of the
newdata
data frame.datagrid()
call to specify a custom grid of regressors. For example:newdata = datagrid(cyl = c(4, 6))
:cyl
variable equal to 4 and 6 and other regressors fixed at their means or modes.See the Examples section and the
datagrid()
documentation.
string:
"mean": Marginal Effects at the Mean. Slopes when each predictor is held at its mean or mode.
"median": Marginal Effects at the Median. Slopes when each predictor is held at its median or mode.
"marginalmeans": Marginal Effects at Marginal Means. See Details section below.
"tukey": Marginal Effects at Tukey's 5 numbers.
"grid": Marginal Effects on a grid of representative numbers (Tukey's 5 numbers and unique values of categorical predictors).
- variables
NULL
, character vector, or named list. The subset of variables to use for creating a counterfactual grid of predictions. The entire dataset replicated for each unique combination of the variables in this list. See the Examples section below.Warning: This can use a lot of memory if there are many variables and values, and when the dataset is large.
NULL
: computes one prediction per row ofnewdata
Named list: names identify the subset of variables of interest and their values. For numeric variables, the
variables
argument supports functions and string shortcuts:A function which returns a numeric value
Numeric vector: Contrast between the 2nd element and the 1st element of the
x
vector."iqr": Contrast across the interquartile range of the regressor.
"sd": Contrast across one standard deviation around the regressor mean.
"2sd": Contrast across two standard deviations around the regressor mean.
"minmax": Contrast between the maximum and the minimum values of the regressor.
"threenum": mean and 1 standard deviation on both sides
"fivenum": Tukey's five numbers #' @param newdata
NULL
, data frame, string, ordatagrid()
call. Determines the grid of predictors on which we make predictions.
NULL
(default): Predictions for each observed value in the original dataset.data frame: Predictions for each row of the
newdata
data frame.string:
"mean": Predictions at the Mean. Predictions when each predictor is held at its mean or mode.
"median": Predictions at the Median. Predictions when each predictor is held at its median or mode.
"marginalmeans": Predictions at Marginal Means. See Details section below.
"tukey": Predictions at Tukey's 5 numbers.
"grid": Predictions on a grid of representative numbers (Tukey's 5 numbers and unique values of categorical predictors).
datagrid()
call to specify a custom grid of regressors. For example:newdata = datagrid(cyl = c(4, 6))
:cyl
variable equal to 4 and 6 and other regressors fixed at their means or modes.See the Examples section and the
datagrid()
documentation.
- vcov
Type of uncertainty estimates to report (e.g., for robust standard errors). Acceptable values:
FALSE: Do not compute standard errors. This can speed up computation considerably.
TRUE: Unit-level standard errors using the default
vcov(model)
variance-covariance matrix.String which indicates the kind of uncertainty estimates to return.
Heteroskedasticity-consistent:
"HC"
,"HC0"
,"HC1"
,"HC2"
,"HC3"
,"HC4"
,"HC4m"
,"HC5"
. See?sandwich::vcovHC
Heteroskedasticity and autocorrelation consistent:
"HAC"
Mixed-Models degrees of freedom: "satterthwaite", "kenward-roger"
Other:
"NeweyWest"
,"KernHAC"
,"OPG"
. See thesandwich
package documentation.
One-sided formula which indicates the name of cluster variables (e.g.,
~unit_id
). This formula is passed to thecluster
argument of thesandwich::vcovCL
function.Square covariance matrix
Function which returns a covariance matrix (e.g.,
stats::vcov(model)
)
- conf_level
numeric value between 0 and 1. Confidence level to use to build a confidence interval.
- type
string indicates the type (scale) of the predictions used to compute contrasts or slopes. This can differ based on the model type, but will typically be a string such as: "response", "link", "probs", or "zero". When an unsupported string is entered, the model-specific list of acceptable values is returned in an error message. When
type
isNULL
, the default value is used. This default is the first model-related row in themarginaleffects:::type_dictionary
dataframe.- by
Aggregate unit-level estimates (aka, marginalize, average over). Valid inputs:
FALSE
: return the original unit-level estimates.TRUE
: aggregate estimates for each term.Character vector of column names in
newdata
or in the data frame produced by calling the function without theby
argument.Data frame with a
by
column of group labels, and merging columns shared bynewdata
or the data frame produced by calling the same function without theby
argument.See examples below.
- byfun
A function such as
mean()
orsum()
used to aggregate estimates within the subgroups defined by theby
argument.NULL
uses themean()
function. Must accept a numeric vector and return a single numeric value. This is sometimes used to take the sum or mean of predicted probabilities across outcome or predictor levels. See examples section.- wts
string or numeric: weights to use when computing average contrasts or slopes. These weights only affect the averaging in
avg_*()
or with theby
argument, and not the unit-level estimates themselves.string: column name of the weights variable in
newdata
. When supplying a column name towts
, it is recommended to supply the original data (including the weights variable) explicitly tonewdata
.numeric: vector of length equal to the number of rows in the original data or in
newdata
(if supplied).
- transform_post
(experimental) A function applied to unit-level adjusted predictions and confidence intervals just before the function returns results. For bayesian models, this function is applied to individual draws from the posterior distribution, before computing summaries.
- hypothesis
specify a hypothesis test or custom contrast using a numeric value, vector, or matrix, a string, or a string formula.
Numeric:
Single value: the null hypothesis used in the computation of Z and p (before applying
transform_post
).Vector: Weights to compute a linear combination of (custom contrast between) estimates. Length equal to the number of rows generated by the same function call, but without the
hypothesis
argument.Matrix: Each column is a vector of weights, as describe above, used to compute a distinct linear combination of (contrast between) estimates. The column names of the matrix are used as labels in the output.
String formula to specify linear or non-linear hypothesis tests. If the
term
column uniquely identifies rows, terms can be used in the formula. Otherwise, useb1
,b2
, etc. to identify the position of each parameter. Examples:hp = drat
hp + drat = 12
b1 + b2 + b3 = 0
String:
"pairwise": pairwise differences between estimates in each row.
"reference": differences between the estimates in each row and the estimate in the first row.
"sequential": difference between an estimate and the estimate in the next row.
"revpairwise", "revreference", "revsequential": inverse of the corresponding hypotheses, as described above.
See the Examples section below and the vignette: https://vincentarelbundock.github.io/marginaleffects/articles/hypothesis.html
- df
Degrees of freedom used to compute p values and confidence intervals. A single numeric value between 1 and
Inf
. Whendf
isInf
, the normal distribution is used. Whendf
is finite, thet
distribution is used. See insight::get_df for a convenient function to extract degrees of freedom. Ex:slopes(model, df = insight::get_df(model))
- ...
Additional arguments are passed to the
predict()
method supplied by the modeling package.These arguments are particularly useful for mixed-effects or bayesian models (see the online vignettes on themarginaleffects
website). Available arguments can vary from model to model, depending on the range of supported arguments by each modeling package. See the "Model-Specific Arguments" section of the?marginaleffects
documentation for a non-exhaustive list of available arguments.
Value
A data.frame
with one row per observation and several columns:
rowid
: row number of thenewdata
data frametype
: prediction type, as defined by thetype
argumentgroup
: (optional) value of the grouped outcome (e.g., categorical outcome models)estimate
: predicted outcomestd.error
: standard errors computed by theinsight::get_predicted
function or, if unavailable, viamarginaleffects
delta method functionality.conf.low
: lower bound of the confidence interval (or equal-tailed interval for bayesian models)conf.high
: upper bound of the confidence interval (or equal-tailed interval for bayesian models)
Details
The newdata
argument, the tidy()
function, and datagrid()
function can be used to control the kind of predictions to report:
Average Predictions
Predictions at the Mean
Predictions at User-Specified values (aka Predictions at Representative values).
When possible, predictions()
delegates the computation of confidence
intervals to the insight::get_predicted()
function, which uses back
transformation to produce adequate confidence intervals on the scale
specified by the type
argument. When this is not possible, predictions()
uses the Delta Method to compute standard errors around adjusted
predictions, and builds symmetric confidence intervals. These naive symmetric
intervals may not always be appropriate. For instance, they may stretch beyond
the bounds of a binary response variables.
Model-Specific Arguments
Some model types allow model-specific arguments to modify the nature of marginal effects, predictions, marginal means, and contrasts.
Package | Class | Argument | Documentation |
brms | brmsfit | ndraws | brms::posterior_predict |
re_formula | |||
lme4 | merMod | include_random | insight::get_predicted |
re.form | lme4::predict.merMod | ||
allow.new.levels | lme4::predict.merMod | ||
glmmTMB | glmmTMB | re.form | glmmTMB::predict.glmmTMB |
allow.new.levels | glmmTMB::predict.glmmTMB | ||
zitype | glmmTMB::predict.glmmTMB | ||
mgcv | bam | exclude | mgcv::predict.bam |
robustlmm | rlmerMod | re.form | robustlmm::predict.rlmerMod |
allow.new.levels | robustlmm::predict.rlmerMod |
Bayesian posterior summaries
By default, credible intervals in bayesian models are built as equal-tailed intervals. This can be changed to a highest density interval by setting a global option:
options("marginaleffects_posterior_interval" = "eti")
options("marginaleffects_posterior_interval" = "hdi")
By default, the center of the posterior distribution in bayesian models is identified by the median. Users can use a different summary function by setting a global option:
options("marginaleffects_posterior_center" = "mean")
options("marginaleffects_posterior_center" = "median")
When estimates are averaged using the by
argument, the tidy()
function, or
the summary()
function, the posterior distribution is marginalized twice over.
First, we take the average across units but within each iteration of the
MCMC chain, according to what the user requested in by
argument or
tidy()/summary()
functions. Then, we identify the center of the resulting
posterior using the function supplied to the
"marginaleffects_posterior_center"
option (the median by default).
Examples
# Adjusted Prediction for every row of the original dataset
mod <- lm(mpg ~ hp + factor(cyl), data = mtcars)
pred <- predictions(mod)
head(pred)
#>
#> Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> 20.04 1.2041 16.64 < 2.22e-16 17.68 22.40
#> 20.04 1.2041 16.64 < 2.22e-16 17.68 22.40
#> 26.41 0.9620 27.46 < 2.22e-16 24.53 28.30
#> 20.04 1.2041 16.64 < 2.22e-16 17.68 22.40
#> 15.92 0.9925 16.04 < 2.22e-16 13.98 17.87
#> 20.16 1.2186 16.54 < 2.22e-16 17.77 22.55
#>
#> Prediction type: response
#> Columns: rowid, type, estimate, std.error, statistic, p.value, conf.low, conf.high, mpg, cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb
#>
# Adjusted Predictions at User-Specified Values of the Regressors
predictions(mod, newdata = datagrid(hp = c(100, 120), cyl = 4))
#>
#> Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 % hp cyl
#> 26.25 0.9856 26.63 < 2.22e-16 24.31 28.18 100 4
#> 25.77 1.1096 23.22 < 2.22e-16 23.59 27.94 120 4
#>
#> Prediction type: response
#> Columns: rowid, type, estimate, std.error, statistic, p.value, conf.low, conf.high, mpg, hp, cyl
#>
m <- lm(mpg ~ hp + drat + factor(cyl) + factor(am), data = mtcars)
predictions(m, newdata = datagrid(FUN_factor = unique, FUN_numeric = median))
#>
#> Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> 21.95 1.288 17.04 < 2.22e-16 19.43 24.48
#> 18.19 1.271 14.31 < 2.22e-16 15.70 20.68
#> 25.55 1.322 19.32 < 2.22e-16 22.96 28.14
#> 21.78 1.541 14.13 < 2.22e-16 18.76 24.81
#> 22.62 2.141 10.56 < 2.22e-16 18.42 26.81
#> 18.85 1.734 10.87 < 2.22e-16 15.45 22.25
#>
#> Prediction type: response
#> Columns: rowid, type, estimate, std.error, statistic, p.value, conf.low, conf.high, mpg, hp, drat, cyl, am
#>
# Average Adjusted Predictions (AAP)
library(dplyr)
mod <- lm(mpg ~ hp * am * vs, mtcars)
avg_predictions(mod)
#>
#> Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> 20.09 0.4844 41.47 < 2.22e-16 19.14 21.04
#>
#> Prediction type: response
#> Columns: type, estimate, std.error, statistic, p.value, conf.low, conf.high
#>
predictions(mod, by = "am")
#>
#> am Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> 1 24.39 0.7601 32.09 < 2.22e-16 22.90 25.88
#> 0 17.15 0.6287 27.27 < 2.22e-16 15.92 18.38
#>
#> Prediction type: response
#> Columns: type, am, estimate, std.error, statistic, p.value, conf.low, conf.high
#>
# Conditional Adjusted Predictions
plot_predictions(mod, condition = "hp")
# Counterfactual predictions with the `variables` argument
# the `mtcars` dataset has 32 rows
mod <- lm(mpg ~ hp + am, data = mtcars)
p <- predictions(mod)
head(p)
#>
#> Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> 25.38 0.8176 31.05 < 2.22e-16 23.78 26.99
#> 25.38 0.8176 31.05 < 2.22e-16 23.78 26.99
#> 26.39 0.8496 31.06 < 2.22e-16 24.72 28.05
#> 20.11 0.7755 25.93 < 2.22e-16 18.59 21.63
#> 16.28 0.6774 24.03 < 2.22e-16 14.95 17.61
#> 20.40 0.7962 25.62 < 2.22e-16 18.84 21.96
#>
#> Prediction type: response
#> Columns: rowid, type, estimate, std.error, statistic, p.value, conf.low, conf.high, mpg, cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb
#>
nrow(p)
#> [1] 32
# counterfactual predictions obtained by replicating the entire for different
# values of the predictors
p <- predictions(mod, variables = list(hp = c(90, 110)))
nrow(p)
#> [1] 64
# hypothesis test: is the prediction in the 1st row equal to the prediction in the 2nd row
mod <- lm(mpg ~ wt + drat, data = mtcars)
predictions(
mod,
newdata = datagrid(wt = 2:3),
hypothesis = "b1 = b2")
#>
#> Term Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> b1=b2 4.783 0.797 6.001 1.9629e-09 3.221 6.345
#>
#> Prediction type: response
#> Columns: type, term, estimate, std.error, statistic, p.value, conf.low, conf.high
#>
# same hypothesis test using row indices
predictions(
mod,
newdata = datagrid(wt = 2:3),
hypothesis = "b1 - b2 = 0")
#>
#> Term Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> b1-b2=0 4.783 0.797 6.001 1.9629e-09 3.221 6.345
#>
#> Prediction type: response
#> Columns: type, term, estimate, std.error, statistic, p.value, conf.low, conf.high
#>
# same hypothesis test using numeric vector of weights
predictions(
mod,
newdata = datagrid(wt = 2:3),
hypothesis = c(1, -1))
#>
#> Term Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> custom 4.783 0.797 6.001 1.9629e-09 3.221 6.345
#>
#> Prediction type: response
#> Columns: type, term, estimate, std.error, statistic, p.value, conf.low, conf.high
#>
# two custom contrasts using a matrix of weights
lc <- matrix(c(
1, -1,
2, 3),
ncol = 2)
predictions(
mod,
newdata = datagrid(wt = 2:3),
hypothesis = lc)
#>
#> Term Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> custom 4.783 0.797 6.001 1.9629e-09 3.221 6.345
#> custom 115.214 3.647 31.587 < 2.22e-16 108.065 122.363
#>
#> Prediction type: response
#> Columns: rowid, type, term, estimate, std.error, statistic, p.value, conf.low, conf.high
#>
# `by` argument
mod <- lm(mpg ~ hp * am * vs, data = mtcars)
predictions(mod, by = c("am", "vs"))
#>
#> am vs Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> 1 0 19.75 1.1188 17.65 < 2.22e-16 17.56 21.94
#> 1 1 28.37 1.0358 27.39 < 2.22e-16 26.34 30.40
#> 0 1 20.74 1.0358 20.03 < 2.22e-16 18.71 22.77
#> 0 0 15.05 0.7911 19.02 < 2.22e-16 13.50 16.60
#>
#> Prediction type: response
#> Columns: type, am, vs, estimate, std.error, statistic, p.value, conf.low, conf.high
#>
library(nnet)
nom <- multinom(factor(gear) ~ mpg + am * vs, data = mtcars, trace = FALSE)
# first 5 raw predictions
predictions(nom, type = "probs") |> head()
#>
#> Group Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> 3 3.624e-05 2.002e-03 0.01810 0.985561 -3.889e-03 3.961e-03
#> 3 3.624e-05 2.002e-03 0.01810 0.985561 -3.889e-03 3.961e-03
#> 3 9.348e-08 6.912e-06 0.01352 0.989210 -1.345e-05 1.364e-05
#> 3 4.045e-01 1.965e-01 2.05788 0.039602 1.924e-02 7.897e-01
#> 3 1.000e+00 1.246e-03 802.40563 < 2e-16 9.975e-01 1.002e+00
#> 3 5.183e-01 2.898e-01 1.78858 0.073683 -4.967e-02 1.086e+00
#>
#> Prediction type: probs
#> Columns: rowid, type, group, estimate, std.error, statistic, p.value, conf.low, conf.high, mpg, cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb
#>
# average predictions
avg_predictions(nom, type = "probs", by = "group")
#>
#> Group Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> 3 0.4688 0.04043 11.595 < 2.22e-16 0.38952 0.5480
#> 4 0.3750 0.06142 6.106 1.0231e-09 0.25462 0.4954
#> 5 0.1562 0.04624 3.379 0.0007279 0.06561 0.2469
#>
#> Prediction type: probs
#> Columns: type, group, estimate, std.error, statistic, p.value, conf.low, conf.high
#>
by <- data.frame(
group = c("3", "4", "5"),
by = c("3,4", "3,4", "5"))
predictions(nom, type = "probs", by = by)
#>
#> Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 % By
#> 0.4219 0.02312 18.246 < 2.22e-16 0.37656 0.4672 3,4
#> 0.1562 0.04624 3.379 0.0007279 0.06561 0.2469 5
#>
#> Prediction type: probs
#> Columns: type, estimate, std.error, statistic, p.value, conf.low, conf.high, by
#>
# sum of predicted probabilities for combined response levels
mod <- multinom(factor(cyl) ~ mpg + am, data = mtcars, trace = FALSE)
by <- data.frame(
by = c("4,6", "4,6", "8"),
group = as.character(c(4, 6, 8)))
predictions(mod, newdata = "mean", byfun = sum, by = by)
#>
#> Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 % By
#> 0.91584 0.1218 7.5188 5.5296e-14 0.6771 1.1546 4,6
#> 0.08416 0.1218 0.6909 0.48963 -0.1546 0.3229 8
#>
#> Prediction type: probs
#> Columns: type, estimate, std.error, statistic, p.value, conf.low, conf.high, by
#>