EpiClass accurately predicts EpiATLAS assay and biospecimen metadata

Author

Joanny Raby

Results section 1 figures

Formatting of the figures may not be identical to the paper, but they contain the same data points.

All code is folded by default, click on “Code” to expand it.

ImportantIMPORTANT

THIS IS A WORK IN PROGRESS. Most figures are here, but some figure captions and code cell basic descriptions are still missing.

Article text

The harmonized EpiATLAS data and metadata used to develop EpiClass comprise 7,464 datasets (experiments) from 2,216 epigenomes (biological samples) generated by consortia such as ENCODE, Blueprint and CEEHRC. Each epigenome included data from up to nine different assays, which we hereafter refer to as our ‘core assays’: six ChIP-Seq histone modifications sharing a single control Input file, RNA-Seq and WGBS (Fig. 1A, Supplementary Table 1). The total training set included 20,922 signal files, comprising multiple normalization outputs per ChIP-Seq dataset (raw, fold change and p-value) and strand-specific files for RNA-Seq and WGBS assays. The rationale to include the three different track types and stranded tracks was to increase the robustness of the classifiers and increase the size of the training set.

Using five different machine learning approaches, we evaluated classification performance through stratified 10-fold cross-validation on 100 kb non-overlapping genome-wide bins (excluding the Y chromosome) (Methods). The Assay classifiers achieved ~99% accuracy, F1-score, and Area Under the Curve of the Receiver Operating Characteristic (AUROC), while the Biospecimen classifiers reached ~95% across the 16 most abundant classes comprising 84% of the epigenomes (the remaining being distributed in 46 smaller classes ignored) (Fig. 1B, Supplementary Fig. 1A-C).

The Multi-Layer Perceptron (MLP, or dense feedforward neural network) showed marginally superior performance on the more complex biospecimen classification task having important class-imbalance (certain classes being either over or under-represented) (Supplementary Table 2). As our primary goal was to establish a proof-of-concept for this approach, we selected the MLP and focused our subsequent efforts on assessing the model’s performance across different genomic resolutions, rather than on exhaustive hyperparameter optimization. Further analysis with this approach revealed that larger genome-wide bins (1 Mb and 10 Mb) substantially decreased performance, while smaller bins (10 kb and 1 kb) offered minimal improvements despite greatly increasing computational demand (Fig. 1C, Supplementary Table 2). Additional data preprocessing steps, including blacklisted region removal and winsorization, showed no significant impact on performance (Supplementary Fig. 1D, Supplementary Table 3, Methods), leading us to adopt the 100 kb bins resolution without further filtering to simplify subsequent analyses.

We also evaluated alternative genomic features including protein coding genes (~68 kb on average), cis-regulatory elements showing high correlation between H3K27ac level and gene expression (avg. ~2.3 kb)26, and highly variable DNA methylation segments (200 bp)27. For both Assay and Biospecimen classifiers, none of these alternative feature sets improved the average accuracy by more than 1% compared to the 100 kb bins, with the notable exception of WGBS data. In this case accuracy improved substantially from 85% with 100 kb bins to 93% when using smaller bin sizes and more relevant features (Supplementary Fig. 2, Supplementary Table 2). These findings validated our choice of using the 100 kb approach as an effective compromise, providing comprehensive genome-wide coverage without introducing selection bias from predefined regions, while maintaining strong classification performance and simplifying data processing.

Interestingly, the confusion matrix of the Assay classifier revealed that the very few prediction errors of some individual files occur mainly in specific scenarios: they arise between different protocol types of RNA-seq (mRNA vs total RNA) and WGBS (standard vs PBAT), they involve misclassifications with control Input datasets, or they occur between the activating histone marks (H3K27ac, H3K4me3, H3K4me1) that are typically localized around promoters/enhancers (Fig. 1D). These confusion patterns are all biologically understandable given the functional similarities. For the ChIP and RNA-seq assays, the vast majority of prediction scores exceeded 0.98 and are above 0.9 for biospecimen prediction, with much lower scores for Input and WGBS assays as expected at the chosen resolution (Supplementary Fig. 1E-F, Supplementary File 1). Importantly, the classifier performances are positively correlated with the prediction scores, allowing to use the score as a reliable confidence metric (Supplementary Fig. 1H-I). Increasing the prediction score threshold empirically increases performance, even though the scores should not be directly interpreted as true probabilities.

EpiClass demonstrated practical utility during the development phase by identifying eleven datasets with potentially incorrect assay annotation. After reviewing our findings, data generators examined their original datasets and decided to correct one sample swap between two datasets, and excluded eight contaminated datasets from subsequent EpiATLAS versions (Fig. 1E, Supplementary Fig. 3, Supplementary Table 4). The Assay classifier also validated imputed ChIP datasets from EpiATLAS, achieving perfect predictions and very high prediction scores across all assays (Supplementary Fig. 1G, Supplementary File 2, Methods). Additionally, EpiClass contributed to the identification of 134 low-quality ChIP datasets that were also excluded by the EpiATLAS harmonization working group through notably low prediction scores (or high Input prediction score), indicating noisy signal (Supplementary Fig. 4, Supplementary Table 4).

Setup Code - Imports and co.

Setup imports.

Code
from __future__ import annotations

import copy
import logging
import re
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import Dict, Optional, Tuple
import re

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.core.display import Image
from IPython.display import display
from plotly.subplots import make_subplots
from sklearn.metrics import auc, confusion_matrix as sk_cm, roc_curve
from sklearn.preprocessing import label_binarize

from epiclass.core.confusion_matrix import ConfusionMatrixWriter
from epiclass.utils.notebooks.paper.metrics_per_assay import MetricsPerAssay
from epiclass.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    ASSAY_ORDER,
    CELL_TYPE,
    SEX,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
    extract_input_sizes_from_output_files,
    merge_similar_assays,
)

Setup paths.

Code
base_dir = Path.home() / "Projects/epiclass/output/paper"
paper_dir = base_dir
if not paper_dir.exists():
    raise FileNotFoundError(f"Directory {paper_dir} does not exist.")

base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"

Setup colors.

Code
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map
cell_type_colors = IHECColorMap.cell_type_color_map

Setup metadata and prediction files handlers.

Code
split_results_handler = SplitResultsHandler()

metadata_handler = MetadataHandler(paper_dir)
metadata_v2 = metadata_handler.load_metadata("v2")
metadata_v2_df = metadata_v2.to_df()

Setup data directories.

Code
gen_data_dir = base_data_dir / "training_results" / "dfreeze_v2"
if not gen_data_dir.exists():
    raise FileNotFoundError(f"Directory {gen_data_dir} does not exist.")

data_dir_100kb = gen_data_dir / "hg38_100kb_all_none"
if not data_dir_100kb.exists():
    raise FileNotFoundError(f"Directory {data_dir_100kb} does not exist.")

Setup figures general settings.

Code
main_title_settings = {
    "title":dict(
        automargin=True,
        x=0.5,
        xanchor="center",
        yanchor="top",
        y=0.98
        ),
    "margin":dict(t=50, l=10, r=10)
}

Figure 1

Performance of EpiClass Assay and Biospecimen classifiers.

A - EpiClass training overview

Fig. 1A: Overview of the EpiClass training process for various classifiers and their inference on external data. Each classifier is trained independently.

B-C Prep

Path setup.

Code
mixed_data_dir = gen_data_dir / "mixed"
if not mixed_data_dir.exists():
    raise FileNotFoundError(f"Directory {mixed_data_dir} does not exist.")

Feature sets setup.

Code
feature_sets_14 = [
    "hg38_10mb_all_none_1mb_coord",
    "hg38_100kb_random_n316_none",
    "hg38_1mb_all_none",
    "hg38_100kb_random_n3044_none",
    "hg38_100kb_all_none",
    "hg38_gene_regions_100kb_coord_n19864",
    "hg38_10kb_random_n30321_none",
    "hg38_regulatory_regions_n30321",
    "hg38_1kb_random_n30321_none",
    "hg38_cpg_topvar_200bp_10kb_coord_n30k",
    "hg38_10kb_all_none",
    "hg38_regulatory_regions_n303114",
    "hg38_1kb_random_n303114_none",
    "hg38_cpg_topvar_200bp_10kb_coord_n300k",
]
fig1_sets = [
    "hg38_10mb_all_none_1mb_coord",
    "hg38_100kb_random_n316_none",
    "hg38_1mb_all_none",
    "hg38_100kb_random_n3044_none",
    "hg38_100kb_all_none",
    "hg38_10kb_random_n30321_none",
    "hg38_1kb_random_n30321_none",
    "hg38_10kb_all_none",
    "hg38_1kb_random_n303114_none",
]
flagship_selection_4cat = [
    "hg38_100kb_all_none",
    "hg38_gene_regions_100kb_coord_n19864",
    "hg38_regulatory_regions_n30321",
    "hg38_cpg_topvar_200bp_10kb_coord_n30k",
]
different_nature_sets = [
    "hg38_regulatory_regions_n30321",
    "hg38_regulatory_regions_n303114",
    "hg38_cpg_topvar_200bp_10kb_coord_n30k",
    "hg38_cpg_topvar_200bp_10kb_coord_n300k",
    "hg38_cpg_topvar_2bp_10kb_coord_n30k",
    "hg38_cpg_topvar_2bp_10kb_coord_n300k",
    "hg38_gene_regions_100kb_coord_n19864",
    "hg38_100kb_all_none",
    "hg38_10kb_all_none",
    "hg38_10kb_random_n30321_none",
    "hg38_1kb_random_n30321_none",
    "hg38_1kb_random_n303114_none",
]

