# Stat 753 - Numerically solving SDEs
# Stochatic 4th order Runge-Kutta algorithm
# May 5, 2020

############################################################
# Stochastic Runge-Kutta-4 for Stratonovich calculus,
# as described in Hansen and Penland 2006.
############################################################

############################################################
# Example: Susceptible-Infected-Recovered (SIR) Model 
#
# dX = mu dt + B dW 
# where X(t) = [S(t) I(t)]^T -- 2D Vector
# and W(t) = [W_1(t) W_2(t)] -- 2D standard Brownian motion
############################################################

library(Matrix)

# A function to calculate the deterministic part of our SDE,
dSI = function(X, ps) { # i.e., the drift coefficient mu.
  S = X[["S"]] # Unpack our state variable values
  I = X[["I"]]
  b = ps[["beta"]] # Unpack paramters
  g = ps[["gamma"]]
  return(matrix(c(-b*S*I, b*S*I - g*I), nrow=2, ncol=1))
}

sdeRK4 = function(X, dt, ps) {
  S = X[["S"]]; I = X[["I"]]; # As in dSI above.
  b = ps[["beta"]]; g = ps[["gamma"]];
  
  # Standard normal values with with variance dt
  dW = matrix(rnorm(2,mean=0,sd=sqrt(dt)), nrow=2, ncol=1)
  
  # Square root of the second moment matrix B^2
  B = matrix(c(sqrt(b*S*I)*sqrt(g*I)+b*S*I, -b*S*I, -b*S*I, 
               g*I + sqrt(b*S*I)*sqrt(g*I)+b*S*I)/sqrt((sqrt(g*I) + sqrt(b*S*I))^2+b*S*I), nrow=2, ncol=2)
  
  
  # Define the terms of the stochastic RK4 scheme
  # See Hansen and Penland 2006 for details.
  k1 = dSI(X,ps)*dt + B %*% dW ## %*% is matrix multiplication
  k2 = dSI(X + .5 * k1, ps)*dt + B %*% dW
  k3 = dSI(X + .5 * k2, ps)*dt + B %*% dW
  k4 = dSI(X + k3, ps)*dt + B %*% dW
  return(X + (1 / 6) * (k1 + 2*k2 + 2*k3 + k4))
}


# Parameters for our numerical solution
Pars = c(beta=0.0003, gamma=0.1)
N = 1000
X = data.frame(S=998, I=2) # Assume that R(0)=0
Time = c(0)
tstep = 0.1

# Iterate sdeRK4() steps to simulate a full trajectory
set.seed(246) # For repeatable random number generation
i=2 # counter for our while loop below
while(Time[i-1] < 150 & X[i-1,2]>0) { # while t<150, I>0
  Time[i] = Time[i-1] + tstep
  X[i,] = sdeRK4(X[i-1,],tstep,Pars)
  i = i+1
}

# Combine columns into a single data frame
Xout = cbind(Time, X, R = N - rowSums(X))
tail(Xout,4)

### Output to the screen - size of each population
### S = susceptible, I = infected, R = recovered
# Time        S           I        R
# 1417 141.6 81.05222  0.02383346 918.9239
# 1418 141.7 81.05465  0.04914261 918.8962
# 1419 141.8 81.06478  0.02691306 918.9083
# 1420 141.9 81.07035 -0.01539963 918.9450


# If I(t) dropped below zero on the last iteration...
if(Xout[i-1,3] < 0) { # ... set I=0 and extend to t=150:
  Xout[i-1,3] = 0 # First set negative value to 0, then
  Xout[i,] = Xout[i-1,] # duplicate that last row and
  Xout[i,1] = 150 # set the final time to 150.
}


# Plot the numerical solution curves 
par(mfrow=c(1,3)) # Plot 3 panels in 1 row
plot(Xout[,1], Xout[,2], ylab="Susceptible (S)", ylim=c(0,N), xlab="Time", type="l", lwd=2, col="mediumseagreen"); 
  abline(h=0)
plot(Xout[,1], Xout[,3], ylab="Infected (I)", ylim=c(0,N), xlab="Time", type="l", lwd=2, col="tomato"); 
  abline(h=0)
plot(Xout[,1], Xout[,4], ylab="Recovered (R)", ylim=c(0,N), xlab="Time", type="l", lwd=2, col="darkorchid3"); 
  abline(h=0)

