Source code for maldibatchkit.corrections.quality_weighted

r"""Quality-weighted ComBat.

This module implements a *weighted* empirical-Bayes variant of the
classical Johnson-et-al. (2007) ComBat algorithm. Samples with higher
acquisition quality (typically a signal-to-noise ratio, or any
user-supplied scalar score) contribute proportionally more to the
batch-level moment estimates, and the shrinkage prior on each batch is
scaled by the effective (weighted) batch size rather than the raw
sample count.

Mathematical formulation
------------------------
For feature :math:`g` in batch :math:`i`, let :math:`w_{ij} \\ge 0` be the
per-sample weight (weights are normalised so the weight vector sums to
the number of samples). Write :math:`\\bar w_i = \\sum_j w_{ij}` for the
effective batch size.

*Grand mean and pooled variance* are estimated with weights:

.. math::

    \\hat\\alpha_g &= \\frac{\\sum_{i,j} w_{ij} X_{ij,g}}{\\sum_{i,j} w_{ij}}

    \\hat\\sigma_g^2 &= \\frac{\\sum_{i,j} w_{ij} (X_{ij,g} - \\hat\\alpha_g)^2}
                              {\\sum_{i,j} w_{ij}}

*Per-batch L/S model.* After standardisation
:math:`Z_{ij,g} = (X_{ij,g} - \\hat\\alpha_g) / \\hat\\sigma_g`, the batch-
and-feature effects are estimated with weighted sums:

.. math::

    \\hat\\gamma_{i,g} &= \\frac{\\sum_j w_{ij} Z_{ij,g}}{\\bar w_i}

    \\hat\\delta_{i,g}^2 &= \\frac{\\sum_j w_{ij} (Z_{ij,g} - \\hat\\gamma_{i,g})^2}
                                {\\bar w_i - 1}

(``max(bar_w_i - 1, eps)`` is used to protect against tiny effective
batch sizes.)

*Empirical-Bayes priors.* The Normal-Inverse-Gamma hyperparameters
are fit by method-of-moments on the per-feature
:math:`\\hat\\gamma_{i,g}` (mean, variance) and per-feature
:math:`\\hat\\delta_{i,g}^2` (inverse-gamma :math:`\\alpha, \\beta`),
exactly as in classical ComBat.

*Posteriors.* The posterior mean and variance are then the usual
weighted combination of prior and data, but with the **effective** batch
size :math:`\\bar w_i` replacing :math:`n_i`:

.. math::

    \\gamma^*_{i,g} &= \\frac{\\bar w_i \\bar\\tau_i^2 \\hat\\gamma_{i,g}
                               + \\delta^*_{i,g} \\bar\\gamma_i}
                             {\\bar w_i \\bar\\tau_i^2 + \\delta^*_{i,g}}

    \\delta^{*2}_{i,g} &= \\frac{\\bar\\theta_i + 0.5 \\sum_j w_{ij}
                                 (Z_{ij,g} - \\gamma^*_{i,g})^2}
                               {\\bar w_i / 2 + \\bar\\lambda_i - 1}

with :math:`\\bar\\gamma_i, \\bar\\tau_i^2, \\bar\\lambda_i, \\bar\\theta_i`
the prior hyperparameters fit as in Johnson et al. (2007). The
iterative scheme follows Johnson's original fixed-point updates; only
the summation weights differ.

Rationale
---------
In MALDI-TOF AMR studies, not every spectrum is equally informative:
low-SNR acquisitions or spectra flagged by the lab QC are often kept in
the analysis to avoid discarding hard-won samples, but they should not
drive the per-batch correction. Down-weighting low-quality spectra in
the empirical-Bayes shrinkage keeps the posterior pulled towards the
estimate obtained from the high-quality spectra.

This corrector is **not** a replacement for more general covariate
modelling: like classical Johnson ComBat it does not preserve
biological signal encoded in a covariate. If you need that, either
residualise your covariate first or use
:class:`maldibatchkit.ComBat` with ``method='fortin'``.
"""

from __future__ import annotations

from typing import Any

import numpy as np
import numpy.typing as npt
import pandas as pd

from .._base import BaseBatchCorrector
from .._utils import ArrayLike, _subset

__all__ = ["QualityWeightedComBat"]


