
# Traditional Simulation Study Gaussian Output ----------------------------

# This R script will take in the output from the traditional simulation study to output plots and a summary table for the timings and parameter recovery for the Gaussian example, where there is only one grid for the data and several meshes.


# Author: Nadeen Khaleel

# Libraries ---------------------------------------------------------------

library(purrr)
library(ggplot2)
library(rlist)
library(magrittr)
library(grid)
library(gridExtra)
library(stringr)
library(xtable)
library(tidyverse)


# Functions ---------------------------------------------------------------

# Produce results for timings of all of the runs
time.meansd.mesh <- function(grid.dat){
  N <- length(grid.dat[[1]]$run.df$time)
  meshes <- names(grid.dat); N.m <- length(meshes)
  means <- sapply(1:N.m,function(i){mean(grid.dat[[i]]$run.df$time)})
  sds <- sapply(1:N.m,function(i){sd(grid.dat[[i]]$run.df$time)})
  q <- qnorm(0.975)
  lower.c <- (as.vector(means) - q*as.vector(sds)); upper.c <- (as.vector(means) + q*as.vector(sds))
  meshes.s <- str_extract(meshes, "\\d+\\.*\\d*")
  mesh.lab <- paste0("Mesh ",meshes.s)
  out.time.df <- data.frame(Mesh=mesh.lab,mean=as.vector(means),sd=as.vector(sds),ci.l=lower.c,ci.u=upper.c)
  out.time.df$Mesh <- factor(as.character(out.time.df$Mesh),levels=unique(out.time.df$Mesh)[order(unique(out.time.df$Mesh),decreasing=TRUE)])
  return(out.time.df)
}

# Produce summary results for parameters
param.meansd.mesh <- function(grid.dat){
  N <- nrow(grid.dat[[1]]$est.df)
  meshes <- names(grid.dat); m <- as.factor(meshes); levels(m) <- meshes; N.m <- length(meshes)
  means <- sapply(1:N.m,function(i){colMeans(grid.dat[[i]]$est.df)})
  sds <- sapply(1:N.m,function(i){sapply(grid.dat[[i]]$est.df,sd,2)})
  params <- rownames(means); N.p <- length(params)
  q <- qnorm(0.975)
  lower.c <- (as.vector(means) - q*as.vector(sds)); upper.c <- (as.vector(means) + q*as.vector(sds))
  meshes.s <- str_extract(meshes, "\\d+\\.*\\d*")
  mesh.lab <- paste0("Mesh ",meshes.s)
  out.est.df <- data.frame(Mesh=rep(mesh.lab,each=N.p),Param=rep(params,N.m),mean=as.vector(means),sd=as.vector(sds),ci.l=lower.c,ci.u=upper.c)
  out.est.df$Mesh <- factor(as.character(out.est.df$Mesh),levels=unique(out.est.df$Mesh)[order(unique(out.est.df$Mesh),decreasing=TRUE)])
  out.est.df$Param <- factor(as.character(out.est.df$Param),levels=unique(out.est.df$Param))
  return(out.est.df)
}

