Status

Estimated and checked against the book:

Libraries

library(tidyverse)
library(tidybayes)
library(rstan)
library(patchwork)
options(mc.cores = parallel::detectCores())

Section 8.1: Building an interaction

rugged <- read.csv('data/rugged.csv', sep = ';') %>%
          mutate(log_gdp = log(rgdppc_2000),
                 log_gdp_std = log_gdp / mean(log_gdp, na.rm = TRUE),
                 rugged_std = rugged / max(rugged, na.rm = TRUE),
                 cid = ifelse(cont_africa == 1, 1, 2),
                 Region = ifelse(cont_africa == 1, 'Africa', 'Not Africa')) %>%
          select(country, log_gdp_std, rugged_std, Region, cid) %>%
          filter(!is.na(log_gdp_std))

stan_data <- rugged %>% compose_data
stan_data$x_new <- seq(min(rugged$rugged_std), max(rugged$rugged_std), length.out = 100)
stan_program <- '
data {
  int<lower=1> n;        // number of observations
  vector[n] log_gdp_std; // outcome
  vector[n] rugged_std;  // regressor
  vector[100] x_new;     // prediction x
  int cid[n];            // africa indicator
}
parameters {
  real<lower=0> sigma;
  real a;
  real b;
}
transformed parameters {
  vector[n] mu;                    // location
  mu = a + b * (rugged_std - 0.215);
}
model {
  log_gdp_std ~ normal(mu, sigma);
  sigma ~ exponential(1);
  a ~ normal(1, 0.1);
  b ~ normal(0, 0.3);
}
generated quantities {
  matrix[100, 2] yhat;
  for (i in 1:100) {
    for (j in 1:2) {
      yhat[i, j] = a + b * (x_new[i] - 0.215);
    }
  }
}
'

m8.1 <- stan(model_code = stan_program, data = stan_data)
summary(m8.1, c('a', 'b', 'sigma'))$ summary
##              mean      se_mean          sd       2.5%         25%         50%
## a     1.000144344 0.0001816693 0.010928535  0.9784620  0.99282958 1.000081800
## b     0.002720112 0.0011843653 0.069678663 -0.1300724 -0.04537059 0.002386835
## sigma 0.138378765 0.0001298810 0.007763986  0.1240235  0.13296697 0.138034442
##              75%     97.5%    n_eff      Rhat
## a     1.00752281 1.0211280 3618.770 0.9995055
## b     0.05082129 0.1388013 3461.212 1.0000788
## sigma 0.14335786 0.1542270 3573.370 0.9999128
stan_program <- '
data {
  int<lower=1> n;        // number of observations
  vector[n] log_gdp_std; // outcome
  vector[n] rugged_std;  // regressor
  vector[100] x_new;     // prediction x
  int cid[n];            // africa indicator
}
parameters {
  real<lower=0> sigma;
  vector[2] a;
  real b;
}
transformed parameters {
  vector[n] mu;                    // location
  mu = a[cid] + b * (rugged_std - 0.215);
}
model {
  log_gdp_std ~ normal(mu, sigma);
  sigma ~ exponential(1);
  a ~ normal(1, 0.1);
  b ~ normal(0, 0.3);
}
generated quantities {
  matrix[100, 2] yhat;
  for (i in 1:100) {
    for (j in 1:2) {
      yhat[i, j] = a[j] + b * (x_new[i] - 0.215);
    }
  }
}
'

m8.2 <- stan(model_code = stan_program, data = stan_data)
summary(m8.2, c('a', 'b', 'sigma'))$ summary
##              mean      se_mean          sd       2.5%         25%         50%
## a[1]   0.87823286 0.0002716276 0.017202359  0.8446714  0.86664811  0.87800775
## a[2]   1.04646218 0.0001633879 0.010583849  1.0258689  1.03921333  1.04654412
## b     -0.05757425 0.0009733675 0.058330815 -0.1702711 -0.09651719 -0.05759421
## sigma  0.11432094 0.0001016560 0.006315514  0.1024028  0.11004820  0.11393490
##               75%      97.5%    n_eff     Rhat
## a[1]   0.89001925 0.91118895 4010.774 1.000378
## a[2]   1.05355768 1.06722397 4196.115 1.000015
## b     -0.01966054 0.05886718 3591.223 1.000693
## sigma  0.11821310 0.12791392 3859.677 1.000470
datplot <- m8.2 %>% 
           spread_draws(yhat[i, j]) %>%
           median_qi() %>%
           left_join(tibble(i = 1:100, ruggedness = stan_data$x_new), by = 'i') %>%
           mutate(Region = ifelse(j == 1, 'Africa', 'Not Africa'))

ggplot(datplot) + 
    geom_ribbon(aes(x = ruggedness, ymax = .upper, ymin = .lower, fill = Region), alpha = .1) +
    geom_line(aes(x = ruggedness, y = yhat, color = Region)) +
    geom_point(data = rugged, aes(rugged_std, log_gdp_std, color = Region)) +
    labs(x = 'Ruggedness (standardized)',
         y = 'log GDP (as proportion of mean)',
         color = '', fill = '')

stan_program <- '
data {
  int<lower=1> n;        // number of observations
  vector[n] log_gdp_std; // outcome
  vector[n] rugged_std;  // regressor
  vector[100] x_new;     // prediction x
  int cid[n];            // africa indicator
}
parameters {
  real<lower=0> sigma;
  vector[2] a;
  vector[2] b;
}
transformed parameters {
  vector[n] mu;                    // location
  for (i in 1:n) {
    mu[i] = a[cid[i]] + b[cid[i]] * (rugged_std[i] - 0.215);
  }
}
model {
  log_gdp_std ~ normal(mu, sigma);
  sigma ~ exponential(1);
  a ~ normal(1, 0.1);
  b ~ normal(0, 0.3);
}
generated quantities {
  matrix[100, 2] yhat;
  for (i in 1:100) {
    for (j in 1:2) {
      yhat[i, j] = a[j] + b[j] * (x_new[i] - 0.215);
    }
  }
}
'

