Script 3 - MCMC - Gibbs sampling

Author

Roberto Ascari

In this section, we show two Gibbs samplers: the first one aims to estimate the parameters \((\mu, \sigma^2)\) of a Normal likelihood, whereas the second one estimates the parameters of a linear regression model.

Gibbs Sampling for the Mean and Variance of a Normal Distribution

We consider a scenario characterized where \(Y| \mu, \sigma^2 \sim N(\mu, \sigma^2)\). Furthermore, we impose prior independence, meaning that \(\pi(\mu, \sigma^2) = \pi(\mu)\pi(\sigma^2)\). Then, we specify the following priors:

  • \(\mu \sim N(\mu_0, \sigma^2_0)\);

  • \(\sigma^2 \sim Inv.Gamma(a_0, b_0)\).

In a simulation framework, we need to specify the true values of the parameters generating the data.

mu.true <- 10
sigma2.true <- 5
n <- 100

Then, we can generate a sample…

set.seed(42)
y <- rnorm(n, mu.true, sqrt(sigma2.true))
ybar <- mean(y); ybar
[1] 10.07271

… and set hyperparameters:

mu0 <- 0
sigma2.0 <- 1000

alpha <- 10
beta <- 10

The prior expectation is

beta/(alpha-1)
[1] 1.111111

and the prior variance is:

(beta^2)/((alpha-1)*(alpha-2))
[1] 1.388889

Initialization

We initialize each element by drawing a value from the corresponding prior:

B <- 1000
mu.chain <- numeric(B)
sigma2.chain <- numeric(B)

set.seed(42)
mu.chain[1] <- rnorm(1, mu0, sqrt(sigma2.0))
sigma2.chain[1] <- 1/rgamma(1, alpha, rate = beta)

Then, we can implement the Gibbs sampling by generating each parameter from its full-conditional distribution. The full-conditionals are the following:

  • \(\mu| \sigma^2, \textbf{y} \sim N\left(\frac{n\sigma^2_0 \bar{y} + \sigma^2 \mu_0}{n\sigma^2_0 + \sigma^2}, \frac{\sigma^2 \sigma^2_0}{n\sigma^2_0 + \sigma^2}\right)\);

  • \(\sigma^2| \mu, \textbf{y} \sim Inv.Gamma\left(a_0 + \frac{n}{2}, b_0 + \frac{\sum_{i=1}^n (y_i - \mu)^2}{2}\right)\).

set.seed(42)
for(b in 2:B){
  # Draw mu from the F.C.
  mu.chain[b] <-
    rnorm(1, 
          (n*sigma2.0*ybar + sigma2.chain[b-1]*mu0)/(n*sigma2.0+sigma2.chain[b-1]),
          sqrt((sigma2.chain[b-1]*sigma2.0)/(n*sigma2.0+sigma2.chain[b-1])))
  
  # Draw sigma2 from the F.C.
  sigma2.chain[b] <-
    1/rgamma(1, alpha + n/2,
             rate = beta + .5*sum((y-mu.chain[b])^2))
  
}
plot(mu.chain, sigma2.chain, pch=20)
points(mu.true, sigma2.true, col="#D55E00", pch=20)

We can now generate some diagnostic plots.

# Traceplots:
par(mfrow=c(2,1))
plot(mu.chain[-1], pch=20, type="l")
plot(sigma2.chain[-1], pch=20, type="l")

par(mfrow=c(1,1))
mu.means <- numeric(B)
sigma2.means <- numeric(B)

mu.means[1] <- mu.chain[1]
sigma2.means[1] <- sigma2.chain[1]

for(b in 2:B){
  mu.means[b] <- mean(mu.chain[2:b])
  sigma2.means[b] <- mean(sigma2.chain[2:b])
}

plot(mu.means[-1], pch=20, type="l", ylim=c(9.2, 10.8))
abline(h=mu.true, col="#D55E00")

plot(sigma2.means[-1], pch=20, type="l", ylim=c(3,7))
abline(h=sigma2.true, col="#D55E00")

