/*

	Code for PT version of PairCloneTree MCMC part using OMP 
	
	
	Update History:
	Sep 28: Tianjian, runs well
    Sep 13 Tianjian
	Sep 8: Subhajit
	Sep 6: Tianjian

*/


#include "stdio.h"
#include "assert.h"
#include "math.h"
#include "stdlib.h"
#include "string.h"
#include "iostream"
#include "omp.h"
#include "gsl/gsl_rng.h"
#include "gsl/gsl_randist.h"

#define RNG_SEED 523531

//#define MAX_TREE_NODE	5
#define MAX_TREE_NODE	7

#define MAX_CHARS_PER_STATE	11

#define MAX_TEMP 8
#define MAX_MCMC_ITERATION 30000

#define MAX_N_STR	50
#define	MAX_STR_LEN	10

/*************************************************************
Functions for generating random numbers, gsl library
*************************************************************/


/* initialize rng */
void init_rng(gsl_rng ** r,long int seed){
	const gsl_rng_type * T;
	gsl_rng_env_setup();

  	T = gsl_rng_default;
  	(*r) = gsl_rng_alloc (T);
    gsl_rng_set (*r,seed);

}

/* free rng */
void free_rng(gsl_rng ** r){
	gsl_rng_free (*r);
}

/*p(x) dx = {1 \over \sqrt{2 \pi \sigma^2}} \exp (-x^2 / 2\sigma^2) dx*/
 double rnorm(double mu,double sigma,gsl_rng * r){

	double x;
	x = mu + gsl_ran_gaussian (r,sigma);
 	return x;
 }

/*p(x) dx = {1 \over \mu} \exp(-x/\mu) dx*/
double rexp(double mu,gsl_rng * r){

	double x;
	x = gsl_ran_exponential (r,mu);
	return x;
}


/*p(x) dx = {1 \over \Gamma(a) b^a} x^{a-1} e^{-x/b} dx */
double rgamma(double a,double b,gsl_rng * r){
	
	double x;
	x = gsl_ran_gamma (r,a,b);
	return x;
}

/*p(x) dx = {1 \over (b-a)} dx*/
double runif(double a,double b,gsl_rng * r){
	
	double x;
	x = gsl_ran_flat (r,a,b);
	return x;
}

/*p(x) dx = {\Gamma(a+b) \over \Gamma(a) \Gamma(b)} x^{a-1} (1-x)^{b-1} dx*/
double rbeta(double a,double b,gsl_rng * r){

	double x;
	x = gsl_ran_beta (r,a,b);
	return x;
}

/*p(k) = {\mu^k \over k!} \exp(-\mu)*/
unsigned int rpoi(double mu,gsl_rng * r){

	unsigned int x;
	x = gsl_ran_poisson (r,mu);
	return x;

}
/*p(k) =  p (1-p)^(k-1)*/
unsigned int rgeom(double p,gsl_rng * r){

	unsigned int x;
	x = gsl_ran_geometric (r,p);
	return x;
}

unsigned int rbern(double p,gsl_rng * r){

	unsigned int x;
	x = gsl_ran_bernoulli (r,p);
	return x;
}

unsigned int rbino(double p,unsigned int n,gsl_rng * r){

	unsigned int x;
	x = gsl_ran_binomial (r,p,n);
	return x;
}

void rmulti(size_t K, unsigned int N, double *p, unsigned int * n, gsl_rng * r){

 	gsl_ran_multinomial (r,K,N,p,n);

}

// discrete uniform, a random integer in [0, n-1]
int rdunif(unsigned long int n, gsl_rng * r){
    unsigned long int x;
    x = gsl_rng_uniform_int (r, n);
    return x;
}




/***************************************************
Functions for finding possible tree states
***************************************************/

// hashTable[i] stores all characters that correspond to state i for a node in the tree
const char hashTable[MAX_TREE_NODE][MAX_CHARS_PER_STATE] = {"a", "abc", "abcdefg", "abcdefghi", "abcdefghij", "abcdefghij", "abcdefghij"};

int findState(char ch)
{
	if(ch == 'a')
		return 0;
	if(ch == 'b' || ch == 'c')
		return 1;
	if(ch == 'd' || ch == 'e' || ch == 'f' || ch == 'g')
		return 2;	
	if(ch == 'h' || ch == 'i')
		return 3;
	if(ch == 'j')
		return 4;
	else
		return -1;
}
 
 
// A recursive function to print all possible words that can be obtained
// by input number[] of size n.  The output words are one by one stored in output[]
void  genWordsRecur(int* number, int curr_digit, char output[], int n, char** treeStateCharMat)
{
    // Base case, if current output word is prepared
    unsigned int i;
	static int cnt;
    if (curr_digit == n)
    {
        //fprintf(stdout,"%s ", output);
		unsigned int k = strlen(output);
		//fprintf(stdout,"%d\n",k);
		for(i=0;i<k;i++)
			treeStateCharMat[cnt][i] = output[i]; 
		treeStateCharMat[cnt][k] = '\0';
		cnt++;

		//fprintf(stdout,"%d\t", cnt);		
        return ;
    }
 
    // Try all MAX_TREE_NODE possible nodes for current tree in number[] and recur for remaining state
    for (i=0; i<strlen(hashTable[number[curr_digit]]); i++)
    {
        output[curr_digit] = hashTable[number[curr_digit]][i];
        genWordsRecur(number, curr_digit+1, output, n, treeStateCharMat);
    }
}
 
// Thiscreates an output array and calls the recursive routine
void genWords(int* parentArr, int n, char** treeStateCharMat)
{
    char result[n+1];
    result[n] ='\0';
    genWordsRecur(parentArr, 0, result, n, treeStateCharMat);
}