metric_orders_map = {
    "flagship_selection_4cat": flagship_selection_4cat,
    "fig1_sets": fig1_sets,
    "feature_sets_14": feature_sets_14,
    "different_nature_sets": different_nature_sets,
}

Compute input sizes for each feature set.

Code
input_sizes = extract_input_sizes_from_output_files(mixed_data_dir)  # type: ignore
input_sizes: Dict[str, int] = {k: v.pop() for k, v in input_sizes.items() if len(v) == 1}  # type: ignore

Set selection.

Code
set_selection_name = "feature_sets_14"

logdir = (
    base_fig_dir
    / "fig2_EpiAtlas_other"
    / "fig2--reduced_feature_sets"
    / "test"
    / set_selection_name
)
logdir.mkdir(parents=True, exist_ok=True)

Compute metrics.

Code
all_metrics = split_results_handler.obtain_all_feature_set_data(
    parent_folder=mixed_data_dir,
    merge_assays=True,
    return_type="metrics",
    include_categories=[ASSAY, CELL_TYPE],
    include_sets=metric_orders_map[set_selection_name],
    exclude_names=["16ct", "27ct", "7c", "chip-seq-only"],
)

# Order the metrics
all_metrics = {
    name: all_metrics[name]  # type: ignore
    for name in metric_orders_map[set_selection_name]
    if name in all_metrics
}

Label correction.

Code
# correct a name
try:
    all_metrics["hg38_100kb_all_none"][ASSAY] = all_metrics["hg38_100kb_all_none"][  # type: ignore
        f"{ASSAY}_11c"
    ]
    del all_metrics["hg38_100kb_all_none"][f"{ASSAY}_11c"]
except KeyError:
    pass

Resolution/feature set –> color mapping.

Code
resolution_colors = {
    "100kb": px.colors.qualitative.Safe[0],
    "10kb": px.colors.qualitative.Safe[1],
    "1kb": px.colors.qualitative.Safe[2],
    "regulatory": px.colors.qualitative.Safe[3],
    "gene": px.colors.qualitative.Safe[4],
    "cpg": px.colors.qualitative.Safe[5],
    "1mb": px.colors.qualitative.Safe[6],
    "5mb": px.colors.qualitative.Safe[7],
    "10mb": px.colors.qualitative.Safe[8],
}

Define function graph_feature_set_metrics.

Code
def graph_feature_set_metrics(
    all_metrics: Dict[str, Dict[str, Dict[str, Dict[str, float]]]],
    input_sizes: Dict[str, int],
    logdir: Path | None = None,
    sort_by_input_size: bool = False,
    name: str | None = None,
    y_range: Tuple[float, float] | None = None,
    boxpoints: str = "all",
    width: int = 1200,
    height: int = 1200,
) -> None:
    """Graph the metrics for all feature sets.

    Args:
        all_metrics (Dict[str, Dict[str, Dict[str, Dict[str, float]]]): A dictionary containing all metrics for all feature sets.
            Format: {feature_set: {task_name: {split_name: metric_dict}}}
        input_sizes (Dict[str, int]): A dictionary containing the input sizes for all feature sets.
        logdir (Path): The directory where the figure will be saved. If None, the figure will only be displayed.
        sort_by_input_size (bool): Whether to sort the feature sets by input size.
        name (str|None): The name of the figure.
        y_range (Tuple[float, float]|None): The y-axis range for the figure.
        boxpoints (str): The type of boxpoints to display. Can be "all" or "outliers". Defaults to "all".
    """
    if boxpoints not in ["all", "outliers"]:
        raise ValueError("Invalid boxpoints value.")

    reference_hdf5_type = "hg38_100kb_all_none"
    metadata_categories = list(all_metrics[reference_hdf5_type].keys())

    non_standard_names = {ASSAY: f"{ASSAY}_11c", SEX: f"{SEX}_w-mixed"}
    non_standard_assay_task_names = ["hg38_100kb_all_none"]
    non_standard_sex_task_name = [
        "hg38_100kb_all_none",
        "hg38_regulatory_regions_n30321",
        "hg38_regulatory_regions_n303114",
    ]
    used_resolutions = set()
    for i in range(len(metadata_categories)):
        category_idx = i
        category_fig = make_subplots(
            rows=1,
            cols=2,
            shared_yaxes=True,
            subplot_titles=["Accuracy", "F1-score (macro)"],
            horizontal_spacing=0.01,
        )

        trace_names = []
        order = list(all_metrics.keys())
        if sort_by_input_size:
            order = sorted(
                all_metrics.keys(),
                key=lambda x: input_sizes[x],
            )
        for feature_set_name in order:
            # print(feature_set_name)
            tasks_dicts = all_metrics[feature_set_name]
            meta_categories = copy.deepcopy(metadata_categories)

            if feature_set_name not in input_sizes:
                print(f"Skipping {feature_set_name}, no input size found.")
                continue

            task_name = meta_categories[category_idx]
            if "split" in task_name:
                raise ValueError("Split in task name. Wrong metrics dict.")

            try:
                task_dict = tasks_dicts[task_name]
            except KeyError as err:
                if SEX in str(err) and feature_set_name in non_standard_sex_task_name:
                    task_dict = tasks_dicts[non_standard_names[SEX]]
                elif (
                    ASSAY in str(err)
                    and feature_set_name in non_standard_assay_task_names
                ):
                    task_dict = tasks_dicts[non_standard_names[ASSAY]]
                else:
                    print("Skipping", feature_set_name, task_name)
                    continue

            input_size = input_sizes[feature_set_name]

            feature_set_name = feature_set_name.replace("_none", "").replace("hg38_", "")
            feature_set_name = re.sub(r"\_[\dmkb]+\_coord", "", feature_set_name)

            resolution = feature_set_name.split("_")[0]
            used_resolutions.add(resolution)

            trace_name = f"{input_size}|{feature_set_name}"
            trace_names.append(trace_name)

            # Accuracy
            metric = "Accuracy"
            y_vals = [task_dict[split][metric] for split in task_dict]
            hovertext = [
                f"{split}: {metrics_dict[metric]:.4f}"
                for split, metrics_dict in task_dict.items()
            ]
            category_fig.add_trace(
                go.Box(
                    y=y_vals,
                    name=trace_name,
                    boxmean=True,
                    boxpoints=boxpoints,
                    marker=dict(size=3, color="black"),
                    line=dict(width=1, color="black"),
                    fillcolor=resolution_colors[resolution],
                    hovertemplate="%{text}",
                    text=hovertext,
                    legendgroup=resolution,
                    showlegend=False,
                ),
                row=1,
                col=1,
            )

            metric = "F1_macro"
            y_vals = [task_dict[split][metric] for split in task_dict]
            hovertext = [
                f"{split}: {metrics_dict[metric]:.4f}"
                for split, metrics_dict in task_dict.items()
            ]
            category_fig.add_trace(
                go.Box(
                    y=y_vals,
                    name=trace_name,
                    boxmean=True,
                    boxpoints=boxpoints,
                    marker=dict(size=3, color="black"),
                    line=dict(width=1, color="black"),
                    fillcolor=resolution_colors[resolution],
                    hovertemplate="%{text}",
                    text=hovertext,
                    legendgroup=resolution,
                    showlegend=False,
                ),
                row=1,
                col=2,
            )

        title = f"{metadata_categories[category_idx]} classification"
        title = title.replace(CELL_TYPE, "biospecimen")
        if name is not None:
            title += f" - {name}"
        category_fig.update_layout(
            width=width,
            height=height,
            title_text=title,
            **main_title_settings
        )

        # dummy scatters for resolution colors
        for resolution, color in resolution_colors.items():
            if resolution not in used_resolutions:
                continue
            category_fig.add_trace(
                go.Scatter(
                    x=[None],
                    y=[None],
                    mode="markers",
                    name=resolution,
                    marker=dict(color=color, size=5),
                    showlegend=True,
                    legendgroup=resolution,
                )
            )

        category_fig.update_layout(legend=dict(itemsizing="constant"))

        # y-axis
        if y_range:
            category_fig.update_yaxes(range=y_range)
        else:
            if ASSAY in task_name:
                category_fig.update_yaxes(range=[0.96, 1.001])
            if CELL_TYPE in task_name:
                category_fig.update_yaxes(range=[0.75, 1])

        category_fig.update_layout(**main_title_settings)

        # Save figure
        if logdir:
            base_name = f"feature_set_metrics_{metadata_categories[category_idx]}"
            if name is not None:
                base_name = base_name + f"_{name}"
            category_fig.write_html(logdir / f"{base_name}.html")
            category_fig.write_image(logdir / f"{base_name}.svg")
            category_fig.write_image(logdir / f"{base_name}.png")

        category_fig.show()

