#' Simulate Process Monitoring Data for Stream-Based Active Learning
#'
#' Generate a sequence of latent states and corresponding multivariate Gaussian
#' observations for process monitoring. The process has three possible states:
#' \enumerate{
#'   \item state 1: in-control (IC),
#'   \item state 2: out-of-control (OC),
#'   \item state 3: out-of-control (OC).
#' }
#'
#' The first \code{T0} observations are fixed in state 1 (IC). Then, in the
#' following \code{TT} observations, only state 2 appears in the first half, and
#' only state 3 appears in the second half. Within each half, runs of state 1
#' (IC) of random length between \code{T_min_IC} and \code{T_max_IC} alternate
#' with fixed-length runs of the corresponding OC state of length \code{T_OC}.
#'
#' @param d Integer. Number of variables (dimension of the multivariate
#'   observations). Default is 10.
#' @param TT Integer. Length of the sequence after the initial IC portion.
#'   Default is 500.
#' @param T0 Integer. Length of the initial IC sequence known to belong to
#'   state 1. Default is 100.
#' @param T_min_IC,T_max_IC Integers. Minimum and maximum length of consecutive IC
#'   observations before switching to an OC state.
#' @param T_OC Integer. Fixed length of each OC state sequence.
#' @param mean List of three numeric vectors of length \code{d}, representing
#'   the mean vectors of states 1 (IC), 2 (OC), and 3 (OC). If \code{NULL}
#'   (default), simple default values are used.
#' @param covariance List of three \code{d x d} covariance matrices, one for each
#'   state. If \code{NULL} (default), pre-defined (equal) covariance matrices are used.
#'
#' @return A list with elements:
#'   \item{x}{Integer vector of latent states of length \code{T0 + TT}.}
#'   \item{y}{Matrix of simulated multivariate observations with \code{T0 + TT}
#'     rows and \code{d} columns.}
#'
#' @examples
#' library(ActiveLearning4SPM)
#' sim <- simulate_stream()
#' table(sim$x)
#'
#' @export
simulate_stream <- function(d = 10,
                            TT = 500,
                            T0 = 100,
                            T_min_IC = 60,
                            T_max_IC = 85,
                            T_OC = 5,
                            mean = NULL,
                            covariance = NULL) {

  ## --- Input checks ---
  stopifnot(is.numeric(d), length(d) == 1, d > 0)
  stopifnot(is.numeric(TT), length(TT) == 1, TT > 0)
  stopifnot(is.numeric(T0), length(T0) == 1, T0 > 0)
  stopifnot(is.numeric(T_min_IC), is.numeric(T_max_IC), T_min_IC > 0, T_max_IC >= T_min_IC)
  stopifnot(is.numeric(T_OC), length(T_OC) == 1, T_OC > 0)

  if (!is.null(mean) || !is.null(covariance)) {
    if (is.null(mean) || is.null(covariance)) {
      stop("Both 'mean' and 'covariance' must be provided, or both NULL.")
    }
    if (!is.list(mean) || !is.list(covariance)) {
      stop("'mean' and 'covariance' must be lists.")
    }
    if (length(mean) != 3L || length(covariance) != 3L) {
      stop("'mean' and 'covariance' must each have length 3 (for the 3 states).")
    }
    for (m in mean) {
      if (!is.numeric(m) || length(m) != d) {
        stop("Each element of 'mean' must be a numeric vector of length d.")
      }
    }
    for (S in covariance) {
      if (!is.matrix(S) || nrow(S) != d || ncol(S) != d) {
        stop("Each element of 'covariance' must be a d x d matrix.")
      }
      if (!all(abs(S - t(S)) < 1e-8)) {
        stop("Each covariance matrix must be symmetric.")
      }
      ev <- eigen(S, symmetric = TRUE, only.values = TRUE)$values
      if (any(ev < -1e-8)) {
        stop("Each covariance matrix must be positive semi-definite.")
      }
    }
  } else {
    # Defaults: 3 states
    shift <- 1.8
    d1 <- c(shift, rep(0, d - 1))
    d2 <- c(0, shift, rep(0, d - 2))
    mean <- list(rep(0, d), d1, d2)

    cov_matrix <- matrix(0, d, d)
    for (i in 1:d) {
      for (j in 1:d) {
        cov_matrix[i, j] <- 0.75 ^ abs(i - j)
      }
    }
    covariance <- list(cov_matrix, cov_matrix, cov_matrix)
  }

  TT1 <- as.integer(TT/2)
  TT2 <- TT - TT1

  x <- rep(NA, TT)
  tt <- 1
  while(tt <= TT1) {
    tt_start <- tt
    tt_end <- tt_start + sample(T_min_IC:T_max_IC, 1)
    tt_end <- min(tt_end, TT1)
    x[tt_start:tt_end] <- 1
    tt_start <- tt_end + 1
    tt_end <- tt_start + T_OC - 1
    if (tt_end < TT1) {
      x[tt_start:tt_end] <- 2
    } else {
      x[tt_start:TT1] <- 1
    }

    tt <- tt_end + 1
  }
  tt <- TT1 + 1
  while(tt <= TT) {
    tt_start <- tt
    tt_end <- tt_start + sample(T_min_IC:T_max_IC, 1)
    tt_end <- min(tt_end, TT)
    x[tt_start:tt_end] <- 1
    tt_start <- tt_end + 1
    tt_end <- tt_start + T_OC - 1
    if (tt_end < TT) {
      x[tt_start:tt_end] <- 3
    } else {
      x[tt_start:TT] <- 1
    }
    tt <- tt_end + 1
  }
  x <- x[1:TT]

  y <- lapply(1:TT, function(ii) mvnfast::rmvn(1, mu = mean[[x[ii]]], sigma = covariance[[x[[ii]]]]))
  y <- do.call(rbind, y)

  x <- c(rep(1, T0), x)
  y <- rbind(mvnfast::rmvn(T0, mu = mean[[1]], sigma = covariance[[1]]), y)
  TT <- length(x)

  ret <- list(x = x, y = y)
  return(ret)
}




#' Automatic Initialization and Fitting of a Partially Hidden Markov Model (pHMM)
#'
#' Fits a partially hidden Markov model (pHMM) to multivariate time series
#' observations \eqn{y} with partially observed states \eqn{x}, using the
#' constrained Baum-Welch algorithm. Unlike \code{\link{fit_pHMM}}, this function
#' does not require user-specified initial parameters. Instead, it implements
#' a customized initialization strategy designed for process monitoring with
#' highly imbalanced classes, as described in the supplementary material of
#' Capezza, Lepore, and Paynabar (2025).
#'
#' @details
#' The initialization procedure addresses the multimodality of the likelihood
#' and the sensitivity of the Baum-Welch algorithm to starting values:
#' \enumerate{
#'   \item A one-state model (in-control process) is first fitted using robust
#'         estimators of location and scatter.
#'   \item To introduce an additional state, candidate mean vectors are selected
#'         from observations that are least well represented by the current
#'         model. This is achieved by computing moving averages of the data over
#'         window lengths \eqn{k = 1, \ldots, 9}, and then calculating the
#'         Mahalanobis distances of these smoothed points to existing state
#'         means.
#'   \item The \code{ntry} observations with the largest minimum distances are
#'         retained as candidate initializations for the new state's mean.
#'   \item For each candidate, a pHMM is initialized with:
#'         \itemize{
#'           \item Existing means fixed to their previous estimates.
#'           \item The new state's mean set to the candidate vector.
#'           \item A shared covariance matrix fixed to the robust estimate from
#'                 the in-control state.
#'           \item Initial state distribution \eqn{\pi} concentrated on the IC
#'                 state.
#'           \item Transition matrix with diagonal entries
#'                 \eqn{1 - 0.01 (N-1)} and off-diagonal entries \eqn{0.01}.
#'         }
#'   \item Each initialized model is fitted with the Baum-Welch algorithm, and
#'         the one achieving the highest log-likelihood is retained.
#'   \item This process is repeated until up to \code{max_nstates} states are
#'         considered.
#' }
#'
#' This strategy leverages prior process knowledge (dominant in-control regime)
#' and focuses the search on under-represented regions of the data space, which
#' improves convergence and reduces sensitivity to random initialization.
#'
#' @param y A numeric matrix of dimension \eqn{T \times d}, where each row
#'   corresponds to a \eqn{d}-dimensional observation at time \eqn{t}.
#' @param xlabeled An integer vector of length \eqn{T} with partially observed
#'   states. Known states must be integers in \eqn{1, \ldots, N}; unknown states
#'   should be coded as \code{NA}.
#' @param tol Convergence tolerance for log-likelihood and parameter change.
#'   Default is \code{1e-3}.
#' @param max_nstates Maximum number of hidden states to consider during the
#'   initialization procedure. Default is \code{5}.
#' @param ntry Number of candidate initializations for each new state. Default
#'   is \code{10}.
#'
#' @return A list with the same structure as returned by \code{\link{fit_pHMM}}:
#' \itemize{
#'   \item \code{y}, \code{xlabeled}: the input data.
#'   \item \code{log_lik}, \code{log_lik_vec}: final and trace of log-likelihood.
#'   \item \code{iter}: number of EM iterations performed.
#'   \item \code{logB}, \code{log_alpha}, \code{log_beta}, \code{log_gamma},
#'         \code{log_xi}: posterior quantities from the Baum-Welch algorithm.
#'   \item \code{logAhat}, \code{mean_hat}, \code{covariance_hat},
#'         \code{log_pi_hat}: estimated model parameters.
#'   \item \code{AIC}, \code{BIC}: information criteria for model selection.
#' }
#'
#' @examples
#' library(ActiveLearning4SPM)
#' set.seed(123)
#' dat <- simulate_stream(T0 = 100, TT = 500)
#' y <- dat$y
#' xlabeled <- dat$x
#' d <- ncol(dat$y)
#' xlabeled[sample(1:600, 300)] <- NA
#' obj <- fit_pHMM_auto(y = y,
#'                      xlabeled = xlabeled,
#'                      tol = 1e-3,
#'                      max_nstates = 5,
#'                      ntry = 10)
#' obj$AIC
#'
#' @references
#' Capezza, C., Lepore, A., & Paynabar, K. (2025).
#'   Stream-Based Active Learning for Process Monitoring.
#'   \emph{Technometrics}. <doi:10.1080/00401706.2025.2561744>.
#'
#' Supplementary Material, Section B: Initialization of the Partially Hidden
#' Markov Model. Available at
#' <https://doi.org/10.1080/00401706.2025.2561744>.
#'
#' @export
fit_pHMM_auto <- function(y = y,
                          xlabeled = xlabeled,
                          tol = 1e-3,
                          max_nstates = 5,
                          ntry = 10) {

  mod1 <- mod2 <- mod3 <- mod4 <- mod5 <- list(AIC = Inf)

  if (max_nstates > 5) {
    warning("max_nstates cannot exceed 5; resetting to 5.")
    max_nstates <- 5
  }

  xlabeled_old <- xlabeled
  states_observed <- as.numeric(names(sort(-table(xlabeled_old))))

  for (ii in seq_along(states_observed)) {
    xlabeled[which(xlabeled_old == states_observed[ii])] <- ii
  }

  nstates_min <- length(states_observed)
  max_nstates <- max(max_nstates, nstates_min)
  max_nstates <- min(max_nstates, nstates_min + sum(is.na(xlabeled)))
  tryCatch({
    cov_robust <- rrcov::CovRobust(y)
  }, error = function(e) {
    cov_robust <- rrcov::CovMcd(y, alpha = 0.9)
  })
  Sigma_hat_inv <- solve(cov_robust$cov)
  mean_start <- list(cov_robust$center)
  covariance_start <- list(cov_robust$cov)
  nstates <- 1

  # 1 state
  mod1 <- fit_pHMM(y = y,
                   xlabeled = ifelse(xlabeled <= 1, xlabeled, NA),
                   nstates = 1,
                   ppi_start = 1,
                   A_start = matrix(1),
                   mean_start = mean_start,
                   covariance_start = covariance_start,
                   equal_covariance = TRUE,
                   tol = tol)

  if (nstates_min > 1) mod1$AIC <- Inf

  # 2 states
  if (max_nstates >= 2) {
    ycen_1 <- t(y) - cov_robust$center
    ycen_1_k_list <- list()
    T2_1_k_list <- list()
    ycen_1_k_list[[1]] <- ycen_1
    T2_1_k_list[[1]] <- colSums(ycen_1 * (Sigma_hat_inv %*% ycen_1))

    for (k in 2:9) {
      ycen_1_k <- caTools::runmean(t(ycen_1), k = k, align = "right", endrule = "NA")
      ycen_1_k_list[[k]] <- t(ycen_1_k)
      T2_1_k_list[[k]] <- k * colSums(ycen_1_k_list[[k]] * (Sigma_hat_inv %*% ycen_1_k_list[[k]]))
    }

    T2_1_mat <- do.call(cbind, T2_1_k_list)
    T2_1 <- Rfast::rowMaxs(T2_1_mat, value = TRUE)
    T2_1_which_max <- Rfast::rowMaxs(T2_1_mat)

    ycen_1_mean <- t(y) * NA
    for (k in 1:9) {
      ycen_1_mean[, which(T2_1_which_max == k)] <- ycen_1_k_list[[k]][, which(T2_1_which_max == k)]
    }

    order_T2_1 <- order(-T2_1)
    ntry <- min(sum(!is.na(T2_1)), ntry)

    best_mod2 <- list(log_lik = -1e5)
    log_lik_try <- numeric(ntry)
    niter_try <- numeric(ntry)
    A_start <- matrix(0.01, 2, 2)
    diag(A_start) <- 1 - 0.01 * (2 - 1)
    for (ii in 1:ntry) {
      mean_start[[2]] <- ycen_1_mean[, order_T2_1[ii]]
      mod2 <- fit_pHMM(y = y,
                       xlabeled = ifelse(xlabeled <= 2, xlabeled, NA),
                       nstates = 2,
                       ppi_start = c(1, 0),
                       A_start = A_start,
                       mean_start = mean_start,
                       covariance_start = covariance_start,
                       equal_covariance = TRUE,
                       tol = tol)
      log_lik_try[ii] <- mod2$log_lik
      niter_try[ii] <- mod2$iter
      if (mod2$log_lik > best_mod2$log_lik) best_mod2 <- mod2
    }
    mod2 <- best_mod2
    nstates <- which.min(c(mod1$AIC, mod2$AIC))
  }
  if (nstates_min > 2) {
    mod2$AIC <- Inf
    nstates <- 2
  }


  # 3 states
  if (nstates == 2 & max_nstates >= 3) {
    ycen_2 <- t(y) - mod2$mean_hat[[2]]
    ycen_2_k_list <- list()
    T2_2_k_list <- list()
    ycen_2_k_list[[1]] <- ycen_2
    T2_2_k_list[[1]] <- colSums(ycen_2 * (Sigma_hat_inv %*% ycen_2))

    for (k in 2:9) {
      ycen_2_k <- caTools::runmean(t(ycen_2), k = k, align = "right", endrule = "NA")
      ycen_2_k_list[[k]] <- t(ycen_2_k)
      T2_2_k_list[[k]] <- k * colSums(ycen_2_k_list[[k]] * (Sigma_hat_inv %*% ycen_2_k_list[[k]]))
    }

    T2_2_mat <- do.call(cbind, T2_2_k_list)
    T2_12_array <- abind::abind(T2_1_mat, T2_2_mat, along = 3)
    T2_12_which_min_mat <- apply(T2_12_array, 1:2, which.min)
    T2_12_mat <- apply(T2_12_array, 1:2, min)
    T2_12 <- Rfast::rowMaxs(T2_12_mat, value = TRUE)
    T2_12_which_max <- Rfast::rowMaxs(T2_12_mat)

    ycen_12_k_list <- ycen_1_k_list
    for (k in 1:9) {
      idx2 <- which(as.numeric(T2_12_which_min_mat[,k]) == 2)
      ycen_12_k_list[[k]][, idx2] <- ycen_2_k_list[[k]][, idx2]
    }
    ycen_12_mean <- t(y) * NA
    for (k in 1:9) {
      idxk <- which(T2_12_which_max == k)
      ycen_12_mean[, idxk] <- ycen_12_k_list[[k]][, idxk]
    }

    ntry <- min(sum(!is.na(T2_12)), ntry)
    best_mod3 <- list(log_lik = -1e5)
    log_lik_try <- numeric(ntry)
    niter_try <- numeric(ntry)
    A_start <- matrix(0.01, 3, 3)
    diag(A_start) <- 1 - 0.01 * (3 - 1)
    mean_start <- mod2$mean_hat
    mean_start[[1]] <- cov_robust$center

    order_T2_12 <- order(-T2_12)

    for (ii in 1:ntry) {
      mean_start[[3]] <- ycen_12_mean[, order_T2_12[ii]]
      mod3 <- fit_pHMM(y = y,
                       xlabeled = ifelse(xlabeled <= 3, xlabeled, NA),
                       nstates = 3,
                       ppi_start = c(1, 0, 0),
                       A_start = A_start,
                       mean_start = mean_start,
                       covariance_start = covariance_start,
                       equal_covariance = TRUE,
                       tol = tol)
      log_lik_try[ii] <- mod3$log_lik
      niter_try[ii] <- mod3$iter
      if (mod3$log_lik > best_mod3$log_lik) best_mod3 <- mod3
    }
    mod3 <- best_mod3
    nstates <- which.min(c(mod1$AIC, mod2$AIC, mod3$AIC))
  }
  if (nstates_min > 3) {
    mod3$AIC <- Inf
    nstates <- 3
  }

  # 4 states
  if (nstates == 3 & max_nstates >= 4) {
    ycen_2 <- t(y) - mod3$mean_hat[[2]]
    ycen_2_k_list <- list()
    T2_2_k_list <- list()
    ycen_2_k_list[[1]] <- ycen_2
    T2_2_k_list[[1]] <- colSums(ycen_2 * (Sigma_hat_inv %*% ycen_2))

    for (k in 2:9) {
      ycen_2_k <- caTools::runmean(t(ycen_2), k = k, align = "right", endrule = "NA")
      ycen_2_k_list[[k]] <- t(ycen_2_k)
      T2_2_k_list[[k]] <- k * colSums(ycen_2_k_list[[k]] * (Sigma_hat_inv %*% ycen_2_k_list[[k]]))
    }

    ycen_3 <- t(y) - mod3$mean_hat[[3]]
    ycen_3_k_list <- list()
    T2_3_k_list <- list()
    ycen_3_k_list[[1]] <- ycen_3
    T2_3_k_list[[1]] <- colSums(ycen_3 * (Sigma_hat_inv %*% ycen_3))

    for (k in 2:9) {
      ycen_3_k <- caTools::runmean(t(ycen_3), k = k, align = "right", endrule = "NA")
      ycen_3_k_list[[k]] <- t(ycen_3_k)
      T2_3_k_list[[k]] <- k * colSums(ycen_3_k_list[[k]] * (Sigma_hat_inv %*% ycen_3_k_list[[k]]))
    }

    T2_2_mat <- do.call(cbind, T2_2_k_list)
    T2_3_mat <- do.call(cbind, T2_3_k_list)
    T2_123_array <- abind::abind(T2_1_mat, T2_2_mat, T2_3_mat, along = 3)
    T2_123_which_min_mat <- apply(T2_123_array, 1:2, which.min)
    T2_123_mat <- apply(T2_123_array, 1:2, min)
    T2_123 <- Rfast::rowMaxs(T2_123_mat, value = TRUE)
    T2_123_which_max <- Rfast::rowMaxs(T2_123_mat)

    ycen_123_k_list <- ycen_1_k_list
    for (k in 1:9) {
      idx2 <- which(as.numeric(T2_123_which_min_mat[,k]) == 2)
      idx3 <- which(as.numeric(T2_123_which_min_mat[,k]) == 3)
      ycen_123_k_list[[k]][, idx2] <- ycen_2_k_list[[k]][, idx2]
      ycen_123_k_list[[k]][, idx3] <- ycen_3_k_list[[k]][, idx3]
    }
    ycen_123_mean <- t(y) * NA
    for (k in 1:9) {
      idxk <- which(T2_123_which_max == k)
      ycen_123_mean[, idxk] <- ycen_123_k_list[[k]][, idxk]
    }

    ntry <- min(sum(!is.na(T2_123)), ntry)
    best_mod4 <- list(log_lik = -1e5)
    log_lik_try <- numeric(ntry)
    niter_try <- numeric(ntry)
    A_start <- matrix(0.01, 4, 4)
    diag(A_start) <- 1 - 0.01 * (4 - 1)
    mean_start <- mod3$mean_hat
    mean_start[[1]] <- cov_robust$center

    order_T2_123 <- order(-T2_123)

    for (ii in 1:ntry) {
      mean_start[[4]] <- ycen_123_mean[, order_T2_123[ii]]
      mod4 <- fit_pHMM(y = y,
                       xlabeled = ifelse(xlabeled <= 4, xlabeled, NA),
                       nstates = 4,
                       ppi_start = c(1, 0, 0, 0),
                       A_start = A_start,
                       mean_start = mean_start,
                       covariance_start = covariance_start,
                       equal_covariance = TRUE,
                       tol = tol)
      log_lik_try[ii] <- mod4$log_lik
      niter_try[ii] <- mod4$iter
      if (mod4$log_lik > best_mod4$log_lik) best_mod4 <- mod4
    }
    mod4 <- best_mod4
    nstates <- which.min(c(mod1$AIC, mod2$AIC, mod3$AIC, mod4$AIC))
  }
  if (nstates_min > 4) {
    mod4$AIC <- Inf
    nstates <- 4
  }

  # 5 states
  if (nstates == 4 & max_nstates >= 5) {
    ycen_2 <- t(y) - mod4$mean_hat[[2]]
    ycen_2_k_list <- list()
    T2_2_k_list <- list()
    ycen_2_k_list[[1]] <- ycen_2
    T2_2_k_list[[1]] <- colSums(ycen_2 * (Sigma_hat_inv %*% ycen_2))

    for (k in 2:9) {
      ycen_2_k <- caTools::runmean(t(ycen_2), k = k, align = "right", endrule = "NA")
      ycen_2_k_list[[k]] <- t(ycen_2_k)
      T2_2_k_list[[k]] <- k * colSums(ycen_2_k_list[[k]] * (Sigma_hat_inv %*% ycen_2_k_list[[k]]))
    }

    ycen_3 <- t(y) - mod4$mean_hat[[3]]
    ycen_3_k_list <- list()
    T2_3_k_list <- list()
    ycen_3_k_list[[1]] <- ycen_3
    T2_3_k_list[[1]] <- colSums(ycen_3 * (Sigma_hat_inv %*% ycen_3))

    for (k in 2:9) {
      ycen_3_k <- caTools::runmean(t(ycen_3), k = k, align = "right", endrule = "NA")
      ycen_3_k_list[[k]] <- t(ycen_3_k)
      T2_3_k_list[[k]] <- k * colSums(ycen_3_k_list[[k]] * (Sigma_hat_inv %*% ycen_3_k_list[[k]]))
    }

    ycen_4 <- t(y) - mod4$mean_hat[[4]]
    ycen_4_k_list <- list()
    T2_4_k_list <- list()
    ycen_4_k_list[[1]] <- ycen_4
    T2_4_k_list[[1]] <- colSums(ycen_4 * (Sigma_hat_inv %*% ycen_4))

    for (k in 2:9) {
      ycen_4_k <- caTools::runmean(t(ycen_4), k = k, align = "right", endrule = "NA")
      ycen_4_k_list[[k]] <- t(ycen_4_k)
      T2_4_k_list[[k]] <- k * colSums(ycen_4_k_list[[k]] * (Sigma_hat_inv %*% ycen_4_k_list[[k]]))
    }

    T2_2_mat <- do.call(cbind, T2_2_k_list)
    T2_3_mat <- do.call(cbind, T2_3_k_list)
    T2_4_mat <- do.call(cbind, T2_4_k_list)
    T2_1234_array <- abind::abind(T2_1_mat, T2_2_mat, T2_3_mat, T2_4_mat, along = 3)
    T2_1234_which_min_mat <- apply(T2_1234_array, 1:2, which.min)
    T2_1234_mat <- apply(T2_1234_array, 1:2, min)
    T2_1234 <- Rfast::rowMaxs(T2_1234_mat, value = TRUE)
    T2_1234_which_max <- Rfast::rowMaxs(T2_1234_mat)

    ycen_1234_k_list <- ycen_1_k_list
    for (k in 1:9) {
      idx2 <- which(as.numeric(T2_1234_which_min_mat[,k]) == 2)
      idx3 <- which(as.numeric(T2_1234_which_min_mat[,k]) == 3)
      idx4 <- which(as.numeric(T2_1234_which_min_mat[,k]) == 4)
      ycen_1234_k_list[[k]][, idx2] <- ycen_2_k_list[[k]][, idx2]
      ycen_1234_k_list[[k]][, idx3] <- ycen_3_k_list[[k]][, idx3]
      ycen_1234_k_list[[k]][, idx4] <- ycen_3_k_list[[k]][, idx4]
    }
    ycen_1234_mean <- t(y) * NA
    for (k in 1:9) {
      idxk <- which(T2_1234_which_max == k)
      ycen_1234_mean[, idxk] <- ycen_1234_k_list[[k]][, idxk]
    }

    ntry <- min(sum(!is.na(T2_1234)), ntry)

    best_mod5 <- list(log_lik = -1e5)
    log_lik_try <- numeric(ntry)
    niter_try <- numeric(ntry)
    A_start <- matrix(0.01, 5, 5)
    diag(A_start) <- 1 - 0.01 * (5 - 1)
    mean_start <- mod4$mean_hat
    mean_start[[1]] <- cov_robust$center

    order_T2_1234 <- order(-T2_1234)
    for (ii in 1:ntry) {
      mean_start[[5]] <- ycen_1234_mean[, order_T2_1234[ii]]
      mod5 <- fit_pHMM(y = y,
                       xlabeled = ifelse(xlabeled <= 5, xlabeled, NA),
                       nstates = 5,
                       ppi_start = c(1, 0, 0, 0, 0),
                       A_start = A_start,
                       mean_start = mean_start,
                       covariance_start = covariance_start,
                       equal_covariance = TRUE,
                       tol = tol)
      log_lik_try[ii] <- mod5$log_lik
      niter_try[ii] <- mod5$iter
      if (mod5$log_lik > best_mod5$log_lik) best_mod5 <- mod5
    }
    mod5 <- best_mod5
    nstates <- which.min(c(mod1$AIC, mod2$AIC, mod3$AIC, mod4$AIC, mod5$AIC))
  }

  nstates <- which.min(c(mod1$AIC, mod2$AIC, mod3$AIC, mod4$AIC, mod5$AIC))
  nstates <- max(nstates, nstates_min)
  if (nstates == 1) mod <- mod1
  if (nstates == 2) mod <- mod2
  if (nstates == 3) mod <- mod3
  if (nstates == 4) mod <- mod4
  if (nstates == 5) mod <- mod5

  states_selected <- c(states_observed, setdiff(1:nstates, states_observed))

  mod$xlabeled <- xlabeled_old
  mod$logB <- mod$logB[, states_selected, drop = FALSE]
  mod$log_alpha <- mod$log_alpha[, states_selected, drop = FALSE]
  mod$log_beta <- mod$log_beta[, states_selected, drop = FALSE]
  mod$log_gamma <- mod$log_gamma[, states_selected, drop = FALSE]
  mod$log_xi <- mod$log_xi[, states_selected, states_selected, drop = FALSE]
  mod$logAhat <- mod$logAhat[states_selected, states_selected, drop = FALSE]
  mod$mean_hat <- mod$mean_hat[states_selected]
  mod$log_pi_hat <- mod$log_pi_hat[states_selected]

  return(mod)

}










