Skip to content

[Feature]: Generalize Prediction pipeline for Lightning CLI models#148

Draft
aditya0by0 wants to merge 11 commits intodevfrom
feature/general_pred_pipeline
Draft

[Feature]: Generalize Prediction pipeline for Lightning CLI models#148
aditya0by0 wants to merge 11 commits intodevfrom
feature/general_pred_pipeline

Conversation

@aditya0by0
Copy link
Member

@aditya0by0 aditya0by0 commented Jan 30, 2026

Generalize prediction logic

Please merge below PRs after this PR:

Related Discussion

Related bugs rectified in Lightning for the pipeline

Additional changes

@aditya0by0 aditya0by0 added the enhancement New feature or request label Jan 30, 2026
@aditya0by0 aditya0by0 requested a review from sfluegel05 February 3, 2026 10:12
@aditya0by0 aditya0by0 marked this pull request as ready for review February 3, 2026 10:12
@aditya0by0
Copy link
Member Author

@sfluegel05,

Could you confirm our agreed approach for handling old_checkpoint files which don't classification labels stored in them?

  1. Update code to handle legacy checkpoints: This requires adding logic to prediction.py and the chebifier repo to ingest external class files.

    • Concerns: Adds boilerplate and permanent complexity to handle a temporary issue.
  2. Patch old checkpoints (Preferred): Use the below one-time script to inject labels into the existing files.

    • Benefits: Keeps the codebase clean and ensures all checkpoints follow a standardized schema.

I’m in favor of Option 2 to avoid carrying technical debt in the prediction logic. Does this match your understanding?

I'm willing to add this script to the repo, and small readme note for old checkpoints for option 2.


import sys

import torch


def add_class_labels_to_checkpoint(input_path, classes_file_path):
    with open(classes_file_path, "r") as f:
        class_labels = [line.strip() for line in f.readlines()]

    assert len(class_labels) > 0, "The classes file is empty."

    # 1. Load the checkpoint
    checkpoint = torch.load(
        input_path, map_location=torch.device("cpu"), weights_only=False
    )

    if "classification_labels" in checkpoint:
        print(
            "Warning: 'classification_labels' key already exists in the checkpoint and will be overwritten."
        )


    # 2. Add your custom key/value pair
    checkpoint["classification_labels"] = class_labels

    # 3. Save the modified checkpoint
    output_path = input_path.replace(".ckpt", "_modified.ckpt")
    torch.save(checkpoint, output_path)
    print(f"Successfully added classification_labels and saved to {output_path}")


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Usage: python modify_checkpoints.py <input_checkpoint> <classes_file>")
        sys.exit(1)

    input_ckpt = sys.argv[1]
    classes_file = sys.argv[2]

    add_class_labels_to_checkpoint(
        input_path=input_ckpt, classes_file_path=classes_file
    )


Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a new generalized prediction pipeline intended to work with LightningCLI-saved models/checkpoints, including persisting classification label names into checkpoints for consistent prediction output formatting.

Changes:

  • Add checkpoint persistence of classification_labels (derived from a dataset classes.txt) and wire the dataset path into model init via LightningCLI argument linking.
  • Introduce a new SMILES prediction entrypoint (chebai/result/prediction.py) that reconstructs model/datamodule from checkpoint hyperparameters.
  • Refactor XYBaseDataModule.predict_dataloader to build a prediction dataloader from an in-memory SMILES list, plus update docs/tests and add VS Code workspace files.

Reviewed changes

Copilot reviewed 11 out of 12 changed files in this pull request and generated 11 comments.