B - Assay MLP performance

Graph 100kb resolution MLP metrics.

Code
metrics_fig1b = {name: all_metrics[name] for name in ["hg38_100kb_all_none"]}

metrics_fig1b_1 = {
    "hg38_100kb_all_none": {ASSAY: metrics_fig1b["hg38_100kb_all_none"][ASSAY]}
}
graph_feature_set_metrics(
    all_metrics=metrics_fig1b_1,  # type: ignore
    input_sizes=input_sizes,
    boxpoints="all",
    width=425,
    height=400,
    y_range=(0.98, 1.001),
)

metrics_fig1b_2 = {
    "hg38_100kb_all_none": {CELL_TYPE: metrics_fig1b["hg38_100kb_all_none"][CELL_TYPE]}
}
graph_feature_set_metrics(
    all_metrics=metrics_fig1b_2,  # type: ignore
    input_sizes=input_sizes,
    boxpoints="all",
    width=425,
    height=400,
    y_range=(0.93, 1.001),
)

Fig. 1B: Distribution of accuracy and F1-score for each of the ten training folds (dots) for the Assay and Biospecimen MLP classifiers.

C - MLP performance at varying resolution

Graph.

Code
metrics_fig1c = {name: all_metrics[name] for name in fig1_sets}

graph_feature_set_metrics(
    all_metrics=metrics_fig1c,  # type: ignore
    input_sizes=input_sizes,
    boxpoints="all",
    width=900,
    height=600,
)

Fig. 1C-alt: Distribution of accuracy per training fold for different bin resolutions for the Assay and Biospecimen classifiers.


Define function parse_bin_size to extract a numerical bin size in base pairs.

Code
def parse_bin_size(feature_set_name: str) -> Optional[float]:
    """
    Parses the feature set name to extract a numerical bin size in base pairs.
    Handles formats like '100kb', '5mb', 'regulatory', 'gene', 'cpg'.

    Returns numerical size (float) or None if unparseable or non-numeric.
    Assigns placeholder values for non-genomic-range types if needed,
    but for a continuous axis, it's better to return None or filter later.
    """
    name_parts = feature_set_name.replace("hg38_", "").split("_")
    if not name_parts:
        return None

    resolution_str = name_parts[0].lower()

    # Handle standard genomic ranges
    match_kb = re.match(r"(\d+)kb", resolution_str)
    if match_kb:
        return float(match_kb.group(1)) * 1_000
    match_mb = re.match(r"(\d+)mb", resolution_str)
    if match_mb:
        return float(match_mb.group(1)) * 1_000_000

    # Handle non-range types - decide how to represent them.
    # Option 1: Return None (they won't be plotted on the numeric axis)
    # Option 2: Assign arbitrary numbers (might distort scale)
    # Option 3: Could use different marker symbols later if needed
    if resolution_str in ["regulatory", "gene", "cpg"]:
        # For now, let's return None so they are filtered out from the numeric plot
        # Or assign a placeholder if you want to handle them differently:
        # if resolution_str == 'regulatory': return 1e1 # Example placeholder
        # if resolution_str == 'gene': return 1e2 # Example placeholder
        # if resolution_str == 'cpg': return 1e0 # Example placeholder
        return None  # Returning None is cleaner for a pure numeric axis

    # Fallback for unrecognised formats
    try:
        # Maybe it's just a number (e.g., representing window size)?
        return float(resolution_str)
    except ValueError:
        return None

Define function graph_feature_set_scatter to graph performance metrics as a scatter plot instead of bar plot.

