Diagnostics Module#
Quantitative checks that a correction step actually removed batch structure. The subpackage exposes two families of metrics plus a convenience report helper:
maldibatchkit.diagnostics.generic- classic batch-mixing metrics (silhouette by batch, kBET, LISI).maldibatchkit.diagnostics.maldi- MALDI-specific summaries (per-batch peak drift, TIC coefficient of variation, spectrum count).maldibatchkit.diagnostics.diagnostic_report()- run every metric on a(before, after)pair and collapse it into a tidy DataFrame suitable for downstream tables and plots.
All metrics take the same (X, batch) signature and return scalars
or tidy pandas objects.
Warning
Batch-mixing metrics say nothing about whether biological signal is preserved. Always pair them with a supervised metric (AMR classifier AUROC, VME, …) in any real comparison of correctors.
Generic Batch-Mixing Metrics#
- maldibatchkit.diagnostics.silhouette_batch(X, batch, *, metric='euclidean')[source]#
Silhouette coefficient using batch labels as clusters.
Values close to 0 indicate good mixing; values close to 1 indicate strong batch separation; values close to -1 indicate batches sitting inside each others’ clusters (typically an artefact).
- Parameters:
- Returns:
Silhouette coefficient, or
0.0if there is fewer than two distinct batches (silhouette is undefined in that case).- Return type:
- maldibatchkit.diagnostics.kbet(X, batch, *, k=None, alpha=0.05)[source]#
k-nearest-neighbours Batch Effect Test (kBET; Büttner et al. 2019).
For each sample we compute a chi-square statistic testing whether its k-nearest-neighbour batch composition matches the global batch frequencies. The reported statistics are the acceptance rate (the fraction of samples whose p-value exceeds
alpha- higher is better) and the mean chi-square statistic (lower is better).- Parameters:
X (
DataFrame|Series|ndarray[tuple[Any,...],dtype[Any]]) – Feature matrix.batch (
DataFrame|Series|ndarray[tuple[Any,...],dtype[Any]]) – Batch labels.k (
int|None) – Number of nearest neighbours. Defaults tomax(10, int(0.1 * n_samples)).alpha (
float) – Significance threshold used to compute the acceptance rate.
- Returns:
{"acceptance_rate": float, "mean_chi2": float, "k": int}.- Return type:
- maldibatchkit.diagnostics.lisi(X, batch, *, perplexity=30.0)[source]#
Local Inverse Simpson’s Index for batch mixing.
LISI is the effective number of batches represented in each sample’s local neighbourhood (Gaussian-kernel weighted to the requested perplexity). The returned value is the median LISI across samples. It lies in
[1, n_batches]; values close ton_batchesindicate strong mixing, values close to 1 indicate batch-segregated neighbourhoods.- Parameters:
- Returns:
Median LISI across samples.
- Return type:
MALDI-Specific Metrics#
- maldibatchkit.diagnostics.peak_position_drift(X, batch, *, mz_values=None, top_k=50)[source]#
Per-batch peak-position drift relative to a global reference.
For each batch we compute its median spectrum, identify the top
top_kpeaks of the global median spectrum, and locate the nearest local maximum of the per-batch median to each of those global peaks. The returned table summarises the distribution of (signed) position shifts per batch.- Parameters:
X (
DataFrame|Series|ndarray[tuple[Any,...],dtype[Any]]) – Binned intensities.batch (
DataFrame|Series|ndarray[tuple[Any,...],dtype[Any]]) – Batch labels.mz_values (
DataFrame|Series|ndarray[tuple[Any,...],dtype[Any]] |None) – m/z values for the columns ofX. If None, integer column positions are used, and the returnedmean_delta_mzcolumn is in bin units.top_k (
int) – Number of global peaks to track.
- Returns:
One row per batch, with columns
mean_delta_mz,median_delta_mz,max_abs_delta_mz.- Return type:
Combined Report#
- maldibatchkit.diagnostics.diagnostic_report(before, after, batch, *, mz_values=None, k=None, lisi_perplexity=30.0, top_k_peaks=50)[source]#
Run every diagnostic on a (before, after) pair.
- Parameters:
before (
DataFrame|Series|ndarray[tuple[Any,...],dtype[Any]]) – Feature matrix prior to batch correction.after (
DataFrame|Series|ndarray[tuple[Any,...],dtype[Any]]) – Feature matrix after batch correction. Must have the same shape asbefore.batch (
DataFrame|Series|ndarray[tuple[Any,...],dtype[Any]]) – Batch labels.mz_values (
DataFrame|Series|ndarray[tuple[Any,...],dtype[Any]] |None) – m/z positions for the feature columns (passed topeak_position_drift()).lisi_perplexity (
float) – Perplexity for LISI.top_k_peaks (
int) – Number of peaks tracked for drift.
- Returns:
Tidy report with columns
metric,scope,value_before,value_after(anddeltawhere both columns are numeric and the metric’s improvement direction is well-defined).- Return type:
Benchmark#
BatchCorrectionBenchmark runs a
fixed set of metrics across multiple correctors under a single
protocol, returning tidy per-(method, metric) summaries plus the raw
long-form observations. See Choosing a corrector for the recipe.
- class maldibatchkit.diagnostics.BatchCorrectionBenchmark(correctors, *, metrics=('kbet', 'lisi_normalized', 'species_preservation'), protocol='full_data', test_size=0.2, n_repeats=1, n_bootstrap=0, bootstrap_mode='resample_metric', random_state=None)[source]#
Bases:
objectDiagnostic comparison of multiple batch correctors.
- Parameters:
correctors (
Mapping[str,BaseBatchCorrector]) – Mapping from a display name to an unfittedBaseBatchCorrector. Each will be cloned per protocol iteration (so callingfiton the benchmark does not leave the input correctors fitted).metrics (
Sequence[Any]) – Metric specifications. Strings are resolved against the registry inmaldibatchkit.diagnostics; callables are invoked asmetric(X_corrected, batch, species=species, **extra)and only receive the kwargs they actually accept.protocol (
str) –'full_data'fits each corrector on all rows and scores on the same rows (Büttner-2019 convention).'stratified_split'fits on a stratified train split and scores on the held-out test split; every batch is forced into both folds.test_size (
float) – Test fraction for'stratified_split'.n_repeats (
int) – Number of repeated splits for'stratified_split'.n_bootstrap (
int) – If non-zero, every metric is recomputed onn_bootstrapstratified row-resamples of the (corrected) matrix to give confidence intervals.0disables bootstrapping.bootstrap_mode (
str) –'resample_metric'(the fast, default mode) fits each corrector once and resamples rows of the corrected matrix to score the metric repeatedly - CIs reflect the metric’s sampling noise only.'refit'resamples rows ofXand refits the corrector for every bootstrap iteration; this is slower (n_correctors × n_bootstrapextra fits per repeat) but the CI also captures corrector stability.random_state (
int|Generator|None) – Seed / generator for splits and bootstrap.
- Variables:
results_long (pd.DataFrame) – Tidy raw observations: columns
method,metric,repeat,bootstrap,value.bootstrap == -1marks the point-estimate row (no bootstrap resampling).results (pd.DataFrame) – Per-(method, metric) summary:
value(mean),ci_lo/ci_hi(2.5 / 97.5 percentile),std(sample std),n(number of observations),better('higher','lower','zero'for metrics where both signs are bad likesilhouette_batch, or'n/a'for user callables - annotate via_METRIC_DIRECTIONfor registered names).corrected (dict[str, pd.DataFrame]) – For
protocol='full_data', the fitted-and-transformed matrix from each corrector. Forprotocol='stratified_split', the corrected test matrix from the last repeat (provided for downstream inspection / plotting; useresults_long_for per-repeat statistics).baseline (pd.DataFrame) – One-row-per-metric report on the uncorrected
X, mirroring theresults_schema.
Examples
>>> from maldibatchkit import ComBat, NoOpCorrector >>> from maldibatchkit.diagnostics import BatchCorrectionBenchmark >>> bench = BatchCorrectionBenchmark( ... correctors={ ... "none": NoOpCorrector(batch=b), ... "combat-fortin": ComBat(batch=b, method="fortin"), ... }, ... metrics=("kbet", "species_preservation"), ... n_bootstrap=200, ... random_state=0, ... ) >>> bench.fit(X, batch=b, species=s) >>> bench.rank(by="species_preservation")
- __init__(correctors, *, metrics=('kbet', 'lisi_normalized', 'species_preservation'), protocol='full_data', test_size=0.2, n_repeats=1, n_bootstrap=0, bootstrap_mode='resample_metric', random_state=None)[source]#
- fit(X, *, batch, species=None, y=None, **extra)[source]#
Run every corrector under the chosen protocol and score each metric.
- Parameters:
X (
DataFrame|Series|ndarray[tuple[Any,...],dtype[Any]]) – Feature matrix.batch (
DataFrame|Series|ndarray[tuple[Any,...],dtype[Any]]) – Batch labels.species (
DataFrame|Series|ndarray[tuple[Any,...],dtype[Any]] |None) – Forwarded to metrics that need it (e.g.species_preservation).y (
DataFrame|Series|ndarray[tuple[Any,...],dtype[Any]] |None) – Ignored at the benchmark level (no classifier scoring). Kept on the signature so the call site mirrors sklearn.**extra (
Any) – Forwarded to every metric callable that accepts the given keyword (e.g.mz_values=forpeak_position_drift).
- Returns:
self – Fitted benchmark with
results_long_,results_,corrected_,baseline_populated.- Return type:
- rank(by, *, ascending=None)[source]#
Return
results_sorted by one metric’s mean value.- Parameters:
by (
str) – Metric name to rank on.ascending (
bool|None) – Sort direction. If omitted, the metric’s registered “better” direction is used:'higher'→ descending,'lower'→ ascending,'zero'→ ascending by|value|(e.g.silhouette_batch, where both positive and negative extremes are bad - only 0 is well-mixed).
- Return type:
Example#
from maldibatchkit import SpeciesAwareComBat
from maldibatchkit.diagnostics import diagnostic_report
from maldibatchkit.viz import plot_diagnostic_summary
# Correct and summarise
corrector = SpeciesAwareComBat(batch=batch, species=species)
X_corrected = corrector.fit_transform(X)
report = diagnostic_report(
X, X_corrected, batch,
mz_values=mz, top_k_peaks=40,
)
print(report.head())
# metric scope value_before value_after delta better
# 0 silhouette_batch overall 0.311 0.042 -0.269 lower
# 1 kbet_acceptance overall 0.124 0.561 0.437 higher
# 2 lisi overall 1.420 2.310 0.890 higher
# Quick bar-chart visualisation of the overall metrics
plot_diagnostic_summary(report, scope="overall")