#################################################################
# Preparation, import packages.
#################################################################
require(gtools)
require(combinat)
source("TreeClone_fn.R")
source("TreeClone_plot.R")

#################################################################
# Preparation, generate matrix A, corresponding to 
# A(h_g, z_{kc}) in the paper.
#################################################################
HG = matrix(c(0, 0, 1, 1, 2, 2, 0, 1, 0, 1, 0, 1, 0, 1, 2, 2), 8, 2)
ZQ = array(c(rep(0, 8), rep(1, 8), rep(0, 4), rep(1, 4), 
             rep(0, 4), rep(1, 4), rep(c(0, 0, 1, 1), 4), rep(c(0, 1), 8)), c(16, 2, 2))
ZQ = cbind(ZQ[,,1], ZQ[,,2])
ZQ = ZQ[-c(5, 9, 10, 13, 14, 15), ]

Q = dim(ZQ)[1]
G = dim(HG)[1]

gen_A = function(hg, zq){
  mm = 0
  if(sum(hg == 2) == 0){
    if(sum(hg == zq[1:2]) == 2) mm = mm + 1/2
    if(sum(hg == zq[3:4]) == 2) mm = mm + 1/2 
  } else if(hg[1] == 2){
    if(hg[2] == zq[2]) mm = mm + 1/2
    if(hg[2] == zq[4]) mm = mm + 1/2
  } else if(hg[2] == 2){
    if(hg[1] == zq[1]) mm = mm + 1/2
    if(hg[1] == zq[3]) mm = mm + 1/2
  }
  return(mm)
}

A = matrix(0, G, Q)
for(g in 1:G){
  for(q in 1:Q){
    A[g, q] = gen_A(HG[g, ], ZQ[q, ])
  }
}

A = A[ , c(1,2,3,4,5,6,8,7,9,10)]

rm(HG)
rm(ZQ)


#################################################################
# Function for simulate data
#################################################################
simulate_n = function(Z, w, rho, v, N){
  # input: (Z, w, rho) - parameters
  # input: v - prob of observing left and right missing reads
  # input: N - total number of reads for k and t, T*K matrix
  # output: n and p
  
  K = dim(Z)[1]
  C = dim(Z)[2]
  G = length(rho)
  T = dim(w)[1]
  
  # p: T * K * G array, to represent multinomial probabilities
  p = array(0, c(T, K, G))
  
  for(t in 1:T){
    for(k in 1:K){
      for(g in 1:4){
        p[t, k, g] = (1 - v[t, k, 1] - v[t, k, 2]) * 
          (sum(w[t, 1:C] * A[g, Z[k, ]]) + w[t, C + 1] * rho[g])
      }
      for(g in 5:6){
        p[t, k, g] = v[t, k, 1] * (sum(w[t, 1:C] * A[g, Z[k, ]]) + w[t, C + 1] * rho[g])
      }
      for(g in 7:8){
        p[t, k, g] = v[t, k, 2] * (sum(w[t, 1:C] * A[g, Z[k, ]]) + w[t, C + 1] * rho[g])
      }
    }
  }
  
  
  n = array(0, c(T, K, G))
  
  for(t in 1:T){
    for(k in 1:K){
      n[t, k, ] = rmultinom(n = 1, size = N[t, k], prob = p[t, k, ])
    }
  }
  
  return(list(n = n, p = p))
}





transform_theta = function(theta){
  return(t(apply(theta, 1, function(x){x / sum(x)})))
}


pdfplot_hist_pdiff = function(p_star, psim, vsim, breaks = 10, path = "./results/pdiff.pdf"){
  # input:
  # p_star - point estimate of \tilde{p}, i.e. p from post_point_PT
  # psim - true value of p, not p tilde, but normalized p
  # vsim - true value of v, used to normalize p_star
  # output: residual plot in a pdf file
  
  p_star_rescale = p_star
  
  for(g in 1:4){
    p_star_rescale[ , , g] = p_star[ , , g] * (1 - vsim[ , , 1] - vsim[ , , 2])
  }
  for(g in 5:6){
    p_star_rescale[ , , g] = p_star[ , , g] * vsim[ , , 1]
  }
  for(g in 7:8){
    p_star_rescale[ , , g] = p_star[ , , g] * vsim[ , , 2]
  }
  
  pdf(path)
  par(mar = c(4.5, 5, 1, 1))
  hist(p_star_rescale - psim, xlab = "", main = "", 
       cex.axis = 1.6, cex.lab = 1.6, breaks = breaks)
  dev.off()
}



PairTree_R = function(n, treeStateMat, treeCntMat, hp, thin, niter){
  
  lambda = hp$lambda
  d = hp$d
  d0 = hp$d0
  d1 = hp$d1
  
  T = dim(n)[1]
  K = dim(n)[2]
  G = dim(n)[3]
  n_State = dim(treeStateMat)[1]
  C = dim(treeStateMat)[2]
  
  Z_sam = array(0, c(K, C, niter))
  Z_vec_sam = matrix(0, K, niter)
  
  theta_sam = array(0, c(T, C + 1, niter))
  rho_star_sam = matrix(0, G, niter)
  
  logpost_sam = rep(-Inf, niter)
  
  Z_vec_sam[ , 1] = sample(x = 1:n_State, size = K, replace = TRUE)
  Z_sam[ , , 1] = treeStateMat[Z_vec_sam[ , 1], ]
  count = colSums(treeCntMat[Z_vec_sam[ , 1], ])
  
  # to accomodate for C
  Z_vec_sam[ , 1] = Z_vec_sam[ , 1] - 1 
  
  
  theta_sam[ , 1:C, 1] = matrix(rgamma(T * C , rep(d, T * C), 1), T, C)
  theta_sam[ , C+1, 1] = apply(matrix(theta_sam[ , 1:C, 1], T, C), 1, sum) / 999
  
  rho_star_sam[ , 1] = rgamma(G, c(rep(d1, 4), rep(2*d1, 4)), 1)
  
  
  for(i in 2:niter){
    
    Z_vec_sam[ , i] = Z_vec_sam[ , i - 1]
    Z_sam[ , , i] = Z_sam[ , , i - 1]
    theta_sam[ , , i] = theta_sam[ , , i - 1]
    rho_star_sam[ , i] = rho_star_sam[ , i - 1]
    
    
    output = .C("PairTree_MCMC", Z = as.integer(Z_sam[ , , i]),
                Z_vec = as.integer(Z_vec_sam[ , i]), 
                count = as.integer(count),
                theta = as.double(theta_sam[ , , i]), 
                rho_star = as.double(rho_star_sam[ ,i]),
                n = as.double(n), C = as.integer(C), T = as.integer(T),
                K = as.integer(K), G = as.integer(G), 
                lambda = as.double(lambda), d = as.double(d),
                d0 = as.double(d0), d1 = as.double(d1),
                niter = as.integer(thin), 
                treeStateMat = as.integer(treeStateMat),
                treeCntMat = as.integer(treeCntMat),
                n_State = as.integer(n_State),
                logpost = as.double(0))
    
    Z_vec_sam[ , i] = output$Z_vec
    Z_sam[ , , i] = matrix(output$Z, K, C)
    theta_sam[ , , i] = matrix(output$theta, T, C + 1)
    rho_star_sam[ , i] = output$rho_star
    count = output$count
    logpost_sam[i] = output$logpost
    
  }
  
  return(list(Z = Z_sam, Z_vec = Z_vec_sam, theta = theta_sam, 
              rho_star = rho_star_sam, logpost = logpost_sam))
  
}








