# Information ##################################################################
# Author:  Justin Dvorak
# Project: Introduction to CART and Random Forests
# Purpose: Example code
# Date:    2025-03-05

# Import dependencies ##########################################################

library(tidyverse)
library(readxl)
library(tidymodels)
library(parsnip)
library(magrittr)
library(rpart)
library(rpart.plot)
library(randomForest)

# Example 1: Linear regression #################################################

# Read dataset 1.
readxl::read_xlsx("dataset 1.xlsx") ->
  dat1

# Create a scatter plot with regression line.
dat1 %>%
  ggplot(aes(x = x, y = y)) +
  geom_smooth(method = "lm", formula = "y ~ x") +
  geom_point() +
  labs(
    title = "Example of linear regression",
    x = "Predictor",
    y = "Outcome"
  )

# Fit a model and examine results.
lm(y ~ x, data = dat1) ->
  fit1

fit1 %>% summary()
fit1 %>% tidy()

# Predict Y and compute residuals.
fit1 %>%
  predict(newdata = dat1) ->
  y_predicted_lm
dat1 %>%
  mutate(
    y_predicted = y_predicted_lm,
    error = y_predicted - y,
    method = "Linear regression (Y ~ X)"
  ) ->
  dat1_predicted_lm

# Plot regression line and residuals.
dat1_predicted_lm %>%
  arrange(x) %>%
  ggplot(aes(x = x, y = y)) +
  geom_point(color = "darkgrey") +
  geom_path(aes(y = y_predicted)) +
  geom_segment(aes(xend = x, yend = y_predicted))  +
  labs(
    title = "Residuals and estimated regression line",
    subtitle = "Linear regression",
    x = "Predictor",
    y = "Outcome"
  )

# Show that mean residual is zero (plus or minus digital rounding error).
dat1_predicted_lm %>%
  pull(error) %>%
  mean()

# Compute mean squared error (MSE).
dat1_predicted_lm %>%
  pull(error) %>%
  raise_to_power(2) %>%
  mean()

# Fit a null model.
lm(y ~ 1, data = dat1) ->
  fit0

fit0 %>% summary()
fit0 %>% tidy()

# Predict Y and compute residuals.
fit0 %>%
  predict(newdata = dat1) ->
  y_predicted_null
dat1 %>%
  mutate(
    y_predicted = y_predicted_null,
    error = y_predicted - y,
    method = "Linear regression (Null model)"
  ) ->
  dat1_predicted_null

dat1_predicted_null %>%
  pull(error) %>%
  raise_to_power(2) %>%
  mean()

# Plot the "regression line" and residuals of the null model.
dat1_predicted_null %>%
  arrange(x) %>%
  ggplot(aes(x = x, y = y)) +
  geom_point(color = "darkgrey") +
  geom_path(aes(y = y_predicted)) +
  geom_segment(aes(xend = x, yend = y_predicted))  +
  labs(
    title = "Residuals and estimated regression line",
    subtitle = "Null model",
    x = "Predictor",
    y = "Outcome"
  )

# Compare predictions of both models.
rbind(
  dat1_predicted_lm,
  dat1_predicted_null
) %>%
  mutate(squared_error = error^2) %>%
  ggplot(aes(x = method, y = squared_error)) +
  geom_boxplot() +
  labs(
    title = "Squared error by model method",
    x = "Method",
    y = "Squared error"
  )

# Example 2: Logistic regression ###############################################

# Read dataset 2.
readxl::read_xlsx("dataset 2.xlsx") ->
  dat2

# Examine event rate.
dat2 %>%
  pull(y) %>%
  table()

# Create scatter plot of outcome by predictor.
dat2 %>%
  ggplot(aes(x = x, y = y)) +
  geom_point() +
  labs(
    title = "Dichotomous outcome",
    x = "Predictor",
    y = "Outcome"
  )

# Create box plot by outcome.
dat2 %>%
  ggplot(aes(x = y, y = x, group = y)) +
  geom_boxplot() +
  labs(
    title = "Dichotomous outcome",
    y = "Predictor",
    x = "Outcome"
  )

# Fit a logistic regression model.
glm(y ~ x, data = dat2, family = binomial()) ->
  logistic_fit

logistic_fit %>% summary()
logistic_fit %>% tidy()

# Examine predictions.
logistic_fit %>%
  predict(newdata = dat2, type = "response") ->
  y_predicted

dat2 %>%
  mutate(
    y_predicted = y_predicted,
    error = y_predicted - y,
    method = "Logistic regression (Y ~ X)"
  ) ->
  dat2_predicted_logistic

# Compute MSE.
dat2_predicted_logistic %>%
  pull(error) %>%
  raise_to_power(2) %>%
  mean()

# Compare to null model.
glm(
  formula = y ~ 1,
  data = dat2,
  family = binomial()
  ) ->
  fit_logistic_null

fit_logistic_null %>% summary()
fit_logistic_null %>% tidy()

fit_logistic_null %>%
  predict(newdata = dat2, type = "response") ->
  y_predicted

mean(dat2$y)