#' Fit a Partially Hidden Markov Model (pHMM)
#'
#' Fits a partially hidden Markov model (pHMM) to multivariate time series
#' observations \eqn{y} with partially observed process states \eqn{x}, using a
#' constrained Baum-Welch algorithm. The function allows the user to provide
#' custom initial parameters, and supports constraints on known means and/or
#' covariances, as well as equal or diagonal covariance structures.
#'
#' @param y A numeric matrix of dimension \eqn{T \times d}, where each row
#'   corresponds to a \eqn{d}-dimensional observation at time \eqn{t}.
#' @param xlabeled An integer vector of length \eqn{T} with partially observed
#'   states. Known states must be integers in \eqn{1, \ldots, N}; unknown states
#'   should be coded as \code{NA}.
#' @param nstates Integer. The total number of hidden states to fit.
#' @param ppi_start Numeric vector of length \code{nstates} giving the initial
#'   state distribution. If \code{NULL}, defaults to \code{c(1,0,...,0)}.
#' @param A_start Numeric \code{nstates} \eqn{\times} \code{nstates} transition probability
#'   matrix. If \code{NULL}, defaults to a transition matrix with diagonal
#'   entries equal to \code{1-0.01*(nstates-1)} and all off-diagonal entries equal to \code{0.01}.
#' @param mean_start List of length \code{nstates} containing numeric mean
#'   vectors for the emission distributions.
#' @param covariance_start List of covariance matrices for the emission
#'   distributions. Must be of length \code{nstates}, unless
#'   \code{equal_covariance = TRUE}, in which case it must be of length 1.
#'   If \code{NULL}, defaults to identity matrices.
#' @param known_mean Optional list of known mean vectors. Use \code{NA} for
#'   unknown elements.
#' @param known_covariance Optional list of known covariance matrices. Use
#'   \code{NA} for unknown elements.
#' @param equal_covariance Logical. If \code{TRUE}, all states are constrained
#'   to share a common covariance matrix.
#' @param covariance_structure Character string specifying the covariance
#'   structure. Either \code{"full"} (default) or \code{"diagonal"}.
#' @param max_iter Maximum number of EM iterations. Default is 200.
#' @param tol Convergence tolerance for log-likelihood and parameter change.
#'   Default is \code{1e-3}.
#' @param verbose Logical. If \code{TRUE}, prints log-likelihood progress at
#'   each iteration.
#'
#' @return A list with components:
#' \itemize{
#'   \item \code{y}, \code{xlabeled}: the input data.
#'   \item \code{log_lik}, \code{log_lik_vec}: final and trace of log-likelihood.
#'   \item \code{iter}: number of EM iterations performed.
#'   \item \code{logB}, \code{log_alpha}, \code{log_beta}, \code{log_gamma},
#'         \code{log_xi}: posterior quantities from the Baum-Welch algorithm.
#'   \item \code{logAhat}, \code{mean_hat}, \code{covariance_hat},
#'         \code{log_pi_hat}: estimated model parameters.
#'   \item \code{AIC}, \code{BIC}: information criteria for model selection.
#' }
#'
#' @examples
#' library(ActiveLearning4SPM)
#' set.seed(123)
#' dat <- simulate_stream(T0 = 100, TT = 500)
#' y <- dat$y
#' xlabeled <- dat$x
#' d <- ncol(dat$y)
#' xlabeled[sample(1:600, 300)] <- NA
#' out <- fit_pHMM(y = y,
#'                 xlabeled = xlabeled,
#'                 nstates = 3,
#'                 mean_start = list(rep(0, d), rep(1, d), rep(-1, d)),
#'                 equal_covariance = TRUE)
#' out$AIC
#'
#' @references Capezza, C., Lepore, A., & Paynabar, K. (2025).
#'   Stream-Based Active Learning for Process Monitoring.
#'   \emph{Technometrics}. <doi:10.1080/00401706.2025.2561744>.
#'
#' @export
fit_pHMM <- function(y,
                     xlabeled,
                     nstates,
                     ppi_start = NULL,
                     A_start = NULL,
                     mean_start,
                     covariance_start = NULL,
                     known_mean = NULL,
                     known_covariance = NULL,
                     equal_covariance = FALSE,
                     covariance_structure = "full",
                     max_iter = 200,
                     tol = 1e-3,
                     verbose = FALSE) {

  if (!is.null(known_mean)) {
    which_known_mean <- which(!is.na(known_mean))
  }
  if (!is.null(known_covariance)) {
    which_known_covariance <- which(!is.na(known_covariance))
  }

  if (is.null(covariance_start )) {
    if (equal_covariance) {
      covariance_start <- list(diag(ncol(y)))
    } else {
      covariance_start <- rep(list(diag(ncol(y))), nstates)
    }
  }

  if (is.null(ppi_start)) {
    ppi_start <- c(1, rep(0, nstates - 1))
  }

  if (is.null(A_start)) {
    A_start <- matrix(0.01, nstates, nstates)
    diag(A_start) <- 1 - 0.01 * (nstates - 1)
  }

  if (equal_covariance) {
    if (length(covariance_start) != 1) {
      stop("If equal_covariance is TRUE, covariance_start must have length 1")
    }
  }

  if (!equal_covariance & is.null(known_covariance)) {
    if (length(covariance_start) != nstates) {
      stop("covariance_start must be a list with length equal to nstates")
    }
  }

  if (covariance_structure == "diagonal") {
    for (jj in 1:length(covariance_start)) {
      if (sum(abs(covariance_start[[jj]] - diag(diag(covariance_start[[jj]])))) > 0) {
        stop("If covariance_structure is diagonal, covariance_start must contain only diagonal matrices")
      }
    }
  }

  states <- sort(unique(xlabeled[!is.na(xlabeled)]))
  if (!identical(states, 1:nstates)) {
    actual_states <- states
    actual_xlabeled <- xlabeled
    states <- seq_along(states)
    for (ss in 1:nstates) {
      xlabeled[xlabeled == actual_states[ss]] <- ss
    }
  } else {
    actual_states <- states
  }

  d <- ncol(y)
  TT <- length(xlabeled)

  log_ppi_hat <- log(ppi_start)
  logAhat <- log(A_start)
  mean_hat <- mean_start
  if (equal_covariance) {
    covariance_hat <- lapply(1:nstates, function(jj) covariance_start[[1]])
  } else {
    covariance_hat <- covariance_start
  }
  if (!is.null(known_mean)) {
    mean_hat[which_known_mean] <- known_mean[which_known_mean]
  }
  if (!is.null(known_covariance)) {
    covariance_hat[which_known_covariance] <- known_covariance[which_known_covariance]
  }

  log_lik <- -Inf
  log_lik_vec <- log_lik
  loglik_change <- 1e5
  par_change <- 1e5
  iter <- 0

  while (iter < max_iter & abs(loglik_change) > tol & par_change > tol) {
    iter <- iter + 1
    logB <- calculate_logB(y = y, mean = mean_hat, covariance = covariance_hat)
    log_alpha <- log_forward_algorithm_cpp(log_ppi = log_ppi_hat, logA = logAhat, logB = logB, xlabeled = xlabeled)
    log_beta <- log_backward_algorithm_cpp(logA = logAhat, logB = logB, xlabeled = xlabeled)
    log_ab <- log_alpha + log_beta

    log_rowsums_log_ab <- log_rowSums(log_ab)
    log_gamma <- log_ab - log_rowsums_log_ab
    gamma <- exp(log_gamma)

    log_xi <- log_xi_cpp(log_alpha = log_alpha, logAhat = logAhat,
                         log_beta = log_beta, logB = logB)
    log_xi_mat <- matrix(log_xi, nrow = dim(log_xi)[1])

    colSums_gamma <- colSums(gamma)

    log_colSums_gamma <- as.numeric(log_colSums_cpp(log_gamma))
    log_colSums_log_xi_mat <- log_colSums_cpp(log_xi_mat)
    logAhat <- matrix(as.numeric(log_colSums_log_xi_mat) - log_colSums_gamma, nrow = nrow(logAhat))
    logAhat <- logAhat - as.numeric(log_rowSums_cpp(logAhat))
    for (ii in 1:nstates) {
      if (sum(is.nan(logAhat[ii, ])) == nstates) {
        logAhat[ii, ] <- rep(- log(nstates), nstates)
      }
    }

    log_gamma_calc <- log_gamma
    y_calc <- y


    ww <- exp(t(t(log_gamma_calc) - log_colSums_gamma))

    mean_old <- unlist(mean_hat)
    covariance_old <- unlist(covariance_hat)

    mean_hat <- crossprod(ww, y_calc)
    mean_hat <- lapply(1:nstates, function(ii) mean_hat[ii, ])

    for (ii in 1:nstates) {
      wii <- ww[, ii]
      ycen_ii <- y_calc - t(matrix(mean_hat[[ii]], nrow = d, ncol = nrow(y_calc)))
      covariance_hat[[ii]] <- crossprod(ycen_ii * sqrt(wii))
      if (covariance_structure == "diagonal") {
        covariance_hat[[ii]] <- diag(diag(covariance_hat[[ii]]))
      }
    }
    if (!is.null(known_mean)) {
      mean_hat[which_known_mean] <- known_mean[which_known_mean]
    }
    if (equal_covariance) {
      props <- colSums(exp(log_gamma_calc)) / nrow(log_gamma_calc)
      covariance_weighted <- Reduce("+", lapply(1:nstates, function(jj) covariance_hat[[jj]] * props[jj]))
      covariance_hat <- lapply(1:nstates, function(jj) covariance_weighted)
    }
    if (!is.null(known_covariance)) {
      covariance_hat[which_known_covariance] <- known_covariance[which_known_covariance]
    }
    log_ppi_hat <- log_gamma[1, ]
    new_logLik <- log_sum_vec(log_alpha[TT,])

    mean_new <- unlist(mean_hat)
    covariance_new <- unlist(covariance_hat)

    mean_change <- sqrt(sum((mean_new - mean_old)^2)) / sqrt(sum(mean_old^2))
    cov_change <- sqrt(sum((covariance_new - covariance_old)^2)) / sqrt(sum(covariance_old^2))
    par_change <- max(mean_change, cov_change)

    loglik_change <- new_logLik - log_lik
    log_lik <- new_logLik
    log_lik_vec <- c(log_lik_vec, log_lik)
    if (verbose) {
      message(
        sprintf("Iter %d: logLik = %.4f | logLik diff. = %.4e | par diff. = %.4e",
                iter, log_lik, loglik_change, par_change)
      )
    }
  }

  logB <- calculate_logB(y = y, mean = mean_hat, covariance = covariance_hat)
  log_alpha <- log_forward_algorithm_cpp(log_ppi = log_ppi_hat, logA = logAhat, logB = logB, xlabeled = xlabeled)
  log_beta <- log_backward_algorithm_cpp(logA = logAhat, logB = logB, xlabeled = xlabeled)
  log_ab <- log_alpha + log_beta
  log_rowsums_log_ab <- log_rowSums(log_ab)
  log_gamma <- log_ab - log_rowsums_log_ab
  log_xi <- log_xi_cpp(log_alpha = log_alpha, logAhat = logAhat,
                       log_beta = log_beta, logB = logB)
  log_xi_mat <- matrix(log_xi, nrow = dim(log_xi)[1])
  new_logLik <- log_sum_vec(log_alpha[TT,])
  log_lik <- new_logLik
  log_lik_vec <- c(log_lik_vec, log_lik)


  ll <- new_logLik

  npar <-
    nstates - 1 + # pi
    nstates * (nstates - 1) + # A
    nstates * d # mu

  if (is.null(known_covariance) & covariance_structure == "full" & !equal_covariance) {
    npar <- npar +
      nstates * d * (d + 1) / 2 # Sigma
  }

  if (!is.null(known_covariance) & covariance_structure == "full" & !equal_covariance) {
    how_many_unknown_covariance <- nstates - sum(which_known_covariance)
    npar <- npar +
      how_many_unknown_covariance * d * (d + 1) / 2 # Sigma
  }

  if (is.null(known_covariance) & covariance_structure == "diagonal" & !equal_covariance) {
    npar <- npar +
      nstates * d # diagonal Sigma
  }

  if (!is.null(known_covariance) & covariance_structure == "diagonal" & !equal_covariance) {
    how_many_unknown_covariance <- nstates - sum(which_known_covariance)
    npar <- npar +
      how_many_unknown_covariance * d # diagonal Sigma
  }

  if (is.null(known_covariance) & covariance_structure == "full" & equal_covariance) {
    npar <- npar +
      d * (d + 1) / 2 # equal covariance
  }

  if (!is.null(known_covariance) & covariance_structure == "full" & equal_covariance) {
    how_many_unknown_covariance <- nstates - sum(which_known_covariance)
    npar <- npar +
      0 # it should return an error if you provide more than one covariance if equal_covariance is TRUE
  }

  if (is.null(known_covariance) & covariance_structure == "diagonal" & equal_covariance) {
    npar <- npar +
      d # equal diagonal covariance
  }

  if (!is.null(known_covariance) & covariance_structure == "diagonal" & equal_covariance) {
    how_many_unknown_covariance <- nstates - sum(which_known_covariance)
    npar <- npar +
      0 # it should return an error if you provide more than one covariance if equal_covariance is TRUE
  }

  AIC <- -2 * ll + 2*npar
  BIC <- -2 * ll + log(length(xlabeled))*npar

  if (!identical(actual_states, 1:nstates)) {
    xlabeled <- actual_xlabeled
  }

  out <- list(y = y,
              xlabeled = xlabeled,
              log_lik = log_lik,
              log_lik_vec = log_lik_vec,
              iter = iter,
              logB = logB,
              log_alpha = log_alpha,
              log_beta = log_beta,
              log_gamma = log_gamma,
              log_xi = log_xi,
              logAhat = logAhat,
              mean_hat = mean_hat,
              covariance_hat = covariance_hat,
              log_pi_hat = log_ppi_hat,
              AIC = AIC,
              BIC = BIC)
  return(out)

}