Code
def graph_feature_set_scatter(
    all_metrics: Dict[str, Dict[str, Dict[str, Dict[str, float]]]],
    input_sizes: Dict[str, int],
    logdir: Optional[Path] = None,
    metric_to_plot: str = "Accuracy",
    name: Optional[str] = None,
    metric_range: Optional[Tuple[float, float]] = None,
    assay_task_key: str = ASSAY,
    sex_task_key: str = SEX,
    cell_type_task_key: str = CELL_TYPE,
    verbose: bool = True,
) -> None:
    """
    Graphs performance metrics as a scatter plot with modifications.

    X-axis: Number of Features (log scale).
    Y-axis: Average performance metric (e.g., Accuracy, F1_macro) across folds.
            Vertical lines indicate the min/max range across folds.
    Color: Bin Size (bp, log scale).

    Args:
        all_metrics: Nested dict {feature_set: {task_name: {split_name: metric_dict}}}.
        input_sizes: Dict {feature_set: num_features}.
        logdir: Directory to save figures. If None, display only.
        metric_to_plot: The metric key to use for the Y-axis ('Accuracy', 'F1_macro').
        name: Optional suffix for figure titles and filenames.
        metric_range: Optional tuple (min, max) to set the Y-axis range.
        assay_task_key: Key used for the assay prediction task.
        sex_task_key: Key used for the sex prediction task.
        cell_type_task_key: Key used for the cell type prediction task.
    """
    if metric_to_plot not in ["Accuracy", "F1_macro"]:
        raise ValueError("metric_to_plot must be 'Accuracy' or 'F1_macro'")

    # --- Standard Name Handling (simplified from original) ---
    non_standard_names = {ASSAY: f"{ASSAY}_11c", SEX: f"{SEX}_w-mixed"}
    # These lists are no longer strictly needed by the simplified lookup, but kept for context
    # non_standard_assay_task_names = ["hg38_100kb_all_none"]
    # non_standard_sex_task_name = [
    #     "hg38_100kb_all_none",
    #     "hg38_regulatory_regions_n30321",
    #     "hg38_regulatory_regions_n303114",
    # ]

    # --- Find reference and task names ----
    reference_hdf5_type = next(iter(all_metrics), None)
    if reference_hdf5_type is None or not all_metrics.get(reference_hdf5_type):
        print(
            "Warning: Could not determine tasks from all_metrics. Trying default tasks."
        )
        cleaned_metadata_categories = {assay_task_key, sex_task_key, cell_type_task_key}
    else:
        metadata_categories = list(all_metrics[reference_hdf5_type].keys())
        cleaned_metadata_categories = set()
        for cat in metadata_categories:
            original_name = cat
            for standard, non_standard in non_standard_names.items():
                if cat == non_standard:
                    original_name = standard
                    break
            cleaned_metadata_categories.add(original_name)

    # --- Define Bin size categories and Colors ---
    bin_category_names = ["1Kb", "10Kb", "100Kb", "1Mb", "10Mb"]
    bin_category_values = [1000, 10000, 100 * 1000, 1000 * 1000, 10000 * 1000]
    discrete_colors = px.colors.sequential.Viridis_r
    color_map = {
        name: discrete_colors[i * 2] for i, name in enumerate(bin_category_names)
    }

    if verbose:
        print(f"Plotting for tasks: {list(cleaned_metadata_categories)}")

    for category_name in cleaned_metadata_categories:
        plot_data_points = []

        for feature_set_name_orig in all_metrics.keys():
            try:
                num_features = input_sizes[feature_set_name_orig]
            except KeyError as e:
                raise ValueError(
                    f"Feature set '{feature_set_name_orig}' not found in input_sizes"
                ) from e

            # Parse Bin Size
            bin_size = parse_bin_size(feature_set_name_orig)
            if bin_size is None:
                print(
                    f"Skipping {feature_set_name_orig}, could not parse numeric bin size."
                )
                continue

            # 3. Get Metric Values (Average, Min, Max)
            tasks_dicts = all_metrics[feature_set_name_orig]

            # --- Task Name Lookup ---
            # 1. Try the standard category name first
            # 2. If standard name not found, use non-standard name
            task_dict = None
            task_name = category_name
            if category_name in tasks_dicts:
                task_dict = tasks_dicts[category_name]
            else:
                non_standard_task_name = non_standard_names.get(category_name)
                if non_standard_task_name and non_standard_task_name in tasks_dicts:
                    task_name = non_standard_task_name
                    task_dict = tasks_dicts[non_standard_task_name]

                if task_dict is None:
                    raise ValueError(
                        f"Task '{category_name}' not found in feature set '{feature_set_name_orig}'"
                    )
            # --- End Task Name Lookup ---

            # Calculate average, min, max metric value across splits
            try:
                metric_values = []
                for split, split_data in task_dict.items():
                    if metric_to_plot in split_data:
                        metric_values.append(split_data[metric_to_plot])
                    else:
                        print(
                            f"Warning: Metric '{metric_to_plot}' not found in split '{split}' for {feature_set_name_orig} / {task_name}"
                        )

                if not metric_values:
                    print(
                        f"Warning: No metric values found for {feature_set_name_orig} / {task_name} / {metric_to_plot}"
                    )
                    continue

                avg_metric = np.mean(metric_values)
                min_metric = np.min(metric_values)
                max_metric = np.max(metric_values)

            except Exception as e:  # pylint: disable=broad-except
                raise ValueError(
                    f"Error calculating metrics for {feature_set_name_orig} / {task_name}: {e}"
                ) from e

            # Clean feature set name for hover text
            clean_name = feature_set_name_orig.replace("_none", "").replace("hg38_", "")
            clean_name = re.sub(r"\_[\dmkb]+\_coord", "", clean_name)

            # Store data for this point
            plot_data_points.append(
                {
                    "bin_size": bin_size,
                    "num_features": num_features,
                    "metric_value": avg_metric,
                    "min_metric": min_metric,  # For error bar low
                    "max_metric": max_metric,  # For error bar high
                    "name": clean_name,
                    "raw_name": feature_set_name_orig,
                }
            )

        if not plot_data_points:
            raise ValueError(
                f"No suitable data points found to plot for task: {category_name}"
            )

        # --- Determine Marker Symbols ---
        marker_symbols = []
        default_symbol = "circle"
        random_symbol = "cross"
        for p in plot_data_points:
            if "random" in p["raw_name"]:
                marker_symbols.append(random_symbol)
            else:
                marker_symbols.append(default_symbol)

        # --- Group Data by Category ---
        points_by_category = {name: [] for name in bin_category_names}
        for i, point_data in enumerate(plot_data_points):
            bin_size = point_data["bin_size"]
            assigned_category = None
            for cat_name, cat_value in zip(bin_category_names, bin_category_values):
                if bin_size == cat_value:
                    assigned_category = cat_name
                    break
            else:
                raise ValueError(f"Could not find category for bin size: {bin_size}")

            points_by_category[assigned_category].append(
                {
                    "x": point_data["num_features"],  # X is Num Features
                    "y": point_data["metric_value"],
                    "error_up": point_data["max_metric"] - point_data["metric_value"],
                    "error_down": point_data["metric_value"] - point_data["min_metric"],
                    "text": point_data["name"],
                    "customdata": [
                        point_data["min_metric"],
                        point_data["max_metric"],
                        point_data["bin_size"],
                    ],  # Keep bin size for hover
                    "symbol": marker_symbols[i],  # Assign symbol determined earlier
                }
            )

        # --- Create Figure and Add Traces PER CATEGORY ---
        fig = go.Figure()
        traces = []

        for cat_name in bin_category_names:  # Iterate in defined order for legend
            points_in_cat = points_by_category[cat_name]
            if not points_in_cat:
                continue

            category_color = color_map[cat_name]

            # Extract data for all points in this category
            x_vals = [p["x"] for p in points_in_cat]
            y_vals = [p["y"] for p in points_in_cat]
            error_up_vals = [p["error_up"] for p in points_in_cat]
            error_down_vals = [p["error_down"] for p in points_in_cat]
            text_vals = [p["text"] for p in points_in_cat]
            customdata_vals = [p["customdata"] for p in points_in_cat]
            symbols_vals = [p["symbol"] for p in points_in_cat]

            trace = go.Scatter(
                x=x_vals,
                y=y_vals,
                mode="markers",
                name=cat_name,
                showlegend=False,
                legendgroup=cat_name,  # Group legend entries
                marker=dict(
                    color=category_color,
                    size=15,
                    symbol=symbols_vals,
                    line=dict(width=1, color="DarkSlateGrey"),
                ),
                error_y=dict(
                    type="data",
                    symmetric=False,
                    array=error_up_vals,
                    arrayminus=error_down_vals,
                    visible=True,
                    thickness=1.5,
                    width=15,
                    color=category_color,
                ),
                text=text_vals,
                customdata=customdata_vals,
                hovertemplate=(
                    f"<b>%{{text}}</b><br><br>"
                    f"Num Features: %{{x:,.0f}}<br>"
                    f"{metric_to_plot}: %{{y:.4f}}<br>"
                    f"Bin Size: %{{customdata:,.0f}} bp<br>"
                    f"{metric_to_plot} Range (10-fold): %{{customdata:.4f}} - %{{customdata:.4f}}"
                    "<extra></extra>"
                ),
            )
            traces.append(trace)

        fig.add_traces(traces)

        # --- Add Legend ---
        # Add a hidden scatter trace with square markers for legend
        for cat_name in bin_category_names:
            category_color = color_map[cat_name]
            legend_trace = go.Scatter(
                x=[None],
                y=[None],
                mode="markers",
                name=cat_name,
                marker=dict(
                    color=category_color,
                    size=15,
                    symbol="square",
                    line=dict(width=1, color="DarkSlateGrey"),
                ),
                legendgroup=cat_name,
                showlegend=True,
            )
            fig.add_trace(legend_trace)

        # --- Update layout ---
        title_name = category_name.replace(CELL_TYPE, "biospecimen")

        plot_title = f"{metric_to_plot} vs Number of Features - {title_name}"
        if name:
            plot_title += f" - {name}"
        xaxis_title = "Number of Features (log scale)"
        xaxis_type = "log"

        yaxis_title = metric_to_plot.replace("_", " ").title()
        yaxis_type = "linear"

        fig.update_layout(
            xaxis_title=xaxis_title,
            yaxis_title=yaxis_title,
            xaxis_type=xaxis_type,
            yaxis_type=yaxis_type,
            yaxis_range=metric_range,
            width=500,
            height=500,
            hovermode="closest",
            legend_title_text="Bin Size",
            title_text=plot_title,
            **main_title_settings
        )

        if category_name == CELL_TYPE:
            fig.update_yaxes(range=[0.75, 1.005])
        elif category_name == ASSAY:
            fig.update_yaxes(range=[0.96, 1.001])

        # --- Save or show figure ---
        if logdir:
            logdir.mkdir(parents=True, exist_ok=True)
            # Include "modified" or similar in filename to distinguish
            base_name = f"feature_scatter_MODIFIED_v2_{category_name}_{metric_to_plot}"
            if name:
                base_name += f"_{name}"
            html_path = logdir / f"{base_name}.html"
            svg_path = logdir / f"{base_name}.svg"
            png_path = logdir / f"{base_name}.png"

            print(f"Saving modified plot for {category_name} to {html_path}")
            fig.write_html(html_path)
            fig.write_image(svg_path)
            fig.write_image(png_path)

        fig.show()

Graph

Code
for metric in ["Accuracy", "F1_macro"]:
    graph_feature_set_scatter(
        all_metrics=metrics_fig1c,  # type: ignore
        input_sizes=input_sizes,
        metric_to_plot=metric,
        verbose=False,
    )

Fig. 1C: Distribution of accuracy per training fold for different bin resolutions for the Assay and Biospecimen classifiers. The circles represent the means and the whiskers the min and max values of the ten training folds.

D - Confusion matrix

Define create_confusion_matrix to create and show a confusion matrix.

Code
def create_confusion_matrix(
    df: pd.DataFrame,
    name: str = "confusion_matrix",
    logdir: Path | None = None,
    min_pred_score: float = 0,
    majority: bool = False,
    verbose:bool=False
) -> None:
    """Create a confusion matrix for the given DataFrame and save it to the logdir.

    Args:
        df (pd.DataFrame): The DataFrame containing the results.
        logdir (Path): The directory path for saving the figures.
        name (str): The name for the saved figures.
        min_pred_score (float): The minimum prediction score to consider.
        majority (bool): Whether to use majority vote (uuid-wise) for the predicted class.
    """
    # Compute confusion matrix
    classes = sorted(df["True class"].unique())
    if "Max pred" not in df.columns:
        df["Max pred"] = df[classes].max(axis=1)  # type: ignore
    filtered_df = df[df["Max pred"] > min_pred_score]

    if majority:
        # Majority vote for predicted class
        groupby_uuid = filtered_df.groupby(["uuid", "True class", "Predicted class"])[
            "Max pred"
        ].aggregate(["size", "mean"])

        if groupby_uuid["size"].max() > 3:
            raise ValueError("More than three predictions for the same uuid.")

        groupby_uuid = groupby_uuid.reset_index().sort_values(
            ["uuid", "True class", "size"], ascending=[True, True, False]
        )
        groupby_uuid = groupby_uuid.drop_duplicates(
            subset=["uuid", "True class"], keep="first"
        )
        filtered_df = groupby_uuid

    confusion_mat = sk_cm(
        filtered_df["True class"], filtered_df["Predicted class"], labels=classes
    )

    mat_writer = ConfusionMatrixWriter(labels=classes, confusion_matrix=confusion_mat)

    if logdir is None:
        logdir = Path(tempfile.gettempdir())

    files = mat_writer.to_all_formats(logdir, name=f"{name}_n{len(filtered_df)}")

    if verbose:
        print(f"Saved confusion matrix to {logdir}:")
        for file in files:
            print(Path(file).name)

    for file in files:
        if "png" in file.name:
            scale = 0.6
            display(Image(filename=file, width=1250*scale, height=1000*scale))

