diff --git a/.gitignore b/.gitignore index be471dd0..f1286a8c 100644 --- a/.gitignore +++ b/.gitignore @@ -175,7 +175,7 @@ chebai.egg-info lightning_logs logs .isort.cfg -/.vscode +/.vscode/launch.json *.out *.err diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 00000000..d1e06324 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,11 @@ +{ + "recommendations": [ + "ms-python.python", + "ms-python.vscode-pylance", + "charliermarsh.ruff", + "usernamehw.errorlens" + ], + "unwantedRecommendations": [ + "ms-python.vscode-python2" + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..dbebc3c5 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,16 @@ +{ + "python.testing.unittestArgs": [ + "-v", + "-s", + "./tests", + "-p", + "test*.py" + ], + "python.testing.pytestEnabled": false, + "python.testing.unittestEnabled": true, + "python.analysis.typeCheckingMode": "basic", + "editor.formatOnSave": true, + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff" + } +} diff --git a/README.md b/README.md index 7672bc28..401c7324 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --model=conf ``` A command with additional options may look like this: ``` -python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000 +python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce_weighted.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_weighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000 ``` ### Fine-tuning for classification tasks, e.g. Toxicity prediction @@ -78,11 +78,16 @@ python -m chebai fit --config=[path-to-your-esol-config] --trainer.callbacks=con ### Predicting classes given SMILES strings ``` -python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] +python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] --smiles_file_path=[path-to-file-containing-smiles] [--save_to=[path-to-output]] ``` -The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the -one row for each SMILES string and one column for each class. -The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs. + +* **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`). + +* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line. + +* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. Default path will be the current working directory with file name as `predictions.csv`. + +> **Note**: Newly created checkpoints after PR #148 must be used for this prediction pipeline. The list of ChEBI classes (classification labels) used during training is stored in new checkpoints, which are required. ## Evaluation @@ -96,7 +101,7 @@ An example notebook is provided at `tutorials/eval_model_basic.ipynb`. Alternatively, you can evaluate the model via the CLI: ```bash -python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file] +python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce_weighted.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file] ``` > **Note**: It is recommended to use `devices=1` and `num_nodes=1` during testing; multi-device settings use a `DistributedSampler`, which may replicate some samples to maintain equal batch sizes, so using a single device ensures that each sample or batch is evaluated exactly once. diff --git a/chebai/cli.py b/chebai/cli.py index 1aaba53c..d65dd51e 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -59,6 +59,12 @@ def call_data_methods(data: Type[XYBaseDataModule]): apply_on="instantiate", ) + parser.link_arguments( + "data.classes_txt_file_path", + "model.init_args.classes_txt_file_path", + apply_on="instantiate", + ) + for kind in ("train", "val", "test"): for average in ( "micro-f1", @@ -111,8 +117,6 @@ def subcommands() -> Dict[str, Set[str]]: "fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, "validate": {"model", "dataloaders", "datamodule"}, "test": {"model", "dataloaders", "datamodule"}, - "predict": {"model", "dataloaders", "datamodule"}, - "predict_from_file": {"model"}, } diff --git a/chebai/models/base.py b/chebai/models/base.py index 82d84033..df060e9a 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -40,6 +40,7 @@ def __init__( pass_loss_kwargs: bool = True, optimizer_kwargs: Optional[Dict[str, Any]] = None, exclude_hyperparameter_logging: Optional[Iterable[str]] = None, + classes_txt_file_path: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -47,8 +48,8 @@ def __init__( if exclude_hyperparameter_logging is None: exclude_hyperparameter_logging = tuple() self.criterion = criterion - assert out_dim is not None, "out_dim must be specified" - assert input_dim is not None, "input_dim must be specified" + assert out_dim is not None and out_dim > 0, "out_dim must be specified" + assert input_dim is not None and input_dim > 0, "input_dim must be specified" self.out_dim = out_dim self.input_dim = input_dim print( @@ -62,6 +63,7 @@ def __init__( "train_metrics", "val_metrics", "test_metrics", + "classes_txt_file_path", *exclude_hyperparameter_logging, ] ) @@ -78,6 +80,23 @@ def __init__( self.test_metrics = test_metrics self.pass_loss_kwargs = pass_loss_kwargs + self.classes_txt_file_path = classes_txt_file_path + + # During prediction `classes_txt_file_path` is set to None + if classes_txt_file_path is not None: + with open(classes_txt_file_path, "r") as f: + self.labels_list = [cls.strip() for cls in f.readlines()] + assert len(self.labels_list) > 0, "Class labels list is empty." + assert len(self.labels_list) == out_dim, ( + f"Number of class labels ({len(self.labels_list)}) does not match " + f"the model output dimension ({out_dim})." + ) + + def on_save_checkpoint(self, checkpoint): + if self.classes_txt_file_path is not None: + # https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#modify-a-checkpoint-anywhere + checkpoint["classification_labels"] = self.labels_list + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a # different loss) @@ -100,7 +119,7 @@ def __init_subclass__(cls, **kwargs): def _get_prediction_and_labels( self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor - ) -> (torch.Tensor, torch.Tensor): + ) -> tuple[torch.Tensor, torch.Tensor]: """ Gets the predictions and labels from the model output. @@ -151,7 +170,7 @@ def _process_for_loss( model_output: torch.Tensor, labels: torch.Tensor, loss_kwargs: Dict[str, Any], - ) -> (torch.Tensor, torch.Tensor, Dict[str, Any]): + ) -> tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: """ Processes the data for loss computation. @@ -237,7 +256,7 @@ def predict_step( Returns: Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step. """ - return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False) + return self._execute(batch, batch_idx, log=False) def _execute( self, diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 9ff40748..e295a3ed 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -340,18 +340,19 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: for d in tqdm.tqdm(self._load_dict(path), total=lines) if d["features"] is not None ] - # filter for missing features in resulting data, keep features length below token limit - data = [ - val - for val in data - if val["features"] is not None - and ( - self.n_token_limit is None or len(val["features"]) <= self.n_token_limit - ) - ] + data = [val for val in data if self._filter_to_token_limit(val)] return data + def _filter_to_token_limit(self, data_instance: dict) -> bool: + # filter for missing features in resulting data, keep features length below token limit + if data_instance["features"] is not None and ( + self.n_token_limit is None + or len(data_instance["features"]) <= self.n_token_limit + ): + return True + return False + def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: """ Returns the train DataLoader. @@ -401,22 +402,84 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] Returns: Union[DataLoader, List[DataLoader]]: A DataLoader object for test data. """ + return self.dataloader("test", shuffle=False, **kwargs) def predict_dataloader( - self, *args, **kwargs - ) -> Union[DataLoader, List[DataLoader]]: + self, + smiles_list: List[str], + model_hparams: dict, + **kwargs, + ) -> tuple[DataLoader, list[int]]: """ Returns the predict DataLoader. Args: - *args: Additional positional arguments (unused). + smiles_list (List[str]): List of SMILES strings to predict. + model_hparams (Optional[dict]): Model hyperparameters. + Some prediction pre-processing pipelines may require these. **kwargs: Additional keyword arguments, passed to dataloader(). Returns: - Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data. + tuple[DataLoader, list[int]]: A DataLoader object for prediction data and a list of valid indices. """ - return self.dataloader(self.prediction_kind, shuffle=False, **kwargs) + + data, valid_indices = self._process_input_for_prediction( + smiles_list, model_hparams + ) + return ( + DataLoader( + data, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ), + valid_indices, + ) + + def _process_input_for_prediction( + self, smiles_list: list[str], model_hparams: dict + ) -> tuple[list, list]: + """ + Process input data for prediction. + + Args: + smiles_list (List[str]): List of SMILES strings. + model_hparams (dict): Model hyperparameters. + Some prediction pre-processing pipelines may require these. + + Returns: + tuple[list, list]: Processed input data and valid indices. + """ + data, valid_indices = [], [] + num_of_labels = int(model_hparams["out_dim"]) + self._dummy_labels: list = list(range(1, num_of_labels + 1)) + + for idx, smiles in enumerate(smiles_list): + result = self._preprocess_smiles_for_pred(idx, smiles, model_hparams) + if result is None or result["features"] is None: + continue + if not self._filter_to_token_limit(result): + continue + data.append(result) + valid_indices.append(idx) + + return data, valid_indices + + def _preprocess_smiles_for_pred( + self, idx: int, smiles: str, model_hparams: Optional[dict] = None + ) -> dict: + """Preprocess prediction data.""" + # Add dummy labels because the collate function requires them. + # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, + # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. + return self.reader.to_data( + { + "id": f"smiles_{idx}", + "features": smiles, + "labels": self._dummy_labels, + } + ) def prepare_data(self, *args, **kwargs) -> None: if self._prepare_data_flag != 1: @@ -563,6 +626,19 @@ def raw_file_names_dict(self) -> dict: """ raise NotImplementedError + @property + def classes_txt_file_path(self) -> str: + """ + Returns the filename for the classes text file. + + Returns: + str: The filename for the classes text file. + """ + # This property also used in following places: + # - chebai/result/prediction.py: to load class names for csv columns names + # - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path` + return os.path.join(self.processed_dir_main, "classes.txt") + class MergedDataset(XYBaseDataModule): MERGED = [] @@ -1189,7 +1265,8 @@ def _retrieve_splits_from_csv(self) -> None: print(f"Applying label filter from {self.apply_label_filter}...") with open(self.apply_label_filter, "r") as f: label_filter = [line.strip() for line in f] - with open(os.path.join(self.processed_dir_main, "classes.txt"), "r") as cf: + + with open(self.classes_txt_file_path, "r") as cf: classes = [line.strip() for line in cf] # reorder labels old_labels = np.stack(df_data["labels"]) diff --git a/chebai/preprocessing/migration/migrate_checkpoints.py b/chebai/preprocessing/migration/migrate_checkpoints.py new file mode 100644 index 00000000..3f9f5358 --- /dev/null +++ b/chebai/preprocessing/migration/migrate_checkpoints.py @@ -0,0 +1,56 @@ +""" +Docstring for chebai.preprocessing.migration.migrate_checkpoints + +This script migrates lightning checkpoints created before python-chebai +version 1.2.1 to be compatible with the new version. + +The main change is the addition of a new key "classification_labels" in the checkpoint, +which is required for the new version of python-chebai from version 1.2.1 onwards. + +For more details, see the pull request: https://github.com/ChEB-AI/python-chebai/pulls +""" + +import sys + +import torch + + +def add_class_labels_to_checkpoint(input_path, classes_file_path): + print(f"Loading checkpoint from {input_path}...") + print(f"Loading class labels from {classes_file_path}...") + + with open(classes_file_path, "r") as f: + class_labels = [line.strip() for line in f.readlines()] + + assert len(class_labels) > 0, "The classes file is empty." + + # 1. Load the checkpoint + checkpoint = torch.load( + input_path, map_location=torch.device("cpu"), weights_only=False + ) + + if "classification_labels" in checkpoint: + print( + "Warning: 'classification_labels' key already exists in the checkpoint and will be overwritten." + ) + + # 2. Add your custom key/value pair + checkpoint["classification_labels"] = class_labels + + # 3. Save the modified checkpoint + output_path = input_path.replace(".ckpt", "_modified.ckpt") + torch.save(checkpoint, output_path) + print(f"Successfully added classification_labels and saved to {output_path}") + + +if __name__ == "__main__": + if len(sys.argv) < 3: + print("Usage: python migrate_checkpoints.py ") + sys.exit(1) + + input_ckpt = sys.argv[1] + classes_file = sys.argv[2] + + add_class_labels_to_checkpoint( + input_path=input_ckpt, classes_file_path=classes_file + ) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py new file mode 100644 index 00000000..60548c5e --- /dev/null +++ b/chebai/result/prediction.py @@ -0,0 +1,195 @@ +from typing import List, Optional + +import pandas as pd +import torch +from jsonargparse import CLI +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.cli import instantiate_module + +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.datasets.base import XYBaseDataModule + + +class Predictor: + def __init__( + self, + checkpoint_path: _PATH, + batch_size: Optional[int] = None, + compile_model: bool = True, + ): + """Initializes the Predictor with a model loaded from the checkpoint. + + Args: + checkpoint_path: Path to the model checkpoint. + batch_size: Optional batch size for the DataLoader. If not provided, + the default from the datamodule will be used. + compile_model: Whether to compile the model using torch.compile. Default is True. + """ + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ckpt_file = torch.load( + checkpoint_path, map_location=self.device, weights_only=False + ) + assert ( + "_class_path" in ckpt_file["datamodule_hyper_parameters"] + and "_class_path" in ckpt_file["hyper_parameters"] + ), ( + "Datamodule and Model hyperparameters must include a '_class_path' key.\n" + "Hence, either the checkpoint is corrupted or " + "it was not saved properly with latest lightning version" + ) + + print("-" * 50) + print(f"Using device: {self.device}") + print(f"For Loaded checkpoint from: {checkpoint_path}") + print("Below are the modules loaded from the checkpoint:") + + self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] + self._dm_hparams.pop("splits_file_path", None) + self._dm_hparams.pop("augment_smiles", None) + self._dm_hparams.pop("aug_smiles_variations", None) + self._dm_hparams.pop("_instantiator", None) + self._dm: XYBaseDataModule = instantiate_module( + XYBaseDataModule, self._dm_hparams + ) + if batch_size is not None and int(batch_size) > 0: + self._dm.batch_size = int(batch_size) + print("*" * 10, f"Loaded datamodule class: {self._dm.__class__.__name__}") + + self._model_hparams = ckpt_file["hyper_parameters"] + self._model_hparams.pop("_instantiator", None) + self._model_hparams.pop("classes_txt_file_path", None) + self._model = ChebaiBaseNet.load_from_checkpoint( + checkpoint_path, map_location=self.device + ) + assert ( + isinstance(self._model, ChebaiBaseNet) + and type(self._model) is not ChebaiBaseNet + ), ( + f"Loaded model must be a subclass of ChebaiBaseNet, not ChebaiBaseNet itself. " + f"Got {type(self._model).__name__}." + ) + self._model.to(self.device) + print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}") + + self._classification_labels: list = ckpt_file.get("classification_labels") + if self._classification_labels is None: + raise KeyError( + "The checkpoint does not contain 'classification_labels'. " + "Make sure the checkpoint is compatible with python-chebai version 1.2.1 or later." + "See the migration script in `chebai.preprocessing.migration.migrate_checkpoints` for more details." + ) + + print(f"Loaded {len(self._classification_labels)} classification labels.") + assert len(self._classification_labels) > 0, ( + "Classification labels list is empty." + ) + assert len(self._classification_labels) == self._model.out_dim, ( + f"Number of class labels ({len(self._classification_labels)}) does not match " + f"the model output dimension ({self._model.out_dim})." + ) + + if compile_model: + self._model = torch.compile(self._model) # type: ignore + self._model.eval() + print("-" * 50) + + def predict_from_file( + self, + smiles_file_path: _PATH, + save_to: _PATH = "predictions.csv", + ) -> None: + """ + Loads a model from a checkpoint and makes predictions on input data from a file. + + Args: + smiles_file_path: Path to the input file containing SMILES strings. + save_to: Path to save the predictions CSV file. + """ + with open(smiles_file_path, "r") as input: + smiles_strings = [inp.strip() for inp in input.readlines()] + + preds: list[torch.Tensor | None] = self.predict_smiles(smiles=smiles_strings) + if all(pred is None for pred in preds): + print("No valid predictions were made. (All predictions are None.)") + return + + num_of_cols = len(self._classification_labels) + rows = [ + pred.tolist() if pred is not None else [None] * num_of_cols + for pred in preds + ] + predictions_df = pd.DataFrame( + rows, columns=self._classification_labels, index=smiles_strings + ) + + predictions_df.to_csv(save_to) + print(f"Predictions saved to: {save_to}") + + @torch.inference_mode() + def predict_smiles( + self, + smiles: List[str], + ) -> list[torch.Tensor | None]: + """ + Predicts the output for a list of SMILES strings using the model. + + Args: + smiles: A list of SMILES strings. + + Returns: + A tensor containing the predictions. + """ + # For certain data prediction pipelines, we may need model hyperparameters + pred_dl, valid_indices = self._dm.predict_dataloader( + smiles_list=smiles, model_hparams=self._model_hparams + ) + if valid_indices is None or len(valid_indices) == 0: + return [None] * len(smiles) + + preds = [] + for batch_idx, batch in enumerate(pred_dl): + # For certain model prediction pipelines, we may need data module hyperparameters + result = self._model.predict_step( + batch, batch_idx, dm_hparams=self._dm_hparams + ) + preds.append(result["preds"]) + preds = torch.cat(preds) + + # Initialize output with None + output: list[torch.Tensor | None] = [None] * len(smiles) + + # Scatter predictions back + for pred, idx in zip(preds, valid_indices): + output[idx] = pred + + return output + + +class MainPredictor: + @staticmethod + def predict_from_file( + checkpoint_path: _PATH, + smiles_file_path: _PATH, + save_to: _PATH = "predictions.csv", + batch_size: Optional[int] = None, + ) -> None: + predictor = Predictor(checkpoint_path, batch_size) + predictor.predict_from_file( + smiles_file_path, + save_to, + ) + + @staticmethod + def predict_smiles( + checkpoint_path: _PATH, + smiles: List[str], + batch_size: Optional[int] = None, + ) -> list[torch.Tensor | None]: + predictor = Predictor(checkpoint_path, batch_size) + return predictor.predict_smiles(smiles) + + +if __name__ == "__main__": + # python chebai/result/prediction.py predict_from_file --help + # python chebai/result/prediction.py predict_smiles --help + CLI(MainPredictor, as_positional=False) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index 5c960007..c81fda82 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -1,18 +1,13 @@ import logging -from typing import Any, List, Optional, Tuple +from typing import Any, Optional, Tuple -import pandas as pd -import torch -from lightning import LightningModule, Trainer +from lightning import Trainer from lightning.fabric.utilities.data import _set_sampler_epoch -from lightning.fabric.utilities.types import _PATH from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.loops.fit_loop import _FitLoop from lightning.pytorch.trainer import call -from torch.nn.utils.rnn import pad_sequence from chebai.loggers.custom import CustomLogger -from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader log = logging.getLogger(__name__) @@ -74,68 +69,18 @@ def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: else: return key, value - def predict_from_file( + def predict( self, - model: LightningModule, - checkpoint_path: _PATH, - input_path: _PATH, - save_to: _PATH = "predictions.csv", - classes_path: Optional[_PATH] = None, - **kwargs, - ) -> None: - """ - Loads a model from a checkpoint and makes predictions on input data from a file. - - Args: - model: The model to use for predictions. - checkpoint_path: Path to the model checkpoint. - input_path: Path to the input file containing SMILES strings. - save_to: Path to save the predictions CSV file. - classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). - """ - loaded_model = model.__class__.load_from_checkpoint(checkpoint_path) - with open(input_path, "r") as input: - smiles_strings = [inp.strip() for inp in input.readlines()] - loaded_model.eval() - predictions = self._predict_smiles(loaded_model, smiles_strings) - predictions_df = pd.DataFrame(predictions.detach().cpu().numpy()) - if classes_path is not None: - with open(classes_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] - predictions_df.index = smiles_strings - predictions_df.to_csv(save_to) - - def _predict_smiles( - self, model: LightningModule, smiles: List[str] - ) -> torch.Tensor: - """ - Predicts the output for a list of SMILES strings using the model. - - Args: - model: The model to use for predictions. - smiles: A list of SMILES strings. - - Returns: - A tensor containing the predictions. - """ - reader = ChemDataReader() - parsed_smiles = [reader._read_data(s) for s in smiles] - x = pad_sequence( - [torch.tensor(a, device=model.device) for a in parsed_smiles], - batch_first=True, + model=None, + dataloaders=None, + datamodule=None, + return_predictions=None, + ckpt_path=None, + ): + raise NotImplementedError( + "CustomTrainer.predict is not implemented." + "Use the script from `chebai/result/prediction.py` instead." ) - cls_tokens = ( - torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1) - * CLS_TOKEN - ) - features = torch.cat((cls_tokens, x), dim=1) - model_output = model({"features": features}) - if model.model_type == "regression": - preds = model_output["logits"] - else: - preds = torch.sigmoid(model_output["logits"]) - - return preds @property def log_dir(self) -> Optional[str]: diff --git a/tests/unit/cli/classification_labels.txt b/tests/unit/cli/classification_labels.txt new file mode 100644 index 00000000..06d2d6d1 --- /dev/null +++ b/tests/unit/cli/classification_labels.txt @@ -0,0 +1,10 @@ +label_1 +label_2 +label_3 +label_4 +label_5 +label_6 +label_7 +label_8 +label_9 +label_10 diff --git a/tests/unit/cli/mock_dm.py b/tests/unit/cli/mock_dm.py index 25116e21..40e68bb4 100644 --- a/tests/unit/cli/mock_dm.py +++ b/tests/unit/cli/mock_dm.py @@ -1,16 +1,27 @@ +import os + import torch -from lightning.pytorch.core.datamodule import LightningDataModule from torch.utils.data import DataLoader from chebai.preprocessing.collate import RaggedCollator +from chebai.preprocessing.datasets.base import XYBaseDataModule +from chebai.preprocessing.reader import DataReader + + +class MockReader(DataReader): + def name(self) -> str: + return "mock_reader" -class MyLightningDataModule(LightningDataModule): +class MyLightningDataModule(XYBaseDataModule): + READER = MockReader + def __init__(self): super().__init__() self._num_of_labels = None self._feature_vector_size = None self.collator = RaggedCollator() + self.save_hyperparameters() def prepare_data(self): pass @@ -29,6 +40,14 @@ def num_of_labels(self): def feature_vector_size(self): return self._feature_vector_size + @property + def classes_txt_file_path(self) -> str: + return os.path.join("tests", "unit", "cli", "classification_labels.txt") + + @property + def _name(self) -> str: + return "mock_dm" + def train_dataloader(self): assert self.feature_vector_size is not None, "feature_vector_size must be set" # Dummy dataset for example purposes @@ -44,3 +63,19 @@ def train_dataloader(self): ] return DataLoader(datalist, batch_size=32, collate_fn=self.collator) + + def val_dataloader(self): + assert self.feature_vector_size is not None, "feature_vector_size must be set" + # Dummy validation dataset + + datalist = [ + { + "features": torch.randn(self._feature_vector_size), + "labels": torch.randint(0, 2, (self._num_of_labels,), dtype=torch.bool), + "ident": i, + "group": None, + } + for i in range(32) + ] + + return DataLoader(datalist, batch_size=32, collate_fn=self.collator) diff --git a/tests/unit/cli/testCLI.py b/tests/unit/cli/testCLI.py index 863a6df3..92119878 100644 --- a/tests/unit/cli/testCLI.py +++ b/tests/unit/cli/testCLI.py @@ -1,6 +1,11 @@ import unittest +from pathlib import Path + +import torch +from lightning import LightningDataModule, LightningModule from chebai.cli import ChebaiCLI +from chebai.result.prediction import Predictor class TestChebaiCLI(unittest.TestCase): @@ -9,8 +14,9 @@ def setUp(self): "fit", "--trainer=configs/training/default_trainer.yml", "--model=configs/model/ffn.yml", - "--model.init_args.hidden_layers=[10]", + "--model.init_args.hidden_layers=[1]", "--model.train_metrics=configs/metrics/micro-macro-f1.yml", + "--model.val_metrics=configs/metrics/micro-macro-f1.yml", "--data=tests/unit/cli/mock_dm_config.yml", "--model.pass_loss_kwargs=false", "--trainer.min_epochs=1", @@ -20,14 +26,33 @@ def setUp(self): def test_mlp_on_chebai_cli(self): # Instantiate ChebaiCLI and ensure no exceptions are raised - try: - ChebaiCLI( - args=self.cli_args, - save_config_kwargs={"config_filename": "lightning_config.yaml"}, - parser_kwargs={"parser_mode": "omegaconf"}, - ) - except Exception as e: - self.fail(f"ChebaiCLI raised an unexpected exception: {e}") + cli = ChebaiCLI( + args=self.cli_args, + save_config_kwargs={"config_filename": "lightning_config.yaml"}, + parser_kwargs={"parser_mode": "omegaconf"}, + ) + assert cli.trainer.log_dir is not None + checkpoint_path = next( + Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None + ) + assert checkpoint_path is not None and checkpoint_path.is_file() + loaded_checkpoint = torch.load(checkpoint_path, weights_only=True) + model_hparams = loaded_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] + dm_hparams = loaded_checkpoint[LightningDataModule.CHECKPOINT_HYPER_PARAMS_KEY] + assert "classification_labels" in loaded_checkpoint, ( + "Checkpoint is missing 'classification_labels' key." + ) + assert "_class_path" in model_hparams, ( + "Model hyperparameters missing '_class_path' key." + ) + assert "_class_path" in dm_hparams, ( + "DataModule hyperparameters missing '_class_path' key." + ) + assert "classes_txt_file_path" not in model_hparams, ( + "Model hyperparameters should not contain 'classes_txt_file_path' key." + ) + + Predictor(checkpoint_path=checkpoint_path) if __name__ == "__main__":