par(mfrow=c(2,1))
acf(mu.chain)
acf(sigma2.chain)

par(mfrow=c(1,1))

Once we have a chain for each parameter, we need to remove the warm-up (i.e., the part of the chains for which we cannot assume the convergence to the stationary distribution).

warm_perc <- .5

mu.new <- mu.chain[round(B*warm_perc+1):B]
sigma2.new <- sigma2.chain[round(B*warm_perc+1):B]

par(mfrow=c(2,1))
plot(mu.new, pch=20, type="l", ylim=c(9.2, 10.8))
abline(h=mu.true, col="#D55E00")

plot(sigma2.new, pch=20, type="l", ylim=c(3,7))
abline(h=sigma2.true, col="#D55E00")

par(mfrow=c(1,1))

The new vectors can be used to compute estimates, CSs, and probabilities:

mean(mu.new)
[1] 10.07776
quantile(mu.new, probs = c(.025, .975))
     2.5%     97.5% 
 9.660935 10.497956 
mean(sigma2.new)
[1] 4.710458
quantile(sigma2.new, probs = c(.025, .975))
    2.5%    97.5% 
3.694683 6.045644 
mean(mu.new > 10)
[1] 0.648
hist(mu.new, prob=T, xlab=expression(mu),
     main="Posterior distribution of mu")
lines(density(mu.new), col="#D55E00")
abline(v=mu.true, col="black", lty="dashed")

hist(sigma2.new, prob=T, xlab=expression(sigma2),
     main="Posterior distribution of sigma2")
lines(density(sigma2.new), col="#D55E00")
abline(v=sigma2.true, col="black", lty="dashed")

plot(mu.new, sigma2.new, pch = 20)
points(mu.true, sigma2.true, pch = 19, col = "#D55E00")

We can define a function to fit this Gibbs sampler more easily on new data.

normal_GS <- function(y, B = 5000, 
                      mu0 = 0, sigma2.0 = 1000, 
                      alpha = 10, beta = 10, 
                      warm_perc = .5, seed=42){
  
  mu.chain <- numeric(B)
  sigma2.chain <- numeric(B)
  ybar <- mean(y)
  n <- length(y)
  
  # Initialization:
  set.seed(seed)
  mu.chain[1] <- rnorm(1, mu0, sqrt(sigma2.0))
  sigma2.chain[1] <- 1/rgamma(1, alpha, rate = beta)
  
  for(b in 2:B){
    # Draw mu from the F.C.
    mu.chain[b] <-
      rnorm(1,
            (n*sigma2.0*ybar + sigma2.chain[b-1]*mu0)/(n*sigma2.0+sigma2.chain[b-1]),
            sqrt((sigma2.chain[b-1]*sigma2.0)/(n*sigma2.0+sigma2.chain[b-1])))
    
    # Draw sigma2 from the F.C.
    sigma2.chain[b] <-
      1/rgamma(1, alpha + n/2,
               rate = beta + .5*sum((y-mu.chain[b])^2))
  }
  
  mu.new <- mu.chain[round(B*warm_perc+1):B]
  sigma2.new <- sigma2.chain[round(B*warm_perc+1):B]
  
  return(cbind(mu.chain = mu.new, sigma2.chain = sigma2.new))
}

Gala dataset (I)

library(faraway)
Warning in check_dep_version(): ABI version mismatch: 
lme4 was built with Matrix ABI version 1
Current Matrix ABI version is 0
Please re-install lme4 from source or restore original 'Matrix' package
data(gala)
str(gala)
'data.frame':   30 obs. of  7 variables:
 $ Species  : num  58 31 3 25 2 18 24 10 8 2 ...
 $ Endemics : num  23 21 3 9 1 11 0 7 4 2 ...
 $ Area     : num  25.09 1.24 0.21 0.1 0.05 ...
 $ Elevation: num  346 109 114 46 77 119 93 168 71 112 ...
 $ Nearest  : num  0.6 0.6 2.8 1.9 1.9 8 6 34.1 0.4 2.6 ...
 $ Scruz    : num  0.6 26.3 58.7 47.4 1.9 ...
 $ Adjacent : num  1.84 572.33 0.78 0.18 903.82 ...