Prepare prediction data for confusion matrix.

Code
assay_split_dfs = split_results_handler.gather_split_results_across_methods(
    results_dir=data_dir_100kb, label_category=ASSAY, only_NN=True
)
concat_assay_df = split_results_handler.concatenate_split_results(assay_split_dfs)["NN"]

df_with_meta = metadata_handler.join_metadata(concat_assay_df, metadata_v2)  # type: ignore
if "Predicted class" not in df_with_meta.columns:
    raise ValueError("`Predicted class` not in DataFrame")

classifier_name = "MLP"
min_pred_score = 0
majority = False

name = f"{classifier_name}_pred>{min_pred_score}"

logdir = base_fig_dir / "fig1_EpiAtlas_assay" / "fig1_supp_D-assay_c11_confusion_matrices"
if majority:
    logdir = logdir / "per_uuid"
else:
    logdir = logdir / "per_file"
logdir.mkdir(parents=True, exist_ok=True)

Graph.

Code
create_confusion_matrix(
    df=df_with_meta,
    min_pred_score=min_pred_score,
    majority=majority,
)

Fig. 1D: Confusion matrix aggregating the cross-validation folds (therefore showing all files) without applying a prediction score threshold. RNA-seq and WGBS data were both separated according to two protocols during initial training (but combined thereafter to nine assays).

E - Mislabeled target assays

Fig. 1E: Genome browser representation showing in black the datasets swap between H3K4me3 and H3K27ac for IHECRE00001897 in the metadata freeze v1.0, along with typical correct datasets over a representative region.

Supplementary Figure 1

More detailled performance of EpiClass Assay and Biospecimen classifiers.

A,B - All classifiers metrics on EpiAtlas

Fig. 1A,B data points are included in these two graphs (MLP data points).


Define graphing function plot_multiple_models_split_metrics.

Code
def plot_multiple_models_split_metrics(
    split_metrics: Dict[str, Dict[str, Dict[str, float]]],
    label_category: str,
    logdir: Path | None = None,
    filename: str = "fig1_all_classifiers_metrics",
) -> None:
    """Render to box plots the metrics per classifier/models and split, each in its own subplot.

    Args:
        split_metrics: A dictionary containing metric scores for each classifier and split.
        label_category: The label category for the classification task.
        name: The name of the figure.
        logdir: The directory to save the figure to. If None, the figure is only displayed.

    Returns:
        None: Displays the figure and saves it to the logdir if provided.
    """
    metrics = ["Accuracy", "F1_macro", "AUC_micro", "AUC_macro"]
    classifier_names = list(next(iter(split_metrics.values())).keys())
    classifier_names = ["NN", "LR", "LGBM", "LinearSVC", "RF"]

    # Create subplots, one row for each metric
    fig = make_subplots(
        rows=1,
        cols=len(metrics),
        subplot_titles=metrics,
        horizontal_spacing=0.075,
    )

    for i, metric in enumerate(metrics):
        for classifier in classifier_names:
            values = [split_metrics[split][classifier][metric] for split in split_metrics]
            if classifier == "NN":
                classifier = "MLP"
            fig.add_trace(
                go.Box(
                    y=values,
                    name=classifier,
                    line=dict(color="black", width=1.5),
                    marker=dict(size=3, color="black"),
                    boxmean=True,
                    boxpoints="all",  # or "outliers" to show only outliers
                    pointpos=-1.4,
                    showlegend=False,
                    width=0.5,
                    hovertemplate="%{text}",
                    text=[
                        f"{split}: {value:.4f}"
                        for split, value in zip(split_metrics, values)
                    ],
                ),
                row=1,
                col=i + 1,
            )

    fig.update_layout(
        title_text=f"{label_category} classification",
        boxmode="group",
        **main_title_settings,
    )

    # Adjust y-axis
    if label_category == ASSAY:
        range_acc = [0.95, 1.001]
        range_AUC = [0.992, 1.0001]
    elif label_category == CELL_TYPE:
        range_acc = [0.81, 1]
        range_AUC = [0.96, 1]
    else:
        range_acc = [0.6, 1.001]
        range_AUC = [0.9, 1.0001]

    fig.update_layout(
        yaxis=dict(range=range_acc),
        yaxis2=dict(range=range_acc),
        yaxis3=dict(range=range_AUC),
        yaxis4=dict(range=range_AUC),
        height=450,
    )

    fig.update_layout(margin=dict(l=20, r=20))

    # Save figure
    if logdir:
        fig.write_image(logdir / f"{filename}.svg")
        fig.write_image(logdir / f"{filename}.png")
        fig.write_html(logdir / f"{filename}.html")

    fig.show()

Graph.

Code
merge_assays = True

for label_category in [ASSAY, CELL_TYPE]:
    all_split_dfs = split_results_handler.gather_split_results_across_methods(
        results_dir=data_dir_100kb,
        label_category=label_category,
        only_NN=False,
    )

    if merge_assays and label_category == ASSAY:
        for split_name, split_dfs in all_split_dfs.items():
            for classifier_type, df in split_dfs.items():
                split_dfs[classifier_type] = merge_similar_assays(df)

    split_metrics = split_results_handler.compute_split_metrics(all_split_dfs)

    plot_multiple_models_split_metrics(
        split_metrics,
        label_category=label_category,
    )

Supplementary Figure 1A,B: Distribution of performance scores (accuracy, F1 as well as micro and macro AUROC) per training fold (dots) for each machine learning approach used for training on the Assay (A) and Biospecimen (B) metadata. Micro-averaging aggregates contributions from all classes (global true positive rate and false positive rate); macro-averaging averages the true positive rate from each class. Dashed lines represent means, solid lines the medians, boxes the quartiles, and whiskers the farthest points within 1.5× the interquartile range.


Going forward, all results are for MLP classifiers.

C - ROC curves

Define graphing function plot_roc_curves. Computes macro-average ROC curves manually.