def _normalise_weights(
    w: npt.NDArray[np.float64], *, eps: float
) -> npt.NDArray[np.float64]:
    w = np.asarray(w, dtype=float)
    if np.any(w < 0):
        raise ValueError("quality weights must be non-negative.")
    if not np.any(w > 0):
        raise ValueError("at least one quality weight must be positive.")
    # Normalise so weights sum to the number of samples: equal weights
    # should recover the unweighted estimates exactly.
    w = w * (w.size / max(w.sum(), eps))
    return w


def _aprior(gamma_hat: np.ndarray) -> np.ndarray:
    """Method-of-moments hyperparameter for the gamma prior (per batch)."""
    m = gamma_hat.mean(axis=1)
    s2 = gamma_hat.var(axis=1, ddof=1)
    return (2 * s2 + m**2) / np.clip(s2, 1e-12, None)


def _bprior(gamma_hat: np.ndarray) -> np.ndarray:
    m = gamma_hat.mean(axis=1)
    s2 = gamma_hat.var(axis=1, ddof=1)
    return (m * s2 + m**3) / np.clip(s2, 1e-12, None)


def _postmean(
    g_hat: np.ndarray,
    g_bar: float,
    n_bar: float,
    t2: float,
    d_star: np.ndarray,
) -> np.ndarray:
    return (n_bar * t2 * g_hat + d_star * g_bar) / (n_bar * t2 + d_star)


def _postvar(
    sum_sq: np.ndarray,
    n_bar: float,
    a: float,
    b: float,
) -> np.ndarray:
    return (0.5 * sum_sq + b) / (n_bar / 2.0 + a - 1.0)