// returns 0 if the number combination is not a valid tree
// returns 1 if the number combination is a valid tree
int validTree(char* treeToChk,int* parentArr,int n)
{
	int i,k;
	int ph1,ph2,ph3,ph4;
	int sameSt1,sameSt2,sameSt3,sameSt4;
	int sameSt5,sameSt6,sameSt7,sameSt8;
	int sameSt9,sameSt10,sameSt11,sameSt12;
	int sameSt13,sameSt14,sameSt15,sameSt16;
    
	for(i=(n-1);i>=1;i--)
	{
		k = (findState(treeToChk[i]) - findState(treeToChk[parentArr[i]-1]));

		ph1 = (int)(treeToChk[i] == 'g' && treeToChk[parentArr[i]-1] == 'b');
		ph2 = (int)(treeToChk[i] == 'e' && treeToChk[parentArr[i]-1] == 'c');
		ph3 = (int)(treeToChk[i] == 'i' && treeToChk[parentArr[i]-1] == 'e');
		ph4 = (int)(treeToChk[i] == 'h' && treeToChk[parentArr[i]-1] == 'g');

		sameSt1 = (int)(treeToChk[i] == 'b' && treeToChk[parentArr[i]-1] == 'c');
		sameSt2 = (int)(treeToChk[i] == 'c' && treeToChk[parentArr[i]-1] == 'b');
		sameSt3 = (int)(treeToChk[i] == 'h' && treeToChk[parentArr[i]-1] == 'i');
		sameSt4 = (int)(treeToChk[i] == 'i' && treeToChk[parentArr[i]-1] == 'h');
		sameSt5 = (int)(treeToChk[i] == 'd' && treeToChk[parentArr[i]-1] == 'e');
		sameSt6 = (int)(treeToChk[i] == 'd' && treeToChk[parentArr[i]-1] == 'f');
		sameSt7 = (int)(treeToChk[i] == 'd' && treeToChk[parentArr[i]-1] == 'g');
		sameSt8 = (int)(treeToChk[i] == 'e' && treeToChk[parentArr[i]-1] == 'd');
		sameSt9 = (int)(treeToChk[i] == 'e' && treeToChk[parentArr[i]-1] == 'f');
		sameSt10 = (int)(treeToChk[i] == 'e' && treeToChk[parentArr[i]-1] == 'g');
		sameSt11 = (int)(treeToChk[i] == 'f' && treeToChk[parentArr[i]-1] == 'd');
		sameSt12 = (int)(treeToChk[i] == 'f' && treeToChk[parentArr[i]-1] == 'e');
		sameSt13 = (int)(treeToChk[i] == 'f' && treeToChk[parentArr[i]-1] == 'g');
		sameSt14 = (int)(treeToChk[i] == 'g' && treeToChk[parentArr[i]-1] == 'd');
		sameSt15 = (int)(treeToChk[i] == 'g' && treeToChk[parentArr[i]-1] == 'e');
		sameSt16 = (int)(treeToChk[i] == 'g' && treeToChk[parentArr[i]-1] == 'f');
		
        
		if( (k > 1) || (k < 0) ||((ph1==1)||(ph2==1)||(ph3==1)||(ph4==1)) || 
					((sameSt1 == 1)||(sameSt2 == 1)||(sameSt3 == 1) || (sameSt4 == 1) ||  
					 (sameSt5 == 1)||(sameSt6 == 1)||(sameSt7 == 1) || (sameSt8 == 1) ||   
					 (sameSt9 == 1)||(sameSt10 == 1)||(sameSt11 == 1) || (sameSt12 == 1) ||    
					 (sameSt13 == 1)||(sameSt14 == 1)||(sameSt15 == 1) || (sameSt16 == 1) )) 

			return 0;
	}	

	return 1;
}



void getAllCombinations(int* parentArr,int n,int*** treeStateMat,int*** treeCntMat, int* m,int *m2)
{
// m is the number of valid combinations
// m2 is the number of all possible combinations
    int i,j,k;
	int N_combinations = 1;


	for(i=0;i<n;i++)
		N_combinations *= strlen(hashTable[parentArr[i]]); 	

	*m2 = N_combinations;
    
	char** treeStateCharMat;
	treeStateCharMat = (char**)(calloc(N_combinations,sizeof(char*)));
	for(i=0;i<N_combinations;i++)
		treeStateCharMat[i] = (char*)(calloc((n+1),sizeof(char)));
	
	
	genWords(parentArr, n,treeStateCharMat);
	
	for(i=0;i<N_combinations;i++)
	{
		k = validTree(treeStateCharMat[i],parentArr,n);
		if(k == 1)	
		{	
			for(j=0;j<n;j++)
			{
				(*treeStateMat)[*m][j] = (int)(treeStateCharMat[i][j])-97;
                if(j==0)
                {
                    (*treeCntMat)[*m][j] = 0;
                }
                else
                    (*treeCntMat)[*m][j] = findState(treeStateCharMat[i][j]) - findState(treeStateCharMat[i][parentArr[j]-1]); 
			}
			(*m)++;
		}
		else
			;
	}
	fprintf(stdout, "Number of row states of Z: %d (out of %d).\n", (*m), N_combinations);

	for(i=0;i<N_combinations;i++)
		free(treeStateCharMat[i]);
	free(treeStateCharMat);	

	
}

void cleanGarbage(int*** treeStateMat,int*** treeCntMat,int m2)
{
	int i;
	for(i=0;i<m2;i++)
	{
		free((*treeStateMat)[i]);
		free((*treeCntMat)[i]);
	}
	free(*treeStateMat);	
	free(*treeCntMat);	
}



/***************************************************
 Preparation. Two transformations.
***************************************************/

// int MAX_TEMP = 8;

typedef struct _mcmcUnit 
{
  int *Z[MAX_TEMP]; // K x C
  int *Z_vec[MAX_TEMP]; // K
  int *count[MAX_TEMP]; // C
  double *theta[MAX_TEMP]; // T x (C+1)
  double *rho_star[MAX_TEMP]; // 1 x G
  double lpost[MAX_TEMP];	// 1	

} mcmcUnit;

 

void transform_theta(double *theta, double *w, int T, int C){
// transform theta to (normalized) w
// theta: a T*(C+1) matrix; w: a T*(C+1) matrix
// the first column of theta is w_t1, normalized weight of the normal clone
  
  int c, t;
  double *theta_rowsum; 
  theta_rowsum = (double *)malloc(T * sizeof(double));
  
  /*
  if(C == 1){
    // w is only the background subclone then
    for (t = 0; t < T; t++){
      w[0*T + t] = theta[0*T + t];
      w[1*T + t] = 1 - theta[0*T + t];
    }
  }
  else{
    for (t = 0; t < T; t++){
      theta_rowsum[t] = 0.0;
      for (c = 1; c <= C; c++){
        // only rowsums for columns 1-C are calculated
        theta_rowsum[t] = theta_rowsum[t] + theta[c * T + t];
      }
    }
    for (t = 0; t < T; t++){
      w[0 * T + t]  = theta[0*T + t];
      for (c = 1; c <= C; c++){
        w[c * T + t] = (1 - theta[0*T + t]) * (theta[c * T + t] / theta_rowsum[t]);
      }
    }
  }
  */
  for (t = 0; t < T; t++){
    theta_rowsum[t] = 0.0;
    for (c = 1; c <= C; c++){
      // only rowsums for columns 1-C are calculated
      theta_rowsum[t] = theta_rowsum[t] + theta[c * T + t];
    }
  }
  for (t = 0; t < T; t++){
    w[0 * T + t]  = theta[0*T + t];
    for (c = 1; c <= C; c++){
      w[c * T + t] = (1 - theta[0*T + t]) * (theta[c * T + t] / theta_rowsum[t]);
    }
  }
    
  free(theta_rowsum);
  return;
}