#' Stream-Based Active Learning with a Partially Hidden Markov Model (pHMM)
#'
#' Implements the stream-based active learning strategy of Capezza, Lepore, and
#' Paynabar (2025) for process monitoring with partially observed states. At
#' each time step, the method fits a pHMM to the available data, and balances
#' between \emph{exploitation} (reducing predictive uncertainty) and
#' \emph{exploration} (detecting potential out-of-control shifts) to decide
#' whether to request the true label of the current observation. Labeling
#' requests are constrained by a user-defined budget.
#'
#' The exploitation criterion is based on the entropy of the state sequence,
#' while the exploration criterion uses a multivariate exponentially weighted
#' moving average (MEWMA) statistic. The two criteria are combined with a
#' user-defined weighting, and labeling stops when the budget is exhausted
#' or at the end of the data stream.
#'
#' @param y A numeric matrix of dimension \eqn{T \times d}, where each row
#'   corresponds to a \eqn{d}-dimensional observation at time \eqn{t}.
#' @param true_x Integer vector of true states of length \code{nrow(y)},
#' used to assess model predictions. The first \code{T0} values,
#' assumed to be from an in-control process, must be 1.
#' @param T0 Integer. Number of initial observations assumed to be labeled as
#'   in-control (state 1).
#' @param B Numeric between 0 and 1. Labeling budget, expressed as the maximum
#'   fraction of observations (after the first \code{T0}) for which labels may
#'   be acquired. Default is \code{0.1}.
#' @param weight_exploration Numeric between 0 and 1. Weight assigned to the
#'   exploration criterion. The exploitation weight is computed as
#'   \code{1 - weight_exploration}. Default is \code{0.5}.
#' @param lambda_MEWMA Numeric in (0,1). Smoothing parameter for the MEWMA
#'   statistic used in the exploration criterion. Default is \code{0.3}.
#' @param verbose Logical. If \code{TRUE}, prints the current time index as the
#'   algorithm progresses. Default is \code{FALSE}.
#'
#' @return A list with components:
#' \itemize{
#'   \item \code{decision}: character vector indicating the action taken at
#'         each time (\code{"label_exploitation"}, \code{"label_exploration"},
#'         or predicted state).
#'   \item \code{xlabeled}: updated state sequence including acquired labels.
#'   \item \code{xhat}: final predicted state sequence.
#'   \item \code{scores}: classification performance metrics (accuracy,
#'         precision, recall, F1, AUC) computed against the true states.
#' }
#'
#' @examples
#' \donttest{
#' library(ActiveLearning4SPM)
#' set.seed(123)
#' dat <- simulate_stream(T0 = 50, TT = 100, T_min_IC = 20, T_max_IC = 30)
#' out <- active_learning_pHMM(y = dat$y,
#'                             true_x = dat$x,
#'                             T0 = 50,
#'                             B = 0.1)
#' table(out$decision)
#' out$scores$f1
#' }
#'
#' @note This function is intended for simulation studies where the entire
#' observation sequence \code{y} and the corresponding true states \code{x}
#' are available in advance. The function uses these to evaluate the
#' active learning strategy under a given budgets. In real-time applications,
#' data and labels would arrive sequentially, and labels would only be obtained
#' if requested by the strategy.
#'
#' @references
#' Capezza, C., Lepore, A., & Paynabar, K. (2025).
#'   Stream-Based Active Learning for Process Monitoring.
#'   \emph{Technometrics}. <doi:10.1080/00401706.2025.2561744>.
#'
#' @export
active_learning_pHMM <- function(y,
                                 true_x,
                                 T0,
                                 B = 0.1,
                                 weight_exploration = 0.5,
                                 lambda_MEWMA = 0.3,
                                 verbose = FALSE) {

  TT_tot <- nrow(y)
  d <- ncol(y)
  xlabeled <- rep(NA, TT_tot)
  xlabeled[1:T0] <- 1

  entropy_vec <- rep(NA, TT_tot)
  p_value_exploration_vec <- rep(NA, TT_tot)
  p_value_entropy_vec <- rep(NA, TT_tot)
  decision <- rep(NA, TT_tot)
  B_vec <- rep(NA, TT_tot)
  available_budget <- B * (TT_tot - T0)
  decision[1:T0] <- "1"
  V2n_min <- numeric(TT_tot)

  for (tt in (T0 + 1):TT_tot) {

    states_observed <- sort(unique(xlabeled[!is.na(xlabeled)]))
    nstates_observed <- length(unique(xlabeled[!is.na(xlabeled)]))
    if (!are_equal(states_observed, 1:nstates_observed)) {
      actual_states <- states_observed
      actual_xlabeled <- xlabeled
      states <- seq_along(states_observed)
      for (ss in 1:nstates_observed) {
        xlabeled[xlabeled == actual_states[ss]] <- ss
      }
    } else {
      actual_states <- states_observed
    }

    max_nstates <- nstates_observed + sum(is.na(xlabeled[1:tt]))

    mod_proposed <- fit_pHMM_auto(y = y[1:tt, , drop = FALSE],
                                  xlabeled = xlabeled[1:tt])

    nstates <- length(mod_proposed$mean_hat)
    if (nstates == 1) {
      chosen_states <- 1
    }
    if (nstates == 2) {
      if (nstates_observed == 1) {
        chosen_states <- 1:2
      }
      if (nstates_observed == 2) {
        chosen_states <- states_observed
      }
    }
    if (nstates >= 3) {
      chosen_states <- 1:nstates
    }

    available_budget <- B * (TT_tot - T0) - sum(!is.na(xlabeled[-(1:T0)]))
    labellable_samples <- TT_tot - tt + 1
    this_B <- max(min(available_budget / labellable_samples, 1), 0)

    nsim_sequences <- max(20, 4 / this_B)
    nsim_sequences <- min(200, nsim_sequences)
    entropy_sim <- numeric(nsim_sequences)
    which_labeled <- which(!is.na(xlabeled))
    which_labeled_before_window <- which_labeled[which_labeled < (tt)]
    if (length(which_labeled_before_window) == 0) {
      last_labeled <- 1
      log_ppi <- mod_proposed$log_pi_hat
    } else {
      last_labeled <- max(which_labeled_before_window)
      log_ppi <- rep(-Inf, nstates)
      log_ppi[xlabeled[last_labeled]] <- 0
    }

    nnn <- length(last_labeled:tt)
    for (nn in 1:nsim_sequences) {
      dat_sim <- simulate_hmm(nnn,
                              log_ppi = log_ppi,
                              logA = mod_proposed$logAhat,
                              mean = mod_proposed$mean_hat,
                              covariance = mod_proposed$covariance_hat,
                              xlabeled = xlabeled[last_labeled:tt],
                              n_last_y = nnn)
      entropy_sim[nn] <- entropy_of_sequence_hmm(mod_proposed,
                                                 xlabeled = xlabeled[last_labeled:tt],
                                                 y = dat_sim$y,
                                                 window = nnn:nnn)
    }
    entropy_vec[tt] <- entropy_of_sequence_hmm(mod_proposed,
                                                   xlabeled = xlabeled[last_labeled:tt],
                                                   y = y[last_labeled:tt, , drop = FALSE],
                                                   window = nnn:nnn)
    p_value_entropy <- mean(entropy_vec[tt] < entropy_sim)
    p_value_entropy_vec[tt] <- p_value_entropy


    V2n_min[tt] <- min_ewma_cpp(y = y,
                                    mean_list = mod_proposed$mean_hat,
                                    cov_list = mod_proposed$covariance_hat,
                                    TT = TT_tot,
                                    p = d,
                                    lambda = lambda_MEWMA,
                                    ww_end = tt)
    p_value_exploration <- stats::pchisq(V2n_min[tt], df = d, lower.tail = FALSE)
    p_value_exploration_vec[tt] <- p_value_exploration


    if (nstates == 1) {
      entropy_vec[tt] <- 0
      p_value_entropy <- 1
      p_value_entropy_vec[tt] <- p_value_entropy
    }

    if (!are_equal(actual_states, 1:nstates_observed)) {
      xlabeled <- actual_xlabeled
      states_observed <- actual_states
      nstates_observed <- length(states_observed)
    }

    B_vec[tt] <- this_B

    Bentropy <- (1 - weight_exploration) * this_B
    Bexploration <- weight_exploration * this_B

    if (p_value_entropy < Bentropy & this_B > B) {
      decision[tt] <- "label_exploitation"
      xlabeled[tt] <- true_x[tt]
    } else {
      if (p_value_exploration < Bexploration & this_B > B) {
        decision[tt] <- "label_exploration"
        xlabeled[tt] <- true_x[tt]
      } else {
        decision[tt] <- chosen_states[which.max(mod_proposed$log_gamma[nrow(mod_proposed$log_gamma), ])]
      }
    }
    if (verbose) {
      message(sprintf(
        "t=%d | Available labels: %d | Explor. p-value = %.3f | Exploit. p-value = %.3f | True state = %d | Decision = %s",
        tt,
        round(available_budget),
        p_value_exploration,
        p_value_entropy,
        true_x[tt],
        decision[tt]
      ))
    }
  }

  xhat <- suppressWarnings(as.numeric(decision))
  xhat[!is.na(xlabeled)] <- xlabeled[!is.na(xlabeled)]

  scores <- get_classification_scores(xhat[-(1:T0)], true_x[-(1:T0)])

  ret <- list(
    decision = decision,
    xlabeled = xlabeled,
    xhat = xhat,
    scores = scores
  )

  return(ret)
}


