BaseEstimator

class lightkit.BaseEstimator(*, default_params=None, user_params=None, overwrite_params=None)[source]

Bases: abc.ABC

Base estimator class that all estimators should inherit from. This base estimator does not enforce the implementation of any methods, but users should follow the Scikit-learn guide on implementing estimators (which can be found here). Some of the methods mentioned in this guide are already implemented in this base estimator and work as expected if the aspects listed below are followed.

In contrast to Scikit-learn's estimator, this estimator is strongly typed and integrates well with PyTorch Lightning. Most importantly, it provides the trainer() method which returns a fully configured trainer to be used by other methods. The configuration is stored in the estimator and can be adjusted by passing parameters to default_params, user_params and overwrite_params when calling super().__init__(). By default, the base estimator sets the following flags:

  • Logging is disabled (logger=False).

  • Logging is performed at every step (log_every_n_steps=1).

  • The progress bar is only enabled (enable_progress_bar) if LightKit's logging level is

    INFO or more verbose.

  • Checkpointing is only enabled (enable_checkpointing) if LightKit's logging level is

    DEBUG or more verbose.

  • The model summary is only enabled (enable_model_summary) if LightKit's logging level is

    DEBUG or more verbose.

Note that the logging level can be changed via lightkit.set_logging_level().

When subclassing this base estimator, users should take care of the following aspects:

  • All parameters passed to the initializer must be assigned to attributes with the same name. This ensures that get_params() and set_params() work as expected. Parameters that are passed to the trainer must be named trainer_params and should not be manually assigned to an attribute (this is handled by the base estimator).

  • Fitted attributes must (1) have a single trailing underscore (e.g. model_) and (2) be defined as annotations. This ensures that save() and load() properly manage the estimator's persistence.

Parameters
  • default_params (dict[str, Any] | None) -- Estimator-specific parameters that provide defaults for configuring the PyTorch Lightning trainer. An example might be setting max_epochs. Overwrites the default parameters established by the base estimator.

  • user_params (dict[str, Any] | None) -- User-specific parameters that configure the PyTorch Lightning trainer. This dictionary should be passed through from a trainer_params init argument in subclasses. Overwrites any of the default parameters.

  • overwrite_params (dict[str, Any] | None) -- PyTorch Lightning trainer flags that need to be ensured independently of user-provided parameters. For example, max_epochs could be fixed to a certain value.

Methods

clone

Clones the estimator without copying any fitted attributes.

get_params

Returns the estimator's parameters as passed to the initializer.

load

Loads the estimator and (if available) the fitted model.

load_attributes

Loads the fitted attributes that are stored at the fitted path.

load_parameters

Initializes this estimator by loading its parameters.

save

Saves the estimator to the provided directory.

save_attributes

Saves the fitted attributes of this estimator.

save_parameters

Saves the parameters of this estimator.

set_params

Sets the provided values on the estimator.

trainer

Returns the trainer as configured by the estimator.

Inherited Methods

Attributes

persistent_attributes

Returns the list of fitted attributes that ought to be saved and loaded.