Code
def plot_roc_curves(
    results_df: pd.DataFrame,
    label_category: str,
    logdir: Path | None = None,
    name: str = "roc_curve",
    title: str | None = None,
    colors_dict: Dict | None = None,  # Optional specific colors
    verbose: bool = False,
) -> None:
    """
    Generates and plots ROC curves for multi-class classification results using Plotly.

    Calculates and plots individual class ROC curves, micro-average, and macro-average ROC curves.

    Args:
        results_df (pd.DataFrame): DataFrame with true labels and prediction probabilities for each class.
                                   Must contain the `label_category` column (e.g., 'True class')
                                   and probability columns named after each class.
        label_category (str): The column name containing the true labels (e.g., 'True class', ASSAY, CELL_TYPE).
        logdir (Path | None): Directory to save the figure. If None, only displays the figure.
        name (str): Base name for saved files (e.g., "supp_fig1e").
        title (str | None): Title suffix for the plot. If None, a default title based on label_category is used.
        colors_dict (Dict | None): Optional dictionary mapping class names to colors. If None or a class
                                   is missing, default Plotly colors are used.
    """
    df = results_df.copy()
    true_label_col = "True class"  # Assuming 'True class' holds the ground truth labels

    if true_label_col not in df.columns:
        raise ValueError(f"True label column '{true_label_col}' not found in DataFrame.")

    classes = sorted(df[true_label_col].unique())
    if verbose:
        print(f"Using classes: {classes}")

    n_classes = len(classes)
    if n_classes < 2:
        print(
            f"Warning: Only {n_classes} class found after processing. Cannot generate ROC curve."
        )
        return

    # Check if probability columns exist for all determined classes
    missing_cols = [c for c in classes if c not in df.columns]
    if missing_cols:
        raise ValueError(f"Missing probability columns for classes: {missing_cols}")

    # Binarize the true labels against the final set of classes
    try:
        y_true = label_binarize(df[true_label_col], classes=classes)
    except ValueError as e:
        raise ValueError(
            f"Error binarizing labels for classes {classes}. Check if all labels in '{true_label_col}' are included in 'classes'."
        ) from e

    if n_classes == 2 and y_true.shape[1] == 1:
        # Adjust for binary case where label_binarize might return one column
        y_true = np.hstack((1 - y_true, y_true))  # type: ignore
    elif y_true.shape[1] != n_classes:
        raise ValueError(
            f"Binarized labels shape {y_true.shape} does not match number of classes {n_classes}"
        )

    # Get the predicted probabilities for each class
    # Ensure columns are in the same order as 'classes'
    y_score = df[classes].values

    # --- Compute ROC curve and ROC area for each class ---
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i, class_name in enumerate(classes):
        try:
            fpr[class_name], tpr[class_name], _ = roc_curve(
                y_true=y_true[:, i], y_score=y_score[:, i]  # type: ignore
            )
            roc_auc[class_name] = auc(fpr[class_name], tpr[class_name])
        except ValueError as e:
            raise ValueError("Could not compute ROC for class {class_name}.") from e

    # --- Compute micro-average ROC curve and ROC area ---
    try:
        fpr["micro"], tpr["micro"], _ = roc_curve(y_true.ravel(), y_score.ravel())  # type: ignore
        roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
    except ValueError as e:
        raise ValueError("Could not compute micro-average ROC.") from e

    # --- Compute macro-average ROC curve and ROC area ---
    try:
        # Aggregate all false positive rates
        all_fpr = np.unique(
            np.concatenate(
                [fpr[class_name] for class_name in classes if class_name in fpr]
            )
        )
        # Interpolate all ROC curves at these points
        mean_tpr = np.zeros_like(all_fpr)
        valid_classes_count = 0
        for class_name in classes:
            if class_name in fpr and class_name in tpr:
                mean_tpr += np.interp(all_fpr, fpr[class_name], tpr[class_name])
                valid_classes_count += 1

        # Average it and compute AUC
        if valid_classes_count > 0:
            mean_tpr /= valid_classes_count
            fpr["macro"] = all_fpr
            tpr["macro"] = mean_tpr
            roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
        else:
            raise ValueError("No valid classes found for macro averaging.")

    except ValueError as e:
        raise ValueError("Could not compute macro-average ROC.") from e

    # --- Plot all ROC curves ---
    fig = go.Figure()

    # Plot diagonal line for reference
    fig.add_shape(
        type="line", line=dict(dash="dash", color="grey", width=1), x0=0, x1=1, y0=0, y1=1
    )

    # Define colors for plotting
    color_cycle = px.colors.qualitative.Plotly  # Default cycle
    plot_colors = {}
    for i, cls_name in enumerate(classes):
        if colors_dict and cls_name in colors_dict:
            plot_colors[cls_name] = colors_dict[cls_name]
        else:
            plot_colors[cls_name] = color_cycle[i % len(color_cycle)]

    # Plot Micro-average ROC curve first (often plotted thicker/dashed)
    fig.add_trace(
        go.Scatter(
            x=fpr["micro"],
            y=tpr["micro"],
            mode="lines",
            name=f'Micro-average ROC (AUC = {roc_auc["micro"]:.5f})',
            line=dict(color="deeppink", width=3, dash="dash"),
            hoverinfo="skip",  # Less important for hover usually
        )
    )

    # Plot Macro-average ROC curve
    fig.add_trace(
        go.Scatter(
            x=fpr["macro"],
            y=tpr["macro"],
            mode="lines",
            name=f'Macro-average ROC (AUC = {roc_auc["macro"]:.5f})',
            line=dict(color="navy", width=3, dash="dash"),
            hoverinfo="skip",
        )
    )

    # Plot individual class ROC curves
    for class_name in classes:
        if class_name not in fpr or class_name not in tpr or class_name not in roc_auc:
            continue  # Skip if calculation failed
        fig.add_trace(
            go.Scatter(
                x=fpr[class_name],
                y=tpr[class_name],
                mode="lines",
                name=f"{class_name} (AUC = {roc_auc[class_name]:.5f})",
                line=dict(width=1.5, color=plot_colors.get(class_name)),
                hovertemplate=f"<b>{class_name}</b><br>FPR=%{{x:.5f}}<br>TPR=%{{y:.5f}}<extra></extra>",  # Show class name and values on hover
            )
        )

    # --- Update layout ---
    base_title = f"ROC Curves<br>{label_category}"
    plot_title = f"{base_title} - {title}" if title else base_title

    title_settings=dict(
        yanchor="top",
        yref="paper",
        y=0.97,
        xanchor="center",
        xref="paper",
        x=0.5,
    )

    fig.update_layout(
        title=title_settings,
        title_text=plot_title,
        xaxis_title="False Positive Rate (1 - Specificity)",
        yaxis_title="True Positive Rate (Sensitivity)",
        xaxis=dict(range=[0.0, 1.0], constrain="domain"),  # Ensure axes range 0-1
        yaxis=dict(
            range=[0.0, 1.01], scaleanchor="x", scaleratio=1, constrain="domain"
        ),  # Make it square-ish, slight top margin
        width=800,
        height=650,
        hovermode="closest",
        legend=dict(
            traceorder="reversed",  # Show averages first in legend
            title="Classes & Averages",
            font=dict(size=9),
            itemsizing="constant",
            y=0.8,
            yref="paper",
        ),
        margin=dict(l=60, r=30, t=0, b=0),
    )

    # --- Save figure if logdir is provided ---
    if logdir:
        logdir.mkdir(parents=True, exist_ok=True)  # Ensure directory exists
        filename_base = f"{name}_{label_category}_roc"
        filepath_base = logdir / filename_base

        fig.write_html(f"{filepath_base}.html")
        fig.write_image(f"{filepath_base}.svg", width=800, height=750)
        fig.write_image(f"{filepath_base}.png", width=800, height=750, scale=2)

        print(f"Saved ROC curve plots for {label_category} to {logdir}")
        print(f" -> {filename_base}.html / .svg / .png")

    fig.show()

Prepare assay data for plotting.

Code
data_dir = (
    mixed_data_dir
    / "hg38_100kb_all_none"
    / f"{ASSAY}_1l_3000n"
    / "11c"
    / "10fold-oversampling"
)
if not data_dir.exists():
    raise FileNotFoundError(f"Directory {data_dir} does not exist.")

dfs = split_results_handler.read_split_results(data_dir)
concat_df: pd.DataFrame = split_results_handler.concatenate_split_results(dfs, depth=1)  # type: ignore
concat_df = split_results_handler.add_max_pred(concat_df)
concat_df_w_meta = metadata_handler.join_metadata(concat_df, metadata_v2)

df = merge_similar_assays(concat_df_w_meta.copy())

Graph assay results.

Code
plot_roc_curves(
    results_df=df.copy(),
    label_category=ASSAY,
    title="Aggregated 10fold",  # Title suffix
    colors_dict=assay_colors,
    verbose=False,
)

Prepare biospecimen data for plotting.

Code
data_dir = (
    mixed_data_dir
    / "hg38_100kb_all_none"
    / f"{CELL_TYPE}_1l_3000n"
    / "10fold-oversampling"
)
if not data_dir.exists():
    raise FileNotFoundError(f"Directory {data_dir} does not exist.")

dfs = split_results_handler.read_split_results(data_dir)
concat_df: pd.DataFrame = split_results_handler.concatenate_split_results(dfs, depth=1)  # type: ignore
concat_df = split_results_handler.add_max_pred(concat_df)
concat_df_w_meta = metadata_handler.join_metadata(concat_df, metadata_v2)

Graph biospecimen results.

Code
plot_roc_curves(
    results_df=concat_df_w_meta,
    label_category=CELL_TYPE,
    title="Aggregated 10fold",  # Title suffix
    colors_dict=cell_type_colors,
    verbose=False,
)

Supplementary Figure 1C: ROC curves from aggregated cross-validation results for the Assay and Biospecimen classifiers. Curves for each class are computed in a one-vs-rest scheme.

D - Alternative signal pre-processing

Define graphing function create_blklst_graphs.

Code
def create_blklst_graphs(
    feature_set_metrics_dict: Dict[str, Dict[str, Dict[str, Dict[str, float]]]],
    logdir: Path | None = None,
) -> List[go.Figure]:
    """Create boxplots for blacklisted related feature sets.

    Args:
        feature_set_metrics_dict (Dict[str, Dict[str, Dict[str, Dict[str, float]]]]): The dictionary containing all metrics for all blklst related feature sets.
            format: {feature_set: {task_name: {split_name: metric_dict}}}
        logdir (Path, Optional): The directory to save the figure to. If None, the figure is only displayed.
    """
    figs = []

    # Assume names exist in all feature sets
    task_names = list(feature_set_metrics_dict.values())[0].keys()

    traces_names_dict = {
        "hg38_100kb_all_none": "observed",
        "hg38_100kb_all_none_0blklst": "0blklst",
        "hg38_100kb_all_none_0blklst_winsorized": "0blklst_winsorized",
    }

    for task_name in task_names:
        category_fig = make_subplots(
            rows=1,
            cols=2,
            shared_yaxes=False,
            subplot_titles=["Accuracy", "F1-score (macro)"],
            horizontal_spacing=0.1,
        )
        for feature_set_name, tasks_dicts in feature_set_metrics_dict.items():
            task_dict = tasks_dicts[task_name]
            trace_name = traces_names_dict[feature_set_name]

            # Accuracy
            metric = "Accuracy"
            y_vals = [task_dict[split][metric] for split in task_dict]  # type: ignore
            hovertext = [
                f"{split}: {metrics_dict[metric]:.4f}"  # type: ignore
                for split, metrics_dict in task_dict.items()
            ]

            category_fig.add_trace(
                go.Box(
                    y=y_vals,
                    name=trace_name,
                    boxmean=True,
                    boxpoints="all",
                    showlegend=False,
                    marker=dict(size=3, color="black"),
                    line=dict(width=1, color="black"),
                    hovertemplate="%{text}",
                    text=hovertext,
                ),
                row=1,
                col=1,
            )

            metric = "F1_macro"
            y_vals = [task_dict[split][metric] for split in task_dict]  # type: ignore
            hovertext = [
                f"{split}: {metrics_dict[metric]:.4f}"  # type: ignore
                for split, metrics_dict in task_dict.items()
            ]
            category_fig.add_trace(
                go.Box(
                    y=y_vals,
                    name=trace_name,
                    boxmean=True,
                    boxpoints="all",
                    showlegend=False,
                    marker=dict(size=3, color="black"),
                    line=dict(width=1, color="black"),
                    hovertemplate="%{text}",
                    text=hovertext,
                ),
                row=1,
                col=2,
            )

        category_fig.update_xaxes(
            categoryorder="array",
            categoryarray=list(traces_names_dict.values()),
        )
        category_fig.update_yaxes(range=[0.9, 1.001])

        category_fig.update_layout(
            title_text=task_name,
            height=600,
            width=500,
            **main_title_settings
        )

        # Save figure
        if logdir:
            task_name = task_name.replace("_1l_3000n-10fold", "")
            base_name = f"metrics_{task_name}"

            category_fig.write_html(logdir / f"{base_name}.html")
            category_fig.write_image(logdir / f"{base_name}.svg")
            category_fig.write_image(logdir / f"{base_name}.png")

        figs.append(category_fig)

    return figs