void transform_rho_star(double *rho_star, double *rho, int G){
// transform rho_star to (normalized) rho
// (rho_1, ..., rho_4) ~ Dir, (rho_5, rho_6) ~ Dir, (rho_7, rho_8) ~ Dir
// rho_star, rho: G dimensional vectors
  
    double rho_star_sum = 0.0;
    int g;
    int G_half = G / 2;

    for(g = 0; g < G_half; g++){
        rho_star_sum = rho_star_sum + rho_star[g];
    }
    for(g = 0; g < G_half; g++){
        rho[g] = rho_star[g] / rho_star_sum;
    }
    
    rho_star_sum = 0.0;
    for(g = G_half; g < G_half + 2; g++){
        rho_star_sum = rho_star_sum + rho_star[g];
    }
    for(g = G_half; g < G_half + 2; g++){
        rho[g] = rho_star[g] / rho_star_sum;
    }

    rho_star_sum = 0.0;
    for(g = G_half + 2; g < G; g++){
        rho_star_sum = rho_star_sum + rho_star[g];
    }
    for(g = G_half + 2; g < G; g++){
        rho[g] = rho_star[g] / rho_star_sum;
    }

    return;
  }
  
  
// my own sample function in C. prob needs not sum to one
int sample_TJ(double *prob, int n, gsl_rng *r){
    
    double sum_prob = 0.0;
    double randomUnif;
    int i;
    for(i = 0; i < n; i++){
        sum_prob = sum_prob + prob[i];
    }
    
    i = 0;
    prob[i] = prob[i] / sum_prob;
    randomUnif = runif(0.0, 1.0, r);
    
    while(randomUnif > prob[i]){
        i++;
        prob[i] = (prob[i] / sum_prob) + prob[i - 1];
    }
    return i;
}
  

  
double lfactorial_TJ(int x){
    
    if(x <= 0){
        return 0;
    }
    return (log(x) + lfactorial_TJ(x - 1));
}
  

  
  

void update_Z(int *Z, int *Z_vec, double *theta, double *rho_star,
    int *count, double *n, int C, int T, int K, int G, double lambda,
    double A[], int **treeStateMat, int **treeCntMat, int n_State, double Tmp,
    gsl_rng *r){
    
    // count: current counts by column
    // Tmp: temprature for parallel tempering. If don't want tempering, set Tmp = 1
    
    double *w;
    w = (double *)calloc(T * (C + 1),sizeof(double));
    transform_theta(theta, w, T, C);
    
    double *rho; 
    rho = (double *)calloc(G, sizeof(double));
    transform_rho_star(rho_star, rho, G);
    
    double *prob;
    prob = (double *)calloc(n_State, sizeof(double));
    
    double loglik, logprior, p_tkg;
    int count_diff;
    double max_logprob = -INFINITY;
    
    int k, i, t, g, c, index;
    
    // update row by row
    for(k = 0; k < K; k++){
        for(i = 0; i < n_State; i++){
            loglik = 0.0;
            logprior = 0.0;
            for(t = 0; t < T; t++){
                for(g = 0; g < G; g++){
                    p_tkg = 0.0;
                    for(c = 0; c < C; c++){
                        p_tkg = p_tkg + w[c * T + t] * A[treeStateMat[i][c] * G + g];
                        //p_tkg = p_tkg + w[c * T + t] * A[treeStateMat[c * n_State + i] * G + g];
                    } 
                    p_tkg = p_tkg + w[C * T + t] * rho[g];
                    if(n[g * T * K + k * T + t] > 1.0E-6){
                        loglik = loglik + n[g * T * K + k * T + t] * log(p_tkg);
                    }
                }
            }
            
            // the first column count is always 0, no need to consider
            for(c = 1; c < C; c++){
                count_diff = treeCntMat[i][c] - treeCntMat[Z_vec[k]][c];
                //count_diff = treeCntMat[c * n_State + i] - treeCntMat[c * n_State + Z_vec[k]];
                // Since this is truncated Poisson, if a subclone has no new mutation, that means it's the same with its parent, should reject.
                if((count[c] + count_diff) == 0){
                    logprior = -INFINITY;
                    break;
                }
                if(count_diff != 0){
                    logprior = logprior + log(lambda) * count_diff - lfactorial_TJ(count[c] + count_diff) + lfactorial_TJ(count[c]);
                } 
            }
            prob[i] = (loglik + logprior) / Tmp;
            
            // find the maximum log probability
            if(i == 0){
                max_logprob = prob[i];
            }
            else{
                if(max_logprob < prob[i]) max_logprob = prob[i];
            }
        }
        
        // transform prob from log scale to usual scale
        for(i = 0; i < n_State; i++){
            prob[i] = prob[i] - max_logprob;
            prob[i] = exp(prob[i]);
        }
        
        index = sample_TJ(prob, n_State, r);
        
        for(c = 0; c < C; c++){
            Z[c * K + k] = treeStateMat[index][c];
            count[c] = count[c] + treeCntMat[index][c] - treeCntMat[Z_vec[k]][c];
            //Z[c * K + k] = treeStateMat[c * n_State + index];
            //count[c] = count[c] + treeCntMat[c * n_State + index] - treeCntMat[c * n_State + Z_vec[k]];
        }
        
        Z_vec[k] = index;
    }
    
    free(w);
    free(rho);
    free(prob);
    return;
}


  
  

