Leveraging Low-cost Pathological Spatial Knowledge for Unbiased Interpretability: Diagnosis-Guided Hypergraph Learning for Tumor Prognosis
This repo is the PyTorch implementation for the DGP described in the paper "Leveraging Low-cost Pathological Spatial Knowledge for Unbiased Interpretability: Diagnosis-Guided Hypergraph Learning for Tumor Prognosis".
Diagnostic_model_weight
#This folder contains the weights of our trained diagnostic models based on CLAM framework for the binary claasification task (normal vs. tumor)
on three cancer types (i.e., BRCA, HCC, CRC) of datasets, which are used to obtain the low-cost pathological spatial knowledge for
subsequent prognosis task and model's interpretability validation.
DGP-master
│ environment.yaml
│ main.py # training and evaluating DGP
├─dataset_modules
│ dataset_generic.py
│ dataset_h5.py
│ wsi_dataset.py
│
├─HGAT # the model's architecture
│ H_GNN.py
│ hypergraph_util.py
│ layers.py
│
├─models # the compared model's architecture
│ ├─ABMIL
│ abmil.py
│ ├─CLAM
│ clam.py
│ ├─DSMIL
│ dsmil.py
│ ├─GDFMIL
│ gdfmil.py
│ ├─RRTMIL
│ rrtmil.py
│ ├─Patch_GCN
│ graph_construction.py
│ patch_gcn.py
│
├─runs # the training scripts
│ scratch.sh
│
├─utils
│ __init__.py
│ general.py # help function
│ losses.py # loss function
│ utils.py
Ubuntu 22.04 LTS, environment.yml
The input of graph-based methods generally contains two parts: image features and graph structure.
Supported formats include OpenSlide and NDPI formats. The following backbones are supported: R50, VIT-S, CTRANSPATH, PLIP, CONCH, UNI, GIGAPATH, VIRCHOW, VIRCHOW-V2 and CONCH-V1.5.
We recommend using repositories such as PIANO or TRIDENT for your feature extraction work.
The graph/hypergraph structure is used to represent the global architecture of pathlogical tissue within each slide. Generally, the image patches are regarded as the nodes of graph. For some methods using fixed graph structure with to represent each slide, like Patch_GCN and HGSurvNet, we can construct the graph structure before training/testing models.
python ./models/Patch_GCN/graph_construction.py python ./models/HGSurvNet/feature_to_hypergraph.pyFor our proposed method, we can directly use the case_id_blockmap.h5 file generated by the script creat_heatmap.py in CLAM code.
There is an example in the directary dataset_csv. The format of input csv file:
| case_id | slide_id | OS_status | OS_time | cohort_name | diagnosis_result | label |
|---|---|---|---|---|---|---|
| 1407010 | 1407010-11-HE-DX1 | 0 | 1627 | TCGA | tumor | 3 |
| ... | ... | ... | ... | ... | ... | ... |
case_id: [str] the id of each patient. slide_id: [str] the id of each slide. OS_status: [bool] the status of patient's overall survival (1: death, 0: alive). OS_time: [int] the overall survival time of patient. cohort_name: [str] the cohort that each slide is from. diagnosis_result: [str] the diagnostic result in the prior diagnosis task. label: [int] the group id ([0, 1, 2, 3]) based on a binning strategy