dat2 %>%
  mutate(
    y_predicted = y_predicted,
    error = y_predicted - y,
    method = "Logistic regression (null model)"
  ) ->
  dat2_predicted_logistic_null

dat2_predicted_logistic_null %>%
  pull(error) %>%
  raise_to_power(2) %>%
  mean()

# Compare predictions of both models.
rbind(
  dat2_predicted_logistic,
  dat2_predicted_logistic_null
) %>%
  mutate(squared_error = error^2) %>%
  ggplot(aes(x = method, y = squared_error)) +
  geom_boxplot() +
  labs(
    title = "Squared error by model method",
    x = "Method",
    y = "Squared error"
  )

# Example 3: Dichotomous classifier performance ################################

# Read dataset 3.
readxl::read_xlsx("dataset 3.xlsx") ->
  dat3

# Set up factor variables.
dat3 %<>%
  mutate(
    true = factor(true, levels = c("Positive", "Negative")),
    test = factor(test, levels = c("Positive", "Negative"))
  )

# Create 2x2 table.
dat3 %$%
  table(test, true)

# Compute metrics.
dat3 %>% accuracy(true, test)
dat3 %>% sensitivity(true, test)
dat3 %>% specificity(true, test)
dat3 %>% ppv(true, test)
dat3 %>% npv(true, test)

# Compute MSE
dat3 %>%
  mutate(error = (test != true)) %>%
  pull(error) %>%
  raise_to_power(2) %>%
  mean()

# Example 4: CART ##############################################################

# Read dataset 4.
readxl::read_xlsx("dataset 4.xlsx") ->
  dat4

# Define a function to convert 0/1 to factor.
tf_to_yn <- function(x) {
  case_when(
    x == 1 ~ "Yes",
    x == 0 ~ "No",
    is.na(x) ~ NA
  ) %>%
    factor(levels = c("Yes", "No"))
}

# Set up factor variables.
dat4 %<>%
  mutate(
    death = tf_to_yn(death),
    shock = tf_to_yn(shock),
    malnutr = tf_to_yn(malnutr),
    alcohol = tf_to_yn(alcohol),
    bowelinf = tf_to_yn(bowelinf)
  )

# In class example: dichotomize age at 65
dat4 %<>%
  mutate(
    age65 = tf_to_yn(age >= 65)
  )

# Examine outcome.
dat4 %>%
  pull(death) %>%
  table()

# Examine outcome by categorical predictors.
dat4 %$%
  table(death, shock)

dat4 %$%
  table(death, malnutr)

dat4 %$%
  table(death, alcohol)

dat4 %$%
  table(death, bowelinf)

# Examine age by outcome.
dat4 %>%
  ggplot(aes(x = death, y = age)) +
  geom_boxplot() +
  labs(
    title = "Age by outcome",
    ylab = "Age (years)",
    xlab = "Death"
  )

# Fit a classification tree.
rpart(
  formula = death ~ shock + malnutr + alcohol + age + bowelinf,
  data = dat4,
  method = "class"
) ->
  cart_fit

# Examine results.
cart_fit %>% 
  summary()

# View the generated tree.
cart_fit %>%
  rpart.plot()

# Predict probability of death.
cart_fit %>%
  predict(newdata = dat4) %>%
  as.data.frame() %>%
  pull(Yes) ->
  p_death_predicted

# Predict dichotomous outcome death.
cart_fit %>%
  predict(newdata = dat4, type = "class") ->
  death_predicted

# Compute metrics.
dat4 %>%
  mutate(
    p_death_predicted = p_death_predicted,
    death_predicted = death_predicted
  ) ->
  dat4_predicted

dat4_predicted %>% accuracy(death, death_predicted)
dat4_predicted %>% sensitivity(death, death_predicted)
dat4_predicted %>% specificity(death, death_predicted)
dat4_predicted %>% ppv(death, death_predicted)
dat4_predicted %>% npv(death, death_predicted)

# Compute MSE based on predicted probability.
dat4 %>%
  mutate(
    death_01 = case_when(death == "Yes" ~ 1, death == "No" ~ 0),
    error = (p_death_predicted - death_01)
  ) %>%
  pull(error) %>%
  raise_to_power(2) %>%
  mean()

# Example 5: Cross-validation ##################################################

# Split into training (80%) and testing (20%) sets.
set.seed(54321)
dat4 %<>%
  mutate(
    r = runif(n(), 0, 1),
    set = case_when(
      r <= 0.8 ~ "Train",
      r > 0.8 ~ "Test")
  )

dat4 %>%
  filter(set == "Train") ->
  dat4_train

dat4 %>%
  filter(set == "Test") ->
  dat4_test

# Fit the tree on the training set.
rpart(
  formula = death ~ shock + malnutr + alcohol + age + bowelinf,
  data = dat4_train,
  method = "class"
) ->
  cart_fit

# View the generated tree.
cart_fit %>%
  rpart.plot()

# Create predictions on the testing set.
cart_fit %>%
  predict(newdata = dat4_test) %>%
  as.data.frame() %>%
  pull(Yes) ->
  p_death_predicted

cart_fit %>%
  predict(newdata = dat4_test, type = "class") ->
  death_predicted