void update_theta(int *Z, double *theta, double *rho_star, double *n,
    int C, int T, int K, int G, double d, double d0, double a_p, double b_p, 
    double A[], double Tmp, gsl_rng *r){
  // Update theta
  // theta: T * (C + 1) matrix. 
  // theta_tC means theta_t0 in the paper, proportion of the background subclone in sample t.

    double *rho; 
    rho = (double *)calloc(G, sizeof(double));
    transform_rho_star(rho_star, rho, G);
    
    double *w;
    w = (double *)calloc(T * (C + 1), sizeof(double));
    transform_theta(theta, w, T, C);
    
    // only calculate proposed w for the t-th row w_t = (w_t1, ..., w_tC, w_t0)
    double *w_t_pro;
    w_t_pro = (double *)calloc((C + 1), sizeof(double));
    double sum_theta_t_pro;
    
    double dd;
    double theta_tc_cur, theta_tc_pro;
    
    int c, t, k, g, c_iter;
    double loglik_cur, loglik_pro;
    double logpost_cur, logpost_pro;
    double p_tkg_cur, p_tkg_pro;
    double norm_prop;
    double u;
    
    for(t = 0; t < T; t++){
        
        // update theta_t0 i.e. normal cell contamination
        // c = 0 implicitly
        
        loglik_cur = 0.0;
        loglik_pro = 0.0;
        sum_theta_t_pro = 0.0;
        
        theta_tc_cur = theta[0 * T + t];
        
        norm_prop = rnorm(0.0, 0.1, r);
        theta_tc_pro = theta_tc_cur + norm_prop;
        
        // if not 0 < w_t1 <= 1 then reject
        if((theta_tc_pro < 1) && (theta_tc_pro > 0)){
            for(c_iter = 1; c_iter <= C; c_iter++){
                sum_theta_t_pro = sum_theta_t_pro + theta[c_iter * T + t];
            }
            
            w_t_pro[0] = theta_tc_pro;
            for(c_iter = 1; c_iter <= C; c_iter++){
                w_t_pro[c_iter] = (1 - w_t_pro[0]) * (theta[c_iter * T + t] / sum_theta_t_pro);
            }
            
            
            for(k = 0; k < K; k++){
                for(g = 0; g < G; g++){
                    p_tkg_cur = 0.0;
                    p_tkg_pro = 0.0;
                    for(c_iter = 0; c_iter < C; c_iter++){
                        p_tkg_cur = p_tkg_cur + w[c_iter * T + t] * A[Z[c_iter * K + k] * G + g];
                        p_tkg_pro = p_tkg_pro + w_t_pro[c_iter] * A[Z[c_iter * K + k] * G + g];
                    }
                    p_tkg_cur = p_tkg_cur + w[C * T + t] * rho[g];
                    p_tkg_pro = p_tkg_pro + w_t_pro[C] * rho[g];
                    
                    if(n[g * T * K + k * T + t] > 1.0E-6){
                        // = Sum_{k, g} n_tkg log(p_cur_tkg)
                        loglik_cur = loglik_cur + n[g * T * K + k * T + t] * log(p_tkg_cur); 
                        // = Sum_{k, g} n_tkg log(p_pro_tkg)
                        loglik_pro = loglik_pro + n[g * T * K + k * T + t] * log(p_tkg_pro); 
                    }
                }
            }
            
            logpost_cur = (a_p-1) * log(theta_tc_cur) + (b_p-1) * log(1 - theta_tc_cur) + loglik_cur; 
            logpost_pro = (a_p-1) * log(theta_tc_pro) + (b_p-1) * log(1 - theta_tc_pro) + loglik_pro;
            
            // logpost * (1/Tmp) + Jacobian, no Jacobian here because symmetric prop
            logpost_cur = logpost_cur / Tmp;
            logpost_pro = logpost_pro / Tmp;
            
            u = runif(0.0, 1.0, r);   
  
            if((log(u) < logpost_pro - logpost_cur) && (w_t_pro[C] <= 0.03)){
                theta[0 * T + t] = theta_tc_pro;
                for(c_iter = 0; c_iter <= C; c_iter++){
                   w[c_iter * T + t] = w_t_pro[c_iter];
                }
            } 
            
        }
        
        
        
        
        // update theta_t1, ..., theta_tC    
        for(c = 1; c <= C; c++){
            // update theta_tc
            // initialize the log likelihoods.
            loglik_cur = 0.0;
            loglik_pro = 0.0;
            sum_theta_t_pro = 0.0;
            
            // set the hyperparameter
            if(c == C) dd = d0;
            else dd = d;
            
            theta_tc_cur = theta[c * T + t]; // current value of theta_tc
            
            norm_prop = rnorm(0.0, 0.2, r);
            theta_tc_pro = theta_tc_cur * exp(norm_prop);
            
            
            // calculate w_t for the proposed theta_tc
            for(c_iter = 1; c_iter <= C; c_iter++){
                if(c_iter != c) sum_theta_t_pro = sum_theta_t_pro + theta[c_iter * T + t];
                else sum_theta_t_pro = sum_theta_t_pro + theta_tc_pro;
            }
            
            w_t_pro[0] = theta[0*T + t];
            for(c_iter = 1; c_iter <= C; c_iter++){
                if(c_iter != c) w_t_pro[c_iter] = (1 - w_t_pro[0]) * theta[c_iter * T + t] / sum_theta_t_pro;
                else w_t_pro[c_iter] = (1 - w_t_pro[0]) * theta_tc_pro / sum_theta_t_pro;
            }

            for(k = 0; k < K; k++){
                for(g = 0; g < G; g++){
                    p_tkg_cur = 0.0;
                    p_tkg_pro = 0.0;
                    for(c_iter = 0; c_iter < C; c_iter++){
                        p_tkg_cur = p_tkg_cur + w[c_iter * T + t] * A[Z[c_iter * K + k] * G + g];
                        p_tkg_pro = p_tkg_pro + w_t_pro[c_iter] * A[Z[c_iter * K + k] * G + g];
                    }
                    p_tkg_cur = p_tkg_cur + w[C * T + t] * rho[g];
                    p_tkg_pro = p_tkg_pro + w_t_pro[C] * rho[g];
                    
                    if(n[g * T * K + k * T + t] > 1.0E-6){
                        // = Sum_{k, g} n_tkg log(p_cur_tkg)
                        loglik_cur = loglik_cur + n[g * T * K + k * T + t] * log(p_tkg_cur); 
                        // = Sum_{k, g} n_tkg log(p_pro_tkg)
                        loglik_pro = loglik_pro + n[g * T * K + k * T + t] * log(p_tkg_pro); 
                    }
                }
            }
            
            
            logpost_cur = (dd-1) * log(theta_tc_cur) - theta_tc_cur + loglik_cur; 
            logpost_pro = (dd-1) * log(theta_tc_pro) - theta_tc_pro + loglik_pro;
            
            // logpost * (1/Tmp) + Jacobian
            logpost_cur = logpost_cur / Tmp + log(theta_tc_cur);
            logpost_pro = logpost_pro / Tmp + log(theta_tc_pro);
            
            u = runif(0.0, 1.0, r);   
  
            if((log(u) < logpost_pro - logpost_cur) && (w_t_pro[C] <= 0.03) && (theta_tc_pro > 0)){
                theta[c * T + t] = theta_tc_pro;
                for(c_iter = 0; c_iter <= C; c_iter++){
                   w[c_iter * T + t] = w_t_pro[c_iter];
                }
            } 

        }
    }
  
  free(rho);
  free(w);
  free(w_t_pro);
  return;
}




void update_rho_star(int *Z, double *theta, double *rho_star, double *n,
    int C, int T, int K, int G, double d1, double A[], double Tmp,
    gsl_rng *r){
    // Update all element of rho_star by iterating update_rho_star_g
        
    double dd;
    double rho_star_g_cur, rho_star_g_pro;
    
    int c, t, k, g, g_iter;
    double loglik_cur, loglik_pro;
    double logpost_cur, logpost_pro;
    double p_tkg_cur, p_tkg_pro;
    double norm_prop;
    double u;

    double *w;
    w = (double *)calloc(T * (C + 1), sizeof(double));
    transform_theta(theta, w, T, C);
    
    double *rho;  // current rho
    rho = (double *)calloc(G, sizeof(double));
    transform_rho_star(rho_star, rho, G);    
    
    
    double *rho_star_pro;
    rho_star_pro = (double *)calloc(G, sizeof(double));

    for(g_iter = 0; g_iter < G; g_iter++){
        rho_star_pro[g_iter] = rho_star[g_iter];
    }

    double *rho_pro;
    rho_pro = (double *)calloc(G, sizeof(double));
    

    for(g = 0; g < G; g++){
        // update rho_star_g
        // initialize the log likelihoods.
        loglik_cur = 0.0;
        loglik_pro = 0.0;
        
        // set the hyperparameter
        if(g < G/2) dd = d1;
        else dd = 2 * d1;
        
        rho_star_g_cur = rho_star[g]; // current value of theta_tc
        
        norm_prop = rnorm(0.0, 0.1, r);
        rho_star_g_pro = rho_star_g_cur * exp(norm_prop);
        
        rho_star_pro[g] = rho_star_g_pro;
        transform_rho_star(rho_star_pro, rho_pro, G);   
        
        if(g < 4){
            for(t = 0; t < T; t++){
                for(k = 0; k < K; k++){
                    for(g_iter = 0; g_iter < 4; g_iter++){
                        p_tkg_cur = 0.0;
                        for(c = 0; c < C; c++){
                            p_tkg_cur = p_tkg_cur + w[c * T + t] * A[Z[c * K + k] * G + g_iter];
                        }
                        p_tkg_pro = p_tkg_cur + w[C * T + t] * rho_pro[g_iter];
                        p_tkg_cur = p_tkg_cur + w[C * T + t] * rho[g_iter];
                        
                        if(n[g_iter * T * K + k * T + t] > 1.0E-6){
                            // = Sum_{t, k, g} n_tkg log(p_tkg_cur)
                            loglik_cur = loglik_cur + n[g_iter * T * K + k * T + t] * log(p_tkg_cur); 
                            loglik_pro = loglik_pro + n[g_iter * T * K + k * T + t] * log(p_tkg_pro);
                        }
                    }
                }
            }
        }
        else if(g < 6){
            for(t = 0; t < T; t++){
                for(k = 0; k < K; k++){
                    for(g_iter = 4; g_iter < 6; g_iter++){
                        p_tkg_cur = 0.0;
                        for(c = 0; c < C; c++){
                            p_tkg_cur = p_tkg_cur + w[c * T + t] * A[Z[c * K + k] * G + g_iter];
                        }
                        p_tkg_pro = p_tkg_cur + w[C * T + t] * rho_pro[g_iter];
                        p_tkg_cur = p_tkg_cur + w[C * T + t] * rho[g_iter];
                        
                        if(n[g_iter * T * K + k * T + t] > 1.0E-6){
                            // = Sum_{t, k, g} n_tkg log(p_tkg_cur)
                            loglik_cur = loglik_cur + n[g_iter * T * K + k * T + t] * log(p_tkg_cur); 
                            loglik_pro = loglik_pro + n[g_iter * T * K + k * T + t] * log(p_tkg_pro);
                        }
                    }
                }
            }
        }
        else{
           for(t = 0; t < T; t++){
                for(k = 0; k < K; k++){
                    for(g_iter = 6; g_iter < 8; g_iter++){
                        p_tkg_cur = 0.0;
                        for(c = 0; c < C; c++){
                            p_tkg_cur = p_tkg_cur + w[c * T + t] * A[Z[c * K + k] * G + g_iter];
                        }
                        p_tkg_pro = p_tkg_cur + w[C * T + t] * rho_pro[g_iter];
                        p_tkg_cur = p_tkg_cur + w[C * T + t] * rho[g_iter];
                        
                        if(n[g_iter * T * K + k * T + t] > 1.0E-6){
                            // = Sum_{t, k, g} n_tkg log(p_tkg_cur)
                            loglik_cur = loglik_cur + n[g_iter * T * K + k * T + t] * log(p_tkg_cur); 
                            loglik_pro = loglik_pro + n[g_iter * T * K + k * T + t] * log(p_tkg_pro);
                        }
                    }
                }
            } 
        }
        // prior, likelihood
        logpost_cur = (dd-1) * log(rho_star_g_cur) - rho_star_g_cur + loglik_cur;
        logpost_pro = (dd-1) * log(rho_star_g_pro) - rho_star_g_pro + loglik_pro;
        // logpost * 1/Tmp + jacobian
        logpost_cur = logpost_cur / Tmp + log(rho_star_g_cur);
        logpost_pro = logpost_pro / Tmp + log(rho_star_g_pro);
        
        u = runif(0.0, 1.0, r);
  
        if((log(u) < logpost_pro - logpost_cur) && (rho_star_g_pro > 0)) rho_star[g] = rho_star_g_pro;
    }
    
    free(w);
    free(rho);
    free(rho_star_pro);
    free(rho_pro);
    return;
}



void calc_p_tilde(double *p_tilde, int *Z, double *theta, double *rho_star, int C, int T, int K, int G, double A[]){
    
    double *w;
    w = (double *)calloc(T * (C + 1),  sizeof(double));
    transform_theta(theta, w, T, C);
    
    double *rho;
    rho = (double *)calloc(G, sizeof(double));
    transform_rho_star(rho_star, rho, G);

    int t, k, g;
    int c;
    double p_tkg_tilde;
  
    for(t = 0; t < T; t++){
        for(k = 0; k < K; k++){
            for(g = 0; g < G; g++){
                p_tkg_tilde = 0.0;
                for(c = 0; c < C; c++){
                    p_tkg_tilde = p_tkg_tilde + w[c * T + t] * A[Z[c * K + k] * G + g];
                }
                p_tkg_tilde = p_tkg_tilde + w[C * T + t] * rho[g];
                p_tilde[g * T * K + k * T + t] = p_tkg_tilde;
            }
        }
    }
    free(w);
    free(rho); 
    return;
}





double calc_logpost(int *Z, int *count, double *theta, double *rho_star, double *n,
  int C, int T, int K, int G, double lambda, double d, 
  double d0, double d1, double a_p, double b_p, double A[]){
    
    double logpost;
    double loglik = 0.0;
    double logprior = 0.0;
    
    double *w;
    w = (double *)calloc(T * (C + 1), sizeof(double));
    transform_theta(theta, w, T, C);
    
    double *rho;
    rho = (double *)calloc(G, sizeof(double));
    transform_rho_star(rho_star, rho, G);
    
    double p_tkg;
    int c, t, k, g;
    for(t = 0; t < T; t++){
        for(k = 0; k < K; k++){
            for(g = 0; g < G; g++){
                p_tkg = 0.0;
                for(c = 0; c < C; c++){
                    p_tkg = p_tkg + w[c * T + t] * A[Z[c * K + k] * G + g];
                }
                p_tkg = p_tkg + w[C * T + t] * rho[g];
                
                if(n[g * T * K + k * T + t] > 1.0E-6){
                    loglik = loglik + n[g * T * K + k * T + t] * log(p_tkg);
                }
            }
        }
    }
    
    // prior for Z
    for(c = 1; c < C; c++){
        // truncated poisson: can't have two same subclones
        if(count[c] == 0){
            return -INFINITY;
        }
        logprior = logprior + log(lambda) * count[c] - lambda - lfactorial_TJ(count[c]);
    }
    
    // prior for w
    for(t = 0; t < T; t++){
        logprior = logprior + (a_p - 1) * log(w[0 * T + t]) + (b_p - 1) * log(1 - w[0 * T + t]);
        for(c = 1; c < C; c++){
            logprior = logprior + (d - 1) * log(w[c * T + t] / (1 - w[0 * T + t]) );
        }
        logprior = logprior + (d0 - 1) * log(w[c * T + t] / (1 - w[0 * T + t]) );
    }
    
    // prior for rho
    for(g = 0; g < 4; g++){
        logprior = logprior + (d1 - 1) * log(rho[g]);
    }
    for(g = 4; g < 8; g++){
        logprior = logprior + (2*d1 - 1) * log(rho[g]);
    }
    
    logpost = loglik + logprior;
    
    if(isnan(logpost)){
        logpost = -INFINITY;
    }
    
    free(w);
    free(rho);
    return logpost;
}


