Optimizing Decision Boundaries


The output of most ML models is a continuous “probability” score, which must be converted into a class label by applying some kind of decision boundaries. This can be achieved in a number of ways, ranging from simple to quite complex. As an illustrative example, we will use the normal saline prediction task, with the outputs from the real-time model as the predictions, and the outputs from the retrospective model as the ground truth.

Visual Inspection

A relatively simple approach to setting a decision boundary is to visually inspect the distribution of predicted probabilities for each class, and draw a separator where you deem the balance of sensitivity and specificity is appropriate. Below, with R, we will visualize the distributions as a standard density plot.

Show the Code
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(tidymodels))

## Load models
model_realtime <- read_rds("https://figshare.com/ndownloader/files/45631488") |> pluck("model") |> bundle::unbundle()
model_retrospective <- read_rds("https://figshare.com/ndownloader/files/45631491") |> pluck("model") |> bundle::unbundle()
predictors <- model_retrospective |> extract_recipe() |> pluck("term_info") |> dplyr::filter(role == "predictor") |> pluck("variable")

## Load validation data
options(timeout=300)
validation <- arrow::read_feather("https://figshare.com/ndownloader/files/45407398") |> select(any_of(predictors))
validation_no_NA <- validation |> drop_na(matches("prior|post|potassium"))

## Make predictions using real-time model
validation_predictions <- 
  bind_cols(
    validation_no_NA,
    predicted_probability = predict(model_realtime, new_data = validation_no_NA, type = "prob") |> pluck(".pred_1"),
    ground_truth = factor(predict(model_retrospective, new_data = validation_no_NA, type = "class") |> pluck(".pred_class"), labels = c("Negative", "Positive"))
    ) |> 
  mutate(predicted_class = factor(ifelse(predicted_probability > 0.5, "Positive", "Negative")))

## Calculate performance metrics at a threshold of 0.5
sens <- sensitivity(validation_predictions, estimate = predicted_class, truth = ground_truth, event_level = "second")
spec <- specificity(validation_predictions, estimate = predicted_class, truth = ground_truth, event_level = "second")
pos_pred_value <- ppv(validation_predictions, estimate = predicted_class, truth = ground_truth, event_level = "second")
neg_pred_value <- npv(validation_predictions, estimate = predicted_class, truth = ground_truth, event_level = "second")
matthews_corr_coef <- mcc(validation_predictions, estimate = predicted_class, truth = ground_truth, event_level = "second")
                                    
## Visualizing the differences in the distributions between positive and negative classes. 
gg_dist <- 
  ggplot(validation_predictions, aes(x = predicted_probability, fill = ground_truth)) +
    geom_density(bw = 0.1) +
    geomtextpath::geom_textdensity(aes(label = ground_truth, color = ground_truth, hjust = ground_truth), 
                                   linewidth = 1.25, alpha = 0.75, bw = 0.1, 
                                   vjust = -0.5, fontface = "bold", size = 8) +
    scale_y_continuous(name = "Proportion", expand = c(0, 0), labels = NULL) +
    scale_x_continuous(name = "Predicted Probability", expand = c(0.01, 0.01), breaks = c(0, 0.5, 1)) +
    scale_discrete_manual(aes = "hjust", values = c(0.1, 0.95)) +
    scico::scale_fill_scico_d(palette = "lipari", begin = 0.1, end = 0.5) +
    scico::scale_color_scico_d(palette = "lipari", begin = 0.1, end = 0.5) +
    labs(x = "Predicted Probability", y = "Density", fill = "Ground Truth") +
    ggtitle("Distribution of Predicted Probabilities by Class") +
    theme(legend.position = "none", axis.line.y.left = element_blank())

gg_dist +   
  geom_vline(xintercept = 0.5, linetype = "dashed", linewidth = 1.5) +
  annotate("text", x = 0.52, y = 3, label = "Threshold", fontface = "bold", angle = -90, size = 6)

Performance Metrics

Sens: 0.793