################################################################
# get data from .tj file, still writing
##################################################################
get_data_tj = function(read){
  T = read[1, 1]
  K = dim(read)[1]
  G = 8
  
  n = array(0, c(T, K, G))
  
  for(k in 1:K){
    ref_genome = strsplit(toString(read[k, 2]), "")[[1]]
    for(t in 1:T){
      nk = rep(0, 8)
      obs_summary = strsplit(toString(read[k, 2 + t]), ";")[[1]]
      for(g in 1:length(obs_summary)){
        obs_sg = obs_summary[g]
        obs_g = strsplit(obs_sg, "=")[[1]][1]
        obs_ng = strsplit(obs_sg, "=")[[1]][2]
        obs_ng = as.numeric(obs_ng)
        obs_g = strsplit(obs_g, "")[[1]]
        if(obs_g[1] == "_") if(obs_g[2] == ref_genome[2]) nk[5] = obs_ng else nk[6] = obs_ng 
        else if(obs_g[2] == "_") if(obs_g[1] == ref_genome[1]) nk[7] = obs_ng else nk[8] = obs_ng
        else if(obs_g[1] == ref_genome[1]) if(obs_g[2] == ref_genome[2]) nk[1] = obs_ng else nk[2] = obs_ng
        else if(obs_g[2] == ref_genome[2]) nk[3] = obs_ng else nk[4] = obs_ng
      }
      n[t, k, ] = nk
    }
  }
  return(n)
}




##############################
# Functions for Bayclone
##############################

# convert the Z matrix from PairClone's version to Bayclone's version
convert_Z_BC = function(Z){
  K = dim(Z)[1]
  C = dim(Z)[2]
  
  Z1 = matrix(0, 2 * K, C)
  for(k in 1:K){
    for(c in 1:C){
      if(Z[k, c] == 1) { Z1[2 * k - 1, c] = 0;  Z1[2 * k, c] = 0;}
      else if(Z[k, c] == 2) { Z1[2 * k - 1, c] = 0;  Z1[2 * k, c] = 0.5;}
      else if(Z[k, c] == 3) { Z1[2 * k - 1, c] = 0.5;  Z1[2 * k, c] = 0;}
      else if(Z[k, c] == 4) { Z1[2 * k - 1, c] = 0.5;  Z1[2 * k, c] = 0.5;}
      else if(Z[k, c] == 5) { Z1[2 * k - 1, c] = 0;  Z1[2 * k, c] = 1;}
      else if(Z[k, c] == 6) { Z1[2 * k - 1, c] = 0.5;  Z1[2 * k, c] = 0.5;}
      else if(Z[k, c] == 7) { Z1[2 * k - 1, c] = 0.5;  Z1[2 * k, c] = 1;}
      else if(Z[k, c] == 8) { Z1[2 * k - 1, c] = 1;  Z1[2 * k, c] = 0;}
      else if(Z[k, c] == 9) { Z1[2 * k - 1, c] = 1;  Z1[2 * k, c] = 0.5;}
      else if(Z[k, c] == 10) { Z1[2 * k - 1, c] = 1;  Z1[2 * k, c] = 1;}
      else{print("wrong Z matrix"); return(NULL); }
    }
  }
  return(Z1)
}



##  convert the data (array n) from PairClone's version to Bayclone's version
convert_n_BC = function(n){
  
  T = dim(n)[1]
  K = dim(n)[2]
  G = dim(n)[3]
  
  n1 = array(0, c(T, 2 * K, 2))
  
  for(t in 1:T){
    n1[t, 2 * (1:K) - 1, 1] = n[t, , 1] + n[t, , 2] + n[t, , 7]
    n1[t, 2 * (1:K) - 1, 2] = n[t, , 3] + n[t, , 4] + n[t, , 8]
    n1[t, 2 * (1:K), 1] = n[t, , 1] + n[t, , 3] + n[t, , 5]
    n1[t, 2 * (1:K), 2] = n[t, , 2] + n[t, , 4] + n[t, , 6]
  }
  
  data_snv = NULL
  n2 = matrix(0, 2 * K, T)
  N2 = matrix(0, 2 * K, T)
  
  for(t in 1:T){
    n2[ , t] = matrix(c(n1[t, , 2]), 2 * K, 1)
    N2[ , t] = matrix(c(apply(n1, c(1, 2), sum)[t,]), 2 * K, 1)
  }
  
  data_snv$n = n2
  data_snv$N = N2
  return(data_snv)
}



##############################
# Functions for PhyloWGS
##############################

convert_n_PhyloWGS = function(n){
  T = dim(n)[1]
  K = dim(n)[2]
  n_BC = convert_n_BC(n)
  
  nn = n_BC$n    # number of variant
  
  # index1 = (1:(2*K))
  index1 = (1:(2*K))[rowSums(nn) != 0]
  
  NN = n_BC$N    # number of total
  aa = NN - nn   # number of reference
  
  S = length(index1)
  NN = matrix(c(NN[index1, ]), S, T)
  aa = matrix(c(aa[index1, ]), S, T)
  
  
  id = paste0("s", 0:(S-1))
  gene = NA
  a = apply(aa, 1, function(x){paste(x, collapse = ",")})
  d = apply(NN, 1, function(x){paste(x, collapse = ",")})
  mu_r = 0.999
  mu_v = 0.499
  
  ssm_data = data.frame(id = id, gene = gene, a = a, d = d, mu_r = mu_r, mu_v = mu_v)
  return(ssm_data)
  
}

# write.table(ssm_data, file = "ssm_data.txt", sep = "\t", row.names = FALSE, quote = FALSE)


##################################################################################
# Simulation 1, 5 samples, 100 mutation pairs
# 3, 4, 5 subclones, 50, 200, 1000 depth, 9 datasets
##################################################################################

