Source code for maldibatchkit.viz.summary

"""Grid bar plot of before/after diagnostic metrics.

Different diagnostic metrics live on very different scales
(silhouette_batch ~ 0, LISI ~ 1-n_batches, kBET mean chi-square in the
tens to thousands...), so sharing a single y-axis makes the smaller
metrics unreadable.  We render one subplot per metric, each with its
own y-axis.
"""

from __future__ import annotations

from typing import Any, Iterable

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

__all__ = ["plot_diagnostic_summary"]


_BAR_BEFORE = "#b0b0b0"
_BAR_AFTER = "#2b8cbe"


[docs] def plot_diagnostic_summary( report_df: pd.DataFrame, *, scope: str | Iterable[str] = "overall", ncols: int | None = None, figsize_per_plot: tuple[float, float] = (3.2, 3.0), axes: Any | None = None, ) -> tuple[Any, np.ndarray]: """Plot before/after diagnostic values, one subplot per metric. Parameters ---------- report_df : pd.DataFrame Output of :func:`maldibatchkit.diagnostics.diagnostic_report`. scope : str or iterable of str, default='overall' Which slice(s) of the report to plot. Pass a single scope name (``"overall"``, ``"batch_00"``, ...) to render one pair of bars per metric, or a list (``["batch_00", "batch_01"]``) to group bars by scope inside each metric's subplot. ncols : int, optional Number of subplot columns. Defaults to ``min(n_metrics, 4)``. figsize_per_plot : (float, float), default=(3.2, 3.0) Width, height of each metric's subplot (inches). The returned figure has size ``(ncols * w, nrows * h)``. axes : array of matplotlib Axes, optional Pre-built axes grid. Must have at least ``n_metrics`` entries when flattened. Returns ------- fig : matplotlib.figure.Figure axes : np.ndarray of matplotlib.axes.Axes Flattened array of the axes actually used. Unused slots from a non-rectangular grid are turned off. """ required = {"metric", "scope", "value_before", "value_after"} missing = required - set(report_df.columns) if missing: raise ValueError(f"report_df is missing required columns: {sorted(missing)}.") scopes = [scope] if isinstance(scope, str) else list(scope) if not scopes: raise ValueError("`scope` must name at least one scope.") df = report_df[report_df["scope"].isin(scopes)].copy() if df.empty: raise ValueError( f"No rows in report_df with scope in {scopes!r}. " f"Available scopes: {sorted(report_df['scope'].unique())}." ) # Preserve the metric order as it appears in the report metrics: list[str] = [] for m in df["metric"].tolist(): if m not in metrics: metrics.append(m) n_metrics = len(metrics) if ncols is None: ncols = min(n_metrics, 4) ncols = max(1, ncols) nrows = int(np.ceil(n_metrics / ncols)) if axes is None: fig, ax_grid = plt.subplots( nrows, ncols, figsize=(ncols * figsize_per_plot[0], nrows * figsize_per_plot[1]), ) ax_flat = np.atleast_1d(ax_grid).ravel() else: ax_flat = np.atleast_1d(np.asarray(axes)).ravel() if ax_flat.size < n_metrics: raise ValueError( f"Provided axes has {ax_flat.size} slots but {n_metrics} " f"metrics need plotting." ) fig = ax_flat[0].figure used_axes = [] for i, metric in enumerate(metrics): ax = ax_flat[i] sub = df[df["metric"] == metric] scope_order = [s for s in scopes if s in sub["scope"].unique()] if not scope_order: ax.set_axis_off() continue sub = sub.set_index("scope").loc[scope_order] n = len(scope_order) xs = np.arange(n) width = 0.38 ax.bar( xs - width / 2, sub["value_before"].astype(float).to_numpy(), width, label="before", color=_BAR_BEFORE, ) ax.bar( xs + width / 2, sub["value_after"].astype(float).to_numpy(), width, label="after", color=_BAR_AFTER, ) ax.set_xticks(xs) if n > 1: ax.set_xticklabels(scope_order, rotation=30, ha="right") else: ax.set_xticks([]) ax.set_title(metric, fontsize=10) ax.grid(axis="y", alpha=0.3) # Hairline guide at 0 for signed metrics if sub[["value_before", "value_after"]].astype(float).min().min() < 0: ax.axhline(0.0, color="k", lw=0.5) if i == 0: ax.legend(fontsize=8, loc="best") used_axes.append(ax) # Turn off any leftover subplot slots for j in range(n_metrics, ax_flat.size): ax_flat[j].set_axis_off() title_suffix = scopes[0] if len(scopes) == 1 else f"{len(scopes)} scopes" fig.suptitle( f"Batch-effect diagnostics ({title_suffix})", fontsize=12, y=1.02, ) fig.tight_layout() return fig, ax_flat[:n_metrics]