Several packages in the R
ecosystem allow users to estimate multinomial logit model and discrete choice models. This case study illustrates the use of marginaleffects
with the nnet
and mlogit
packages.
We begin by loading two libraries:
nnet
package
The multinom
function of the nnet
package allows users to fit log-linear models via neural networks. The data
used for this function is a data frame with one observation per row, and the response variable is coded a factor. All the marginaleffects
package function work seamlessly with this model. For example, we can estimate a model and compute average marginal effects as follows:
library(nnet)
head(mtcars)
#> mpg cyl disp hp drat wt qsec vs am gear carb
#> Mazda RX4 21.0 6 160 110 3.90 2.620 16.46 0 1 4 4
#> Mazda RX4 Wag 21.0 6 160 110 3.90 2.875 17.02 0 1 4 4
#> Datsun 710 22.8 4 108 93 3.85 2.320 18.61 1 1 4 1
#> Hornet 4 Drive 21.4 6 258 110 3.08 3.215 19.44 1 0 3 1
#> Hornet Sportabout 18.7 8 360 175 3.15 3.440 17.02 0 0 3 2
#> Valiant 18.1 6 225 105 2.76 3.460 20.22 1 0 3 1
mod <- multinom(factor(gear) ~ hp + mpg, data = mtcars, trace = FALSE)
avg_slopes(mod, type = "probs")
#>
#> Group Term Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> 3 hp -3.438e-05 0.002253 -0.01526 0.98782440 -0.004450 0.0043815
#> 3 mpg -7.131e-02 0.026441 -2.69703 0.00699606 -0.123134 -0.0194886
#> 4 hp -4.667e-03 0.002199 -2.12294 0.03375920 -0.008976 -0.0003583
#> 4 mpg 1.591e-02 0.020003 0.79525 0.42646997 -0.023298 0.0551125
#> 5 hp 4.702e-03 0.001304 3.60439 0.00031289 0.002145 0.0072584
#> 5 mpg 5.540e-02 0.016469 3.36416 0.00076776 0.023126 0.0876827
#>
#> Prediction type: probs
#> Columns: type, group, term, estimate, std.error, statistic, p.value, conf.low, conf.high
Notice that in such models, we get one marginal effect for each term, for each level of the response variable. For this reason, we should use "group"
in the condition
argument (or facet_*()
function) when calling one of the plotting functions:
library(ggplot2)
plot_predictions(mod, condition = c("mpg", "group"), type = "probs")
plot_predictions(mod, condition = "mpg", type = "probs") + facet_wrap(~group)
plot_comparisons(
mod,
effect = list(mpg = c(15, 30)),
condition = "group",
type = "probs")
mlogit
package
The mlogit
package uses data
in a slightly different structure, with one row per observation-choice combination. For example, this data on choice of travel mode includes 4 rows per individual, one for each mode of transportation:
library("AER")
library("mlogit")
library(tidyverse)
data("TravelMode", package = "AER")
head(TravelMode)
#> individual mode choice wait vcost travel gcost income size
#> 1 1 air no 69 59 100 70 35 1
#> 2 1 train no 34 31 372 71 35 1
#> 3 1 bus no 35 25 417 70 35 1
#> 4 1 car yes 0 10 180 30 35 1
#> 5 2 air no 64 58 68 68 30 2
#> 6 2 train no 44 31 354 84 30 2
mod <- mlogit(choice ~ wait + gcost | income + size, TravelMode)
avg_slopes(mod, variables = c("income", "size"))
#>
#> Group Term Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> air income 0.0027855 0.001218 2.2876 0.022159 0.0003990 0.005172
#> bus income -0.0003721 0.001103 -0.3374 0.735811 -0.0025337 0.001790
#> car income 0.0033731 0.001373 2.4559 0.014052 0.0006812 0.006065
#> train income -0.0057865 0.001319 -4.3861 1.1540e-05 -0.0083723 -0.003201
#> air size -0.1264647 0.028918 -4.3732 1.2245e-05 -0.1831434 -0.069786
#> bus size 0.0113450 0.025867 0.4386 0.660962 -0.0393539 0.062044
#> car size 0.0458798 0.024755 1.8534 0.063830 -0.0026388 0.094398
#> train size 0.0692398 0.024785 2.7936 0.005212 0.0206624 0.117817
#>
#> Prediction type: response
#> Columns: type, group, term, estimate, std.error, statistic, p.value, conf.low, conf.high
Note that the slopes
function will always return estimates of zero for regressors before the vertical bar in the formula. This is because the predict()
function supplied by the mlogit
package does not produce different predictions for different values of these variables.
To compute different kinds of marginal effects, we can construct customized data frames and feed them to the newdata
argument of the slopes
function.
Important: The newdata
argument for mlogit
models must be a “balanced” data frame, that is, it must have a number of rows that is a multiple of the number of choices.
If we want to compute the slope of the response function (marginal effects) when each of the predictors is fixed to its global mean, we can do:
nd <- TravelMode |>
summarize(across(c("wait", "gcost", "income", "size"),
function(x) rep(mean(x), 4)))
nd
#> wait gcost income size
#> 1 34.58929 110.8798 34.54762 1.742857
#> 2 34.58929 110.8798 34.54762 1.742857
#> 3 34.58929 110.8798 34.54762 1.742857
#> 4 34.58929 110.8798 34.54762 1.742857
avg_slopes(mod, newdata = nd, variables = c("income", "size"))
#>
#> Group Term Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> air income 6.656e-03 2.426e-03 2.743 0.0060816 1.901e-03 0.0114108
#> bus income -1.141e-03 9.454e-04 -1.207 0.2273591 -2.994e-03 0.0007117
#> car income 6.480e-06 2.032e-05 0.319 0.7497346 -3.334e-05 0.0000463
#> train income -5.521e-03 1.910e-03 -2.890 0.0038527 -9.265e-03 -0.0017767
#> air size -1.694e-01 5.877e-02 -2.883 0.0039386 -2.846e-01 -0.0542485
#> bus size 4.672e-02 2.723e-02 1.716 0.0862434 -6.656e-03 0.1000991
#> car size 1.358e-03 8.808e-04 1.542 0.1230361 -3.680e-04 0.0030845
#> train size 1.214e-01 4.447e-02 2.729 0.0063528 3.420e-02 0.2085119
#>
#> Prediction type: response
#> Columns: type, group, term, estimate, std.error, statistic, p.value, conf.low, conf.high
If we want to compute marginal effects with the gcost
and wait
fixed at their mean value, conditional on the choice of transportation mode:
nd <- TravelMode |>
group_by(mode) |>
summarize(across(c("wait", "gcost", "income", "size"), mean))
nd
#> # A tibble: 4 × 5
#> mode wait gcost income size
#> <fct> <dbl> <dbl> <dbl> <dbl>
#> 1 air 61.0 103. 34.5 1.74
#> 2 train 35.7 130. 34.5 1.74
#> 3 bus 41.7 115. 34.5 1.74
#> 4 car 0 95.4 34.5 1.74
avg_slopes(mod, newdata = nd, variables = c("income", "size"))
#>
#> Group Term Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> air income 0.0060149 0.002332 2.5793 0.0098996 0.0014443 0.010586
#> bus income -0.0007128 0.001461 -0.4878 0.6256800 -0.0035768 0.002151
#> car income 0.0054450 0.002288 2.3800 0.0173140 0.0009609 0.009929
#> train income -0.0107471 0.002563 -4.1925 2.7592e-05 -0.0157713 -0.005723
#> air size -0.2329273 0.056622 -4.1137 3.8936e-05 -0.3439050 -0.121950
#> bus size 0.0204397 0.034364 0.5948 0.5519797 -0.0469129 0.087792
#> car size 0.0678200 0.041226 1.6451 0.0999517 -0.0129810 0.148621
#> train size 0.1446676 0.047752 3.0295 0.0024493 0.0510747 0.238261
#>
#> Prediction type: response
#> Columns: type, group, term, estimate, std.error, statistic, p.value, conf.low, conf.high
We can also explore more complex alternatives. Here, for example, only one alternative is affected by cost reduction:
nd <- datagrid(mode = TravelMode$mode, newdata = TravelMode)
nd <- lapply(1:4, function(i) mutate(nd, gcost = ifelse(1:4 == i, 30, gcost)))
nd <- bind_rows(nd)
nd
#> individual choice wait vcost travel gcost income size mode
#> 1 1 no 35 48 486 30 35 2 air
#> 2 1 no 35 48 486 111 35 2 train
#> 3 1 no 35 48 486 111 35 2 bus
#> 4 1 no 35 48 486 111 35 2 car
#> 5 1 no 35 48 486 111 35 2 air
#> 6 1 no 35 48 486 30 35 2 train
#> 7 1 no 35 48 486 111 35 2 bus
#> 8 1 no 35 48 486 111 35 2 car
#> 9 1 no 35 48 486 111 35 2 air
#> 10 1 no 35 48 486 111 35 2 train
#> 11 1 no 35 48 486 30 35 2 bus
#> 12 1 no 35 48 486 111 35 2 car
#> 13 1 no 35 48 486 111 35 2 air
#> 14 1 no 35 48 486 111 35 2 train
#> 15 1 no 35 48 486 111 35 2 bus
#> 16 1 no 35 48 486 30 35 2 car
avg_slopes(mod, newdata = nd, variables = c("income", "size"))
#>
#> Group Term Estimate Std. Error z Pr(>|z|) 2.5 % 97.5 %
#> air income 8.240e-03 2.463e-03 3.3454 0.00082164 0.0034123 0.0130668
#> bus income -1.328e-03 1.304e-03 -1.0183 0.30853207 -0.0038827 0.0012276
#> car income 2.659e-05 4.321e-05 0.6154 0.53832121 -0.0000581 0.0001113
#> train income -6.939e-03 1.860e-03 -3.7298 0.00019166 -0.0105848 -0.0032924
#> air size -2.124e-01 6.027e-02 -3.5247 0.00042400 -0.3305707 -0.0943084
#> bus size 6.062e-02 3.791e-02 1.5993 0.10974996 -0.0136709 0.1349205
#> car size 2.378e-03 1.571e-03 1.5130 0.13028300 -0.0007024 0.0054575
#> train size 1.494e-01 4.286e-02 3.4870 0.00048852 0.0654413 0.2334331
#>
#> Prediction type: response
#> Columns: type, group, term, estimate, std.error, statistic, p.value, conf.low, conf.high