log_sum_vec <- function(log_vec) {
  max_log_vec <- max(log_vec)
  if (max_log_vec == -Inf) {
    ret <- -Inf
  } else {
    ret <- log(sum(exp(log_vec - max_log_vec))) + max_log_vec
  }
  ret
}

log_rowSums <- function(log_mat) {
  row_maxs <- Rfast::rowMaxs(log_mat, value = TRUE)
  log(Rfast::rowsums(exp(log_mat - row_maxs))) + row_maxs
}

log_colSums <- function(log_mat) {
  col_maxs <- Rfast::colMaxs(log_mat, value = TRUE)
  log(Rfast::rowsums(exp(t(log_mat) - col_maxs))) + col_maxs
}

calculate_logB <- function(y, mean, covariance) {
  nstates <- length(mean)
  logBlist <- lapply(1:nstates, function(ii) {
    log_density <- mvnfast::dmvn(y, mu = mean[[ii]], sigma = covariance[[ii]], log = TRUE)
    return(log_density)
  })
  logB <- do.call(cbind, logBlist)
  return(logB)
}

get_classification_scores <- function(xhat, x) {

  xhat[xhat > 1] <- 2
  x[x > 1] <- 2

  nclasses <- length(table(c(x, xhat)))

  accuracy <- mean(xhat == x) # (TP+TN) / ALL

  precision <- recall <- f1 <- auc <- numeric(nclasses)

  if (nclasses > 2) {
    for (jj in 1:nclasses) {
      precision[jj] <- mean(x[xhat == jj] == jj) # TP / PREDICTED_P
      recall[jj] <- mean(xhat[x == jj] == jj) #TP / ALL_P
      f1[jj] <- 1 / mean(1 / precision[jj], 1 / recall[jj])
      xjj <- ifelse(x == jj, 1, 2)
      xjjhat <- ifelse(xhat == jj, 1, 2)
      auc[jj] <- as.numeric(pROC::auc(pROC::roc(xjj, xjjhat, quiet = TRUE)))
    }
  }

  if (nclasses == 2) { # assuming 2 is minority, positive class
    precision <- mean(x[xhat == 2] == 2) # TP / PREDICTED_P
    recall <- mean(xhat[x == 2] == 2) #TP / ALL_P
    f1 <- 1 / mean(c(1 / precision, 1 / recall))
    auc <- as.numeric(pROC::auc(pROC::roc(x, as.numeric(xhat), quiet = TRUE)))
  }

  list(accuracy = accuracy,
       precision = precision,
       recall = recall,
       f1 = f1,
       auc = auc)

}