gen_simdata_sim1 = function(scenario = 1){
  
  
  if(scenario == 1){
    T = 5
    K = 100
    
    for(parent_str in c("011", "0122", "01124")){
      output_file = paste0("./tmpfiles/TreeStateMat_", parent_str, ".dat")
      cmd = sprintf("./gen_TreeStateMat %s %s", parent_str, output_file)
      system(cmd) 
      
      
      nState = as.numeric(read.table(output_file, skip = 1, nrows = 1))
      TreeStateMat = as.matrix(read.table(output_file, skip = 4, nrows = nState))
      TreeCntMat = as.matrix(read.table(output_file, skip = 6 + nState, nrows = nState))
      
      Z_index = sample(x = 1:nState, size = K, replace = TRUE)
      Zsim = TreeStateMat[Z_index, ] + 1
      
      C = nchar(parent_str)
      if(C == 3) d_w = c(15, 10, 5)
      if(C == 4) d_w = c(15, 10, 8, 5)
      if(C == 5) d_w = c(15, 10, 8, 5, 3)
      wsim = matrix(0, T, C+1)
      for(t in 1:T){
        wsim[t, ] = rdirichlet(n = 1, alpha = c(sample(x = d_w, size = C, replace = FALSE), 0.01))
      }
      rhosim = c(rdirichlet(1, c(1,1,1,1)), rdirichlet(1, c(2,2)), rdirichlet(1, c(2,2)))
      vsim = array(0.3, c(T, K, 2))
      
      for(N_mean in c(50, 200, 1000)){
        if(N_mean == 50){
          size = 50
          prob = 1/2
          # sd = 10
        }else if(N_mean == 200){
          size = 200/7
          prob = 1/8
          # sd = 40
        }else if(N_mean == 1000){
          size = 1000/39
          prob = 1/40
          # sd = 200
        }
        
        Nsim = matrix(rnbinom(n = T*K, size = size, prob = prob), T, K)
        
        n_p_sim = simulate_n(Zsim, wsim, rhosim, vsim, Nsim)
        
        psim = n_p_sim$p
        nsim = n_p_sim$n
        
        nsim = c(nsim)
        nsim_file = sprintf("./data/nsim1_C%d_N%d.txt", C, N_mean)
        write.table(nsim, file = nsim_file, quote = FALSE, sep = " ", row.names = FALSE, col.names = FALSE)
        
        simdata = NULL
        simdata$Z = Zsim
        simdata$w = wsim
        simdata$rho = rhosim
        simdata$v = vsim
        simdata$n = nsim
        simdata$p = psim
        simdata$Tree = parent_str
        
        simdata_file = sprintf("./data/simdata1_C%d_N%d.rds", C, N_mean)
        saveRDS(simdata, file = simdata_file)
      }
      # remove the temporary files
      cmd2 = sprintf("rm %s", output_file)
      system(cmd2) 
      print("Temporary tree state matrix deleted.")
    }
  }else if(scenario == 2){
    
    K = 100
    
    # T = 1, 3, 5; N_mean = 50, 200, 1000
    T_max = 5
    
    parent_str = "0122"
    output_file = paste0("./tmpfiles/TreeStateMat_", parent_str, ".dat")
    cmd = sprintf("./gen_TreeStateMat %s %s", parent_str, output_file)
    system(cmd) 
    
    nState = as.numeric(read.table(output_file, skip = 1, nrows = 1))
    TreeStateMat = as.matrix(read.table(output_file, skip = 4, nrows = nState))
    TreeCntMat = as.matrix(read.table(output_file, skip = 6 + nState, nrows = nState))
      
    Z_index = sample(x = 1:nState, size = K, replace = TRUE)
    Zsim = TreeStateMat[Z_index, ] + 1
      
    C = nchar(parent_str)
    d_w = c(15, 10, 8, 5)
    
    wsim_max = matrix(0, T_max, C+1)
    for(t in 1:T_max){
      wsim_max[t, ] = rdirichlet(n = 1, alpha = c(sample(x = d_w, size = C, replace = FALSE), 0.01))
    }
      
    rhosim = c(rdirichlet(1, c(1,1,1,1)), rdirichlet(1, c(2,2)), rdirichlet(1, c(2,2)))
    
    for(T in c(1, 3, 5)){
    
      wsim = matrix(wsim_max[1:T, ], T, C + 1)
      vsim = array(0.3, c(T, K, 2))
      
      for(N_mean in c(50, 200, 1000)){
        if(N_mean == 50){
          size = 50
          prob = 1/2
          # sd = 10
        }else if(N_mean == 200){
          size = 200/7
          prob = 1/8
          # sd = 40
        }else if(N_mean == 1000){
          size = 1000/39
          prob = 1/40
          # sd = 200
        }
        
        Nsim = matrix(rnbinom(n = T*K, size = size, prob = prob), T, K)
        
        n_p_sim = simulate_n(Zsim, wsim, rhosim, vsim, Nsim)
        
        psim = n_p_sim$p
        nsim = n_p_sim$n
        
        nsim = c(nsim)
        nsim_file = sprintf("./data/nsim1_T%d_N%d.txt", T, N_mean)
        write.table(nsim, file = nsim_file, quote = FALSE, sep = " ", row.names = FALSE, col.names = FALSE)
        
        simdata = NULL
        simdata$Z = Zsim
        simdata$w = wsim
        simdata$rho = rhosim
        simdata$v = vsim
        simdata$n = nsim
        simdata$p = psim
        simdata$Tree = parent_str
        
        simdata_file = sprintf("./data/simdata1_T%d_N%d.rds", T, N_mean)
        saveRDS(simdata, file = simdata_file)
      }
    }
    
    # remove the temporary files
    cmd2 = sprintf("rm %s", output_file)
    system(cmd2) 
    print("Temporary tree state matrix deleted.")
    
  }else if(scenario == 3){
    
    K = 100
    T = 5
    
    parent_str = "0122"
    output_file = paste0("./tmpfiles/TreeStateMat_", parent_str, ".dat")
    cmd = sprintf("./gen_TreeStateMat %s %s", parent_str, output_file)
    system(cmd) 
    
    nState = as.numeric(read.table(output_file, skip = 1, nrows = 1))
    TreeStateMat = as.matrix(read.table(output_file, skip = 4, nrows = nState))
    TreeCntMat = as.matrix(read.table(output_file, skip = 6 + nState, nrows = nState))
      
    Z_index = sample(x = 1:nState, size = K, replace = TRUE)
    Zsim = TreeStateMat[Z_index, ] + 1
      
    C = nchar(parent_str)
    d_w = c(15, 10, 8, 5)
    
    wsim = matrix(0, T, C+1)
    for(t in 1:T){
      wsim[t, ] = rdirichlet(n = 1, alpha = c(sample(x = d_w, size = C, replace = FALSE), 0.01))
    }
      
    rhosim = c(rdirichlet(1, c(1,1,1,1)), rdirichlet(1, c(2,2)), rdirichlet(1, c(2,2)))
    
    v1_all = c(0, 0.1, 0.25, 0.4, 0.5, 0)
    v2_all = c(0, 0.1, 0.25, 0.4, 0.5, 1)
    
    N_mean = 200
    size = 200/7
    prob = 1/8
    
    for(vv in 1:6){
      vsim = array(0, c(T, K, 2))
      vsim[ , , 1] = v1_all[vv]
      vsim[ , , 2] = v2_all[vv]
      
      Nsim = matrix(rnbinom(n = T*K, size = size, prob = prob), T, K)
      
      n_p_sim = simulate_n(Zsim, wsim, rhosim, vsim, Nsim)
      
      psim = n_p_sim$p
      nsim = n_p_sim$n
      
      nsim = c(nsim)
      nsim_file = sprintf("./data/nsim1_v1%d_v2%d_N%d.txt", v1_all[vv]*100, v2_all[vv]*100, N_mean)
      write.table(nsim, file = nsim_file, quote = FALSE, sep = " ", row.names = FALSE, col.names = FALSE)
      
      simdata = NULL
      simdata$Z = Zsim
      simdata$w = wsim
      simdata$rho = rhosim
      simdata$v = vsim
      simdata$n = nsim
      simdata$p = psim
      simdata$Tree = parent_str
      
      simdata_file = sprintf("./data/simdata1_v1%d_v2%d_N%d.rds", v1_all[vv]*100, v2_all[vv]*100, N_mean)
      saveRDS(simdata, file = simdata_file)
    }
  }
}

