
##########################################################################
# Source
##########################################################################
require(coda)
dyn.load("PairTree_MCMC_R.so")

Mode = function(x) {
  ux = unique(x)
  return(ux[which.max(tabulate(match(x, ux)))])
}

##########################################################################
# Run PT for all possible trees
##########################################################################
TreeClone_runPT = function(input_file = './data/nsim1.txt', suffix = "sim1", parent_str_all){
  
  L_arr = length(parent_str_all)
  
  print(paste0("Parallel Tempering started. Time: ", date()))
  #time_start = proc.time()
  
  for(i in 1:L_arr){
    parent_str = parent_str_all[i] 
    output_file = paste0("./tmpfiles/PTfile_", parent_str, "_", suffix, ".dat")
    cmd = sprintf("./PairTree_MCMC_PT %s %d %d %s %s", parent_str, T, K, input_file, output_file)
    system(cmd) 
  }
  
  #time_end = proc.time()
  print(paste0("Parallel Tempering finished. Time:", date()))
  
}

############################################
# calculate log-prior for a tree

calc_depth_tree = function(parent_str_split, i){
  if(parent_str_split[i] == 0){ 
    depth = 0 
  } else {
    depth = 1 + calc_depth_tree(parent_str_split, parent_str_split[i])
  }
  return(depth)
}



calc_logprior_tree = function(parent_str_split, hp){
  
  beta = hp$beta
  n_nodes = length(parent_str_split)
  tree_depth = rep(0, n_nodes)
  for(i in 1:n_nodes){
    tree_depth[i] = calc_depth_tree(parent_str_split, i)
  }
  logprior_tree = -beta * sum(log(1 + tree_depth))
  return(logprior_tree)
}


# calculate the log prior masses of different trees and C, log p(\Tau, C) = log p(\Tau | C) + log p(C)
calc_logprior_tree_C_all = function(parent_str_all, hp){

  alpha = hp$alpha
  beta = hp$beta
  
  L_arr = length(parent_str_all)
  C_all = nchar(parent_str_all)
  
  logprior_tree_all = rep(0, L_arr)
  logprior_C_all = rep(0, L_arr)  
  
  max_tree_node = max(nchar(parent_str_all))
  parent_str_all_split = strsplit(parent_str_all, split = "")
  
  for(i in 1:L_arr){
    parent_str_split = as.numeric(parent_str_all_split[[i]])
    logprior_tree_all[i] = calc_logprior_tree(parent_str_split, hp)
    logprior_C_all[i] = (C_all[i] - 1) * log(1 - alpha) + log(alpha)
  }
  
  # normalization constant
  for(C in 2:max_tree_node){
    logprior_tree_all[C_all == C] = logprior_tree_all[C_all == C] - log(sum(exp(logprior_tree_all[C_all == C])))
  }
  
  return(logprior_tree_all + logprior_C_all)
}



###########################################

