
#' @title Global Laplace VB
#' @description
#'  A variational Bayesian algorithm is proposed for multi-source heterogeneous models under the Laplace
#'    Spike-and-Slab prior, enabling simultaneous variable selection for both homogeneous and             #'    heterogeneous covariates.
#'
#' @param X Homogeneous covariates
#' @param Z Heterogeneous covariates
#' @param Y Response covariates
#' @param max_iter Maximum number of iterations, Defaut:1000
#' @param tol Algorithm convergence tolerance, Defaut:1e-6
#' @param a A prior of Beta distribution, Defaut:1
#' @param b A prior of Beta distribution, Defaut:10
#' @param lambda A prior of Laplace distribution, Defaut:1
#'
#' @return The mean of the homogeneity coefficient:mu1;
#'         The variance of homogeneity coefficient:sigma1;
#'         Selection coefficient:gamma1;
#'         The mean of the heterogeneous coefficient:mu2;
#'         The variance of heterogeneous coefficient:sigma2;
#'         Selection heterogeneous:gamma2.
#'
vb_lap_global <- function(X, Z, Y, max_iter=1000, tol=1e-6, a=1, b=10, lambda=1) {

  fn <- function(coef_sq, coef_lin, lambda) {
    return(function(x) {
      mu <- x[1]
      sigma <- x[2]

      exp_term <- lambda * (1 / sqrt(2)) * (2 / sqrt(pi)) * exp(-0.5 * (mu / sigma)^2)
      erf_term <- lambda * pracma::erf(sqrt(0.5) * mu / sigma)

      gradient <- numeric(2)
      gradient[1] <- 2 * coef_sq * mu + erf_term + coef_lin
      gradient[2] <- 2 * coef_sq * sigma + exp_term - 1 / sigma

      obj_value <- (erf_term + coef_lin) * mu +
        coef_sq * (mu^2 + sigma^2) -
        log(abs(sigma)) + sigma * exp_term

      return(list(value = obj_value, gradient = gradient))
    })
  }

  n <- dim(X)[1]
  p <- dim(X)[2]
  K <- dim(X)[3]
  q <- dim(Z)[2]

  noisy_sd <-0
  for (k in 1:K) {
    noisy_sd <- noisy_sd + selectiveInference::estimateSigma(cbind(X[,,k],Z[,,k]),Y[,k])$sigmahat
  }
  noisy_sd <- noisy_sd/K
  X <- X/noisy_sd
  Y <- Y/noisy_sd
  Z <- Z/noisy_sd

  mu1 <- matrix(rep(0,p),nrow=p,ncol=1)
  sigma1 <- matrix(rep(1,p),nrow=p,ncol=1)
  gamma1 <- matrix(rep(0.5,p),nrow=p,ncol=1)
  mu2 <- matrix(rep(0,q),nrow=q,ncol=K)
  sigma2 <- matrix(rep(1,q),nrow=q,ncol=K)
  gamma2 <- matrix(rep(0.5,q),nrow=q,ncol=K)



  old_entr <- entropy(gamma1)

  all_X <- X[,,1]
  all_Z <- Z[,,1]
  all_Zm <- Z[,,1] %*% (mu2[,1]*gamma2[,1])
  all_Y <- Y[,1]
  for (k in 2:K) {
    all_X <- rbind(all_X,X[,,k])
    all_Z <- rbind(all_Z,Z[,,k])
    all_Y <- c(all_Y,Y[,k])
    all_Zm <- c(all_Zm,Z[,,k] %*% (mu2[,k]*gamma2[,k]))
  }

  YX_vec <- t(all_Y-all_Zm) %*% all_X
  half_diag <- 0.5 * gram_diag(all_X)
  approx_mean <- gamma1 * mu1
  X_appm <- all_X %*% approx_mean


  half_diag_k <- matrix(0,nrow=q,ncol=K)
  approx_mean_k <- matrix(0,nrow=q,ncol=K)
  X_appm_k <- matrix(0,nrow=n,ncol=K)
  for (k in 1:K) {
    half_diag_k[,k] <- 0.5 * gram_diag(Z[,,k])
    approx_mean_k[,k] <- gamma2[,k] * mu2[,k]
    X_appm_k[,k] <- Z[,,k] %*% approx_mean_k[,k]
  }

  exit_loop <- FALSE

  const_lodds <- (log(a) - log(b)) + 0.5
  const_lodds <- const_lodds + 0.5 * log(pi) + log(lambda) - 0.5 * log(2)
  for (i in 1:max_iter) {
    all_Zm <- Z[,,1] %*% (mu2[,1]*gamma2[,1])
    for (k in 2:K) {
      all_Zm <- c(all_Zm,Z[,,k] %*% (mu2[,k]*gamma2[,k]))
    }
    YX_vec <- t(all_Y-all_Zm) %*% all_X
    for (j in 1:length(mu1)) {
      X_appm <- X_appm - approx_mean[j] * all_X[, j]
      obj_fn <- fn(half_diag[j], as.numeric(all_X[, j] %*% X_appm - YX_vec[j]), lambda)
      x <- c(mu1[j], sigma1[j])
      optim_result <- try(optim(par = x, fn = function(par) obj_fn(par)$value,
                                gr = function(par) obj_fn(par)$gradient, method = "L-BFGS-B"),silent=TRUE)
      if (inherits(optim_result, "try-error")) {
        exit_loop <- TRUE
        break
      }
      mu1[j] <- optim_result$par[1]
      sigma1[j] <- optim_result$par[2]
      gamma1[j] <- sigmoid(const_lodds - optim_result$value)
      approx_mean[j] <- gamma1[j] * mu1[j]
      X_appm <- X_appm + approx_mean[j] * all_X[, j]
    }


    for (k in 1:K){
      YX_vec_k <- t(Y[,k]-X[,,k]%*%approx_mean) %*% Z[,,k]
      for (j in 1:length(mu2[,k])) {

        X_appm_k[,k] <- X_appm_k[,k] - approx_mean_k[j,k] * Z[, j,k]

        obj_fn <- fn(half_diag_k[j,k], as.numeric(Z[, j,k] %*% X_appm_k[,k] - YX_vec_k[j]), lambda)
        x <- c(mu2[j,k], sigma2[j,k])

        optim_result <- try(optim(par = x, fn = function(par) obj_fn(par)$value,
                                  gr = function(par) obj_fn(par)$gradient, method = "L-BFGS-B"),silent=TRUE)

        if (inherits(optim_result, "try-error")) {
          exit_loop <- TRUE
          break
        }

        mu2[j,k] <- optim_result$par[1]
        sigma2[j,k] <- optim_result$par[2]
        gamma2[j,k] <- sigmoid(const_lodds - optim_result$value)
        approx_mean_k[j,k] = mu2[j,k] * gamma2[j,k]

        X_appm_k[,k] <- X_appm_k[,k] + as.vector(approx_mean_k[j,k]) * Z[, j, k]
      }
    }
    if (exit_loop == TRUE && all(gamma1==gamma1[1])){
      return(NA)
      break
    }

    new_entr <- entropy(gamma1)
    if (max(abs(new_entr - old_entr)) <= tol) {
      break
    } else {
      old_entr <- new_entr
    }
  }
  all_Zm <- Z[,,1] %*% (mu2[,1]*gamma2[,1])
  for (k in 2:K) {
    all_Zm <- c(all_Zm,Z[,,k] %*% (mu2[,k]*gamma2[,k]))
  }
  return(list(mu1 = mu1, sigma1=sigma1,gamma1 = gamma1, mu2 = mu2, gamma2 = gamma2,sigma2=sigma2))
}