Show a summary per file
File Description
tests/unit/cli/testCLI.py Adjusts CLI unit test model args (smaller hidden layer).
tests/unit/cli/mock_dm.py Adds classes_txt_file_path for CLI linking in tests.
tests/unit/cli/classification_labels.txt Adds sample classification labels used by CLI unit tests.
chebai/trainer/CustomTrainer.py Removes prior bespoke prediction logic and overrides predict().
chebai/result/prediction.py Adds new prediction script/class for SMILES/file inference from checkpoint.
chebai/preprocessing/datasets/base.py Refactors prediction dataloader flow and adds classes_txt_file_path.
chebai/models/base.py Adds label-file loading + saving classification_labels into checkpoints.
chebai/cli.py Links data.classes_txt_file_path into model.init_args.classes_txt_file_path.
README.md Updates prediction instructions to use the new prediction script.
.vscode/settings.json Adds VS Code project settings (currently invalid JSON).
.vscode/extensions.json Adds recommended VS Code extensions.
.gitignore Stops ignoring the entire .vscode directory (only ignores launch.json).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +440 to +444
def _process_input_for_prediction(
self, smiles_list: list[str], model_hparams: Optional[dict] = None
) -> tuple[list, list]:
"""
Process input data for prediction.
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new prediction preprocessing path (predict_dataloader returning (DataLoader, valid_indices) plus _process_input_for_prediction / _preprocess_smiles_for_pred) isn't covered by unit tests. Since XYBaseDataModule already has unit tests, adding coverage for valid-index scattering and the empty/all-filtered case would help prevent regressions.

Copilot uses AI. Check for mistakes.
Comment on lines 408 to +412
def predict_dataloader(
self, *args, **kwargs
) -> Union[DataLoader, List[DataLoader]]:
self,
smiles_list: List[str],
model_hparams: Optional[dict] = None,
**kwargs,
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XYBaseDataModule.predict_dataloader now requires smiles_list (and returns a tuple), which is incompatible with PyTorch Lightning's expected predict_dataloader(self) signature. Any use of Trainer.predict(...) / LightningCLI predict will fail with a missing-argument error. Keep the Lightning-compatible predict_dataloader() signature and add a separate helper (e.g., predict_dataloader_from_smiles(...)) or make smiles_list optional and fall back to the previous behavior when it is None.

Copilot uses AI. Check for mistakes.
Comment on lines +58 to +65
self._model_hparams = ckpt_file["hyper_parameters"]
self._model_hparams.pop("_instantiator", None)
self._model_hparams.pop("classes_txt_file_path", None)
self._model: ChebaiBaseNet = instantiate_module(
ChebaiBaseNet, self._model_hparams
)
self._model.to(self.device)
print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}")
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The predictor instantiates a fresh model via instantiate_module(...) but never loads the checkpoint weights (ckpt_file["state_dict"]) into it. As written, inference will run with randomly initialized weights. Load the state_dict into self._model (and consider strict=False with a clear warning if keys mismatch), or use Lightning's load_from_checkpoint(...) so weights are restored correctly.

Copilot uses AI. Check for mistakes.
print("Below are the modules loaded from the checkpoint:")

self._dm_hparams = ckpt_file["datamodule_hyper_parameters"]
self._dm_hparams.pop("splits_file_path")
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._dm_hparams.pop("splits_file_path") will raise KeyError for checkpoints/datamodules that don't have this key. Use a default (pop(..., None)) or gate this removal on key presence so inference works across datamodules.

Suggested change
self._dm_hparams.pop("splits_file_path")
self._dm_hparams.pop("splits_file_path", None)

Copilot uses AI. Check for mistakes.
Comment on lines +80 to +82
raise NotImplementedError(
"CustomTrainer.predict is not implemented."
"Use `Prediction.predict_from_file` or `Prediction.predict_smiles` from `chebai/result/prediction.py` instead."
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CustomTrainer.predict() now unconditionally raises NotImplementedError, which breaks Lightning's predict workflow and the existing chebai CLI predict subcommand (it will always crash). If prediction is intentionally unsupported through the Trainer, remove/disable the CLI predict subcommand; otherwise implement predict() by delegating to super().predict(...) (or a supported prediction path). Also, the error message points to Prediction.* but the new class is named Predictor.

Suggested change
raise NotImplementedError(
"CustomTrainer.predict is not implemented."
"Use `Prediction.predict_from_file` or `Prediction.predict_smiles` from `chebai/result/prediction.py` instead."
"""
Runs prediction using the underlying Lightning Trainer implementation.
This restores the default `Trainer.predict` behavior so that Lightning's
prediction workflow and the `chebai` CLI `predict` subcommand work as expected.
"""
return super().predict(
model=model,
dataloaders=dataloaders,
datamodule=datamodule,
return_predictions=return_predictions,
ckpt_path=ckpt_path,

Copilot uses AI. Check for mistakes.
Comment on lines +440 to +457
def _process_input_for_prediction(
self, smiles_list: list[str], model_hparams: Optional[dict] = None
) -> tuple[list, list]:
"""
Process input data for prediction.

Args:
smiles_list (List[str]): List of SMILES strings.
model_hparams (Optional[dict]): Model hyperparameters.
Some prediction pre-processing pipelines may require these.

Returns:
tuple[list, list]: Processed input data and valid indices.
"""
data, valid_indices = [], []
num_of_labels = int(model_hparams["out_dim"])
self._dummy_labels: list = list(range(1, num_of_labels + 1))

Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_process_input_for_prediction treats model_hparams as optional, but then immediately indexes model_hparams["out_dim"]. If model_hparams is omitted/None this will crash. Either make model_hparams required in the public API (and type it accordingly) or add a clear validation/error early in predict_dataloader/_process_input_for_prediction.

Copilot uses AI. Check for mistakes.
Comment on lines +67 to +72
self._classification_labels: list = ckpt_file.get("classification_labels")
print(f"Loaded {len(self._classification_labels)} classification labels.")
assert len(self._classification_labels) > 0, (
"Classification labels list is empty."
)
assert len(self._classification_labels) == self._model.out_dim, (
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ckpt_file.get("classification_labels") can return None; the next line calls len(self._classification_labels) which will raise TypeError. Handle the missing key explicitly (e.g., default to [] and raise a clear error if absent) so users get an actionable message when using older checkpoints.

Copilot uses AI. Check for mistakes.
aditya0by0 and others added 3 commits February 10, 2026 00:32
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@aditya0by0 aditya0by0 marked this pull request as draft February 17, 2026 12:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request priority: high

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants