Extending MaldiBatchKit#

Every corrector shipped with MaldiBatchKit is a subclass of BaseBatchCorrector. Custom correctors plug into the same contract: subclass the base, implement two methods, and you get a scikit-learn compatible, train/test-safe transformer for free.

Why Inherit?#

BaseBatchCorrector solves the two awkward parts of sklearn + batch correction:

  1. Index alignment. batch (and any covariates) is passed once at construction time. On every fit / transform call the stored series is aligned to X.index, so sklearn pipelines and cross-validation can pass only X around - you never have to manually slice batch for train / test folds.

  2. Leakage safety. fit learns parameters only from the rows it actually receives; transform never re-fits. Fitted attributes (those ending in _) are used verbatim on unseen rows, even when a new batch level appears in the test set.

The base class additionally handles NaN / finite checks, DataFrame-vs-ndarray round-tripping, and populates feature_names_in_ / n_features_in_ / get_feature_names_out so your corrector behaves like any other sklearn transformer.

Minimal Subclass#

Two methods are required: _fit_impl() to learn parameters, and _transform_impl() to apply them. Store fitted attributes as ..._ members so sklearn.utils.validation.check_is_fitted() picks them up:

import pandas as pd
from maldibatchkit import BaseBatchCorrector

class MeanCentering(BaseBatchCorrector):
    """Subtract per-batch means from each feature."""

    def _fit_impl(self, X_df, batch):
        self.batch_means_ = X_df.groupby(batch).mean()
        self.grand_mean_ = X_df.mean(axis=0)

    def _transform_impl(self, X_df, batch):
        out = X_df.copy().astype(float)
        known = set(self.batch_means_.index)
        for lvl in pd.unique(batch):
            mask = batch == lvl
            offset = (
                self.batch_means_.loc[lvl].to_numpy()
                if lvl in known
                else self.grand_mean_.to_numpy()
            )
            out.loc[mask] = out.loc[mask].to_numpy() - offset
        return out

The inherited fit(), transform(), and fit_transform (via TransformerMixin) do the rest:

corrector = MeanCentering(batch=batch)
X_corrected = corrector.fit_transform(X)

Adding Hyperparameters#

Override __init__ to add method-specific hyperparameters. Always call super().__init__(batch=batch) - the base class stores batch under self.batch so get_params() and set_params() work:

from maldibatchkit._utils import ArrayLike

class ScaledMeanCentering(BaseBatchCorrector):
    """Mean centering with a user-controlled scale factor."""

    def __init__(self, batch: ArrayLike, *, alpha: float = 1.0) -> None:
        super().__init__(batch=batch)
        self.alpha = alpha

    def _fit_impl(self, X_df, batch):
        self.batch_means_ = X_df.groupby(batch).mean()

    def _transform_impl(self, X_df, batch):
        out = X_df.copy().astype(float)
        for lvl in pd.unique(batch):
            mask = batch == lvl
            offset = self.batch_means_.loc[lvl].to_numpy() * self.alpha
            out.loc[mask] = out.loc[mask].to_numpy() - offset
        return out

Dropping It Into a Pipeline#

Custom correctors are drop-in sklearn Pipeline steps. Because the base class enforces train/test separation, you can safely cross-validate without leakage:

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score

pipe = Pipeline([
    ("mean", MeanCentering(batch=batch)),
    ("scaler", StandardScaler()),
    ("clf", RandomForestClassifier(n_estimators=200)),
])
scores = cross_val_score(pipe, X, y, cv=5, scoring="roc_auc")

Adding Covariates#

If your corrector needs more than batch (for example species or a quality score), follow the same pattern the shipped correctors use: take the auxiliary vector in __init__, store it on self, and subset it inside _fit_impl / _transform_impl via maldibatchkit._utils._subset():

from maldibatchkit._utils import _subset

class QualityWeightedMean(BaseBatchCorrector):
    def __init__(self, batch, *, quality) -> None:
        super().__init__(batch=batch)
        self.quality = quality

    def _fit_impl(self, X_df, batch):
        q = _subset(self.quality, X_df.index).to_numpy(dtype=float)
        # weighted per-batch mean -- low quality contributes less
        self.means_ = (
            (X_df.mul(q, axis=0)).groupby(batch).sum()
            .div(pd.Series(q).groupby(batch).sum(), axis=0)
        )

    def _transform_impl(self, X_df, batch):
        out = X_df.copy().astype(float)
        for lvl in pd.unique(batch):
            mask = batch == lvl
            out.loc[mask] = (
                out.loc[mask].to_numpy() - self.means_.loc[lvl].to_numpy()
            )
        return out

See QualityWeightedComBat for a production-grade example of the same pattern (iterative empirical-Bayes with quality weights).

Conventions#

See Contributing for the full style guide. The important points for custom correctors:

  • NumPy-style docstrings on every public class (rendered by napoleon).

  • Trailing-underscore fitted attributes (self.batch_means_, not self.means) - required by sklearn’s check_is_fitted().

  • No side effects outside ``fit``. transform must be idempotent: calling it twice on the same input must return the same output.

  • Unseen batches at transform time should fall back to a training statistic (grand mean, identity, …) rather than raising, so cross-validation folds that split off a full batch still work.

  • Clear ``ImportError`` for optional dependencies. If your corrector wraps a library that is not a hard dependency, raise an explanatory ImportError rather than letting a bare ModuleNotFoundError propagate. See Harmony for the reference pattern.

Reference Implementations#

The simplest real-world examples are in the baselines module - read them top-to-bottom to see the idiom in a few dozen lines:

For a corrector with an iterative fit and an auxiliary covariate, study QualityWeightedComBat.