
#------------------------------------------------------------------------------#
####              Calculate cost-effectiveness outcomes                     ####
#------------------------------------------------------------------------------#
#' Calculate cost-effectiveness outcomes
#'
#' \code{calculate_ce_out} calculates costs and effects for a given vector of parameters using a simulation model.
#' @param l_params_all List with all parameters of decision model
#' @param n_wtp Willingness-to-pay threshold to compute net monetary benefits (
#' NMB)
#' @return A dataframe with discounted costs, effectiveness and NMB.
#' @export
calculate_ce_out <- function(l_params_all, n_wtp = 10000, verbose = FALSE){ # User defined
  with(as.list(l_params_all), {
    #browser()
    n_cycles        <- n_cycles/cycle_length
    v_names_states  <- c("H", "S", "D")       # state names, Healthy (H), Sick (S), Dead(D)
    n_states        <- length(v_names_states) # number of health states 
    
    ## Cycle names
    v_names_cycles  <- paste("cycle", 0:n_cycles)
    
    ###  Transition Probabilities
    
    ### Converting rates to probabilities
    # p = 1 - exp( -r * cycle_length)
    p_HS_SoC  <- rate_to_prob(r = r_HS_SoC,  time = cycle_length) # probability  of becoming sick when healthy, under SoC
    p_HS_trtA <- rate_to_prob(r = r_HS_trtA, time = cycle_length) # probability of becoming sick when healthy, under treatment A
    p_HS_trtB <- rate_to_prob(r = r_HS_trtB, time = cycle_length) # probability of becoming sick when healthy, under treatment B
    p_SD      <- rate_to_prob(r = r_SD,      time = cycle_length) # probability of dying when sick
    v_p_HD    <- rate_to_prob(r = v_r_HD,    time = cycle_length) # probability of dying when healthy (vector)
    v_p_HD    <- rep( v_p_HD, each = 1 / cycle_length)
    
    
    # All starting healthy
    v_m_init <- c("H" = 1, "S" = 0, "D" = 0)  
    
    ###################### Construct state-transition models ###################
    ### Initialize cohort trace for SoC 
    m_M_SoC <- matrix(0, 
                      nrow = (n_cycles + 1), ncol = n_states, 
                      dimnames = list(v_names_cycles, v_names_states))
    # Store the initial state vector in the first row of the cohort trace
    m_M_SoC[1, ] <- v_m_init
    
    ## Initialize cohort traces for treatments A and B
    # Structure and initial states are the same as for SoC
    m_M_trtA <- m_M_trtB <- m_M_SoC
    
    ## Create transition probability arrays for strategy SoC 
    ### Initialize transition probability array for strategy SoC 
    # All transitions to a non-death state are assumed to be conditional on survival
    a_P_SoC <- array(0,  # Create 3-D array
                     dim = c(n_states, n_states, n_cycles),
                     dimnames = list(v_names_states, v_names_states, 
                                     v_names_cycles[-length(v_names_cycles)])) # name the dimensions of the array 
    
    ### Fill in array
    ## Standard of Care
    # from Healthy
    a_P_SoC["H", "H", ]    <-  (1 - p_HS_SoC - v_p_HD)
    a_P_SoC["H", "S",    ] <-       p_HS_SoC
    a_P_SoC["H", "D",    ] <-       v_p_HD
    
    # from Sick
    a_P_SoC["S", "S", ] <- 1 - p_SD
    a_P_SoC["S", "D", ] <-     p_SD
    
    # from Dead
    a_P_SoC["D", "D", ] <- 1
    
    ## Treatment A
    a_P_trtA <- a_P_SoC
    a_P_trtA["H", "H", ]    <- 1 - p_HS_trtA - v_p_HD
    a_P_trtA["H", "S",    ] <-     p_HS_trtA
    
    ## Treatment B
    a_P_trtB <- a_P_SoC
    a_P_trtB["H", "H", ]    <- 1 - p_HS_trtB - v_p_HD
    a_P_trtB["H", "S",    ] <-     p_HS_trtB
    
    ## Check if transition array and probabilities are valid
    # Check that transition probabilities are in [0, 1]
    check_transition_probability(a_P_SoC,  verbose = verbose)
    check_transition_probability(a_P_trtA, verbose = verbose)
    check_transition_probability(a_P_trtB, verbose = verbose)
    # Check that all rows sum to 1
    check_sum_of_transition_array(a_P_SoC,  n_states = n_states, n_cycles = n_cycles, verbose = verbose)
    check_sum_of_transition_array(a_P_trtA, n_states = n_states, n_cycles = n_cycles, verbose = verbose)
    check_sum_of_transition_array(a_P_trtB, n_states = n_states, n_cycles = n_cycles, verbose = verbose)
    
    # Iterative solution of age-dependent cSTM
    for(t in 1:n_cycles){
      ## Fill in cohort trace
      # For SoC
      m_M_SoC[t + 1, ]  <- m_M_SoC[t, ]  %*% a_P_SoC[, , t]
      # For strategy A
      m_M_trtA[t + 1, ] <- m_M_trtA[t, ] %*% a_P_trtA[, , t]
      # For strategy B
      m_M_trtB[t + 1, ] <- m_M_trtB[t, ] %*% a_P_trtB[, , t]
    }
    
    ## Store the cohort traces in a list 
    l_m_M <- list(SoC =  m_M_SoC,
                  A   =  m_M_trtA,
                  B   =  m_M_trtB)
    names(l_m_M) <- v_names_str
    
    ### State rewards
    ## Scale by the cycle length 
    # Vector of state utilities under strategy SoC
    v_u_SoC    <- c(H  = u_H, 
                    S  = u_S,
                    D  = u_D) * cycle_length
    # Vector of state costs under strategy SoC
    v_c_SoC    <- c(H  = c_H, 
                    S  = c_S,
                    D  = c_D) * cycle_length
    # Vector of state utilities under treatment A
    v_u_trtA   <- c(H  = u_H, 
                    S  = u_S, 
                    D  = u_D) * cycle_length
    # Vector of state costs under treatment A
    v_c_trtA   <- c(H  = c_H + c_trtA, 
                    S  = c_S, 
                    D  = c_D) * cycle_length
    # Vector of state utilities under treatment B
    v_u_trtB   <- c(H  = u_H, 
                    S  = u_S, 
                    D  = u_D) * cycle_length
    # Vector of state costs under treatment B
    v_c_trtB   <- c(H  = c_H + c_trtB, 
                    S  = c_S, 
                    D  = c_D) * cycle_length
    
    ## Store state rewards 
    # Store the vectors of state utilities for each strategy in a list 
    l_u   <- list(SoQ = v_u_SoC,
                  A   = v_u_trtA,
                  B   = v_u_trtB)
    # Store the vectors of state cost for each strategy in a list 
    l_c   <- list(SoQ = v_c_SoC,
                  A   = v_c_trtA,
                  B   = v_c_trtB)
    
    # assign strategy names to matching items in the lists
    names(l_u) <- names(l_c) <- v_names_str
    
    # Create empty vectors to store total utilities and costs 
    v_tot_qaly <- v_tot_cost <- vector(mode = "numeric", length = n_str)
    names(v_tot_qaly) <- names(v_tot_cost) <- v_names_str
    
    ## Loop through each strategy and calculate total utilities and costs 
    for (i in 1:n_str) {
      v_u_str <- l_u[[i]]   # select the vector of state utilities for the i-th strategy
      v_c_str <- l_c[[i]]   # select the vector of state costs for the i-th strategy
      
      ### Expected QALYs and costs per cycle 
      ## Vector of QALYs and Costs
      # Apply state rewards 
      v_qaly_str <- l_m_M[[i]] %*% v_u_str # sum the utilities of all states for each cycle
      v_cost_str <- l_m_M[[i]] %*% v_c_str # sum the costs of all states for each cycle
      
      ### Discounted total expected QALYs and Costs per strategy and apply within-cycle correction if applicable
      # QALYs
      v_tot_qaly[i] <- t(v_qaly_str) %*% (v_dwe * v_wcc)
      # Costs
      v_tot_cost[i] <- t(v_cost_str) %*% (v_dwc * v_wcc)
    }
    
    ## Vector with discounted net monetary benefits (NMB)
    v_nmb <- v_tot_qaly * n_wtp - v_tot_cost
    
    ## data.frame with discounted costs, effectiveness and NMB
    df_ce <- data.frame(Strategy = v_names_str,
                        Cost     = v_tot_cost,
                        Effect   = v_tot_qaly,
                        NMB      = v_nmb)
    
    return(df_ce)
  }
  )
}


#------------------------------------------------------------------------------#
####             Generate a PSA input parameter dataset                     ####
#------------------------------------------------------------------------------#
#' Generate parameter sets for the probabilistic sensitivity analysis (PSA)
#'
#' \code{generate_psa_params} generates a PSA dataset of the parameters of the 
#' cost-effectiveness analysis.
#' @param n_sim Number of parameter sets for the PSA dataset
#' @param seed Seed for the random number generation
#' @return A data.frame with a PSA dataset of he parameters of the 
#' cost-effectiveness analysis
#' @export
generate_psa_params <- function(n_sim = 1000, seed = 071818){
  set.seed(seed) # set a seed to be able to reproduce the same results
  df_psa <- data.frame(
    # Transition rates
    # rate of dying
    r_SD      = rlnorm(n_sim, meanlog =   log(0.1), sdlog =  (0.1)),  # from sick
    
    # probability of becoming sick when healthy, conditional on surviving
    r_HS_SoC  =  rlnorm(n_sim, meanlog =   log(0.05), sdlog = 0.05) ,      # standard of care
    r_HS_trtA =  rlnorm(n_sim, meanlog =   log(0.04), sdlog = 0.04) ,
    r_HS_trtB =  rlnorm(n_sim, meanlog =   log(0.02), sdlog = 0.02) ,
    
    ## State rewards
    # Costs
    c_H       = rgamma(n_sim, shape = 16, scale = 25),        # cost of one cycle in healthy state
    c_S       = rgamma(n_sim, shape = 100, scale = 10),       # cost of one cycle in sick state
    c_D       = 0,                                            # cost of one cycle in dead state
    c_trtA    = 800,                                          # cost of treatment A (per cycle) in healthy state
    c_trtB    = 1500,                                         # cost of treatment B (per cycle) in healthy state
    
    # Utilities
    u_H       = rbeta(n_sim, shape1 =  1.5, shape2 = 0.0015), # utility when healthy 
    u_S       = rbeta(n_sim, shape1 = 49.5, shape2 = 49.5),   # utility when sick
    u_D       = 0                                             # utility when dead
  )
  return(df_psa)
}