Spec: 0.996

PPV: 0.547

NPV: 0.999

MCC: 0.656

Youden’s J Index and the ROC Curve

The Receiver Operating Characteristic (ROC) curve has long been a staple in the evaluation of binary classifiers, and is another useful tool for setting a decision boundary. When using an ROC curve to establish a decision boundary, it is common to calculate the point at which the sum of sensitivity and specificity are maximized. This can be achieved by calculating Youden’s J index across multiple thresholds.

Show the Code
## Calculate Sensitivity and Specificity Across All Thresholds
roc <- 
  roc_curve(validation_predictions, predicted_probability, truth = ground_truth, event_level = "second") |> 
  mutate(youden = sensitivity + specificity - 1) # Add Youden's J at each threshold

## Find the Threshold that maximizes Youden's J
threshold <- roc[which.max(roc$youden), ]

## Plot ROC Curve
ggplot(roc, aes(x = 1 - specificity, y = sensitivity)) +
  geom_step(size = 1.25) +
  geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "black", size = 1) +
  annotate("segment", x = 1 - threshold[["specificity"]], xend = 1 - threshold[["specificity"]], y = threshold[["sensitivity"]], yend = threshold[[".threshold"]], color = "grey80", linewidth = 1.25) + 
  labs(x = "1 - Specificity", y = "Sensitivity") +
  annotate("text", x = 1 - threshold[["specificity"]] + 0.02, y = 0.6, label = "Youden's J", fontface = "bold.italic", color = "grey80", size = 6, hjust = 0) +
  labs(x = "1 - Specificity", y = "Sensitivity") +
  ggtitle("ROC Curve")





Youden’s J Index

Definition:
J = Sensitivity + Specificity - 1

Key Consideration:
Maximizing Youden’s J will provide the optimal threshold iff sensitivity and specificity are equally important.

Decision Boundaries Using Imbalance-Sensitive Metrics

Maximizing Youden’s J index works well if we care equally about sensitivity and specificity, but real-life clinical scenarios in which a false positive and a false negative are equally disruptive are rare. Additionally, the section on Measuring Performance highlighted that metrics like sensitivity and specificity can be misleading in the setting of a class imbalance. Let’s explore how we might use metrics like the NPV, PPV, and MCC to set a decision boundary. For this, we’ll use the {probably}1 package.

Show the Code
library(probably)

# Define our thresholds to test
thresholds <- seq(0.01, 0.99, by = 0.01)
  
# Define our metrics of interest
decision_metrics <- metric_set(ppv, mcc, j_index)

# Calculate each metric across each threshold
decision_curves <- 
  validation_predictions |> 
  threshold_perf(truth = ground_truth, estimate = predicted_probability, 
                 metrics = decision_metrics, 
                 thresholds = thresholds, 
                 event_level = "second")
max_mcc <- decision_curves |> dplyr::filter(.metric == "mcc") |> arrange(desc(.estimate)) |> slice_head(n = 1)
max_J <- decision_curves |> dplyr::filter(.metric == "j_index") |> arrange(desc(.estimate)) |> slice_head(n = 1)

# Plot Results
ggplot(decision_curves, aes(x = .threshold, y = .estimate, color = .metric)) + 
  geomtextpath::geom_textline(aes(label = str_to_upper(.metric)), linewidth = 1.5, fontface = "bold", size = 8, hjust = 0.25) + 
  geom_vline(xintercept = max_mcc[[".threshold"]], linetype = "dashed") +
  geom_text(data = max_mcc, x = max_mcc[[".threshold"]] - 0.02, y = 0.1, hjust = 1, fontface = "bold", label = glue::glue("Max MCC at a cut-off of ", max_mcc[[".threshold"]])) + 
  geom_point(data = max_mcc, size = 6, aes(color = "mcc")) + 
  geom_vline(xintercept = max_J[[".threshold"]], linetype = "dashed") +
  geom_text(data = max_J, x = max_J[[".threshold"]] + 0.02, y = 0.1, hjust = 0, fontface = "bold", label = glue::glue("Max Youden's J at a cut-off of ", max_J[[".threshold"]])) + 
  geom_point(data = max_J, size = 6, aes(color = "j_index")) + 
  scico::scale_color_scico_d(palette = "lipari", begin = 0.1, end = 0.9) +
  scale_x_continuous(name = "Prediction Threshold", breaks = c(0, 0.5, 1)) + 
  scale_y_continuous(name = "Metric Value", breaks = c(0, 0.5, 1)) + 
  ggtitle("Performance Metrics Across a Range of Thresholds") + 
  theme(legend.position = "none")