void init_mcmc_arr(mcmcUnit** mcmcSpl,int niter,int C,int T, int K,int G)
{
    int i, j;
    
	(*mcmcSpl) = (mcmcUnit*)(calloc(niter,sizeof(mcmcUnit))); // big structure allocation
	
	for(i = 0; i < niter; i++){
		for(j = 0; j < MAX_TEMP; j++){			
			(*mcmcSpl)[i].Z[j] = (int*)(calloc(K*C, sizeof(int)));	
			(*mcmcSpl)[i].Z_vec[j] = (int*)(calloc(K, sizeof(int)));
			(*mcmcSpl)[i].count[j] = (int*)(calloc(C, sizeof(int)));			
			(*mcmcSpl)[i].theta[j] = (double*)(calloc(T*(C+1), sizeof(double)));				
			(*mcmcSpl)[i].rho_star[j] = (double*)(calloc(G, sizeof(double)));				
		}
	}
}

void free_mcmc_arr(mcmcUnit** mcmcSpl,int niter,int C,int T, int K,int G)
{
    int i, j;
    
	for(i = 0; i < niter; i++){
		for(j = 0; j < MAX_TEMP; j++){			
			free((*mcmcSpl)[i].Z[j]);
			free((*mcmcSpl)[i].Z_vec[j]);
			free((*mcmcSpl)[i].count[j]);
			free((*mcmcSpl)[i].theta[j]);
			free((*mcmcSpl)[i].rho_star[j]);
		}
	}
}



void PairTree_swap(int **Z1, double **theta1, double **rho_star1, int **Z_vec1, int **count1, double *lpost1, int **Z2, double **theta2, double **rho_star2, int **Z_vec2, int **count2, double *lpost2){
    
    int *temp_int;
    double *temp_double;
    int temp_int2;
    
    temp_int = *Z1;
    *Z1 = *Z2;
    *Z2 = temp_int;
    
    temp_double = *theta1;
    *theta1 = *theta2;
    *theta2 = temp_double;
    
    temp_double = *rho_star1;
    *rho_star1 = *rho_star2;
    *rho_star2 = temp_double;
    
    temp_int = *Z_vec1;
    *Z_vec1 = *Z_vec2;
    *Z_vec2 = temp_int;
    
    temp_int = *count1;
    *count1 = *count2;
    *count2 = temp_int;
    
    temp_int2 = *lpost1;
    *lpost1 = *lpost2;
    *lpost2 = temp_int2;
    
}



/// entry function ??

/// I think we should write in a fashion such that it can run in a any machine 
/// where more than 1 thread are available;
/// we will limit this number of thread by MAX_TEMP   
/// if only single core then it is not true PT run