Prepare paths.

Code
include_sets = [
    "hg38_100kb_all_none",
    "hg38_100kb_all_none_0blklst",
    "hg38_100kb_all_none_0blklst_winsorized",
]

results_folder_blklst = base_data_dir / "training_results" / "2023-01-epiatlas-freeze"
if not results_folder_blklst.exists():
    raise FileNotFoundError(f"Folder '{results_folder_blklst}' not found")

Compute metrics.

Code
# Select 10-fold oversampling runs
# expected result shape: {feature_set: {task_name: {split_name: metrics_dict}}}
all_metrics_blklst: Dict[
    str, Dict[str, Dict[str, Dict[str, float]]]
] = split_results_handler.obtain_all_feature_set_data(
    return_type="metrics",
    parent_folder=results_folder_blklst,
    merge_assays=True,
    include_categories=[ASSAY, CELL_TYPE],
    include_sets=include_sets,
    oversampled_only=False,
    verbose=False,
)  # type: ignore

Graph.

Code
figs = create_blklst_graphs(all_metrics_blklst)

figs[0].show()
figs[1].show()

Supplementary Figure 1D: Distribution of accuracy and F1-score per training fold (dots) for the Assay and Biospecimen classifiers after removing signal from blacklisted regions and applying winsorization of 0.1%. Dashed lines represent means, solid lines the medians, boxes the quartiles, and whiskers the farthest points within 1.5× the interquartile range.

E-G - Distribution of average prediction scores per assay

  • E: Assay training 10-fold cross-validation
  • F: Assay complete training (mixed tracks), predictions on imputed data (all pval)
  • G: Biospecimen 10-fold cross-validation

Define graphing function plot_prediction_scores_distribution.

Supplementary Figure 1E-G: Distribution of average prediction score per file (dots) for the majority-vote class (up to three track type files) (E, F) or individual file (G), from the MLP approach for the Assay (E, G) and Biospecimen classifiers (F), using aggregated cross-validation results from observed data (E, F) or results from the classifier trained on all observed data and applied to imputed data from EpiATLAS (G). Dashed lines represent means, solid lines the medians, boxes the quartiles, and whiskers the farthest points within 1.5× the interquartile range, with a violin representation on top.

E - Assay

Gather prediction scores.

Code
data_dir = (
    mixed_data_dir
    / "hg38_100kb_all_none"
    / f"{ASSAY}_1l_3000n"
    / "11c"
    / "10fold-oversampling"
)
if not data_dir.exists():
    raise FileNotFoundError(f"Directory {data_dir} does not exist.")

dfs = split_results_handler.read_split_results(data_dir)
concat_df: pd.DataFrame = split_results_handler.concatenate_split_results(dfs, depth=1)  # type: ignore
concat_df = split_results_handler.add_max_pred(concat_df)
concat_df_w_meta = metadata_handler.join_metadata(concat_df, metadata_v2)

Graph.

Code
plot_prediction_scores_distribution(
    results_df=concat_df_w_meta,
    group_by_column=ASSAY,
    merge_assay_pairs=True,
    min_y=0.7,
    title="11 classes assay training<br>Prediction scores for 10-fold cross-validation",
)

F - Biospecimen

Gather prediction scores.

Code
data_dir = data_dir_100kb / f"{CELL_TYPE}_1l_3000n" / "10fold-oversampling"
if not data_dir.exists():
    raise FileNotFoundError(f"Directory {data_dir} does not exist.")

dfs = split_results_handler.read_split_results(data_dir)
concat_df: pd.DataFrame = split_results_handler.concatenate_split_results(dfs, depth=1)  # type: ignore
concat_df = split_results_handler.add_max_pred(concat_df)
concat_df_w_meta = metadata_handler.join_metadata(concat_df, metadata_v2)
concat_df_w_meta.replace({ASSAY: ASSAY_MERGE_DICT}, inplace=True)

Graph.

Code
plot_prediction_scores_distribution(
    results_df=concat_df_w_meta,
    group_by_column=ASSAY,
    min_y=0,
    title="Biospecimen training<br>Prediction scores for 10-fold cross-validation",
)
Skipping assay merging: Wrong results dataframe, rna or wgbs columns missing.

G - Assay imputed

Gather imputed signal metadata.

Code
metadata_path = (
    paper_dir
    / "data"
    / "metadata"
    / "epiatlas"
    / "imputed"
    / "hg38_epiatlas_imputed_pval_chip_2024-02.json"
)
metadata_imputed: pd.DataFrame = metadata_handler.load_any_metadata(metadata_path, as_dataframe=True)  # type: ignore

Gather prediction scores.

Code
data_dir = (
    gen_data_dir
    / "hg38_100kb_all_none"
    / f"{ASSAY}_1l_3000n"
    / "11c"
    / "complete_no_valid_oversample"
    / "predictions"
    / "epiatlas_imputed"
    / "ChIP"
)
if not data_dir.exists():
    raise FileNotFoundError(f"Directory {data_dir} does not exist.")

df_pred = pd.read_csv(
    data_dir / "complete_no_valid_oversample_prediction.csv",
    index_col=0,
)

Prepare dataframe for graphing.

Code
assay_classes = list(metadata_v2_df[ASSAY].unique())
df_pred = split_results_handler.add_max_pred(df_pred, expected_classes=assay_classes)

augmented_df = pd.merge(df_pred, metadata_imputed, left_index=True, right_on="md5sum")
augmented_df["True class"] = augmented_df[ASSAY]
print("Number of files per assay:")
print(augmented_df["True class"].value_counts(dropna=False).to_string())
Number of files per assay:
h3k36me3    1703
h3k27me3    1703
h3k9me3     1700
h3k4me1     1688
h3k4me3     1688
h3k27ac     1088



Graph.

Code
plot_prediction_scores_distribution(
    results_df=augmented_df,
    group_by_column=ASSAY,
    merge_assay_pairs=True,
    min_y=0.79,
    use_aggregate_vote=False,
    title="Complete 11c assay classifier<br>inference on imputed data",
)

H,I - Prediction score thresholds

For the code that produced the figures, see src/python/epiclass/utils/notebooks/paper/confidence_threshold.ipynb.

Supplementary Figure 1H,I: Distribution of aggregated accuracy, F1-score and corresponding file subset size across varying prediction score thresholds, based on pooled predictions from all cross-validation folds for the Assay (H) and Biospecimen (I) classifiers.

Supplementary Figure 2 - Biospecimen performance on various feature sets.

A,B

Code
metrics_supp2 = {name: all_metrics[name] for name in feature_sets_14}

graph_feature_set_metrics(
    all_metrics=metrics_supp2,  # type: ignore
    input_sizes=input_sizes,
    boxpoints="all",
    width=900,
    height=600,
)

C,D