TreeClone_sampleTree = function(n, T, K, G, hp, b, niter, parent_str_all, 
                         suffix = "sim1", tmpdirectory = "./tmpfiles/"){
  
  alpha = hp$alpha
  beta = hp$beta
  #lambda = hp$lambda
  d = hp$d
  d0 = hp$d0
  d1 = hp$d1
  
  n_tr = b * n
  n_te = (1-b) * n
  
  # Files = list.files("./test/sam_all_tree/", pattern = ".dat")
  Files = paste0(tmpdirectory, "PTfile_", parent_str_all, "_", suffix, ".dat")

  nFiles = length(Files)
  
  Spl = NULL
  Spl$Z = NULL
  Spl$Z_vec = NULL
  Spl$count = NULL
  Spl$theta = NULL
  Spl$rho_star = NULL
  Spl$StateMat = NULL
  Spl$CntMat = NULL
  Spl$nState = rep(0, nFiles)
  
  
  Result = NULL
  Result$Tree = NULL
  Result$Z = NULL
  Result$theta = NULL
  Result$rho_star = NULL
  Result$logpost = NULL
  
  # Tree -> C
  C_Tree = nchar(parent_str_all)
  
  # log p(Tree, C)
  logpriorTree = calc_logprior_tree_C_all(parent_str_all, hp)
  
  # read the initial values from PT files (to avoid local optimum)
  for(i in 1:nFiles){
    Spl$Z[[i]] = c(as.matrix(read.table(Files[i], skip = 1, nrows = K)))
    Spl$Z_vec[[i]] = c(as.matrix(read.table(Files[i], skip = K + 3, nrows = 1)))
    Spl$count[[i]] = c(as.matrix(read.table(Files[i], skip = K + 6, nrows = 1)))
    theta1 = c(as.matrix(read.table(Files[i], skip = K + 9, nrows = T)))
    theta1[theta1 == 0] = 1e-5
    Spl$theta[[i]] = theta1
    rho1 = c(as.matrix(read.table(Files[i], skip = K + 2*T + 13, nrows = 1)))
    rho1[rho1 == 0] = 1e-5
    Spl$rho_star[[i]] = rho1
    Spl$nState[i] = c(as.matrix(read.table(Files[i], skip = K + 2*T + 22, nrows = 1)))
    
    Spl$StateMat[[i]] = c(as.matrix(read.table(Files[i], skip = K + 2*T + 25, nrows = Spl$nState[i])))
    Spl$CntMat[[i]] = c(as.matrix(read.table(Files[i], skip = K + 2*T + Spl$nState[i] + 27, nrows = Spl$nState[i])))
    
    # cmd2 = sprintf("rm %s", Files[i])
    # system(cmd2)
  }
  
  # print("Temporary PT files deleted.\n")
  
  Spl_Tree = rep(0, niter)
  # we index trees from 1 to 16, instead of parentArr
  
  Tree_cur = sample(1:nFiles, 1)
  C_cur = C_Tree[Tree_cur]
  
  loglik_cur_out = .C("calc_loglik", loglik = as.double(0),
                  Z = as.integer(Spl$Z[[Tree_cur]]), 
                  theta = as.double(Spl$theta[[Tree_cur]]),
                  rho_star = as.double(Spl$rho_star[[Tree_cur]]),
                  n = as.double(c(n_te)),
                  C = as.integer(C_cur), T = as.integer(T), 
                  K = as.integer(K), G = as.integer(G))
  
  loglik_cur = loglik_cur_out$loglik
  
  for(i in 1:niter){
    if((i%%100)==0)
    {
      print(paste(round(i/niter*100,2), "% MCMC sampling has been done." ))
      print(date())
    }
    
    Tree_pro = sample(1:nFiles, 1)
    C_pro = C_Tree[Tree_pro]
    
    Spl_pro = .C("PairTree_MCMC_2", Z = as.integer(Spl$Z[[Tree_pro]]), 
                    Z_vec = as.integer(Spl$Z_vec[[Tree_pro]]),
                    theta = as.double(Spl$theta[[Tree_pro]]), 
                    rho_star = as.double(Spl$rho_star[[Tree_pro]]),
                    count = as.integer(Spl$count[[Tree_pro]]), 
                    n = as.double(c(n_tr)),
                    C = as.integer(C_pro), T = as.integer(T), 
                    K = as.integer(K), G = as.integer(G),
                    lambda = as.double(2*K/C_pro), d = as.double(d), 
                    d0 = as.double(d0), d1 = as.double(d1),
                    a_p = as.double(d), b_p = as.double(d0+(C_pro-1)*d),
                    niter = as.integer(1), 
                    treeStateMat = as.integer(Spl$StateMat[[Tree_pro]]),
                    treeCntMat = as.integer(Spl$CntMat[[Tree_pro]]),
                    n_State = as.integer(Spl$nState[Tree_pro]))
    # only change when gets accepted!!!!!!!!!!!!!!!!!
    # change later!!!!!!!!!!!!!
    Z_tmp = Spl_pro$Z
    Z_vec_tmp = Spl_pro$Z_vec
    count_tmp = Spl_pro$count
    theta_tmp = Spl_pro$theta
    rho_star_tmp = Spl_pro$rho_star
    
  
    loglik_pro_out = .C("calc_loglik", loglik = as.double(0),
                     Z = as.integer(Z_tmp), 
                     theta = as.double(theta_tmp),
                     rho_star = as.double(rho_star_tmp),
                     n = as.double(c(n_te)),
                     C = as.integer(C_pro), T = as.integer(T), 
                     K = as.integer(K), G = as.integer(G))
    
    loglik_pro = loglik_pro_out$loglik
    
    logpost_cur = loglik_cur + logpriorTree[Tree_cur]
    logpost_pro = loglik_pro + logpriorTree[Tree_pro]
    
    if(is.na(logpost_cur)  | is.na(logpost_pro)){
      # if current posterior is -INF, accept proposal. (Not likely to happen)
      if(is.na(logpost_cur)){
        loglik_cur = loglik_pro
        Tree_cur = Tree_pro
        C_cur = C_pro
        
        Spl$Z[[Tree_pro]] = Z_tmp
        Spl$Z_vec[[Tree_pro]] = Z_vec_tmp
        Spl$count[[Tree_pro]] = count_tmp
        Spl$theta[[Tree_pro]] = theta_tmp
        Spl$rho_star[[Tree_pro]] = rho_star_tmp
        } 
    } else{
      u = runif(1, 0, 1)
      # accept the proposal with min(1, p_acc)
      if(log(u) < logpost_pro - logpost_cur){
        loglik_cur = loglik_pro
        Tree_cur = Tree_pro  
        C_cur = C_pro
        
        Spl$Z[[Tree_pro]] = Z_tmp
        Spl$Z_vec[[Tree_pro]] = Z_vec_tmp
        Spl$count[[Tree_pro]] = count_tmp
        Spl$theta[[Tree_pro]] = theta_tmp
        Spl$rho_star[[Tree_pro]] = rho_star_tmp
      }
    }

    logpost_cur_out = .C("calc_logpost", logpost = as.double(0),
                        Z = as.integer(Spl$Z[[Tree_cur]]),
                        count = as.integer(Spl$count[[Tree_cur]]),
                        theta = as.double(Spl$theta[[Tree_cur]]),
                        rho_star = as.double(Spl$rho_star[[Tree_cur]]),
                        n = as.double(c(n)),
                        C = as.integer(C_cur), T = as.integer(T), 
                        K = as.integer(K), G = as.integer(G),
                        lambda = as.double(2*K/C_cur), d = as.double(d), 
                        d0 = as.double(d0), d1 = as.double(d1),
                        a_p = as.double(d), b_p = as.double(d0+(C_cur-1)*d))
    
    
    Result$logpost[[i]] = logpost_cur_out$logpost + logpriorTree[Tree_cur]
    
    Result$Tree[[i]] = Tree_cur
    Result$Z[[i]] = matrix(Spl$Z[[Tree_cur]], K, C_cur)
    Result$theta[[i]] = matrix(Spl$theta[[Tree_cur]], T, C_cur+1)
    Result$rho_star[[i]] = Spl$rho_star[[Tree_cur]]
    
    Spl_Tree[i] = Tree_cur
  }
  
  return(list(Spl_Tree = Spl_Tree, Result = Result))
  
}

