Source code for maldibatchkit.viz.peaks
"""Visualize per-batch peak position drift."""
from __future__ import annotations
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from .._utils import ArrayLike, _resolve_mz_axis
__all__ = ["plot_peak_shift"]
def _to_dataframe(X: ArrayLike) -> pd.DataFrame:
if isinstance(X, pd.DataFrame):
return X
return pd.DataFrame(np.asarray(X))
[docs]
def plot_peak_shift(
batches: ArrayLike,
X: ArrayLike,
reference: ArrayLike | None = None,
*,
mz_values: ArrayLike | None = None,
ax: Any | None = None,
max_batches: int | None = 6,
) -> tuple[Any, Any]:
"""Overlay per-batch median spectra against a reference spectrum.
Parameters
----------
batches : array-like of shape (n_samples,)
Batch labels.
X : array-like of shape (n_samples, n_features)
Binned intensities.
reference : array-like of shape (n_features,), optional
Reference spectrum. If None, the global median across all rows
is used.
mz_values : array-like of shape (n_features,), optional
m/z coordinates for the x-axis. Defaults to column positions.
ax : matplotlib Axes, optional
Axis to draw on. If None, a new figure is created.
max_batches : int or None, default=6
Maximum number of batches to overlay. Oldest ties broken
alphabetically. Set to None to draw every batch (slow on
many-batch studies).
Returns
-------
fig, ax : matplotlib Figure and Axes
"""
df = _to_dataframe(X).astype(float)
b = np.asarray(batches)
mz, mz_is_real = _resolve_mz_axis(df, mz_values)
if reference is None:
ref = df.median(axis=0).to_numpy()
else:
ref = np.asarray(reference, dtype=float).ravel()
levels = sorted(set(b.tolist()))
if max_batches is not None:
levels = levels[:max_batches]
if ax is None:
fig, ax = plt.subplots(figsize=(10, 4))
else:
fig = ax.figure
ax.plot(mz, ref, color="black", lw=1.3, label="reference", zorder=5)
cmap = plt.get_cmap("tab10")
for i, lvl in enumerate(levels):
med = df.loc[b == lvl].median(axis=0).to_numpy()
ax.plot(mz, med, lw=0.9, color=cmap(i % 10), label=str(lvl), alpha=0.85)
ax.set_xlabel("m/z" if mz_is_real else "bin index")
ax.set_ylabel("intensity (median per batch)")
ax.legend(title="batch", fontsize=7, loc="best")
ax.set_title("Per-batch peak-shape comparison")
return fig, ax