Skip to content

Question: Simulating gene knockout on novel datasets (zero-shot inference) & Implementation feedback #5

@jklupup

Description

@jklupup

Dear Authors,

First of all, thank you for your outstanding work on PerturbDiff! The concept and the methodology are truly inspiring.

I am currently trying to apply your pre-trained model (specifically the finetuned_tahoe100m_fixed.ckpt) to a novel, independent single-cell dataset (Colorectal Cancer data). My goal is to perform pure inference: simulating the knockout of a specific gene (e.g., TP53) on this unseen dataset, without having any actual ground-truth perturbation data or paired control cells.

My primary question is: Does the current theoretical framework and the pre-trained checkpoint support this kind of "zero-shot" simulation on a completely novel dataset?

While attempting to implement this inference pipeline, I ran into several engineering challenges. It seems that the current codebase is heavily optimized for training and reproducing benchmark metrics, which makes it quite difficult to decouple for pure, out-of-distribution inference. I would love to share a brief summary of the roadblocks I encountered, in hopes it might be helpful for future updates or an inference-only API:

Tight Coupling in DataLoader & Sampler: The dataset_core.py and sampler.py strictly expect paired "control" and "perturbed" cells to calculate metrics. Bypassing this to feed a simple .h5ad file of raw cells requires heavily modifying the dictionary mappings (e.g., grouped_num_cell, data_indices) to prevent KeyErrors and AssertionErrors.

Hardcoded Dataset Names & Metadata: The codebase heavily relies on predefined dataset names (pbmc, tahoe100m, etc.). When feeding novel data, the framework automatically assigns names like dummy_plate_9, which later causes AssertionErrors in functions like get_short_dsname and embedder mapping.

Strict Checkpoint Loading & Embedder Dimensions: When adapting the model to accept my dataset's dimensions (e.g., 2000 HVGs), modifying the nn.Linear layers causes Unexpected key(s) in load_state_dict because the checkpoint contains hardcoded dataset-specific embedders (e.g., x_embedder.pbmc.weight). This required setting strict=False to force initialization.

Shape Assertions in Diffusion Core: During the forward pass, the Transformer blocks often output a 3D tensor [Batch, 1, Dim], but the diffusion_core.py strictly asserts x_t.shape == eps.shape (expecting [Batch, Dim]). This required manual squeeze/reshape operations at the model's output to prevent runtime crashes.

I wanted to ask if you have any plans to release a simplified predict.py script for users who just want to input an .h5ad and a perturbation condition to get the simulated results.

Thank you again for your time, your amazing research, and for making this repository open-source!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions