Source code for maldibatchkit.diagnostics.benchmark

"""Class-based diagnostic comparison of batch correctors.

:class:`BatchCorrectionBenchmark` mirrors the
``GridSearchCV``-style configure-then-``fit``-then-inspect shape: build
it with a dict of unfitted correctors and a list of metric names (or
callables), then call ``.fit(X, batch=..., species=...)`` to compute
every metric on every corrector.

The benchmark is **diagnostic-only by design**: it does not score a
downstream classifier. Use
:class:`~maldibatchkit.corrections.auto.AutoCorrector` together with
``sklearn.model_selection.GridSearchCV`` if you want the downstream
metric (AUROC for AMR) to pick the winner.

Aggregation
-----------
``fit`` always populates two tables:

* ``results_long_`` - one row per (method, metric, repeat, bootstrap).
  Use it when you want to compute your own statistics or feed a
  raincloud / strip plot.
* ``results_`` - one row per (method, metric) with the central tendency
  and a confidence interval. The convention is:

  - When ``n_bootstrap >= 2``, ``ci_lo`` / ``ci_hi`` are the 2.5 / 97.5
    **percentiles** of the bootstrap distribution and ``value`` is the
    bootstrap mean (this is the standard non-parametric bootstrap
    percentile interval, e.g. Efron & Tibshirani 1993).
  - With ``n_bootstrap < 2`` but ``n_repeats >= 3`` (only meaningful for
    ``protocol='stratified_split'``), ``ci_lo`` / ``ci_hi`` are the
    2.5 / 97.5 percentiles **across repeats**, so repeat-to-repeat
    instability shows up as a wider interval. ``value`` is the mean
    across repeats.
  - Below those thresholds the interval columns hold ``NaN`` and
    ``value`` is the single observation. Don't read a CI from one
    or two points.

  ``std`` and ``n`` are always populated so you can post-hoc compute
  a different summary if you prefer.
"""

from __future__ import annotations

import inspect
import warnings
from collections.abc import Callable, Mapping, Sequence
from typing import Any

import numpy as np
import pandas as pd
from sklearn.base import clone
from sklearn.model_selection import StratifiedShuffleSplit

from .._base import BaseBatchCorrector
from .._utils import ArrayLike
from .generic import (
    kbet,
    lisi,
    lisi_normalized,
    silhouette_batch,
    species_preservation,
)
from .maldi import peak_position_drift, tic_cov_per_batch

__all__ = ["BatchCorrectionBenchmark"]


def _kbet_acceptance(X: ArrayLike, batch: ArrayLike, **kw: Any) -> float:
    """KBET acceptance rate (scalar wrapper for the dict-returning kbet)."""
    return float(
        kbet(X, batch, **{k: v for k, v in kw.items() if k in ("k", "alpha")})[
            "acceptance_rate"
        ]
    )


def _peak_drift_mean(X: ArrayLike, batch: ArrayLike, **kw: Any) -> float:
    """Per-batch mean |delta m/z|, averaged across batches."""
    mz_values = kw.get("mz_values")
    top_k = kw.get("top_k_peaks", 50)
    out = peak_position_drift(X, batch, mz_values=mz_values, top_k=top_k)
    if out.empty:
        return float("nan")
    return float(out["mean_delta_mz"].mean())


def _tic_cov_mean(X: ArrayLike, batch: ArrayLike, **kw: Any) -> float:
    """Per-batch TIC coefficient of variation, averaged across batches."""
    out = tic_cov_per_batch(X, batch)
    if out.empty:
        return float("nan")
    return float(out.mean())


_METRIC_REGISTRY: dict[str, Callable[..., float]] = {
    "kbet": _kbet_acceptance,
    "kbet_acceptance_rate": _kbet_acceptance,
    "lisi": lambda X, batch, **kw: float(
        lisi(X, batch, perplexity=kw.get("perplexity", 30.0))
    ),
    "lisi_normalized": lambda X, batch, **kw: float(
        lisi_normalized(X, batch, perplexity=kw.get("perplexity", 30.0))
    ),
    "silhouette_batch": lambda X, batch, **kw: silhouette_batch(X, batch),
    "species_preservation": lambda X, batch, *, species=None, **kw: float(
        species_preservation(
            X,
            species if species is not None else batch,
            perplexity=kw.get("perplexity", 30.0),
        )
    ),
    "peak_position_drift": _peak_drift_mean,
    "tic_cov_per_batch": _tic_cov_mean,
}

