BaseEstimator¶
- class lightkit.BaseEstimator(*, default_params=None, user_params=None, overwrite_params=None)[source]¶
Bases:
abc.ABC
Base estimator class that all estimators should inherit from. This base estimator does not enforce the implementation of any methods, but users should follow the Scikit-learn guide on implementing estimators (which can be found here). Some of the methods mentioned in this guide are already implemented in this base estimator and work as expected if the aspects listed below are followed.
In contrast to Scikit-learn's estimator, this estimator is strongly typed and integrates well with PyTorch Lightning. Most importantly, it provides the
trainer()
method which returns a fully configured trainer to be used by other methods. The configuration is stored in the estimator and can be adjusted by passing parameters todefault_params
,user_params
andoverwrite_params
when callingsuper().__init__()
. By default, the base estimator sets the following flags:Logging is disabled (
logger=False
).Logging is performed at every step (
log_every_n_steps=1
).- The progress bar is only enabled (
enable_progress_bar
) if LightKit's logging level is INFO
or more verbose.
- The progress bar is only enabled (
- Checkpointing is only enabled (
enable_checkpointing
) if LightKit's logging level is DEBUG
or more verbose.
- Checkpointing is only enabled (
- The model summary is only enabled (
enable_model_summary
) if LightKit's logging level is DEBUG
or more verbose.
- The model summary is only enabled (
Note that the logging level can be changed via
lightkit.set_logging_level()
.When subclassing this base estimator, users should take care of the following aspects:
All parameters passed to the initializer must be assigned to attributes with the same name. This ensures that
get_params()
andset_params()
work as expected. Parameters that are passed to the trainer must be namedtrainer_params
and should not be manually assigned to an attribute (this is handled by the base estimator).Fitted attributes must (1) have a single trailing underscore (e.g.
model_
) and (2) be defined as annotations. This ensures thatsave()
andload()
properly manage the estimator's persistence.
- Parameters
default_params (dict[str, Any] | None) -- Estimator-specific parameters that provide defaults for configuring the PyTorch Lightning trainer. An example might be setting
max_epochs
. Overwrites the default parameters established by the base estimator.user_params (dict[str, Any] | None) -- User-specific parameters that configure the PyTorch Lightning trainer. This dictionary should be passed through from a
trainer_params
init argument in subclasses. Overwrites any of the default parameters.overwrite_params (dict[str, Any] | None) -- PyTorch Lightning trainer flags that need to be ensured independently of user-provided parameters. For example,
max_epochs
could be fixed to a certain value.
Methods
Clones the estimator without copying any fitted attributes. |
|
Returns the estimator's parameters as passed to the initializer. |
|
Loads the estimator and (if available) the fitted model. |
|
Loads the fitted attributes that are stored at the fitted path. |
|
Initializes this estimator by loading its parameters. |
|
Saves the estimator to the provided directory. |
|
Saves the fitted attributes of this estimator. |
|
Saves the parameters of this estimator. |
|
Sets the provided values on the estimator. |
|
Returns the trainer as configured by the estimator. |
Inherited Methods
Attributes
|
Returns the list of fitted attributes that ought to be saved and loaded. |