###########################################################
# Simulation 2 200x depth ??
###########################################################

gen_simdata_sim2 = function(N_mean = 200, rep_rows = 1){
  
  K0 = 100
  K = K0 * rep_rows
  T = 1
  C = 4
  
  parent_str = "0112"
  output_file = paste0("./tmpfiles/TreeStateMat_", parent_str, ".dat")
  cmd = sprintf("./gen_TreeStateMat %s %s", parent_str, output_file)
  system(cmd) 
      
      
  nState = as.numeric(read.table(output_file, skip = 1, nrows = 1))
  TreeStateMat = as.matrix(read.table(output_file, skip = 4, nrows = nState))
  TreeCntMat = as.matrix(read.table(output_file, skip = 6 + nState, nrows = nState))
  
  Zsim0 = matrix(c(rep(1, 100), 
                  rep(2, 20), rep(3, 20), rep(1, 60),
                  rep(1, 30), rep(2, 10), rep(3, 20), rep(2, 10), rep(1, 30),
                  rep(2, 10), rep(4, 20), rep(3, 10), 
                  rep(1, 20), rep(3, 20), rep(2, 10), rep(1, 10)), K, C)
      
  # Z_index = sample(x = 1:nState, size = K, replace = TRUE)
  # Zsim0 = TreeStateMat[Z_index, ] + 1
  # rowInd = heatmap_order_2(Zsim0)
  # Zsim0 = Zsim0[rowInd, ]
  
  
  
  Zsim = matrix(rep(t(Zsim0), rep_rows) , ncol =  ncol(Zsim0) , byrow = TRUE)
  
  if(rep_rows == 1){
    Z_plot_file = sprintf("./data/Zsim2_N%d_rep%d.pdf", N_mean, rep_rows)
    pdf(Z_plot_file)
    PairTree_plot_Z(Zsim)
    dev.off()
  }
  
  wsim = matrix(0, T, C+1)
  for(t in 1:T){
    wsim[t, ] = rdirichlet(n = 1, alpha = c(sample(x = c(15, 10, 8, 5), 
                                                   size = 4, replace = FALSE), 0.01))
  }
  
  rhosim = c(rdirichlet(1, c(1,1,1,1)), rdirichlet(1, c(2,2)), rdirichlet(1, c(2,2)))
  
  
  vsim = array(0.3, c(T, K, 2))
  # vsim[ , 51:100, ] = 0.25
  
  # mean = N_mean, sd = N_mean/5
  Nsim = matrix(rnbinom(n = T*K, size = 25*N_mean/(N_mean-25), prob = 25/N_mean), T, K)

  n_p_sim = simulate_n(Zsim, wsim, rhosim, vsim, Nsim)
  
  
  psim = n_p_sim$p
  nsim = n_p_sim$n
  
  nsim = c(nsim)
  nsim_file = sprintf("./data/nsim2_N%d_rep%d.txt", N_mean, rep_rows)
  write.table(nsim, file = nsim_file, quote = FALSE, sep = " ", row.names = FALSE, col.names = FALSE)
  
  
  
  simdata = list()
  simdata$Z = Zsim
  simdata$w = wsim
  simdata$rho = rhosim
  simdata$v = vsim
  simdata$n = nsim
  simdata$p = psim
  simdata$Tree = parent_str
  
  simdata_file = sprintf("./data/simdata2_N%d_rep%d.rds", N_mean, rep_rows)
  saveRDS(simdata, file = simdata_file)
  
  cmd2 = sprintf("rm %s", output_file)
  system(cmd2) 
  print("Temporary tree state matrix deleted.")
}


###########################################################
# Simulation 3
###########################################################

gen_simdata_sim3 = function(N_mean = 200){

  K = 100
  T = 8
  C = 5
  
  Zsim = matrix(c(rep(1, 100), 
                  rep(2, 40), rep(3, 20), rep(1, 40),
                  rep(4, 20), rep(6, 20), rep(4, 20), rep(2, 20), rep(1, 20),
                  rep(8, 20), rep(6, 20), rep(4, 20), rep(5, 20), rep(1, 20),
                  rep(4, 20), rep(9, 20), rep(8, 20), rep(2, 20), rep(3, 20)), K, C)
  
  wsim = matrix(0, T, C+1)
  for(t in 1:T){
    wsim[t, ] = rdirichlet(n = 1, alpha = c(sample(x = c(25, 15, 10, 8, 5), 
                                                   size = 5, replace = FALSE), 0.01))
  }
  
  rhosim = c(rdirichlet(1, c(1,1,1,1)), rdirichlet(1, c(2,2)), rdirichlet(1, c(2,2)))
  
  vsim = array(0.3, c(T, K, 2))
  
  # mean = N_mean, sd = N_mean/5
  Nsim = matrix(rnbinom(n = T*K, size = 25*N_mean/(N_mean-25), prob = 25/N_mean), T, K)
  
  n_p_sim = simulate_n(Zsim, wsim, rhosim, vsim, Nsim)
  
  
  psim = n_p_sim$p
  nsim = n_p_sim$n
  
  nsim = c(nsim)
  nsim_file = sprintf("./data/nsim3.txt")
  write.table(nsim, file = nsim_file, quote = FALSE, sep = " ", row.names = FALSE, col.names = FALSE)
  
  simdata = NULL
  simdata$Z = Zsim
  simdata$w = wsim
  simdata$rho = rhosim
  simdata$v = vsim
  simdata$n = nsim
  simdata$p = psim
  simdata$Tree = "01233"
  
  simdata_file = sprintf("./data/simdata3.rds")
  saveRDS(simdata, file = simdata_file)

}










