Source code for maldibatchkit.corrections.limma

r"""Limma-style ``removeBatchEffect`` (Ritchie et al. 2015).

This is a faithful port of the linear-model batch subtraction performed by
``limma::removeBatchEffect``. Unlike ComBat, Limma makes no empirical-Bayes
shrinkage; it simply fits the batch indicators (plus any protected
covariates) by ordinary least squares and subtracts the fitted batch
contribution from the data.

Mathematically, for each feature we solve

.. math::

    X = \\alpha + B \\gamma + C \\beta + \\varepsilon

where :math:`B` holds sum-to-zero contrasts on the batch indicator,
:math:`C` holds the protected covariates (design-of-interest), and we return
:math:`X - B \\hat\\gamma`. The intercept and covariate terms are left
untouched so biological signal encoded in ``design`` is preserved.
"""

from __future__ import annotations

from typing import Any

import numpy as np
import numpy.linalg as la
import numpy.typing as npt
import pandas as pd

from .._base import BaseBatchCorrector
from .._utils import ArrayLike, _subset

__all__ = ["Limma"]


def _sum_to_zero_contrasts(levels: np.ndarray) -> np.ndarray:
    """Return the ``(n_levels, n_levels - 1)`` sum-to-zero contrast matrix."""
    n = len(levels)
    if n < 2:
        return np.zeros((n, 0))
    contrasts = np.eye(n)[:, :-1]
    contrasts[-1, :] = -1.0
    return contrasts


def _design_matrix(
    labels: np.ndarray, levels: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
    """Build a sum-to-zero batch design matrix aligned to ``labels``."""
    contrasts = _sum_to_zero_contrasts(levels)
    index_of = {lvl: i for i, lvl in enumerate(levels)}
    unknown = [lvl for lvl in np.unique(labels) if lvl not in index_of]
    if unknown:
        raise ValueError(
            f"Unseen batch level(s) at transform: {unknown}. "
            f"Limma's batch subtraction is undefined for batches not "
            f"present at fit time."
        )
    row_idx = np.asarray([index_of[lvl] for lvl in labels])
    design = contrasts[row_idx]
    return design, contrasts


[docs] class Limma(BaseBatchCorrector): """Linear-model batch subtraction following ``limma::removeBatchEffect``. Parameters ---------- batch : array-like of shape (n_samples,) Batch labels. design : array-like of shape (n_samples, n_covariates), optional Protected covariates (the "design of interest" in Limma). Any effect explained by ``design`` is left in the output. eps : float, default=1e-8 Ridge-like regularisation added to the diagonal of the normal equations to keep things numerically stable on rank-deficient designs. Attributes ---------- batch_levels_ : np.ndarray Batch levels observed at ``fit`` time (in sorted order). gamma_ : np.ndarray Estimated batch coefficients in the sum-to-zero contrast basis, of shape ``(n_contrasts, n_features)``. contrasts_ : np.ndarray Sum-to-zero contrast matrix used at fit time. Notes ----- * Unknown batch levels at ``transform`` raise a ``ValueError``. This is intentional: Limma's subtraction is not defined for a batch that was not part of the design. * The design matrix is augmented with a constant column internally so the intercept is absorbed into the OLS fit; only the batch coefficients are subtracted on output. References ---------- Ritchie, M.E. et al. (2015) "limma powers differential expression analyses for RNA-sequencing and microarray studies." Nucleic Acids Research 43(7): e47. Examples -------- >>> from maldibatchkit import Limma >>> corrector = Limma(batch=batches, design=species_dummies) >>> X_corrected = corrector.fit_transform(X) """
[docs] def __init__( self, batch: ArrayLike, *, design: ArrayLike | None = None, eps: float = 1e-8, ) -> None: super().__init__(batch=batch) self.design = design self.eps = eps
def _design_at(self, idx: pd.Index) -> np.ndarray | None: if self.design is None: return None sub = _subset(self.design, idx) arr = ( sub.to_numpy(dtype=float) if hasattr(sub, "to_numpy") else np.asarray(sub, dtype=float) ) if arr.ndim == 1: arr = arr.reshape(-1, 1) return arr def _fit_impl(self, X_df: pd.DataFrame, batch: npt.NDArray[Any]) -> None: levels = np.unique(batch) self.batch_levels_ = levels batch_design, contrasts = _design_matrix(batch, levels) self.contrasts_ = contrasts intercept = np.ones((X_df.shape[0], 1)) cov = self._design_at(X_df.index) if cov is not None: full_design = np.hstack([intercept, cov, batch_design]) n_before = 1 + cov.shape[1] else: full_design = np.hstack([intercept, batch_design]) n_before = 1 # Tikhonov-regularised normal equations gram = full_design.T @ full_design gram += self.eps * np.eye(gram.shape[0]) rhs = full_design.T @ X_df.to_numpy(dtype=float) beta = la.solve(gram, rhs) # Batch coefficients live in the last n_contrasts rows self.gamma_ = beta[n_before:, :] def _transform_impl( self, X_df: pd.DataFrame, batch: npt.NDArray[Any] ) -> pd.DataFrame: batch_design, _ = _design_matrix(batch, self.batch_levels_) correction = batch_design @ self.gamma_ out = X_df.to_numpy(dtype=float) - correction return pd.DataFrame(out, index=X_df.index, columns=X_df.columns)