[docs] class QualityWeightedComBat(BaseBatchCorrector): r"""Weighted empirical-Bayes extension of Johnson-ComBat. Parameters ---------- batch : array-like of shape (n_samples,) Batch labels. quality : array-like of shape (n_samples,) Non-negative per-sample quality scores (typically SNR). Higher values mean a sample should contribute more to per-batch estimation. Internally rescaled to sum to ``n_samples``. parametric : bool, default=True Use the parametric EB fixed-point iteration (as in Johnson 2007). ``False`` falls back to the data estimates :math:`\\hat\\gamma` / :math:`\\hat\\delta^2` directly - this is useful as a sanity check. reference_batch : Any, optional Batch level to leave unchanged. eps : float, default=1e-8 Numerical jitter. max_iter : int, default=50 Hard cap on the parametric fixed-point iterations. tol : float, default=1e-4 Convergence tolerance on the max absolute change in :math:`\\gamma^*, \\delta^*`. Attributes ---------- batch_levels_ : np.ndarray Batch levels observed at fit time. grand_mean_ : np.ndarray Weighted feature-wise grand mean (:math:`\\hat\\alpha`). pooled_var_ : np.ndarray Weighted feature-wise pooled variance (:math:`\\hat\\sigma^2`). gamma_star_ : np.ndarray Posterior batch means, shape ``(n_batches, n_features)``. delta_star_ : np.ndarray Posterior batch variances, shape ``(n_batches, n_features)``. effective_batch_sizes_ : np.ndarray :math:`\\bar w_i` per batch (weighted sample counts). n_iter_ : int Iterations actually taken by the parametric EB loop. Examples -------- >>> from maldibatchkit import QualityWeightedComBat >>> corrector = QualityWeightedComBat(batch=batches, quality=snr) >>> X_corrected = corrector.fit_transform(X) References ---------- Johnson, W.E., Li, C. and Rabinovic, A. (2007). "Adjusting batch effects in microarray expression data using empirical Bayes methods." *Biostatistics* 8(1): 118-127. - the unweighted predecessor of the scheme implemented here. """
[docs] def __init__( self, batch: ArrayLike, *, quality: ArrayLike, parametric: bool = True, reference_batch: Any | None = None, eps: float = 1e-8, max_iter: int = 50, tol: float = 1e-4, ) -> None: super().__init__(batch=batch) self.quality = quality self.parametric = parametric self.reference_batch = reference_batch self.eps = eps self.max_iter = max_iter self.tol = tol
def _weights_at(self, idx: pd.Index) -> np.ndarray: w = _subset(self.quality, idx) if isinstance(w, pd.DataFrame): if w.shape[1] != 1: raise ValueError("quality must be 1-D.") w = w.iloc[:, 0] if w is None: raise ValueError("quality weights must be provided.") w = np.asarray(w, dtype=float) return _normalise_weights(w, eps=self.eps) def _fit_impl(self, X_df: pd.DataFrame, batch: npt.NDArray[Any]) -> None: X = X_df.to_numpy(dtype=float) n_samples, n_features = X.shape w = self._weights_at(X_df.index) # shape (n,) levels = np.unique(batch) self.batch_levels_ = levels if self.reference_batch is not None and self.reference_batch not in levels: raise ValueError( f"reference_batch={self.reference_batch!r} not found among " f"training batch levels {list(levels)}." ) self.reference_batch_ = self.reference_batch W = w.sum() self.grand_mean_ = (w[:, None] * X).sum(axis=0) / max(W, self.eps) centred = X - self.grand_mean_ self.pooled_var_ = (w[:, None] * centred**2).sum(axis=0) / max(W, self.eps) std = np.sqrt(self.pooled_var_ + self.eps) Z = centred / std n_batches = len(levels) eff = np.zeros(n_batches) gamma_hat = np.zeros((n_batches, n_features)) delta_hat2 = np.zeros((n_batches, n_features)) for i, lvl in enumerate(levels): mask = batch == lvl w_i = w[mask] w_bar = float(w_i.sum()) eff[i] = w_bar if w_bar < self.eps: raise ValueError( f"Effective batch size for {lvl!r} is ~0 " f"(all quality weights are zero in this batch)." ) g_hat = (w_i[:, None] * Z[mask]).sum(axis=0) / w_bar gamma_hat[i] = g_hat resid = Z[mask] - g_hat denom = max(w_bar - 1.0, self.eps) delta_hat2[i] = (w_i[:, None] * resid**2).sum(axis=0) / denom self.effective_batch_sizes_ = eff gamma_bar = gamma_hat.mean(axis=1) t2 = gamma_hat.var(axis=1, ddof=1) a_prior = _aprior(delta_hat2) b_prior = _bprior(delta_hat2) if not self.parametric: gamma_star = gamma_hat delta_star = delta_hat2 self.n_iter_ = 0 else: gamma_star = gamma_hat.copy() delta_star = delta_hat2.copy() n_iter = 0 for _ in range(self.max_iter): n_iter += 1 prev_g = gamma_star.copy() prev_d = delta_star.copy() for i, lvl in enumerate(levels): mask = batch == lvl w_i = w[mask] w_bar = eff[i] g_new = _postmean( gamma_hat[i], gamma_bar[i], w_bar, t2[i], delta_star[i] ) resid = Z[mask] - g_new sum_sq = (w_i[:, None] * resid**2).sum(axis=0) d_new = _postvar(sum_sq, w_bar, a_prior[i], b_prior[i]) gamma_star[i] = g_new delta_star[i] = np.clip(d_new, self.eps, None) change = max( np.max(np.abs(gamma_star - prev_g)), np.max(np.abs(delta_star - prev_d)), ) if change < self.tol: break self.n_iter_ = n_iter self.gamma_star_ = gamma_star self.delta_star_ = delta_star def _transform_impl( self, X_df: pd.DataFrame, batch: npt.NDArray[Any] ) -> pd.DataFrame: X = X_df.to_numpy(dtype=float) std = np.sqrt(self.pooled_var_ + self.eps) Z = (X - self.grand_mean_) / std out = np.empty_like(X) level_to_idx = {lvl: i for i, lvl in enumerate(self.batch_levels_)} for lvl in np.unique(batch): mask = batch == lvl if lvl == self.reference_batch_: out[mask] = X[mask] continue if lvl in level_to_idx: i = level_to_idx[lvl] gamma = self.gamma_star_[i] delta = self.delta_star_[i] else: # Unknown batch at transform time: correct using zero # per-batch effect (identity in standardised space). gamma = np.zeros(X.shape[1]) delta = np.ones(X.shape[1]) z_corr = (Z[mask] - gamma) / np.sqrt(np.clip(delta, self.eps, None)) out[mask] = z_corr * std + self.grand_mean_ return pd.DataFrame(out, index=X_df.index, columns=X_df.columns)