y <- gala$Species
gala_GS <- normal_GS(y, B = 10000)

str(gala_GS)
 num [1:5000, 1:2] 92.5 44.1 32.6 24.1 102.8 ...
 - attr(*, "dimnames")=List of 2
  ..$ : NULL
  ..$ : chr [1:2] "mu.chain" "sigma2.chain"
head(gala_GS)
      mu.chain sigma2.chain
[1,]  92.52505     9972.609
[2,]  44.11655     8021.474
[3,]  32.56115     9979.261
[4,]  24.07501    10765.908
[5,] 102.78028     7780.960
[6,]  63.76942     7516.838
plot(gala_GS, pch = 20)

hist(gala_GS[,1], prob = T, 
     main="Posterior distribution of mu")
lines(density(gala_GS[,1]), col="#D55E00")

hist(gala_GS[,2], prob = T, 
     main="Posterior distribution of sigma2")
lines(density(gala_GS[,2]), col="#D55E00")

colMeans(gala_GS)
    mu.chain sigma2.chain 
    67.11751   8299.55326 
t(apply(gala_GS, 2, function(x) quantile(x, probs=c(.025, .975))))
                   2.5%       97.5%
mu.chain       37.25436    96.47314
sigma2.chain 5468.21124 12628.64946
mean(gala_GS[,1] > 70)
[1] 0.4212
mean(gala_GS[,1] > 70 & gala_GS[,2] < 5500)
[1] 0.017

Gibbs Sampling for the parameters of a Linear Regression Model

rm(list=ls())

# True values:
beta <- c(-2,5,3)
sigma2 <- 6
# Generating data:
n <- 100
set.seed(42)
X <- matrix(rnorm(2*n, 0, 50), ncol=2)
X <- cbind(rep(1,n), X)

y <- as.numeric(X%*%beta + rnorm(n, 0, sqrt(sigma2)))

Classical OLS/ML estimates:

summ <- summary(lm(y~X[,2]+X[,3])); summ

Call:
lm(formula = y ~ X[, 2] + X[, 3])

Residuals:
    Min      1Q  Median      3Q     Max 
-6.3114 -1.6213 -0.2021  1.5605  6.1741 

Coefficients:
             Estimate Std. Error  t value Pr(>|t|)    
(Intercept) -1.995674   0.249606   -7.995 2.75e-12 ***
X[, 2]       4.992960   0.004795 1041.303  < 2e-16 ***
X[, 3]       3.004178   0.005522  543.997  < 2e-16 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 2.483 on 97 degrees of freedom
Multiple R-squared:  0.9999,    Adjusted R-squared:  0.9999 
F-statistic: 7.085e+05 on 2 and 97 DF,  p-value: < 2.2e-16
beta0 <- rep(0,3)
Sigma0 <- 100*diag(3)

a0 <- 10
b0 <- 10
# Prior Expectation for sigma2
b0/(a0-1)
[1] 1.111111
# Prior Variance for sigma2
(b0^2)/((a0-1)*(a0-2))
[1] 1.388889
B <- 5000
beta.chain <- matrix(NA, ncol=3, nrow=B)
sigma2.chain <- numeric(B)

beta.chain[1,] <- rep(0,3)
sigma2.chain[1] <- 1

Gibbs sampling:

library(MASS)

for(b in 2:B){
  
  Sigma.n <- solve(solve(Sigma0) + (t(X)%*%X)/sigma2.chain[b-1])
  beta.n <- Sigma.n %*% ((solve(Sigma0)%*%beta0) + (t(X)%*%y)/sigma2.chain[b-1])
  
  beta.chain[b,] <- mvrnorm(n=1, mu=beta.n, Sigma=Sigma.n)
  
  sigma2.chain[b] <- 
    1/rgamma(1, a0 + .5*n,
             rate = b0 + 
               0.5*(t(y-X%*%beta.chain[b,])%*%(y-X%*%beta.chain[b,])))
}