# Function to run one way sensitivity analysi
owsa_tornado_new <- function (owsa, return = c("plot", "data"), txtsize = 12, min_rel_diff = 0, 
                              col = c("full", "bw"), n_y_ticks = 8, ylim = NULL, ybreaks = NULL, select_str = NULL, outcome_name = NULL,params_basecase  = NULL, FUN = calculate_ce_out, n_wtp = NULL) 
{
  
  #browser()

  if(is.null(params_basecase)){
    stop("Please provide base case parameters")    
  }
  if(is.null(n_wtp) & sum(outcome_name %in% c("IMB","NMB","ICER"))>0){
    stop("Please specify a willingness to pay threshold")    
  }
  if(is.null(select_str)){
    stop("Please provide at least 1 strategy name")    
  }
  
  
  if (length(select_str)>2){
    stop("Please select a max of 2 strategies")
    
  }
  
  
  # run basecase analysis
  
  if ( sum(outcome_name %in% c("IMB","NMB","ICER"))>0){
    res_out <- FUN(params_basecase, n_wtp = n_wtp)
  }else{
    res_out <- FUN(params_basecase)
    
  }  
  res_out <- res_out %>%
    mutate(Strategy = gsub(" ", ".", Strategy))
  
  if(is.null(select_str)){
    select_str <- owsa$strategy[1]
    owsa <- owsa %>% filter(strategy ==select_str[1])
    y_label <-  paste0(outcome_name, " (", select_str[1],")")
    avg <- res_out[res_out$Strategy %in% select_str[1], outcome_name]
    
  }else{
    if (length(select_str)==2){
      owsa_base <- owsa%>%filter(strategy == select_str[1])
      owsa_comp <- owsa%>%filter(strategy == select_str[2])
      owsa_join <- inner_join(owsa_base,owsa_comp,by = c("parameter", "param_val"))
      owsa_join$outcome_val <- owsa_join$outcome_val.y - owsa_join$outcome_val.x 
      owsa_join$strategy    <- owsa_join$strategy.x
      owsa                  <- owsa_join%>% select(parameter ,strategy,  param_val, outcome_val ) 
      y_label               <- paste0(outcome_name, " (", select_str[2], " - ", select_str[1],")")
      
      avg <- res_out[res_out$Strategy %in% select_str[2], outcome_name] - 
        res_out[res_out$Strategy %in% select_str[1], outcome_name]
      
    }else{
      
      owsa       <- owsa %>% filter(strategy ==select_str)
      y_label    <-  paste0(outcome_name, " (", select_str,")")
      avg        <- res_out[res_out$Strategy %in% select_str, outcome_name]
      
      
      
    }
  }
  parameter <- param_val <- outcome_val <- strategy <- outcome_val.low <- outcome_val.high <- abs_diff <- rel_diff <- NULL
  if (!dampack:::is_owsa(owsa)) {
    stop("must provide an owsa object created with owsa()")
  }
  if (min_rel_diff < 0 || min_rel_diff > 1) {
    stop("min_rel_diff must be between 0 and 1")
  }
  owsa_filt <- owsa %>% group_by(parameter, param_val) %>% 
    arrange(outcome_val) %>% slice(n()) %>% select(-strategy) %>% 
    ungroup()
  mins <- owsa_filt %>% group_by(parameter) %>% filter(param_val == 
                                                         min(param_val))
  maxes <- owsa_filt %>% group_by(parameter) %>% filter(param_val == 
                                                          max(param_val))
  #avg <- median(owsa_filt$outcome_val)
  min_max <- inner_join(mins, maxes, by = c("parameter"), 
                        suffix = c(".low", ".high")) %>% mutate(abs_diff = abs(outcome_val.high - 
                                                                                 outcome_val.low), rel_diff = abs_diff/outcome_val.low) %>% 
    arrange(-abs_diff)
  ret <- match.arg(return)
  if (ret == "plot") {
    if (is.null(outcome_name)){
      outcome_name <-"Outcome"
    }
    
    g <- ggplot(min_max, aes(x = reorder(min_max$parameter, 
                                         min_max$abs_diff))) + geom_bar(aes(y = outcome_val.low, 
                                                                            fill = "Low"), stat = "identity") + geom_bar(aes(y = outcome_val.high, 
                                                                                                                             fill = "High"), stat = "identity") + labs(x = "Parameter", 
                                                                                                                                                                       y = y_label) + coord_flip()
    col <- match.arg(col)
    
    g <- dampack:::add_common_aes(g, txtsize, col = col, col_aes = "fill", 
                                  scale_name = "Parameter\nLevel", continuous = "y", 
                                  ytrans = dampack:::offset_trans(offset = avg), n_y_ticks = n_y_ticks, 
                                  ybreaks = ybreaks, ylim = ylim) + geom_hline(yintercept = avg, 
                                                                               linetype = 3)
    return(g)
  }
  else {
    return(min_max)
  }
}
