Source code for maldibatchkit.viz.umap

"""Side-by-side UMAP before/after batch correction."""

from __future__ import annotations

from typing import Any

import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA

from .._utils import ArrayLike

__all__ = ["plot_batch_umap"]


def _require_umap():
    try:
        import umap  # noqa: F401
    except ImportError as exc:
        raise ImportError(
            "plot_batch_umap requires `umap-learn`. Install the MaldiBatchKit "
            "viz extras with `pip install maldibatchkit[viz]`."
        ) from exc
    return umap


def _embed(X: np.ndarray, random_state: int) -> np.ndarray:
    umap_mod = _require_umap()
    reducer = umap_mod.UMAP(n_components=2, random_state=random_state)
    return np.asarray(reducer.fit_transform(X))


def _to_ndarray(X: ArrayLike) -> np.ndarray:
    if hasattr(X, "to_numpy"):
        return np.asarray(X.to_numpy())
    return np.asarray(X)


[docs] def plot_batch_umap( before: ArrayLike, after: ArrayLike, batch: ArrayLike, *, color_by: str = "batch", species: ArrayLike | None = None, random_state: int = 42, pca_preprocess: int | None = 50, ax: tuple[Any, Any] | None = None, ) -> tuple[Any, Any]: """Plot UMAP embeddings of ``X`` before and after batch correction. Parameters ---------- before : array-like of shape (n_samples, n_features) Feature matrix prior to correction. after : array-like of shape (n_samples, n_features) Feature matrix after correction. Must share ``before``'s shape. batch : array-like of shape (n_samples,) Batch labels. color_by : {'batch', 'species'}, default='batch' Which label drives colouring. ``'species'`` requires passing ``species`` explicitly. species : array-like of shape (n_samples,), optional Species labels. Required when ``color_by='species'``. random_state : int, default=42 Seed for UMAP. pca_preprocess : int or None, default=50 If not None, reduce to this many PCs before UMAP. Gives a large speed-up on typical MALDI-TOF matrices without hurting the plot qualitatively. ax : tuple of matplotlib Axes, optional Two axes to draw on. If None, a new 1x2 figure is created. Returns ------- fig : matplotlib.figure.Figure The figure used (or the parent figure of the provided axes). axes : tuple of matplotlib.axes.Axes The two axes that were drawn on. """ before_arr = _to_ndarray(before).astype(float) after_arr = _to_ndarray(after).astype(float) if before_arr.shape != after_arr.shape: raise ValueError( f"before and after must share shape; got {before_arr.shape} vs {after_arr.shape}." ) if color_by == "species" and species is None: raise ValueError("color_by='species' requires the `species` argument.") labels = np.asarray(batch if color_by == "batch" else species) def _maybe_pca(X: np.ndarray) -> np.ndarray: if pca_preprocess is None: return X n = min(pca_preprocess, X.shape[0], X.shape[1]) if n < 2: return X return np.asarray( PCA(n_components=n, random_state=random_state).fit_transform(X) ) b_emb = _embed(_maybe_pca(before_arr), random_state) a_emb = _embed(_maybe_pca(after_arr), random_state) if ax is None: fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=False, sharey=False) else: axes = ax fig = axes[0].figure for title, emb, ax_ in (("Before", b_emb, axes[0]), ("After", a_emb, axes[1])): for lvl in np.unique(labels): mask = labels == lvl ax_.scatter(emb[mask, 0], emb[mask, 1], s=8, label=str(lvl), alpha=0.7) ax_.set_title(f"UMAP - {title}") ax_.set_xlabel("UMAP1") ax_.set_ylabel("UMAP2") axes[0].legend(title=color_by, loc="best", fontsize=7) return fig, tuple(axes) if isinstance(axes, (list, np.ndarray)) else axes