Diagnostic plots:

# Traceplots:
par(mfrow=c(2,2))
plot(beta.chain[,1], pch=20, type="l");abline(h=beta[1], col="#D55E00")
plot(beta.chain[,2], pch=20, type="l");abline(h=beta[2], col="#D55E00")
plot(beta.chain[,3], pch=20, type="l");abline(h=beta[3], col="#D55E00")
plot(sigma2.chain, pch=20, type="l");abline(h=sigma2, col="#D55E00")

par(mfrow=c(1,1))

par(mfrow=c(2,2))
plot(beta.chain[-1,1], pch=20, type="l");abline(h=beta[1], col="#D55E00")
plot(beta.chain[-1,2], pch=20, type="l");abline(h=beta[2], col="#D55E00")
plot(beta.chain[-1,3], pch=20, type="l");abline(h=beta[3], col="#D55E00")
plot(sigma2.chain[-1], pch=20, type="l");abline(h=sigma2, col="#D55E00")

par(mfrow=c(1,1))
beta_mean <- matrix(NA, ncol = 3, nrow = nrow(beta.chain))
sigma2.means <- numeric(length(sigma2.chain))

beta_mean[1,] <- beta.chain[1,]
sigma2.means[1] <- sigma2.chain[1]

for(b in 2:nrow(beta.chain)){
  beta_mean[b,1] <- mean(beta.chain[2:b,1])
  beta_mean[b,2] <- mean(beta.chain[2:b,2])
  beta_mean[b,3] <- mean(beta.chain[2:b,3])
  sigma2.means[b] <- mean(sigma2.chain[2:b])
}

par(mfrow=c(2,2))
plot(beta_mean[-1,1], pch=20, type="l")
plot(beta_mean[-1,2], pch=20, type="l")
plot(beta_mean[-1,3], pch=20, type="l")
plot(sigma2.means[-1], pch=20, type="l")

par(mfrow=c(1,1))

acf(beta_mean[,1])

acf(beta_mean[,2])

acf(beta_mean[,3])

acf(sigma2.means)

# Removing the warmu-up:
warm_perc <- .5

beta.new <- beta.chain[round(B*warm_perc+1):B,]
sigma2.new <- sigma2.chain[round(B*warm_perc+1):B]

# Computing estimates:
colMeans(beta.new)
[1] -1.990287  4.992966  3.004100
mean(sigma2.new)
[1] 5.379681
# Histogram and kernel estimate:
par(mfrow=c(2,2))
hist(beta.new[,1], prob=T, xlab=expression(beta0),
     main="Posterior distribution of beta0")
lines(density(beta.new[,1]), col="#D55E00")
##############################################
hist(beta.new[,2], prob=T, xlab=expression(beta1),
     main="Posterior distribution of beta1")
lines(density(beta.new[,2]), col="#D55E00")
##############################################
hist(beta.new[,3], prob=T, xlab=expression(beta2),
     main="Posterior distribution of beta2")
lines(density(beta.new[,3]), col="#D55E00")
##############################################
hist(sigma2.new, prob=T, xlab=expression(sigma2),
     main="Posterior distribution of sigma2")
lines(density(sigma2.new), col="#D55E00")

par(mfrow=c(1,1))
round(var(beta.new), 5)
         [,1]   [,2]    [,3]
[1,]  0.05480 -3e-05 0.00011
[2,] -0.00003  2e-05 0.00000
[3,]  0.00011  0e+00 0.00003
round((2.483^2)*summ$cov.unscaled,5)
            (Intercept) X[, 2]  X[, 3]
(Intercept)     0.06231 -4e-05 0.00013
X[, 2]         -0.00004  2e-05 0.00000
X[, 3]          0.00013  0e+00 0.00003

Defining a function:

rm(list=ls())