_METRIC_DIRECTION: dict[str, str] = {
    "kbet": "higher",
    "kbet_acceptance_rate": "higher",
    "lisi": "higher",
    "lisi_normalized": "higher",
    "silhouette_batch": "zero",
    "species_preservation": "higher",
    "peak_position_drift": "lower",
    "tic_cov_per_batch": "lower",
}


def _resolve_metric(metric: Any) -> tuple[str, Callable[..., float]]:
    """Return ``(name, fn)`` for a string alias or callable."""
    if isinstance(metric, str):
        if metric not in _METRIC_REGISTRY:
            raise ValueError(
                f"Unknown metric {metric!r}. Registered: {sorted(_METRIC_REGISTRY)} "
                f"(or pass a callable)."
            )
        return metric, _METRIC_REGISTRY[metric]
    if callable(metric):
        name = getattr(metric, "__name__", repr(metric))
        return name, metric
    raise TypeError(
        f"metric must be a string or callable; got {type(metric).__name__}."
    )


def _call_metric(
    fn: Callable[..., float],
    X: pd.DataFrame,
    batch: np.ndarray,
    *,
    species: Any,
    extra: Mapping[str, Any],
) -> float:
    """Invoke ``fn`` with whichever supported kwargs it accepts."""
    try:
        sig = inspect.signature(fn)
        accepted = {
            name
            for name, p in sig.parameters.items()
            if p.kind
            not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
        }
        has_var_kw = any(
            p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
        )
    except (TypeError, ValueError):
        accepted, has_var_kw = set(), True

    kw: dict[str, Any] = {}
    if "species" in accepted or has_var_kw:
        kw["species"] = species
    for k, v in extra.items():
        if k in accepted or has_var_kw:
            kw[k] = v
    return float(fn(X, batch, **kw))


def _bootstrap_indices(
    rng: np.random.Generator,
    batch: np.ndarray,
    n_bootstrap: int,
) -> list[np.ndarray]:
    """Stratified bootstrap row indices, sampling with replacement within each batch."""
    levels = np.unique(batch)
    by_level = {lvl: np.flatnonzero(batch == lvl) for lvl in levels}
    out = []
    for _ in range(n_bootstrap):
        chunks = [
            rng.choice(idxs, size=len(idxs), replace=True) for idxs in by_level.values()
        ]
        out.append(np.concatenate(chunks))
    return out


def _summarise(group: pd.DataFrame, has_bootstrap: bool) -> pd.Series:
    """Reduce a long-form group to value / ci_lo / ci_hi / std / n."""
    vals = group["value"].to_numpy(dtype=float)
    n = int(len(vals))
    if n == 0:
        return pd.Series(
            {"value": np.nan, "ci_lo": np.nan, "ci_hi": np.nan, "std": np.nan, "n": 0}
        )
    if has_bootstrap and n >= 2:
        ci_lo, ci_hi = np.nanpercentile(vals, [2.5, 97.5])
        value = float(np.nanmean(vals))
    elif n >= 3:
        ci_lo, ci_hi = np.nanpercentile(vals, [2.5, 97.5])
        value = float(np.nanmean(vals))
    else:
        ci_lo, ci_hi = np.nan, np.nan
        value = float(np.nanmean(vals))
    std = float(np.nanstd(vals, ddof=1)) if n >= 2 else float("nan")
    return pd.Series(
        {"value": value, "ci_lo": ci_lo, "ci_hi": ci_hi, "std": std, "n": n}
    )


