Using MaldiBatchKit correctors in a scikit-learn workflow#
Every MaldiBatchKit corrector is a BaseEstimator + TransformerMixin, so it plugs into the standard scikit-learn fit / transform / Pipeline contract. This notebook shows the two patterns you need:
``fit`` on train, ``transform`` on test - never call
fit_transformon the full dataset before splitting.Drop the corrector into ``sklearn.pipeline.Pipeline`` so cross-validation refits it on each fold for you.
Uses the MALDI-Kleb-AI dataset (Rocchi et al. 2026, Zenodo DOI 10.5281/zenodo.17405072).
1. Load the dataset#
[9]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path.cwd().parent))
from notebooks._demo import load_maldi_kleb_ai
demo = load_maldi_kleb_ai(antibiotic='Amikacin', verbose=True)
X = demo.X
batch = demo.batch # Rome / Milan / Catania
species = demo.species # K. pneumoniae / K. variicola / ...
y = (demo.meta['Amikacin'] == 'R').astype(int) # binary: R vs S/I
print(f'X: {X.shape}')
print(f'batches: {sorted(batch.unique())}')
print(f'prevalence of resistance: {y.mean():.2%}')
Processing spectra: 100%|██████████| 743/743 [00:00<00:00, 4359.36spectrum/s]
X: (741, 6000)
batches: ['Catania', 'Milan', 'Rome']
prevalence of resistance: 49.80%
2. Why fit_transform before splitting is wrong#
It is tempting to write:
X_corrected = SpeciesAwareComBat(batch=batch, species=species).fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X_corrected, y, ...)
fit_transform uses every row - including the test rows - to estimate per-batch shift / scale parameters, so the test set has been touched by a computation that depended on it. The fix is the scikit-learn contract: fit on the training rows only, then transform both splits with the fitted parameters.
3. The correct pattern: fit(X_train) + transform(X_test)#
[10]:
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from maldibatchkit import SpeciesAwareComBat
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.25, stratify=y, random_state=0,
)
print(f'train: {X_train.shape} | test: {X_test.shape}')
train: (555, 6000) | test: (186, 6000)
[11]:
# The corrector stores `batch` and `species` for the FULL dataset. They are
# indexed by sample ID, so the corrector aligns them to X_train.index on
# fit() and to X_test.index on transform() -- you never slice them manually.
corrector = SpeciesAwareComBat(batch=batch, species=species)
corrector.fit(X_train)
X_train_c = corrector.transform(X_train)
X_test_c = corrector.transform(X_test) # same parameters, applied to held-out rows
scaler = StandardScaler().fit(X_train_c)
clf = RandomForestClassifier(
n_estimators=300, random_state=0, n_jobs=-1,
).fit(scaler.transform(X_train_c), y_train)
auc = roc_auc_score(
y_test, clf.predict_proba(scaler.transform(X_test_c))[:, 1]
)
print(f'Held-out AUROC: {auc:.4f}')
Held-out AUROC: 0.7887
batch is a pandas.Series indexed by the same sample IDs as X. On corrector.fit(X_train) the base class selects batch.loc[X_train.index]; on corrector.transform(X_test) it selects batch.loc[X_test.index]. You never build batch_train / batch_test by hand.
4. The recommended pattern: sklearn.Pipeline + cross-validation#
Dropping the corrector into a Pipeline removes the per-fold bookkeeping: every sklearn CV utility calls fit on the training fold and transform on the validation fold, so leakage is structurally impossible.
[12]:
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.pipeline import Pipeline
pipe = Pipeline([
('combat', SpeciesAwareComBat(batch=batch, species=species)),
('scaler', StandardScaler()),
('clf', RandomForestClassifier(n_estimators=300, random_state=0, n_jobs=-1)),
])
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
scores = cross_val_score(pipe, X, y, cv=cv, scoring='roc_auc', n_jobs=-1)
print(f'5-fold CV AUROC: {scores.mean():.4f} +/- {scores.std():.4f}')
5-fold CV AUROC: 0.7897 +/- 0.0359
On each fold the pipeline calls combat.fit(X_train_fold) (using batch.loc[X_train_fold.index] only), then combat.transform(X_train_fold) and combat.transform(X_val_fold) with the parameters learned on the training fold. Swap SpeciesAwareComBat for any other corrector - ComBat, Limma, Harmony, QualityWeightedComBat, or any BaseBatchCorrector subclass you write yourself - and the pipeline shape is identical.
5. Hyperparameter search#
GridSearchCV works out of the box: every step in the pipeline exposes its hyperparameters via get_params / set_params, so you can tune corrector and classifier parameters jointly.
[13]:
from sklearn.model_selection import GridSearchCV
param_grid = {
'combat__parametric': [True, False],
'clf__n_estimators': [200, 400],
}
grid = GridSearchCV(pipe, param_grid=param_grid, cv=cv,
scoring='roc_auc', n_jobs=-1, refit=True)
grid.fit(X, y)
print(f'best AUROC: {grid.best_score_:.4f}')
print(f'best params: {grid.best_params_}')
best AUROC: 0.7971
best params: {'clf__n_estimators': 400, 'combat__parametric': False}