Source code for maldibatchkit.corrections.harmony

r"""Harmony integration (Korsunsky et al. 2019) via ``harmonypy``.

Harmony is iterative and has no closed-form "transform on new data"
step in its published form. Concretely, its fitted state consists of:

* L2-normalised cluster centroids ``Y`` (``d × K``),
* a soft-assignment bandwidth ``sigma`` (``K,``),
* a per-cluster ridge-regression batch correction ``W_batch``
  (``K × B × d``) that was applied to the training rows during the
  last ``moe_correct_ridge`` call.

Given those, the correction for a *new* sample reduces to a closed
form:

.. math::

    r_{ik}        &= \\mathrm{softmax}_k(-\\lVert \\hat{x}_i - Y_k
                       \\rVert^2 / \\sigma_k)

    \\hat{x}_i    &= x_i / \\lVert x_i \\rVert_2

    x_i^{\\text{corr}} &= x_i - \\sum_k r_{ik}\\, W_{\\text{batch}}^{\\,k,\\,
                                b(i)}

where :math:`b(i)` is sample :math:`i`'s batch index.  This mirrors
what Harmony does to training rows at its last iteration, but with
:math:`Y` and :math:`W_{\\text{batch}}` frozen at fit time, so test
rows cannot leak into the correction.

For the common case where ``fit`` and ``transform`` receive the same
rows (``fit_transform`` or simply re-``transform``), we short-circuit
to the exact ``harmonypy`` output so users see bit-identical behaviour
to the upstream library.

Built-in PCA preprocessing
--------------------------
Harmony was designed for low-dimensional dense embeddings (~20-50 PCA
components).  Applying it directly to a 6000-bin MALDI-TOF intensity
matrix degrades the soft-cluster assignment, because cosine distances
in the raw feature space are dominated by a handful of high-intensity
bins.  To sidestep that, :class:`Harmony` **always** fits a ``PCA``
on the training rows at ``fit`` time, runs Harmony in PCA space, and
maps the corrected result back to the original feature space via
``PCA.inverse_transform``.  The PCA itself is **frozen** at ``fit``
(like :math:`Y` and :math:`W_{\\text{batch}}`) and re-used verbatim
on unseen rows, so the closed-form transform stays leakage-safe by
construction.  Use the ``n_components`` argument to tune the size of
the PCA basis; ``None`` (the default) picks ``min(50, n_samples - 1,
n_features)``.
"""

from __future__ import annotations

import contextlib
import logging
from typing import Any

import numpy as np
import numpy.typing as npt
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

from .._base import BaseBatchCorrector
from .._utils import ArrayLike

_PCA_DEFAULT_N_COMPONENTS = 50

__all__ = ["Harmony"]


@contextlib.contextmanager
def _silence_harmonypy_logger():
    """Temporarily raise the ``harmonypy`` logger above INFO.

    ``harmonypy`` configures its own logger at module-import time with a
    DEBUG level and a StreamHandler attached, so setting ``verbose=False``
    on ``run_harmony`` alone does not suppress its progress messages.
    This context manager elevates the logger to WARNING for the duration
    of the call and restores the original state afterwards, leaving the
    user's broader logging configuration untouched.
    """
    logger = logging.getLogger("harmonypy")
    prev_level = logger.level
    prev_propagate = logger.propagate
    logger.setLevel(logging.WARNING)
    logger.propagate = False
    try:
        yield
    finally:
        logger.setLevel(prev_level)
        logger.propagate = prev_propagate


[docs] class Harmony(BaseBatchCorrector): """Sklearn-compatible Harmony with a closed-form ``transform``. Parameters ---------- batch : array-like of shape (n_samples,) Batch labels. covariates : array-like, optional Additional categorical covariates forwarded to ``harmonypy.run_harmony`` via the ``vars_use`` mechanism. theta : float or list of float, default=2.0 Diversity clustering penalty in Harmony. Higher values encourage more aggressive batch mixing. max_iter : int, default=20 Harmony iteration cap. nclust : int, optional Number of Harmony clusters. When ``None``, ``harmonypy``'s default heuristic is used. n_components : int, optional Number of PCA components used for the built-in dimensionality- reduction stage. ``None`` (the default) picks ``min(50, n_samples - 1, n_features)``. random_state : int, optional Seed forwarded to harmonypy **and** to the PCA stage. verbose : bool, default=True If ``False``, suppress ``harmonypy``'s own progress messages during ``fit``. ``harmonypy`` configures its module-level logger eagerly at import time, so simply passing ``verbose=False`` to ``run_harmony`` is not enough; this flag also raises the ``harmonypy`` logger above INFO for the duration of the call. The user's broader logging config is left untouched. scale_before_pca : bool, default=True If ``True`` (the default), fit a :class:`~sklearn.preprocessing.StandardScaler` on the training rows and apply it before the internal PCA. This gives every feature equal weight in the principal components and matches the convention from single-cell genomics workflows (Seurat / Scanpy ``ScaleData`` followed by PCA), where Harmony was originally developed. Without scaling, MALDI-TOF features with high intensity (a handful of intense peaks) dominate the PCs and Harmony's soft-cluster geometry inherits that bias. The scaler is **frozen** at ``fit`` (like the PCA, the centroids ``Y``, and the per-batch ridge ``W_batch``) and reused verbatim on unseen rows at ``transform`` time, with the corrected output un-scaled back to the original measurement scale. Set to ``False`` to recover the pre-v0.2 behaviour (centred but not scaled before PCA). Attributes ---------- batch_levels_ : np.ndarray Batch levels in the order used to encode ``Phi_moe`` columns. Y_ : np.ndarray of shape (n_components, K) L2-normalised cluster centroids frozen at fit time (in PCA space). sigma_ : np.ndarray of shape (K,) Soft-assignment bandwidths per cluster. W_batch_ : np.ndarray of shape (K, B, n_components) Per-cluster, per-batch ridge-regression correction vectors, in the same PCA space as ``Y_``. pca_ : sklearn.decomposition.PCA Fitted PCA, re-used verbatim on unseen rows at transform time. X_fit_ : np.ndarray Training feature matrix (original space) cached for same-matrix short-circuiting. n_clusters_ : int Number of clusters actually used by Harmony. n_iter_ : int Harmony iterations observed at fit time. Notes ----- * The closed-form output on ``X_train`` is not quite bit-identical to ``harmonypy``'s iterative correction because Harmony's soft assignment is refined by a batch-diversity penalty during training (the ``update_R`` loop). We deliberately use the raw cosine-softmax assignment at transform time since the diversity correction is a *training-time regulariser* tied to the training batch distribution and would not apply to a held-out sample. When ``transform`` is called on the exact training matrix we short-circuit to the harmonypy output so that ``fit_transform`` matches upstream byte-for-byte. * Unseen batch levels at transform raise ``ValueError``, mirroring ``ComBat`` / ``Limma``. * Requires ``harmonypy``; it is a core dependency of MaldiBatchKit, range-pinned to ``>=0.2.0,<2``. The 2.x line is a C++ rewrite that drops the private attributes our closed-form transform relies on (``_Phi_moe``, ``_lamb``, ``sigma``); we will revisit this constraint once those are re-exposed upstream. Examples -------- >>> from maldibatchkit import Harmony >>> corrector = Harmony(batch=batches, random_state=0) >>> X_train_c = corrector.fit_transform(X_train) >>> X_test_c = corrector.transform(X_test) # closed-form on new rows """
[docs] def __init__( self, batch: ArrayLike, *, covariates: ArrayLike | None = None, theta: float | list[float] = 2.0, max_iter: int = 20, nclust: int | None = None, n_components: int | None = None, random_state: int | None = None, verbose: bool = True, scale_before_pca: bool = True, ) -> None: super().__init__(batch=batch) self.covariates = covariates self.theta = theta self.max_iter = max_iter self.nclust = nclust self.n_components = n_components self.random_state = random_state self.verbose = verbose self.scale_before_pca = scale_before_pca
def _resolve_n_components(self, n_samples: int, n_features: int) -> int: """Pick a safe number of PCs for a given (n_samples, n_features).""" cap = max(1, min(n_samples - 1, n_features)) if self.n_components is None: return min(_PCA_DEFAULT_N_COMPONENTS, cap) return int(min(self.n_components, cap)) @staticmethod def _require_harmonypy(): try: import harmonypy # noqa: F401 except ImportError as exc: raise ImportError( "Harmony requires the optional 'harmonypy' " "dependency. Install it with `pip install harmonypy` " "(it is pulled in automatically by `pip install maldibatchkit`)." ) from exc return harmonypy def _build_meta( self, idx: pd.Index, batch: npt.NDArray[Any] ) -> tuple[pd.DataFrame, list[str]]: """Construct the harmonypy ``meta_data`` DataFrame + ``vars_use`` list.""" meta = pd.DataFrame( {"batch": np.asarray(batch, dtype=object).astype(str)}, index=idx, ) vars_use = ["batch"] if self.covariates is not None: cov = self.covariates if isinstance(cov, pd.Series): cov = cov.to_frame() if not isinstance(cov, pd.DataFrame): cov = pd.DataFrame(np.asarray(cov), index=idx) cov_aligned = cov.loc[idx].copy().astype(str) cov_aligned.columns = [f"cov_{i}" for i in range(cov_aligned.shape[1])] meta = pd.concat([meta, cov_aligned], axis=1) vars_use.extend(list(cov_aligned.columns)) return meta, vars_use @staticmethod def _l2_normalise_rows(X: np.ndarray, eps: float = 1e-12) -> np.ndarray: norms = np.linalg.norm(X, axis=1, keepdims=True) norms = np.clip(norms, eps, None) return X / norms def _softmax_assignment(self, X: np.ndarray) -> np.ndarray: """Compute the frozen-centroid soft assignment R (N, K).""" Z_cos = self._l2_normalise_rows(X) dist = 2.0 * (1.0 - Z_cos @ self.Y_) # (N, K) scaled = -dist / self.sigma_[None, :] scaled -= scaled.max(axis=1, keepdims=True) # numerical stability R = np.exp(scaled) R = R / np.clip(R.sum(axis=1, keepdims=True), 1e-12, None) return R @staticmethod def _extract_W_batch(harmony_out, batch_levels: np.ndarray) -> np.ndarray: """Recompute the final-iteration per-cluster per-batch ridge W. Harmony's ``Phi_moe`` has shape ``(1 + ΣL_v, N)`` where ``L_v`` is the number of levels of each variable in ``vars_use``. The batch indicator occupies **the first** level block (rows ``1 : 1 + n_batches``), because ``_build_meta`` always places ``"batch"`` first in ``vars_use``. Additional covariates contribute further level blocks that are *protected* in the training correction but are not reapplied to unseen rows at transform time - the closed-form path removes batch only. Returns an array of shape ``(K, n_batches, d)`` such that ``W_batch[k, b, :]`` is the per-cluster, per-batch coefficient subtracted for training samples that belong to batch ``b``. """ Z_orig = np.asarray(harmony_out._Z_orig.cpu().numpy(), dtype=float) # (d, N) Phi_moe = np.asarray( harmony_out._Phi_moe.cpu().numpy(), dtype=float ) # (rows, N) R = np.asarray(harmony_out._R.cpu().numpy(), dtype=float) # (K, N) lamb = np.asarray(harmony_out._lamb.cpu().numpy(), dtype=float) # (rows,) K, _ = R.shape d = Z_orig.shape[0] n_batches = int(len(batch_levels)) W_batch = np.zeros((K, n_batches, d), dtype=float) batch_slice = slice(1, 1 + n_batches) # skip intercept, take batch block for k in range(K): Phi_Rk = Phi_moe * R[k] # (rows, N) cov = Phi_Rk @ Phi_moe.T + np.diag(lamb) # (rows, rows) inv_cov = np.linalg.inv(cov) Z_tmp = Z_orig * R[k] # (d, N) rhs = Phi_moe @ Z_tmp.T # (rows, d) W_k = inv_cov @ rhs # (rows, d) W_k[0, :] = 0.0 # keep the intercept W_batch[k] = W_k[batch_slice, :] return W_batch def _extract_batch_levels(self, harmony_out, batch: npt.NDArray[Any]) -> np.ndarray: """Extract the batch-level ordering implied by Phi_moe's columns. ``harmonypy`` uses ``pd.factorize(meta[vars_use[0]])`` under the hood, so the order is the first-occurrence order in the meta DataFrame we provided. """ seen: list[str] = [] for b in batch.tolist(): s = str(b) if s not in seen: seen.append(s) return np.asarray(seen, dtype=object) def _fit_impl(self, X_df: pd.DataFrame, batch: npt.NDArray[Any]) -> None: harmonypy = self._require_harmonypy() if self.random_state is not None: np.random.seed(self.random_state) meta, vars_use = self._build_meta(X_df.index, batch) X_orig = X_df.to_numpy(dtype=float) # Optional StandardScaler before PCA: gives every feature equal # weight in the principal components, matching the Seurat / # Scanpy ``ScaleData -> PCA`` convention from single-cell # genomics where Harmony was first developed. The scaler is # frozen at fit time and reused on unseen rows at transform. if self.scale_before_pca: self.scaler_ = StandardScaler().fit(X_orig) X_scaled = self.scaler_.transform(X_orig) else: self.scaler_ = None X_scaled = X_orig # PCA preprocessing is mandatory and fit on training rows only. # It is then frozen and re-used verbatim by ``transform`` so no # information leaks from test rows back into the basis. n_pcs = self._resolve_n_components(X_scaled.shape[0], X_scaled.shape[1]) self.pca_ = PCA(n_components=n_pcs, random_state=self.random_state) X_work = self.pca_.fit_transform(X_scaled) kwargs: dict[str, Any] = { "vars_use": vars_use, "theta": self.theta, "max_iter_harmony": self.max_iter, "verbose": self.verbose, } if self.nclust is not None: kwargs["nclust"] = self.nclust log_ctx = ( _silence_harmonypy_logger() if not self.verbose else contextlib.nullcontext() ) with log_ctx: harmony_out = harmonypy.run_harmony(X_work, meta, **kwargs) self.batch_levels_ = self._extract_batch_levels(harmony_out, batch) # Y_ / W_batch_ live in PCA space. The closed-form transform # projects new rows into the same space before applying them. self.Y_ = harmony_out.Y # (n_components, K) self.sigma_ = harmony_out.sigma.astype(float) # (K,) self.W_batch_ = self._extract_W_batch( harmony_out, self.batch_levels_ ) # (K, B, n_components) self.n_clusters_ = int(self.Y_.shape[1]) self.n_iter_ = int(len(getattr(harmony_out, "kmeans_rounds", []) or [0])) # Same-matrix fast-path cache in the original feature space - # map harmonypy's Z_corr back through the fitted PCA so users # who call ``transform(X_train)`` (or ``fit_transform``) see # ``(n_samples, n_features)``. When the scaler is in use we # also un-standardise so the cached output is on the original # measurement scale, matching the closed-form transform path. Z_corr_work = np.asarray(harmony_out.Z_corr) # (N, n_components) Z_corr_scaled = self.pca_.inverse_transform(Z_corr_work) if self.scaler_ is not None: Z_corr_orig = self.scaler_.inverse_transform(Z_corr_scaled) else: Z_corr_orig = Z_corr_scaled self.X_fit_ = X_orig self._corrected_cache_ = Z_corr_orig # (N, n_features) def _transform_impl( self, X_df: pd.DataFrame, batch: npt.NDArray[Any] ) -> pd.DataFrame: X_orig = X_df.to_numpy(dtype=float) # Fast path: caller transforms the exact matrix we fit on -> # return the cached correction byte-for-byte. if X_orig.shape == self.X_fit_.shape and np.array_equal(X_orig, self.X_fit_): return pd.DataFrame( self._corrected_cache_, index=X_df.index, columns=X_df.columns, ) # Closed-form correction on new rows -------------------------- batch_str = np.asarray([str(b) for b in batch]) known = {lvl: i for i, lvl in enumerate(self.batch_levels_)} unseen = sorted(set(batch_str) - set(known)) if unseen: raise ValueError( f"Unseen batch level(s) at transform: {unseen}. " f"Harmony's closed-form correction is only defined for " f"batches that were present at fit time." ) batch_idx = np.asarray([known[b] for b in batch_str], dtype=int) # Project into PCA space using the fitted basis - no data from # ``X_orig`` flows back into the scaler / PCA at fit time, so # the train/test split stays clean. if self.scaler_ is not None: X_scaled = self.scaler_.transform(X_orig) else: X_scaled = X_orig X_work = self.pca_.transform(X_scaled) R_new = self._softmax_assignment(X_work) # (N, K) # W_batch_ shape (K, B, n_components); pull the batch slice per # sample -> (N, K, n_components); weight by R_new and sum over K. W_selected = self.W_batch_[:, batch_idx, :].transpose(1, 0, 2) correction = np.einsum("nk,nkd->nd", R_new, W_selected) Z_corr_work = X_work - correction Z_corr_scaled = self.pca_.inverse_transform(Z_corr_work) if self.scaler_ is not None: Z_corr_orig = self.scaler_.inverse_transform(Z_corr_scaled) else: Z_corr_orig = Z_corr_scaled return pd.DataFrame(Z_corr_orig, index=X_df.index, columns=X_df.columns)