m8.3 <- stan(model_code = stan_program, data = stan_data)
summary(m8.3, c('a', 'b', 'sigma'))$ summary
##             mean      se_mean          sd        2.5%         25%        50%
## a[1]   0.8937274 2.627894e-04 0.017266066  0.86022762  0.88195080  0.8934872
## a[2]   1.0426974 1.487089e-04 0.010436005  1.02264529  1.03541964  1.0428636
## b[1]   0.1646241 1.436960e-03 0.094685741 -0.01607232  0.09889606  0.1626060
## b[2]  -0.1766077 1.148739e-03 0.071814537 -0.31436246 -0.22669293 -0.1769578
## sigma  0.1116664 9.367801e-05 0.006149022  0.10008288  0.10752841  0.1113868
##              75%       97.5%    n_eff      Rhat
## a[1]   0.9052545  0.92764830 4316.891 0.9991032
## a[2]   1.0498282  1.06294908 4924.870 0.9998572
## b[1]   0.2293413  0.35228492 4341.901 0.9993918
## b[2]  -0.1280748 -0.03274406 3908.244 1.0002802
## sigma  0.1157290  0.12416776 4308.606 1.0002863
datplot <- m8.3 %>% 
           spread_draws(yhat[i, j]) %>%
           median_qi() %>%
           left_join(tibble(i = 1:100, ruggedness = stan_data$x_new), by = 'i') %>%
           mutate(Region = ifelse(j == 1, 'Africa', 'Not Africa'))

ggplot(datplot) + 
    geom_ribbon(aes(x = ruggedness, ymax = .upper, ymin = .lower, fill = Region), alpha = .1) +
    geom_line(aes(x = ruggedness, y = yhat, color = Region)) +
    geom_point(data = rugged, aes(rugged_std, log_gdp_std, color = Region)) +
    labs(x = 'Ruggedness (standardized)',
         y = 'log GDP (as proportion of mean)',
         color = '', fill = '')

Section 8.3: Continuous interactions

tulips <- read.csv('data/tulips.csv', sep = ';') %>%
          mutate(blooms_std = blooms / max(blooms),
                 water_cent = water - mean(water),
                 shade_cent = shade - mean(shade))
stan_data <- compose_data(tulips)
stan_data$pred <- expand_grid(water_cent = c(-1:1),
                        shade_cent = c(-1:1))
stan_data$pred_n <- nrow(stan_data$pred)

stan_program <- '
data {
  int<lower=1> n;      
  vector[n] blooms_std; 
  vector[n] water_cent; 
  vector[n] shade_cent; 
  int<lower=1> pred_n;
  matrix[pred_n, 2] pred;
}
parameters {
  real<lower=0> sigma;
  real a;
  real bw;
  real bs;
}
transformed parameters {
  vector[n] mu;
  mu = a + bw * water_cent + bs * shade_cent;
}
model {
  blooms_std ~ normal(mu, sigma);
  sigma ~ exponential(1);
  bw ~ normal(0, 0.25);
  bs ~ normal(0, 0.25);
}
generated quantities {
  vector[pred_n] yhat;
  for (i in 1:pred_n) {
    yhat[i] = a + bw * pred[i, 1] + bs * pred[i, 2];
  }
}
'

m8.4 <- stan(model_code = stan_program, data = stan_data)#, control = list(adapt_delta = 0.99), iter = 10000)

pred <- stan_data$pred
pred$i <- 1:nrow(pred)
datplot <- m8.4 %>% 
           spread_draws(yhat[i]) %>%
           mean_qi() %>%
           left_join(pred, by = 'i')
ggplot(datplot, aes(water_cent, yhat, ymin = .lower, ymax = .upper)) +
    geom_pointrange() +
    facet_grid(.~shade_cent) +
    theme_classic() +
    labs(x = 'Water', y = 'Predicted blooms')

stan_program <- '
data {
  int<lower=1> n;      
  vector[n] blooms_std; 
  vector[n] water_cent; 
  vector[n] shade_cent; 
  int<lower=1> pred_n;
  matrix[pred_n, 2] pred;
}
parameters {
  real<lower=0> sigma;
  real a;
  real bw;
  real bs;
  real bws;
}
transformed parameters {
  vector[n] mu;
  mu = a + bw * water_cent + bs * shade_cent + bws * water_cent .* shade_cent;
}
model {
  blooms_std ~ normal(mu, sigma);
  sigma ~ exponential(1);
  bw ~ normal(0, 0.25);
  bs ~ normal(0, 0.25);
  bws ~ normal(0, 0.25);
}
generated quantities {
  vector[pred_n] yhat;
  for (i in 1:pred_n) {
    yhat[i] = a + bw * pred[i, 1] + bs * pred[i, 2] + bws * pred[i, 1] * pred[i, 2];
  }
}
'

m8.5 <- stan(model_code = stan_program, data = stan_data)

pred <- stan_data$pred
pred$i <- 1:nrow(pred)
datplot <- m8.5 %>% 
           spread_draws(yhat[i]) %>%
           mean_qi() %>%
           left_join(pred, by = 'i')
ggplot(datplot, aes(water_cent, yhat, ymin = .lower, ymax = .upper)) +
    geom_pointrange() +
    facet_grid(.~shade_cent) +
    theme_classic() +
    labs(x = 'Water', y = 'Predicted blooms')