Source code for lightkit.estimator.mixins
from typing import Generic
from ._protocols import D_contra, Predictor, R_co, Transformer
[docs]class TransformerMixin(Generic[D_contra, R_co]):
"""
Mixin that provides a ``fit_transform`` method that chains fitting the estimator and
transforming the data it was fitted on.
"""
[docs] def fit_transform(self: Transformer[D_contra, R_co], data: D_contra) -> R_co:
"""
Fits the estimator using the provided data and subsequently transforms the data using the
fitted estimator. It simply chains calls to :meth:`fit` and :meth:`transform`.
Args:
data: The data to use for fitting and to transform. The data must have the
same type as for the :meth:`fit` method.
Returns:
The transformed data. Consult the :meth:`transform` documentation for more information
on the return type.
"""
return self.fit(data).transform(data)
[docs]class PredictorMixin(Generic[D_contra, R_co]):
"""
Mixin that provides a ``fit_predict`` method that chains fitting the estimator and making
predictions for the data it was fitted on.
"""
[docs] def fit_predict(self: Predictor[D_contra, R_co], data: D_contra) -> R_co:
"""
Fits the estimator using the provided data and subsequently predicts the labels for the
data using the fitted estimator. It simply chains calls to :meth:`fit` and
:meth:`predict`.
Args:
data: The data to use for fitting and to predict labels for. The data must have the
same type as for the :meth:`fit` method.
Returns:
The predicted labels. Consult the :meth:`predict` documentation for more information
on the return type.
"""
return self.fit(data).predict(data)