[docs] class BatchCorrectionBenchmark: """Diagnostic comparison of multiple batch correctors. Parameters ---------- correctors : dict[str, BaseBatchCorrector] Mapping from a display name to an *unfitted* :class:`BaseBatchCorrector`. Each will be cloned per protocol iteration (so calling ``fit`` on the benchmark does not leave the input correctors fitted). metrics : sequence of str or callable, default=('kbet', 'lisi_normalized', 'species_preservation') Metric specifications. Strings are resolved against the registry in :mod:`maldibatchkit.diagnostics`; callables are invoked as ``metric(X_corrected, batch, species=species, **extra)`` and only receive the kwargs they actually accept. protocol : {'full_data', 'stratified_split'}, default='full_data' ``'full_data'`` fits each corrector on all rows and scores on the same rows (Büttner-2019 convention). ``'stratified_split'`` fits on a stratified train split and scores on the held-out test split; every batch is forced into both folds. test_size : float, default=0.2 Test fraction for ``'stratified_split'``. n_repeats : int, default=1 Number of repeated splits for ``'stratified_split'``. n_bootstrap : int, default=0 If non-zero, every metric is recomputed on ``n_bootstrap`` stratified row-resamples of the (corrected) matrix to give confidence intervals. ``0`` disables bootstrapping. bootstrap_mode : {'resample_metric', 'refit'}, default='resample_metric' ``'resample_metric'`` (the fast, default mode) fits each corrector once and resamples rows of the corrected matrix to score the metric repeatedly - CIs reflect the metric's sampling noise only. ``'refit'`` resamples rows of ``X`` and refits the corrector for every bootstrap iteration; this is slower (``n_correctors × n_bootstrap`` extra fits per repeat) but the CI also captures corrector stability. random_state : int or np.random.Generator, optional Seed / generator for splits and bootstrap. Attributes ---------- results_long_ : pd.DataFrame Tidy raw observations: columns ``method``, ``metric``, ``repeat``, ``bootstrap``, ``value``. ``bootstrap == -1`` marks the point-estimate row (no bootstrap resampling). results_ : pd.DataFrame Per-(method, metric) summary: ``value`` (mean), ``ci_lo`` / ``ci_hi`` (2.5 / 97.5 percentile), ``std`` (sample std), ``n`` (number of observations), ``better`` (``'higher'``, ``'lower'``, ``'zero'`` for metrics where both signs are bad like ``silhouette_batch``, or ``'n/a'`` for user callables - annotate via :attr:`_METRIC_DIRECTION` for registered names). corrected_ : dict[str, pd.DataFrame] For ``protocol='full_data'``, the fitted-and-transformed matrix from each corrector. For ``protocol='stratified_split'``, the corrected *test* matrix from the **last** repeat (provided for downstream inspection / plotting; use ``results_long_`` for per-repeat statistics). baseline_ : pd.DataFrame One-row-per-metric report on the **uncorrected** ``X``, mirroring the ``results_`` schema. Examples -------- >>> from maldibatchkit import ComBat, NoOpCorrector >>> from maldibatchkit.diagnostics import BatchCorrectionBenchmark >>> bench = BatchCorrectionBenchmark( ... correctors={ ... "none": NoOpCorrector(batch=b), ... "combat-fortin": ComBat(batch=b, method="fortin"), ... }, ... metrics=("kbet", "species_preservation"), ... n_bootstrap=200, ... random_state=0, ... ) >>> bench.fit(X, batch=b, species=s) # doctest: +SKIP >>> bench.rank(by="species_preservation") # doctest: +SKIP """
[docs] def __init__( self, correctors: Mapping[str, BaseBatchCorrector], *, metrics: Sequence[Any] = ("kbet", "lisi_normalized", "species_preservation"), protocol: str = "full_data", test_size: float = 0.2, n_repeats: int = 1, n_bootstrap: int = 0, bootstrap_mode: str = "resample_metric", random_state: int | np.random.Generator | None = None, ) -> None: if protocol not in ("full_data", "stratified_split"): raise ValueError( f"protocol must be 'full_data' or 'stratified_split'; got {protocol!r}." ) if bootstrap_mode not in ("resample_metric", "refit"): raise ValueError( "bootstrap_mode must be 'resample_metric' or 'refit'; " f"got {bootstrap_mode!r}." ) if not correctors: raise ValueError("`correctors` must contain at least one entry.") self.correctors = dict(correctors) self.metrics = tuple(metrics) self.protocol = protocol self.test_size = test_size self.n_repeats = int(n_repeats) self.n_bootstrap = int(n_bootstrap) self.bootstrap_mode = bootstrap_mode self.random_state = random_state
def _rng(self) -> np.random.Generator: if isinstance(self.random_state, np.random.Generator): return self.random_state return np.random.default_rng(self.random_state) def _resolve_metrics(self) -> list[tuple[str, Callable[..., float], str]]: resolved = [] for m in self.metrics: name, fn = _resolve_metric(m) direction = _METRIC_DIRECTION.get(name, "n/a") resolved.append((name, fn, direction)) return resolved
[docs] def fit( self, X: ArrayLike, *, batch: ArrayLike, species: ArrayLike | None = None, y: ArrayLike | None = None, **extra: Any, ) -> BatchCorrectionBenchmark: """Run every corrector under the chosen protocol and score each metric. Parameters ---------- X : array-like of shape (n_samples, n_features) Feature matrix. batch : array-like of shape (n_samples,) Batch labels. species : array-like, optional Forwarded to metrics that need it (e.g. ``species_preservation``). y : array-like, optional Ignored at the benchmark level (no classifier scoring). Kept on the signature so the call site mirrors sklearn. **extra : Any Forwarded to every metric callable that accepts the given keyword (e.g. ``mz_values=`` for ``peak_position_drift``). Returns ------- self : BatchCorrectionBenchmark Fitted benchmark with ``results_long_``, ``results_``, ``corrected_``, ``baseline_`` populated. """ if isinstance(X, pd.DataFrame): X_df = X.copy() else: X_df = pd.DataFrame(np.asarray(X)) idx = X_df.index batch_arr = ( batch.loc[idx].to_numpy() if isinstance(batch, pd.Series | pd.DataFrame) else np.asarray(batch) ) species_aligned: Any if species is None: species_aligned = None elif isinstance(species, pd.Series | pd.DataFrame): species_aligned = species.loc[idx] else: species_aligned = pd.Series(np.asarray(species), index=idx) metrics_resolved = self._resolve_metrics() rng = self._rng() if self.protocol == "full_data": long_rows, corrected, baseline_rows = self._run_full_data( X_df, batch_arr, species_aligned, metrics_resolved, extra, rng ) else: long_rows, corrected, baseline_rows = self._run_stratified( X_df, batch_arr, species_aligned, metrics_resolved, extra, rng ) results_long = pd.DataFrame(long_rows) self.results_long_ = results_long direction_lookup = {name: d for name, _, d in metrics_resolved} has_bootstrap = self.n_bootstrap >= 2 summary = ( results_long.groupby(["method", "metric"], sort=False) .apply( lambda g: _summarise(g, has_bootstrap=has_bootstrap), include_groups=False, ) .reset_index() ) summary["better"] = summary["metric"].map(direction_lookup).fillna("n/a") self.results_ = summary self.corrected_ = corrected self.baseline_ = pd.DataFrame(baseline_rows) return self
def _score_corrected( self, X_corrected: pd.DataFrame, batch: np.ndarray, species_aligned: Any, metrics_resolved: list[tuple[str, Callable[..., float], str]], extra: Mapping[str, Any], method_name: str, repeat: int, rng: np.random.Generator, ) -> list[dict[str, Any]]: """Score a single corrected matrix; bootstrap rows when requested. Always emits the point estimate (``bootstrap=-1``) plus ``n_bootstrap`` resampled scores when ``bootstrap_mode`` is ``'resample_metric'``. """ rows: list[dict[str, Any]] = [] species_arr = ( species_aligned.to_numpy() if isinstance(species_aligned, pd.Series) else (None if species_aligned is None else np.asarray(species_aligned)) ) for name, fn, _direction in metrics_resolved: value = _call_metric( fn, X_corrected, batch, species=species_arr, extra=extra ) rows.append( { "method": method_name, "metric": name, "repeat": repeat, "bootstrap": -1, "value": value, } ) if self.n_bootstrap > 0 and self.bootstrap_mode == "resample_metric": boot_idx_sets = _bootstrap_indices(rng, batch, self.n_bootstrap) for b, sel in enumerate(boot_idx_sets): Xb = X_corrected.iloc[sel] bb = batch[sel] sb = species_arr[sel] if species_arr is not None else None for name, fn, _direction in metrics_resolved: val = _call_metric(fn, Xb, bb, species=sb, extra=extra) rows.append( { "method": method_name, "metric": name, "repeat": repeat, "bootstrap": b, "value": val, } ) return rows def _fit_transform( self, corrector: BaseBatchCorrector, X_train: pd.DataFrame, X_eval: pd.DataFrame ) -> pd.DataFrame: """Fit on ``X_train`` and transform ``X_eval`` (may be the same). ``clone`` is used so the user's original corrector instances stay unfitted, and so per-protocol-iteration fits are independent. """ c = clone(corrector) c.fit(X_train) out = c.transform(X_eval) if isinstance(out, pd.DataFrame): return out return pd.DataFrame(np.asarray(out), index=X_eval.index, columns=X_eval.columns) def _run_full_data( self, X_df: pd.DataFrame, batch: np.ndarray, species_aligned: Any, metrics_resolved: list[tuple[str, Callable[..., float], str]], extra: Mapping[str, Any], rng: np.random.Generator, ) -> tuple[list[dict[str, Any]], dict[str, pd.DataFrame], list[dict[str, Any]]]: long_rows: list[dict[str, Any]] = [] corrected: dict[str, pd.DataFrame] = {} baseline_rows = self._score_baseline( X_df, batch, species_aligned, metrics_resolved, extra ) for name, corrector in self.correctors.items(): X_corr = self._fit_transform(corrector, X_df, X_df) corrected[name] = X_corr long_rows.extend( self._score_corrected( X_corr, batch, species_aligned, metrics_resolved, extra, method_name=name, repeat=0, rng=rng, ) ) if self.n_bootstrap > 0 and self.bootstrap_mode == "refit": boot_idx_sets = _bootstrap_indices(rng, batch, self.n_bootstrap) for b, sel in enumerate(boot_idx_sets): X_boot = X_df.iloc[sel] b_boot = batch[sel] sp_boot = ( species_aligned.iloc[sel] if isinstance(species_aligned, pd.Series) else None ) try: X_corr_b = self._fit_transform(corrector, X_boot, X_boot) except Exception as exc: # noqa: BLE001 warnings.warn( f"Bootstrap refit failed for method {name!r} on " f"iteration {b}: {exc}. Skipping.", stacklevel=2, ) continue sp_arr = sp_boot.to_numpy() if sp_boot is not None else None for mname, fn, _ in metrics_resolved: val = _call_metric( fn, X_corr_b, b_boot, species=sp_arr, extra=extra ) long_rows.append( { "method": name, "metric": mname, "repeat": 0, "bootstrap": b, "value": val, } ) return long_rows, corrected, baseline_rows def _run_stratified( self, X_df: pd.DataFrame, batch: np.ndarray, species_aligned: Any, metrics_resolved: list[tuple[str, Callable[..., float], str]], extra: Mapping[str, Any], rng: np.random.Generator, ) -> tuple[list[dict[str, Any]], dict[str, pd.DataFrame], list[dict[str, Any]]]: seed = int(rng.integers(0, 2**31 - 1)) sss = StratifiedShuffleSplit( n_splits=self.n_repeats, test_size=self.test_size, random_state=seed ) long_rows: list[dict[str, Any]] = [] corrected: dict[str, pd.DataFrame] = {} baseline_rows = self._score_baseline( X_df, batch, species_aligned, metrics_resolved, extra ) for repeat, (train_idx, test_idx) in enumerate(sss.split(X_df, batch)): X_train = X_df.iloc[train_idx] X_test = X_df.iloc[test_idx] batch_test = batch[test_idx] species_test = ( species_aligned.iloc[test_idx] if isinstance(species_aligned, pd.Series) else None ) for name, corrector in self.correctors.items(): X_corr = self._fit_transform(corrector, X_train, X_test) if repeat == self.n_repeats - 1: corrected[name] = X_corr long_rows.extend( self._score_corrected( X_corr, batch_test, species_test, metrics_resolved, extra, method_name=name, repeat=repeat, rng=rng, ) ) return long_rows, corrected, baseline_rows def _score_baseline( self, X_df: pd.DataFrame, batch: np.ndarray, species_aligned: Any, metrics_resolved: list[tuple[str, Callable[..., float], str]], extra: Mapping[str, Any], ) -> list[dict[str, Any]]: species_arr = ( species_aligned.to_numpy() if isinstance(species_aligned, pd.Series) else (None if species_aligned is None else np.asarray(species_aligned)) ) rows = [] for name, fn, direction in metrics_resolved: value = _call_metric(fn, X_df, batch, species=species_arr, extra=extra) rows.append( { "method": "__baseline__", "metric": name, "value": value, "ci_lo": float("nan"), "ci_hi": float("nan"), "std": float("nan"), "n": 1, "better": direction, } ) return rows
[docs] def rank(self, by: str, *, ascending: bool | None = None) -> pd.DataFrame: """Return ``results_`` sorted by one metric's mean value. Parameters ---------- by : str Metric name to rank on. ascending : bool, optional Sort direction. If omitted, the metric's registered "better" direction is used: ``'higher'`` → descending, ``'lower'`` → ascending, ``'zero'`` → ascending by ``|value|`` (e.g. ``silhouette_batch``, where both positive and negative extremes are bad - only 0 is well-mixed). """ if not hasattr(self, "results_"): raise RuntimeError("Call .fit(...) before .rank().") sub = self.results_[self.results_["metric"] == by] if sub.empty: raise ValueError(f"No metric named {by!r} in results_.") direction = sub["better"].iloc[0] if direction == "zero": asc = True if ascending is None else ascending ordered = sub.assign(_abs=sub["value"].abs()).sort_values( "_abs", ascending=asc ) return ordered.drop(columns="_abs").reset_index(drop=True) if ascending is None: ascending = direction != "higher" return sub.sort_values("value", ascending=ascending).reset_index(drop=True)
[docs] def to_dataframe(self) -> pd.DataFrame: """Alias for :attr:`results_`.""" if not hasattr(self, "results_"): raise RuntimeError("Call .fit(...) before .to_dataframe().") return self.results_
[docs] def plot(self, ax: Any = None) -> Any: """Bar-plot the summary results, faceted by metric. Requires the ``viz`` extra (seaborn / matplotlib). Returns the matplotlib ``Axes`` (single metric) or ``Figure`` (multiple). """ if not hasattr(self, "results_"): raise RuntimeError("Call .fit(...) before .plot().") try: import matplotlib.pyplot as plt import seaborn as sns except ImportError as exc: # pragma: no cover raise ImportError( "BatchCorrectionBenchmark.plot() needs matplotlib + seaborn; " "install the `viz` extra (pip install 'maldibatchkit[viz]')." ) from exc df = self.results_ metrics = list(df["metric"].unique()) if ax is not None and len(metrics) > 1: raise ValueError( "Pass `ax` only when results_ contains a single metric; " f"got {len(metrics)} metrics." ) if len(metrics) == 1: ax = ax or plt.gca() sns.barplot(data=df, x="method", y="value", ax=ax, color="#4C72B0") if df[["ci_lo", "ci_hi"]].notna().all().all(): ax.errorbar( x=range(len(df)), y=df["value"], yerr=[df["value"] - df["ci_lo"], df["ci_hi"] - df["value"]], fmt="none", ecolor="black", capsize=3, ) ax.set_title(metrics[0]) # Anchor-rotated, right-aligned labels stay under their bars. plt.setp( ax.get_xticklabels(), rotation=30, ha="right", rotation_mode="anchor", ) return ax fig, axes = plt.subplots(1, len(metrics), figsize=(4 * len(metrics), 4)) for ax_i, metric in zip(axes, metrics, strict=True): sub = df[df["metric"] == metric] sns.barplot(data=sub, x="method", y="value", ax=ax_i, color="#4C72B0") if sub[["ci_lo", "ci_hi"]].notna().all().all(): ax_i.errorbar( x=range(len(sub)), y=sub["value"], yerr=[sub["value"] - sub["ci_lo"], sub["ci_hi"] - sub["value"]], fmt="none", ecolor="black", capsize=3, ) ax_i.set_title(metric) plt.setp( ax_i.get_xticklabels(), rotation=30, ha="right", rotation_mode="anchor", ) fig.tight_layout() return fig