gen_simdata_real = function(real_data_file = "./data/Lung_data.rds",
                            point_est_real_file = "./results/point_est_lung_run2.rds"){
  
  real_data = readRDS(real_data_file)
  n_obs = real_data$n
  
  point_est_real = readRDS(point_est_real_file)
  
  Zsim = point_est_real$Z
  wsim = point_est_real$w
  
  K = dim(Zsim)[1]
  C = dim(Zsim)[2]
  T = dim(wsim)[1]
  
  rhosim = c(rdirichlet(1, c(1,1,1,1)), rdirichlet(1, c(2,2)), rdirichlet(1, c(2,2)))
  
  vsim = calc_emp_p_v(n_obs)$v
  
  Nsim = apply(n_obs, c(1, 2), sum)
  
  
  n_p_sim = simulate_n(Zsim, wsim, rhosim, vsim, Nsim)
  
  
  psim = n_p_sim$p
  nsim = n_p_sim$n
  
  nsim = c(nsim)
  nsim_file = sprintf("./data/nsim4.txt")
  write.table(nsim, file = nsim_file, quote = FALSE, sep = " ", row.names = FALSE, col.names = FALSE)
  
  simdata = NULL
  simdata$Z = Zsim
  simdata$w = wsim
  simdata$rho = rhosim
  simdata$v = vsim
  simdata$n = nsim
  simdata$p = psim
  simdata$Tree = "012215"
  
  simdata_file = sprintf("./data/simdata4.rds")
  saveRDS(simdata, file = simdata_file)
  
}




##############

theta_to_w = function(theta){
  theta_rowsum = rowSums(theta)
  return(sweep(theta, 1, theta_rowsum, "/"))
}








####################################################################################
# Convergence Diagnostics
####################################################################################

# Gelman-Rubin statistics
calc_gelman_rubin = function(result_file_all, suffix = ""){
  
  n_chains = length(result_file_all)
  logpost_list = list()
  n_samples_all = rep(0, n_chains)
  for(i in 1:n_chains){
    chain_to_check = readRDS(result_file_all[i])
    
    Spl_Tree = chain_to_check$Spl_Tree
    Result = chain_to_check$Result
    
    niter = length(Spl_Tree)
    
    TreeMode = Mode(Spl_Tree)
    
    index = which(Spl_Tree == TreeMode)
    n_index = length(index)
    
    n_samples_all[i] = n_index
    logpost_list[[i]] = Result$logpost[index]
  } 
  n_samples_min = min(n_samples_all)
  
  for(i in 1:n_chains){
    logpost_list[[i]] = mcmc(logpost_list[[i]][sort(sample(1:n_samples_all[i], size = n_samples_min))])
  }
  
  logpost_mcmclist = mcmc.list(logpost_list)
  
  pdf(sprintf("./traceplot_%s.pdf", suffix))
  traceplot(logpost_mcmclist)
  dev.off()
  
  return(gelman.diag(logpost_mcmclist))
  
}

calc_gelman_rubin2 = function(result_file_all){
  
  n_chains = length(result_file_all)
  logpost_list = NULL
  for(i in 1:n_chains){
    chain_to_check = readRDS(result_file_all[i])
    logpost_list[[i]] = mcmc(chain_to_check$Result$logpost)
  } 
  logpost_mcmclist = mcmc.list(logpost_list)
  return(gelman.diag(logpost_mcmclist))
  
}

# Z_err and w_err. p_err??
calc_Z_w_err = function(point_est_file, simdata_file, type = "pair"){
  point_est = readRDS(point_est_file)
  Zhat = point_est$Z
  what = point_est$w
  
  simdata = readRDS(simdata_file)
  Z = simdata$Z
  w = simdata$w
  
  K = dim(Z)[1]
  C = dim(Z)[2]
  T = dim(w)[1]
  
  if(type == "SNV"){
    Zhat_SNV = convert_Z_BC(Zhat)
    Z_SNV = convert_Z_BC(Z)
    Zhat_SNV = Zhat_SNV[2*(1:K)-1, ]
    Z_SNV = Z_SNV[2*(1:K)-1, ]
    
    Zhat = Zhat_SNV
    Z = Z_SNV
  }
  
  permn_C = permn(2:C)
  n_permn_C = factorial(C-1)
  
  Z_err_permn = rep(0, n_permn_C)
  
  for(pc in 1:n_permn_C){
    permn_pc = permn_C[[pc]]
    Zhat_permn = Zhat[ , c(1, permn_pc)]
    Z_err_permn[pc] = sum((Zhat_permn - Z)!= 0)/(K * (C-1))
  }
  index_pc = which.min(Z_err_permn)
  permn_min = permn_C[[index_pc]]
  
  Z_err = min(Z_err_permn)
  w_err = sum(abs(what[ , c(1, permn_min)] - w[ , 1:C]))/(T*C)
  
  
  return(list(Z_err = Z_err, w_err = w_err))
}



