Source code for maldibatchkit.corrections.maldi

"""MALDI-specific batch corrections that leverage ``maldiamrkit`` internals."""

from __future__ import annotations

from typing import Any

import numpy as np
import numpy.typing as npt
import pandas as pd

from .._base import BaseBatchCorrector
from .._utils import ArrayLike

__all__ = ["BatchAwareWarping"]


def _try_import_maldiamrkit():
    try:
        import maldiamrkit  # noqa: F401
    except ImportError as exc:
        raise ImportError(
            "This transformer requires 'maldiamrkit' (normally installed "
            "alongside MaldiBatchKit). Reinstall with "
            "`pip install -U maldibatchkit` or "
            "`pip install maldiamrkit>=0.13.0`."
        ) from exc
    return maldiamrkit


[docs] class BatchAwareWarping(BaseBatchCorrector): r"""Per-batch m/z warping to a single global reference. This transformer reuses :class:`maldiamrkit.alignment.Warping` under the hood, but fits one warper **per batch**. At ``fit`` time we learn a global reference spectrum from all training rows, then fit a dedicated ``Warping`` instance per batch with that shared reference. At ``transform`` time each sample is routed through the warper of its batch. Why this matters ---------------- Mass-calibration drift is usually a per-acquisition-session phenomenon. Running a single global ``Warping`` can either be too aggressive (squashing real biological peaks in well-calibrated batches) or too conservative (leaving large shifts in a miscalibrated batch). Per-batch warpers share the reference but learn a batch-specific transformation. Parameters ---------- batch : array-like of shape (n_samples,) Batch labels. reference : str, default="median" How to build the global reference. ``"median"`` uses the median spectrum across all training rows. ``"batch_mean"`` uses the mean of per-batch medians (reduces bias towards large batches). method : str, default="shift" Warping method passed through to ``maldiamrkit.alignment.Warping``. n_segments : int, default=5 Piecewise warping segments. max_shift : int, default=50 Maximum shift in bins. dtw_radius : int, default=10 DTW radius constraint. smooth_sigma : float, default=2.0 Gaussian smoothing for piecewise shifts. n_jobs : int, default=1 Parallel jobs forwarded to ``Warping``. Attributes ---------- reference\_ : np.ndarray The global reference spectrum. warpers\_ : dict[Any, Warping] One fitted ``Warping`` per batch level. Notes ----- For batches that do not appear at ``fit`` time but show up during ``transform``, we fall back to a global warper trained on all training rows (stored as :attr:`warpers_` under the ``"__global__"`` key). This keeps the transformer defined on test folds without peeking at test data. Examples -------- >>> from maldibatchkit import BatchAwareWarping >>> warper = BatchAwareWarping(batch=batches, method="piecewise") >>> X_aligned = warper.fit_transform(X) """
[docs] def __init__( self, batch: ArrayLike, *, reference: str = "median", method: str = "shift", n_segments: int = 5, max_shift: int = 50, dtw_radius: int = 10, smooth_sigma: float = 2.0, n_jobs: int = 1, ) -> None: super().__init__(batch=batch) self.reference = reference self.method = method self.n_segments = n_segments self.max_shift = max_shift self.dtw_radius = dtw_radius self.smooth_sigma = smooth_sigma self.n_jobs = n_jobs
def _make_warper(self): _try_import_maldiamrkit() from maldiamrkit.alignment import Warping # type: ignore[import-not-found] return Warping( reference="median", method=self.method, n_segments=self.n_segments, max_shift=self.max_shift, dtw_radius=self.dtw_radius, smooth_sigma=self.smooth_sigma, n_jobs=self.n_jobs, ) def _build_reference( self, X_df: pd.DataFrame, batch: npt.NDArray[Any] ) -> np.ndarray: if self.reference == "median": return X_df.median(axis=0).to_numpy() if self.reference == "batch_mean": per_batch = [] for lvl in np.unique(batch): per_batch.append(X_df.loc[batch == lvl].median(axis=0).to_numpy()) return np.mean(per_batch, axis=0) raise ValueError( f"Unknown reference={self.reference!r}; expected 'median' or 'batch_mean'." ) def _fit_warper_on(self, df: pd.DataFrame, reference: np.ndarray): w = self._make_warper() w.fit(df) w.ref_spec_ = reference # swap in the shared global reference return w def _fit_impl(self, X_df: pd.DataFrame, batch: npt.NDArray[Any]) -> None: self.reference_ = self._build_reference(X_df, batch) warpers: dict[Any, Any] = {} for lvl in np.unique(batch): df = X_df.loc[batch == lvl] warpers[lvl] = self._fit_warper_on(df, self.reference_) warpers["__global__"] = self._fit_warper_on(X_df, self.reference_) self.warpers_ = warpers def _transform_impl( self, X_df: pd.DataFrame, batch: npt.NDArray[Any] ) -> pd.DataFrame: out = X_df.copy().astype(float) for lvl in np.unique(batch): mask = batch == lvl chunk = X_df.loc[mask] warper = self.warpers_.get(lvl, self.warpers_["__global__"]) out.loc[mask] = warper.transform(chunk).to_numpy() return out