## functions to fit a gam model, calculate in-sample residuals and out of sample predictions
## reusing some code from the IVDML package
fit_gam <- function(Y, X){
  p <- NCOL(X)
  dat <- data.frame(Y,X)
  varnames <- character(length = p)
  for(j in 1:p){
    varnames[j] <- paste("X", j, sep = "")
  }
  colnames(dat) <- c("Y", varnames)
  unique_lengths <- apply(as.matrix(X), 2, function(x){length(unique(x))})
  max_k_gam <- pmin(30, unique_lengths - 1)
  if(all(max_k_gam == 0)){
    form <- "Y ~ 1"
  } else {
    form <- "Y ~ "
    for(j in 1:p){
      if((max_k_gam[j] %in% c(1, 2))){
        form <- paste(form, " + X", j, sep = "")
      }
      else if(max_k_gam[j] > 2){
        form <- paste(form, " + s(X", j, ", k = ", max_k_gam[j], ")", sep = "")
      }
    }
  }
  mod <- mgcv::gam(formula = formula(form), data = dat)
  return(mod)
}

resid_gam <- function(Y, X){
  mod <- fit_gam(Y, X)
  return(mod$residuals)
}

predict_gam <- function(mod, Xnew){
  p <- NCOL(Xnew)
  varnames <- character(length = p)
  for(j in 1:p){
    varnames[j] <- paste("X", j, sep = "")
  }
  Xnew <- data.frame(Xnew)
  colnames(Xnew) <- varnames
  pred <- predict(mod, newdata = Xnew)
  return(pred)
}

## functions to fit an xgboost model, calculate in-sample residuals and out of sample predictions
## reusing some code from the IVDML package
fit_xgboost <- function(Y, X, regr.pars){
  if(!exists("max_nrounds", regr.pars)){
    regr.pars$max_nrounds <- 500
  }
  if(!exists("eta", regr.pars)){
    regr.pars$eta <- c(0.1, 0.2, 0.3, 0.5)
  }
  if(!exists("max_depth", regr.pars)){
    regr.pars$max_depth <-  c(1, 2, 3, 4, 5, 6, 7)
  }
  if(!exists("early_stopping_rounds", regr.pars)){
    regr.pars$early_stopping_rounds <- 10
  }
  if(!exists("k_cv", regr.pars)){
    regr.pars$k_cv <- 10
  }
  par_grid <- with(regr.pars, expand.grid(eta = eta, max_depth = max_depth))
  dtrain <- xgboost::xgb.DMatrix(data = as.matrix(X), label = Y, nthread = 1)
  get_min_mse_nrounds <- function(pars){
    fit_cv <- xgboost::xgb.cv(params = list(nthread = 1, eta = pars[1], max_depth = pars[2]),
                              data = dtrain, nrounds = regr.pars$max_nrounds, nfold = regr.pars$k_cv,
                              early_stopping_rounds = regr.pars$early_stopping_rounds, verbose = FALSE)
    min_mse <- min(fit_cv$evaluation_log$test_rmse_mean)
    indopt <- which(fit_cv$evaluation_log$test_rmse_mean == min_mse)
    if(length(indopt) > 1){
      indopt <- indopt[1]
    }
    return(c(min_mse, indopt))
  }
  par_mses <- apply(par_grid, 1, get_min_mse_nrounds)
  ind_min <- which(par_mses[1,] == min(par_mses[1,]))
  if(length(ind_min) > 1){
    ind_min <- ind_min[1]
  }
  par_min <- par_grid[ind_min,]
  min_nrounds <- par_mses[2, ind_min]
  if(min_nrounds == regr.pars$max_nrounds){
    warning("CV chooses nrounds equal to max_nrounds. Consider increasing max_nrounds.")
  }
  best_pars <- list(eta = par_min$eta, max_depth = par_min$max_depth, nrounds = min_nrounds)
  mod <- xgboost::xgb.train(params = list(nthread = 1, eta =  best_pars$eta, max_depth = best_pars$max_depth),
                            data = dtrain, nrounds = best_pars$nrounds, verbose = FALSE)
}

