Estimated and checked against the book:
library(tidyverse)
library(tidybayes)
library(rstan)
library(patchwork)
options(mc.cores = parallel::detectCores())
rugged <- read.csv('data/rugged.csv', sep = ';') %>%
mutate(log_gdp = log(rgdppc_2000),
log_gdp_std = log_gdp / mean(log_gdp, na.rm = TRUE),
rugged_std = rugged / max(rugged, na.rm = TRUE),
cid = ifelse(cont_africa == 1, 1, 2),
Region = ifelse(cont_africa == 1, 'Africa', 'Not Africa')) %>%
select(country, log_gdp_std, rugged_std, Region, cid) %>%
filter(!is.na(log_gdp_std))
stan_data <- rugged %>% compose_data
stan_data$x_new <- seq(min(rugged$rugged_std), max(rugged$rugged_std), length.out = 100)
stan_program <- '
data {
int<lower=1> n; // number of observations
vector[n] log_gdp_std; // outcome
vector[n] rugged_std; // regressor
vector[100] x_new; // prediction x
int cid[n]; // africa indicator
}
parameters {
real<lower=0> sigma;
real a;
real b;
}
transformed parameters {
vector[n] mu; // location
mu = a + b * (rugged_std - 0.215);
}
model {
log_gdp_std ~ normal(mu, sigma);
sigma ~ exponential(1);
a ~ normal(1, 0.1);
b ~ normal(0, 0.3);
}
generated quantities {
matrix[100, 2] yhat;
for (i in 1:100) {
for (j in 1:2) {
yhat[i, j] = a + b * (x_new[i] - 0.215);
}
}
}
'
m8.1 <- stan(model_code = stan_program, data = stan_data)
summary(m8.1, c('a', 'b', 'sigma'))$ summary
## mean se_mean sd 2.5% 25% 50%
## a 1.000144344 0.0001816693 0.010928535 0.9784620 0.99282958 1.000081800
## b 0.002720112 0.0011843653 0.069678663 -0.1300724 -0.04537059 0.002386835
## sigma 0.138378765 0.0001298810 0.007763986 0.1240235 0.13296697 0.138034442
## 75% 97.5% n_eff Rhat
## a 1.00752281 1.0211280 3618.770 0.9995055
## b 0.05082129 0.1388013 3461.212 1.0000788
## sigma 0.14335786 0.1542270 3573.370 0.9999128
stan_program <- '
data {
int<lower=1> n; // number of observations
vector[n] log_gdp_std; // outcome
vector[n] rugged_std; // regressor
vector[100] x_new; // prediction x
int cid[n]; // africa indicator
}
parameters {
real<lower=0> sigma;
vector[2] a;
real b;
}
transformed parameters {
vector[n] mu; // location
mu = a[cid] + b * (rugged_std - 0.215);
}
model {
log_gdp_std ~ normal(mu, sigma);
sigma ~ exponential(1);
a ~ normal(1, 0.1);
b ~ normal(0, 0.3);
}
generated quantities {
matrix[100, 2] yhat;
for (i in 1:100) {
for (j in 1:2) {
yhat[i, j] = a[j] + b * (x_new[i] - 0.215);
}
}
}
'
m8.2 <- stan(model_code = stan_program, data = stan_data)
summary(m8.2, c('a', 'b', 'sigma'))$ summary
## mean se_mean sd 2.5% 25% 50%
## a[1] 0.87823286 0.0002716276 0.017202359 0.8446714 0.86664811 0.87800775
## a[2] 1.04646218 0.0001633879 0.010583849 1.0258689 1.03921333 1.04654412
## b -0.05757425 0.0009733675 0.058330815 -0.1702711 -0.09651719 -0.05759421
## sigma 0.11432094 0.0001016560 0.006315514 0.1024028 0.11004820 0.11393490
## 75% 97.5% n_eff Rhat
## a[1] 0.89001925 0.91118895 4010.774 1.000378
## a[2] 1.05355768 1.06722397 4196.115 1.000015
## b -0.01966054 0.05886718 3591.223 1.000693
## sigma 0.11821310 0.12791392 3859.677 1.000470
datplot <- m8.2 %>%
spread_draws(yhat[i, j]) %>%
median_qi() %>%
left_join(tibble(i = 1:100, ruggedness = stan_data$x_new), by = 'i') %>%
mutate(Region = ifelse(j == 1, 'Africa', 'Not Africa'))
ggplot(datplot) +
geom_ribbon(aes(x = ruggedness, ymax = .upper, ymin = .lower, fill = Region), alpha = .1) +
geom_line(aes(x = ruggedness, y = yhat, color = Region)) +
geom_point(data = rugged, aes(rugged_std, log_gdp_std, color = Region)) +
labs(x = 'Ruggedness (standardized)',
y = 'log GDP (as proportion of mean)',
color = '', fill = '')
stan_program <- '
data {
int<lower=1> n; // number of observations
vector[n] log_gdp_std; // outcome
vector[n] rugged_std; // regressor
vector[100] x_new; // prediction x
int cid[n]; // africa indicator
}
parameters {
real<lower=0> sigma;
vector[2] a;
vector[2] b;
}
transformed parameters {
vector[n] mu; // location
for (i in 1:n) {
mu[i] = a[cid[i]] + b[cid[i]] * (rugged_std[i] - 0.215);
}
}
model {
log_gdp_std ~ normal(mu, sigma);
sigma ~ exponential(1);
a ~ normal(1, 0.1);
b ~ normal(0, 0.3);
}
generated quantities {
matrix[100, 2] yhat;
for (i in 1:100) {
for (j in 1:2) {
yhat[i, j] = a[j] + b[j] * (x_new[i] - 0.215);
}
}
}
'
m8.3 <- stan(model_code = stan_program, data = stan_data)
summary(m8.3, c('a', 'b', 'sigma'))$ summary
## mean se_mean sd 2.5% 25% 50%
## a[1] 0.8937274 2.627894e-04 0.017266066 0.86022762 0.88195080 0.8934872
## a[2] 1.0426974 1.487089e-04 0.010436005 1.02264529 1.03541964 1.0428636
## b[1] 0.1646241 1.436960e-03 0.094685741 -0.01607232 0.09889606 0.1626060
## b[2] -0.1766077 1.148739e-03 0.071814537 -0.31436246 -0.22669293 -0.1769578
## sigma 0.1116664 9.367801e-05 0.006149022 0.10008288 0.10752841 0.1113868
## 75% 97.5% n_eff Rhat
## a[1] 0.9052545 0.92764830 4316.891 0.9991032
## a[2] 1.0498282 1.06294908 4924.870 0.9998572
## b[1] 0.2293413 0.35228492 4341.901 0.9993918
## b[2] -0.1280748 -0.03274406 3908.244 1.0002802
## sigma 0.1157290 0.12416776 4308.606 1.0002863
datplot <- m8.3 %>%
spread_draws(yhat[i, j]) %>%
median_qi() %>%
left_join(tibble(i = 1:100, ruggedness = stan_data$x_new), by = 'i') %>%
mutate(Region = ifelse(j == 1, 'Africa', 'Not Africa'))
ggplot(datplot) +
geom_ribbon(aes(x = ruggedness, ymax = .upper, ymin = .lower, fill = Region), alpha = .1) +
geom_line(aes(x = ruggedness, y = yhat, color = Region)) +
geom_point(data = rugged, aes(rugged_std, log_gdp_std, color = Region)) +
labs(x = 'Ruggedness (standardized)',
y = 'log GDP (as proportion of mean)',
color = '', fill = '')
tulips <- read.csv('data/tulips.csv', sep = ';') %>%
mutate(blooms_std = blooms / max(blooms),
water_cent = water - mean(water),
shade_cent = shade - mean(shade))
stan_data <- compose_data(tulips)
stan_data$pred <- expand_grid(water_cent = c(-1:1),
shade_cent = c(-1:1))
stan_data$pred_n <- nrow(stan_data$pred)
stan_program <- '
data {
int<lower=1> n;
vector[n] blooms_std;
vector[n] water_cent;
vector[n] shade_cent;
int<lower=1> pred_n;
matrix[pred_n, 2] pred;
}
parameters {
real<lower=0> sigma;
real a;
real bw;
real bs;
}
transformed parameters {
vector[n] mu;
mu = a + bw * water_cent + bs * shade_cent;
}
model {
blooms_std ~ normal(mu, sigma);
sigma ~ exponential(1);
bw ~ normal(0, 0.25);
bs ~ normal(0, 0.25);
}
generated quantities {
vector[pred_n] yhat;
for (i in 1:pred_n) {
yhat[i] = a + bw * pred[i, 1] + bs * pred[i, 2];
}
}
'
m8.4 <- stan(model_code = stan_program, data = stan_data)#, control = list(adapt_delta = 0.99), iter = 10000)
pred <- stan_data$pred
pred$i <- 1:nrow(pred)
datplot <- m8.4 %>%
spread_draws(yhat[i]) %>%
mean_qi() %>%
left_join(pred, by = 'i')
ggplot(datplot, aes(water_cent, yhat, ymin = .lower, ymax = .upper)) +
geom_pointrange() +
facet_grid(.~shade_cent) +
theme_classic() +
labs(x = 'Water', y = 'Predicted blooms')
stan_program <- '
data {
int<lower=1> n;
vector[n] blooms_std;
vector[n] water_cent;
vector[n] shade_cent;
int<lower=1> pred_n;
matrix[pred_n, 2] pred;
}
parameters {
real<lower=0> sigma;
real a;
real bw;
real bs;
real bws;
}
transformed parameters {
vector[n] mu;
mu = a + bw * water_cent + bs * shade_cent + bws * water_cent .* shade_cent;
}
model {
blooms_std ~ normal(mu, sigma);
sigma ~ exponential(1);
bw ~ normal(0, 0.25);
bs ~ normal(0, 0.25);
bws ~ normal(0, 0.25);
}
generated quantities {
vector[pred_n] yhat;
for (i in 1:pred_n) {
yhat[i] = a + bw * pred[i, 1] + bs * pred[i, 2] + bws * pred[i, 1] * pred[i, 2];
}
}
'
m8.5 <- stan(model_code = stan_program, data = stan_data)
pred <- stan_data$pred
pred$i <- 1:nrow(pred)
datplot <- m8.5 %>%
spread_draws(yhat[i]) %>%
mean_qi() %>%
left_join(pred, by = 'i')
ggplot(datplot, aes(water_cent, yhat, ymin = .lower, ymax = .upper)) +
geom_pointrange() +
facet_grid(.~shade_cent) +
theme_classic() +
labs(x = 'Water', y = 'Predicted blooms')