Source code for lightkit.estimator.configurable

from pathlib import Path
from typing import Any, Generic, TypeVar
from lightkit.nn._protocols import ConfigurableModule
from lightkit.utils import get_generic_type
from .base import BaseEstimator
from .exception import NotFittedError

M = TypeVar("M", bound=ConfigurableModule)  # type: ignore


[docs]class ConfigurableBaseEstimator(BaseEstimator, Generic[M]): """ Extension of the base estimator which allows to manage a single model that uses the :class:`lightkit.nn.Configurable` mixin. """ model_: M
[docs] def save_attributes(self, path: Path) -> None: # First, store simple attributes super().save_attributes(path) # Then, store the model self.model_.save(path / "model")
[docs] def load_attributes(self, path: Path) -> None: # First, load simple attributes super().load_attributes(path) # Then, load the model model_cls = get_generic_type(self.__class__, ConfigurableBaseEstimator) self.model_ = model_cls.load(path / "model") # type: ignore
def __getattr__(self, key: str) -> Any: try: return super().__getattr__(key) except AttributeError as e: if key.endswith("_") and not key.endswith("__") and not key.startswith("_"): raise NotFittedError(f"`{self.__class__.__name__}` has not been fitted yet") from e raise e