are_equal <- function(vec1, vec2) {
  identical(as.numeric(vec1), as.numeric(vec2))
}

simulate_hmm <- function(n, log_ppi, logA, mean, covariance, xlabeled, n_last_y = n) {

  if (n != length(xlabeled)) {
    stop("length of xlabeled must be equal to n")
  }

  which_label <- which(!is.na(xlabeled))
  nstates <- length(log_ppi)
  TT <- length(xlabeled)
  p <- length(mean[[1]])
  sequence_list <- list()
  cut_points <- unique(c(1, which_label, TT))
  log_gamma <- matrix(NA, nrow = TT, ncol = nstates)
  log_alpha_list <- list()
  log_beta_list <- list()

  xsim <- rep(NA, TT)
  xsim[which_label] <- xlabeled[which_label]

  if (n == 1) {
    if (!is.na(xlabeled)) {
      xsim <- xlabeled
    } else {
      xsim <- sample(1:nstates, 1, prob = exp(log_ppi))
      ysim <- mvnfast::rmvn(1, mu = mean[[xsim]], sigma = covariance[[xsim]])
      ret <- list(x = xsim, y = ysim)
      return(ret)
    }
  }

  A <- exp(logA)

  for (ii_known in 1:(length(cut_points) - 1)) {
    sequence_list[[ii_known]] <- cut_points[ii_known]:cut_points[ii_known + 1]
    start_ii <- sequence_list[[ii_known]][1]
    end_ii <- max(sequence_list[[ii_known]])
    TT_ii <- length(sequence_list[[ii_known]])

    if (is.na(xlabeled[end_ii])) { # Cases 1 and 2 (final x unknown)
      if (is.na(xlabeled[start_ii])) { # Case 1 (also first x unknown)
        ppi <- exp(log_ppi)
        xsim <- simulate_state_sequence(ppi, TT_ii, A)
      } else {
        ppi <- rep(0, length(log_ppi))
        ppi[xlabeled[start_ii]] <- 1
      }

      xsim[start_ii:end_ii] <- simulate_state_sequence(ppi, TT_ii, A)
    } else { # Cases 3 and 4 (final x known)

      if (is.na(xlabeled[start_ii])) { # Case 4 (first x unknown)
        log_alpha_ii <- matrix(NA, nrow = TT_ii, ncol = nstates)
        log_beta_ii <- matrix(NA, nrow = TT_ii, ncol = nstates)
        log_gamma_ii <- matrix(NA, nrow = TT_ii, ncol = nstates)

        log_alpha_ii[1, ] <- log_ppi
        for (tt in 2:TT_ii) {
          log_alpha_ii[tt, ] <- log_colSums_cpp(log_alpha_ii[tt - 1, ] + logA)
        }
        log_beta_ii[TT_ii - 1, ] <- logA[, xlabeled[end_ii]]
        if (TT_ii > 2) {
          for (tt in (TT_ii - 2):1) {
            log_beta_ii[tt, ] <- log_colSums_cpp(t(logA) + log_beta_ii[tt + 1, ])
          }
        }
        log_gamma_ii <- log_alpha_ii + log_beta_ii
        log_gamma_ii <- log_gamma_ii - as.numeric(log_rowSums_cpp(log_gamma_ii))

        xsim[start_ii] <- sample(1:nstates, 1, prob = exp(log_gamma_ii[1, ]))
      }

      # if TT_ii == 2 there are no unknown states
      if (TT_ii > 2) {
        # Case 3 (first and last x known)
        for (ttii in 1:(TT_ii - 2)) {
          # print(ttii)

          current_tt <- sequence_list[[ii_known]][ttii]
          current_length <- length(ttii:TT_ii)
          log_alpha_ii <- matrix(NA, nrow = current_length, ncol = nstates)
          log_beta_ii <- matrix(NA, nrow = current_length, ncol = nstates)
          log_gamma_ii <- matrix(NA, nrow = current_length, ncol = nstates)

          log_alpha_ii[1, ] <- -Inf
          log_alpha_ii[1, xsim[current_tt]] <- 0
          for (ttkk in 2:nrow(log_alpha_ii)) {
            log_alpha_ii[ttkk, ] <- log_colSums_cpp(log_alpha_ii[ttkk - 1, ] + logA)
          }
          log_beta_ii[current_length - 1, ] <- logA[, xlabeled[end_ii]]
          for (tt in (current_length - 2):1) {
            log_beta_ii[tt, ] <- log_colSums_cpp(t(logA) + log_beta_ii[tt + 1, ])
          }
          log_gamma_ii <- log_alpha_ii + log_beta_ii
          log_gamma_ii <- log_gamma_ii - as.numeric(log_rowSums_cpp(log_gamma_ii))

          xsim[current_tt + 1] <- sample(1:nstates, 1, prob = exp(log_gamma_ii[2, ]))

        }
      }
    }
  }
  idx <- (TT - n_last_y + 1):TT
  ysim <- lapply(idx, function(ii) mvnfast::rmvn(1, mu = mean[[xsim[ii]]], sigma = covariance[[xsim[[ii]]]]))
  ysim <- do.call(rbind, ysim)

  ret <- list(x = xsim, y = ysim)
  return(ret)
}


