Extending MaldiBatchKit#
Every corrector shipped with MaldiBatchKit is a subclass of
BaseBatchCorrector. Custom correctors plug into
the same contract: subclass the base, implement two methods, and you
get a scikit-learn compatible, train/test-safe transformer for free.
Why Inherit?#
BaseBatchCorrector solves the two awkward parts
of sklearn + batch correction:
Index alignment.
batch(and any covariates) is passed once at construction time. On everyfit/transformcall the stored series is aligned toX.index, so sklearn pipelines and cross-validation can pass onlyXaround - you never have to manually slicebatchfor train / test folds.Leakage safety.
fitlearns parameters only from the rows it actually receives;transformnever re-fits. Fitted attributes (those ending in_) are used verbatim on unseen rows, even when a new batch level appears in the test set.
The base class additionally handles NaN / finite checks,
DataFrame-vs-ndarray round-tripping, and populates
feature_names_in_ / n_features_in_ / get_feature_names_out
so your corrector behaves like any other sklearn transformer.
Minimal Subclass#
Two methods are required: _fit_impl()
to learn parameters, and
_transform_impl() to apply them.
Store fitted attributes as ..._ members so
sklearn.utils.validation.check_is_fitted() picks them up:
import pandas as pd
from maldibatchkit import BaseBatchCorrector
class MeanCentering(BaseBatchCorrector):
"""Subtract per-batch means from each feature."""
def _fit_impl(self, X_df, batch):
self.batch_means_ = X_df.groupby(batch).mean()
self.grand_mean_ = X_df.mean(axis=0)
def _transform_impl(self, X_df, batch):
out = X_df.copy().astype(float)
known = set(self.batch_means_.index)
for lvl in pd.unique(batch):
mask = batch == lvl
offset = (
self.batch_means_.loc[lvl].to_numpy()
if lvl in known
else self.grand_mean_.to_numpy()
)
out.loc[mask] = out.loc[mask].to_numpy() - offset
return out
The inherited fit(),
transform(), and
fit_transform (via TransformerMixin) do the rest:
corrector = MeanCentering(batch=batch)
X_corrected = corrector.fit_transform(X)
Adding Hyperparameters#
Override __init__ to add method-specific hyperparameters. Always
call super().__init__(batch=batch) - the base class stores batch
under self.batch so
get_params() and
set_params() work:
from maldibatchkit._utils import ArrayLike
class ScaledMeanCentering(BaseBatchCorrector):
"""Mean centering with a user-controlled scale factor."""
def __init__(self, batch: ArrayLike, *, alpha: float = 1.0) -> None:
super().__init__(batch=batch)
self.alpha = alpha
def _fit_impl(self, X_df, batch):
self.batch_means_ = X_df.groupby(batch).mean()
def _transform_impl(self, X_df, batch):
out = X_df.copy().astype(float)
for lvl in pd.unique(batch):
mask = batch == lvl
offset = self.batch_means_.loc[lvl].to_numpy() * self.alpha
out.loc[mask] = out.loc[mask].to_numpy() - offset
return out
Dropping It Into a Pipeline#
Custom correctors are drop-in sklearn Pipeline steps. Because the
base class enforces train/test separation, you can safely
cross-validate without leakage:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
pipe = Pipeline([
("mean", MeanCentering(batch=batch)),
("scaler", StandardScaler()),
("clf", RandomForestClassifier(n_estimators=200)),
])
scores = cross_val_score(pipe, X, y, cv=5, scoring="roc_auc")
Adding Covariates#
If your corrector needs more than batch (for example species or a
quality score), follow the same pattern the shipped correctors use:
take the auxiliary vector in __init__, store it on self, and
subset it inside _fit_impl / _transform_impl via
maldibatchkit._utils._subset():
from maldibatchkit._utils import _subset
class QualityWeightedMean(BaseBatchCorrector):
def __init__(self, batch, *, quality) -> None:
super().__init__(batch=batch)
self.quality = quality
def _fit_impl(self, X_df, batch):
q = _subset(self.quality, X_df.index).to_numpy(dtype=float)
# weighted per-batch mean -- low quality contributes less
self.means_ = (
(X_df.mul(q, axis=0)).groupby(batch).sum()
.div(pd.Series(q).groupby(batch).sum(), axis=0)
)
def _transform_impl(self, X_df, batch):
out = X_df.copy().astype(float)
for lvl in pd.unique(batch):
mask = batch == lvl
out.loc[mask] = (
out.loc[mask].to_numpy() - self.means_.loc[lvl].to_numpy()
)
return out
See QualityWeightedComBat for a production-grade
example of the same pattern (iterative empirical-Bayes with quality
weights).
Conventions#
See Contributing for the full style guide. The important points for custom correctors:
NumPy-style docstrings on every public class (rendered by napoleon).
Trailing-underscore fitted attributes (
self.batch_means_, notself.means) - required by sklearn’scheck_is_fitted().No side effects outside ``fit``.
transformmust be idempotent: calling it twice on the same input must return the same output.Unseen batches at transform time should fall back to a training statistic (grand mean, identity, …) rather than raising, so cross-validation folds that split off a full batch still work.
Clear ``ImportError`` for optional dependencies. If your corrector wraps a library that is not a hard dependency, raise an explanatory
ImportErrorrather than letting a bareModuleNotFoundErrorpropagate. SeeHarmonyfor the reference pattern.
Reference Implementations#
The simplest real-world examples are in the baselines module - read
them top-to-bottom to see the idiom in a few dozen lines:
MedianCentering- fitted table + grand-median fallback.ZScorePerBatch- per-batch mean and stddev plus anepshyperparameter.ReferenceScaling- non-trivial reference selection logic atfittime.
For a corrector with an iterative fit and an auxiliary covariate,
study QualityWeightedComBat.