diff --git a/chebai/cli.py b/chebai/cli.py index 1aaba53c..fbbc5d39 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -31,6 +31,95 @@ def __init__(self, *args, **kwargs): """ super().__init__(trainer_class=CustomTrainer, *args, **kwargs) + def before_instantiate_classes(self) -> None: + """ + Hook called before instantiating classes (Lightning 2.6+ compatible). + Instantiate the datamodule early to compute num_labels and feature_vector_size. + """ + # Get the current subcommand config (fit, test, validate, predict, etc.) + subcommand = self.config.get(self.config["subcommand"]) + + if not (subcommand and "data" in subcommand): + return + + data_config = subcommand["data"] + if "class_path" not in data_config: + return + + # Import and instantiate the datamodule class + module_path, class_name = data_config["class_path"].rsplit(".", 1) + import importlib + module = importlib.import_module(module_path) + data_class = getattr(module, class_name) + + # Instantiate with init_args + init_args = data_config.get("init_args", {}) + data_instance = data_class(**init_args) + + # Call prepare_data and setup to initialize dynamic properties + # We need to check the private attribute to avoid calling the property which has an assert + if hasattr(data_instance, "_num_of_labels") and data_instance._num_of_labels is None: + data_instance.prepare_data() + data_instance.setup() + + num_labels = data_instance.num_of_labels + feature_vector_size = data_instance.feature_vector_size + + # Update model init args + self._update_model_args(subcommand, num_labels, feature_vector_size) + + # Update trainer callbacks + self._update_trainer_callbacks(subcommand, num_labels) + + def _update_model_args(self, subcommand: dict, num_labels: int, feature_vector_size: int) -> None: + """Helper method to update model initialization arguments.""" + if "model" not in subcommand or "init_args" not in subcommand["model"]: + return + + model_init_args = subcommand["model"]["init_args"] + + # Set out_dim and input_dim if not already set + if model_init_args.get("out_dim") is None: + model_init_args["out_dim"] = num_labels + if model_init_args.get("input_dim") is None: + model_init_args["input_dim"] = feature_vector_size + + # Update metrics num_labels in all metrics configurations + for kind in ("train", "val", "test"): + metrics_key = f"{kind}_metrics" + metrics_config = model_init_args.get(metrics_key) + if metrics_config: + self._update_metrics_num_labels(metrics_config, num_labels) + + def _update_metrics_num_labels(self, metrics_config: dict, num_labels: int) -> None: + """Helper method to update num_labels in metrics configuration.""" + init_args = metrics_config.get("init_args", {}) + metrics_dict = init_args.get("metrics", {}) + + for metric_name, metric_config in metrics_dict.items(): + metric_init_args = metric_config.get("init_args", {}) + if "num_labels" in metric_init_args and metric_init_args["num_labels"] is None: + metric_init_args["num_labels"] = num_labels + + def _update_trainer_callbacks(self, subcommand: dict, num_labels: int) -> None: + """Helper method to update num_labels in trainer callbacks.""" + if "trainer" not in subcommand or "callbacks" not in subcommand["trainer"]: + return + + callbacks = subcommand["trainer"]["callbacks"] + + if isinstance(callbacks, list): + for callback in callbacks: + self._set_callback_num_labels(callback, num_labels) + else: + self._set_callback_num_labels(callbacks, num_labels) + + def _set_callback_num_labels(self, callback: dict, num_labels: int) -> None: + """Helper method to set num_labels in a single callback configuration.""" + init_args = callback.get("init_args", {}) + if "num_labels" in init_args and init_args["num_labels"] is None: + init_args["num_labels"] = num_labels + def add_arguments_to_parser(self, parser: LightningArgumentParser): """ Link input parameters that are used by different classes (e.g. number of labels) @@ -38,27 +127,16 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser): Args: parser (LightningArgumentParser): Argument parser instance. + + Note: + In Lightning 2.6+, we use model.init_args.out_dim as the source for linking + because it's set during before_instantiate_classes() from the computed num_labels. + This avoids issues with linking from data.num_of_labels which is a property + that requires the datamodule to be instantiated. """ - def call_data_methods(data: Type[XYBaseDataModule]): - if data._num_of_labels is None: - data.prepare_data() - data.setup() - return data.num_of_labels - - parser.link_arguments( - "data", - "model.init_args.out_dim", - apply_on="instantiate", - compute_fn=call_data_methods, - ) - - parser.link_arguments( - "data.feature_vector_size", - "model.init_args.input_dim", - apply_on="instantiate", - ) - + # Link num_labels (via out_dim) to metrics configurations + # out_dim is set in before_instantiate_classes() from data.num_of_labels for kind in ("train", "val", "test"): for average in ( "micro-f1", @@ -70,31 +148,17 @@ def call_data_methods(data: Type[XYBaseDataModule]): "rmse", "r2", ): - # When using lightning > 2.5.1 then need to uncomment all metrics that are not used - # for average in ("mse", "rmse","r2"): # for regression - # for average in ("f1", "roc-auc"): # for binary classification - # for average in ("micro-f1", "macro-f1", "roc-auc"): # for multilabel classification - # for average in ("micro-f1", "macro-f1", "balanced-accuracy", "roc-auc"): # for multilabel classification using balanced-accuracy parser.link_arguments( - "data.num_of_labels", + "model.init_args.out_dim", f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels", - apply_on="instantiate", ) + # Link out_dim to trainer callbacks parser.link_arguments( - "data.num_of_labels", "trainer.callbacks.init_args.num_labels" + "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels" ) - # parser.link_arguments( - # "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels" - # ) - # parser.link_arguments( - # "data", "model.init_args.criterion.init_args.data_extractor" - # ) - # parser.link_arguments( - # "data.init_args.chebi_version", - # "model.init_args.criterion.init_args.data_extractor.init_args.chebi_version", - # ) + # Link datamodule to criterion's data extractor parser.link_arguments( "data", "model.init_args.criterion.init_args.data_extractor" ) diff --git a/pyproject.toml b/pyproject.toml index b3652b00..f407f340 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "transformers", "pysmiles==1.1.2", "rdkit==2024.3.6", - "lightning==2.5.1", + "lightning==2.6.1", ] [project.optional-dependencies]