[Feature]: Generalize Prediction pipeline for Lightning CLI models#148
[Feature]: Generalize Prediction pipeline for Lightning CLI models#148aditya0by0 wants to merge 11 commits intodevfrom
Conversation
|
Could you confirm our agreed approach for handling
I’m in favor of Option 2 to avoid carrying technical debt in the prediction logic. Does this match your understanding? I'm willing to add this script to the repo, and small readme note for old checkpoints for option 2. |
There was a problem hiding this comment.
Pull request overview
This PR introduces a new generalized prediction pipeline intended to work with LightningCLI-saved models/checkpoints, including persisting classification label names into checkpoints for consistent prediction output formatting.
Changes:
- Add checkpoint persistence of
classification_labels(derived from a datasetclasses.txt) and wire the dataset path into model init via LightningCLI argument linking. - Introduce a new SMILES prediction entrypoint (
chebai/result/prediction.py) that reconstructs model/datamodule from checkpoint hyperparameters. - Refactor
XYBaseDataModule.predict_dataloaderto build a prediction dataloader from an in-memory SMILES list, plus update docs/tests and add VS Code workspace files.
Reviewed changes
Copilot reviewed 11 out of 12 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unit/cli/testCLI.py | Adjusts CLI unit test model args (smaller hidden layer). |
| tests/unit/cli/mock_dm.py | Adds classes_txt_file_path for CLI linking in tests. |
| tests/unit/cli/classification_labels.txt | Adds sample classification labels used by CLI unit tests. |
| chebai/trainer/CustomTrainer.py | Removes prior bespoke prediction logic and overrides predict(). |
| chebai/result/prediction.py | Adds new prediction script/class for SMILES/file inference from checkpoint. |
| chebai/preprocessing/datasets/base.py | Refactors prediction dataloader flow and adds classes_txt_file_path. |
| chebai/models/base.py | Adds label-file loading + saving classification_labels into checkpoints. |
| chebai/cli.py | Links data.classes_txt_file_path into model.init_args.classes_txt_file_path. |
| README.md | Updates prediction instructions to use the new prediction script. |
| .vscode/settings.json | Adds VS Code project settings (currently invalid JSON). |
| .vscode/extensions.json | Adds recommended VS Code extensions. |
| .gitignore | Stops ignoring the entire .vscode directory (only ignores launch.json). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _process_input_for_prediction( | ||
| self, smiles_list: list[str], model_hparams: Optional[dict] = None | ||
| ) -> tuple[list, list]: | ||
| """ | ||
| Process input data for prediction. |
There was a problem hiding this comment.
The new prediction preprocessing path (predict_dataloader returning (DataLoader, valid_indices) plus _process_input_for_prediction / _preprocess_smiles_for_pred) isn't covered by unit tests. Since XYBaseDataModule already has unit tests, adding coverage for valid-index scattering and the empty/all-filtered case would help prevent regressions.
| def predict_dataloader( | ||
| self, *args, **kwargs | ||
| ) -> Union[DataLoader, List[DataLoader]]: | ||
| self, | ||
| smiles_list: List[str], | ||
| model_hparams: Optional[dict] = None, | ||
| **kwargs, |
There was a problem hiding this comment.
XYBaseDataModule.predict_dataloader now requires smiles_list (and returns a tuple), which is incompatible with PyTorch Lightning's expected predict_dataloader(self) signature. Any use of Trainer.predict(...) / LightningCLI predict will fail with a missing-argument error. Keep the Lightning-compatible predict_dataloader() signature and add a separate helper (e.g., predict_dataloader_from_smiles(...)) or make smiles_list optional and fall back to the previous behavior when it is None.
| self._model_hparams = ckpt_file["hyper_parameters"] | ||
| self._model_hparams.pop("_instantiator", None) | ||
| self._model_hparams.pop("classes_txt_file_path", None) | ||
| self._model: ChebaiBaseNet = instantiate_module( | ||
| ChebaiBaseNet, self._model_hparams | ||
| ) | ||
| self._model.to(self.device) | ||
| print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}") |
There was a problem hiding this comment.
The predictor instantiates a fresh model via instantiate_module(...) but never loads the checkpoint weights (ckpt_file["state_dict"]) into it. As written, inference will run with randomly initialized weights. Load the state_dict into self._model (and consider strict=False with a clear warning if keys mismatch), or use Lightning's load_from_checkpoint(...) so weights are restored correctly.
| print("Below are the modules loaded from the checkpoint:") | ||
|
|
||
| self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] | ||
| self._dm_hparams.pop("splits_file_path") |
There was a problem hiding this comment.
self._dm_hparams.pop("splits_file_path") will raise KeyError for checkpoints/datamodules that don't have this key. Use a default (pop(..., None)) or gate this removal on key presence so inference works across datamodules.
| self._dm_hparams.pop("splits_file_path") | |
| self._dm_hparams.pop("splits_file_path", None) |
| raise NotImplementedError( | ||
| "CustomTrainer.predict is not implemented." | ||
| "Use `Prediction.predict_from_file` or `Prediction.predict_smiles` from `chebai/result/prediction.py` instead." |
There was a problem hiding this comment.
CustomTrainer.predict() now unconditionally raises NotImplementedError, which breaks Lightning's predict workflow and the existing chebai CLI predict subcommand (it will always crash). If prediction is intentionally unsupported through the Trainer, remove/disable the CLI predict subcommand; otherwise implement predict() by delegating to super().predict(...) (or a supported prediction path). Also, the error message points to Prediction.* but the new class is named Predictor.
| raise NotImplementedError( | |
| "CustomTrainer.predict is not implemented." | |
| "Use `Prediction.predict_from_file` or `Prediction.predict_smiles` from `chebai/result/prediction.py` instead." | |
| """ | |
| Runs prediction using the underlying Lightning Trainer implementation. | |
| This restores the default `Trainer.predict` behavior so that Lightning's | |
| prediction workflow and the `chebai` CLI `predict` subcommand work as expected. | |
| """ | |
| return super().predict( | |
| model=model, | |
| dataloaders=dataloaders, | |
| datamodule=datamodule, | |
| return_predictions=return_predictions, | |
| ckpt_path=ckpt_path, |
| def _process_input_for_prediction( | ||
| self, smiles_list: list[str], model_hparams: Optional[dict] = None | ||
| ) -> tuple[list, list]: | ||
| """ | ||
| Process input data for prediction. | ||
|
|
||
| Args: | ||
| smiles_list (List[str]): List of SMILES strings. | ||
| model_hparams (Optional[dict]): Model hyperparameters. | ||
| Some prediction pre-processing pipelines may require these. | ||
|
|
||
| Returns: | ||
| tuple[list, list]: Processed input data and valid indices. | ||
| """ | ||
| data, valid_indices = [], [] | ||
| num_of_labels = int(model_hparams["out_dim"]) | ||
| self._dummy_labels: list = list(range(1, num_of_labels + 1)) | ||
|
|
There was a problem hiding this comment.
_process_input_for_prediction treats model_hparams as optional, but then immediately indexes model_hparams["out_dim"]. If model_hparams is omitted/None this will crash. Either make model_hparams required in the public API (and type it accordingly) or add a clear validation/error early in predict_dataloader/_process_input_for_prediction.
| self._classification_labels: list = ckpt_file.get("classification_labels") | ||
| print(f"Loaded {len(self._classification_labels)} classification labels.") | ||
| assert len(self._classification_labels) > 0, ( | ||
| "Classification labels list is empty." | ||
| ) | ||
| assert len(self._classification_labels) == self._model.out_dim, ( |
There was a problem hiding this comment.
ckpt_file.get("classification_labels") can return None; the next line calls len(self._classification_labels) which will raise TypeError. Handle the missing key explicitly (e.g., default to [] and raise a clear error if absent) so users get an actionable message when using older checkpoints.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Generalize prediction logic
Please merge below PRs after this PR:
Related Discussion
Related bugs rectified in Lightning for the pipeline
LightningDataModule.load_from_checkpointdoes not restore subclass fromdatamodule_hyper_parametersLightning-AI/pytorch-lightning#21477save_hyperparameters(ignore=...)is not persistent across inheritance; ignored params reappear when base class also callssave_hyperparametersLightning-AI/pytorch-lightning#21488Additional changes
Save class labels in checkpoint under the key "classification_labels"
Wrap inference with
torch.inference_mode()to avoid gradient tracking (see Avoid gradient tracking python-chebifier#21)model.eval()in PyTorchtorch.no_grad()andtorch.inference_mode()Use
torch.compilefor faster inference