Source code for lightkit.nn.configurable

from __future__ import annotations
import dataclasses
import json
from pathlib import Path
from typing import Any, Generic
import torch
from torch import jit, nn
from lightkit.utils import get_generic_type, PathType
from ._protocols import C_co, ConfigurableModule, M


[docs]class Configurable(Generic[C_co]): """ A mixin for any PyTorch module to extend it with storage capabilities. By passing a single configuration object to the initializer, this mixin allows the module to be extended with :meth:`save` and :meth:`load` methods. These methods allow to (1) save the model along with its configuration (i.e. architecture) and (2) to load the model without instantiating an instance of the class. """ def __init__(self, config: C_co, *args: Any, **kwargs: Any): """ Args: config: The configuration of the architecture. args: Positional arguments that ought to be passed to the superclass. kwargs: Keyword arguments that ought to be passed to the superclass. """ assert dataclasses.is_dataclass(config), "Configuration is not a dataclass." assert isinstance( self, nn.Module ), "Configurable mixin can only be applied to subclasses of `torch.nn.Module`." super().__init__(*args, **kwargs) self.config = config
[docs] @jit.unused def save_config(self: ConfigurableModule[C_co], path: Path) -> None: """ Saves only the module's configuration to a file named ``config.json`` in the specified directory. This method should not be called directly. It is called as part of :meth:`save`. Args: path: The directory to which to save the configuration and parameter files. The directory may or may not exist but no parent directories are created. """ path.mkdir(parents=False, exist_ok=True) with (path / "config.json").open("w+") as f: json.dump(dataclasses.asdict(self.config), f, indent=4)
[docs] @jit.unused def save(self: ConfigurableModule[C_co], path: PathType, compile_model: bool = False) -> None: """ Saves the module's configuration and parameters to files in the specified directory. It creates two files, namely ``config.json`` and ``parameters.pt`` which contain the configuration and parameters, respectively. Args: path: The directory to which to save the configuration and parameter files. The directory may or may not exist but no parent directories are created. compile_model: Whether the model should be compiled via TorchScript. An additional file called ``model.ptc`` will then be stored. Note that you can simply load the compiled model via :meth:`torch.jit.load` at a later point. """ path = Path(path) assert not path.exists() or path.is_dir(), "Modules can only be saved to a directory." path.mkdir(parents=True, exist_ok=True) # Store the model's configuration and all parameters self.save_config(path) with (path / "parameters.pt").open("wb+") as f: torch.save(self.state_dict(), f) # pylint: disable=no-member # Optionally store the compiled model if compile_model: compiled_model = jit.script(self) with (path / "model.ptc").open("wb+") as f: jit.save(compiled_model, f)
[docs] @classmethod def load_config(cls: type[M], path: Path) -> M: """ Loads the module by reading the configuration. Parameters are initialized randomly as if the module would be initialized from scratch. This method should not be called directly. It is called as part of :meth:`load`. Args: path: The directory which contains the ``config.json`` to load. Returns: The loaded model. Attention: This method must only be called if the module is initializable solely from a configuration. """ config_cls = get_generic_type(cls, Configurable) with (path / "config.json").open("r") as f: config_args = json.load(f) config = _init_config(config_cls, config_args) return cls(config) # type: ignore
[docs] @classmethod def load(cls: type[M], path: PathType) -> M: """ Loads the module's configurations and parameters from files in the specified directory at first. Then, it initializes the model with the stored configurations and loads the parameters. This method is typically used after calling :meth:`save` on the model. Args: path: The directory which contains the ``config.json`` and ``parameters.pt`` files to load. Returns: The loaded model. Note: You can load modules even after you changed their configuration class. The only requirement is that any new configuration options have a default value. """ path = Path(path) assert path.is_dir(), "Modules can only be loaded from a directory." # Load the config config_cls = get_generic_type(cls, Configurable) with (path / "config.json").open("r") as f: config_args = json.load(f) config = _init_config(config_cls, config_args) # Initialize model model = cls(config) # type: ignore with (path / "parameters.pt").open("rb") as f: state_dict = torch.load(f) model.load_state_dict(state_dict) # pylint: disable=no-member return model
[docs] def clone(self: M, copy_parameters: bool = True) -> M: """ Clones this module by initializing another module with the same configuration. Args: copy_parameters: Whether to copy this module's parameters or initialize the new module with random parameters. Returns: The cloned module. """ cloned = self.__class__(self.config) # type: ignore if copy_parameters: cloned.load_state_dict(self.state_dict()) # pylint: disable=no-member return cloned
def _init_config(target: type[Any], args: dict[str, Any]) -> Any: result = {} for key, val in args.items(): arg_type = target.__dataclass_fields__[key].type # type: ignore if dataclasses.is_dataclass(arg_type): result[key] = _init_config(arg_type, val) else: result[key] = val return target(**result)