# INLA within MCMC Functions ----------------------------------------------

# Functions that will implement the INLA within MCMC algorithm from Gomez-Rubio and Rue (Gmez-Rubio, V. and Rue, H., 2018. Markov chain Monte Carlo with the integrated nested Laplace approximation. Statistics and Computing, 28(5), pp.1033-1051.).
# log.post.fun.inla() - calculates the log-posterior (unnormalised) for the state of the Markov Chain. This function either outputs a "success" where the likelihood, and log-posterior are output as well as the INLA marginals, otherwise if there is an INLA error, a failure is returned.
# mh.inlawmcmc() - performs the MH step of the INLA within MCMC algorithm, takes in values such as total iterations (which can also be extended when re-run), with proposal standard deviations, number of errors allowed before the function stops and returns and error as well as many other options described below, and outputs information on the run such as the likelihood and log-posterior, acceptance at each iteration as well as the states of the chains and the marginals from the INLA run at each iteration.
# lhyptohyp() - transform the internal hyperpar, from log() transform. Used within the BMA function in order to alter the internal.hyperparmeter output if it is needed.
# bma.inlawmcmc() - implement the BMA step of the INLA within MCMC algorithm, with burn-in and thinning variables taken for processing the MH output before the BMA step begins.

# Author: Nadeen Khaleel

library(INLA)
library(INLABMA)
library(mvtnorm)
library(mcmcse)
library(sp)
library(spatstat)
library(RandomFields)
library(rgeos)
library(maptools)
library(raster)
library(dplyr)

par.lic.filepath <- "./pardiso.lic" # file path for pardiso licence if in use
inla.setOption(pardiso.license = par.lic.filepath)


#################################################################################################
#FUNCTIONS
#################################################################################################

# Sub-functions -----------------------------------------------------------

log.post.fun.inla <- function(data_set,param,mesh,spde,prior.mean,prior.sd){
  # Calculate the log-posterior (unnormalised) for acceptance probability in MH algorithm.
  
  data <- data_set # don't want any potential issues from re-writing parts of the input data set
  data$OFF <- param[1]*data$cov1 + param[2]*data$cov2
  
  s.index <- inla.spde.make.index("field",n.spde=spde$n.spde)
  coords <- data[,c("x","y")]
  coordinates(coords) <- ~ x + y
  A <- inla.spde.make.A(mesh, loc=coords)
  stk <- inla.stack(data=list(resp=data$count),A=list(A,1),effects=list(c(s.index,list(int=1)),list(OFF=data$OFF,larea=log(data$area))),tag='est')
  
  fit.inla <- try(inla(resp ~ 0 + offset(larea) + int + offset(OFF) + f(field,model=spde), family="poisson", data=inla.stack.data(stk),control.predictor=list(A=inla.stack.A(stk),compute=TRUE),control.fixed=list(mean=list(int=0),prec=list(int=1/100)))) # ,control.compute = list(config=TRUE,cpo=TRUE,waic=TRUE,dic=TRUE)))
  
  if (class(fit.inla)!="try-error"){
    lik.val <- fit.inla$mlik[1]
    log.post.fun <- lik.val + dmvnorm(param,mean=prior.mean,sigma=diag(prior.sd^2,length(prior.mean)),log=TRUE)
    
    
    # INLA output
    list.names <- c("marginals.fixed","marginals.hyperpar","internal.marginals.hyperpar")
    temp.mod.marg <- vector(mode="list",length=length(list.names))
    names(temp.mod.marg) <- list.names  
    
    nlist.fixed <- names(fit.inla[["marginals.fixed"]])
    nlist.hyperpar <- names(fit.inla[["marginals.hyperpar"]])
    nlist.int.hyperpar <- names(fit.inla[["internal.marginals.hyperpar"]])
    fixed.marg <- lapply(nlist.fixed,function(nl){fit.inla[["marginals.fixed"]][[nl]]})
    names(fixed.marg) <- nlist.fixed
    hyperpar.marg <- lapply(nlist.hyperpar,function(nl){fit.inla[["marginals.hyperpar"]][[nl]]})
    names(hyperpar.marg) <- nlist.hyperpar
    internal.hyperpar.marg <- lapply(nlist.int.hyperpar,function(nl){fit.inla[["internal.marginals.hyperpar"]][[nl]]})
    names(internal.hyperpar.marg) <- nlist.int.hyperpar
    
    temp.mod.marg[["marginals.fixed"]] <- fixed.marg
    temp.mod.marg[["marginals.hyperpar"]] <- hyperpar.marg
    temp.mod.marg[["internal.marginals.hyperpar"]] <- internal.hyperpar.marg
    
    return(list(type="success",likval=lik.val,logpost=log.post.fun,marginals=temp.mod.marg))
  } else {
    return(list(type="fail"))
  }
}