################################

transform_theta = function(theta){
  theta_rowsum = rowSums(theta)
  return(sweep(theta, 1, theta_rowsum, "/"))
}

################################

calc_emp_p_v = function(n){
  p = apply(n, c(1,2), function(x){x / sum(x)})
  p = aperm(p, c(2, 3, 1))
  
  v = apply(n, c(1,2), function(x){c(sum(x[5:6]), sum(x[7:8])) / sum(x)})
  v = aperm(v, c(2, 3, 1))
  
  return(list(p = p, v = v))
  # the "real" p, not p_tilde
}

####
TreeClone_MCMC = function(input_file = "./data/nsim.txt", T, K, hp, b, niter, burnin, parent_str_all,
                      suffix = "sim1", save_MCMCspls = TRUE, save_point_est = TRUE, save_Z_plot = FALSE){
  
  nsim = scan(input_file)
  
  # the directory for the results from PT
  tmpdirectory = "./tmpfiles/"
    
  MCMCspls = TreeClone_sampleTree(nsim, T = T, K = K, G = 8, hp = hp, b = b, niter = niter, 
               parent_str_all = parent_str_all, suffix = suffix, tmpdirectory = tmpdirectory)
  
  if(save_MCMCspls){
    MCMCspls_file = sprintf("./results/MCMCspls_%s.rds", suffix)
    saveRDS(MCMCspls, MCMCspls_file)
  }
  
  Spl_Tree = MCMCspls$Spl_Tree[(burnin+1):niter]
  
  Result = NULL
  Result$Tree = NULL
  Result$Z = NULL
  Result$theta = NULL
  Result$rho_star = NULL
  Result$logpost = NULL
  
  Result_init = MCMCspls$Result
  
  n_AfterBurnin = niter - burnin
  
  for(i in 1:n_AfterBurnin){
    Result$Tree[[i]] = Result_init$Tree[[i + burnin]]
    Result$Z[[i]] = Result_init$Z[[i + burnin]]
    Result$theta[[i]] = Result_init$theta[[i + burnin]]
    Result$rho_star[[i]] = Result_init$rho_star[[i + burnin]]
    Result$logpost[[i]] = Result_init$logpost[[i + burnin]]
  }
  
  C_Tree = nchar(parent_str_all)
  
  TreeTabulate = tabulate(Spl_Tree)
  Tree_postprob = table(Spl_Tree)
  TreeMode = which(TreeTabulate == max(TreeTabulate))
  index = which(Spl_Tree == TreeMode)
  n_index = length(index)
  logpost = rep(0, n_index)
  for(i in 1:n_index){
    logpost[i] = Result$logpost[[index[i]]]
  }
  
  Chat = C_Tree[TreeMode]
  
  map_index = which.max(logpost)
  
  Zhat = Result$Z[[index[map_index]]]+1
  what = transform_theta(Result$theta[[index[map_index]]])
  
  if(save_Z_plot){
    # name for the output genotype file
    Zhat_file = sprintf("./results/Zhat_%s.pdf", suffix)
    
    # rowOrder = heatmap_order_2(Zhat)
    pdf(Zhat_file)
    PairTree_plot_Z(Zhat)
    dev.off()
  }
  
  
  point_est = NULL
  point_est$Z = Zhat
  point_est$w = what
  point_est$Tree_prob = Tree_postprob
  
  if(save_point_est){
    point_est_file = sprintf("./results/point_est_%s.rds", suffix)
    saveRDS(point_est, file = point_est_file)
  }
  
  return(point_est)
  
}