combined.meanquantsd.mesh <- function(grid.dat,param){
  # Output Mean, Standard Error, Empirical Quantiles and 95% Intervals 
  N <- length(grid.dat[[1]]$run.df$time)
  grid.dat.sub <- lapply(1:length(run.out.final),function(i){run.out.final[[i]]$est.df[,param]})
  names(grid.dat.sub) <- names(grid.dat)
  meshes <- names(grid.dat); m <- as.factor(meshes); levels(m) <- meshes; N.m <- length(meshes)
  means <- sapply(1:N.m,function(i){colMeans(grid.dat.sub[[i]])})
  means.t <- sapply(1:N.m,function(i){mean(grid.dat[[i]]$run.df$time)})
  means.full <- rbind(means.t,means)
  rownames(means.full)[1] <- "time"
  sel.25 <- sapply(1:N.m,function(i){sapply(1:length(param),function(j,g){sort(g[,j])[25]},grid.dat.sub[[i]])})
  colnames(sel.25) <- names(grid.dat.sub)
  rownames(sel.25) <- param
  sel.25t <- rep(NA,length(N.m))
  sel.25.full <- rbind(sel.25t,sel.25)
  rownames(sel.25.full)[1] <- "time"
  sel.975 <- sapply(1:N.m,function(i){sapply(1:length(param),function(j,g){sort(g[,j])[975]},grid.dat.sub[[i]])})
  colnames(sel.975) <- names(grid.dat.sub)
  rownames(sel.975) <- param
  sel.975t <- rep(NA,length(N.m))
  sel.975.full <- rbind(sel.975t,sel.975)
  sds <- sapply(1:N.m,function(i){sapply(grid.dat.sub[[i]],sd,2)})
  sds.t <- sapply(1:N.m,function(i){sd(grid.dat[[i]]$run.df$time)})
  sds.full <- rbind(sds.t,sds)
  rownames(sds.full)[1] <- "time"
  params <- rownames(means.full); N.p <- length(params) # Now includes time!!! - so technically N.p + 1
  q <- qnorm(0.975)
  lower.c <- (as.vector(means.full) - q*as.vector(sds.full)); upper.c <- (as.vector(means.full) + q*as.vector(sds.full))
  meshes.s <- str_extract(meshes, "\\d+\\.*\\d*")
  mesh.lab <- paste0("Mesh ",meshes.s)
  out.df <- data.frame(Mesh=rep(mesh.lab,each=N.p),Param=rep(params,N.m),mean=as.vector(means.full),q0.025=as.vector(sel.25.full),q0.975=as.vector(sel.975.full),sd=as.vector(sds.full),ci.l=lower.c,ci.u=upper.c)
  out.df$Mesh <- factor(as.character(out.df$Mesh),levels=unique(out.df$Mesh)[order(unique(out.df$Mesh),decreasing=TRUE)])
  out.df$Param <- factor(as.character(out.df$Param),levels=unique(out.df$Param))
  return(out.df)
}

# Coverage of the 95% Credible Intervals
coverageparam.mesh <- function(grid.dat,param,true.theta){
  param.cil <- paste0(param,".cil")
  param.ciu <- paste0(param,".ciu")
  grid.dat.sub <- lapply(1:length(run.out.final),function(i){run.out.final[[i]]$est.df[,c(param,param.cil,param.ciu)]})
  names(grid.dat.sub) <- names(grid.dat)
  
  cov <- sapply(1:length(grid.dat.sub),function(i,true.theta,grid.dat.sub){sapply(1:length(param),function(j,true.theta,data){sum(data[,param.cil[j]] <=true.theta[j] & true.theta[j] <= data[,param.ciu[j]])},true.theta,grid.dat.sub[[i]])},true.theta,grid.dat.sub)
  colnames(cov) <- names(grid.dat.sub)
  rownames(cov) <- param
  cov <- cov/length(grid.dat.sub[[1]][[1]])
  cov.df <- as.data.frame(cov)
  cov.df <- tibble::rownames_to_column(cov.df,var="Parameter")
  cov.df <- cov.df %>% gather(Mesh,Coverage,Mesh0.5:Mesh0.1)
  cov.df$Parameter <- as.factor(cov.df$Parameter)
  cov.df$Mesh <- as.factor(cov.df$Mesh)
  
  return(cov.df)
}

# Main Function -----------------------------------------------------------

