/*
Used for R
*/

#include "stdio.h"
#include "R.h"
#include "Rmath.h"
#include "assert.h"
#include "math.h"
#include "stdlib.h"
#include "string.h"

extern "C"{

/***************************************************
 Preparation. Two transformations.
***************************************************/
void transform_theta(double *theta, double *w, int T, int C){
// transform theta to (normalized) w
// theta: a T*(C+1) matrix; w: a T*(C+1) matrix
// the first column of theta is w_t1, normalized weight of the normal clone
  
  int c, t;
  double *theta_rowsum; 
  theta_rowsum = (double *)malloc(T * sizeof(double));
  
  /*
  if(C == 1){
    // w is only the background subclone then
    for (t = 0; t < T; t++){
      w[0*T + t] = theta[0*T + t];
      w[1*T + t] = 1 - theta[0*T + t];
    }
  }
  else{
    for (t = 0; t < T; t++){
      theta_rowsum[t] = 0.0;
      for (c = 1; c <= C; c++){
        // only rowsums for columns 1-C are calculated
        theta_rowsum[t] = theta_rowsum[t] + theta[c * T + t];
      }
    }
    for (t = 0; t < T; t++){
      w[0 * T + t]  = theta[0*T + t];
      for (c = 1; c <= C; c++){
        w[c * T + t] = (1 - theta[0*T + t]) * (theta[c * T + t] / theta_rowsum[t]);
      }
    }
  }
  */
  for (t = 0; t < T; t++){
    theta_rowsum[t] = 0.0;
    for (c = 1; c <= C; c++){
      // only rowsums for columns 1-C are calculated
      theta_rowsum[t] = theta_rowsum[t] + theta[c * T + t];
    }
  }
  for (t = 0; t < T; t++){
    w[0 * T + t]  = theta[0*T + t];
    for (c = 1; c <= C; c++){
      w[c * T + t] = (1 - theta[0*T + t]) * (theta[c * T + t] / theta_rowsum[t]);
    }
  }
    
  free(theta_rowsum);
  return;
}


void transform_rho_star(double *rho_star, double *rho, int G){
// transform rho_star to (normalized) rho
// (rho_1, ..., rho_4) ~ Dir, (rho_5, rho_6) ~ Dir, (rho_7, rho_8) ~ Dir
// rho_star, rho: G dimensional vectors
  
    double rho_star_sum = 0.0;
    int g;
    int G_half = G / 2;

    for(g = 0; g < G_half; g++){
        rho_star_sum = rho_star_sum + rho_star[g];
    }
    for(g = 0; g < G_half; g++){
        rho[g] = rho_star[g] / rho_star_sum;
    }
    
    rho_star_sum = 0.0;
    for(g = G_half; g < G_half + 2; g++){
        rho_star_sum = rho_star_sum + rho_star[g];
    }
    for(g = G_half; g < G_half + 2; g++){
        rho[g] = rho_star[g] / rho_star_sum;
    }

    rho_star_sum = 0.0;
    for(g = G_half + 2; g < G; g++){
        rho_star_sum = rho_star_sum + rho_star[g];
    }
    for(g = G_half + 2; g < G; g++){
        rho[g] = rho_star[g] / rho_star_sum;
    }

    return;
  }
  
  
// my own sample function in C. prob needs not sum to one
int sample_TJ(double *prob, int n){
    
    double sum_prob = 0.0;
    double randomUnif;
    int i;
    for(i = 0; i < n; i++){
        sum_prob = sum_prob + prob[i];
    }
    
    i = 0;
    prob[i] = prob[i] / sum_prob;
    
    GetRNGstate();
    randomUnif = runif(0.0, 1.0);
    PutRNGstate();  
    
    while(randomUnif > prob[i]){
        i++;
        prob[i] = (prob[i] / sum_prob) + prob[i - 1];
    }
    return i;
}
  

  
double lfactorial_TJ(int x){
    
    if(x <= 0){
        return 0;
    }
    return (log(x) + lfactorial_TJ(x - 1));
}
  

  
  

void update_Z(int *Z, int *Z_vec, double *theta, double *rho_star,
    int *count, double *n, int C, int T, int K, int G, double lambda,
    double A[], int *treeStateMat, int *treeCntMat, int n_State, double Tmp){
    
    // count: current counts by column
    // Tmp: temprature for parallel tempering. If don't want tempering, set Tmp = 1
    
    double *w;
    w = (double *)calloc(T * (C + 1),sizeof(double));
    transform_theta(theta, w, T, C);
    
    double *rho; 
    rho = (double *)calloc(G, sizeof(double));
    transform_rho_star(rho_star, rho, G);
    
    double *prob;
    prob = (double *)calloc(n_State, sizeof(double));
    
    double loglik, logprior, p_tkg;
    int count_diff;
    double max_logprob = -INFINITY;
    
    int k, i, t, g, c, index;
    
    // update row by row
    for(k = 0; k < K; k++){
        for(i = 0; i < n_State; i++){
            loglik = 0.0;
            logprior = 0.0;
            for(t = 0; t < T; t++){
                for(g = 0; g < G; g++){
                    p_tkg = 0.0;
                    for(c = 0; c < C; c++){
                        //p_tkg = p_tkg + w[c * T + t] * A[treeStateMat[i][c] * G + g];
                        p_tkg = p_tkg + w[c * T + t] * A[treeStateMat[c * n_State + i] * G + g];
                    } 
                    p_tkg = p_tkg + w[C * T + t] * rho[g];
                    if(n[g * T * K + k * T + t] > 1.0E-6){
                        loglik = loglik + n[g * T * K + k * T + t] * log(p_tkg);
                    }
                    
                }
            }
            
            for(c = 1; c < C; c++){
                //count_diff = treeCntMat[i][c] - treeCntMat[Z_vec[k]][c];
                count_diff = treeCntMat[c * n_State + i] - treeCntMat[c * n_State + Z_vec[k]];  
                // Since this is truncated Poisson, if a subclone has no new mutation, that means it's the same with its parent, should reject.
                if((count[c] + count_diff) == 0){
                    logprior = -INFINITY;
                    break;
                }
                
                if(count_diff != 0){
                    logprior = logprior + log(lambda) * count_diff - lfactorial_TJ(count[c] + count_diff) + lfactorial_TJ(count[c]);
                } 
            }
            prob[i] = (loglik + logprior) / Tmp;
            
            
            
            // find the maximum log probability
            if(i == 0){
                max_logprob = prob[i];
            }
            else{
                if(max_logprob < prob[i]) max_logprob = prob[i];
            }
        }
        
        // transform prob from log scale to usual scale
        for(i = 0; i < n_State; i++){
            prob[i] = prob[i] - max_logprob;
            prob[i] = exp(prob[i]);
        }
        
        index = sample_TJ(prob, n_State);
        
        for(c = 0; c < C; c++){
            //Z[c * K + k] = treeStateMat[index][c];
            //count[c] = count[c] + treeCntMat[index][c] - treeCntMat[Z_vec[k]][c];
            Z[c * K + k] = treeStateMat[c * n_State + index];
            count[c] = count[c] + treeCntMat[c * n_State + index] - treeCntMat[c * n_State + Z_vec[k]];
        }
        
        Z_vec[k] = index;
    }
    
    free(w);
    free(rho);
    free(prob);
    return;
}


  
  
void update_theta(int *Z, double *theta, double *rho_star, double *n,
    int C, int T, int K, int G, double d, double d0, double a_p, double b_p, 
    double A[], double Tmp){
  // Update theta
  // theta: T * (C + 1) matrix. 
  // theta_tC means theta_t0 in the paper, proportion of the background subclone in sample t.

    double *rho; 
    rho = (double *)calloc(G, sizeof(double));
    transform_rho_star(rho_star, rho, G);
    
    double *w;
    w = (double *)calloc(T * (C + 1), sizeof(double));
    transform_theta(theta, w, T, C);
    
    // only calculate proposed w for the t-th row w_t = (w_t1, ..., w_tC, w_t0)
    double *w_t_pro;
    w_t_pro = (double *)calloc((C + 1), sizeof(double));
    double sum_theta_t_pro;
    
    double dd;
    double theta_tc_cur, theta_tc_pro;
    
    int c, t, k, g, c_iter;
    double loglik_cur, loglik_pro;
    double logpost_cur, logpost_pro;
    double p_tkg_cur, p_tkg_pro;
    double norm_prop;
    double u;
    
    for(t = 0; t < T; t++){
        
        // update theta_t0 i.e. normal cell contamination
        // c = 0 implicitly
        
        loglik_cur = 0.0;
        loglik_pro = 0.0;
        sum_theta_t_pro = 0.0;
        
        theta_tc_cur = theta[0 * T + t];
        
        GetRNGstate();
        norm_prop = rnorm(0.0, 0.1);
        PutRNGstate();
            
        theta_tc_pro = theta_tc_cur + norm_prop;
        
        // if not 0 < w_t1 <= 1 then reject
        if((theta_tc_pro < 1) && (theta_tc_pro > 0)){
            for(c_iter = 1; c_iter <= C; c_iter++){
                sum_theta_t_pro = sum_theta_t_pro + theta[c_iter * T + t];
            }
            
            w_t_pro[0] = theta_tc_pro;
            for(c_iter = 1; c_iter <= C; c_iter++){
                w_t_pro[c_iter] = (1 - w_t_pro[0]) * (theta[c_iter * T + t] / sum_theta_t_pro);
            }
            
            
            for(k = 0; k < K; k++){
                for(g = 0; g < G; g++){
                    p_tkg_cur = 0.0;
                    p_tkg_pro = 0.0;
                    for(c_iter = 0; c_iter < C; c_iter++){
                        p_tkg_cur = p_tkg_cur + w[c_iter * T + t] * A[Z[c_iter * K + k] * G + g];
                        p_tkg_pro = p_tkg_pro + w_t_pro[c_iter] * A[Z[c_iter * K + k] * G + g];
                    }
                    p_tkg_cur = p_tkg_cur + w[C * T + t] * rho[g];
                    p_tkg_pro = p_tkg_pro + w_t_pro[C] * rho[g];
                    
                    if(n[g * T * K + k * T + t] > 1.0E-6){
                        // = Sum_{k, g} n_tkg log(p_cur_tkg)
                        loglik_cur = loglik_cur + n[g * T * K + k * T + t] * log(p_tkg_cur); 
                        // = Sum_{k, g} n_tkg log(p_pro_tkg)
                        loglik_pro = loglik_pro + n[g * T * K + k * T + t] * log(p_tkg_pro); 
                    }
                }
            }
            
            logpost_cur = (a_p-1) * log(theta_tc_cur) + (b_p-1) * log(1 - theta_tc_cur) + loglik_cur; 
            logpost_pro = (a_p-1) * log(theta_tc_pro) + (b_p-1) * log(1 - theta_tc_pro) + loglik_pro;
            
            // logpost * (1/Tmp) + Jacobian, no Jacobian here because symmetric prop
            logpost_cur = logpost_cur / Tmp;
            logpost_pro = logpost_pro / Tmp;
            
            GetRNGstate();
            u = runif(0.0, 1.0);
            PutRNGstate();     
  
            if((log(u) < logpost_pro - logpost_cur) && (w_t_pro[C] <= 0.03)){
                theta[0 * T + t] = theta_tc_pro;
                for(c_iter = 0; c_iter <= C; c_iter++){
                   w[c_iter * T + t] = w_t_pro[c_iter];
                }
            } 
            
        }
        
        
        
        
        // update theta_t1, ..., theta_tC    
        for(c = 1; c <= C; c++){
            // update theta_tc
            // initialize the log likelihoods.
            loglik_cur = 0.0;
            loglik_pro = 0.0;
            sum_theta_t_pro = 0.0;
            
            // set the hyperparameter
            if(c == C) dd = d0;
            else dd = d;
            
            theta_tc_cur = theta[c * T + t]; // current value of theta_tc
            
            GetRNGstate();
            norm_prop = rnorm(0.0, 0.2);
            PutRNGstate();
            
            theta_tc_pro = theta_tc_cur * exp(norm_prop);
            
            
            // calculate w_t for the proposed theta_tc
            for(c_iter = 1; c_iter <= C; c_iter++){
                if(c_iter != c) sum_theta_t_pro = sum_theta_t_pro + theta[c_iter * T + t];
                else sum_theta_t_pro = sum_theta_t_pro + theta_tc_pro;
            }
            
            w_t_pro[0] = theta[0*T + t];
            for(c_iter = 1; c_iter <= C; c_iter++){
                if(c_iter != c) w_t_pro[c_iter] = (1 - w_t_pro[0]) * theta[c_iter * T + t] / sum_theta_t_pro;
                else w_t_pro[c_iter] = (1 - w_t_pro[0]) * theta_tc_pro / sum_theta_t_pro;
            }

            for(k = 0; k < K; k++){
                for(g = 0; g < G; g++){
                    p_tkg_cur = 0.0;
                    p_tkg_pro = 0.0;
                    for(c_iter = 0; c_iter < C; c_iter++){
                        p_tkg_cur = p_tkg_cur + w[c_iter * T + t] * A[Z[c_iter * K + k] * G + g];
                        p_tkg_pro = p_tkg_pro + w_t_pro[c_iter] * A[Z[c_iter * K + k] * G + g];
                    }
                    p_tkg_cur = p_tkg_cur + w[C * T + t] * rho[g];
                    p_tkg_pro = p_tkg_pro + w_t_pro[C] * rho[g];
                    
                    if(n[g * T * K + k * T + t] > 1.0E-6){
                        // = Sum_{k, g} n_tkg log(p_cur_tkg)
                        loglik_cur = loglik_cur + n[g * T * K + k * T + t] * log(p_tkg_cur); 
                        // = Sum_{k, g} n_tkg log(p_pro_tkg)
                        loglik_pro = loglik_pro + n[g * T * K + k * T + t] * log(p_tkg_pro); 
                    }
                }
            }
            
            
            logpost_cur = (dd-1) * log(theta_tc_cur) - theta_tc_cur + loglik_cur; 
            logpost_pro = (dd-1) * log(theta_tc_pro) - theta_tc_pro + loglik_pro;
            
            // logpost * (1/Tmp) + Jacobian
            logpost_cur = logpost_cur / Tmp + log(theta_tc_cur);
            logpost_pro = logpost_pro / Tmp + log(theta_tc_pro);
            
            GetRNGstate();
            u = runif(0.0, 1.0);
            PutRNGstate();   
  
            if((log(u) < logpost_pro - logpost_cur) && (w_t_pro[C] <= 0.03) && (theta_tc_pro > 0)){
                theta[c * T + t] = theta_tc_pro;
                for(c_iter = 0; c_iter <= C; c_iter++){
                   w[c_iter * T + t] = w_t_pro[c_iter];
                }
            } 

        }
    }
  
  free(rho);
  free(w);
  free(w_t_pro);
  return;
}
  


void update_rho_star(int *Z, double *theta, double *rho_star, double *n,
    int C, int T, int K, int G, double d1, double A[], double Tmp){
    // Update all element of rho_star by iterating update_rho_star_g
        
    double dd;
    double rho_star_g_cur, rho_star_g_pro;
    
    int c, t, k, g, g_iter;
    double loglik_cur, loglik_pro;
    double logpost_cur, logpost_pro;
    double p_tkg_cur, p_tkg_pro;
    double norm_prop;
    double u;

    double *w;
    w = (double *)calloc(T * (C + 1), sizeof(double));
    transform_theta(theta, w, T, C);
    
    double *rho;  // current rho
    rho = (double *)calloc(G, sizeof(double));
    transform_rho_star(rho_star, rho, G);    
    
    
    double *rho_star_pro;
    rho_star_pro = (double *)calloc(G, sizeof(double));

    for(g_iter = 0; g_iter < G; g_iter++){
        rho_star_pro[g_iter] = rho_star[g_iter];
    }

    double *rho_pro;
    rho_pro = (double *)calloc(G, sizeof(double));
    

    for(g = 0; g < G; g++){
        // update rho_star_g
        // initialize the log likelihoods.
        loglik_cur = 0.0;
        loglik_pro = 0.0;
        
        // set the hyperparameter
        if(g < G/2) dd = d1;
        else dd = 2 * d1;
        
        rho_star_g_cur = rho_star[g]; // current value of theta_tc
        
        GetRNGstate();
        norm_prop = rnorm(0.0, 0.1);
        PutRNGstate();
        
        rho_star_g_pro = rho_star_g_cur * exp(norm_prop);
        
        rho_star_pro[g] = rho_star_g_pro;
        transform_rho_star(rho_star_pro, rho_pro, G);   
        
        if(g < 4){
            for(t = 0; t < T; t++){
                for(k = 0; k < K; k++){
                    for(g_iter = 0; g_iter < 4; g_iter++){
                        p_tkg_cur = 0.0;
                        for(c = 0; c < C; c++){
                            p_tkg_cur = p_tkg_cur + w[c * T + t] * A[Z[c * K + k] * G + g_iter];
                        }
                        p_tkg_pro = p_tkg_cur + w[C * T + t] * rho_pro[g_iter];
                        p_tkg_cur = p_tkg_cur + w[C * T + t] * rho[g_iter];
                        
                        if(n[g_iter * T * K + k * T + t] > 1.0E-6){
                            // = Sum_{t, k, g} n_tkg log(p_tkg_cur)
                            loglik_cur = loglik_cur + n[g_iter * T * K + k * T + t] * log(p_tkg_cur); 
                            loglik_pro = loglik_pro + n[g_iter * T * K + k * T + t] * log(p_tkg_pro);
                        }
                        
                        
                    }
                }
            }
        }
        else if(g < 6){
            for(t = 0; t < T; t++){
                for(k = 0; k < K; k++){
                    for(g_iter = 4; g_iter < 6; g_iter++){
                        p_tkg_cur = 0.0;
                        for(c = 0; c < C; c++){
                            p_tkg_cur = p_tkg_cur + w[c * T + t] * A[Z[c * K + k] * G + g_iter];
                        }
                        p_tkg_pro = p_tkg_cur + w[C * T + t] * rho_pro[g_iter];
                        p_tkg_cur = p_tkg_cur + w[C * T + t] * rho[g_iter];
                        
                        if(n[g_iter * T * K + k * T + t] > 1.0E-6){
                            // = Sum_{t, k, g} n_tkg log(p_tkg_cur)
                            loglik_cur = loglik_cur + n[g_iter * T * K + k * T + t] * log(p_tkg_cur); 
                            loglik_pro = loglik_pro + n[g_iter * T * K + k * T + t] * log(p_tkg_pro);
                        }
                    }
                }
            }
        }
        else{
           for(t = 0; t < T; t++){
                for(k = 0; k < K; k++){
                    for(g_iter = 6; g_iter < 8; g_iter++){
                        p_tkg_cur = 0.0;
                        for(c = 0; c < C; c++){
                            p_tkg_cur = p_tkg_cur + w[c * T + t] * A[Z[c * K + k] * G + g_iter];
                        }
                        p_tkg_pro = p_tkg_cur + w[C * T + t] * rho_pro[g_iter];
                        p_tkg_cur = p_tkg_cur + w[C * T + t] * rho[g_iter];
                        
                        if(n[g_iter * T * K + k * T + t] > 1.0E-6){
                            // = Sum_{t, k, g} n_tkg log(p_tkg_cur)
                            loglik_cur = loglik_cur + n[g_iter * T * K + k * T + t] * log(p_tkg_cur); 
                            loglik_pro = loglik_pro + n[g_iter * T * K + k * T + t] * log(p_tkg_pro);
                        }
                    }
                }
            } 
        }
        // prior, likelihood
        logpost_cur = (dd-1) * log(rho_star_g_cur) - rho_star_g_cur + loglik_cur;
        logpost_pro = (dd-1) * log(rho_star_g_pro) - rho_star_g_pro + loglik_pro;
        // logpost * 1/Tmp + jacobian
        logpost_cur = logpost_cur / Tmp + log(rho_star_g_cur);
        logpost_pro = logpost_pro / Tmp + log(rho_star_g_pro);
        
        GetRNGstate();
        u = runif(0.0, 1.0);
        PutRNGstate();
  
        if((log(u) < logpost_pro - logpost_cur) && (rho_star_g_pro > 0)) rho_star[g] = rho_star_g_pro;
    }
    
    free(w);
    free(rho);
    free(rho_star_pro);
    free(rho_pro);
    return;
}



// p_tilde function for R
void calc_p_tilde(double *p_tilde, int *Z, double *theta, double *rho_star, 
  int *C, int *T, int *K, int *G){
    
    double A[80] = 	  {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0,
                       0.5, 0.5, 0.0, 0.0, 0.5, 0.5, 1.0, 0.0,
                       0.5, 0.0, 0.5, 0.0, 1.0, 0.0, 0.5, 0.5,
                       0.5, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 0.5,
                       0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0,
                       0.0, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.5,
                       0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0,
                       0.0, 0.5, 0.0, 0.5, 0.0, 1.0, 0.5, 0.5,
                       0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 0.0, 1.0,
                       0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0};
    
    int C_num = *C;
    int T_num = *T;
    int K_num = *K;
    int G_num = *G;
        
    double *w;
    w = (double *)calloc(T_num * (C_num + 1),  sizeof(double));
    transform_theta(theta, w, T_num, C_num);
    
    double *rho;
    rho = (double *)calloc(G_num, sizeof(double));
    transform_rho_star(rho_star, rho, G_num);

    int t, k, g;
    int c;
    double p_tkg_tilde;
  
    for(t = 0; t < T_num; t++){
        for(k = 0; k < K_num; k++){
            for(g = 0; g < G_num; g++){
                p_tkg_tilde = 0.0;
                for(c = 0; c < C_num; c++){
                    p_tkg_tilde = p_tkg_tilde + w[c * T_num + t] * A[Z[c * K_num + k] * G_num + g];
                }
                p_tkg_tilde = p_tkg_tilde + w[C_num * T_num + t] * rho[g];
                p_tilde[g * T_num * K_num + k * T_num + t] = p_tkg_tilde;
            }
        }
    }
    free(w);
    free(rho); 
    return;
}




// log-likelihood function for R
void calc_loglik(double *loglik, int *Z, double *theta, double *rho_star, double *n,
  int *C, int *T, int *K, int *G){
    
    double loglik1 = 0.0;
    
    double A[80] = 	  {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0,
                       0.5, 0.5, 0.0, 0.0, 0.5, 0.5, 1.0, 0.0,
                       0.5, 0.0, 0.5, 0.0, 1.0, 0.0, 0.5, 0.5,
                       0.5, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 0.5,
                       0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0,
                       0.0, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.5,
                       0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0,
                       0.0, 0.5, 0.0, 0.5, 0.0, 1.0, 0.5, 0.5,
                       0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 0.0, 1.0,
                       0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0};
    
    int C_num = *C;
    int T_num = *T;
    int K_num = *K;
    int G_num = *G;
    
                       
    double *w;
    w = (double *)calloc(T_num * (C_num + 1), sizeof(double));
    transform_theta(theta, w, T_num, C_num);
    
    double *rho;
    rho = (double *)calloc(G_num, sizeof(double));
    transform_rho_star(rho_star, rho, G_num);
    
    double p_tkg;
    int c, t, k, g;
    for(t = 0; t < T_num; t++){
        for(k = 0; k < K_num; k++){
            for(g = 0; g < G_num; g++){
                p_tkg = 0.0;
                for(c = 0; c < C_num; c++){
                    p_tkg = p_tkg + w[c * T_num + t] * A[Z[c * K_num + k] * G_num + g];
                }
                p_tkg = p_tkg + w[C_num * T_num + t] * rho[g];
                if(n[g * T_num * K_num + k * T_num + t] > 1.0E-6){
                        loglik1 = loglik1 + n[g * T_num * K_num + k * T_num + t] * log(p_tkg);
                }
            }
        }
    }
    
    
    *loglik = loglik1;
    
    free(w);
    free(rho);
    return;
}



// log-posterior for R
void calc_logpost(double *logpost, int *Z, int *count, double *theta, 
  double *rho_star, double *n, int *C, int *T, int *K, int *G, 
  double *lambda, double *d, double *d0, double *d1, double *a_p, double *b_p){
    
    double logpost1;
    double loglik = 0.0;
    double logprior = 0.0;
    
    int C_num = *C;
    int T_num = *T;
    int K_num = *K;
    int G_num = *G;
    
    double lambda_num = *lambda;
	double d_num = *d;
	double d0_num = *d0;
	double d1_num = *d1;
    double a_p_num = *a_p;
	double b_p_num = *b_p;
	
    double A[80] = 	  {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0,
                       0.5, 0.5, 0.0, 0.0, 0.5, 0.5, 1.0, 0.0,
                       0.5, 0.0, 0.5, 0.0, 1.0, 0.0, 0.5, 0.5,
                       0.5, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 0.5,
                       0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0,
                       0.0, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.5,
                       0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0,
                       0.0, 0.5, 0.0, 0.5, 0.0, 1.0, 0.5, 0.5,
                       0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 0.0, 1.0,
                       0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0};
                       
    double *w;
    w = (double *)calloc(T_num * (C_num + 1), sizeof(double));
    transform_theta(theta, w, T_num, C_num);
    
    double *rho;
    rho = (double *)calloc(G_num, sizeof(double));
    transform_rho_star(rho_star, rho, G_num);
    
    double p_tkg;
    int c, t, k, g;
    for(t = 0; t < T_num; t++){
        for(k = 0; k < K_num; k++){
            for(g = 0; g < G_num; g++){
                p_tkg = 0.0;
                for(c = 0; c < C_num; c++){
                    p_tkg = p_tkg + w[c * T_num + t] * A[Z[c * K_num + k] * G_num + g];
                }
                p_tkg = p_tkg + w[C_num * T_num + t] * rho[g];
                
                if(n[g * T_num * K_num + k * T_num + t] > 1.0E-6){
                    loglik = loglik + n[g * T_num * K_num + k * T_num + t] * log(p_tkg);
                }
            }
        }
    }
    
    // prior for Z
    for(c = 1; c < C_num; c++){
        // truncated poisson: can't have two same subclones
        if(count[c] == 0){
            *logpost = -INFINITY;
            return;
        }
        logprior = logprior + log(lambda_num) * count[c] - lambda_num - lfactorial_TJ(count[c]);
    }
    
    // prior for w
    for(t = 0; t < T_num; t++){
        logprior = logprior + (a_p_num - 1) * log(w[0 * T_num + t]) + (b_p_num - 1) * log(1 - w[0 * T_num + t]);
        for(c = 1; c < C_num; c++){
            logprior = logprior + (d_num - 1) * log(w[c * T_num + t] / (1 - w[0 * T_num + t]));
        }
        logprior = logprior + (d0_num - 1) * log(w[c * T_num + t] / (1 - w[0 * T_num + t]));
    }
    
    // prior for rho
    for(g = 0; g < 4; g++){
        logprior = logprior + (d1_num - 1) * log(rho[g]);
    }
    for(g = 4; g < 8; g++){
        logprior = logprior + (2*d1_num - 1) * log(rho[g]);
    }
    
    logpost1 = loglik + logprior;
    if(isnan(logpost1)){
        logpost1 = -INFINITY;
    }
    
    free(w);
    free(rho);
    *logpost = logpost1;
    return;
}




void PairTree_MCMC_2(int *Z, int *Z_vec, double *theta, double *rho_star,
    int *count, double *n, int *C, int *T, int *K, int *G, double *lambda, 
    double *d, double *d0, double *d1, double *a_p, double *b_p, int *niter, 
    int *treeStateMat, int *treeCntMat, int *n_State){
    
    int C_num = *C;
    int T_num = *T;
    int K_num = *K;
    int G_num = *G;
    int n_State_num = *n_State;
    int niter_num = *niter;
    int i;
    
	double lambda_num = *lambda;
	double d_num = *d;
	double d0_num = *d0;
	double d1_num = *d1;
	double a_p_num = *a_p;
	double b_p_num = *b_p;
	
	double Tmp = 1.0;
    
    // remember two things are swapped here.
    double A[80] = 	  {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0,
                       0.5, 0.5, 0.0, 0.0, 0.5, 0.5, 1.0, 0.0,
                       0.5, 0.0, 0.5, 0.0, 1.0, 0.0, 0.5, 0.5,
                       0.5, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 0.5,
                       0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0,
                       0.0, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.5,
                       0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0,
                       0.0, 0.5, 0.0, 0.5, 0.0, 1.0, 0.5, 0.5,
                       0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 0.0, 1.0,
                       0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0};

    
    for(i = 0; i < niter_num; i++){
        update_Z(Z, Z_vec, theta, rho_star, count, n, C_num, T_num, K_num, G_num, lambda_num, A, treeStateMat, treeCntMat, n_State_num, Tmp);
        update_theta(Z, theta, rho_star, n, C_num, T_num, K_num, G_num, d_num, d0_num, a_p_num, b_p_num, A, Tmp);
        update_rho_star(Z, theta, rho_star, n, C_num, T_num, K_num, G_num, d1_num, A, Tmp);
    }
  
    return;
}


void calc_count(int *count, int *Z, int *parent_arr, int *C, int *K){
    
    int c, count_c, k;
    int C_num = *C;
    int K_num = *K;
    
    count[0] = 0;
    for(c = 1; c < C_num; c++){
        count_c = 0;
        for(k = 0; k < K_num; k++){
            // Z_{kc} != Z_{k, Tau_c}
            if(Z[c * K_num + k] != Z[(parent_arr[c] - 1) * K_num + k]){
                count_c = count_c + 1;
            }
        }
        count[c] = count_c;
    }
    return;
}


} // closing extern "C"
