# Markov Chain Monte Carlo (MCMC) Examples 
# STAT 753
# March 5, 2020

##############################################
# Gibbs Sampler for MCMC - 2 examples below
##############################################


#########################################
# Simple Discrete Joint Distribution
#########################################

# X and Y are binary random variables (dependent) with a joint distribution
#   X/Y   0     1
#   0     0.6   0.1  
#   1     0.15  0.15

# Function returns 1 with probability p, 0 with probability 1-p
rbernoulli <- function(p) {
  return(1*runif(1)<p)
}

# Sample from distribution X given Y above
sample_XgivenY = function(y){
  if(y==0){
    x = rbernoulli(0.2) # returns 1 with probability 0.2; otherwise 0
  } else {
    x = rbernoulli(0.6)
  } 
  return(x)
}

# Sample from distribution Y given X above
sample_YgivenX = function(x){
  if(x==0){
    y = rbernoulli(1/7)
  } else {
    y = rbernoulli(0.5)
  }
  return(y)
}

# Simple Gibbs sampler
set.seed(100)
niter = 1000
X = rep(0,niter)
Y = rep(0,niter)
X[1]=1
Y[1]=1 # start from (1,1)
for(i in 2:niter){
  X[i] = sample_XgivenY(Y[i-1])
  Y[i] = sample_YgivenX(X[i])
}
res = data.frame(X=X,Y=Y)


# Print to screen the first 20 iterations
head(res,20)

# Summary of what proportions of the rows are of each type
table(data.frame(X=X,Y=Y))/niter



#########################################
# Bivariate Normal Distribution
#########################################

# Function to simulate from a bivariate normal distribution with mean 0 and variance 1 for the marginal
# distributions of X and Y, but a correlation of rho between the two variables.

rbvnorm <- function(n, rho) {
  x <- rnorm(n, 0, 1)
  y <- rnorm(n, rho * x, sqrt(1 - rho^2))
  cbind(x, y)
}

# Run the above function, rho=0.98 means highly correlated
n = 1000
bvnorm <- rbvnorm(n,0.98)

# Plots
par(mfrow=c(3,2))
plot(bvnorm,col=1:n, xlab="X values", ylab="Y values", main="Scatter Plot")
plot(bvnorm,type="l", xlab="X values", ylab="Y values", main="Scatter Plot")
plot(ts(bvnorm[,1]), col="blue", ylab="X values", main="Time Series")
plot(ts(bvnorm[,2]), col="blue", ylab="Y values", main="Time Series")
hist(bvnorm[,1],40, col="wheat", freq=F, xlab="X values", main="Histogram of Marginal Distribution")
hist(bvnorm[,2],40, col="wheat", freq=F, xlab="Y values", main="Histogram of Marginal Distribution")
par(mfrow=c(1,1))


# Run the above function, rho=0.3 means weakly correlated
n = 1000
bvnorm <- rbvnorm(n,0.3)

# Plot
par(mfrow=c(3,2))
plot(bvnorm,col=1:n, xlab="X values", ylab="Y values", main="Scatter Plot")
plot(bvnorm,type="l", xlab="X values", ylab="Y values", main="Scatter Plot")
plot(ts(bvnorm[,1]), col="blue", ylab="X values", main="Time Series")
plot(ts(bvnorm[,2]), col="blue", ylab="Y values", main="Time Series")
hist(bvnorm[,1],40, col="wheat", freq=F, xlab="X values", main="Histogram of Marginal Distribution")
hist(bvnorm[,2],40, col="wheat", freq=F, xlab="Y values", main="Histogram of Marginal Distribution")
par(mfrow=c(1,1))


#################
# Gibbs Sampler
#################

gibbs <- function(n, rho) 
{
  mat <- matrix(ncol = 2, nrow = n)
  x <- 0
  y <- 0
  mat[1, ] <- c(x, y)
  for (i in 2:n) {
    x <- rnorm(1, rho * y, sqrt(1 - rho^2))
    y <- rnorm(1, rho * x, sqrt(1 - rho^2))
    mat[i, ] <- c(x, y)
  }
  mat
}

# Run gibbs sampler
n = 10000
rho = 0.98
bvn_gibbs <- gibbs(n, rho)

# n = 1000
# rho = 0.3
# bvn_gibbs <- gibbs(n, rho)

# Plots
par(mfrow=c(3,2))
plot(bvn_gibbs,col=1:n, xlab="X values", ylab="Y values", main="Scatter Plot")
plot(bvn_gibbs,type="l", xlab="X values", ylab="Y values", main="Scatter Plot")
plot(ts(bvn_gibbs[,1]), col="aquamarine4", ylab="X values", main="Time Series")
plot(ts(bvn_gibbs[,2]), col="aquamarine4", ylab="Y values", main="Time Series")
hist(bvn_gibbs[,1],40, col="wheat", freq=F, xlab="X values", main="Histogram of Marginal Distribution")
hist(bvn_gibbs[,2],40, col="wheat", freq=F, xlab="Y values", main="Histogram of Marginal Distribution")
par(mfrow=c(1,1))