trad_sim_out <- function(data,meshes,true.theta,param,table=TRUE,round.vals=3,time.marker=60,latex.table=FALSE,plots=TRUE,plots.save=FALSE,plot.eachp=TRUE){
  # data - data from traditional simulation study
  # meshes - character string for names of the meshes for each list item of rank.values
  # true.theta - values of the fixed parameters for the simulations
  # param - vector of character strings for each parameter
  # table - generate tables for results?
  # round.vals - how many significant figures?
  # time.marker - only time plots, what time do we want to include as a horizontal red line?
  # latex.table - generate latex code for table of results?
  # plots - generate plots of data?
  # plots.save - save plots or only print?
  # plot.eachp - plot each parameter parameter plot individually?
  
  l.m <- length(unique(meshes))
  
  # Full data frame with time/parameter summaries for plots and output tables
  full.df <- combined.meanquantsd.mesh(data,param)
  full.df.plot <- full.df
  meshes.names <- as.character(full.df.plot$Mesh)
  ord.meshes.names <- unique(meshes.names)[order(unique(meshes.names),decreasing = T)]
  for (i in 1:l.m){
    meshes.names <- str_replace(meshes.names,ord.meshes.names[i],paste0("Mesh ",i)) # assume grid and mesh labels are in order
  }
  full.df.plot$Mesh <- meshes.names
  full.df.plot$Mesh <- factor(as.character(full.df.plot$Mesh),levels=unique(full.df.plot$Mesh)[order(unique(full.df.plot$Mesh),decreasing=FALSE)])
  
  # Data frame for coverage of 95% credible intervals for plotting
  cov.df <- coverageparam.mesh(data,param,true.theta)
  meshes.names <- as.character(cov.df$Mesh)
  ord.meshes.names <- unique(meshes.names)[order(unique(meshes.names),decreasing = T)]
  for (i in 1:l.m){
    meshes.names <- str_replace(meshes.names,ord.meshes.names[i],paste0("Mesh ",i)) # assume grid and mesh labels are in order
  }
  cov.df$Mesh <- meshes.names
  cov.df$Mesh <- factor(as.character(cov.df$Mesh),levels=unique(cov.df$Mesh)[order(unique(cov.df$Mesh),decreasing=FALSE)])
  # Setting up labels for the plots
  var.lab.orig <- as.character(cov.df$Parameter)
  num <- as.numeric(str_extract(var.lab.orig, "[0-9]+"))
  char <- str_extract(var.lab.orig, "[aA-zZ]+")
  
  plot.x.lab <- char
  plot.x.lab
  for (i in which(!is.na(num))){
    plot.x.lab[i] <- paste0(plot.x.lab[i],"[",num[i],"]")
  }
  plot.x.lab
  
  if (sum(plot.x.lab=="Int")>0){
    plot.x.lab[plot.x.lab=="Int"] <- "beta[0]"
  }
  cov.df$Label <- plot.x.lab
  
  if (table==TRUE){
    bold <- function(x){
      paste0('{\\bfseries ', x, '}')
    }
    msd.time <- paste0(round(full.df$mean[full.df$Param=="time"],round.vals)," (",round(full.df$ci.l[full.df$Param=="time"],round.vals),",",round(full.df$ci.u[full.df$Param=="time"],round.vals),")")
    msd.param <- paste0(round(full.df$mean[full.df$Param!="time"],round.vals)," (",round(full.df$q0.025[full.df$Param!="time"],round.vals),",",round(full.df$q0.975[full.df$Param!="time"],round.vals),")")
    msd <- vector(mode="character",length=(length(msd.time)+length(msd.param)))
    msd[full.df$Param=="time"]<- msd.time
    msd[full.df$Param!="time"] <- msd.param
    rn <- unique(full.df$Param); cn <- unique(full.df$Mesh)
    full.matrix <- matrix(msd,length(rn),length(cn),dimnames = list(rn,cn))
    full.table <- as.table(full.matrix)
    
    if (latex.table==TRUE){
      full.table.latex <- xtable(full.matrix)
      align(full.table.latex) <- "c|c|c|c|c|"
      digits(full.table.latex) <- 4
      
      param.ind <- str_extract(rownames(full.table.latex), "\\d+\\.*\\d*")
      sub.ind <- which(!is.na(str_extract(rownames(full.table.latex), "\\d+\\.*\\d*")))
      rownames(full.table.latex)[sub.ind] <- paste0(str_extract(rownames(full.table.latex)[sub.ind],"[a-z]+"),"_",param.ind[sub.ind])
      rownames(full.table.latex)[2:nrow(full.table.latex)] <- paste0("$ ","\\",rownames(full.table.latex)[2:nrow(full.table.latex)]," $")
    }
  }
  
  N <- nrow(run.out.final[[1]]$est.df)
  
  mesh.names <- paste0("Mesh ",1:length(run.out.final))
  
  # Data frame for the model selection criterion, WAIC and DIC
  crit.df <- data.frame(Mesh=rep(mesh.names,each=(2*N)),Criterion=rep(rep(c("WAIC","DIC"),each=N),4))
  start.ind <- 1
  for (i in 1:length(run.out.final)){
    crit.df$Value[start.ind:(start.ind+length(run.out.final[[i]]$run.df$waic)-1)] <- run.out.final[[i]]$run.df$waic
    start.ind <- start.ind + length(run.out.final[[i]]$run.df$waic)
    crit.df$Value[start.ind:(start.ind+length(run.out.final[[i]]$run.df$dic)-1)] <- run.out.final[[i]]$run.df$dic
    start.ind <- start.ind + length(run.out.final[[i]]$run.df$dic)
  }
  crit.df$Mesh <- factor(as.character(crit.df$Mesh),levels=unique(crit.df$Mesh)[order(unique(crit.df$Mesh),decreasing=FALSE)])
  crit.df$Criterion <- as.factor(crit.df$Criterion)
  crit.df$Value <- as.numeric(crit.df$Value)
  
  critsum.df <- crit.df
  critsum.df %<>%
    group_by(Mesh,Criterion) %>%
    summarise(mean = mean(Value),sd=sd(Value), n = n())
  
  # Data frame for the errors, warnings (FFT and otherwise)
  err.df <- data.frame(Mesh=rep(mesh.names,each=(3*N)),Criterion=rep(c("Error","Warning","FFT"),each=N),Run=rep(1:N,12))
  start.ind <- 1
  for (i in 1:length(run.out.final)){
    err <- run.out.final[[i]]$mess.ls$error
    err[!is.na(err)] <- 1
    err[is.na(err)] <- 0
    warn <- run.out.final[[i]]$mess.ls$warning
    warn[!is.na(warn)] <- 1
    warn[is.na(warn)] <- 0
    fft <- run.out.final[[i]]$mess.ls$FFT
    err.df$Value[start.ind:(start.ind+length(err)-1)] <- err
    start.ind <- start.ind + length(err)
    err.df$Value[start.ind:(start.ind+length(warn)-1)] <- warn
    start.ind <- start.ind + length(warn)
    err.df$Value[start.ind:(start.ind+length(fft)-1)] <- fft
    start.ind <- start.ind + length(fft)
  }
  err.df$Mesh <- factor(as.character(err.df$Mesh),levels=unique(err.df$Mesh)[order(unique(err.df$Mesh),decreasing=FALSE)])
  err.df$Criterion <- as.factor(err.df$Criterion)
  err.df$Value <- as.numeric(err.df$Value)
  
  errsum.df <- err.df
  errsum.df %<>%
    group_by(Mesh,Criterion) %>%
    summarise(mean = mean(Value),sd=sd(Value), n = n())
  
  if (plots==TRUE){
    if (plots.save==TRUE){
      # Time Plot
      p1 <- ggplot(full.df.plot[full.df.plot$Param=="time",],aes(x=Mesh,y=mean,colour=Mesh)) + geom_point(aes(size=2)) + geom_errorbar(aes(ymin=ci.l,ymax=ci.u,size=1),width=0.1) + geom_hline(yintercept=time.marker,color="red",linetype="dashed",size=1.5) + ggtitle("Average Time Taken for INLA Run") + scale_x_discrete(limits = levels(full.df.plot$Mesh)) + ylab("Time") + theme(plot.title = element_text(size=35,hjust = 0.5),axis.title = element_text(size=25),axis.text = element_text(size=25),strip.text.x = element_text(size=20),legend.position = "none")
      print(p1)
      ggsave("SGTradTimevMesh.pdf",plot=p1,width = 15,height = 10,device="pdf")
      
      # Parameter Estimates
      params <- as.character(unique(full.df.plot[full.df.plot$Param!="time",]$Param)); N.p <- length(params);

      param.ind <- str_extract(params, "\\d+\\.*\\d*")
      sub.ind <- which(!is.na(param.ind))
      param.lab <- params
      param.lab[sub.ind] <- paste0(str_extract(param.lab[sub.ind],"[a-z]+"),"[",param.ind[sub.ind],"]")
      plots <- lapply(1:N.p,function(i,names,data,lab,vals){ggplot(data[data$Param==names[i],],aes(x=Mesh,y=mean,colour=Mesh)) + scale_x_discrete(limits = levels(data$Mesh)) + xlab("Mesh") + geom_point(aes(size=2)) + geom_errorbar(aes(ymin=q0.025,ymax=q0.975,size=1),width=0.1) + geom_hline(yintercept=true.theta[i],color="red",linetype="dashed",size=1.5) + ggtitle(parse(text=paste0("Parameter~Estimates~of~", lab[i]))) + ylab("Estimate") + theme(plot.title = element_text(size=35,hjust = 0.5),axis.title = element_text(size=25),axis.text = element_text(size=25),strip.text.x = element_text(size=20),legend.position = "none")},params,full.df.plot,param.lab,true.theta)
      
      all.param <- do.call(grid.arrange,plots)
      
      ggsave("SGTradParameterEstimateswQuant.pdf",plot=all.param,width=15,height=10,device="pdf")
      
      if (plot.eachp==TRUE){
        for (i in 1:N.p){
          ggsave(paste0("SGTrad",params[i],"EstimateswQuant.pdf"),plot=plots[[i]],width=15,height=10,device="pdf")
        }
      }
      
      p2 <- ggplot(crit.df) + geom_histogram(aes(Value,fill=Criterion),binwidth = 25) + facet_wrap(vars(Mesh)) + theme(plot.title = element_text(size=35,hjust = 0.5),axis.title = element_text(size=25),axis.text = element_text(size=25),strip.text.x = element_text(size=20),legend.text = element_text(size=15),legend.title = element_text(size=20))
      p3 <- ggplot(critsum.df) + geom_point(aes(x=Mesh,y=mean,col=Criterion,size=2)) + xlab("Mesh") + facet_wrap(vars(Criterion)) + theme(plot.title = element_text(size=35,hjust = 0.5),axis.title = element_text(size=25),axis.text = element_text(size=25),strip.text.x = element_text(size=20),legend.position = "none")
      
      print(p2)
      print(p3)
      
      ggsave("SGTradDICWAIC.pdf",plot=p2,width = 15,height = 10,device="pdf")
      ggsave("SGTradDICWAICMean.pdf",plot=p3,width = 15,height = 10,device="pdf")
      
      p4 <- ggplot(err.df) + geom_point(aes(x=Run,y=Value,col=Criterion,size=2)) + facet_wrap(vars(Mesh,Criterion),ncol = 3) + theme(plot.title = element_text(size=35,hjust = 0.5),axis.title = element_text(size=25),axis.text = element_text(size=25),strip.text.x = element_text(size=20),legend.position = "none")
      p5 <- ggplot(errsum.df) + geom_point(aes(x=Mesh,y=mean,col=Criterion,size=2)) + scale_x_discrete(limits = levels(full.df.plot$Mesh)) + xlab("Mesh") + facet_wrap(vars(Criterion)) + theme(plot.title = element_text(size=35,hjust = 0.5),axis.title = element_text(size=25),axis.text = element_text(size=25),axis.text.x=element_text(angle = 45, vjust = 1, hjust = 1),strip.text.x = element_text(size=20),legend.position = "none")
      
      print(p4)
      print(p5)
      
      ggsave("SGTradErr.pdf",plot=p4,width = 15,height = 10,device="pdf")
      ggsave("SGTradErrMean.pdf",plot=p5,width = 15,height = 10,device="pdf")
      
      p6 <- ggplot(cov.df) + geom_point(aes(x=Mesh,y=Coverage,col=Mesh,size=2)) + scale_x_discrete(limits = levels(cov.df$Mesh)) + xlab("Mesh") + geom_hline(yintercept=0.95,color="red",linetype="dashed",size=1.5) + facet_wrap(vars(Label), labeller = label_parsed) + ggtitle("Coverage of 95% Credible Intervals") + theme(plot.title = element_text(size=35,hjust = 0.5),axis.title = element_text(size=25),axis.text = element_text(size=25),axis.text.x=element_text(angle = 45, vjust = 1, hjust = 1),strip.text.x = element_text(size=20),legend.position = "none")
      print(p6)
      
      ggsave("SGCoverage.pdf",plot=p6,width = 15,height = 10,device="pdf")
    } else {
      # Time Plot
      p1 <- ggplot(full.df.plot[full.df.plot$Param=="time",],aes(x=Mesh,y=mean,colour=Mesh)) + geom_point(aes(size=2)) + geom_errorbar(aes(ymin=ci.l,ymax=ci.u,size=1),width=0.1) + geom_hline(yintercept=time.marker,color="red",linetype="dashed",size=1.5) + ggtitle("Average Time Taken for INLA Run") + scale_x_discrete(limits = levels(full.df.plot$Mesh)) + ylab("Time") + theme(plot.title = element_text(size=35,hjust = 0.5),axis.title = element_text(size=25),axis.text = element_text(size=25),strip.text.x = element_text(size=20),legend.position = "none")
      print(p1)
      
      # Parameter Estimates
      params <- as.character(unique(full.df.plot[full.df.plot$Param!="time",]$Param)); N.p <- length(params);

      param.ind <- str_extract(params, "\\d+\\.*\\d*")
      sub.ind <- which(!is.na(param.ind))
      param.lab <- params
      param.lab[sub.ind] <- paste0(str_extract(param.lab[sub.ind],"[a-z]+"),"[",param.ind[sub.ind],"]")
      plots <- lapply(1:N.p,function(i,names,data,lab){ggplot(data[data$Param==names[i],],aes(x=as.factor(str_extract(Mesh,"[-+.e0-9]*\\d")),y=mean,colour=Mesh)) + xlab("Mesh") + geom_point(aes(size=2)) + geom_errorbar(aes(ymin=q0.025,ymax=q0.975,size=1),width=0.1) + geom_hline(yintercept=true.theta[i],color="red",linetype="dashed",size=1.5) + scale_x_discrete(limits = levels(full.df.plot$Mesh)) + ggtitle(parse(text=paste0("Parameter~Estimates~of~", lab[i]))) + ylab("Estimate") + theme(plot.title = element_text(size=35,hjust = 0.5),axis.title = element_text(size=25),axis.text = element_text(size=25),strip.text.x = element_text(size=20),legend.position = "none")},params,full.df,param.lab)
      
      all.param <- do.call(grid.arrange,plots)
      
      print(all.param)
      
      for (i in 1:N.p){
        print(plots[[i]])
      }
      
      p2 <- ggplot(crit.df) + geom_histogram(aes(Value,fill=Criterion),binwidth = 25) + facet_wrap(vars(Mesh)) + theme(plot.title = element_text(size=35,hjust = 0.5),axis.title = element_text(size=25),axis.text = element_text(size=25),strip.text.x = element_text(size=20),legend.text = element_text(size=15),legend.title = element_text(size=20))
      p3 <- ggplot(critsum.df) + geom_point(aes(x=Mesh,y=mean,col=Criterion,size=2)) + xlab("Mesh") + facet_wrap(vars(Criterion)) + scale_x_discrete(limits = levels(full.df.plot$Mesh)) + theme(plot.title = element_text(size=35,hjust = 0.5),axis.title = element_text(size=25),axis.text = element_text(size=25),strip.text.x = element_text(size=20),legend.position = "none")
      
      print(p2)
      print(p3)
      
      p4 <- ggplot(err.df) + geom_point(aes(x=Run,y=Value,col=Criterion,size=2)) + facet_wrap(vars(Mesh,Criterion),ncol = 3) + theme(plot.title = element_text(size=35,hjust = 0.5),axis.title = element_text(size=25),axis.text = element_text(size=25),strip.text.x = element_text(size=20),legend.position = "none")
      p5 <- ggplot(errsum.df) + geom_point(aes(x=Mesh,y=mean,col=Criterion,size=2)) + xlab("Mesh") + facet_wrap(vars(Criterion)) + theme(plot.title = element_text(size=35,hjust = 0.5),axis.title = element_text(size=25),axis.text = element_text(size=25),axis.text.x=element_text(angle = 45, vjust = 1, hjust = 1),strip.text.x = element_text(size=20),legend.position = "none")
      
      print(p4)
      print(p5)
      
      p6 <- ggplot(cov.df) + geom_point(aes(x=Mesh,y=Coverage,col=Mesh,size=2)) + scale_x_discrete(limits = levels(cov.df$Mesh)) + xlab("Mesh") + geom_hline(yintercept=0.95,color="red",linetype="dashed",size=1.5) + facet_wrap(vars(Label), labeller = label_parsed) + ggtitle("Coverage of 95% Credible Intervals") + theme(plot.title = element_text(size=35,hjust = 0.5),axis.title = element_text(size=25),axis.text = element_text(size=25),strip.text.x = element_text(size=20),legend.position = "none")
      print(p6)
    }
  }
  
  if (table==TRUE){
    if (latex.table==FALSE){
      return("results.table"=full.table)
    } else {
      return(list("results.table"=full.table,"latex.results.table"=full.table.latex))
    }
  }
}