Using the more imbalance-sensitive metrics, we can see that the threshold that maximizes the Matthews Correlation Coefficient (MCC) is substantially different from the threshold that maximizes Youden’s J index.

Equivocal Zones and the No-Prediction Rate

In some clinical scenarios, where the cost of a false positive and/or a false negative prediction is high, it may be beneficial to define an “equivocal zone” where the model’s continuous output is not converted into a class prediction. This is particularly helpful when planning for fully automated implementations, where extreme confidence is needed before deciding to forgo a human review step.

Show the Code
gg_dist_with_equiv <-
  gg_dist + 
    geom_vline(xintercept = c(0.25, 0.75), linetype = "dashed") + 
    annotate("rect", xmin = 0.25, xmax = 0.75, ymin = 0, ymax = 5, fill = "gray70", alpha = 0.5) + 
    annotate("text", x = 0.5, y = 2.5, label = "Equivocal", hjust = 0.5, size = 6, fontface = "bold") + 
    ## Add a segment with an arrow on either side
    annotate("segment", x = 0.6, xend = 0.73, y = 2.5, yend = 2.5, arrow = arrow(type = "closed", length = unit(0.1, "inches"))) +
    annotate("segment", x = 0.4, xend = 0.27, y = 2.5, yend = 2.5, arrow = arrow(type = "closed", length = unit(0.1, "inches"))) + 
    ggtitle("Decision Boundaries with an Equivocal Zone")
gg_dist_with_equiv

For example, let’s suppose that we can tolerate not making predicitons on a subset of results to improve our PPV and NPV. We can again use the {probably}1 package to set an equivocal zone between these thresholds, shown in code below.

Show the Code
# Define our equivocal zone
predictions_with_equivocal_zone <-
  validation_predictions |> 
    mutate(
      .pred = make_two_class_pred(
        estimate = 1 - predicted_probability, 
        levels = levels(ground_truth),
        threshold = 0.4,
        buffer = 0.25
      )
    )

# Calculate the reportable rate and performance metrics
class_metrics <- metric_set(ppv, npv, mcc)

performance_without_equivocal <- class_metrics(validation_predictions, truth = ground_truth, estimate = predicted_class, event_level = "second") |> mutate(type = "Standard", reportable_rate = 1)
performance_with_equivocal <- class_metrics(predictions_with_equivocal_zone, truth = ground_truth, estimate = .pred, event_level = "second") |> mutate(type = "With Equivocal Zone", reportable_rate = round(reportable_rate(predictions_with_equivocal_zone$.pred), digits = 3))

# Combine the results
performance <- 
  bind_rows(performance_without_equivocal, performance_with_equivocal) |> 
  pivot_wider(names_from = .metric, values_from = c(.estimate)) |> 
  select(-.estimator)

# Print the results
performance |> knitr::kable(digits = 3)
type reportable_rate ppv npv mcc
Standard 1.000 0.586 0.999 0.669
With Equivocal Zone 0.995 0.859 0.999 0.793

Key Takeaway
The incorporation of an Equivocal Zone can improve performance at the cost of the proportion of results with predictions.

References

1.
Kuhn M, Vaughan D, Ruiz E. Probably: Tools for post-processing predicted values [Internet]. 2024. Available from: https://github.com/tidymodels/probably