# Grid-Mesh: Gaussian Example: Traditional Simulation Study ---------------

# This script will implement the Traditional simulation study for Gaussian example to check what comes out from it such as the parameter recovery.
# The section for the simulation of the covariates and meshes are run before the code is moved to Balena, with the outputs saved in order to ensure there is no need to re-simulate for each iteration.
# We have code that generates the data and the A matrix for the INLA runs to keep the implementation of the simulation study as compact as possible.

# Author: Nadeen Khaleel


# Example of R CMD BATCH command from job slurm script for implementing this simulation study:
# R CMD BATCH --vanilla GridMeshOptimTradSG_final.R gm_tradsg.out

# Libraries ---------------------------------------------------------------

# Set up parallel computing
Nprocs <- 16

library(doParallel)
library(foreach)
parallelCluster <- parallel::makeCluster(Nprocs)
print(parallelCluster)
registerDoParallel(parallelCluster)


ptm <- proc.time()
foreach(k = 1:Nprocs) %dopar% {
  
  library(INLA)
  library(mvtnorm)
  library(sp)
  library(spatstat)
  library(raster)
  library(maptools)
  
  
  inla.setOption('keep'=TRUE, 'working.directory' = paste0('./inlalogstrad',k,'/temp')) # initially used to access logfiles, however for later simulation studies, this is not used as it is not needed to check logfiles for warnings/FFT messages.
  
  par.lic.filepath <- "./pardiso.lic" # file path for pardiso licence if in use
  inla.setOption(pardiso.license = par.lic.filepath)
  
  
  # Functions ---------------------------------------------------------------
  
  
  data.gen <- function(W,n.max,theta,disc,cov1.ras,cov2.ras){
    # W is the window
    # n.max is the largest discretisation at which the data will be created on
    # theta are the parameters required to generate the data.
    # !disc is the vector of discretisations across the window for the mesh resolutions! This is no longer required for the data generation.
    # covi.ras is the raster for the covariates
    
    minx <- W$xrange[1]; maxx <- W$xrange[2]; miny <- W$yrange[1]; maxy <- W$yrange[2]; gridx <- n.max; gridy <- n.max; midx <- 0.5*(maxx-minx)/gridx; midy <- 0.5*(maxy-miny)/gridy
    x <- seq(from=minx+midx,to=maxx-midx,by=2*midx)
    y <- seq(from=miny+midy,to=maxy-midy,by=2*midy)
    Grid <- expand.grid(x, y)
    
    ind.order <- data.frame(index=1:length(Grid$Var1),x=Grid$Var1,y=Grid$Var2)
    coord.inla <- expand.grid(sort(y,decreasing = T),x)
    coord.inla <- data.frame(x=coord.inla$Var2,y=coord.inla$Var1)
    new.order <- sapply(1:nrow(coord.inla),function(i){which(coord.inla[i,"x"]==ind.order[,"x"] & coord.inla[i,"y"]==ind.order[,"y"])})
    
    cellsize <- c(maxx-minx,maxy-miny)/n.max
    
    distance <- as.matrix(dist(Grid,diag=TRUE,upper=TRUE))
    # Simulate from this process with the parameters as defined above.
    beta.0 <- theta[[1]]; beta.1 <- theta[[2]]; beta.2 <- theta[[3]]; sigma.t <- theta[[4]]; rho <- theta[[5]];
    mod.Corr <- inla.matern.cov(nu=1,kappa=sqrt(8)/rho,x=distance,d=2,corr=TRUE)
    mod.Cov <- sigma.t^2*mod.Corr
    
    xy <- cbind(Grid$Var1,Grid$Var2)
    cov1 <- extract(cov1.ras,xy); cov2 <- extract(cov2.ras,xy)
    mu <- rep(beta.0,dim(Grid)[1]) + beta.1*cov1 + beta.2*cov2
    gf <- mvtnorm::rmvnorm(n=1,sigma=mod.Cov)
    y.i <- mu + gf + rnorm(n.max^2,mean=0,sd=sqrt(1/theta[[6]]))
    
    gf.ord <- gf[new.order]
    
    df.main <- data.frame(x=Grid[,1],y=Grid[,2],resp=t(y.i),cov1=cov1,cov2=cov2)
    df.main.ord <- df.main[new.order,]
    data <- df.main.ord
    
    return(list("data"=data,"field"=gf.ord))
  }
  
  # Generate A matrix and stack data
  A_stack.gen <- function(data,mesh,sigma.star,rho.star,cov1.ras,cov2.ras){
    the_spde <- inla.spde2.pcmatern(mesh,alpha=2,prior.range = c(rho.star[1],rho.star[2]),prior.sigma = c(sigma.star[1],sigma.star[2]))
    s.index <- inla.spde.make.index("field",n.spde=the_spde$n.spde)
    coords <- data[,c("x","y")]
    coordinates(coords) <- ~ x + y
    A <- inla.spde.make.A(mesh, loc=coords)
    stk <- inla.stack(data=list(resp=data$resp),A=list(A,1),effects=list(c(s.index,list(int=1)),list(cov1=data$cov1,cov2=data$cov2)),tag='est')
    
    return(list("spde"=the_spde,"A"=A,"stack.est"=stk))
  }
  
  
  # Covariate and Mesh Set-up --------------------------------------------------------
  
  # Set working directory to save the outputs before moving them over to Balena
  # library("rstudioapi")
  # # Either setwd() to the source file location, or run the following:
  # setwd(dirname(getActiveDocumentContext()$path))
  
  # # Will comment this out for future running of the script to prevent re-production of the code - will set the seed for covariate production and then for the rest of the script set the seed above when running.
  # W <- owin(c(0,5),c(0,5))
  # 
  # ## THESE WERE SIMULATED FROM THE GridMeshOptimSG_final.R BUT WILL KEEP THE CODE HERE
  # set.seed(625)
  # reg.poly <- W$type=="rectangle" # True if rectangluar window/false if different polygon.
  # 
  # # Points at which to simulate the data
  # nx <- 150; ny <- 150;
  # W.gridc <- gridcentres(W, nx, ny)
  # keep.c <- inside.owin(W.gridc$x,W.gridc$y,W)
  # 
  # W.gridc$x <- W.gridc$x[keep.c]; W.gridc$y <- W.gridc$y[keep.c];
  # 
  # 
  # # Simulate the covariate data (and save covariates too!)
  # # Create (fine) regular grid over the window in order to be able to simulate the covariate. Want to use intersect function to make sure that the overlayed grid points retained lie within the boundary, this is not needed for the rectangular grid, the function gridcentres should work fine there.
  # # However, we can use the function (in spatstat), inside.owin to create a logical vector for which points lie inside/outside the grid.
  # 
  # 
  # max.y <- max(W.gridc$y)
  # cov1.df <- data.frame(x=W.gridc$x,y=W.gridc$y)
  # cov1.df$val <- W.gridc$y/max.y + rnorm(length(W.gridc$x),mean = 0,sd=0.5)
  # 
  # max.norm <- sqrt(max(W.gridc$x)^2+max(W.gridc$y)^2)
  # cov2.df <- data.frame(x=W.gridc$x,y=W.gridc$y)
  # cov2.df$val <- sqrt(W.gridc$x^2+W.gridc$y^2)/max.norm + rnorm(length(W.gridc$x),mean = 0,sd = 0.5)
  # 
  # cov1.ras <- rasterFromXYZ(cov1.df)
  # cov2.ras <- rasterFromXYZ(cov2.df)
  # 
  # # no non-missing arguments to min; returning Infno non-missing arguments to max; returning -Infno non-missing arguments to min; returning Infno non-missing arguments to max; returning -Inf
  # 
  # # cov1.ras <- disaggregate(cov1.ras,fact=) if this is ever needed
  # 
  # c1 <- raster::extract(cov1.ras,cbind(W.gridc$x,W.gridc$y),method="bilinear")
  # c2 <- raster::extract(cov2.ras,cbind(W.gridc$x,W.gridc$y),method="bilinear")
  # 
  cov.name <- paste0("GridMeshSGSSCov.rda")
  # save(c1,c2,cov1.ras,cov2.ras,W.gridc,file=cov.name)
  # 
  
  # # Meshes
  # # Run this once then load up the meshes at the beginning.
  # 
  # ## THESE WERE SIMULATED FROM THE GridMeshOptimSG_final.R BUT WILL KEEP THE CODE HERE
  # set.seed(1250)
  # n.max <- 50 # (0.1x0.1)
  # disc <- matrix(c(10,10,20,20,25,25),ncol=2,byrow=T)
  # disc.full <- rbind(disc,c(n.max,n.max))
  # 
  # mesh.gen <- function(W,disc,save.name){
  #   mesh.size <- matrix(c(round((W$xrange[2]-W$xrange[1])/disc[,1],3),round((W$yrange[2]-W$yrange[1])/disc[,2],3)),ncol=2)
  #   mesh.names <- paste0("mesh",mesh.size[,1],mesh.size[,2])
  #   mesh.list <- vector(mode="list",length=length(mesh.names))
  #   names(mesh.list) <- mesh.names
  #   for (i in 1:dim(disc)[1]){
  #     M <- disc[i,1]; N <- disc[i,2]
  #     cellsize <- c((W$xrange[2]-W$xrange[1])/M,(W$yrange[2]-W$yrange[1])/N)
  #     # q <- quadrats(W,M,N)
  #     g <- gridcenters(W,M,N)
  #     x.ord <- sort(g$x); y.ord <- as.vector(matrix(rev(g$y),ncol=M,byrow=TRUE))
  #     # cell.area <- diff(q$xgrid)*diff(q$ygrid)
  #     df <- data.frame(x=x.ord,y=y.ord)
  # 
  #     coords <- df[,c("x","y")]
  #     coordinates(coords) <- ~ x + y
  #     boundary <- as(W,"SpatialPolygons") # For the meshes
  #     mesh <- inla.mesh.2d(loc=coords, boundary=boundary, max.edge=c(max(cellsize), max(cellsize)+0.5), min.angle=c(30, 21),
  #                          max.n=c(48000, 16000), ## Safeguard against large meshes.
  #                          max.n.strict=c(128000, 128000), ## Don't build a huge mesh!
  #                          cutoff=0.01, ## Filter away adjacent points.
  #                          offset=c(0.1, 1)) ## Offset for extra boundaries, if needed.
  #     # the_spde.prior.sub <- inla.spde2.pcmatern(mesh.sub,alpha=2,prior.range = c(rho.star[1],rho.star[2]),prior.sigma = c(sigma.star[1],sigma.star[2]))
  #     # mesh.list[[(i+1)]] <- mesh.sub
  #     mesh.list[[i]] <- mesh
  #   }
  #   save("mesh"=mesh.list,file=save.name)
  # }
  # 
  meshes.file <- "MeshesRegPolSG.rda"
  # mesh.gen(W,disc.full,meshes.file)
  
  
  # Simulations -------------------------------------------------------------
  
  
  # Iterations of Simulation
  N <- 1000
  if (round(N/Nprocs)==N/Nprocs){
    M.it <- rep(N/Nprocs,length=Nprocs)
  } else {
    M.it <- rep(0,length=Nprocs)
    M.it[1:(Nprocs-1)] <- floor(N/Nprocs)
    M.it[Nprocs] <- N - (Nprocs-1)*M.it[1]
  }
  
  fft.threshold <- 5 # how many "Fail to factorise Q" warnings accepted before a warning message is produced for user. This was set, however the FFT warnings are still placed in the output regardless.
  
  load(cov.name) # load the covariates
  load(meshes.file) # load meshes
  sim <- 0 # set sim to 1 if continuing from previous run
  
  
  # Prior for the Gaussian latent field covariance parameters
  alpha.rho <- 0.65; alpha.sigma <- 0.1; rho.0 <- 4; sigma.0 <- 2
  rho.star <- c(rho.0,alpha.rho) ; sigma.star <- c(sigma.0,alpha.sigma)
  
  W <- owin(c(0,5),c(0,5))
  
  # Saving the output - GridMeshSimpleGaussian
  save.file <- paste0("GridMeshSGTradSS",k,".rda")
  print(save.file)
  
  n.max <- 50 # (0.1x0.1)
  disc <- matrix(c(10,10,20,20,25,25),ncol=2,byrow=T)
  param <- c("Int","Beta1","Beta2","Sigma","Rho","Tau")
  N.grid <- n.max
  mesh.edge <- (W$xrange[2]-W$xrange[1])/c(disc[,1],n.max)
  N.g <- 1 # length(N.grid); 
  N.m <- length(mesh.edge); N.p <- length(param)
  
  mesh.ind <- paste0("Mesh",round(mesh.edge,3));
  
  if (sim==0){
    p.length <- 1
    
    # Final Data list
    list.mesh <- vector(mode="list",length=N.m)
    names(list.mesh) <- mesh.ind
    list.param <- vector(mode="list",length=3)
    names(list.param) <- c("est.df","run.df","mess.ls")
    list.param$est.df <- data.frame(beta0=rep(NA,M.it[k]),beta0.sd=rep(NA,M.it[k]),beta0.cil=rep(NA,M.it[k]),beta0.ciu=rep(NA,M.it[k]),beta1=rep(NA,M.it[k]),beta1.sd=rep(NA,M.it[k]),beta1.cil=rep(NA,M.it[k]),beta1.ciu=rep(NA,M.it[k]),beta2=rep(NA,M.it[k]),beta2.sd=rep(NA,M.it[k]),beta2.cil=rep(NA,M.it[k]),beta2.ciu=rep(NA,M.it[k]),sigma=rep(NA,M.it[k]),sigma.sd=rep(NA,M.it[k]),sigma.cil=rep(NA,M.it[k]),sigma.ciu=rep(NA,M.it[k]),rho=rep(NA,M.it[k]),rho.sd=rep(NA,M.it[k]),rho.cil=rep(NA,M.it[k]),rho.ciu=rep(NA,M.it[k]),tau=rep(NA,M.it[k]),tau.sd=rep(NA,M.it[k]),tau.cil=rep(NA,M.it[k]),tau.ciu=rep(NA,M.it[k]))
    list.param$run.df <- list(time=rep(NA,M.it[k]),cpo=vector(mode="list",length=M.it[k]),waic=rep(NA,M.it[k]),dic=rep(NA,M.it[k]))
    list.param$mess.ls <- list(error=rep(NA,M.it[k]),warning=rep(NA,M.it[k]),FFT=rep(NA,M.it[k]),message=vector(mode="list",length=M.it[k]))
    
    run.out <- lapply(list.mesh,function(x){x <- list.param})
    
  } else{
    load(save.file)
    p.length <- sum(!is.na(run.out[[N.m]]$est.df$beta0)) + 1
  } 
  
  # True theta values
  theta.tilde <- list(beta0.tilde=1,beta1.tilde=-2,beta2.tilde=2,sigma.tilde=1,rho.tilde=1.5,tau.tilde=25)
  
  for (i in p.length:M.it[k]){
    seed <- (k-1)*sum(M.it[0:(k-1)]) + i # (k-1) not really necessary but...
    set.seed(5*seed)
    data.sim <- data.gen(owin(c(0,5),c(0,5)),n.max,theta=theta.tilde,disc=disc,cov1.ras,cov2.ras)
    
    # 
    data <- data.sim$data
    for (l in 1:N.m){
      mesh <- mesh.list[[l]]
      print(mesh.edge[l])
      ind <- l
      
      str <- A_stack.gen(data,mesh,sigma.star,rho.star,cov1.ras,cov2.ras)
      start.time <- proc.time()
      fit.inla <- try(inla(resp ~ 0 + int + cov1 + cov2 + f(field,model=str$spde), data=inla.stack.data(str$stack.est),control.predictor=list(A=inla.stack.A(str$stack.est),compute=TRUE),control.compute = list(config=TRUE,cpo=TRUE,waic=TRUE,dic=TRUE),control.family=list(hyper=list(prec=list(prior="loggamma",param=c(2,0.1))))))
      end.time <- proc.time()
      
      
      if (class(fit.inla)=="try-error"){
        # If there is an error, print the value of the offset that caused the error, otherwise, carry on.
        run.out[[l]]$mess.ls$error[i] <- "ERROR"
      } else if (length(grep('Fail to factorize',fit.inla$logfile)) > fft.threshold) {
        run.out[[l]]$mess.ls$FFT[i] <- length(grep('Fail to factorize',fit.inla$logfile))
        if (length(grep('WARNING',fit.inla$logfile)) > 0){
          run.out[[l]]$mess.ls$warning[i] <- "WARNING"
          run.out[[l]]$mess.ls$message[[i]] <- fit.inla$logfile[(grep('WARNING',fit.inla$logfile))]
        }
      } else if (length(grep('WARNING',fit.inla$logfile)) > 0) {
        run.out[[l]]$mess.ls$warning[i] <- "WARNING"
        run.out[[l]]$mess.ls$message[[i]] <- fit.inla$logfile[(grep('WARNING',fit.inla$logfile))]
      } else {
        run.out[[l]]$mess.ls$FFT[i] <- length(grep('Fail to factorize',fit.inla$logfile)) # in case there were some messages, but below the threshold, want to keep track of any messages.
      }
      
      if (class(fit.inla)!="try-error"){
        time.taken <- unname(end.time[3] - start.time[3])
        
        # Put results of approximations into the output data set.
        run.out[[l]]$run.df$time[i] <- time.taken
        run.out[[l]]$run.df$cpo[[i]] <- fit.inla$cpo # Need to tell inla to calculate this, the CPO is not used and also not calculated in the simulation studies for the LGCPs (also takes a lot of memory!)
        run.out[[l]]$run.df$waic[i] <- fit.inla$waic$waic # Need to tell inla to calculate this too!
        run.out[[l]]$run.df$dic[i] <- fit.inla$dic$dic # Need to tell inla to calculate this too!
        # Posterior Means
        run.out[[l]]$est.df$beta0[i] <- fit.inla$summary.fixed$mean[1]
        run.out[[l]]$est.df$beta1[i] <- fit.inla$summary.fixed$mean[2]
        run.out[[l]]$est.df$beta2[i] <- fit.inla$summary.fixed$mean[3]
        run.out[[l]]$est.df$sigma[i] <- fit.inla$summary.hyperpar$mean[3]
        run.out[[l]]$est.df$rho[i] <- fit.inla$summary.hyperpar$mean[2]
        run.out[[l]]$est.df$tau[i] <- fit.inla$summary.hyperpar$mean[1]
        # Posterior SD
        run.out[[l]]$est.df$beta0.sd[i] <- fit.inla$summary.fixed$sd[1]
        run.out[[l]]$est.df$beta1.sd[i] <- fit.inla$summary.fixed$sd[2]
        run.out[[l]]$est.df$beta2.sd[i] <- fit.inla$summary.fixed$sd[3]
        run.out[[l]]$est.df$sigma.sd[i] <- fit.inla$summary.hyperpar$sd[3]
        run.out[[l]]$est.df$rho.sd[i] <- fit.inla$summary.hyperpar$sd[2]
        run.out[[l]]$est.df$tau.sd[i] <- fit.inla$summary.hyperpar$sd[1]
        # Posterior 2.5%
        run.out[[l]]$est.df$beta0.cil[i] <- fit.inla$summary.fixed$`0.025quant`[1]
        run.out[[l]]$est.df$beta1.cil[i] <- fit.inla$summary.fixed$`0.025quant`[2]
        run.out[[l]]$est.df$beta2.cil[i] <- fit.inla$summary.fixed$`0.025quant`[3]
        run.out[[l]]$est.df$sigma.cil[i] <- fit.inla$summary.hyperpar$`0.025quant`[3]
        run.out[[l]]$est.df$rho.cil[i] <- fit.inla$summary.hyperpar$`0.025quant`[2]
        run.out[[l]]$est.df$tau.cil[i] <- fit.inla$summary.hyperpar$`0.025quant`[1]
        # Posterior 97.5%
        run.out[[l]]$est.df$beta0.ciu[i] <- fit.inla$summary.fixed$`0.975quant`[1]
        run.out[[l]]$est.df$beta1.ciu[i] <- fit.inla$summary.fixed$`0.975quant`[2]
        run.out[[l]]$est.df$beta2.ciu[i] <- fit.inla$summary.fixed$`0.975quant`[3]
        run.out[[l]]$est.df$sigma.ciu[i] <- fit.inla$summary.hyperpar$`0.975quant`[3]
        run.out[[l]]$est.df$rho.ciu[i] <- fit.inla$summary.hyperpar$`0.975quant`[2]
        run.out[[l]]$est.df$tau.ciu[i] <- fit.inla$summary.hyperpar$`0.975quant`[1]
      }
      save(run.out,file=save.file)
      unlink(inla.getOption('working.directory'), recursive=TRUE)
    }
  }
  save(run.out,file=save.file)
  
}
# Stop the clock
print(proc.time() - ptm)

stopCluster(parallelCluster)
#################################################################################################

#Define arrays for storing result
rm(list=ls()) # Must finish with this.

