
# Locating Timing Errors --------------------------------------------------

# This R script, takes the outputs from the SBC runs and searches for those with timing errors.
# It will also take the outputs and save them under the same file name with _TIMEERRORFINAL added as a suffix.
# I will then also produce a list with for all the data sets with the simulations, grid, meshes and allocated 'time' of the errors as data frame.

# Author: Nadeen Khaleel

# Setwd and Load Libraries ------------------------------------------------

library("rstudioapi")
# Either setwd() to the source file location, or run the following:
setwd(dirname(getActiveDocumentContext()$path))
library(stringr)


# Generate Time Error Data Frame ------------------------------------------

# Now want to be in the outputs directory
setwd("../IRREGPOLGLGCP_OUTPUT")

# Find out which processes have an error labeled "TIME ERROR", regardless of which 'time'.
sink("TimingErrorProcessesSearchandSetUp.txt")
time.err.count.vec <- rep(0,40)
for (i in 1:40){
  load(paste0("SBCCompletedPre-TimeErrorReruns/GridMeshIrregPolLGCPSBCSS",i,".rda"))
  for (j in 1:4){
    for (l in 1:4){
      time.err.count.vec[i] <- time.err.count.vec[i] + sum(str_detect(run.out[[j]][[l]]$mess.ls$error,"TIME ERROR"),na.rm=T)
    }
  }  
}

print("procs with timing errors:")
err.procs <- which(time.err.count.vec>0)


err.procs

err.df.list <- vector(mode="list",length=length(err.procs))
names(err.df.list) <- paste0("Process_",err.procs)

regexp <- "[[:digit:]]+"
for (i in 1:length(err.procs)){
  proc <- err.procs[i]
  load(paste0("SBCCompletedPre-TimeErrorReruns/GridMeshIrregPolLGCPSBCSS",proc,".rda"))
  
  err.df <- data.frame(sim=numeric(),grid=numeric(),mesh=numeric(),err.time=numeric(),rerun=numeric())
  
  err.count <- 0
  for (j in 1:4){
    for (l in 1:4){
      s <- which(str_detect(run.out[[j]][[l]]$mess.ls$error,"TIME ERROR"))
      if (length(s)>0){
        err.df[(err.count+1):(err.count+length(s)),1] <- s
        err.df[(err.count+1):(err.count+length(s)),2] <- rep(j,length(s))
        err.df[(err.count+1):(err.count+length(s)),3] <- rep(l,length(s))
        err.df[(err.count+1):(err.count+length(s)),4] <- str_extract(run.out[[j]][[l]]$mess.ls$error[s],regexp)
      }
      err.count <- err.count + length(s)
    }
  }
  err.df$err.time[is.na(err.df$err.time)] <- 1
  
  err.df <- err.df[order(err.df$sim),]
  
  err.df$rerun <- 0
  
  err.df.list[[i]] <- err.df
  
}

# Save the list of data frames for the timing errors to guide the simulations on Balena
save(err.procs,err.df.list,file="TimingErrorDataFrames.rda")

err.df.full <- data.frame("Process"=rep(names(err.df.list)[1],nrow(err.df.list[[1]])),err.df.list[[1]])
for (i in 2:length(err.df.list)){
  sub.df <- data.frame("Process"=rep(names(err.df.list)[i],nrow(err.df.list[[i]])),err.df.list[[i]])
  err.df.full <- rbind(err.df.full,sub.df)
}

print(err.df.full)

sink()
rm(list=ls())


# Re-Save Outputs ---------------------------------------------------------
# Want to store the outputs as is with suffix `_TIMEERRORFINAL' so that re-running will not overwrite these versions.

load("TimingErrorDataFrames.rda")

for (i in err.procs){
  print(paste0("GridMeshIrregPolLGCPSBCSS",i,".rda"))
  load(paste0("SBCCompletedPre-TimeErrorReruns/GridMeshIrregPolLGCPSBCSS",i,".rda"))
  save.file <- paste0("GridMeshIrregPolLGCPSBCSS",i,"_TIMEERRORFINAL.rda")
  print(save.file)
  save(run.out,gm,true.theta,data.err.tracker,seed.vec,file=save.file)
  rm(run.out)
  rm(gm)
  rm(true.theta)
  rm(data.err.tracker)
  rm(seed.vec)
}

# Checks for the Old and New Output files