dat4_test %>%
  mutate(
    p_death_predicted = p_death_predicted,
    death_predicted = death_predicted
  ) ->
  dat4_test
  
# Compute metrics on the testing set.
dat4_test %>% accuracy(death, death_predicted)
dat4_test %>% sensitivity(death, death_predicted)
dat4_test %>% specificity(death, death_predicted)
dat4_test %>% ppv(death, death_predicted)
dat4_test %>% npv(death, death_predicted)
dat4_test %>%
  mutate(
    death_01 = case_when(death == "Yes" ~ 1, death == "No" ~ 0),
    error = (p_death_predicted - death_01)
  ) %>%
  pull(error) %>%
  raise_to_power(2) %>%
  mean()

# Notice a problem with sensitivity?
dat4_test$death %>%
  table()

# Using only one split can give results with poor event rates.

# Solution 1: multiple splits (k-fold)
n_splits <- 5
dat4 %<>%
  mutate(split = ceiling(r * n_splits))

results <- data.frame()
for (nth_split in 1:n_splits) {
  dat4 %>%
    filter(split != nth_split) ->
    dat4_train
  dat4 %>%
    filter(split == nth_split) ->
    dat4_test
  rpart(
    formula = death ~ shock + malnutr + alcohol + age + bowelinf,
    data = dat4_train,
    method = "class"
  ) ->
    cart_fit
  
  cart_fit %>%
    predict(newdata = dat4_test) %>%
    as.data.frame() %>%
    pull(Yes) ->
    p_death_predicted
  
  cart_fit %>%
    predict(newdata = dat4_test, type = "class") ->
    death_predicted
  
  dat4_test %>%
    mutate(
      p_death_predicted = p_death_predicted,
      death_predicted = death_predicted
    ) ->
    dat4_test
    
  dat4_test %<>%
    select(split, death, p_death_predicted, death_predicted)
  results %<>%
    rbind(dat4_test)
}

results %>% accuracy(death, death_predicted)
results %>% sensitivity(death, death_predicted)
results %>% specificity(death, death_predicted)
results %>% ppv(death, death_predicted)
results %>% npv(death, death_predicted)
results %>%
  mutate(
    death_01 = case_when(death == "Yes" ~ 1, death == "No" ~ 0),
    error = (p_death_predicted - death_01)
  ) %>%
  pull(error) %>%
  raise_to_power(2) %>%
  mean()
  
# Solution 2: Leave-one-out
results <- data.frame()
for (i in 1:nrow(dat4)) {
  dat4[i, ] ->
    dat4_test
  dat4[-i, ] ->
    dat4_train
  
  rpart(
    formula = death ~ shock + malnutr + alcohol + age + bowelinf,
    data = dat4_train,
    method = "class"
  ) ->
    cart_fit
  
  cart_fit %>%
    predict(newdata = dat4_test) %>%
    as.data.frame() %>%
    pull(Yes) ->
    p_death_predicted
  
  cart_fit %>%
    predict(newdata = dat4_test, type = "class") ->
    death_predicted
  
  dat4_test %>%
    mutate(
      p_death_predicted = p_death_predicted,
      death_predicted = death_predicted,
      split = i
    ) ->
    dat4_test
  
  dat4_test %<>%
    select(split, death, p_death_predicted, death_predicted)
  results %<>%
    rbind(dat4_test)
}

results %>% accuracy(death, death_predicted)
results %>% sensitivity(death, death_predicted)
results %>% specificity(death, death_predicted)
results %>% ppv(death, death_predicted)
results %>% npv(death, death_predicted)
results %>%
  mutate(
    death_01 = case_when(death == "Yes" ~ 1, death == "No" ~ 0),
    error = (p_death_predicted - death_01)
  ) %>%
  pull(error) %>%
  raise_to_power(2) %>%
  mean()

# Example 6: Random forest #####################################################

# Build random forest on the sepsis data.
randomForest(
  formula = death ~ shock + malnutr + alcohol + age + bowelinf,
  data = dat4,
  method = "class"
) ->
  rf_fit

# Create a variable importance plot (based on permuted OOB error).
importance(rf_fit)
varImpPlot(rf_fit)

# Predict death.
rf_fit %>%
  predict(newdata = dat4) ->
  death_predicted

rf_fit %>%
  predict(newdata = dat4, type = "prob") %>%
  as.data.frame() %>%
  pull(Yes) ->
  p_death_predicted

dat4 %>%
  mutate(
    death_predicted = death_predicted,
    p_death_predicted = p_death_predicted) ->
  dat4_predicted

# Compute metrics (naive approach)
dat4_predicted %>% accuracy(death, death_predicted)
dat4_predicted %>% sensitivity(death, death_predicted)
dat4_predicted %>% specificity(death, death_predicted)
dat4_predicted %>% ppv(death, death_predicted)
dat4_predicted %>% npv(death, death_predicted)
dat4_predicted %>%
  mutate(
    death_01 = case_when(death == "Yes" ~ 1, death == "No" ~ 0),
    error = (p_death_predicted - death_01)
  ) %>%
  pull(error) %>%
  raise_to_power(2) %>%
  mean()