# frequentist coverage rate of p_tkg
calc_coverage_rate = function(result_file_all, simdata_file, burnin = 0){
  
  simdata = readRDS(simdata_file)
  # parameters
  Z = simdata$Z - 1
  w = simdata$w
  rho = simdata$rho
  K = dim(Z)[1]
  C = dim(Z)[2]
  T = dim(w)[1]
  G = 8
  p_tilde = rep(0, T*K*G)
  
  p_tilde_out = .C("calc_p_tilde", p_tilde = as.double(p_tilde),
                        Z = as.integer(Z),
                        theta = as.double(w),
                        rho_star = as.double(rho),
                        C = as.integer(C), T = as.integer(T), 
                        K = as.integer(K), G = as.integer(G))
  
  p_tilde = p_tilde_out$p_tilde
  
  # for these indexes, truth might be 0 but due to the noise rho, 
  # inference might give CI [0.000x, 0.000x]. shouldn't take these into account.
  index_not0 = (1:(T*K*G))[(p_tilde > 0.0001) & (p_tilde < 0.9999)]
  
  n_chains = length(result_file_all)
  coverage = rep(0, n_chains)
  
  for(i in 1:n_chains){
    result_i = readRDS(result_file_all[i])
    
    Spl_Tree = result_i$Spl_Tree
    Result = result_i$Result
    
    niter = length(Spl_Tree)
    
    Spl_Tree = Spl_Tree[(burnin+1):niter]
    
    TreeMode = Mode(Spl_Tree)
    
    index = which(Spl_Tree == TreeMode)
    n_index = length(index)
    
    logpost = rep(0, n_index)
    
    p_tilde_spls = matrix(0, T*K*G, n_index) 
    
    for(i2 in 1:n_index){
      Z_i2 = Result$Z[[burnin + index[i2]]]
      theta_i2 = Result$theta[[burnin + index[i2]]]
      rho_star_i2 = Result$rho_star[[burnin + index[i2]]]
      
      p_tilde_out_i2 = .C("calc_p_tilde", p_tilde = as.double(p_tilde_spls[ , i2]),
                        Z = as.integer(Z_i2),
                        theta = as.double(theta_i2),
                        rho_star = as.double(rho_star_i2),
                        C = as.integer(C), T = as.integer(T), 
                        K = as.integer(K), G = as.integer(G))
      
      p_tilde_spls[ , i2] = p_tilde_out_i2$p_tilde
      
    }
    bounds = apply(p_tilde_spls, 1, function(x){quantile(x, probs = c(0.025, 0.975))})
    coverage[i] = sum(((p_tilde >= bounds[1, ]) & (p_tilde <= bounds[2, ]))[index_not0]) / length(index_not0)
  }
  
  return(coverage)
}


# frequentist coverage rate
calc_coverage_rate2 = function(result_file_all, simdata_file, hp, burnin = 0){
  
  alpha = hp$alpha
  beta = hp$beta
  d = hp$d
  d0 = hp$d0
  d1 = hp$d1
  
  simdata = readRDS(simdata_file)
  # data
  n = simdata$n
  # parameters
  Z = simdata$Z - 1
  w = simdata$w
  rho = simdata$rho
  Tree = simdata$Tree  # let's save this
  K = dim(Z)[1]
  C = dim(Z)[2]
  T = dim(w)[1]
  G = 8
  
  parent_arr = as.numeric(strsplit(Tree, split = "")[[1]])
  
  count_out = .C("calc_count", count = as.integer(rep(0, C)),
             Z = as.integer(Z),
             parent_arr = as.integer(parent_arr),
             C = as.integer(C),
             K = as.integer(K))
  
  count = count_out$count
  
  
  logpost_out = .C("calc_logpost", logpost = as.double(0),
                        Z = as.integer(Z),
                        count = as.integer(count),
                        theta = as.double(w),
                        rho_star = as.double(rho),
                        n = as.double(c(n)),
                        C = as.integer(C), T = as.integer(T), 
                        K = as.integer(K), G = as.integer(G),
                        lambda = as.double(2*K/C), d = as.double(d), 
                        d0 = as.double(d0), d1 = as.double(d1),
                        a_p = as.double(d), b_p = as.double(d0+(C-1)*d))
    
  logpriorTree = calc_logprior_tree_C_all(c(Tree), hp)
  
  logpost_truth = logpost_out$logpost + logpriorTree
    
  n_chains = length(result_file_all)
  coverage = rep(0, n_chains)
  
  for(i in 1:n_chains){
    result_i = readRDS(result_file_all[i])
    
    Spl_Tree = result_i$Spl_Tree
    Result = result_i$Result
    
    niter = length(Spl_Tree)
    
    Spl_Tree = Spl_Tree[(burnin+1):niter]
    
    TreeTabulate = tabulate(Spl_Tree)
    Tree_postprob = table(Spl_Tree)
    TreeMode = which(TreeTabulate == max(TreeTabulate))
    index = which(Spl_Tree == TreeMode)
    n_index = length(index)
    logpost = rep(0, n_index)
    for(i2 in 1:n_index){
      logpost[i2] = Result$logpost[[burnin + index[i2]]]
    }
    bounds = quantile(logpost, probs = c(0.025, 0.975))
    if((logpost_truth >= bounds[1]) && (logpost_truth <= bounds[2])){
      coverage[i] = 1
    }
  }
  
  return(coverage)
}



####################################################################################
# Generate nohup run scripts for it to run 
####################################################################################