resid_xgboost <- function(Y, X, regr.pars){
  mod <- fit_xgboost(Y, X, regr.pars)
  dtrain <- xgboost::xgb.DMatrix(data = as.matrix(X), nthread = 1)
  return(Y - predict(mod, dtrain))
}

predict_xgboost <- function(mod, Xnew){
  dnew <- xgboost::xgb.DMatrix(data = as.matrix(Xnew), nthread = 1)
  return(predict(mod, dnew))
}


get_residuals <- function(Y, X, regr.meth, regr.pars = list()){
  if(regr.meth == "gam"){
    res <- resid_gam(Y, X)
  } else if(regr.meth == "xgboost"){
    res <- resid_xgboost(Y, X, regr.pars)
  } else {
    stop("Invalid regression method. Use regr.meth = 'gam' or regr.meth = 'xgboost'.")
  }
  return(res)
}





## function for multivariate wgcm.fix
wgcm.fix.mult <- function(X, Y, Z, regr.meth, regr.pars, weight.num,
                          weight.meth, nsim){
  X <- as.matrix(X)
  Y <- as.matrix(Y)
  n <- NROW(X)
  dx <- NCOL(X)
  dy <- NCOL(Y)
  if(is.null(Z)){
    warning("No Z specified. No weight functions can be calculated. Function simply tests for vanishing correlation between components of X and Y.")
    eps.mat <- t(t(X)-colMeans(X))
    xi.mat <- t(t(Y)-colMeans(Y))
    W <- rep(1,n)
  } else {
    Z <- as.matrix(Z)
    calc.res.Z <- function(V){
      return(as.numeric(get_residuals(V, Z, regr.pars = regr.pars, regr.meth = regr.meth)))
    }
    eps.mat <- apply(X, 2, calc.res.Z)
    xi.mat <- apply(Y, 2, calc.res.Z)
    W <- weight_matrix(Z, weight.num, weight.meth)
  }
  R <- NULL
  for (j in 1:dx){
    for (l in 1:dy){
      R.jl <- eps.mat[,j]*xi.mat[,l]
      R <- cbind(R, R.jl*W)
    }
  }
  R <- t(R)
  R.norm <- R / sqrt(rowMeans(R^2) - rowMeans(R)^2)
  T.stat <- sqrt(n) * max(abs(rowMeans(R.norm)))
  T.stat.sim <- apply(abs(R.norm %*% matrix(rnorm(n * nsim), n, nsim)),
                      2, max) / sqrt(n)
  p.value <- (sum(T.stat.sim >= T.stat) + 1) / (nsim + 1)
  return(p.value)
}

