"""Batch-aware downstream classifier metrics.
The functions in this module compute a sklearn metric **per batch** on a
held-out fold, then aggregate the per-batch scalars into a single number
according to a chosen weighting scheme. They are designed for selecting
batch correctors with :class:`~maldibatchkit.AutoCorrector` and
:class:`sklearn.model_selection.GridSearchCV` when the goal is a model
that performs well *on every site*, not one that just maximises the
overall pooled score.
Weighting modes
---------------
All ``batch_*_score`` functions accept ``weights=``:
* ``'uniform'`` - each batch contributes one entry to the aggregate
with equal weight regardless of its sample count
(``w_i = 1 / n_batches``). The right default when you want a
corrector that generalises across sites rather than one that
optimises the dominant site.
* ``'balanced'`` - weights inversely proportional to per-batch size
(``w_i ∝ 1 / n_i``). The **smallest site gets the loudest voice**,
mirroring sklearn's ``class_weight='balanced'`` formula applied at
the per-batch level. Use this when minority sites are the hardest
distributions to learn and the corrector must not crush them.
* ``'size'`` - weights proportional to per-batch sample counts.
The relationship to the standard pooled value depends on the metric
family:
- **Additive `correct / total` metrics** (accuracy, error rate):
``'size'`` exactly recovers the pooled value.
- **Class-conditional rates** (sensitivity, specificity): the pooled
value is recovered by *class-conditional* weights (positives-
per-batch for sensitivity, negatives-per-batch for specificity),
not by total-sample weights. Equal only when class prevalence is
constant across batches.
- **Non-linear metrics** (F1, precision, balanced accuracy, MCC):
no per-batch weighting recovers the pooled value exactly -
they are ratios of sums or non-linear functions of confusion-
matrix counts.
- **Pairwise / rank-based metrics** (AUROC, average precision):
the pooled value also counts *cross-batch* positive-vs-negative
pairs, which any per-batch reducer cannot see by construction;
those pairs are gone, not approximated.
In every non-additive case, ``'size'`` still gives the dominant
site the loudest voice - useful when you want to score the
classifier as a single-site practitioner would.
* ``Mapping[label, float]`` or array - explicit per-batch weights. The
array must be aligned with ``np.unique(batch)``.
Within each batch, the wrapped sklearn metric's own averaging parameter
(``average='binary'`` / ``'macro'`` / ``'weighted'`` etc.) is preserved
and untouched. ``weights=`` only controls the *across-batch* reducer.
Degenerate folds
----------------
When a batch contains only one class in the held-out fold, AUROC and
average precision are undefined; F1 / precision / recall fall back to
zero. These batches are **silently dropped from the aggregate with a
warning** and the remaining weights are renormalised. Use ``weights=``
plus a careful split design if you need a different policy.
``make_batch_scorer`` factory
-----------------------------
:func:`make_batch_scorer` wraps the per-batch metric functions into a
:mod:`sklearn` ``scorer(estimator, X, y)`` callable suitable for
``GridSearchCV(scoring=...)``. It captures the full ``batch`` Series at
factory time and slices it by ``y.index`` for every fold.
"""
from __future__ import annotations
import warnings
from collections.abc import Callable, Mapping
from typing import Any, Literal
import numpy as np
import pandas as pd
from sklearn.metrics import (
average_precision_score,
balanced_accuracy_score,
f1_score,
matthews_corrcoef,
precision_score,
recall_score,
roc_auc_score,
)
from ._utils import ArrayLike
WeightSpec = (
Literal["uniform", "balanced", "size"] | Mapping[Any, float] | np.ndarray | list
)
__all__ = [
"batch_average_precision_score",
"batch_balanced_accuracy_score",
"batch_f1_score",
"batch_matthews_corrcoef",
"batch_precision_score",
"batch_recall_score",
"batch_roc_auc_score",
"make_batch_scorer",
]
def _resolve_weights(
batch_levels: np.ndarray,
sizes: np.ndarray,
spec: WeightSpec,
) -> np.ndarray:
"""Resolve a weight specification to a non-negative vector summing to 1."""
if isinstance(spec, str):
if spec == "uniform":
w = np.ones_like(sizes, dtype=float)
elif spec == "balanced":
sizes_f = sizes.astype(float)
if np.any(sizes_f <= 0):
raise ValueError(
"weights='balanced' requires every batch to have at "
"least one sample; got sizes "
f"{sizes_f.tolist()}."
)
w = 1.0 / sizes_f
elif spec == "size":
w = sizes.astype(float)
else:
raise ValueError(
f"Unknown weights mode {spec!r}. "
f"Expected one of 'uniform', 'balanced', 'size', or a dict / array."
)
elif isinstance(spec, Mapping):
missing = [lvl for lvl in batch_levels if lvl not in spec]
if missing:
raise ValueError(
f"weights dict is missing entries for batch levels: {missing}."
)
w = np.asarray([float(spec[lvl]) for lvl in batch_levels], dtype=float)
else:
w = np.asarray(spec, dtype=float)
if w.shape != batch_levels.shape:
raise ValueError(
f"weights array length {w.shape[0]} does not match the number "
f"of batch levels {batch_levels.shape[0]}."
)
if (w < 0).any():
raise ValueError("weights must be non-negative.")
total = w.sum()
if total <= 0:
raise ValueError("at least one weight must be positive.")
return w / total
def _to_1d_array(x: ArrayLike) -> np.ndarray:
if isinstance(x, pd.Series | pd.DataFrame):
return np.asarray(x).reshape(-1)
return np.asarray(x)
def _per_batch_metric(
y_true: ArrayLike,
y_pred_or_score: ArrayLike,
batch: ArrayLike,
metric_fn: Callable[..., float],
*,
weights: WeightSpec = "uniform",
requires_two_classes: bool = False,
**metric_kwargs: Any,
) -> float:
"""Apply ``metric_fn`` per batch and aggregate.
Parameters
----------
y_true, y_pred_or_score, batch
1-D arrays of equal length. Either pandas Series or numpy arrays.
metric_fn
Sklearn-style ``(y_true, y_pred_or_score, **kwargs)`` -> scalar.
weights
See :data:`WeightSpec`.
requires_two_classes
If True, batches that contain only one class in ``y_true`` are
skipped (the metric is undefined). Skipped batches are logged via
a ``UserWarning`` and the weights are renormalised.
metric_kwargs
Forwarded verbatim to ``metric_fn`` (e.g. ``average='macro'``,
``pos_label=0``).
"""
y_true_arr = _to_1d_array(y_true)
y_pred_arr = _to_1d_array(y_pred_or_score)
batch_arr = _to_1d_array(batch)
if not (len(y_true_arr) == len(y_pred_arr) == len(batch_arr)):
raise ValueError(
f"y_true, predictions, and batch must have the same length; got "
f"{len(y_true_arr)}, {len(y_pred_arr)}, {len(batch_arr)}."
)
levels = np.unique(batch_arr)
n_levels = len(levels)
scores = np.full(n_levels, np.nan, dtype=float)
sizes = np.zeros(n_levels, dtype=int)
for i, lvl in enumerate(levels):
mask = batch_arr == lvl
sizes[i] = int(mask.sum())
if sizes[i] < 2:
continue
if requires_two_classes and np.unique(y_true_arr[mask]).size < 2:
continue
try:
scores[i] = float(
metric_fn(y_true_arr[mask], y_pred_arr[mask], **metric_kwargs)
)
except ValueError:
# Metric undefined for this fold (e.g. AUROC with one class).
scores[i] = np.nan
valid = ~np.isnan(scores)
if not valid.any():
warnings.warn(
"All batches were degenerate for this metric (single-class folds "
"or fewer than two samples). Returning NaN.",
stacklevel=2,
)
return float("nan")
n_dropped = int((~valid).sum())
if n_dropped:
warnings.warn(
f"{n_dropped} batch(es) dropped from per-batch aggregation "
f"(metric undefined on that fold). Renormalising weights across "
f"the remaining {int(valid.sum())} batch(es).",
stacklevel=2,
)
w = _resolve_weights(levels[valid], sizes[valid], weights)
return float(np.sum(w * scores[valid]))
[docs]
def batch_roc_auc_score(
y_true: ArrayLike,
y_score: ArrayLike,
*,
batch: ArrayLike,
weights: WeightSpec = "uniform",
**kwargs: Any,
) -> float:
"""Per-batch AUROC, aggregated with ``weights``.
Parameters
----------
y_true : array-like of shape (n_samples,)
Ground-truth binary labels.
y_score : array-like of shape (n_samples,)
Probability of the positive class (or any monotone score).
batch : array-like of shape (n_samples,)
Per-sample batch labels.
weights : str or mapping or array, default='uniform'
See module docstring.
**kwargs
Forwarded to :func:`sklearn.metrics.roc_auc_score`
(``average``, ``multi_class``, ...).
"""
return _per_batch_metric(
y_true,
y_score,
batch,
roc_auc_score,
weights=weights,
requires_two_classes=True,
**kwargs,
)
[docs]
def batch_average_precision_score(
y_true: ArrayLike,
y_score: ArrayLike,
*,
batch: ArrayLike,
weights: WeightSpec = "uniform",
**kwargs: Any,
) -> float:
"""Per-batch average precision (PR-AUC), aggregated with ``weights``."""
return _per_batch_metric(
y_true,
y_score,
batch,
average_precision_score,
weights=weights,
requires_two_classes=True,
**kwargs,
)
[docs]
def batch_balanced_accuracy_score(
y_true: ArrayLike,
y_pred: ArrayLike,
*,
batch: ArrayLike,
weights: WeightSpec = "uniform",
**kwargs: Any,
) -> float:
"""Per-batch balanced accuracy, aggregated with ``weights``."""
return _per_batch_metric(
y_true,
y_pred,
batch,
balanced_accuracy_score,
weights=weights,
requires_two_classes=True,
**kwargs,
)
[docs]
def batch_matthews_corrcoef(
y_true: ArrayLike,
y_pred: ArrayLike,
*,
batch: ArrayLike,
weights: WeightSpec = "uniform",
) -> float:
"""Per-batch Matthews correlation coefficient, aggregated with ``weights``.
MCC is in ``[-1, 1]``; 0 means random prediction. A batch where
``y_pred`` is constant or only one class is present yields MCC = 0
and is **not** considered degenerate (sklearn returns 0 with a
warning) -- this differs from AUROC where the metric is undefined.
"""
return _per_batch_metric(
y_true,
y_pred,
batch,
matthews_corrcoef,
weights=weights,
requires_two_classes=False,
)
[docs]
def batch_f1_score(
y_true: ArrayLike,
y_pred: ArrayLike,
*,
batch: ArrayLike,
weights: WeightSpec = "uniform",
average: str | None = "binary",
**kwargs: Any,
) -> float:
"""Per-batch F1 score, aggregated with ``weights``.
Parameters
----------
average : str, default='binary'
Forwarded to :func:`sklearn.metrics.f1_score`. Controls how the
F1 is averaged *within* each batch's class set (``'binary'``,
``'macro'``, ``'weighted'``, ``'micro'``). The across-batch
reduction is controlled by ``weights=``.
"""
return _per_batch_metric(
y_true,
y_pred,
batch,
f1_score,
weights=weights,
average=average,
**kwargs,
)
[docs]
def batch_precision_score(
y_true: ArrayLike,
y_pred: ArrayLike,
*,
batch: ArrayLike,
weights: WeightSpec = "uniform",
average: str | None = "binary",
pos_label: int | str = 1,
**kwargs: Any,
) -> float:
"""Per-batch precision, aggregated with ``weights``.
Set ``pos_label=0`` (or ``average='macro'`` / ``'micro'`` /
``'weighted'``) for negative-class or aggregated precision; see
:func:`sklearn.metrics.precision_score`.
"""
return _per_batch_metric(
y_true,
y_pred,
batch,
precision_score,
weights=weights,
average=average,
pos_label=pos_label,
**kwargs,
)
[docs]
def batch_recall_score(
y_true: ArrayLike,
y_pred: ArrayLike,
*,
batch: ArrayLike,
weights: WeightSpec = "uniform",
average: str | None = "binary",
pos_label: int | str = 1,
**kwargs: Any,
) -> float:
"""Per-batch recall, aggregated with ``weights``."""
return _per_batch_metric(
y_true,
y_pred,
batch,
recall_score,
weights=weights,
average=average,
pos_label=pos_label,
**kwargs,
)
# (metric_fn, default response_method)
_SCORER_REGISTRY: dict[str, tuple[Callable[..., float], str]] = {
"roc_auc": (batch_roc_auc_score, "predict_proba"),
"average_precision": (batch_average_precision_score, "predict_proba"),
"balanced_accuracy": (batch_balanced_accuracy_score, "predict"),
"matthews_corrcoef": (batch_matthews_corrcoef, "predict"),
"mcc": (batch_matthews_corrcoef, "predict"),
"f1": (batch_f1_score, "predict"),
"precision": (batch_precision_score, "predict"),
"recall": (batch_recall_score, "predict"),
}
def _resolve_metric(
metric: str | Callable[..., float],
response_method: str | None,
) -> tuple[Callable[..., float], str]:
if callable(metric):
if response_method is None:
raise ValueError(
"When passing a callable metric, `response_method` must be set "
"to 'predict_proba', 'predict', or 'decision_function'."
)
return metric, response_method
if metric not in _SCORER_REGISTRY:
raise ValueError(
f"Unknown metric alias {metric!r}. Registered: {sorted(_SCORER_REGISTRY)}."
)
fn, default_response = _SCORER_REGISTRY[metric]
return fn, response_method or default_response
[docs]
def make_batch_scorer(
batch: ArrayLike,
metric: str | Callable[..., float] = "roc_auc",
*,
weights: WeightSpec = "uniform",
response_method: str | None = None,
greater_is_better: bool = True,
pos_class_index: int = 1,
**metric_kwargs: Any,
) -> Callable[..., float]:
"""Return a sklearn-compatible ``scorer(estimator, X, y)``.
Parameters
----------
batch : array-like or pandas.Series
Per-sample batch labels for the **full** dataset. The scorer
slices this by ``y.index`` on each fold; pass a
:class:`pandas.Series` indexed by the sample IDs used by ``X``
and ``y`` for safe alignment.
metric : str or callable, default='roc_auc'
Registered alias (``'roc_auc'``, ``'average_precision'``,
``'balanced_accuracy'``, ``'mcc'`` / ``'matthews_corrcoef'``,
``'f1'``, ``'precision'``, ``'recall'``), or a callable with the
same signature as the ``batch_*_score`` functions
(``(y_true, y_pred_or_score, *, batch, weights, **kw)``).
weights : str or mapping or array, default='uniform'
Per-batch aggregation weighting (see module docstring).
response_method : {'predict_proba', 'predict', 'decision_function'}, optional
How the scorer should produce predictions from the estimator.
Defaults to the metric's natural choice (``'predict_proba'`` for
AUROC / AP, ``'predict'`` for the rest). Required when ``metric``
is a callable.
greater_is_better : bool, default=True
If False, the returned scorer negates the metric so larger is
always better (sklearn convention).
pos_class_index : int, default=1
Column index in ``predict_proba(...)`` to use as the positive
score when ``response_method='predict_proba'``.
**metric_kwargs
Forwarded verbatim to the underlying metric (e.g.
``average='macro'``, ``pos_label=0``).
Returns
-------
scorer : callable
``scorer(estimator, X, y)`` -> float, signed by
``greater_is_better``.
Examples
--------
>>> from sklearn.model_selection import GridSearchCV
>>> from sklearn.pipeline import Pipeline
>>> from sklearn.linear_model import LogisticRegression
>>> from maldibatchkit import AutoCorrector
>>> from maldibatchkit.metrics import make_batch_scorer
>>>
>>> scorer = make_batch_scorer(batch=batch, metric='roc_auc',
... weights='uniform')
>>> grid = GridSearchCV(
... Pipeline([
... ('correct', AutoCorrector(batch=batch,
... discrete_covariates=species)),
... ('clf', LogisticRegression()),
... ]),
... param_grid={'correct__method': ['noop', 'combat-fortin', 'harmony']},
... scoring=scorer,
... ) # doctest: +SKIP
"""
metric_fn, response = _resolve_metric(metric, response_method)
if response not in {"predict_proba", "predict", "decision_function"}:
raise ValueError(
f"response_method must be 'predict_proba', 'predict', or "
f"'decision_function'; got {response!r}."
)
if isinstance(batch, pd.Series):
batch_lookup = batch
else:
batch_lookup = pd.Series(np.asarray(batch))
def _align_batch(y: Any, X: Any) -> np.ndarray:
if isinstance(y, pd.Series | pd.DataFrame):
idx = y.index
elif isinstance(X, pd.DataFrame | pd.Series):
idx = X.index
else:
n = len(np.asarray(y))
warnings.warn(
"`y` is not a pandas object; aligning `batch` positionally on "
f"the first {n} entries of the captured batch Series. Pass `y` "
"as a pandas Series indexed by sample IDs for safe CV alignment.",
stacklevel=2,
)
return np.asarray(batch_lookup)[:n]
try:
return batch_lookup.loc[idx].to_numpy()
except KeyError as exc:
raise ValueError(
"y's index does not align with `batch`'s index. Index both by "
"the same sample IDs."
) from exc
def _scorer(estimator: Any, X: Any, y: Any, sample_weight: Any = None) -> float:
batch_fold = _align_batch(y, X)
if response == "predict_proba":
proba = estimator.predict_proba(X)
proba_arr = np.asarray(proba)
if proba_arr.ndim == 2 and proba_arr.shape[1] >= 2:
pred = proba_arr[:, pos_class_index]
else:
pred = proba_arr.reshape(-1)
elif response == "predict":
pred = estimator.predict(X)
else:
pred = estimator.decision_function(X)
score = metric_fn(
np.asarray(y),
np.asarray(pred),
batch=batch_fold,
weights=weights,
**metric_kwargs,
)
return float(score) if greater_is_better else -float(score)
_scorer._batch = batch_lookup # type: ignore[attr-defined]
_scorer._metric = metric_fn # type: ignore[attr-defined]
_scorer._weights = weights # type: ignore[attr-defined]
return _scorer