void PairTree_MCMC(double *n, int C, int T, int K, int G, 
    double lambda, double d, double d0, double d1, double a_p, double b_p,
    int niter, int **treeStateMat, int **treeCntMat, int n_State, 
    double *Delta, int nThread, char *parent_str, char *output_file)
{
    
    int i,k,c,t,g,rank,rank_partner;
	double u1, u2, lalpha;
	
    gsl_rng * r;
 	init_rng(&r,RNG_SEED);
    
    // remember two things are swapped here.
    double A[80] = 	  {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0,
                       0.5, 0.5, 0.0, 0.0, 0.5, 0.5, 1.0, 0.0,
                       0.5, 0.0, 0.5, 0.0, 1.0, 0.0, 0.5, 0.5,
                       0.5, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 0.5,
                       0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0,
                       0.0, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.5,
                       0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0,
                       0.0, 0.5, 0.0, 0.5, 0.0, 1.0, 0.5, 0.5,
                       0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 0.0, 1.0,
                       0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0};
    
	
	/// initialize all the structure
    mcmcUnit *mcmcSpl;
	init_mcmc_arr(&mcmcSpl, niter, C, T, K, G);    

	//fd = fopen("./mcmc_data.bin","wb");
	//if(fd == NULL)
	//	exit(0);
    
    // write Z MAP to Z_hat
    // Z_hat = fopen("./mcmc_data.bin","wb");

    // an array of pointers??
    int *Z[nThread];
    int *Z_vec[nThread];
    int *count[nThread]; // need to initialize this to all 0
    double *theta[nThread];
    double theta_rowsum[nThread];
    double *rho_star[nThread];
    double lpost[nThread];
    
	for(rank = 0; rank < nThread; rank++){
		Z[rank] = (int*)(calloc(K*C, sizeof(int)));
		Z_vec[rank] = (int*)(calloc(K, sizeof(int)));	
		theta[rank] = (double*)(calloc(T*(C+1),sizeof(double)));
		rho_star[rank] = (double*)(calloc(G,sizeof(double)));
	
		count[rank] = (int*)(calloc(C,sizeof(int)));

		for(c = 0; c < C; c++)
			count[rank][c] = 0;
	}  
    
    
    
    
    // initialization
    // 8 is the number of thread, so we can run at most 8 tempratures?
    for(rank = 0; rank < nThread; rank++)
    {
        for(k = 0; k < K; k++){
            Z_vec[rank][k] = rdunif(n_State, r);
        }
        
        
        for(k = 0; k < K; k++){
            for(c = 0; c < C; c++){
                Z[rank][c * K + k] = treeStateMat[Z_vec[rank][k]][c];
                count[rank][c] = count[rank][c] + treeCntMat[Z_vec[rank][k]][c];
            }
        }

        
        // Delta is a 8-dimensional vector, tempratures.
        for(t = 0; t < T; t++){
            theta_rowsum[rank] = 0.0;
            for(c = 0; c < C; c++){
                theta[rank][c * T + t] = rgamma((d-1)/Delta[rank]+1, 1.0, r);
                theta_rowsum[rank] = theta_rowsum[rank] + theta[rank][c * T + t];
            }
            theta[rank][C * T + t] = theta_rowsum[rank] / 999; /// what is 999 ??
        }
    
        for(g = 0; g < G; g++){
            rho_star[rank][g] = rgamma((d1-1)/Delta[rank]+1, 1.0, r);
        }
   
		 
        for(k = 0; k < K; k++){
            for(c = 0; c < C; c++){
        		mcmcSpl[0].Z[rank][c * K + k] = Z[rank][c * K + k];
			}
			mcmcSpl[0].Z_vec[rank][k] = Z_vec[rank][k];
		}
		
		for(c = 0; c < C; c++){
        	mcmcSpl[0].count[rank][c] = count[rank][c];
		}
		
        for(t = 0; t < T; t++){
            for(c = 0; c <= C; c++){	
        		mcmcSpl[0].theta[rank][c * T + t] = theta[rank][c * T + t];
     		}
		}
        for(g = 0; g < G; g++){
        	mcmcSpl[0].rho_star[rank][g] = rho_star[rank][g];
		}
		
		// only proportional to, not equal, but not a big deal
    	mcmcSpl[0].lpost[rank] = (1/Delta[rank]) * calc_logpost(Z[rank], count[rank], theta[rank], rho_star[rank], n, C, T, K, G, lambda, d, d0, d1, a_p, b_p, A);
	}
	
	
    
    //// parallel step
#pragma omp parallel private(i,rank,rank_partner,u1,u2,lalpha) shared(Z, Z_vec, theta, rho_star, count, lpost, Delta)
    {
        rank = omp_get_thread_num();


        for(i = 1; i < niter; i++) // i=1?
        {
            update_Z(Z[rank], Z_vec[rank], theta[rank], rho_star[rank], count[rank], n, C, T, K, G, lambda, A, treeStateMat, treeCntMat, n_State, Delta[rank], r);
            update_theta(Z[rank], theta[rank], rho_star[rank], n, C, T, K, G, d, d0, a_p, b_p, A, Delta[rank], r);
            update_rho_star(Z[rank], theta[rank], rho_star[rank], n, C, T, K, G, d1, A, Delta[rank], r);
            lpost[rank] = calc_logpost(Z[rank], count[rank], theta[rank], rho_star[rank], n, C, T, K, G, lambda, d, d0, d1, a_p, b_p, A);
    
// Synchronise threads
#pragma omp barrier
#pragma omp critical // by individual thread 
            {
            
                // the u should be random unif(0,1)
            	u1 = runif(0.0, 1.0, r);
            	u2 = runif(0.0, 1.0, r);
   
                if(u1 < 0.5){
                    rank_partner = rank + 1;
                    if(rank_partner < nThread){
                        lalpha = (1/Delta[rank] - 1/Delta[rank_partner]) * (lpost[rank_partner] - lpost[rank]);
                        if(log(u2) < lalpha){
                            PairTree_swap(&Z[rank], &theta[rank], &rho_star[rank], &Z_vec[rank], &count[rank], &lpost[rank], &Z[rank_partner], &theta[rank_partner], &rho_star[rank_partner], &Z_vec[rank_partner], &count[rank_partner], &lpost[rank_partner]);
                        }
                    }
                }
                
                for(k = 0; k < K; k++){
            		for(c = 0; c < C; c++){
        				mcmcSpl[i].Z[rank][c * K + k] = Z[rank][c * K + k];
					}
					mcmcSpl[i].Z_vec[rank][k] = Z_vec[rank][k];
				}
				
				for(c = 0; c < C; c++){
        	        mcmcSpl[i].count[rank][c] = count[rank][c];
		        }
		        
        		for(t = 0; t < T; t++){
            		for(c = 0; c <= C; c++){	
        				mcmcSpl[i].theta[rank][c * T + t] = theta[rank][c * T + t];
     				}
				}

        		for(g = 0; g < G; g++){
        			mcmcSpl[i].rho_star[rank][g] = rho_star[rank][g];
				}
				
				mcmcSpl[i].lpost[rank] = (1/Delta[rank]) * lpost[rank];
            }
#pragma omp barrier  // sync threads
    
        } // ends the for loop
    }

    
    double max_lpost = -INFINITY;
    int max_index = 0;
    
    
    // only need to write the samples with temprature = 1
    //write(mcmcSpl[0:niter][rank = 8]) 
   	for(i=0;i<niter;i++){
	//	fwrite(&mcmcSpl[i], sizeof(mcmcUnit), 1, fd);
	//  fprintf(stdout, "log posterior: %f.\n", mcmcSpl[i].lpost[nThread - 1]);	 
	    if(mcmcSpl[i].lpost[nThread - 1] > max_lpost){
	        max_lpost = mcmcSpl[i].lpost[nThread - 1];
	        max_index = i;
	    }
	}
	
	
	// fprintf(stdout, "max_index: %d, nThread: %d.\n", max_index, nThread);
	
	
	// save MAP estimates
	FILE *point_est;
	    
	fprintf(stdout, "Output file: %s\n", output_file);
    point_est = fopen(output_file, "w");
    
    // Point estimate of Z
    fprintf(point_est, "Z_hat: \n");
	for(k = 0; k < K; k++){
	    for(c = 0; c < C; c++){
	        fprintf(point_est, "%d ", mcmcSpl[max_index].Z[nThread - 1][c * K + k]);
	    }
	    fprintf(point_est, "\n");
	}
    
    fprintf(point_est, "\n");
    
    // for Z_vec
    fprintf(point_est, "Z_vec_hat: \n");
	for(k = 0; k < K; k++){
	    fprintf(point_est, "%d ", mcmcSpl[max_index].Z_vec[nThread - 1][k]);
	}
    fprintf(point_est, "\n\n");
    
    // for count
    fprintf(point_est, "count_hat: \n");
	for(c = 0; c < C; c++){
	    fprintf(point_est, "%d ", mcmcSpl[max_index].count[nThread - 1][c]);
	}
    fprintf(point_est, "\n\n");
    
    // point estimate for theta
    fprintf(point_est, "theta_hat: \n");
    for(t = 0; t < T; t++){
	    for(c = 0; c <= C; c++){
	        fprintf(point_est, "%f ", mcmcSpl[max_index].theta[nThread - 1][c * T + t]);
	    }
	    fprintf(point_est, "\n");
	}
	fprintf(point_est, "\n");
	
	
    double *w_hat;
    w_hat = (double *)calloc(T * (C + 1),sizeof(double));
    transform_theta(mcmcSpl[max_index].theta[nThread - 1], w_hat, T, C);
    
	fprintf(point_est, "w_hat: \n");
    for(t = 0; t < T; t++){
	    for(c = 0; c <= C; c++){
	        fprintf(point_est, "%f ", w_hat[c * T + t]);
	    }
	    fprintf(point_est, "\n");
	}
	fprintf(point_est, "\n");
	
	
	fprintf(point_est, "rho_star_hat: \n");
	for(g = 0; g < G; g++){
	    fprintf(point_est, "%f ", mcmcSpl[max_index].rho_star[nThread - 1][g]);
	}
	fprintf(point_est, "\n");
	fprintf(point_est, "\n");
	
	fprintf(point_est, "log posterior: \n");
	fprintf(point_est, "%f ", mcmcSpl[max_index].lpost[nThread - 1]);
	fprintf(point_est, "\n");
	fprintf(point_est, "\n");

	double *p_tilde;
	p_tilde = (double *)calloc(T*K*G, sizeof(double));
	calc_p_tilde(p_tilde, mcmcSpl[max_index].Z[nThread - 1], mcmcSpl[max_index].theta[nThread - 1], mcmcSpl[max_index].rho_star[nThread - 1], C, T, K, G, A);
	
	fprintf(point_est, "p_tilde_hat: \n");
	for(t = 0; t < T; t++){
        for(k = 0; k < K; k++){
            for(g = 0; g < G; g++){
                fprintf(point_est, "%f ", p_tilde[g * T * K + k * T + t]);
            }
        }
    }
    
    fprintf(point_est, "\n");
	fprintf(point_est, "\n");
    
    fprintf(point_est, "Number of States: \n");
	fprintf(point_est, "%d", n_State);
    fprintf(point_est, "\n");
	fprintf(point_est, "\n");
    
    fprintf(point_est, "Tree State Matrix: \n");
	for(k = 0; k < n_State; k++){
	    for(c = 0; c < C; c++){
	        fprintf(point_est, "%d ", treeStateMat[k][c]);
	    }
	    fprintf(point_est, "\n");
	}
	fprintf(point_est, "\n");
	
	
	fprintf(point_est, "Tree Count Matrix: \n");
	for(k = 0; k < n_State; k++){
	    for(c = 0; c < C; c++){
	        fprintf(point_est, "%d ", treeCntMat[k][c]);
	    }
	    fprintf(point_est, "\n");
	}
	fprintf(point_est, "\n");
    
	fclose(point_est);

	
	for(rank = 0; rank < nThread; rank++){
		free(Z[rank]);
		free(Z_vec[rank]);
		free(theta[rank]);
		free(rho_star[rank]);
		free(count[rank]);
	}
	/// freeing up all the allocated structure
	free_mcmc_arr(&mcmcSpl, niter, C, T, K, G);
	free(mcmcSpl);
	free(w_hat);
    free(p_tilde);
    return;
}