## function for multivariate wgcm.est
wgcm.est.mult <- function(X, Y, Z, beta, regr.meth, regr.pars, nsim){
  X <- as.matrix(X)
  Y <- as.matrix(Y)
  n <- NROW(X)
  dx <- NCOL(X)
  dy <- NCOL(Y)
  if(is.null(Z)){
    warning("No Z specified. No weight functions can be estimated. Function simply tests for vanishing correlation between components of X and Y.")
    eps.mat <- t(t(X)-colMeans(X))
    xi.mat <- t(t(Y)-colMeans(Y))
    R <- NULL
    for (j in 1:dx){
      for (l in 1:dy){
        R.jl <- eps.mat[,j]*xi.mat[,l]
        R <- cbind(R, R.jl)
      }
    }
    R <- t(R)
    R.norm <- R / sqrt(rowMeans(R^2) - rowMeans(R)^2)
    T.stat <- sqrt(n) * max(abs(rowMeans(R.norm)))
    T.stat.sim <- apply(abs(R.norm %*% matrix(rnorm(n * nsim), n, nsim)),
                        2, max) / sqrt(n)
    p.value <- (sum(T.stat.sim >= T.stat) + 1) / (nsim + 1)
  } else {
    Z <- as.matrix(Z)
    ind.train <- sample(1:n, ceiling(beta*n))
    Xtrain <- X[ind.train,]
    Xtest <- X[-ind.train,]
    Ytrain <- Y[ind.train,]
    Ytest <- Y[-ind.train,]
    Ztrain <- Z[ind.train,]
    Ztest <- Z[-ind.train,]
    calc.res.Z <- function(V){
      return(as.numeric(get_residuals(V, Ztest, regr.pars = regr.pars, regr.meth = regr.meth)))
    }
    eps.mat <- apply(Xtest, 2, calc.res.Z)
    xi.mat <- apply(Ytest, 2, calc.res.Z)
    R <- NULL
    for (j in 1:dx){
      for (l in 1:dy){
        W <- predict_weight(Xtrain[,j], Ytrain[,l], Ztrain, Ztest, regr.meth, regr.pars)
        R.jl <- eps.mat[,j]*xi.mat[,l]
        R <- cbind(R, R.jl*W)
      }
    }
    R <- t(R)
    R.norm <- R/sqrt(rowMeans(R^2)-rowMeans(R)^2)
    ntest <- n-length(ind.train)
    #The estimated w-functions aim at making a positive test statistic. Hence
    #a one-sided test.
    T.stat <- sqrt(ntest) * max(rowMeans(R.norm))
    T.stat.sim <- apply(R.norm %*% matrix(rnorm(ntest * nsim), ntest, nsim),
                        2, max) / sqrt(ntest)
    p.value <- (sum(T.stat.sim >= T.stat) + 1)/(nsim + 1)
  }
  return(p.value)
}



## function to calculate weight matrix
weight_matrix <- function(Z, weight.num, weight.meth) {
  if (weight.meth == "sign") {
    n <- NROW(Z)
    dz <- NCOL(Z)
    W <- rep(1,n)
    if (weight.num >= 1) {
      d.probs <- (1:weight.num) / (weight.num + 1)
      for (i in 1:dz) {
        Zi <- Z[,i]
        a.vec <- quantile(Zi, d.probs, names=F)
        Wi <- outer(Zi, a.vec, signa)
        W <- cbind(W,Wi)
      }
    }
  } else {
    stop("Only method \"sign\" implemented yet to calculate weight function")
  }
  return(W)
}




## translated sign weight function
signa <- function(x,a){return(sign(x-a))}


## function to calculate a 1sided p-value for wgcm.fix, since we expect
## the test statistic to be positive under the alternative
wgcm.1d.1sided <- function(Xtest, Ytest, Ztest, W, regr.meth, regr.pars) {
  n <- NROW(Ztest)
  eps <- as.numeric(get_residuals(Xtest, Ztest, regr.pars = regr.pars, regr.meth = regr.meth))
  xi <- as.numeric(get_residuals(Ytest, Ztest, regr.pars = regr.pars, regr.meth = regr.meth))
  R <- eps*xi*W
  T.stat <- sqrt(n)*mean(R)/sqrt(mean(R^2)-mean(R)^2)
  p.value <- 1-pnorm(T.stat)
  return(p.value)
}

## function to estimate weight function for wgcm.est
predict_weight <- function(Xtrain, Ytrain, Ztrain, Ztest,
                           regr.meth, regr.pars) {
  eps <- as.numeric(get_residuals(Xtrain, Ztrain, regr.pars = regr.pars, regr.meth = regr.meth))
  xi <- as.numeric(get_residuals(Ytrain, Ztrain, regr.pars = regr.pars, regr.meth = regr.meth))
  eps.xi.train <- eps * xi
  switch(regr.meth, "gam"={
    mod <- fit_gam(eps.xi.train, Ztrain)
    W <- predict_gam(mod, Ztest)
  }, "xgboost"={
    mod <- fit_xgboost(eps.xi.train, Ztrain, regr.pars)
    W <- predict_xgboost(mod, Ztest)
  }
  )
  return(W)
}