gen_nohup_script = function(){

  nohup_file = file("./script.sh")
  nohup_cmd_all = NULL
  
  # scenario 1
  T = 5
  K = 100
  for(MCMC_run in 1:3){
    for(C in 3:5){
      for(N_mean in c(50, 200, 1000)){
        input_file = paste0("./data/nsim1_C", C, "_N", N_mean, ".txt")
        suffix = paste0("sim1_C", C, "_N", N_mean, "_run", MCMC_run)
        log_file = paste0("./LOGS/log_C", C, "_N", N_mean, "_run", MCMC_run, ".out")
        nohup_cmd = paste("Rscript", "TreeClone_main.R", input_file, suffix, T, K, ">", log_file, "&", sep = " ")
        nohup_cmd_all = c(nohup_cmd_all, nohup_cmd)
      }
    }
  }
  
  # scenario 2
  K = 100
  for(T in c(1, 3, 5)){
    for(N_mean in c(50, 200, 1000)){
      input_file = paste0("./data/nsim1_T", T, "_N", N_mean, ".txt")
      suffix = paste0("sim1_T", T, "_N", N_mean)
      log_file = paste0("./LOGS/log_T", T, "_N", N_mean, ".out")
      nohup_cmd = paste("Rscript", "TreeClone_main.R", input_file, suffix, T, K, ">", log_file, "&", sep = " ")
      nohup_cmd_all = c(nohup_cmd_all, nohup_cmd)
    }
  }
  
  # scenario 3
  v1_all = c(0, 0.1, 0.25, 0.4, 0.5, 0)
  v2_all = c(0, 0.1, 0.25, 0.4, 0.5, 1)
  T = 5
  K = 100
  N_mean = 200
  for(vv in 1:6){
    input_file = paste0("./data/nsim1_v1", v1_all[vv]*100, "_v2", v2_all[vv]*100, "_N", N_mean, ".txt")
    suffix = paste0("sim1_v1", v1_all[vv]*100, "_v2", v2_all[vv]*100, "_N", N_mean)
    log_file = paste0("./LOGS/log_v1", v1_all[vv]*100, "_v2", v2_all[vv]*100, "_N", N_mean, ".out")
    nohup_cmd = paste("Rscript", "TreeClone_main.R", input_file, suffix, T, K, ">", log_file, "&", sep = " ")
    nohup_cmd_all = c(nohup_cmd_all, nohup_cmd)
  }
  
  print(nohup_cmd_all)
  writeLines(nohup_cmd_all, nohup_file)
  close(nohup_file)
}
  
  
print_sim1_results = function(parent_str_all, scenario = 1){
  
  MCMCspls_file_all = c()
  point_est_file_all = c()
  
  # scenario 1
  if(scenario == 1){
    
    Z_err_all = array(0, c(3, 3, 3))
    w_err_all = array(0, c(3, 3, 3))
    
    gelman_rubin_stat_all = matrix(0, 3, 3)
    coverage_all = matrix(0, 3, 3)
    
    
    C_all = 3:5
    N_mean_all = c(50, 200, 1000)
    
    for(i in 1:3){
      C = C_all[i]
      for(j in 1:3){
        N_mean = N_mean_all[j]
        simdata_file = paste0("./data/simdata1_C", C, "_N", N_mean, ".rds")
        
        result_file_C_N_mean = c()
        
        for(k in 1:3){
          MCMC_run = k
          
          MCMCspls_file = paste0("./results/MCMCspls_sim1_C", C, "_N", N_mean, "_run", MCMC_run, ".rds")
          point_est_file = paste0("./results/point_est_sim1_C", C, "_N", N_mean, "_run", MCMC_run, ".rds")
          
          result_file_C_N_mean = c(result_file_C_N_mean, MCMCspls_file)
          
          Z_w_err = calc_Z_w_err(point_est_file, simdata_file)
        
          Z_err_all[i, j, k] = Z_w_err$Z_err
          w_err_all[i, j, k] = Z_w_err$w_err
          
          MCMCspls_file_all = c(MCMCspls_file_all, MCMCspls_file)
          point_est_file_all = c(point_est_file_all, point_est_file)
        }
        
        gelman_rubin_stat_all[i, j] = calc_gelman_rubin2(result_file_C_N_mean)$psrf[1]
        
        #coverage_all[i, j] = mean(calc_coverage_rate(result_file_C_N_mean, simdata_file, burnin = 0))
      }
    }
    Z_err = apply(Z_err_all, c(1,2), mean)
    w_err = apply(w_err_all, c(1,2), mean)
    rownames(Z_err) = paste0("C = ", C_all)
    colnames(Z_err) = paste0("N_mean = ", N_mean_all)
    rownames(w_err) = paste0("C = ", C_all)
    colnames(w_err) = paste0("N_mean = ", N_mean_all)
    print(Z_err)
    print(w_err)
    
    rownames(gelman_rubin_stat_all) = paste0("C = ", C_all)
    colnames(gelman_rubin_stat_all) = paste0("N_mean = ", N_mean_all)
    print(gelman_rubin_stat_all)
    
    rownames(coverage_all) = paste0("C = ", C_all)
    colnames(coverage_all) = paste0("N_mean = ", N_mean_all)
    print(coverage_all)
    
  }else if(scenario == 2){
    
    C_err = array(0, c(3, 3))
    Tree_err = array(0, c(3, 3))
    Z_err = array(0, c(3, 3))
    w_err = array(0, c(3, 3))
    
    T_all = c(1, 3, 5)
    N_mean_all = c(50, 200, 1000)
    
    for(i in 1:3){
      T = T_all[i]
      for(j in 1:3){
        N_mean = N_mean_all[j]
        
        simdata_file = paste0("./data/simdata1_T", T, "_N", N_mean, ".rds")
        
        MCMCspls_file = paste0("./results/MCMCspls_sim1_T", T, "_N", N_mean, ".rds")
        point_est_file = paste0("./results/point_est_sim1_T", T, "_N", N_mean, ".rds")
        
        point_est = readRDS(point_est_file)
        Tree_hat = parent_str_all[as.numeric(names(which.max(point_est$Tree_prob)))]
        C_hat = nchar(Tree_hat)
        simdata = readRDS(simdata_file)
        Tree = simdata$Tree
        C = nchar(Tree)
        
        if(C != C_hat){
          C_err[i, j] = 1
        }
        if(Tree != Tree_hat){
          Tree_err[i, j] = 1
        }
        
        Z_w_err = calc_Z_w_err(point_est_file, simdata_file)
        Z_err[i, j] = Z_w_err$Z_err
        w_err[i, j] = Z_w_err$w_err
          
        MCMCspls_file_all = c(MCMCspls_file_all, MCMCspls_file)
        point_est_file_all = c(point_est_file_all, point_est_file)
      }
    }
    
    Z_err = round(Z_err, digits = 2)
    w_err = round(w_err, digits = 2)
    
    all_err = matrix(paste(C_err, Tree_err, Z_err, w_err, sep = ", "), 3, 3)
    
    rownames(all_err) = paste0("T = ", T_all)
    colnames(all_err) = paste0("N_mean = ", N_mean_all)
    print(all_err)
    
  }else if(scenario == 3){
    C_err = rep(0, 6)
    Tree_err = rep(0, 6)
    Z_err = rep(0, 6)
    w_err = rep(0, 6)
    
    N_mean = 200
    v1_all = c(0, 0.1, 0.25, 0.4, 0.5, 0)
    v2_all = c(0, 0.1, 0.25, 0.4, 0.5, 1)
    for(i in 1:6){
      v1 = v1_all[i]
      v2 = v2_all[i]
      
      simdata_file = paste0("./data/simdata1_v1", v1*100, "_v2", v2*100, "_N", N_mean, ".rds")
        
      MCMCspls_file = paste0("./results/MCMCspls_sim1_v1", v1*100, "_v2", v2*100, "_N", N_mean, ".rds")
      point_est_file = paste0("./results/point_est_sim1_v1", v1*100, "_v2", v2*100, "_N", N_mean, ".rds")
        
      point_est = readRDS(point_est_file)
      Tree_hat = parent_str_all[as.numeric(names(which.max(point_est$Tree_prob)))]
      C_hat = nchar(Tree_hat)
      simdata = readRDS(simdata_file)
      Tree = simdata$Tree
      C = nchar(Tree)
        
      if(C != C_hat){
        C_err[i] = 1
      }
      if(Tree != Tree_hat){
        Tree_err[i] = 1
      }
        
      Z_w_err = calc_Z_w_err(point_est_file, simdata_file)
      Z_err[i] = Z_w_err$Z_err
      w_err[i] = Z_w_err$w_err
          
      MCMCspls_file_all = c(MCMCspls_file_all, MCMCspls_file)
      point_est_file_all = c(point_est_file_all, point_est_file)
    }
    Z_err = round(Z_err, digits = 2)
    w_err = round(w_err, digits = 2)
    
    all_err = paste(C_err, Tree_err, Z_err, w_err, sep = ", ")
    print(all_err)
    
    i = 6
    v1 = v1_all[i]
    v2 = v2_all[i]
      
    simdata_file = paste0("./data/simdata1_v1", v1*100, "_v2", v2*100, "_N", N_mean, ".rds")
    point_est_file = paste0("./results/point_est_sim1_v1", v1*100, "_v2", v2*100, "_N", N_mean, ".rds")
    
    Z_w_err = calc_Z_w_err(point_est_file, simdata_file, type = "SNV")
    Z_err_SNV = Z_w_err$Z_err
    Z_err_SNV = round(Z_err_SNV, digits = 2)
    print(paste("Re-defined Z_err: ", Z_err_SNV))
  }
}