####################################################################

# Main Functions ----------------------------------------------------------


# MCMC --------------------------------------------------------------------

mh.inlawmcmc <- function(data,spde,mesh,its,init=NULL,prior.mean,prior.sd,prop.sd,param.names,save.name="IwMMH.rda",restart=0,lb.iterr=100,lb.buff=50,it.err.lim=2,tot.err.lim=10)
{
  # This is for the model with covariates and therefore we want to perform the MCMC on the betas.
  
  # INPUTS:
  # data - count data frame of the point pattern
  # spde - the spde set up for INLA on the input mesh with the pre-chosen values of for the covariance priors.
  # mesh - pre-produced mesh for INLA run, to ensure consistency with every INLA run.
  # its - iterations for MCMC algorithm. (not total iterations, so if more iterations are needed, e.g. have already run 500, but want 1000, set its = 500, essentially how many more iterations)
  # init - initial values for the parameters of MCMC.
  # prior.mean - mean for the priors of the covariate effects being sampled.
  # prior.sd - sd for the priors of the covariate effects being sampled.
  # prop.sd - standard deviation for the proposal distribution for sampling.
  # param.names - character vector for parameter names
  # save.name - file to save output (inc. ".rda")
  # restart - 0 if starting MCMC run afresh, otherwise set to 1 re-starting with a previous run and MCMC output
  # lb.iterr - limit before which any INLA errors will cause the entire function to stop (essentially should be at least as large as the burn-in), otherwise we can proceed by sampling from previous parameter values to move past (while storing information) the error.
  # it.err.lim - number of errors for a single iterations for which if err>=it.err.lim for iteration i the function will stop and spit out an error.
  # tot.err.lim - the total number of errors within the MH run for which the function will stop and spit out an error.
  
  # OUTPUTS:
  # out - a list of length 2:
    # run - information from the MCMC:
      # theta (data frame) - parameter values for each state of the chain for each iteration
      # acc.rej (binary vector) - whether each step resulted in an acceptance or rejection
      # logpost.lik (data frame) - log-posterior and log likelihood for the respective state of the chain
      # ess (numeric) - effective sample size of the chain so far
      # error (data frame) - data frame containing any iteration of which there was an error and the (proposed?) parameter values for which there was an error
    # inla - inla output: the posterior marginals for BMA
  

  if (restart==0){ # not re-starting, set up new list
    out <- vector(mode="list",length=2)
    names(out) <- c("run","inla")
    out$run <- vector(mode="list",length=5)
    names(out$run) <- c("theta","acc.rej","logpost.lik","ess","error")
    out$run$theta <- data.frame(matrix(rep(NA,length(init)*its), ncol = length(init), nrow = (its)))
    colnames(out$run$theta) <- param.names
    out$run$acc.rej <- rep(NA,(its))
    out$run$logpost.lik <- data.frame(lik.val=rep(NA,(its)),log.post=rep(NA,(its)))
    out$run$ess <- rep(NA, length(init))
    out$run$error <-  data.frame(matrix(vector(), ncol = (2*length(init)+2), nrow = 0)) # how many errors in INLA for the algorithm and which iteration did it result from and what were the parameter values
    colnames(out$run$error) <- c("iteration",paste0(param.names,"_curr"),paste0(param.names,"_prop"),"replacement iterations")
    out$inla <- vector(mode="list",length=(its))
    names(out$inla) <- paste0("Model",1:(its))
 
    theta.c <- init
    log.post.c <- log.post.fun.inla(data,theta.c,mesh,spde,prior.mean,prior.sd)
    
    start.it <- 1
    
  } else {
    load(save.name)
    out.old <- out
    l.old <- sum(!is.na(out.old$run$theta[,1]))
    
    # Set-up new output
    out <- vector(mode="list",length=2)
    names(out) <- c("run","inla")
    out$run <- vector(mode="list",length=5)
    names(out$run) <- c("theta","acc.rej","logpost.lik","ess","error")
    out$run$theta <- data.frame(matrix(rep(NA,ncol(out.old$run$theta)*(l.old+its)), ncol = ncol(out.old$run$theta), nrow = (l.old+its)))
    colnames(out$run$theta) <- param.names
    out$run$acc.rej <- rep(NA,(l.old+its))
    out$run$logpost.lik <- data.frame(lik.val=rep(NA,(l.old+its)),log.post=rep(NA,(l.old+its)))
    out$inla <- vector(mode="list",length=(l.old+its))
    
    out$run$theta[1:l.old,] <- out.old$run$theta[1:l.old,]
    out$run$acc.rej[1:l.old] <- out.old$run$acc.rej[1:l.old]
    out$run$logpost.lik[1:l.old,] <- out.old$run$logpost.lik[1:l.old,]
    out$run$ess <- out.old$run$ess
    out$run$error <- out.old$run$error
    out$inla[1:l.old] <- out.old$inla[1:l.old]
    
    names(out$inla) <- paste0("Model",1:(l.old+its))
    
    theta.c <- as.numeric(as.vector(out$run$theta[l.old,]))
    log.post.c <- list(likval=out$run$logpost.lik$lik.val[l.old],logpost=out$run$logpost.lik$log.post[l.old],marginals=out$inla[[l.old]])
    
    start.it <- l.old + 1
  }
  
  rep.iteration.min <- lb.iterr + lb.buff # ? - don't want to just be sampling from 1 or 2 values? burn-in + some for lower bound?
  #######
#  # create progress bar
#  if (restart==0){
#    total.its.bar <- its
#  } else {
#    total.its.bar <- l.old + its  
#  }
#  # pb <- txtProgressBar(min = 0, max = its, style = 3)
#  pb <- txtProgressBar(min = 0, max = total.its.bar, style = 3)
#  Sys.sleep(0.1)
#  # update progress bar
#  setTxtProgressBar(pb, (start.it-1))
  for (i in start.it:(start.it + its - 1)){
    
    it.err <- 0
    next.state <- 0
    # While loop, if there is no INLA error, we move onto the next iteration, next.step=1, and break out of the while loop and move onto i+1.
    # However, if there is an INLA error we use a method from the INLABMA package, in particular the INLAMH code (Roger S. Bivand, Virgilio Gomez-Rubio, Havard Rue (2015). Spatial Data Analysis with R-INLA with Some Extensions. Journal of Statistical Software, 63(20), 1-31. URL http://www.jstatsoft.org/v63/i20/.), where either:
    # (a) - we are too early in our simulations, stop function and print error message.
    # (b) - we are `far` enough in our iterations that we can (as in INLAMH function) replace our current state of the chain by some randomly selected older parameter values and re-propose.
    # In the case of (b) we limit the number of times we allow this to happen per iteration (maybe even 2) - while also noting and saving the error table which contains the iteration at which the error occurred and the state of the chain where the error occurred. So either, too many errors (WARNING) or INLA ran fine and we move onto the next iteration.
    while (it.err < it.err.lim & next.state == 0){
     
      theta.p <- theta.c + rmvnorm(1,mean=rep(0,length(theta.c)),sigma=diag(prop.sd^2,length(theta.c)))
      log.post.p <- log.post.fun.inla(data,theta.p,mesh,spde,prior.mean,prior.sd)
      
      if (log.post.p$type=="success"){
        alpha <- exp((log.post.p$logpost)-(log.post.c$logpost)) # Using Random Walk MH
        
        if (alpha > runif(1)){
          out$run$acc.rej[i] <- 1
          out$run$theta[i,] <- theta.p
          out$run$logpost.lik$log.post[i] <- log.post.p$logpost
          out$run$logpost.lik$lik.val[i] <- log.post.p$likval
          out$inla[[i]] <- log.post.p$marginals
          
          # Proposed state is now current state
          theta.c <- theta.p
          log.post.c <- log.post.p
        } else {
          out$run$acc.rej[i] <- 0
          out$run$theta[i,] <- theta.c
          out$run$logpost.lik$log.post[i] <- log.post.c$logpost
          out$run$logpost.lik$lik.val[i] <- log.post.c$likval
          out$inla[[i]] <- log.post.c$marginals
        }
        if (i >= 1e3){
          out$run$ess <- sapply(1:ncol(out$run$theta),function(j){ess(out$run$theta[(1:i),j])})
        }
        
        next.state <- 1
        save(out,file=save.name)
        
        print(i)
        # for(i in 1:its){
#         Sys.sleep(0.1)
#          # update progress bar
#          setTxtProgressBar(pb, i)
        # }
      } else {
        
        # WHAT TO DO IF INLA HAS ISSUES
        # Using method in INLAMH function from INLABMA package (cited below), replace current state by older state and re-propose new state with accept/reject decision
        if (i <= rep.iteration.min){
          err.count <- nrow(out$run$error)
          out$run$error[(err.count + 1),1] <- i
          out$run$error[(err.count + 1),2:(length(param.names)+1)] <- theta.c
          out$run$error[(err.count + 1),(length(param.names)+2):(2*length(param.names)+1)] <- theta.p
          stop(paste0("INLA error occured in iteration ",i ," of MH run. Too early in the chain to replace current value."))
        } else {
          print(paste0("INLA run with parameter values ", theta.p, " at iteration ", i," resulted in error"))
          
          err.count <- nrow(out$run$error)
          out$run$error[(err.count + 1),1] <- i
          out$run$error[(err.count + 1),2:(length(param.names)+1)] <- theta.c
          out$run$error[(err.count + 1),(length(param.names)+2):(2*length(param.names)+1)] <- theta.p
          
          if (nrow(out$run$error) >= tot.err.lim){
            print(out$run$error)
            stop(paste0("There have been ",tot.err.lim," errors, reaching the maximum limit of INLA errors for this MCMC run."))
          }
          
          # Idea borrowed from INLAMH in INLABMA package: replace this with the (non-burnin) older samples..
          rep.i <- sample((rep.iteration.min:(i-2)),1) # in INLAMH: only look at "save" values -they don't keep all simulations
          
          out$run$error[(err.count + 1),(2*length(param.names)+2)] <- rep.i
          
          theta.c <- as.numeric(as.vector(out$run$theta[rep.i,]))
          log.post.c <- list(likval=out$run$logpost.lik$lik.val[rep.i],logpost=out$run$logpost.lik$log.post[rep.i],marginals=out$inla[[rep.i]])
          
          if ((i-1)>=1e3){
            out$run$ess <- sapply(1:ncol(out$run$theta),function(j){ess(out$run$theta[(1:(i-1)),j])})
          }
          it.err <- it.err + 1
        }
        if (it.err >= it.err.lim){
          stop(paste0("Repeatedly (",it.err,") received errors for iteration ",i))
        }
      }
    }
  }
#  close(pb)
#  ########
  
  return(out)
}



