Using MaldiBatchKit correctors in a scikit-learn workflow#

Every MaldiBatchKit corrector is a BaseEstimator + TransformerMixin, so it plugs into the standard scikit-learn fit / transform / Pipeline contract. This notebook shows the two patterns you need:

  1. ``fit`` on train, ``transform`` on test - never call fit_transform on the full dataset before splitting.

  2. Drop the corrector into ``sklearn.pipeline.Pipeline`` so cross-validation refits it on each fold for you.

Uses the MALDI-Kleb-AI dataset (Rocchi et al. 2026, Zenodo DOI 10.5281/zenodo.17405072).

1. Load the dataset#

[9]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path.cwd().parent))
from notebooks._demo import load_maldi_kleb_ai

demo = load_maldi_kleb_ai(antibiotic='Amikacin', verbose=True)
X = demo.X
batch = demo.batch          # Rome / Milan / Catania
species = demo.species      # K. pneumoniae / K. variicola / ...
y = (demo.meta['Amikacin'] == 'R').astype(int)   # binary: R vs S/I

print(f'X:       {X.shape}')
print(f'batches: {sorted(batch.unique())}')
print(f'prevalence of resistance: {y.mean():.2%}')
Processing spectra: 100%|██████████| 743/743 [00:00<00:00, 4359.36spectrum/s]
X:       (741, 6000)
batches: ['Catania', 'Milan', 'Rome']
prevalence of resistance: 49.80%

2. Why fit_transform before splitting is wrong#

It is tempting to write:

X_corrected = SpeciesAwareComBat(batch=batch, species=species).fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X_corrected, y, ...)

fit_transform uses every row - including the test rows - to estimate per-batch shift / scale parameters, so the test set has been touched by a computation that depended on it. The fix is the scikit-learn contract: fit on the training rows only, then transform both splits with the fitted parameters.

3. The correct pattern: fit(X_train) + transform(X_test)#

[10]:
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from maldibatchkit import SpeciesAwareComBat

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, stratify=y, random_state=0,
)
print(f'train: {X_train.shape} | test: {X_test.shape}')
train: (555, 6000) | test: (186, 6000)
[11]:
# The corrector stores `batch` and `species` for the FULL dataset. They are
# indexed by sample ID, so the corrector aligns them to X_train.index on
# fit() and to X_test.index on transform() -- you never slice them manually.
corrector = SpeciesAwareComBat(batch=batch, species=species)
corrector.fit(X_train)

X_train_c = corrector.transform(X_train)
X_test_c = corrector.transform(X_test)   # same parameters, applied to held-out rows

scaler = StandardScaler().fit(X_train_c)
clf = RandomForestClassifier(
    n_estimators=300, random_state=0, n_jobs=-1,
).fit(scaler.transform(X_train_c), y_train)

auc = roc_auc_score(
    y_test, clf.predict_proba(scaler.transform(X_test_c))[:, 1]
)
print(f'Held-out AUROC: {auc:.4f}')
Held-out AUROC: 0.7887

batch is a pandas.Series indexed by the same sample IDs as X. On corrector.fit(X_train) the base class selects batch.loc[X_train.index]; on corrector.transform(X_test) it selects batch.loc[X_test.index]. You never build batch_train / batch_test by hand.