procs <- err.procs
sink("DoubleCheckingTimeErrorNewFilesforReRun.txt")
for (ii in 1:length(procs)){

  k <- procs[ii]
  this.node <- k%/%2 + as.numeric(k%%2!=0)

  print(paste0("Process ",k))
  
  old.file <- paste0("SBCCompletedPre-TimeErrorReruns/GridMeshIrregPolLGCPSBCSS",k,".rda")
  load(old.file)
  run.out.old <- run.out
  rm(run.out)
  gm.old <- gm
  rm(gm)
  true.theta.old <- true.theta
  rm(true.theta)
  seed.vec.old <- seed.vec
  rm(seed.vec)
  data.err.tracker.old <- data.err.tracker
  rm(data.err.tracker)

  save.file <- paste0("GridMeshIrregPolLGCPSBCSS",k,"_TIMEERRORFINAL.rda")
  load(save.file)


  N.g <- 4; N.m <- 4

  nn <- dim(gm[[N.g*N.m]]$ranks.mf)[2]
  p.length <- sum(!is.na(gm[[N.g*N.m]]$ranks.mf[,nn])) + sum(!is.na(run.out[[N.g]][[N.m]]$mess.ls$error)) + 1 # no ranks in trad ss

  gm.s <- function(g){sapply(1:length(g),function(i){sum(!is.na(g[[i]]$ranks.mf[,dim(g[[i]]$ranks.mf)[2]]))})}
  s <- matrix(gm.s(gm),nrow=N.m) # fills in down the columns, so for each grid, fills in row i with mesh i, following the output from e.s below
  e.s <- function(g){sapply(1:length(g),function(i){sum(!is.na(g[[i]]$mess.ls$error))})}
  s <- s + sapply(1:length(run.out),function(i){e.s(run.out[[i]])})
  ds <- diff(s)
  if (sum(ds)!=0){
    w <- which(ds!=0,arr.ind = T)
    grid.start.ind <- unname(w)[2]
    mesh.start.ind <- unname(w)[1] + 1
  } else if (sum(diff(t(s))!=0)){
    w <- which(diff(t(s))!=0,arr.ind=TRUE) # should be easy to extract the common row, then +1 to get the required GRID that needs to begin running...
    grid.start.ind <- unname(w)[1,1] + 1
    mesh.start.ind <- unname(w)[1,2]
  } else {
    grid.start.ind <- 1
    mesh.start.ind <- 1
  }


  print("Names gm")
  print(names(gm.old))
  print(names(gm))

  print(sum(names(gm.old)==names(gm)))
  print(sum(names(gm.old)!=names(gm)))


  for (i in 1:4){
    for (j in 1:4){
      print("Length Message")
      lm <- length(run.out[[i]][[j]]$mess.ls$message)
      print(lm)
      print("Null Messages Counts")
      print(sum(sapply(1:lm,function(l){is.null(run.out.old[[i]][[j]]$mess.ls$message[[l]])})))
      print(sum(sapply(1:lm,function(l){is.null(run.out[[i]][[j]]$mess.ls$message[[l]])})))
    }
  }


  for (i in 1:4){
    for (j in 1:4){
      print("Length FFT")
      lm <- length(run.out[[i]][[j]]$mess.ls$FFT)
      print(lm)
      print("Length Error")
      lm2 <- length(run.out[[i]][[j]]$mess.ls$error)
      print(lm2)
      print("FFT Count")
      print(sum((run.out.old[[i]][[j]]$mess.ls$FFT),na.rm=TRUE))
      print(sum((run.out[[i]][[j]]$mess.ls$FFT),na.rm=TRUE))
      print("Warnings Present")
      print(sum(!is.na(run.out.old[[i]][[j]]$mess.ls$warning)))
      print(sum(!is.na(run.out[[i]][[j]]$mess.ls$warning)))
      print("Errors Present")
      print(sum(!is.na(run.out.old[[i]][[j]]$mess.ls$error)))
      print(sum(!is.na(run.out[[i]][[j]]$mess.ls$error)))
    }
  }


  # CHECKS ####

  if (grid.start.ind==1&mesh.start.ind==1){ # for time error addition
    p.check <- p.length - 1
  } else {
    p.check <- p.length
  }
  
  print("Compare true.theta and seed.vec")
  print(sum(true.theta[1:p.check,]-true.theta.old[1:p.check,]))
  print(max(abs(true.theta[1:p.check,]-true.theta.old[1:p.check,])))
  print(sum(seed.vec[1:p.check]-seed.vec.old[1:p.check]))
  print(max(abs(seed.vec[1:p.check]-seed.vec.old[1:p.check])))

  print("data.err.tracker comparison")
  print(data.err.tracker)
  print(data.err.tracker.old)

  print("Sum and Max-Abs Different and sum is.na in est.df")
  for (jj in 1:4){
    for (ll in 1:4){
      if (jj <= grid.start.ind&ll<(mesh.start.ind-1)){
        replacement.max <- p.length
      } else {
        replacement.max <- p.length - 1
      }
      
      print(sum(run.out[[jj]][[ll]]$est.df[1:replacement.max,] - run.out.old[[jj]][[ll]]$est.df[1:replacement.max,],na.rm=TRUE))
      print(max(abs(run.out[[jj]][[ll]]$est.df[1:replacement.max,] - run.out.old[[jj]][[ll]]$est.df[1:replacement.max,])))
      
      print(sum(is.na(run.out.old[[jj]][[ll]]$est.df)))
      print(sum(is.na(run.out[[jj]][[ll]]$est.df)))

      print(sum(!is.na(run.out.old[[jj]][[ll]]$est.df)))
      print(sum(!is.na(run.out[[jj]][[ll]]$est.df)))
    }
  }

  print("Sum and Max-Abs difference in FFT")
  for (jj in 1:4){
    for (ll in 1:4){
      if (jj <= grid.start.ind&ll<(mesh.start.ind-1)){
        replacement.max <- p.length
      } else {
        replacement.max <- p.length - 1
      }
      
      print(sum(run.out[[jj]][[ll]]$mess.ls$FFT[1:replacement.max] - run.out.old[[jj]][[ll]]$mess.ls$FFT[1:replacement.max],na.rm=TRUE))
      print(max(abs(run.out[[jj]][[ll]]$mess.ls$FFT[1:replacement.max] - run.out.old[[jj]][[ll]]$mess.ls$FFT[1:replacement.max])))
      
      print(sum(is.na(run.out.old[[jj]][[ll]]$mess.ls$FFT)))
      print(sum(is.na(run.out[[jj]][[ll]]$mess.ls$FFT)))

      print(sum(!is.na(run.out.old[[jj]][[ll]]$mess.ls$FFT)))
      print(sum(!is.na(run.out[[jj]][[ll]]$mess.ls$FFT)))
    }
  }

  N.g <- 4; N.m <- 4

  print("Sum and Max-Abs difference in ranks.param")
  for (jj in 1:4){
    for (ll in 1:4){
      if (jj <= grid.start.ind&ll<(mesh.start.ind-1)){
        replacement.max <- p.length
      } else {
        replacement.max <- p.length - 1
      }
      ind <- (jj-1)*N.m + ll
      
      print(sum(gm[[ind]]$ranks.param[1:replacement.max,] - gm.old[[ind]]$ranks.param[1:replacement.max,],na.rm=TRUE))
      print(max(abs(gm[[ind]]$ranks.param[1:replacement.max,] - gm.old[[ind]]$ranks.param[1:replacement.max,])))

      print(sum(is.na(gm.old[[ind]]$ranks.param)))
      print(sum(is.na(gm[[ind]]$ranks.param)))

      print(sum(!is.na(gm.old[[ind]]$ranks.param)))
      print(sum(!is.na(gm[[ind]]$ranks.param)))
    }
  }

  print("Sum and Max-Abs difference in ranks.mf")
  for (jj in 1:4){
    for (ll in 1:4){
      if (jj <= grid.start.ind&ll<(mesh.start.ind-1)){
        replacement.max <- p.length
      } else {
        replacement.max <- p.length - 1
      }
      ind <- (jj-1)*N.m + ll
      
      print(sum(gm[[ind]]$ranks.mf[1:replacement.max,] - gm.old[[ind]]$ranks.mf[1:replacement.max,],na.rm=TRUE))
      print(max(abs(gm[[ind]]$ranks.mf[1:replacement.max,] - gm.old[[ind]]$ranks.mf[1:replacement.max,])))

      print(sum(is.na(gm.old[[ind]]$ranks.mf)))
      print(sum(is.na(gm[[ind]]$ranks.mf)))

      print(sum(!is.na(gm.old[[ind]]$ranks.mf)))
      print(sum(!is.na(gm[[ind]]$ranks.mf)))
    }
  }

  for (i in 1:16){
    print(paste0("Dimension gm.old ", i))
    print(dim(gm.old[[i]]$ranks.param))
    print(dim(gm.old[[i]]$ranks.mf))
    print(paste0("Dimension gm ", i))
    print(dim(gm[[i]]$ranks.param))
    print(dim(gm[[i]]$ranks.mf))
  }



}

sink()

rm(list=ls())


# sessionInfo -------------------------------------------------------------

sessionInfo()