# BMA ---------------------------------------------------------------------


# Transform the internal.marginals.hyperpar to the output parameter forms (exp(log(x))).
lhyptohyp <- function(marg){inla.tmarginal(function(x){exp(x)},as.matrix(marg))}

bma.inlawmcmc <- function(mh.inlawmcmc.out,burnin=0,thin=1,save.name="IwMBMA.rda"){
  # Runs the Bayesian Model Averaging step of the INLA with MCMC algorithm, using the INLABMA function fitmargBMA2.
  # 
  # INPUTS:
  # mh.inlawmcmc.out - output from the MH step of the INLA with MCMC algorithm (mh.inlawmcmc function), taking in the posterior marginal outputs saved from the MH run at each state as well as the states of the chain.
  # burnin - how much of a burn in is required, no burn in is implemented in the mh.inlawmcm function.
  # thin - how much thinning is required to the full MCMC output.
  # save.name - file name for saving the output.
  # 
  # OUTPUTS:
  # theta - the thinned theta chain.
  # ess - the effective sample size of the thinned theta chain.
  # marginals - the approximation posterior marginals for the remaining parameters.
  
  # Thin the theta chains
  step_thin <- seq(from=(burnin+1),to=nrow(mh.inlawmcmc.out$run$theta),by=thin)
  theta_thin <- mh.inlawmcmc.out$run$theta[step_thin,]
  ess_thin <- sapply(1:ncol(theta_thin),function(j){ess(theta_thin[,j])})
  
  # Output list
  out <- vector(mode="list",length=3)
  names(out) <- c("run","inla.mh","inla.bma")
  out$run <- vector(mode="list",length=3)
  names(out$run) <- c("theta","logpost.lik","ess")
  out$run$theta <- theta_thin
  colnames(out$run$theta) <- colnames(mh.inlawmcmc.out$run$theta)
  out$run$logpost.lik <- out$run$logpost.lik[step_thin,]
  out$run$ess <- ess_thin
  out$inla.mh <- mh.inlawmcmc.out$inla[step_thin] # thinned list of marginals
  
  # Extract the marginals
  listmarg <- c("marginals.fixed", "marginals.hyperpar","internal.marginals.hyperpar")
  marg.all <- mh.inlawmcmc.out$inla[step_thin]
  
  ws <- rep(1/length(marg.all),length(marg.all))
  margeff <- mclapply(listmarg, function(X){INLABMA:::fitmargBMA2(marg.all, ws, X)})
  names(margeff) <- listmarg
  
  # BMA Approximated Posterior Marginals
  out$inla.bma <- margeff
  
  for (j in 1:length(out$inla.bma$internal.marginals.hyperpar)){
      out$inla.bma$transformed.internal.marginals.hyperpar[[j]] <- lhyptohyp(out$inla.bma$internal.marginals.hyperpar[[j]])
    }
    names(out$inla.bma$transformed.internal.marginals.hyperpar) <- paste0("Transformed ",names(out$inla.bma$internal.marginals.hyperpar))
  
  save(out,file=save.name)
  
  return(out)
}

# INLABMA package
# Roger S. Bivand, Virgilio Gomez-Rubio, Havard Rue (2015). Spatial Data Analysis with R-INLA with Some Extensions. Journal of Statistical Software, 63(20), 1-31. URL http://www.jstatsoft.org/v63/i20/.
#' @Article{,
#'   title = {Spatial Data Analysis with {R}-{INLA} with Some
#'     Extensions},
#'   author = {Roger S. Bivand and Virgilio G\'omez-Rubio and H{\aa}vard
#'       Rue},
#'     journal = {Journal of Statistical Software},
#'     year = {2015},
#'     volume = {63},
#'     number = {20},
#'     pages = {1--31},
#'     url = {http://www.jstatsoft.org/v63/i20/},
#'   }

