
post_prob_H1_cont <- function(n, theta0, e, a, b, sig, q = 0.5){
  var_sin <- 1/((1/(b^2)) + (n/(sig^2)))
  diff_mean <- var_sin*(a/(b^2) + (n*e)/(sig^2))
  z_score <- (theta0 - diff_mean)/(sqrt(var_sin))
  post_prob <- 1 - pnorm(z_score)
  if(q == 0.5){
    return(post_prob)
  }else{
    A2Pm1 <- (2*q-1)*post_prob
    result <- (q*post_prob)/(1 - q + A2Pm1)
    return(result)
  }
}

BESS_cont <- function(theta0, c, e, a, b, n_min=5, n_max=100, q = 0.5, sig = 1){

  n_seq <- seq(n_min, n_max)
  s_vec <- rep(NA, length(n_seq))
  cnt <- 1
  for(n in n_seq){
    s_vec[cnt] <- post_prob_H1_cont(n, theta0, e, a, b, sig, q = 0.5)
    cnt <- cnt + 1
  }
  if(s_vec[1] >= c){
    #print("Can lower n_min.")
    return("Can lower n_min.")
  }
  if(s_vec[length(s_vec)] < c){
    #print("Need to increase max sample size!")
    return("Need to increase max sample size!")
  }
  ss <- n_seq[which(s_vec >= c)[1]]
  return(list(n = ss, prob = s_vec[which(s_vec >= c)[1]], prob_h1_vec = s_vec))

}

post_prob_H1_cnt_sim <- function(Y_t, Y_c, n, theta0, a1, b1, a2, b2, q = 0.5, sim = 10000){

  #set.seed(seed)
  # Posterior p_t parameters
  pt_pa <- a1 + Y_t
  pt_pb <- b1 + n
  # Posterior p1|Y_t, n ~ Beta(pt_pa, pt_pb)
  pt_post <- rgamma(sim, pt_pa, rate = pt_pb)


  # Posterior p_c parameters
  pc_pa <- a2 + Y_c
  pc_pb <- b2 + n
  # Posterior p2|Y_c, n ~ Beta(pc_pa, pc_pb)
  pc_post <- rgamma(sim, pc_pa, rate = pc_pb)

  # p1 - p2
  pd_post <- pt_post - pc_post

  # Compute Pr(p1 - p2 > delta|data) = #(p1 - p2 > delta)/sim
  post_prob <- sum(pd_post > theta0)/sim
  if(q == 0.5){
    return(post_prob)
  }else{
    A2Pm1 <- (2*q-1)*post_prob
    result <- (q*post_prob)/(1 - q + A2Pm1)
    return(result)
  }
}

find_min_H1_count <- function(e, n, theta0, a1, b1, a2, b2, q = 0.5, s_sim = 10000, sim = 10000){
  lambda_ctl <- rgamma(s_sim, a2, rate = b2)
  s_vec <- rep(NA, s_sim)
  y1_vec <- rep(NA, s_sim)
  y2_vec <- rep(NA, s_sim)
  for(i in 1:s_sim){
    #set.seed(seed+i)
    y_c <- sum(rpois(n,lambda_ctl[i]))
    y_t <- y_c + floor(n*e)
    s_vec[i] <- post_prob_H1_cnt_sim(y_t, y_c, n, theta0, a1, b1, a2, b2, q = q, sim = sim)
    y1_vec[i] <- y_t
    y2_vec[i] <- y_c
  }
  # Find the minimum xi among all the y1s and y2s in the set y1 - y2 = yd
  min_xi_idx <- which(s_vec == min(s_vec))[1]
  return(list(y1 = y1_vec[min_xi_idx], y2 = y2_vec[min_xi_idx], min_xi = s_vec[min_xi_idx]))
}

#set.seed(12345)

BESS_count <- function(theta0, c, e, a1, b1, a2, b2, n_min=5, n_max=100, q = 0.5, s_sim = 1000, sim = 10000){

  xi_list <- c()
  yd_n_list <- c()
  y1_list <- c()
  y2_list <- c()

  ss <- NA
  yd_n <- NA

  for(n in seq(n_min,n_max)){

    if(n %% 10 == 0){
      print(n)
    }

    min_xi <- find_min_H1_count(e, n, theta0, a1, b1, a2, b2, q = q, s_sim = s_sim, sim = sim)
    #print(c(n, floor(n*(m+delta)), min_xi))

    xi_list <- c(xi_list, min_xi$min_xi)
    yd_n_list <- c(yd_n_list, min_xi$y1 - min_xi$y2)
    y1_list <- c(y1_list, min_xi$y1)
    y2_list <- c(y2_list, min_xi$y2)

    if(min_xi$min_xi >= c && is.na(ss)){
      ss <- n
      yd_n <- min_xi$y1 - min_xi$y2
      break
    }

  }

  return(list(n = ss, yd = yd_n, prob_h1_vec = xi_list, yd_vec = yd_n_list, y1_vec = y1_list, y2_vec = y2_list))

}

#BESS_count(0.05,0.6,0.2,3,3,3,3)

