Selecting Effective Performance Metrics


Many clinical applications of ML involve the prediction of rare events. In these cases, the classic metrics of discriminatory performance (e.g. accuracy, sensitivity, specificity, and the area under the ROC curve) may provide overly optimistic estimates of real-world performance.

Let’s explore the effect of class imbalance on various performance metrics by simulating a model’s predictions across varying degrees of imbalance.


Measuring Performance on Imbalanced Classes


  1. Randomly sample 100,000 results from the negative class.

  2. Add to it a random sample of positives at varying class imbalance.

  3. Calculate classic performance metrics and build ROC/PR curves.

  4. Compare the effect of class imbalance on each metric.

As our positive class – the event we are trying to predict – becomes rarer and rarer, we will need to adjust the performance metrics we use to evaluate our machine learning pipeline.

Show the Code
# Define prevalence levels for class imbalance
prevalences <- c(0.001, 0.01, 0.1, 0.5)
tmp <- tibble()

# Build multiple datasets with varying class imbalance
for (prevalence in prevalences) {
  
  n_pos <- 100000 * prevalence 
  n_neg <- 100000 * (1 - prevalence)
  pos <- tibble(result = rnorm(n_pos, mean = 0.65, sd = 0.15), label = "Positive", prevalence = prevalence)
  neg <- tibble(result = rnorm(n_neg, mean = 0.35, sd = 0.15), label = "Negative", prevalence = prevalence)
  tmp <- bind_rows(tmp, pos, neg)
  
}

# Assign class labels based on predicted probability for each dataset
gg_input <- 
  tmp |> 
    mutate(label = factor(label, levels = c("Positive", "Negative")), 
           estimate = factor(ifelse(result > 0.5, "Positive", "Negative"), levels = c("Positive", "Negative")),
           prevalence = factor(prevalence))

# Calculate performance metrics for each dataset
gg_metric_input <- 
  gg_input |>
  group_by(prevalence) |>
  summarise(
    Accuracy = accuracy_vec(label, estimate),
    Sens = sens_vec(label, estimate),
    Spec = spec_vec(label, estimate),
    PPV = ppv_vec(label, estimate),
    NPV = npv_vec(label, estimate),
    MCC = mcc_vec(label, estimate),
    auROC = roc_auc_vec(label, result), 
    auPRC = pr_auc_vec(label, result)) |> 
  pivot_longer(cols = -prevalence, names_to = "Metric", values_to = "Value") |> 
  mutate(Value = round(Value, 3), Metric = factor(Metric, levels = c("auROC", "Accuracy", "Sens", "Spec", "NPV", "PPV", "MCC", "auPRC")))

# Build ROC curves for each dataset
gg_roc_input <- 
  gg_input |>
  group_by(prevalence) |>
  roc_curve(label, result)

# Build PR curves for each dataset
gg_pr_input <- 
  gg_input |>
  group_by(prevalence) |>
  pr_curve(label, result)

# Plot ROC curves for each dataset
gg_rocs <- 
  ggplot(gg_roc_input, aes(1-specificity, sensitivity, color = prevalence)) +
    geom_path(linewidth = 2) +
    geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "grey50") +
    scico::scale_color_scico_d(palette = "lipari", begin = 0.1, end = 0.9, labels = c("Extremely Imbalanced", "Imbalanced", "Mildly Imbalanced", "Balanced")) +
    guides(color = guide_legend(title = "Class Imbalance", reverse = T)) +
    labs(x = "1 - Specificity", y = "Sensitivity", title = "ROC Curves") + 
    theme(axis.text = element_blank(), legend.position = c(0.8, 0.3), legend.background = element_blank())
  
# Plot PR curves for each dataset
gg_prs <- 
  ggplot(gg_pr_input, aes(recall, precision, color = prevalence)) +
    geom_path(linewidth = 2) +
    scico::scale_color_scico_d(palette = "lipari", begin = 0.1, end = 0.9) +
    labs(x = "Recall (Sensitivity)", y = "Precision (PPV)", title = "PR Curves") + 
    theme(axis.text = element_blank(), legend.position = "none")

# Plot the ROC and PR curves side-by-side and save them to files of each format
gg_imbalance_curves <- ggpubr::ggarrange(gg_rocs, gg_prs, nrow = 1, ncol = 2)
ggsave("../../figures/imbalanced_curves.png", gg_imbalance_curves, width = 10, height = 4, dpi = 600)
ggsave("../../figures/imbalanced_curves.pdf", gg_imbalance_curves, width = 10, height = 4)
ggsave("../../figures/imbalanced_curves.svg", gg_imbalance_curves, width = 10, height = 4)

# Plot the performance metrics for each dataset
gg_metrics <-
  ggplot(gg_metric_input, aes(Metric, Value, fill = fct_rev(prevalence))) +
    geom_bar(stat = "identity", position = "dodge") +
    scico::scale_fill_scico_d(palette = "lipari", begin = 0.9, end = 0.1) +
    scale_y_continuous(name = "Metric Value", limits = c(0, 1), breaks = c(0, 1)) +
    scale_x_discrete(name = NULL) +
    ggtitle("Binary Metrics") + 
    theme(axis.text.x = element_text(angle = 45, hjust = 1), legend.position = "none", axis.text.x.bottom = element_text(angle = 0, face = "bold", hjust = 0.5))

# Plot the curves and metrics side-by-side and save them to files of each format
gg_imbalance_metrics_combined <- 
  ggpubr::ggarrange(
    gg_imbalance_curves, 
    gg_metrics,
      ncol = 1, nrow = 2, heights = c(0.6, 0.4))

ggsave("../../figures/imbalanced_metrics_combined.png", gg_imbalance_metrics_combined, width = 10, height = 6, dpi = 600, bg = NULL)
ggsave("../../figures/imbalanced_metrics_combined.pdf", gg_imbalance_metrics_combined, width = 10, height = 6, dpi = 600, bg = NULL)
ggsave("../../figures/imbalanced_metrics_combined.svg", gg_imbalance_metrics_combined, width = 10, height = 6, dpi = 600, bg = NULL)

# Display the combined plot
gg_imbalance_metrics_combined

Figure 1: The effect of class imbalance on various performance metrics. Gray lines and bars highlight a balanced classification task, while red represents a task where negatives far outnumber positives.

Here, we see a stark contrast between metrics that are sensitive to class imbalance (e.g. PPV/NPV, MCC, and the Precision-Recall curve) and those that are not (e.g. Sensitivity, Specificity, Accuracy, and the ROC curve).

Key Takeaway:
When predicting rare events, use Precision-Recall Curves, Positive Predictive Value, and the Matthews Correlation Coefficient (MCC) to better assess clinical utility in real-world applications.