diff --git a/chebai/models/base.py b/chebai/models/base.py index cb254570..263d6f93 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -49,6 +49,11 @@ def __init__( if exclude_hyperparameter_logging is None: exclude_hyperparameter_logging = tuple() self.criterion = criterion + assert out_dim is not None, "out_dim must be specified" + assert input_dim is not None, "input_dim must be specified" + self.out_dim = out_dim + self.input_dim = input_dim + self.save_hyperparameters( ignore=[ "criterion", @@ -59,10 +64,8 @@ def __init__( ] ) - self.out_dim = out_dim - self.input_dim = input_dim - assert out_dim is not None, "out_dim must be specified" - assert input_dim is not None, "input_dim must be specified" + self.hparams["out_dim"] = out_dim + self.hparams["input_dim"] = input_dim if optimizer_kwargs: self.optimizer_kwargs = optimizer_kwargs