get_logp_start <- function(hmm, xlabeled, window) {

  tstart <- min(window)
  nstates <- ncol(hmm$logB)

  logAhat <- hmm$logAhat
  log_p_x <- matrix(NA, nrow = tstart, ncol = nstates)
  if (!is.na(xlabeled[1])) {
    log_p_x[1, ] <- -Inf
    log_p_x[1, xlabeled[1]] <- 0
  } else {
    log_p_x[1, ] <- hmm$log_pi_hat
  }
  if (tstart >= 2) {
    for (tt in 2:tstart) {
      if (!is.na(xlabeled[tt])) {
        log_p_x[tt, ] <- -Inf
        log_p_x[tt, xlabeled[tt]] <- 0
      } else {
        log_p_x[tt, ] <- log_colSums_cpp(logAhat + log_p_x[tt - 1, ])
      }
    }
  }

  return(log_p_x[nrow(log_p_x), ])

}


entropy_of_sequence_hmm <- function(hmm, xlabeled, y, window) {

  log_ppi <- get_logp_start(hmm, xlabeled, 1)
  logAhat <- hmm$logAhat
  nstates <- ncol(hmm$logB)
  logB <- calculate_logB(
    y = y,
    mean = hmm$mean_hat,
    covariance = hmm$covariance_hat)

  if (all(window == nrow(y))) {
    log_alpha <- log_forward_algorithm_cpp(log_ppi, logAhat, logB, xlabeled)[window, , drop = FALSE]
    log_gamma <- log_alpha - log_rowSums(log_alpha)
  } else {
    log_alpha <- log_forward_algorithm_cpp(log_ppi, logAhat, logB, xlabeled)[window, , drop = FALSE]
    log_beta <- log_backward_algorithm_cpp(logAhat, logB, xlabeled)[window, , drop = FALSE]
    log_ab <- log_alpha + log_beta
    log_rowsums_log_ab <- log_sum_vec(log_ab)
    log_gamma <- log_ab - log_rowsums_log_ab
  }

  ntest <- length(window)

  log_cc <- matrix(0, nrow = ntest, ncol = nstates)
  log_cc1 <- log_gamma[1, ]


  log_cc[1, ] <- log_cc1 - log_sum_vec(log_cc1)
  if (ntest > 1) {
    for (tt in 2:ntest) {
      if (!is.na(xlabeled[window[tt]])) {
        log_cc[tt, xlabeled[window[tt]]] <- 0
        log_cc[tt, -xlabeled[window[tt]]] <- -Inf
      } else {
        log_cc_tt <- log_colSums_cpp(log_cc[tt-1, ] + t(t(logAhat) + logB[window[tt], ]))
        log_cc[tt, ] <- log_cc_tt - log_sum_cpp(log_cc_tt)
      }
    }
  }

  # log_p_back
  if (ntest > 1) {
    log_p_back <- array(0, dim = c(ntest - 1, nstates, nstates))
    if (ntest > 2) {
      for (tt in 2:ntest) {
        log_p_back_ttm1 <- logAhat + log_cc[tt - 1, ]
        log_p_back[tt-1, , ] <- t(t(log_p_back_ttm1) - as.numeric(log_colSums(log_p_back_ttm1)))
      }
    }
  }

  H <- matrix(NA, nrow = ntest, ncol = nstates)
  H[1, ] <- 0
  if (ntest > 1) {
    for (tt in 2:ntest) {
      possible_states <- 1:nstates
      if (!is.na(xlabeled[window[tt]])) {
        H[tt, -xlabeled[window[tt]]] <- 0
        possible_states <- xlabeled[window[tt]]
      }
      for (jj in possible_states) {
        term1 <- sum(H[tt - 1, ] * exp(log_p_back[tt - 1, , jj]))
        if (max(log_p_back[tt - 1, , jj]) == 0) {
          term2 <- 0
        } else {
          term2 <- sum(exp(log_p_back[tt - 1, , jj]) * (log_p_back[tt - 1, , jj]))
        }
        H[tt, jj] <- term1 - term2
      }
    }
  }

  term1 <- H[ntest, ] * exp(log_cc[ntest, ])
  if (max(log_cc[ntest, ]) == 0) {
    term2 <- 0
  } else {
    term2 <- exp(log_cc[ntest, ]) * log_cc[ntest, ]
  }
  entropy_of_sequence <- sum(term1 - term2)
  return(entropy_of_sequence)
}