LM_GS <- function(y, X, B = 5000, 
                  beta0 = rep(0, ncol(X)), 
                  Sigma0 = diag(ncol(X)), 
                  a0 = 10, b0 = 10, warm_perc = .5, seed=42){
  
  beta.chain <- matrix(NA, ncol=ncol(X), nrow=B)
  sigma2.chain <- numeric(B)
  n <- length(y)
  
  # Initialization:
  beta.chain[1,] <- rep(0, ncol(X))
  sigma2.chain[1] <- 1
  
  library(MASS)
  for(b in 2:B){
    Sigma.n <- solve(solve(Sigma0) + (t(X)%*%X)/sigma2.chain[b-1])
    beta.n <- Sigma.n %*% ((solve(Sigma0)%*%beta0) + (t(X)%*%y)/sigma2.chain[b-1])
    
    beta.chain[b,] <- mvrnorm(n=1, mu=beta.n, Sigma=Sigma.n)
    
    sigma2.chain[b] <- 
      1/rgamma(1, a0 + .5*n,
               rate = b0 + 
                 0.5*(t(y-X%*%beta.chain[b,])%*%(y-X%*%beta.chain[b,])))
  }
  
  beta.new <- beta.chain[round(B*warm_perc+1):B,]
  sigma2.new <- sigma2.chain[round(B*warm_perc+1):B]
  
  return(list(beta = beta.new, sigma2 = sigma2.new))
}

Gala dataset (II)

gala_lm <- lm(Species ~ ., data = gala[,-2])
summ <- summary(gala_lm); summ

Call:
lm(formula = Species ~ ., data = gala[, -2])

Residuals:
     Min       1Q   Median       3Q      Max 
-111.679  -34.898   -7.862   33.460  182.584 

Coefficients:
             Estimate Std. Error t value Pr(>|t|)    
(Intercept)  7.068221  19.154198   0.369 0.715351    
Area        -0.023938   0.022422  -1.068 0.296318    
Elevation    0.319465   0.053663   5.953 3.82e-06 ***
Nearest      0.009144   1.054136   0.009 0.993151    
Scruz       -0.240524   0.215402  -1.117 0.275208    
Adjacent    -0.074805   0.017700  -4.226 0.000297 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 60.98 on 24 degrees of freedom
Multiple R-squared:  0.7658,    Adjusted R-squared:  0.7171 
F-statistic:  15.7 on 5 and 24 DF,  p-value: 6.838e-07
# Definition of y and X:
y <- gala$Species
X <- model.matrix(Species ~ ., data = gala[,-2])

# Fitting the model:
lm_gala <- LM_GS(y, X)
str(lm_gala)
List of 2
 $ beta  : num [1:2500, 1:6] -0.685 0.788 -0.497 0.317 -0.501 ...
 $ sigma2: num [1:2500] 1677 2015 1942 1978 2534 ...
# Extracting the elements of the chain:
betas <- lm_gala$beta
colnames(betas) <- colnames(X)
sigma2 <- lm_gala$sigma2

colMeans(betas)
(Intercept)        Area   Elevation     Nearest       Scruz    Adjacent 
 0.02312137 -0.02602818  0.32983671 -0.00566936 -0.20657747 -0.07616102 
t(apply(betas, 2, function(x) quantile(x, probs=c(.025, .975))))
                   2.5%        97.5%
(Intercept) -1.90212462  1.943814977
Area        -0.05552613  0.003312421
Elevation    0.26572467  0.392750020
Nearest     -1.14952611  1.231034257
Scruz       -0.48552725  0.070630919
Adjacent    -0.09987118 -0.051690537
mean(sigma2)
[1] 2049.94
round(cov(betas), 5)
            (Intercept)     Area Elevation  Nearest    Scruz Adjacent
(Intercept)     0.97945  0.00058  -0.00232  0.02294 -0.00695  0.00054
Area            0.00058  0.00023  -0.00038  0.00136  0.00035  0.00007
Elevation      -0.00232 -0.00038   0.00104 -0.00467 -0.00106 -0.00025
Nearest         0.02294  0.00136  -0.00467  0.38000 -0.05012  0.00196
Scruz          -0.00695  0.00035  -0.00106 -0.05012  0.01880 -0.00008
Adjacent        0.00054  0.00007  -0.00025  0.00196 -0.00008  0.00016