Code
def prepare_metric_sets_per_assay(
    all_results: Dict[str, Dict[str, Dict[str, pd.DataFrame]]], verbose: bool = False
) -> Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, float]]]]]:
    """Prepare metric sets per assay.

    Args:
        all_results (Dict[str, Dict[str, Dict[str, pd.DataFrame]]]): A dictionary containing all results for all feature sets.

    Returns:
        Dict[str, Dict[str, Dict[str, Dict[str, float]]]]: A dictionary containing all metrics per assay for all feature sets.
            Format: {assay: {feature_set: {task_name: {split_name: metric_dict}}}}
    """
    if verbose:
        print("Loading metadata.")
    metadata = metadata_handler.load_metadata("v2")
    metadata.convert_classes(ASSAY, ASSAY_MERGE_DICT)
    md5_per_assay = metadata.md5_per_class(ASSAY)
    md5_per_assay = {k: set(v) for k, v in md5_per_assay.items()}

    if verbose:
        print("Getting results per assay.")
    results_per_assay = {}
    for assay_label in ASSAY_ORDER:
        if verbose:
            print(assay_label)
        results_per_assay[assay_label] = {}
        for feature_set, task_dict in all_results.items():
            if verbose:
                print(feature_set)
            results_per_assay[assay_label][feature_set] = {}
            for task_name, split_dict in task_dict.items():
                if verbose:
                    print(task_name)
                results_per_assay[assay_label][feature_set][task_name] = {}

                # Only keep the relevant assay
                for split_name, split_df in split_dict.items():
                    if verbose:
                        print(split_name)
                    assay_df = split_df[split_df.index.isin(md5_per_assay[assay_label])]
                    results_per_assay[assay_label][feature_set][task_name][
                        split_name
                    ] = assay_df

    if verbose:
        print("Finished getting results per assay. Now computing metrics.")
    metrics_per_assay = {}
    for assay_label in ASSAY_ORDER:
        if verbose:
            print(assay_label)
        metrics_per_assay[assay_label] = {}
        for feature_set, task_dict in results_per_assay[assay_label].items():
            if verbose:
                print(feature_set)
            assay_metrics = split_results_handler.compute_split_metrics(
                task_dict, concat_first_level=True
            )
            inverted_dict = split_results_handler.invert_metrics_dict(assay_metrics)
            metrics_per_assay[assay_label][feature_set] = inverted_dict

    return metrics_per_assay
Code
def graph_feature_set_metrics_per_assay(
    all_metrics_per_assay: Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, float]]]]],
    input_sizes: Dict[str, int],
    logdir: Path | None = None,
    sort_by_input_size: bool = False,
    name: str | None = None,
    y_range: Tuple[float, float] | None = None,
    boxpoints: str = "outliers",
) -> None:
    """Graph the metrics for all feature sets, per assay, with separate plots for accuracy and F1-score.

    Args:
        all_metrics_per_assay (Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, float]]]]]): A dictionary containing all metrics per assay for all feature sets.
            Format: {assay: {feature_set: {task_name: {split_name: metric_dict}}}}
        input_sizes (Dict[str, int]): A dictionary containing the input sizes for all feature sets.
        logdir (Path): The directory where the figures will be saved. If None, the figures will only be displayed.
        sort_by_input_size (bool): Whether to sort the feature sets by input size.
        name (str|None): The name of the figure.
        y_range (Tuple[float, float]|None): The y-axis range for the plots.
        boxpoints (str): The type of points to display in the box plots. Defaults to "outliers".
    """
    valid_boxpoints = ["all", "outliers"]
    if boxpoints not in valid_boxpoints:
        raise ValueError(f"Invalid boxpoints value. Choose from {valid_boxpoints}.")

    fig_assay_order = [
        "rna_seq",
        "h3k27ac",
        "h3k4me1",
        "h3k4me3",
        "h3k36me3",
        "h3k27me3",
        "h3k9me3",
        "input",
        "wgbs",
    ]

    reference_assay = next(iter(all_metrics_per_assay))
    reference_feature_set = next(iter(all_metrics_per_assay[reference_assay]))
    metadata_categories = list(
        all_metrics_per_assay[reference_assay][reference_feature_set].keys()
    )

    for _, category in enumerate(metadata_categories):
        for metric, metric_name in [
            ("Accuracy", "Accuracy"),
            ("F1_macro", "F1-score (macro)"),
        ]:
            fig = go.Figure()

            feature_sets = list(all_metrics_per_assay[reference_assay].keys())
            unique_feature_sets = set(feature_sets)
            for assay in fig_assay_order:
                if set(all_metrics_per_assay[assay].keys()) != unique_feature_sets:
                    raise ValueError("Different feature sets through assays.")

            feature_set_order = feature_sets
            if sort_by_input_size:
                feature_set_order = sorted(
                    feature_set_order, key=lambda x: input_sizes[x]
                )

            # Adjust spacing so each assay group has dedicated space based on the number of feature sets
            spacing_multiplier = (
                1.1  # Increase this multiplier if needed to add more spacing
            )
            x_positions = {
                assay: i * len(feature_set_order) * spacing_multiplier
                for i, assay in enumerate(fig_assay_order)
            }

            for i, feature_set_name in enumerate(feature_set_order):
                resolution = (
                    feature_set_name.replace("_none", "")
                    .replace("hg38_", "")
                    .split("_")[0]
                )
                color = resolution_colors[resolution]
                display_name = feature_set_name.replace("_none", "").replace("hg38_", "")

                for assay in fig_assay_order:
                    if feature_set_name not in all_metrics_per_assay[assay]:
                        continue

                    tasks_dicts = all_metrics_per_assay[assay][feature_set_name]

                    if feature_set_name not in input_sizes:
                        print(f"Skipping {feature_set_name}, no input size found.")
                        continue

                    task_name = category
                    if "split" in task_name:
                        raise ValueError("Split in task name. Wrong metrics dict.")

                    try:
                        task_dict = tasks_dicts[task_name]
                    except KeyError:
                        print(
                            f"Skipping {feature_set_name}, {task_name} for assay {assay}"
                        )
                        continue

                    y_vals = [task_dict[split][metric] for split in task_dict]
                    hovertext = [
                        f"{assay} - {display_name} - {split}: {metrics_dict[metric]:.4f}"
                        for split, metrics_dict in task_dict.items()
                    ]

                    x_position = x_positions[assay] + i
                    fig.add_trace(
                        go.Box(
                            x=[x_position] * len(y_vals),
                            y=y_vals,
                            name=f"{assay}|{display_name}",
                            boxmean=True,
                            boxpoints=boxpoints,
                            marker=dict(size=3, color="black"),
                            line=dict(width=1, color="black"),
                            fillcolor=color,
                            hovertemplate="%{text}",
                            text=hovertext,
                            showlegend=False,
                            legendgroup=display_name,
                        )
                    )

                    # separate box groups
                    fig.add_vline(
                        x=x_positions[assay] - 1, line_width=1, line_color="black"
                    )

            # Add dummy traces for the legend
            for feature_set_name in feature_set_order:
                resolution = (
                    feature_set_name.replace("_none", "")
                    .replace("hg38_", "")
                    .split("_")[0]
                )
                color = resolution_colors[resolution]
                display_name = feature_set_name.replace("_none", "").replace("hg38_", "")
                display_name = re.sub(r"\_[\dmkb]+\_coord", "", display_name)

                fig.add_trace(
                    go.Scatter(
                        name=display_name,
                        x=[None],
                        y=[None],
                        mode="markers",
                        marker=dict(size=10, color=color),
                        showlegend=True,
                        legendgroup=display_name,
                    )
                )

            title = f"{category} - {metric_name} (per assay)"
            if name is not None:
                title += f" - {name}"

            fig.update_layout(
                width=1250,
                height=900,
                title_text=title,
                xaxis_title="Assay",
                yaxis_title=metric_name,
                **main_title_settings
            )

            # Create x-axis labels
            fig.update_xaxes(
                tickmode="array",
                tickvals=[
                    x_positions[assay] + len(feature_set_order) / 2
                    for assay in fig_assay_order
                ],
                ticktext=list(x_positions.keys()),
                title="Assay",
            )

            fig.update_layout(
                legend=dict(
                    title="Feature Sets", itemsizing="constant", traceorder="normal"
                )
            )
            if y_range:
                fig.update_yaxes(range=y_range)

            if logdir:
                base_name = f"feature_set_metrics_{category}_{metric}_per_assay"
                if name is not None:
                    base_name = base_name + f"_{name}"
                fig.write_html(logdir / f"{base_name}.html")
                fig.write_image(logdir / f"{base_name}.svg")
                fig.write_image(logdir / f"{base_name}.png")

            fig.show()
Code
set_selection_name = "feature_sets_14"
all_results = split_results_handler.obtain_all_feature_set_data(
    parent_folder=mixed_data_dir,
    merge_assays=True,
    return_type="split_results",
    include_categories=[CELL_TYPE],
    include_sets=metric_orders_map[set_selection_name],
    exclude_names=["16ct", "27ct", "7c", "chip-seq-only"],
)
Code
metrics_per_assay = prepare_metric_sets_per_assay(all_results)  # type: ignore
Code
# Reorder feature sets
feature_set_order = metric_orders_map[set_selection_name]
for assay, feature_sets in list(metrics_per_assay.items()):
    metrics_per_assay[assay] = {
        feature_set_name: metrics_per_assay[assay][feature_set_name]
        for feature_set_name in feature_set_order
    }
Code
graph_feature_set_metrics_per_assay(
    all_metrics_per_assay=metrics_per_assay,  # type: ignore
    input_sizes=input_sizes,
    boxpoints="all",
    sort_by_input_size=False,
    y_range=(0.1, 1.01)
)