print_est_trees = function(point_est_file_all){
  for(point_est_file in point_est_file_all){
    point_est = readRDS(point_est_file)
    print(names(which(point_est$Tree_prob == max(point_est$Tree_prob))))
  }
}


print_Z_w_err = function(){
  
  Z_err_all = array(0, c(3, 3, 3))
  w_err_all = array(0, c(3, 3, 3))
  
  C_all = 3:5
  N_mean_all = c(50, 200, 1000)
  
  # scenario 1
  for(i in 1:3){
    C = C_all[i]
    for(j in 1:3){
      N_mean = N_mean_all[j]
      simdata_file = paste0("./data/simdata1_C", C, "_N", N_mean, ".rds")
      for(k in 1:3){
        MCMC_run = k
        point_est_file = paste0("./results/point_est_sim1_C", C, "_N", N_mean, "_run", MCMC_run, ".rds")
        
        Z_w_err = calc_Z_w_err(point_est_file, simdata_file)
        
        Z_err_all[i, j, k] = Z_w_err$Z_err
        w_err_all[i, j, k] = Z_w_err$w_err
      }
    }
  }
  Z_err = apply(Z_err_all, c(1,2), mean)
  w_err = apply(w_err_all, c(1,2), mean)
  rownames(Z_err) = paste0("C = ", C_all)
  colnames(Z_err) = paste0("N_mean = ", N_mean_all)
  rownames(w_err) = paste0("C = ", C_all)
  colnames(w_err) = paste0("N_mean = ", N_mean_all)
  print(Z_err)
  print(w_err)
}


###################################################################
# plot real data results
###################################################################
plot_result_real = function(point_est_file = "./results/point_est_lung_run2.rds"){
  point_est = readRDS(point_est_file)
  Zhat = point_est$Z
  what = point_est$w
  rowIndex = heatmap_order_2(Zhat)
  Zhat = Zhat[rowIndex, ]
  pdf("./results/Zhat_lung.pdf")
  PairTree_plot_Z(Zhat)
  dev.off()
  
  pdf("./results/what_lung.pdf")
  PairTree_plot_w(what[ , 1:6], 0.75)
  dev.off()
  
}


calc_johnson_chisquare = function(T, K, n_file = "./data/n_lung.txt", 
                           MCMCspls_file = "./results/MCMCspls_lung_run2.rds"){
  
  n = scan(n_file)
  G = 8
  n = array(n, c(T, K, G))
  MCMCspls = readRDS(MCMCspls_file)
  
  emp_p_v = calc_emp_p_v(n)
  emp_p = emp_p_v$p
  emp_v = emp_p_v$v
  
  Spl_Tree = MCMCspls$Spl_Tree
  Result = MCMCspls$Result
    
  niter = length(Spl_Tree)
    
  TreeMode = Mode(Spl_Tree)
    
  index = which(Spl_Tree == TreeMode)
  n_index = length(index)
  
  R_B = rep(0, n_index)
  
  p_spls = array(0, c(T, K, G, n_index))
  
  for(i in 1:n_index){
    Z = Result$Z[[index[i]]]
    theta = Result$theta[[index[i]]]
    rho_star = Result$rho_star[[index[i]]]
    C = dim(Z)[2]
    
    p_tilde = rep(0, T*K*G)
    
    p_tilde_out = .C("calc_p_tilde", p_tilde = as.double(p_tilde),
                 Z = as.integer(Z),
                 theta = as.double(theta),
                 rho_star = as.double(rho_star),
                 C = as.integer(C), T = as.integer(T), 
                 K = as.integer(K), G = as.integer(G))
    
    p_tilde = p_tilde_out$p_tilde
    p_tilde = array(p_tilde, c(T, K, G))
    
    p = array(0, c(T, K, G))
    
    emp_v_complete = 1 - emp_v[ , , 1] - emp_v[ , , 2]
    emp_v_left_mis = emp_v[ , , 1]
    emp_v_right_mis = emp_v[ , , 2]
    
    p[ , , 1:4] = apply(p_tilde[ , , 1:4],  3, function(x){return(x * emp_v_complete)})
    p[ , , 5:6] = apply(p_tilde[ , , 5:6],  3, function(x){return(x * emp_v_left_mis)})
    p[ , , 7:8] = apply(p_tilde[ , , 7:8],  3, function(x){return(x * emp_v_right_mis)})
    
    p_spls[ , , , i] = p 
  }
  
  p_spls_lower = apply(p_spls,  c(1, 2, 3), function(x){return(quantile(x, probs = 0.005))})
  p_spls_upper = apply(p_spls,  c(1, 2, 3), function(x){return(quantile(x, probs = 0.995))})
  
  est_p = apply(p_spls, c(1, 2, 3), mean)
  
  #p_index = emp_p < 0.2
  p_index = (emp_p > p_spls_lower) & (emp_p < p_spls_upper)
  
  n_subset = array(0, c(T, K, G))
  n_subset[p_index] = n[p_index]
  
  # # of observations falling into 
  m_G = apply(n_subset, 3, sum)
  N_TK = apply(n, c(1, 2), sum)
  N = sum(n_subset)
  
  
  
  for(i in 1:n_index){
    
    p_subset = array(0, c(T, K, G))
    p_subset[p_index] = p_spls[ , , , i][p_index]
    
    N_q_G = apply(apply(p_subset, 3, function(x){return(x * N_TK)}), 2, sum)
    
    R_B[i] = sum((m_G - N_q_G)^2 / N_q_G)
    
  }
  
  pdf("./results/chi_square_qqplot_lung.pdf")
  par(mar = c(4.5, 5, 1, 1))
  qqplot(qchisq(ppoints(500), df = 7), R_B, xlab = "Expected order statistics", ylab = expression(paste("Sorted ", R^B, " values")), cex.axis = 1.6, cex.lab = 1.6)
  abline(0, 1)
  dev.off()
  
  pdf("./results/pdiff_lung.pdf")
  par(mar = c(4.5, 5, 1, 1))
  hist(est_p - emp_p, xlab = "", main = "", cex.axis = 1.6, cex.lab = 1.6, breaks = 40)
  dev.off()
  
  print(sprintf("prob of test statistic > chi-square 95%% quantile = %f", sum(R_B > qchisq(p = 0.95, df = 7))/length(R_B)))
  
  return(R_B)
}



