{ "cells": [ { "cell_type": "markdown", "id": "51588049", "metadata": {}, "source": [ "# Using MaldiBatchKit correctors in a scikit-learn workflow\n", "\n", "Every MaldiBatchKit corrector is a `BaseEstimator` + `TransformerMixin`, so it plugs into the standard scikit-learn `fit` / `transform` / `Pipeline` contract. This notebook shows the two patterns you need:\n", "\n", "1. **`fit` on train, `transform` on test** - never call `fit_transform` on the full dataset before splitting.\n", "2. **Drop the corrector into `sklearn.pipeline.Pipeline`** so cross-validation refits it on each fold for you.\n", "\n", "Uses the **MALDI-Kleb-AI** dataset (Rocchi *et al.* 2026, [Zenodo DOI 10.5281/zenodo.17405072](https://zenodo.org/records/17405072)).\n" ] }, { "cell_type": "markdown", "id": "d54cfad9", "metadata": {}, "source": [ "## 1. Load the dataset" ] }, { "cell_type": "code", "execution_count": 9, "id": "5ed95a8e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Processing spectra: 100%|██████████| 743/743 [00:00<00:00, 4359.36spectrum/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "X: (741, 6000)\n", "batches: ['Catania', 'Milan', 'Rome']\n", "prevalence of resistance: 49.80%\n" ] } ], "source": [ "import sys, pathlib\n", "sys.path.insert(0, str(pathlib.Path.cwd().parent))\n", "from notebooks._demo import load_maldi_kleb_ai\n", "\n", "demo = load_maldi_kleb_ai(antibiotic='Amikacin', verbose=True)\n", "X = demo.X\n", "batch = demo.batch # Rome / Milan / Catania\n", "species = demo.species # K. pneumoniae / K. variicola / ...\n", "y = (demo.meta['Amikacin'] == 'R').astype(int) # binary: R vs S/I\n", "\n", "print(f'X: {X.shape}')\n", "print(f'batches: {sorted(batch.unique())}')\n", "print(f'prevalence of resistance: {y.mean():.2%}')" ] }, { "cell_type": "markdown", "id": "01e69c1b", "metadata": {}, "source": [ "## 2. Why `fit_transform` before splitting is wrong\n", "\n", "It is tempting to write:\n", "\n", "```python\n", "X_corrected = SpeciesAwareComBat(batch=batch, species=species).fit_transform(X)\n", "X_train, X_test, y_train, y_test = train_test_split(X_corrected, y, ...)\n", "```\n", "\n", "`fit_transform` uses every row - including the test rows - to estimate per-batch shift / scale parameters, so the test set has been touched by a computation that depended on it. The fix is the scikit-learn contract: `fit` on the training rows only, then `transform` both splits with the fitted parameters.\n" ] }, { "cell_type": "markdown", "id": "b44b9491", "metadata": {}, "source": [ "## 3. The correct pattern: `fit(X_train)` + `transform(X_test)`" ] }, { "cell_type": "code", "execution_count": 10, "id": "f20134ce", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train: (555, 6000) | test: (186, 6000)\n" ] } ], "source": [ "import numpy as np\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.metrics import roc_auc_score\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "from maldibatchkit import SpeciesAwareComBat\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.25, stratify=y, random_state=0,\n", ")\n", "print(f'train: {X_train.shape} | test: {X_test.shape}')" ] }, { "cell_type": "code", "execution_count": 11, "id": "6d6139bf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Held-out AUROC: 0.7887\n" ] } ], "source": [ "# The corrector stores `batch` and `species` for the FULL dataset. They are\n", "# indexed by sample ID, so the corrector aligns them to X_train.index on\n", "# fit() and to X_test.index on transform() -- you never slice them manually.\n", "corrector = SpeciesAwareComBat(batch=batch, species=species)\n", "corrector.fit(X_train)\n", "\n", "X_train_c = corrector.transform(X_train)\n", "X_test_c = corrector.transform(X_test) # same parameters, applied to held-out rows\n", "\n", "scaler = StandardScaler().fit(X_train_c)\n", "clf = RandomForestClassifier(\n", " n_estimators=300, random_state=0, n_jobs=-1,\n", ").fit(scaler.transform(X_train_c), y_train)\n", "\n", "auc = roc_auc_score(\n", " y_test, clf.predict_proba(scaler.transform(X_test_c))[:, 1]\n", ")\n", "print(f'Held-out AUROC: {auc:.4f}')" ] }, { "cell_type": "markdown", "id": "dc16a0d6", "metadata": {}, "source": [ "`batch` is a `pandas.Series` indexed by the same sample IDs as `X`. On `corrector.fit(X_train)` the base class selects `batch.loc[X_train.index]`; on `corrector.transform(X_test)` it selects `batch.loc[X_test.index]`. You never build `batch_train` / `batch_test` by hand.\n" ] }, { "cell_type": "markdown", "id": "fa2ee382", "metadata": {}, "source": [ "## 4. The recommended pattern: `sklearn.Pipeline` + cross-validation\n", "\n", "Dropping the corrector into a `Pipeline` removes the per-fold bookkeeping: every sklearn CV utility calls `fit` on the training fold and `transform` on the validation fold, so leakage is structurally impossible.\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "6c09b3f6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "5-fold CV AUROC: 0.7897 +/- 0.0359\n" ] } ], "source": [ "from sklearn.model_selection import StratifiedKFold, cross_val_score\n", "from sklearn.pipeline import Pipeline\n", "\n", "pipe = Pipeline([\n", " ('combat', SpeciesAwareComBat(batch=batch, species=species)),\n", " ('scaler', StandardScaler()),\n", " ('clf', RandomForestClassifier(n_estimators=300, random_state=0, n_jobs=-1)),\n", "])\n", "\n", "cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)\n", "scores = cross_val_score(pipe, X, y, cv=cv, scoring='roc_auc', n_jobs=-1)\n", "print(f'5-fold CV AUROC: {scores.mean():.4f} +/- {scores.std():.4f}')" ] }, { "cell_type": "markdown", "id": "d3536fd0", "metadata": {}, "source": [ "On each fold the pipeline calls `combat.fit(X_train_fold)` (using `batch.loc[X_train_fold.index]` only), then `combat.transform(X_train_fold)` and `combat.transform(X_val_fold)` with the parameters learned on the training fold. Swap `SpeciesAwareComBat` for any other corrector - `ComBat`, `Limma`, `Harmony`, `QualityWeightedComBat`, or any `BaseBatchCorrector` subclass you write yourself - and the pipeline shape is identical.\n" ] }, { "cell_type": "markdown", "id": "9660990f", "metadata": {}, "source": [ "## 5. Hyperparameter search\n", "\n", "`GridSearchCV` works out of the box: every step in the pipeline exposes its hyperparameters via `get_params` / `set_params`, so you can tune corrector and classifier parameters jointly.\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "f10faea5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "best AUROC: 0.7971\n", "best params: {'clf__n_estimators': 400, 'combat__parametric': False}\n" ] } ], "source": [ "from sklearn.model_selection import GridSearchCV\n", "\n", "param_grid = {\n", " 'combat__parametric': [True, False],\n", " 'clf__n_estimators': [200, 400],\n", "}\n", "\n", "grid = GridSearchCV(pipe, param_grid=param_grid, cv=cv,\n", " scoring='roc_auc', n_jobs=-1, refit=True)\n", "grid.fit(X, y)\n", "print(f'best AUROC: {grid.best_score_:.4f}')\n", "print(f'best params: {grid.best_params_}')" ] } ], "metadata": { "kernelspec": { "display_name": "maldibatchkit", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }