diff --git a/examples/synthetic_ehr_generation/synthetic_ehr_baselines.py b/examples/synthetic_ehr_generation/synthetic_ehr_baselines.py new file mode 100644 index 000000000..495413ae7 --- /dev/null +++ b/examples/synthetic_ehr_generation/synthetic_ehr_baselines.py @@ -0,0 +1,383 @@ +""" +Synthetic EHR Generation Baselines using PyHealth + +This script demonstrates how to use PyHealth's infrastructure with various +baseline generative models for synthetic EHR data. + +Supported models: +- GReaT: Tabular data generation using language models +- CTGAN: Conditional GAN for tabular data +- TVAE: Variational Autoencoder for tabular data +- TransformerBaseline: Custom transformer for sequential EHR + +Usage: + # Using GReaT model + python synthetic_ehr_baselines.py \\ + --mimic_root /path/to/mimic3 \\ + --train_patients /path/to/train_ids.txt \\ + --test_patients /path/to/test_ids.txt \\ + --output_dir /path/to/output \\ + --mode great + + # Using CTGAN model + python synthetic_ehr_baselines.py \\ + --mimic_root /path/to/mimic3 \\ + --train_patients /path/to/train_ids.txt \\ + --test_patients /path/to/test_ids.txt \\ + --output_dir /path/to/output \\ + --mode ctgan + + # Using PyHealth TransformerEHRGenerator + python synthetic_ehr_baselines.py \\ + --mimic_root /path/to/mimic3 \\ + --train_patients /path/to/train_ids.txt \\ + --test_patients /path/to/test_ids.txt \\ + --output_dir /path/to/output \\ + --mode transformer_baseline +""" + +import os +import argparse +import pandas as pd +import torch +from tqdm import tqdm, trange + +from pyhealth.synthetic_ehr_utils.synthetic_ehr_utils import ( + process_mimic_for_generation, + tabular_to_sequences, + sequences_to_tabular, +) + + +def train_great_model(train_flattened, args): + """Train GReaT model on flattened EHR data.""" + try: + import be_great + except ImportError: + raise ImportError( + "be_great is not installed. Install with: pip install be-great" + ) + + print("\n=== Training GReaT Model ===") + model = be_great.GReaT( + llm=args.great_llm, + batch_size=args.batch_size, + epochs=args.epochs, + dataloader_num_workers=args.num_workers, + fp16=torch.cuda.is_available(), + ) + + model.fit(train_flattened) + + # Save model + save_path = os.path.join(args.output_dir, "great") + os.makedirs(save_path, exist_ok=True) + model.save(save_path) + + # Generate synthetic data + print(f"\n=== Generating {args.num_synthetic_samples} synthetic samples ===") + synthetic_data = model.sample(n_samples=args.num_synthetic_samples) + + # Save + synthetic_data.to_csv( + os.path.join(save_path, "great_synthetic_flattened_ehr.csv"), index=False + ) + + print(f"Saved synthetic data to {save_path}") + return synthetic_data + + +def train_ctgan_model(train_flattened, args): + """Train CTGAN model on flattened EHR data.""" + try: + from sdv.metadata import Metadata + from sdv.single_table import CTGANSynthesizer + except ImportError: + raise ImportError("sdv is not installed. Install with: pip install sdv") + + print("\n=== Training CTGAN Model ===") + + # Auto-detect metadata + metadata = Metadata.detect_from_dataframe(data=train_flattened) + + # Set all columns as numerical + for column in train_flattened.columns: + metadata.update_column(column_name=column, sdtype="numerical") + + # Initialize and train + synthesizer = CTGANSynthesizer( + metadata, epochs=args.epochs, batch_size=args.batch_size + ) + synthesizer.fit(train_flattened) + + # Save model + save_path = os.path.join(args.output_dir, "ctgan") + os.makedirs(save_path, exist_ok=True) + synthesizer.save(filepath=os.path.join(save_path, "synthesizer.pkl")) + + # Generate synthetic data + print(f"\n=== Generating {args.num_synthetic_samples} synthetic samples ===") + synthetic_data = synthesizer.sample(num_rows=args.num_synthetic_samples) + + # Save + synthetic_data.to_csv( + os.path.join(save_path, "ctgan_synthetic_flattened_ehr.csv"), index=False + ) + + print(f"Saved synthetic data to {save_path}") + return synthetic_data + + +def train_tvae_model(train_flattened, args): + """Train TVAE model on flattened EHR data.""" + try: + from sdv.metadata import Metadata + from sdv.single_table import TVAESynthesizer + except ImportError: + raise ImportError("sdv is not installed. Install with: pip install sdv") + + print("\n=== Training TVAE Model ===") + + # Auto-detect metadata + metadata = Metadata.detect_from_dataframe(data=train_flattened) + + # Set all columns as numerical + for column in train_flattened.columns: + metadata.update_column(column_name=column, sdtype="numerical") + + # Initialize and train + synthesizer = TVAESynthesizer( + metadata, epochs=args.epochs, batch_size=args.batch_size + ) + synthesizer.fit(train_flattened) + + # Save model + save_path = os.path.join(args.output_dir, "tvae") + os.makedirs(save_path, exist_ok=True) + synthesizer.save(filepath=os.path.join(save_path, "synthesizer.pkl")) + + # Generate synthetic data + print(f"\n=== Generating {args.num_synthetic_samples} synthetic samples ===") + synthetic_data = synthesizer.sample(num_rows=args.num_synthetic_samples) + + # Save + synthetic_data.to_csv( + os.path.join(save_path, "tvae_synthetic_flattened_ehr.csv"), index=False + ) + + print(f"Saved synthetic data to {save_path}") + return synthetic_data + + +def train_transformer_baseline(train_ehr, args): + """Train PyHealth TransformerEHRGenerator on sequential EHR data.""" + from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient + from pyhealth.tasks import SyntheticEHRGenerationMIMIC3 + from pyhealth.models import TransformerEHRGenerator + from pyhealth.trainer import Trainer + + print("\n=== Training Transformer Baseline with PyHealth ===") + + # Load MIMIC-III dataset + print("Loading MIMIC-III dataset...") + base_dataset = MIMIC3Dataset( + root=args.mimic_root, tables=["DIAGNOSES_ICD"], num_workers=args.num_workers + ) + + # Apply task + print("Applying SyntheticEHRGenerationMIMIC3 task...") + task = SyntheticEHRGenerationMIMIC3(min_visits=2) + sample_dataset = base_dataset.set_task(task, num_workers=args.num_workers) + + # Split dataset + train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1]) + + # Create dataloaders + # Use smaller batch size for transformer (sequences are long after flattening) + transformer_batch_size = 64 # Much smaller than tabular models + train_loader = get_dataloader( + train_dataset, batch_size=transformer_batch_size, shuffle=True + ) + val_loader = get_dataloader( + val_dataset, batch_size=transformer_batch_size, shuffle=False + ) + + # Initialize model + print("Initializing TransformerEHRGenerator...") + model = TransformerEHRGenerator( + dataset=sample_dataset, + embedding_dim=256, + num_heads=8, + num_layers=6, + dim_feedforward=1024, + dropout=0.1, + max_seq_length=512, + ) + + # Train + print("Training model...") + trainer = Trainer( + model=model, + device="cuda" if torch.cuda.is_available() else "cpu", + output_path=args.output_dir, + exp_name="transformer_baseline", + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=args.epochs, + monitor="loss", + monitor_criterion="min", + optimizer_params={"lr": 1e-4}, + ) + + # Generate synthetic data + print(f"\n=== Generating {args.num_synthetic_samples} synthetic samples ===") + model.eval() + with torch.no_grad(): + synthetic_nested_codes = model.generate( + num_samples=args.num_synthetic_samples, + max_visits=10, + max_codes_per_visit=20, + max_length=512, + temperature=1.0, + top_k=50, + top_p=0.95, + ) + + # Convert to sequences and tabular + from pyhealth.synthetic_ehr_utils.synthetic_ehr_utils import ( + nested_codes_to_sequences, + sequences_to_tabular, + ) + + synthetic_sequences = nested_codes_to_sequences(synthetic_nested_codes) + synthetic_df = sequences_to_tabular(synthetic_sequences) + + # Save + save_path = os.path.join(args.output_dir, "transformer_baseline") + os.makedirs(save_path, exist_ok=True) + + synthetic_df.to_csv( + os.path.join(save_path, "transformer_baseline_synthetic_ehr.csv"), index=False + ) + + print(f"Saved synthetic data to {save_path}") + return synthetic_df + + +def main(): + parser = argparse.ArgumentParser( + description="Train baseline models for synthetic EHR generation" + ) + + # Data paths + parser.add_argument( + "--mimic_root", + type=str, + required=True, + help="Path to MIMIC data directory", + ) + parser.add_argument( + "--train_patients", + type=str, + default=None, + help="Path to train patient IDs file", + ) + parser.add_argument( + "--test_patients", + type=str, + default=None, + help="Path to test patient IDs file", + ) + parser.add_argument( + "--output_dir", + type=str, + default="./synthetic_data", + help="Output directory for synthetic data", + ) + + # Model selection + parser.add_argument( + "--mode", + type=str, + default="transformer_baseline", + choices=["great", "ctgan", "tvae", "transformer_baseline"], + help="Baseline model to use", + ) + + # Training parameters + parser.add_argument("--epochs", type=int, default=2, help="Number of epochs") + parser.add_argument("--batch_size", type=int, default=512, help="Batch size") + parser.add_argument("--num_workers", type=int, default=4, help="Number of workers") + parser.add_argument( + "--num_synthetic_samples", + type=int, + default=10000, + help="Number of synthetic samples to generate", + ) + + # Model-specific parameters + parser.add_argument( + "--great_llm", + type=str, + default="tabularisai/Qwen3-0.3B-distil", + help="Language model for GReaT", + ) + + args = parser.parse_args() + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Process MIMIC data + print("=" * 80) + print("Processing MIMIC Data") + print("=" * 80) + + if args.mode == "transformer_baseline": + # For transformer baseline, we process data through PyHealth + # Dataset will be loaded in the training function + print("Will load data through PyHealth dataset...") + train_transformer_baseline(None, args) + + else: + # For tabular models, we need flattened representation + print("Processing MIMIC data for tabular models...") + data = process_mimic_for_generation( + args.mimic_root, + args.train_patients, + args.test_patients, + ) + + train_flattened = data["train_flattened"] + print(f"Train flattened shape: {train_flattened.shape}") + + # Train selected model + print("\n" + "=" * 80) + print(f"Training {args.mode.upper()} Model") + print("=" * 80) + + if args.mode == "great": + synthetic_data = train_great_model(train_flattened, args) + elif args.mode == "ctgan": + synthetic_data = train_ctgan_model(train_flattened, args) + elif args.mode == "tvae": + synthetic_data = train_tvae_model(train_flattened, args) + + print("\n" + "=" * 80) + print("Synthetic Data Statistics") + print("=" * 80) + print(f"Shape: {synthetic_data.shape}") + print(f"Columns: {len(synthetic_data.columns)}") + print(f"\nFirst few rows:") + print(synthetic_data.head()) + + print("\n" + "=" * 80) + print("COMPLETED") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/synthetic_ehr_generation/synthetic_ehr_mimic3_transformer.py b/examples/synthetic_ehr_generation/synthetic_ehr_mimic3_transformer.py new file mode 100644 index 000000000..6bc91fd84 --- /dev/null +++ b/examples/synthetic_ehr_generation/synthetic_ehr_mimic3_transformer.py @@ -0,0 +1,342 @@ +""" +Example of training a Transformer-based synthetic EHR generator on MIMIC-III data. + +This example demonstrates the complete workflow for training a generative model +that can create synthetic patient histories: + +1. Loading MIMIC-III data +2. Applying the SyntheticEHRGenerationMIMIC3 task +3. Training a TransformerEHRGenerator model +4. Generating synthetic patient histories +5. Converting synthetic data to different formats + +Usage: + python synthetic_ehr_mimic3_transformer.py \\ + --mimic_root /path/to/mimic3 \\ + --output_dir /path/to/output \\ + --epochs 50 \\ + --batch_size 32 +""" + +import os +import argparse +import pandas as pd +import torch + +from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient +from pyhealth.tasks import SyntheticEHRGenerationMIMIC3 +from pyhealth.models import TransformerEHRGenerator +from pyhealth.trainer import Trainer +from pyhealth.synthetic_ehr_utils.synthetic_ehr_utils import ( + nested_codes_to_sequences, + sequences_to_tabular, +) + + +def main(args): + """Main training and generation pipeline.""" + + print("\n" + "=" * 80) + print("STEP 1: Load MIMIC-III Dataset") + print("=" * 80) + + # Load MIMIC-III base dataset + base_dataset = MIMIC3Dataset( + root=args.mimic_root, + tables=["DIAGNOSES_ICD"], # Only need diagnosis codes + code_mapping={"ICD9CM": "CCSCM"} if args.use_ccs else None, + num_workers=args.num_workers, + ) + + print(f"\nDataset loaded:") + print(f" Total patients: {len(base_dataset.patient_to_index)}") + print(f" Total admissions: {sum(len(p.get_events('admissions')) for p in base_dataset)}") + + print("\n" + "=" * 80) + print("STEP 2: Apply Synthetic EHR Generation Task") + print("=" * 80) + + # Create task for synthetic generation + task = SyntheticEHRGenerationMIMIC3( + min_visits=args.min_visits, max_visits=args.max_visits + ) + + # Generate samples + sample_dataset = base_dataset.set_task(task, num_workers=args.num_workers) + + print(f"\nTask applied:") + print(f" Total samples: {len(sample_dataset)}") + print(f" Input schema: {sample_dataset.input_schema}") + print(f" Output schema: {sample_dataset.output_schema}") + + # Inspect a sample + sample = sample_dataset[0] + print(f"\nSample structure:") + print(f" Patient ID: {sample['patient_id']}") + print(f" Visit codes shape: {sample['visit_codes'].shape}") + print(f" Number of visits: {sample['visit_codes'].shape[0]}") + print(f" Max codes per visit: {sample['visit_codes'].shape[1]}") + + print("\n" + "=" * 80) + print("STEP 3: Split Dataset") + print("=" * 80) + + # Split by patient (important to prevent data leakage) + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [args.train_ratio, args.val_ratio, 1 - args.train_ratio - args.val_ratio] + ) + + print(f"\nDataset split:") + print(f" Train: {len(train_dataset)} samples") + print(f" Val: {len(val_dataset)} samples") + print(f" Test: {len(test_dataset)} samples") + + # Create dataloaders + train_loader = get_dataloader( + train_dataset, batch_size=args.batch_size, shuffle=True + ) + val_loader = get_dataloader(val_dataset, batch_size=args.batch_size, shuffle=False) + test_loader = get_dataloader( + test_dataset, batch_size=args.batch_size, shuffle=False + ) + + print("\n" + "=" * 80) + print("STEP 4: Initialize TransformerEHRGenerator Model") + print("=" * 80) + + # Create the generative model + model = TransformerEHRGenerator( + dataset=sample_dataset, + embedding_dim=args.embedding_dim, + num_heads=args.num_heads, + num_layers=args.num_layers, + dim_feedforward=args.dim_feedforward, + dropout=args.dropout, + max_seq_length=args.max_seq_length, + ) + + num_params = sum(p.numel() for p in model.parameters()) + print(f"\nModel initialized:") + print(f" Total parameters: {num_params:,}") + print(f" Vocabulary size: {model.vocab_size}") + print(f" Embedding dim: {args.embedding_dim}") + print(f" Num layers: {args.num_layers}") + print(f" Num heads: {args.num_heads}") + + print("\n" + "=" * 80) + print("STEP 5: Train the Model") + print("=" * 80) + + # Create trainer + trainer = Trainer( + model=model, + device=args.device, + output_path=args.output_dir, + exp_name="transformer_ehr_generator", + ) + + # Train + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=args.epochs, + monitor="loss", # Monitor validation loss + monitor_criterion="min", + optimizer_params={"lr": args.learning_rate, "weight_decay": args.weight_decay}, + ) + + print("\n" + "=" * 80) + print("STEP 6: Evaluate on Test Set") + print("=" * 80) + + # Evaluate + test_results = trainer.evaluate(test_loader) + print("\nTest Results:") + for metric, value in test_results.items(): + print(f" {metric}: {value:.4f}") + + print("\n" + "=" * 80) + print("STEP 7: Generate Synthetic Patient Histories") + print("=" * 80) + + print(f"\nGenerating {args.num_synthetic_samples} synthetic patients...") + + # Generate synthetic samples + model.eval() + with torch.no_grad(): + synthetic_nested_codes = model.generate( + num_samples=args.num_synthetic_samples, + max_visits=args.max_visits, + max_codes_per_visit=args.max_codes_per_visit, + max_length=args.max_seq_length, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + ) + + print(f"Generated {len(synthetic_nested_codes)} synthetic patients") + + # Convert to different formats + print("\nConverting to different formats...") + + # 1. Convert to text sequences + synthetic_sequences = nested_codes_to_sequences(synthetic_nested_codes) + + # 2. Convert to tabular format + synthetic_df = sequences_to_tabular(synthetic_sequences) + + # Display statistics + print(f"\nSynthetic data statistics:") + print(f" Total patients: {len(synthetic_nested_codes)}") + print(f" Total visits: {synthetic_df['HADM_ID'].nunique()}") + print(f" Total codes: {len(synthetic_df)}") + print(f" Avg visits per patient: {len(synthetic_df) / len(synthetic_nested_codes):.2f}") + print(f" Unique codes: {synthetic_df['ICD9_CODE'].nunique()}") + + # Save synthetic data + print("\n" + "=" * 80) + print("STEP 8: Save Synthetic Data") + print("=" * 80) + + os.makedirs(args.output_dir, exist_ok=True) + + # Save as CSV + synthetic_csv_path = os.path.join( + args.output_dir, "synthetic_ehr_transformer.csv" + ) + synthetic_df.to_csv(synthetic_csv_path, index=False) + print(f"\nSaved synthetic data to: {synthetic_csv_path}") + + # Save sequences as text file + synthetic_seq_path = os.path.join( + args.output_dir, "synthetic_sequences_transformer.txt" + ) + with open(synthetic_seq_path, "w") as f: + for seq in synthetic_sequences: + f.write(seq + "\n") + print(f"Saved synthetic sequences to: {synthetic_seq_path}") + + # Display sample synthetic patient + print("\n" + "=" * 80) + print("Sample Synthetic Patient History") + print("=" * 80) + print(f"\nPatient 0:") + for visit_idx, visit_codes in enumerate(synthetic_nested_codes[0]): + print(f" Visit {visit_idx + 1}: {visit_codes}") + + print("\n" + "=" * 80) + print("COMPLETED") + print("=" * 80) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Train Transformer-based synthetic EHR generator on MIMIC-III" + ) + + # Dataset arguments + parser.add_argument( + "--mimic_root", + type=str, + required=True, + help="Path to MIMIC-III data directory", + ) + parser.add_argument( + "--use_ccs", + action="store_true", + help="Map ICD9 codes to CCS categories", + ) + parser.add_argument( + "--min_visits", type=int, default=2, help="Minimum visits per patient" + ) + parser.add_argument( + "--max_visits", + type=int, + default=None, + help="Maximum visits per patient (None = no limit)", + ) + parser.add_argument( + "--num_workers", type=int, default=4, help="Number of workers for data loading" + ) + + # Model arguments + parser.add_argument( + "--embedding_dim", type=int, default=256, help="Embedding dimension" + ) + parser.add_argument( + "--num_heads", type=int, default=8, help="Number of attention heads" + ) + parser.add_argument( + "--num_layers", type=int, default=6, help="Number of transformer layers" + ) + parser.add_argument( + "--dim_feedforward", + type=int, + default=1024, + help="Feedforward network dimension", + ) + parser.add_argument("--dropout", type=float, default=0.1, help="Dropout rate") + parser.add_argument( + "--max_seq_length", + type=int, + default=512, + help="Maximum sequence length", + ) + + # Training arguments + parser.add_argument("--epochs", type=int, default=50, help="Number of epochs") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size") + parser.add_argument( + "--learning_rate", type=float, default=1e-4, help="Learning rate" + ) + parser.add_argument( + "--weight_decay", type=float, default=0.01, help="Weight decay" + ) + parser.add_argument( + "--train_ratio", type=float, default=0.8, help="Training set ratio" + ) + parser.add_argument( + "--val_ratio", type=float, default=0.1, help="Validation set ratio" + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to use (cuda/cpu)", + ) + + # Generation arguments + parser.add_argument( + "--num_synthetic_samples", + type=int, + default=1000, + help="Number of synthetic samples to generate", + ) + parser.add_argument( + "--max_codes_per_visit", + type=int, + default=20, + help="Maximum codes per visit during generation", + ) + parser.add_argument( + "--temperature", type=float, default=1.0, help="Sampling temperature" + ) + parser.add_argument( + "--top_k", type=int, default=50, help="Top-k sampling (0 = disabled)" + ) + parser.add_argument( + "--top_p", type=float, default=0.95, help="Nucleus sampling threshold" + ) + + # Output arguments + parser.add_argument( + "--output_dir", + type=str, + default="./synthetic_ehr_output", + help="Output directory for model and synthetic data", + ) + + args = parser.parse_args() + + # Run main pipeline + main(args) diff --git a/examples/synthetic_ehr_generation/transformer_mimic3_colab.ipynb b/examples/synthetic_ehr_generation/transformer_mimic3_colab.ipynb new file mode 100644 index 000000000..cb0dff829 --- /dev/null +++ b/examples/synthetic_ehr_generation/transformer_mimic3_colab.ipynb @@ -0,0 +1,630 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "xpY_dOGVk5aF" + }, + "source": [ + "# Transformer Baseline for Synthetic EHR Generation on MIMIC-III\n", + "\n", + "This notebook demonstrates how to train a Transformer-based generative model on MIMIC-III data and generate synthetic patient records.\n", + "\n", + "## Overview\n", + "- **Model**: TransformerEHRGenerator (decoder-only transformer, GPT-style)\n", + "- **Dataset**: MIMIC-III diagnosis codes\n", + "- **Output**: CSV file with columns: `SUBJECT_ID`, `VISIT_NUM`, `ICD9_CODE`\n", + "\n", + "## Setup\n", + "Designed for Google Colab with GPU support." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "m40snV-Yk5aG" + }, + "source": [ + "## 1. Environment Setup" + ] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "\n", + "# Where to clone from\n", + "clone_repo = \"https://github.com/ethanrasmussen/PyHealth.git\"\n", + "clone_branch = \"implement_baseline\"\n", + "\n", + "# Where to save repo/package\n", + "repo_dir = \"/content/PyHealth\"\n", + "\n", + "if not os.path.exists(repo_dir):\n", + " !git clone -b {clone_branch} {clone_repo} {repo_dir}\n", + "%cd /content/PyHealth\n", + "\n", + "# install your repo without letting pip touch torch/cuda stack\n", + "%pip install -e . --no-deps\n", + "\n", + "# now install the runtime deps you actually need for this notebook\n", + "# %pip install -U --no-cache-dir --force-reinstall \"numpy==2.2.0\"\n", + "%pip install -U \"transformers==4.53.2\" \"tokenizers\" \"accelerate\" \"peft\"\n", + "%pip install -U \"pandas\" \"tqdm\" \"litdata\" \"mne\" \"rdkit\"" + ], + "metadata": { + "collapsed": true, + "id": "PrzAv7pMlksS" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "f1LCYYnNk5aH" + }, + "outputs": [], + "source": [ + "# Check GPU availability\n", + "import torch\n", + "print(f\"PyTorch version: {torch.__version__}\")\n", + "print(f\"CUDA available: {torch.cuda.is_available()}\")\n", + "if torch.cuda.is_available():\n", + " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n", + " device = \"cuda\"\n", + "else:\n", + " print(\"WARNING: Running on CPU. Training will be very slow.\")\n", + " device = \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sziIjDsVk5aI" + }, + "outputs": [], + "source": [ + "# Mount Google Drive (optional - for persistent storage)\n", + "from google.colab import drive\n", + "drive.mount('/content/drive')\n", + "\n", + "# Set paths for persistent storage\n", + "DRIVE_ROOT = \"/content/drive/MyDrive/PyHealth_Synthetic_EHR\"\n", + "!mkdir -p \"{DRIVE_ROOT}\"\n", + "!mkdir -p \"{DRIVE_ROOT}/data\"\n", + "!mkdir -p \"{DRIVE_ROOT}/models\"\n", + "!mkdir -p \"{DRIVE_ROOT}/output\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i4WQ5LO1k5aI" + }, + "source": [ + "## 2. Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OJFg_122k5aJ" + }, + "outputs": [], + "source": [ + "# Configuration parameters\n", + "class Config:\n", + " # Paths\n", + " MIMIC_ROOT = f\"{DRIVE_ROOT}/data/mimic3\" # Update this to your MIMIC-III path\n", + " OUTPUT_DIR = f\"{DRIVE_ROOT}/output\"\n", + " MODEL_SAVE_PATH = f\"{DRIVE_ROOT}/models/transformer_ehr_best.pth\"\n", + "\n", + " # Dataset parameters\n", + " MIN_VISITS = 2 # Minimum visits per patient\n", + " MAX_VISITS = None # Maximum visits to include (None = all)\n", + "\n", + " # Model architecture\n", + " EMBEDDING_DIM = 256\n", + " NUM_HEADS = 8\n", + " NUM_LAYERS = 6\n", + " DIM_FEEDFORWARD = 1024\n", + " DROPOUT = 0.1\n", + " MAX_SEQ_LENGTH = 512\n", + "\n", + " # Training parameters\n", + " EPOCHS = 5 # Set to 50-80 for production\n", + " BATCH_SIZE = 64 # Reduce to 32 if OOM errors occur\n", + " LEARNING_RATE = 1e-4\n", + " WEIGHT_DECAY = 1e-5\n", + "\n", + " # Data split\n", + " TRAIN_RATIO = 0.8\n", + " VAL_RATIO = 0.1\n", + " TEST_RATIO = 0.1\n", + "\n", + " # Generation parameters\n", + " NUM_SYNTHETIC_SAMPLES = 1000 # Set to 10000 for production\n", + " MAX_GEN_VISITS = 10\n", + " MAX_CODES_PER_VISIT = 20\n", + " TEMPERATURE = 1.0\n", + " TOP_K = 50\n", + " TOP_P = 0.95\n", + "\n", + "config = Config()\n", + "print(\"Configuration loaded successfully!\")\n", + "print(f\"Training for {config.EPOCHS} epochs\")\n", + "print(f\"Will generate {config.NUM_SYNTHETIC_SAMPLES} synthetic patients\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9pZOHkWgk5aJ" + }, + "source": [ + "## 3. Data Upload\n", + "\n", + "Upload your MIMIC-III data files to the specified directory. You need:\n", + "- `ADMISSIONS.csv`\n", + "- `DIAGNOSES_ICD.csv`\n", + "\n", + "These files should be placed in the directory specified by `config.MIMIC_ROOT`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sq0QKL19k5aJ" + }, + "outputs": [], + "source": [ + "# Check if MIMIC-III files exist\n", + "import os\n", + "\n", + "required_files = ['ADMISSIONS.csv', 'DIAGNOSES_ICD.csv']\n", + "files_exist = all(os.path.exists(os.path.join(config.MIMIC_ROOT, f)) for f in required_files)\n", + "\n", + "if files_exist:\n", + " print(\"✓ All required MIMIC-III files found!\")\n", + "else:\n", + " print(\"✗ Missing MIMIC-III files. Please upload:\")\n", + " for f in required_files:\n", + " path = os.path.join(config.MIMIC_ROOT, f)\n", + " status = \"✓\" if os.path.exists(path) else \"✗\"\n", + " print(f\" {status} {f}\")\n", + " print(f\"\\nUpload files to: {config.MIMIC_ROOT}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Zn8o4OoIk5aK" + }, + "source": [ + "## 4. Load and Preprocess Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "id": "caCNoM5lk5aK" + }, + "outputs": [], + "source": [ + "# Import PyHealth modules\n", + "from pyhealth.datasets import MIMIC3Dataset\n", + "from pyhealth.tasks import SyntheticEHRGenerationMIMIC3\n", + "from pyhealth.datasets import split_by_patient, get_dataloader\n", + "\n", + "print(\"Loading MIMIC-III dataset...\")\n", + "# Load base dataset\n", + "base_dataset = MIMIC3Dataset(\n", + " root=config.MIMIC_ROOT,\n", + " tables=[\"DIAGNOSES_ICD\"],\n", + ")\n", + "\n", + "# Apply synthetic EHR generation task\n", + "print(f\"\\nApplying task with min_visits={config.MIN_VISITS}...\")\n", + "task = SyntheticEHRGenerationMIMIC3(\n", + " min_visits=config.MIN_VISITS,\n", + " max_visits=config.MAX_VISITS\n", + ")\n", + "sample_dataset = base_dataset.set_task(task)\n", + "\n", + "print(f\"Created {len(sample_dataset)} samples\")\n", + "\n", + "# Split by patient to prevent data leakage\n", + "print(f\"\\nSplitting data: {config.TRAIN_RATIO}/{config.VAL_RATIO}/{config.TEST_RATIO}\")\n", + "train_dataset, val_dataset, test_dataset = split_by_patient(\n", + " sample_dataset,\n", + " [config.TRAIN_RATIO, config.VAL_RATIO, config.TEST_RATIO]\n", + ")\n", + "\n", + "print(f\"Train: {len(train_dataset)} samples\")\n", + "print(f\"Val: {len(val_dataset)} samples\")\n", + "print(f\"Test: {len(test_dataset)} samples\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bLzBM2Gek5aK" + }, + "outputs": [], + "source": [ + "# Create data loaders\n", + "print(\"Creating data loaders...\")\n", + "train_loader = get_dataloader(\n", + " train_dataset,\n", + " batch_size=config.BATCH_SIZE,\n", + " shuffle=True\n", + ")\n", + "val_loader = get_dataloader(\n", + " val_dataset,\n", + " batch_size=config.BATCH_SIZE,\n", + " shuffle=False\n", + ")\n", + "test_loader = get_dataloader(\n", + " test_dataset,\n", + " batch_size=config.BATCH_SIZE,\n", + " shuffle=False\n", + ")\n", + "\n", + "print(f\"Train batches: {len(train_loader)}\")\n", + "print(f\"Val batches: {len(val_loader)}\")\n", + "print(f\"Test batches: {len(test_loader)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lfcDLLJCk5aL" + }, + "outputs": [], + "source": [ + "# Inspect a sample batch\n", + "sample_batch = next(iter(train_loader))\n", + "print(\"Sample batch structure:\")\n", + "for key, value in sample_batch.items():\n", + " if isinstance(value, torch.Tensor):\n", + " print(f\" {key}: shape {value.shape}, dtype {value.dtype}\")\n", + " else:\n", + " print(f\" {key}: {type(value)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mkWW6j8Qk5aL" + }, + "source": [ + "## 5. Initialize Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2O0XGcvnk5aL" + }, + "outputs": [], + "source": [ + "from pyhealth.models import TransformerEHRGenerator\n", + "\n", + "print(\"Initializing TransformerEHRGenerator...\")\n", + "model = TransformerEHRGenerator(\n", + " dataset=sample_dataset,\n", + " embedding_dim=config.EMBEDDING_DIM,\n", + " num_heads=config.NUM_HEADS,\n", + " num_layers=config.NUM_LAYERS,\n", + " dim_feedforward=config.DIM_FEEDFORWARD,\n", + " dropout=config.DROPOUT,\n", + " max_seq_length=config.MAX_SEQ_LENGTH\n", + ")\n", + "\n", + "# Move model to device\n", + "model = model.to(device)\n", + "\n", + "# Print model info\n", + "total_params = sum(p.numel() for p in model.parameters())\n", + "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "print(f\"\\nModel initialized successfully!\")\n", + "print(f\"Total parameters: {total_params:,}\")\n", + "print(f\"Trainable parameters: {trainable_params:,}\")\n", + "print(f\"Vocabulary size: {model.vocab_size}\")\n", + "print(f\"Device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mG-Y5T_ak5aL" + }, + "source": [ + "## 6. Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ya3RTumXk5aM" + }, + "outputs": [], + "source": [ + "from pyhealth.trainer import Trainer\n", + "\n", + "print(f\"Starting training for {config.EPOCHS} epochs...\\n\")\n", + "\n", + "# Initialize trainer\n", + "trainer = Trainer(\n", + " model=model,\n", + " device=device,\n", + " output_path=config.OUTPUT_DIR,\n", + " exp_name=\"transformer_ehr_generator\"\n", + ")\n", + "\n", + "# Train model\n", + "trainer.train(\n", + " train_dataloader=train_loader,\n", + " val_dataloader=val_loader,\n", + " epochs=config.EPOCHS,\n", + " monitor=\"loss\",\n", + " monitor_criterion=\"min\",\n", + " load_best_model_at_last=True\n", + ")\n", + "\n", + "print(\"\\n\" + \"=\"*50)\n", + "print(\"Training completed!\")\n", + "print(\"=\"*50)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bcVdhyM-k5aM" + }, + "outputs": [], + "source": [ + "# Save the best model\n", + "torch.save(model.state_dict(), config.MODEL_SAVE_PATH)\n", + "print(f\"Model saved to: {config.MODEL_SAVE_PATH}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2CT5rZLtk5aM" + }, + "source": [ + "## 7. Evaluation on Test Set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8TIMLkHhk5aM" + }, + "outputs": [], + "source": [ + "# Evaluate on test set\n", + "print(\"Evaluating on test set...\")\n", + "test_results = trainer.evaluate(test_loader)\n", + "\n", + "print(\"\\nTest Results:\")\n", + "for metric, value in test_results.items():\n", + " print(f\" {metric}: {value:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0-f71_B0k5aM" + }, + "source": [ + "## 8. Generate Synthetic Patients" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1doD0vdZk5aM" + }, + "outputs": [], + "source": [ + "# Generate synthetic patient histories\n", + "print(f\"Generating {config.NUM_SYNTHETIC_SAMPLES} synthetic patients...\\n\")\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " synthetic_nested_codes = model.generate(\n", + " num_samples=config.NUM_SYNTHETIC_SAMPLES,\n", + " max_visits=config.MAX_GEN_VISITS,\n", + " max_codes_per_visit=config.MAX_CODES_PER_VISIT,\n", + " temperature=config.TEMPERATURE,\n", + " top_k=config.TOP_K,\n", + " top_p=config.TOP_P\n", + " )\n", + "\n", + "print(f\"Generated {len(synthetic_nested_codes)} synthetic patients\")\n", + "print(f\"\\nExample synthetic patient (first 3 visits):\")\n", + "if len(synthetic_nested_codes) > 0 and len(synthetic_nested_codes[0]) > 0:\n", + " for i, visit in enumerate(synthetic_nested_codes[0][:3]):\n", + " print(f\" Visit {i+1}: {visit[:10]}{'...' if len(visit) > 10 else ''}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4Cd2Cauvk5aN" + }, + "source": [ + "## 9. Convert to DataFrame Format" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PVB9reUUk5aN" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "# Get the processor to convert token IDs back to codes\n", + "input_processor = sample_dataset.input_processors[\"visit_codes\"]\n", + "index_to_code = {v: k for k, v in input_processor.code_vocab.items()}\n", + "\n", + "print(\"Converting synthetic data to CSV format...\")\n", + "\n", + "# Convert nested codes to tabular format\n", + "rows = []\n", + "for patient_idx, patient_visits in enumerate(synthetic_nested_codes):\n", + " synthetic_subject_id = f\"SYNTHETIC_{patient_idx:06d}\"\n", + "\n", + " for visit_num, visit_codes in enumerate(patient_visits, start=1):\n", + " for code_idx in visit_codes:\n", + " # Convert token ID to actual code\n", + " code = index_to_code.get(code_idx, str(code_idx))\n", + "\n", + " # Skip special tokens\n", + " if code in ['', '', '', 'VISIT_DELIM']:\n", + " continue\n", + "\n", + " rows.append({\n", + " 'SUBJECT_ID': synthetic_subject_id,\n", + " 'VISIT_NUM': visit_num,\n", + " 'ICD9_CODE': code\n", + " })\n", + "\n", + "# Create DataFrame\n", + "synthetic_df = pd.DataFrame(rows)\n", + "\n", + "print(f\"\\nCreated DataFrame with {len(synthetic_df)} rows\")\n", + "print(f\"Number of unique patients: {synthetic_df['SUBJECT_ID'].nunique()}\")\n", + "print(f\"Number of unique codes: {synthetic_df['ICD9_CODE'].nunique()}\")\n", + "\n", + "# Display sample\n", + "print(\"\\nSample of synthetic data:\")\n", + "print(synthetic_df.head(20))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SdSZKXoVk5aN" + }, + "source": [ + "## 10. Save CSV File" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "P0c-yP2Vk5aN" + }, + "outputs": [], + "source": [ + "# Save to CSV\n", + "output_csv_path = f\"{config.OUTPUT_DIR}/synthetic_ehr_transformer.csv\"\n", + "synthetic_df.to_csv(output_csv_path, index=False)\n", + "\n", + "print(f\"Synthetic data saved to: {output_csv_path}\")\n", + "print(f\"\\nFile info:\")\n", + "print(f\" Rows: {len(synthetic_df):,}\")\n", + "print(f\" Columns: {list(synthetic_df.columns)}\")\n", + "print(f\" File size: {os.path.getsize(output_csv_path) / 1024:.2f} KB\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oFEV5z8Qk5aO" + }, + "source": [ + "## 11. Summary" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pMAgdOTYk5aO" + }, + "outputs": [], + "source": [ + "# Print final summary\n", + "print(\"\\n\" + \"=\"*60)\n", + "print(\"SYNTHETIC EHR GENERATION SUMMARY\")\n", + "print(\"=\"*60)\n", + "\n", + "print(f\"\\nModel: TransformerEHRGenerator\")\n", + "print(f\" - Embedding dim: {config.EMBEDDING_DIM}\")\n", + "print(f\" - Layers: {config.NUM_LAYERS}\")\n", + "print(f\" - Attention heads: {config.NUM_HEADS}\")\n", + "print(f\" - Parameters: {total_params:,}\")\n", + "\n", + "print(f\"\\nTraining:\")\n", + "print(f\" - Epochs: {config.EPOCHS}\")\n", + "print(f\" - Batch size: {config.BATCH_SIZE}\")\n", + "print(f\" - Training samples: {len(train_dataset)}\")\n", + "print(f\" - Validation samples: {len(val_dataset)}\")\n", + "\n", + "print(f\"\\nGeneration:\")\n", + "print(f\" - Synthetic patients: {synthetic_df['SUBJECT_ID'].nunique()}\")\n", + "print(f\" - Total diagnosis records: {len(synthetic_df)}\")\n", + "print(f\" - Unique ICD-9 codes: {synthetic_df['ICD9_CODE'].nunique()}\")\n", + "print(f\" - Avg visits per patient: {visits_per_patient.mean():.2f}\")\n", + "print(f\" - Avg codes per visit: {codes_per_visit.mean():.2f}\")\n", + "\n", + "print(f\"\\nOutput:\")\n", + "print(f\" - CSV file: {output_csv_path}\")\n", + "print(f\" - Model checkpoint: {config.MODEL_SAVE_PATH}\")\n", + "\n", + "print(\"\\n\" + \"=\"*60)\n", + "print(\"Pipeline completed successfully!\")\n", + "print(\"=\"*60)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index a13b18a51..ab1dea92b 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -43,3 +43,4 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding +from .synthetic_ehr import TransformerEHRGenerator diff --git a/pyhealth/models/synthetic_ehr.py b/pyhealth/models/synthetic_ehr.py new file mode 100644 index 000000000..0023dd9d7 --- /dev/null +++ b/pyhealth/models/synthetic_ehr.py @@ -0,0 +1,446 @@ +""" +Transformer-based models for synthetic EHR generation. + +This module implements autoregressive generative models for creating synthetic +Electronic Health Records. The models learn to generate realistic patient visit +sequences by training on real EHR data. +""" + +from typing import Dict, Optional, List, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel +from pyhealth.processors import NestedSequenceProcessor + + +class TransformerEHRGenerator(BaseModel): + """Transformer-based autoregressive model for synthetic EHR generation. + + This model uses a decoder-only transformer architecture (similar to GPT) to learn + patient visit sequence patterns. It can generate synthetic patient histories by + sampling from the learned distribution. + + The model processes nested sequences of medical codes (visits containing diagnosis + codes) and learns to predict future codes autoregressively. + + Architecture: + - Token embedding layer for medical codes + - Positional encoding for sequential modeling + - Multi-layer transformer decoder + - Output projection to vocabulary + + Args: + dataset: SampleDataset containing training data + embedding_dim: Dimension of code embeddings. Default is 256. + num_heads: Number of attention heads. Default is 8. + num_layers: Number of transformer layers. Default is 6. + dim_feedforward: Hidden dimension of feedforward network. Default is 1024. + dropout: Dropout probability. Default is 0.1. + max_seq_length: Maximum sequence length for positional encoding. Default is 512. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import SyntheticEHRGenerationMIMIC3 + >>> from pyhealth.datasets import get_dataloader + >>> + >>> # Load dataset and apply task + >>> base_dataset = MIMIC3Dataset( + ... root="/path/to/mimic3", + ... tables=["DIAGNOSES_ICD"], + ... ) + >>> task = SyntheticEHRGenerationMIMIC3(min_visits=2) + >>> sample_dataset = base_dataset.set_task(task) + >>> + >>> # Create model + >>> model = TransformerEHRGenerator( + ... dataset=sample_dataset, + ... embedding_dim=256, + ... num_heads=8, + ... num_layers=6, + ... ) + >>> + >>> # Training + >>> train_loader = get_dataloader(sample_dataset, batch_size=32) + >>> for batch in train_loader: + ... output = model(**batch) + ... loss = output["loss"] + ... loss.backward() + >>> + >>> # Generation + >>> synthetic_codes = model.generate(num_samples=100, max_visits=10) + """ + + def __init__( + self, + dataset: SampleDataset, + embedding_dim: int = 256, + num_heads: int = 8, + num_layers: int = 6, + dim_feedforward: int = 1024, + dropout: float = 0.1, + max_seq_length: int = 512, + ): + super().__init__(dataset) + + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.dim_feedforward = dim_feedforward + self.dropout = dropout + self.max_seq_length = max_seq_length + + # Get vocabulary size from the processor + input_processor = dataset.input_processors["visit_codes"] + assert isinstance( + input_processor, NestedSequenceProcessor + ), "Expected NestedSequenceProcessor for visit_codes" + + self.vocab_size = input_processor.vocab_size() + self.pad_idx = input_processor.code_vocab.get("", 0) + + # Special tokens + self.bos_token = input_processor.code_vocab.get("", self.vocab_size) + self.eos_token = input_processor.code_vocab.get("", self.vocab_size + 1) + self.visit_delim_token = input_processor.code_vocab.get( + "VISIT_DELIM", self.vocab_size + 2 + ) + + # Adjust vocab size to include special tokens if needed + extended_vocab_size = max( + self.vocab_size, self.bos_token + 1, self.eos_token + 1, self.visit_delim_token + 1 + ) + + # Token embedding + self.token_embedding = nn.Embedding( + extended_vocab_size, embedding_dim, padding_idx=self.pad_idx + ) + + # Positional encoding + self.pos_encoding = nn.Parameter( + torch.zeros(1, max_seq_length, embedding_dim) + ) + nn.init.normal_(self.pos_encoding, std=0.02) + + # Transformer decoder layers + decoder_layer = nn.TransformerDecoderLayer( + d_model=embedding_dim, + nhead=num_heads, + dim_feedforward=dim_feedforward, + dropout=dropout, + batch_first=True, + ) + self.transformer_decoder = nn.TransformerDecoder( + decoder_layer, num_layers=num_layers + ) + + # Output projection + self.output_projection = nn.Linear(embedding_dim, extended_vocab_size) + + # Dropout + self.dropout_layer = nn.Dropout(dropout) + + # Initialize weights + self._init_weights() + + def _init_weights(self): + """Initialize model weights.""" + nn.init.normal_(self.token_embedding.weight, std=0.02) + if self.pad_idx is not None: + self.token_embedding.weight.data[self.pad_idx].zero_() + nn.init.normal_(self.output_projection.weight, std=0.02) + nn.init.zeros_(self.output_projection.bias) + + def flatten_nested_sequence( + self, nested_seq: torch.Tensor, visit_delim: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Flatten nested visit sequences into 1D sequences with visit delimiters. + + Args: + nested_seq: Tensor of shape (batch, num_visits, codes_per_visit) + visit_delim: Token ID for visit delimiter + + Returns: + Tuple of: + - Flattened sequence (batch, seq_len) + - Attention mask (batch, seq_len) + """ + batch_size, num_visits, codes_per_visit = nested_seq.shape + device = nested_seq.device + + # Initialize output sequence + max_seq_len = num_visits * (codes_per_visit + 1) # +1 for delimiter + flat_seq = torch.full( + (batch_size, max_seq_len), + self.pad_idx, + dtype=torch.long, + device=device, + ) + mask = torch.zeros(batch_size, max_seq_len, dtype=torch.bool, device=device) + + for b in range(batch_size): + pos = 0 + for v in range(num_visits): + # Add codes from this visit + visit_codes = nested_seq[b, v] + valid_codes = visit_codes[visit_codes != self.pad_idx] + + if len(valid_codes) > 0: + flat_seq[b, pos : pos + len(valid_codes)] = valid_codes + mask[b, pos : pos + len(valid_codes)] = True + pos += len(valid_codes) + + # Add visit delimiter (except after last visit) + if v < num_visits - 1 and len(valid_codes) > 0: + flat_seq[b, pos] = visit_delim + mask[b, pos] = True + pos += 1 + + return flat_seq, mask + + def unflatten_to_nested_sequence( + self, flat_seq: torch.Tensor, visit_delim: int, max_codes_per_visit: int + ) -> List[List[List[int]]]: + """Convert flattened sequences back to nested visit structure. + + Args: + flat_seq: Flattened sequence (batch, seq_len) + visit_delim: Token ID for visit delimiter + max_codes_per_visit: Maximum codes per visit + + Returns: + List of patient histories, each containing visits, each containing codes + """ + batch_size = flat_seq.shape[0] + nested_sequences = [] + + for b in range(batch_size): + seq = flat_seq[b].cpu().tolist() + patient_visits = [] + current_visit = [] + + for token in seq: + if token == self.pad_idx or token == self.eos_token: + # End of sequence + if current_visit: + patient_visits.append(current_visit) + break + elif token == visit_delim: + # End of visit + if current_visit: + patient_visits.append(current_visit) + current_visit = [] + elif token != self.bos_token: + # Regular code + current_visit.append(token) + + # Add last visit if exists + if current_visit: + patient_visits.append(current_visit) + + nested_sequences.append(patient_visits) + + return nested_sequences + + def forward(self, visit_codes: torch.Tensor, future_codes: torch.Tensor = None, **kwargs) -> Dict[str, torch.Tensor]: + """Forward pass for training or generation. + + Args: + visit_codes: Input nested sequences (batch, num_visits, codes_per_visit) + future_codes: Target nested sequences for teacher forcing (batch, num_visits, codes_per_visit) + **kwargs: Additional arguments (ignored) + + Returns: + Dictionary containing: + - logit: Raw predictions (batch, seq_len, vocab_size) + - loss: Cross-entropy loss (scalar) if future_codes provided + - y_true: True next tokens if future_codes provided + - y_prob: Predicted probabilities (batch, seq_len, vocab_size) + """ + # Move inputs to model's device + visit_codes = visit_codes.to(self.device) + if future_codes is not None: + future_codes = future_codes.to(self.device) + + # Flatten nested sequences + flat_input, input_mask = self.flatten_nested_sequence( + visit_codes, self.visit_delim_token + ) + + # Ensure flattened tensors are on the correct device + flat_input = flat_input.to(self.device) + input_mask = input_mask.to(self.device) + + # Get sequence length + seq_len = flat_input.size(1) + if seq_len > self.max_seq_length: + flat_input = flat_input[:, : self.max_seq_length] + input_mask = input_mask[:, : self.max_seq_length] + seq_len = self.max_seq_length + + # Embed tokens + embeddings = self.token_embedding(flat_input) # (batch, seq_len, embed_dim) + + # Add positional encoding + embeddings = embeddings + self.pos_encoding[:, :seq_len, :] + embeddings = self.dropout_layer(embeddings) + + # Create causal mask for autoregressive generation + causal_mask = nn.Transformer.generate_square_subsequent_mask( + seq_len, device=embeddings.device + ) + + # Create padding mask + padding_mask = ~input_mask # Invert: True = padding + + # Pass through transformer decoder + # For decoder-only, memory is None, so it uses self-attention + transformer_out = self.transformer_decoder( + tgt=embeddings, + memory=embeddings, # Use same sequence as memory for self-attention + tgt_mask=causal_mask, + tgt_key_padding_mask=padding_mask, + ) + + # Project to vocabulary + logits = self.output_projection(transformer_out) # (batch, seq_len, vocab_size) + + # Prepare output dictionary + output = { + "logit": logits, + "y_prob": F.softmax(logits, dim=-1), + } + + # Calculate loss if targets provided + if future_codes is not None: + flat_target, target_mask = self.flatten_nested_sequence( + future_codes, self.visit_delim_token + ) + + # Ensure target tensors are on the correct device + flat_target = flat_target.to(self.device) + target_mask = target_mask.to(self.device) + + if flat_target.size(1) > self.max_seq_length: + flat_target = flat_target[:, : self.max_seq_length] + target_mask = target_mask[:, : self.max_seq_length] + + # Shift target by 1 for next-token prediction + target_shifted = flat_target[:, 1:] # Remove first token + logits_shifted = logits[:, :-1, :] # Remove last prediction + mask_shifted = target_mask[:, 1:] + + # Flatten for loss calculation + logits_flat = logits_shifted.reshape(-1, logits_shifted.size(-1)) + target_flat = target_shifted.reshape(-1) + mask_flat = mask_shifted.reshape(-1) + + # Calculate loss only on non-padded tokens + loss = F.cross_entropy( + logits_flat[mask_flat], target_flat[mask_flat], ignore_index=self.pad_idx + ) + + output["loss"] = loss + output["y_true"] = flat_target + + return output + + @torch.no_grad() + def generate( + self, + num_samples: int = 1, + max_visits: int = 10, + max_codes_per_visit: int = 20, + max_length: int = 512, + temperature: float = 1.0, + top_k: int = 50, + top_p: float = 0.95, + ) -> List[List[List[int]]]: + """Generate synthetic patient histories. + + Args: + num_samples: Number of synthetic patients to generate + max_visits: Maximum number of visits per patient + max_codes_per_visit: Maximum codes per visit + max_length: Maximum total sequence length + temperature: Sampling temperature (higher = more random) + top_k: Keep only top k tokens for sampling (0 = disabled) + top_p: Nucleus sampling threshold (1.0 = disabled) + + Returns: + List of synthetic patient histories, each containing visits with diagnosis codes + """ + self.eval() + device = self.device + + generated_sequences = [] + + for _ in range(num_samples): + # Start with BOS token + current_seq = torch.tensor([[self.bos_token]], dtype=torch.long, device=device) + + for step in range(max_length - 1): + # Get sequence length + seq_len = current_seq.size(1) + + # Embed and add positional encoding + embeddings = self.token_embedding(current_seq) + embeddings = embeddings + self.pos_encoding[:, :seq_len, :] + + # Create causal mask + causal_mask = nn.Transformer.generate_square_subsequent_mask( + seq_len, device=device + ) + + # Pass through transformer + transformer_out = self.transformer_decoder( + tgt=embeddings, + memory=embeddings, + tgt_mask=causal_mask, + ) + + # Get logits for next token + logits = self.output_projection(transformer_out[:, -1, :]) # (1, vocab_size) + + # Apply temperature + logits = logits / temperature + + # Apply top-k filtering + if top_k > 0: + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = float("-inf") + + # Apply top-p (nucleus) filtering + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = float("-inf") + + # Sample next token + probs = F.softmax(logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + + # Append to sequence + current_seq = torch.cat([current_seq, next_token], dim=1) + + # Stop if EOS token generated + if next_token.item() == self.eos_token: + break + + generated_sequences.append(current_seq) + + # Convert to nested structure + generated_sequences = torch.cat(generated_sequences, dim=0) + nested_output = self.unflatten_to_nested_sequence( + generated_sequences, self.visit_delim_token, max_codes_per_visit + ) + + return nested_output diff --git a/pyhealth/synthetic_ehr_utils/__init__.py b/pyhealth/synthetic_ehr_utils/__init__.py new file mode 100644 index 000000000..b3e6fb140 --- /dev/null +++ b/pyhealth/synthetic_ehr_utils/__init__.py @@ -0,0 +1,8 @@ +from .synthetic_ehr_utils import ( + tabular_to_sequences, + sequences_to_tabular, + nested_codes_to_sequences, + sequences_to_nested_codes, + create_flattened_representation, + VISIT_DELIM, +) \ No newline at end of file diff --git a/pyhealth/synthetic_ehr_utils/synthetic_ehr_utils.py b/pyhealth/synthetic_ehr_utils/synthetic_ehr_utils.py new file mode 100644 index 000000000..063395488 --- /dev/null +++ b/pyhealth/synthetic_ehr_utils/synthetic_ehr_utils.py @@ -0,0 +1,375 @@ +""" +Utility functions for synthetic EHR generation. + +This module provides helper functions for converting between different +representations of EHR data, particularly for working with synthetic +data generation baselines. +""" + +import pandas as pd +from typing import List, Dict, Any + + +# Configuration +VISIT_DELIM = "VISIT_DELIM" + + +def tabular_to_sequences(df: pd.DataFrame, visit_delim: str = VISIT_DELIM) -> List[str]: + """Convert tabular EHR DataFrame to text sequences. + + Converts a long-form EHR DataFrame (one row per diagnosis code) into + patient sequences formatted as space-separated text with visit delimiters. + + Order: Group by SUBJECT_ID -> Sort/Group by HADM_ID -> Collect Codes + + Args: + df: DataFrame with columns SUBJECT_ID, HADM_ID, ICD9_CODE + visit_delim: Token to use as visit delimiter. Default is "VISIT_DELIM" + + Returns: + List of text sequences, one per patient. Each sequence contains + codes separated by spaces, with visits separated by the delimiter. + + Example: + >>> df = pd.DataFrame({ + ... 'SUBJECT_ID': [1, 1, 1, 2, 2], + ... 'HADM_ID': [100, 100, 200, 300, 300], + ... 'ICD9_CODE': ['410', '250', '410', '250', '401'] + ... }) + >>> sequences = tabular_to_sequences(df) + >>> print(sequences[0]) + '410 250 VISIT_DELIM 410' + """ + # 1. Clean data: Ensure IDs are integers/strings (handle the .0 issue) + df = df.copy() + df["SUBJECT_ID"] = df["SUBJECT_ID"].astype(int) + df["HADM_ID"] = df["HADM_ID"].astype(int) + df["ICD9_CODE"] = df["ICD9_CODE"].astype(str) + + # 2. Aggregation Helper + # This creates a list of codes for each admission + # Note: In real scenarios, ensure you sort by ADMITTIME before this step! + visits = ( + df.groupby(["SUBJECT_ID", "HADM_ID"])["ICD9_CODE"].apply(list).reset_index() + ) + + # 3. Create Patient Sequences + # Group visits by patient and join them with the delimiter + patient_seqs = [] + + # We iterate by patient to preserve structure + for subject_id, subject_data in visits.groupby("SUBJECT_ID"): + # subject_data is a DataFrame of visits for one patient + patient_history = [] + + for codes_list in subject_data["ICD9_CODE"]: + # Join codes within one visit (e.g., "code1 code2") + visit_str = " ".join(codes_list) + patient_history.append(visit_str) + + # Join all visits with the delimiter + full_seq = f" {visit_delim} ".join(patient_history) + patient_seqs.append(full_seq) + + return patient_seqs + + +def sequences_to_tabular( + sequences: List[str], visit_delim: str = VISIT_DELIM +) -> pd.DataFrame: + """Convert text sequences back to long-form DataFrame. + + Converts list of text sequences (generated by models) back into a + long-form DataFrame with synthetic SUBJECT_ID and HADM_ID. + + Args: + sequences: List of text sequences, one per patient + visit_delim: Token used as visit delimiter. Default is "VISIT_DELIM" + + Returns: + DataFrame with columns SUBJECT_ID, HADM_ID, ICD9_CODE + + Example: + >>> sequences = ['410 250 VISIT_DELIM 410', '250 401'] + >>> df = sequences_to_tabular(sequences) + >>> print(df) + SUBJECT_ID HADM_ID ICD9_CODE + 0 0 0 410 + 1 0 0 250 + 2 0 1 410 + 3 1 0 250 + 4 1 0 401 + """ + data_rows = [] + + for subj_idx, seq in enumerate(sequences): + # 1. Split sequence into visits + # We strip to remove leading/trailing spaces + visits = seq.strip().split(visit_delim) + + for hadm_idx, visit_str in enumerate(visits): + # 2. Split visit into individual codes + codes = visit_str.strip().split() + + # 3. Create a row for each code + for code in codes: + if code: # Check if code is not empty string + data_rows.append( + { + "SUBJECT_ID": subj_idx, # Synthetic Patient ID + "HADM_ID": hadm_idx, # Synthetic Visit ID + "ICD9_CODE": code, + } + ) + + return pd.DataFrame(data_rows) + + +def nested_codes_to_sequences( + nested_codes: List[List[List[str]]], visit_delim: str = VISIT_DELIM +) -> List[str]: + """Convert nested code structure to text sequences. + + Converts the nested structure from PyHealth models (list of patients, + each containing list of visits, each containing list of codes) into + text sequences. + + Args: + nested_codes: List of patients, each containing visits, each containing codes + visit_delim: Token used as visit delimiter. Default is "VISIT_DELIM" + + Returns: + List of text sequences, one per patient + + Example: + >>> nested = [[['410', '250'], ['410']], [['250', '401']]] + >>> sequences = nested_codes_to_sequences(nested) + >>> print(sequences[0]) + '410 250 VISIT_DELIM 410' + """ + sequences = [] + + for patient_visits in nested_codes: + visit_strings = [] + for visit_codes in patient_visits: + # Join codes within visit + visit_str = " ".join([str(c) for c in visit_codes if c]) + if visit_str: # Only add non-empty visits + visit_strings.append(visit_str) + + # Join visits with delimiter + full_seq = f" {visit_delim} ".join(visit_strings) + sequences.append(full_seq) + + return sequences + + +def sequences_to_nested_codes( + sequences: List[str], visit_delim: str = VISIT_DELIM +) -> List[List[List[str]]]: + """Convert text sequences to nested code structure. + + Converts text sequences into nested structure (list of patients, + each containing list of visits, each containing list of codes). + + Args: + sequences: List of text sequences, one per patient + visit_delim: Token used as visit delimiter. Default is "VISIT_DELIM" + + Returns: + List of patients, each containing visits, each containing codes + + Example: + >>> sequences = ['410 250 VISIT_DELIM 410', '250 401'] + >>> nested = sequences_to_nested_codes(sequences) + >>> print(nested[0]) + [['410', '250'], ['410']] + """ + nested_codes = [] + + for seq in sequences: + # Split sequence into visits + visits = seq.strip().split(visit_delim) + + patient_visits = [] + for visit_str in visits: + # Split visit into codes + codes = visit_str.strip().split() + if codes: # Only add non-empty visits + patient_visits.append(codes) + + nested_codes.append(patient_visits) + + return nested_codes + + +def create_flattened_representation( + df: pd.DataFrame, drop_subject_id: bool = True +) -> pd.DataFrame: + """Create flattened patient-level representation (crosstab). + + Converts long-form EHR data into a patient-level matrix where each + row is a patient and each column is a diagnosis code count. + + This representation is used by baseline models like GReaT, CTGAN, TVAE. + + Args: + df: DataFrame with columns SUBJECT_ID and ICD9_CODE + drop_subject_id: Whether to drop SUBJECT_ID column. Default is True. + + Returns: + DataFrame where rows are patients and columns are diagnosis codes + + Example: + >>> df = pd.DataFrame({ + ... 'SUBJECT_ID': [1, 1, 2, 2, 2], + ... 'ICD9_CODE': ['410', '250', '410', '410', '401'] + ... }) + >>> flattened = create_flattened_representation(df) + >>> print(flattened) + 250 401 410 + 0 1 0 1 + 1 0 1 2 + """ + df = df.copy() + df["ICD9_CODE"] = df["ICD9_CODE"].astype(str) + + # Remove NaN codes + df = df.dropna(subset=["ICD9_CODE"]) + df = df[df["ICD9_CODE"] != "nan"] + + # Create crosstab + result_df = pd.crosstab(df["SUBJECT_ID"], df["ICD9_CODE"]).reset_index() + result_df.columns.name = None + + # Drop NaN column if exists + if "nan" in result_df.columns: + result_df = result_df.drop(columns=["nan"]) + + if drop_subject_id: + result_df = result_df.drop(columns=["SUBJECT_ID"]) + + return result_df + + +def process_mimic_for_generation( + mimic_data_path: str, + train_patients_path: str = None, + test_patients_path: str = None, +) -> Dict[str, Any]: + """Process MIMIC data for synthetic generation tasks. + + Loads and processes MIMIC data, creating both sequential and flattened + representations suitable for different baseline models. + + Args: + mimic_data_path: Path to MIMIC data directory + train_patients_path: Path to file containing train patient IDs (one per line) + test_patients_path: Path to file containing test patient IDs (one per line) + + Returns: + Dictionary containing: + - train_ehr: Training data (long-form) + - test_ehr: Test data (long-form) + - train_flattened: Training data (patient-level matrix) + - test_flattened: Test data (patient-level matrix) + - train_sequences: Training data (text sequences) + - test_sequences: Test data (text sequences) + """ + import os + + # Load MIMIC data + admissions_df = pd.read_csv(os.path.join(mimic_data_path, "ADMISSIONS.csv")) + patients_df = pd.read_csv(os.path.join(mimic_data_path, "PATIENTS.csv")) + diagnoses_df = pd.read_csv(os.path.join(mimic_data_path, "DIAGNOSES_ICD.csv")) + + print(f"Admissions shape: {admissions_df.shape}") + print(f"Patients shape: {patients_df.shape}") + print(f"Diagnoses shape: {diagnoses_df.shape}") + + # Parse dates + admissions_df["ADMITTIME"] = pd.to_datetime(admissions_df["ADMITTIME"]) + patients_df["DOB"] = pd.to_datetime(patients_df["DOB"]) + + # Calculate age at first admission + first_admissions = admissions_df.loc[ + admissions_df.groupby("SUBJECT_ID")["ADMITTIME"].idxmin() + ][["SUBJECT_ID", "ADMITTIME"]] + + demo_df = pd.merge( + patients_df[["SUBJECT_ID", "GENDER", "DOB"]], + first_admissions, + on="SUBJECT_ID", + how="inner", + ) + + demo_df["AGE"] = demo_df["ADMITTIME"].dt.year - demo_df["DOB"].dt.year + demo_df["AGE"] = demo_df["AGE"].apply(lambda x: 90 if x > 89 else x) + + # Merge admissions with diagnoses + admissions_info = admissions_df[["SUBJECT_ID", "HADM_ID", "ADMITTIME"]] + merged_df = pd.merge( + admissions_info, + diagnoses_df[["SUBJECT_ID", "HADM_ID", "ICD9_CODE"]], + on=["SUBJECT_ID", "HADM_ID"], + how="inner", + ) + + # Merge with demographics + final_df = pd.merge( + merged_df, + demo_df[["SUBJECT_ID", "AGE", "GENDER"]], + on="SUBJECT_ID", + how="left", + ) + + # Sort chronologically + final_df.sort_values(by=["SUBJECT_ID", "ADMITTIME"], inplace=True) + + # Map HADM_ID to sequential visit IDs per patient + final_df["VISIT_ID"] = final_df.groupby("SUBJECT_ID")["HADM_ID"].transform( + lambda x: pd.factorize(x)[0] + ) + + final_df["SUBJECT_ID"] = final_df["SUBJECT_ID"].astype(str) + final_df["HADM_ID"] = final_df["HADM_ID"].astype(float) + final_df["ICD9_CODE"] = final_df["ICD9_CODE"].astype(str) + + # Keep only essential columns + final_df = final_df[["SUBJECT_ID", "HADM_ID", "ICD9_CODE"]] + final_df = final_df.dropna() + + result = {} + + # Split by train and test if provided + if train_patients_path is not None and test_patients_path is not None: + train_patient_ids = ( + pd.read_csv(train_patients_path, header=None)[0].astype(str) + ) + test_patient_ids = ( + pd.read_csv(test_patients_path, header=None)[0].astype(str) + ) + + train_ehr = final_df[final_df["SUBJECT_ID"].isin(train_patient_ids)].reset_index( + drop=True + ) + test_ehr = final_df[final_df["SUBJECT_ID"].isin(test_patient_ids)].reset_index( + drop=True + ) + + result["train_ehr"] = train_ehr + result["test_ehr"] = test_ehr + + # Create flattened representations + result["train_flattened"] = create_flattened_representation(train_ehr) + result["test_flattened"] = create_flattened_representation(test_ehr) + + # Create sequences + result["train_sequences"] = tabular_to_sequences(train_ehr) + result["test_sequences"] = tabular_to_sequences(test_ehr) + else: + result["full_ehr"] = final_df + result["full_flattened"] = create_flattened_representation(final_df) + result["full_sequences"] = tabular_to_sequences(final_df) + + return result diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 2f4294a19..cde5f334a 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -68,3 +68,7 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .synthetic_ehr_generation import ( + SyntheticEHRGenerationMIMIC3, + SyntheticEHRGenerationMIMIC4, +) diff --git a/pyhealth/tasks/synthetic_ehr_generation.py b/pyhealth/tasks/synthetic_ehr_generation.py new file mode 100644 index 000000000..6536e98b3 --- /dev/null +++ b/pyhealth/tasks/synthetic_ehr_generation.py @@ -0,0 +1,236 @@ +""" +Task for generating synthetic EHR data. + +This module contains tasks for training generative models on Electronic Health Records. +The tasks process patient visit sequences to create samples suitable for training +autoregressive models that can generate synthetic patient histories. +""" + +from typing import Any, Dict, List + +import polars as pl + +from pyhealth.data import Patient +from .base_task import BaseTask + + +class SyntheticEHRGenerationMIMIC3(BaseTask): + """Task for synthetic EHR generation using MIMIC-III dataset. + + This task prepares patient visit sequences for training autoregressive generative + models. Each sample represents a patient's complete history of diagnoses across + multiple visits, formatted as a nested sequence suitable for sequence-to-sequence + modeling. + + The task creates samples where: + - Input: Historical visit sequences (all visits except potentially the last) + - Output: Full visit sequences (for teacher forcing during training) + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The schema for input data: + - visit_codes: Nested list of diagnosis codes per visit + output_schema (Dict[str, str]): The schema for output data: + - future_codes: Nested list of all diagnosis codes (for generation) + + Args: + min_visits (int): Minimum number of visits required per patient. Default is 2. + max_visits (int): Maximum number of visits to include per patient. + If None, includes all visits. Default is None. + """ + + task_name: str = "SyntheticEHRGenerationMIMIC3" + input_schema: Dict[str, str] = { + "visit_codes": "nested_sequence", + } + output_schema: Dict[str, str] = { + "future_codes": "nested_sequence", + } + + def __init__(self, min_visits: int = 2, max_visits: int = None): + """Initialize the synthetic EHR generation task. + + Args: + min_visits: Minimum number of visits required per patient + max_visits: Maximum number of visits to include (None = no limit) + """ + self.min_visits = min_visits + self.max_visits = max_visits + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process a patient to create synthetic EHR generation samples. + + For generative modeling, we create one sample per patient containing their + complete visit history. Each visit contains diagnosis codes. + + Args: + patient: Patient object with get_events method + + Returns: + List containing a single sample with patient_id and nested visit sequences + """ + samples = [] + + # Get all admissions sorted chronologically + admissions = patient.get_events(event_type="admissions") + + # Filter by minimum visits + if len(admissions) < self.min_visits: + return [] + + # Limit to max_visits if specified + if self.max_visits is not None: + admissions = admissions[:self.max_visits] + + # Collect diagnosis codes for each visit + visit_sequences = [] + valid_visit_count = 0 + + for admission in admissions: + # Get diagnosis codes using hadm_id + diagnoses_icd = patient.get_events( + event_type="diagnoses_icd", + filters=[("hadm_id", "==", admission.hadm_id)], + return_df=True, + ) + + if diagnoses_icd is None or len(diagnoses_icd) == 0: + continue + + conditions = ( + diagnoses_icd.select(pl.col("diagnoses_icd/icd9_code")) + .to_series() + .to_list() + ) + + # Filter out empty or null codes + conditions = [c for c in conditions if c] + + if len(conditions) > 0: + visit_sequences.append(conditions) + valid_visit_count += 1 + + # Check if we have enough valid visits + if valid_visit_count < self.min_visits: + return [] + + # Create a single sample with the full patient history + # For autoregressive generation, both input and output are the same sequence + sample = { + "patient_id": patient.patient_id, + "visit_codes": visit_sequences, + "future_codes": visit_sequences, # Same as input for teacher forcing + } + + samples.append(sample) + return samples + + +class SyntheticEHRGenerationMIMIC4(BaseTask): + """Task for synthetic EHR generation using MIMIC-IV dataset. + + This task prepares patient visit sequences for training autoregressive generative + models on MIMIC-IV data. Each sample represents a patient's complete history of + diagnoses across multiple visits. + + The task creates samples where: + - Input: Historical visit sequences (all visits except potentially the last) + - Output: Full visit sequences (for teacher forcing during training) + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The schema for input data: + - visit_codes: Nested list of diagnosis codes per visit + output_schema (Dict[str, str]): The schema for output data: + - future_codes: Nested list of all diagnosis codes (for generation) + + Args: + min_visits (int): Minimum number of visits required per patient. Default is 2. + max_visits (int): Maximum number of visits to include per patient. + If None, includes all visits. Default is None. + """ + + task_name: str = "SyntheticEHRGenerationMIMIC4" + input_schema: Dict[str, str] = { + "visit_codes": "nested_sequence", + } + output_schema: Dict[str, str] = { + "future_codes": "nested_sequence", + } + + def __init__(self, min_visits: int = 2, max_visits: int = None): + """Initialize the synthetic EHR generation task. + + Args: + min_visits: Minimum number of visits required per patient + max_visits: Maximum number of visits to include (None = no limit) + """ + self.min_visits = min_visits + self.max_visits = max_visits + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process a patient to create synthetic EHR generation samples. + + For generative modeling, we create one sample per patient containing their + complete visit history. Each visit contains diagnosis codes. + + Args: + patient: Patient object with get_events method + + Returns: + List containing a single sample with patient_id and nested visit sequences + """ + samples = [] + + # Get all admissions sorted chronologically + admissions = patient.get_events(event_type="admissions") + + # Filter by minimum visits + if len(admissions) < self.min_visits: + return [] + + # Limit to max_visits if specified + if self.max_visits is not None: + admissions = admissions[:self.max_visits] + + # Collect diagnosis codes for each visit + visit_sequences = [] + valid_visit_count = 0 + + for admission in admissions: + # Get diagnosis codes using hadm_id + diagnoses_icd = patient.get_events( + event_type="diagnoses_icd", + filters=[("hadm_id", "==", admission.hadm_id)], + return_df=True, + ) + + if diagnoses_icd is None or len(diagnoses_icd) == 0: + continue + + conditions = ( + diagnoses_icd.select(pl.col("diagnoses_icd/icd9_code")) + .to_series() + .to_list() + ) + + # Filter out empty or null codes + conditions = [c for c in conditions if c] + + if len(conditions) > 0: + visit_sequences.append(conditions) + valid_visit_count += 1 + + # Check if we have enough valid visits + if valid_visit_count < self.min_visits: + return [] + + # Create a single sample with the full patient history + sample = { + "patient_id": patient.patient_id, + "visit_codes": visit_sequences, + "future_codes": visit_sequences, # Same as input for teacher forcing + } + + samples.append(sample) + return samples