Source code for lightkit.estimator.base

from __future__ import annotations
import copy
import inspect
import json
import logging
import pickle
import warnings
from abc import ABC
from pathlib import Path
from typing import Any, TypeVar
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from lightkit.utils.path import PathType
from .exception import NotFittedError

E = TypeVar("E", bound="BaseEstimator")  # type: ignore
T = TypeVar("T")

logger = logging.getLogger(__name__)


[docs]class BaseEstimator(ABC): """ Base estimator class that all estimators should inherit from. This base estimator does not enforce the implementation of any methods, but users should follow the Scikit-learn guide on implementing estimators (which can be found `here <https://scikit- learn.org/stable/developers/develop.html>`_). Some of the methods mentioned in this guide are already implemented in this base estimator and work as expected if the aspects listed below are followed. In contrast to Scikit-learn's estimator, this estimator is strongly typed and integrates well with PyTorch Lightning. Most importantly, it provides the :meth:`trainer` method which returns a fully configured trainer to be used by other methods. The configuration is stored in the estimator and can be adjusted by passing parameters to ``default_params``, ``user_params`` and ``overwrite_params`` when calling ``super().__init__()``. By default, the base estimator sets the following flags: - Logging is disabled (``logger=False``). - Logging is performed at every step (``log_every_n_steps=1``). - The progress bar is only enabled (``enable_progress_bar``) if LightKit's logging level is ``INFO`` or more verbose. - Checkpointing is only enabled (``enable_checkpointing``) if LightKit's logging level is ``DEBUG`` or more verbose. - The model summary is only enabled (``enable_model_summary``) if LightKit's logging level is ``DEBUG`` or more verbose. Note that the logging level can be changed via :meth:`lightkit.set_logging_level`. When subclassing this base estimator, users should take care of the following aspects: - All parameters passed to the initializer must be assigned to attributes with the same name. This ensures that :meth:`get_params` and :meth:`set_params` work as expected. Parameters that are passed to the trainer *must* be named ``trainer_params`` and should not be manually assigned to an attribute (this is handled by the base estimator). - Fitted attributes must (1) have a single trailing underscore (e.g. ``model_``) and (2) be defined as annotations. This ensures that :meth:`save` and :meth:`load` properly manage the estimator's persistence. """ def __init__( self, *, default_params: dict[str, Any] | None = None, user_params: dict[str, Any] | None = None, overwrite_params: dict[str, Any] | None = None, ): """ Args: default_params: Estimator-specific parameters that provide defaults for configuring the PyTorch Lightning trainer. An example might be setting ``max_epochs``. Overwrites the default parameters established by the base estimator. user_params: User-specific parameters that configure the PyTorch Lightning trainer. This dictionary should be passed through from a ``trainer_params`` init argument in subclasses. Overwrites any of the default parameters. overwrite_params: PyTorch Lightning trainer flags that need to be ensured independently of user-provided parameters. For example, ``max_epochs`` could be fixed to a certain value. """ self.trainer_params_user = user_params self.trainer_params = { **dict( logger=False, log_every_n_steps=1, enable_progress_bar=logger.getEffectiveLevel() <= logging.INFO, enable_checkpointing=logger.getEffectiveLevel() <= logging.DEBUG, enable_model_summary=logger.getEffectiveLevel() <= logging.DEBUG, ), **(default_params or {}), **(user_params or {}), **(overwrite_params or {}), }
[docs] def trainer(self, **kwargs: Any) -> pl.Trainer: """ Returns the trainer as configured by the estimator. Typically, this method is only called by functions in the estimator. Args: kwargs: Additional arguments that override the trainer arguments registered in the initializer of the estimator. Returns: A fully initialized PyTorch Lightning trainer. Note: This function should be preferred over initializing the trainer directly. It ensures that the returned trainer correctly deals with LightKit components that may be introduced in the future. """ return pl.Trainer(**{**self.trainer_params, **kwargs})
# --------------------------------------------------------------------------------------------- # PERSISTENCE @property def persistent_attributes(self) -> list[str]: """ Returns the list of fitted attributes that ought to be saved and loaded. By default, this encompasses all annotations. """ return list(self.__annotations__.keys())
[docs] def save(self, path: PathType) -> None: """Saves the estimator to the provided directory. It saves a file named ``estimator.pickle`` for the configuration of the estimator and additional files for the fitted model (if applicable). For more information on the files saved for the fitted model or for more customization, look at :meth:`get_params` and :meth:`lightkit.nn.Configurable.save`. Args: path: The directory to which all files should be saved. Note: This method may be called regardless of whether the estimator has already been fitted. Attention: If the dictionary returned by :meth:`get_params` is not JSON-serializable, this method uses :mod:`pickle` which is not necessarily backwards-compatible. """ path = Path(path) assert not path.exists() or path.is_dir(), "Estimators can only be saved to a directory." path.mkdir(parents=True, exist_ok=True) self.save_parameters(path) try: self.save_attributes(path) except NotFittedError: # In case attributes are not fitted, we just don't save them pass
[docs] def save_parameters(self, path: Path) -> None: """ Saves the parameters of this estimator. By default, it uses JSON and falls back to :mod:`pickle`. It subclasses use non-primitive types as parameters, they should overwrite this method. Typically, this method should not be called directly. It is called as part of :meth:`save`. Args: path: The directory to which the parameters should be saved. """ params = self.get_params() try: data = json.dumps(params, indent=4) with (path / "params.json").open("w+") as f: f.write(data) except TypeError: warnings.warn( f"Failed to serialize parameters of `{self.__class__.__name__}` to JSON. " "Falling back to `pickle`." ) with (path / "params.pickle").open("wb+") as f: pickle.dump(params, f)
[docs] def save_attributes(self, path: Path) -> None: """ Saves the fitted attributes of this estimator. By default, it uses JSON and falls back to :mod:`pickle`. Subclasses should overwrite this method if non-primitive attributes are fitted. Typically, this method should not be called directly. It is called as part of :meth:`save`. Args: path: The directory to which the fitted attributed should be saved. Raises: NotFittedError: If the estimator has not been fitted. """ if len(self.persistent_attributes) == 0: return attributes = { attribute: getattr(self, attribute) for attribute in self.persistent_attributes } try: data = json.dumps(attributes, indent=4) with (path / "attributes.json").open("w+") as f: f.write(data) except TypeError: warnings.warn( f"Failed to serialize fitted attributes of `{self.__class__.__name__}` to JSON. " "Falling back to `pickle`." ) with (path / "attributes.pickle").open("wb+") as f: pickle.dump(attributes, f)
[docs] @classmethod def load(cls: type[E], path: PathType) -> E: """ Loads the estimator and (if available) the fitted model. This method should only be expected to work to load an estimator that has previously been saved via :meth:`save`. Args: path: The directory from which to load the estimator. Returns: The loaded estimator, either fitted or not. """ path = Path(path) assert path.is_dir(), "Estimators can only be loaded from a directory." estimator = cls.load_parameters(path) try: estimator.load_attributes(path) except FileNotFoundError: warnings.warn(f"Failed to read fitted attributes of `{cls.__name__}` at path '{path}'") return estimator
[docs] @classmethod def load_parameters(cls: type[E], path: Path) -> E: """ Initializes this estimator by loading its parameters. If subclasses overwrite :meth:`save_parameters`, this method should also be overwritten. Typically, this method should not be called directly. It is called as part of :meth:`load`. Args: path: The directory from which the parameters should be loaded. """ json_path = path / "params.json" pickle_path = path / "params.pickle" if json_path.exists(): with json_path.open() as f: params = json.load(f) else: with pickle_path.open("rb") as f: params = pickle.load(f) return cls(**params)
[docs] def load_attributes(self, path: Path) -> None: """ Loads the fitted attributes that are stored at the fitted path. If subclasses overwrite :meth:`save_attributes`, this method should also be overwritten. Typically, this method should not be called directly. It is called as part of :meth:`load`. Args: path: The directory from which the parameters should be loaded. Raises: FileNotFoundError: If the no fitted attributes have been stored. """ json_path = path / "attributes.json" pickle_path = path / "attributes.pickle" if json_path.exists(): with json_path.open() as f: self.set_params(json.load(f)) else: with pickle_path.open("rb") as f: self.set_params(pickle.load(f))
# --------------------------------------------------------------------------------------------- # SKLEARN INTERFACE
[docs] def get_params(self, deep: bool = True) -> dict[str, Any]: # pylint: disable=unused-argument """ Returns the estimator's parameters as passed to the initializer. Args: deep: Ignored. For Scikit-learn compatibility. Returns: The mapping from init parameters to values. """ signature = inspect.signature(self.__class__.__init__) parameters = [p.name for p in signature.parameters.values() if p.name != "self"] return {p: getattr(self, p) for p in parameters}
[docs] def set_params(self: E, values: dict[str, Any]) -> E: """ Sets the provided values on the estimator. The estimator is returned as well, but the estimator on which this function is called is also modified. Args: values: The values to set. Returns: The estimator where the values have been set. """ for key, value in values.items(): setattr(self, key, value) return self
[docs] def clone(self: E) -> E: """ Clones the estimator without copying any fitted attributes. All parameters of this estimator are copied via :meth:`copy.deepcopy`. Returns: The cloned estimator with the same parameters. """ return self.__class__( **{ name: param.clone() if isinstance(param, BaseEstimator) else copy.deepcopy(param) for name, param in self.get_params().items() } )
# --------------------------------------------------------------------------------------------- # SPECIAL METHODS def __getattr__(self, key: str) -> Any: if key in self.__dict__: return self.__dict__[key] if key.endswith("_") and not key.endswith("__") and key in self.__annotations__: raise NotFittedError(f"`{self.__class__.__name__}` has not been fitted yet") raise AttributeError( f"Attribute `{key}` does not exist on type `{self.__class__.__name__}`." ) # --------------------------------------------------------------------------------------------- # PRIVATE def _num_batches_per_epoch(self, loader: DataLoader[Any]) -> int: """Returns the number of batches that are run for the given data loader across all processes when using the trainer provided by the :meth:`trainer` method. If ``n`` processes run. ``k`` batches each, this method returns ``k * n``. """ trainer = self.trainer() num_batches = len(loader) # type: ignore kwargs = trainer.distributed_sampler_kwargs if kwargs is None: return num_batches return num_batches * kwargs.get("num_replicas", 1)