#' SIS combo
#'
#' A function for SIS + SIS combo.
#' @param X An n by p matrix of exposures.
#' @param M An n by p matrix of mediators.
#' @param C An n by p matrix of covariates.
#' @param time A vector of survival time of samples.
#' @param status A vector of status indicator: 0=alive, 1=dead.
#' @param p_adjust_option The method for multiple correction. Option include q-value, holm, hochberg, hommel, bonferroni, BH, BY,
#' and fdr. Default is BH.
#' @param thres Threshold for determining significance.
#' @return A list which includes the final p-value matrix (p_final_matrix), adjusted p-value matrix (p_adjusted_matrix) and mediation-exposure matrix (p_med_matrix).
#' @import survival
#' @import fdrtool
#' @import data.table
#' @import stats
#' @export
#' @examples
#' data(example_dat)
#' surv_dat <- example_dat$surv_dat
#' SIS_combo(example_dat$X, example_dat$M, example_dat$C, time = surv_dat$time,
#' status = surv_dat$status)
#'
#'

SIS_combo <- function(X, M, C, time, status, p_adjust_option = "BH", thres = 0.05){
  d_x <- ncol(X)
  d_m <- ncol(M)
  n <- nrow(X)
  d_s1 <- round(n/log(n))
  ## step 1
  ### select X for survival outcome
  Y_X_s1 <- list()
  surv_data <- data.frame(surv_time = time, status = status)
  for(x in colnames(X)){
    surv_model_data_s1 <- cbind(X[ ,x], C, surv_data)
    surv_model_s1 <- survreg(Surv(surv_time, status) ~ ., data = surv_model_data_s1, dist = "lognormal")
    beta <- as.numeric(surv_model_s1$coefficients[2])
    Y_X_s1[[x]] <- data.table(X = x, beta = abs(beta))
  }
  Y_X_s1 <- rbindlist(Y_X_s1)
  Y_X_s1 <- as.data.table(Y_X_s1)
  Y_X_s1 <- Y_X_s1[order(beta, decreasing = TRUE)]
  if(d_s1 <= nrow(Y_X_s1)){
    X_sel_Y_s1 <- Y_X_s1[1:d_s1, X]
  } else {
    X_sel_Y_s1 <- Y_X_s1[ , X]
  }

  ### select X for M
  M_X_s1 <- list()
  for(m in colnames(M)){
    M_k_X_s1 <- list()
    for(x in colnames(X)){
      lm_model_data_s1 <- cbind(M[ ,m], X[ ,x], C)
      lm_model_data_s1 <- as.data.table(lm_model_data_s1)
      lm_model_s1 <- lm(V1 ~ ., data = lm_model_data_s1)
      beta <- as.numeric(lm_model_s1$coefficients[2])
      M_k_X_s1[[x]] <- data.table(X = x, beta = abs(beta))
    }
    M_k_X_s1 <- rbindlist(M_k_X_s1)
    M_k_X_s1 <- as.data.table(M_k_X_s1)
    M_k_X_s1 <- M_k_X_s1[order(beta, decreasing = TRUE)]
    if(d_s1 <= nrow(M_k_X_s1)){
      X_sel_M_k_s1 <- M_k_X_s1[1:d_s1, X]
    } else {
      X_sel_M_k_s1 <- M_k_X_s1[ , X]
    }
    M_X_s1[[m]] <- X_sel_M_k_s1
  }

  ## step 2
  beta_m_s2 <- c()
  for(m in colnames(M)){
    surv_model_data_s2 <- cbind(M[ ,m], X[ ,X_sel_Y_s1], surv_data)
    surv_model_s2 <- survreg(Surv(surv_time, status) ~ ., data = surv_model_data_s2, dist = "lognormal")
    beta_m_s2 <- c(beta_m_s2, as.numeric(surv_model_s2$coefficients[2]))
  }
  names(beta_m_s2) <- colnames(M)

  M_X_s2 <- list()
  for(m in colnames(M)){
    X_sel <- M_X_s1[[m]]
    lm_model_data_s2 <- cbind(M = M[ ,m], X = X[ ,X_sel])
    lm_model_data_s2 <- as.data.table(lm_model_data_s2)
    lm_model_s2 <- lm(M ~ ., data = lm_model_data_s2)
    lm_model_s2_results <- summary(lm_model_s2)$coefficients[ ,1]
    if(length(lm_model_s2_results) == 1){
      M_X_s2[[m]] <- NULL
      next
    }
    #X_pos <- which(startsWith(names(lm_model_s2_results), "X"))
    X_pos <- c(3:length(lm_model_s2_results))
    alpha <- lm_model_s2$coefficients[X_pos]
    if(length(X_sel) == 1){
      names(alpha) <- X_sel
    }
    beta <- beta_m_s2[m]
    lm_model_s2_results <- data.table(M = m, X = names(alpha), alpha = as.numeric(alpha), beta = beta)
    #lm_model_s2_results[ , effect := abs(beta*alpha)]
    lm_model_s2_results$effect <- abs(lm_model_s2_results$beta * lm_model_s2_results$alpha)
    M_X_s2[[m]] <- lm_model_s2_results
  }
  M_X_s2 <- rbindlist(M_X_s2)
  #M_X_s2_sorted <- M_X_s2[order(effect, decreasing = TRUE)]
  M_X_s2_sorted <- M_X_s2[order(M_X_s2$effect, decreasing = TRUE), ]

  d_s2 <- round(n/(log(n)))
  if(d_s2 <= nrow(M_X_s2_sorted)){
    M_X_sel_s2 <- M_X_s2_sorted[1:d_s2, ]
  } else {
    M_X_sel_s2 <- M_X_s2_sorted
  }

  ## step 3
  X_u <- X
  M_u <- M
  colnames(X_u) <- paste0("X_", colnames(X))
  colnames(M_u) <- paste0("M_", colnames(M))
  M_sel <- unique(M_X_sel_s2$M)
  M_sel_u <- paste0("M_", unique(M_X_sel_s2$M))
  n_m_s3 <- length(M_sel)
  surv_model_data_s3 <- cbind(M_u[ , M_sel_u], X_u[ , paste0("X_", X_sel_Y_s1)], C, surv_data)
  surv_model_s3 <- survreg(Surv(surv_time, status) ~ ., data = surv_model_data_s3, dist = "lognormal")
  mod_results <- surv_model_s3$coefficients
  # transform p
  M_pos <- which(startsWith(names(mod_results), "M"))
  p_list <- abs(summary(surv_model_s3)$table[M_pos,1])/summary(surv_model_s3)$table[M_pos,2]
  p_list <- 2*(1 - pnorm(p_list))
  p_beta_m <- p_list

  col_X <- colnames(X)
  p_matrix <- matrix(nrow = d_x, ncol = length(M_sel))
  for(i in 1:length(M_sel)){
    M_k <- M_sel[i]
    X_sel_M_k_s3 <- M_X_sel_s2[M == M_k, X]
    M_k_u <- paste0("M_", M_k)
    lm_model_data_s3 <- cbind(M_u[ ,M_k_u], X_u[ ,paste0("X_", X_sel_M_k_s3)], C)
    lm_model_data_s3 <- as.data.table(lm_model_data_s3)
    lm_model_s3 <- lm(V1 ~ ., data = lm_model_data_s3)
    lm_model_s3_results <- summary(lm_model_s3)$coefficients[ ,1]
    X_pos <- which(startsWith(names(lm_model_s3_results), "X"))
    X_name <- names(lm_model_s3_results)[X_pos]
    p_list <- abs(summary(lm_model_s3)$coefficients[X_pos,1])/summary(lm_model_s3)$coefficients[X_pos, 2]
    p_list <- 2*(1 - pnorm(p_list))
    p_list <- as.matrix(p_list)
    if(length(X_name == 1)){
      rownames(p_list) <- X_name
    }

    X_name <- gsub("^X_", "", X_name)
    col_p_list_r <- setdiff(col_X, X_name)
    p_list_r <- matrix(NA, nrow = length(col_p_list_r), ncol = 1)
    rownames(p_list_r) <- col_p_list_r
    rownames(p_list) <- gsub("^X_", "", rownames(p_list))
    p_list_final <- rbind(p_list, p_list_r)
    p_list_final <- p_list_final[col_X, ]
    p_matrix[ ,i] <- p_list_final
  }
  rownames(p_matrix) <- col_X
  colnames(p_matrix) <- M_sel
  p_alpha_x <- p_matrix

  # generate final raw p-value
  d_x <- ncol(X)
  M_sel <- unique(M_X_sel_s2$M)
  p_final_matrix <- matrix(nrow = d_x, ncol = length(M_sel))
  for(i in 1:length(M_sel)){
    p_final_matrix[ ,i] <- ifelse(p_alpha_x[ ,i] >= p_beta_m[i], p_alpha_x[ ,i], p_beta_m[i])
  }
  rownames(p_final_matrix) <- col_X
  colnames(p_final_matrix) <- M_sel

  # adjust p-value
  p_v <- as.vector(p_final_matrix)
  p_pos <- which(is.na(p_v) == FALSE)
  p_v_without_na <- as.vector(na.omit(p_v))
  if(p_adjust_option == "q-value"){
    p_adjusted_v <- fdrtool(p_v_without_na, statistic = "pvalue", plot=FALSE, verbose = FALSE, cutoff.method = "locfdr")$qval
  } else {
    p_adjusted_v <- p.adjust(p_v_without_na, method = p_adjust_option, n = length(p_v))
  }
  p_adjusted_v_na_filled <- rep(NA, length(p_v))
  p_adjusted_v_na_filled[p_pos] <- p_adjusted_v
  p_adjusted_matrix <- matrix(p_adjusted_v_na_filled, nrow = nrow(p_final_matrix), ncol = ncol(p_final_matrix))
  rownames(p_adjusted_matrix) <- rownames(p_final_matrix)
  colnames(p_adjusted_matrix) <- colnames(p_final_matrix)

  # generate selected M and X pairs
  p_med_matrix <- ifelse(p_adjusted_matrix < thres, 1, 0)
  p_med_matrix <- ifelse(is.na(p_med_matrix) == TRUE, 0, p_med_matrix)
  p_med_matrix_r <- matrix(0, nrow = ncol(X), ncol = ncol(M) - length(M_sel))
  rownames(p_med_matrix_r) <- colnames(X)
  colnames(p_med_matrix_r) <- setdiff(colnames(M), M_sel)
  p_med_matrix <- cbind(p_med_matrix, p_med_matrix_r)
  p_med_matrix <- p_med_matrix[ ,colnames(M)]

  return(list(p_final_matrix = p_final_matrix, p_adjusted_matrix = p_adjusted_matrix, p_med_matrix = p_med_matrix))

}