void run_for_each_parent_str(char* parent_str_arr, int T, int K, char *input_file, char *output_file){

	int C, G;
    
    G = 8;
    
    // hyperparameters
    double lambda, d, d0, d1, a_p, b_p;
    
    d = 0.5;
    d0 = 0.03;
    d1 = 1;

	int* parentArr;

 	// no need for where only normal cell present or Parent[0] = 0; omit C=0
 
	//int kk = 10;
	

	C = strlen(parent_str_arr);
    a_p = d;
    b_p = d0 + (C-1)*d;
	lambda = 2*K / C;
	
	fprintf(stdout, ">>>>>> Running Parallel Tempering. Parent array: %s.\n", parent_str_arr);
	parentArr = (int*)(calloc(C,sizeof(int)));
    
	for(int ii=0;ii<C;ii++){	
	    parentArr[ii] = int(parent_str_arr[ii])-48;
    }

	double preset_Delta[] = {4.5,3.2,2.5,2.0,1.6,1.35,1.1,1.0};
	int nThread = omp_get_max_threads();
	fprintf(stdout, "Number of threads: %d. ", nThread);
	if(nThread < 2){
		fprintf(stdout, "Not enough thread to run PT. Exit. \n");
		exit(0);
	}
	switch(nThread){
		case 2: preset_Delta[1] = 1.0; preset_Delta[0] = 2.0; break; 
		case 3:	nThread = 2;break;
		case 4: preset_Delta[3] = 1.0; preset_Delta[2] = 1.35; preset_Delta[1] = 2.0; preset_Delta[0] = 3.2; break; 
		case 5: nThread = 4;break;
		case 6: preset_Delta[5] = 1.0; preset_Delta[4] = 1.35; preset_Delta[3] = 1.6; preset_Delta[2] = 2.5; preset_Delta[1] = 3.2; preset_Delta[0] = 4.5; break;
		case 7: nThread = 6;break;
		case 8: break;
		default: nThread = 8;
	}
    fprintf(stdout, "And we use %d threads to run PT.\n", nThread);

	///// we need this two lines for a machine > 8 cores @subhajit
	omp_set_dynamic(0);     // Explicitly disable dynamic teams
	omp_set_num_threads(nThread); // set it to nThread
	
	// tempratures
    double *Delta;
    Delta = (double*)(calloc(nThread, sizeof(double)));
	    	
	for(int i = 0;i<nThread;i++){
		Delta[i] = preset_Delta[i];
	}

    int niter = 500;
    
    int c;
	
	// Question: is it necessary to define both tot_State and N_combinations?? They are essentially the same thing!
	
	int** treeStateMat;
	int** treeCntMat;
	int n_State = 0;
	int tot_State = 0;
    int N_combinations = 1;

    for(c = 0; c < C; c++)
		N_combinations *= strlen(hashTable[parentArr[c]]); 	
    
    int i;
    
	treeStateMat = (int**)(calloc(N_combinations,sizeof(int*)));
	treeCntMat = (int**)(calloc(N_combinations,sizeof(int*)));
	for(i = 0; i < N_combinations; i++)
	{		
		treeStateMat[i] = (int*)(calloc((C+1),sizeof(int)));
		treeCntMat[i]   = (int*)(calloc((C+1),sizeof(int)));
	}
	
	
	getAllCombinations(parentArr, C, &treeStateMat, &treeCntMat, &n_State, &tot_State);
    
    
    // read data
    double number;
    FILE* dataFile;
	
    dataFile = fopen(input_file, "r");
    double *n;
    n = (double*)(calloc(T*K*G,sizeof(double)));
    
    i = 0;
    while( fscanf(dataFile, "%lf\n", &number) > 0 )
    {
      n[i] = number;
      i++;
    }
    fclose(dataFile);
    
    // need a way to specify T, K, G, and hyperparameters lambda, ...
    PairTree_MCMC(n, C, T, K, G, lambda, d, d0, d1, a_p, b_p, niter, treeStateMat, treeCntMat, n_State, Delta, nThread, parent_str_arr, output_file);

	cleanGarbage(&treeStateMat, &treeCntMat, N_combinations);
	
	free(parentArr);
	free(n);
	free(Delta);
	
	fprintf(stdout, "Finishing Parallel Tempering. Parent array: %s.\n", parent_str_arr);

} /// for kk



int main(int argc, char* argv[])
{
	if(argc != 6){
		fprintf(stdout, "Usage: <prog_name> <parent_str> <T> <K> <input_file> <output_file>\n");
		exit(0);
	}

	int T = atoi(argv[2]);
	int K = atoi(argv[3]);

    run_for_each_parent_str(argv[1], T, K, argv[4], argv[5]);

    return 0;
}

