diff --git a/examples/generate_synthetic_mimic3_promptehr.py b/examples/generate_synthetic_mimic3_promptehr.py
new file mode 100644
index 000000000..5eefb7ff7
--- /dev/null
+++ b/examples/generate_synthetic_mimic3_promptehr.py
@@ -0,0 +1,47 @@
+"""PromptEHR: Synthetic MIMIC-III Patient Generation.
+
+Load a trained PromptEHR checkpoint and generate synthetic patients.
+
+Reference:
+ Wang et al. "PromptEHR: Conditional Electronic Healthcare Records
+ Generation with Prompt Learning." EMNLP 2023.
+ https://arxiv.org/abs/2211.01761
+"""
+
+import json
+
+from pyhealth.datasets import MIMIC3Dataset
+from pyhealth.models import PromptEHR
+from pyhealth.tasks import promptehr_generation_mimic3_fn
+
+MIMIC3_ROOT = "/srv/local/data/physionet.org/files/mimiciii/1.4"
+CHECKPOINT_PATH = "./save/promptehr/checkpoint.pt"
+OUTPUT_PATH = "./save/promptehr/synthetic_patients.json"
+NUM_SAMPLES = 10_000
+
+# 1. Load dataset + apply task (needed for processor/vocab reconstruction)
+dataset = MIMIC3Dataset(
+ root=MIMIC3_ROOT,
+ tables=["patients", "admissions", "diagnoses_icd"],
+ code_mapping={},
+)
+sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)
+
+# 2. Load checkpoint
+model = PromptEHR(dataset=sample_dataset)
+model.load_model(CHECKPOINT_PATH)
+print(f"Loaded checkpoint from {CHECKPOINT_PATH}")
+
+# 3. Generate
+print(f"Generating {NUM_SAMPLES} synthetic patients...")
+synthetic = model.synthesize_dataset(num_samples=NUM_SAMPLES)
+print(f"Generated {len(synthetic)} patients")
+
+# 4. Save
+with open(OUTPUT_PATH, "w") as f:
+ json.dump(synthetic, f, indent=2)
+print(f"Saved to {OUTPUT_PATH}")
+
+# Summary stats
+avg_visits = sum(len(p["visits"]) for p in synthetic) / len(synthetic)
+print(f"Average visits per patient: {avg_visits:.2f}")
diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb
new file mode 100644
index 000000000..dd6fc96e1
--- /dev/null
+++ b/examples/promptehr_mimic3_colab.ipynb
@@ -0,0 +1,192 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 5,
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.10.0"
+ },
+ "colab": {
+ "provenance": [],
+ "gpuType": "T4"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "preamble",
+ "metadata": {},
+ "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-05_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime → Change runtime type → GPU)\n\n3. **Time**:\n - Demo (5 epochs, 100 samples): ~20–30 min on GPU\n - Production (20 epochs, 10K samples): ~3–5 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV file with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 100 samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) — Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s1-header",
+ "metadata": {},
+ "source": "---\n# 1. Setup & Installation"
+ },
+ {
+ "cell_type": "code",
+ "id": "s1-install",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "import subprocess\nimport sys\n\n# Install PyHealth from GitHub.\n# We uninstall first (to clear any stale build from a previous session),\n# then do a normal install. We do NOT use --force-reinstall because it\n# force-reinstalls ALL transitive deps, which in Colab's system environment\n# creates mixed-version states (old .so + new .py from different versions).\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\n\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"uninstall\", \"pyhealth\", \"-y\"],\n capture_output=True, text=True,\n)\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")"
+ },
+ {
+ "cell_type": "code",
+ "id": "s1-imports",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "import os\nos.environ[\"WANDB_DISABLED\"] = \"true\" # prevent wandb login prompt during training\n\nimport random\nimport shutil\nimport numpy as np\nimport torch\nimport pandas as pd\nfrom IPython.display import display\nfrom google.colab import drive, files\n\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")"
+ },
+ {
+ "cell_type": "code",
+ "id": "s1-drive",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# Mount Google Drive for persistent storage\nprint(\"Mounting Google Drive...\")\nif not os.path.ismount('/content/drive'):\n drive.mount('/content/drive', force_remount=True)\nelse:\n print(\"Drive already mounted\")\nprint(\"✓ Google Drive mounted\")\n\n# Create directory structure in Drive\nBASE_DIR = '/content/drive/MyDrive/PromptEHR_Training'\nDATA_DIR = f'{BASE_DIR}/data'\nCHECKPOINT_DIR = f'{BASE_DIR}/checkpoints'\nOUTPUT_DIR = f'{BASE_DIR}/output'\n\nfor d in [DATA_DIR, CHECKPOINT_DIR, OUTPUT_DIR]:\n os.makedirs(d, exist_ok=True)\n\nprint(f\"\\nDirectory structure created:\")\nprint(f\" Base: {BASE_DIR}\")\nprint(f\" Data: {DATA_DIR}\")\nprint(f\" Checkpoints: {CHECKPOINT_DIR}\")\nprint(f\" Output: {OUTPUT_DIR}\")"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s2-header",
+ "metadata": {},
+ "source": "---\n# 2. Configuration"
+ },
+ {
+ "cell_type": "code",
+ "id": "s2-config",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# ============================================================\n# CONFIGURATION — All modifiable parameters in one place\n# ============================================================\n\n# --- Training parameters ---\nEPOCHS = 5 # Demo: 5, Production: 20\nBATCH_SIZE = 16 # 16 for both\nN_SYNTHETIC_SAMPLES = 100 # Demo: 100, Production: 10000\nWARMUP_STEPS = 100 # Demo: 100, Production: 1000\n\nLR = 1e-5 # Paper LR; low to avoid catastrophic forgetting of BART weights\nMAX_SEQ_LENGTH = 512 # Max tokens per patient (visits + special tokens)\n\n# --- Model architecture ---\nD_HIDDEN = 128 # Hidden dim for demographic prompt encoder\nPROMPT_LENGTH = 1 # Prompt vectors per demographic feature (1 is sufficient per paper)\n\n# --- BART backbone ---\n# \"facebook/bart-base\": pretrained BART (139 M params, 768 hidden dim).\n# PromptEHR fine-tunes these weights rather than training from scratch —\n# the pretrained sequence modeling prior means even 20 epochs can produce good results.\nBART_CONFIG_NAME = \"facebook/bart-base\"\n\n# --- Generation parameters ---\nRANDOM_SAMPLING = True # True: nucleus sampling (diverse), False: greedy (deterministic)\nTEMPERATURE = 0.7 # Lower = more common codes. Higher = more rare/diverse codes.\nTOP_P = 0.95 # Nucleus sampling: sample from top 95% probability mass.\n\n# --- Reproducibility ---\nSEED = 42\n\n# Display configuration\nprint(\"=\" * 60)\nprint(\"PROMPTEHR CONFIGURATION\")\nprint(\"=\" * 60)\nprint(f\"Training:\")\nprint(f\" Epochs: {EPOCHS} | Batch size: {BATCH_SIZE} | LR: {LR}\")\nprint(f\" Warmup steps: {WARMUP_STEPS}\")\nprint(f\"\\nGeneration:\")\nprint(f\" Synthetic samples: {N_SYNTHETIC_SAMPLES:,}\")\nprint(f\" Sampling: {'nucleus (random)' if RANDOM_SAMPLING else 'greedy'}\")\nprint(f\"\\nPaths:\")\nprint(f\" Base directory: {BASE_DIR}\")\nprint(\"=\" * 60)"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s3-header",
+ "metadata": {},
+ "source": "---\n# 3. Data Upload"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s3-desc",
+ "metadata": {},
+ "source": "Upload your MIMIC-III CSV files. PromptEHR needs **3 files** (one more than HALO — `PATIENTS.csv` is required for demographic conditioning):\n\n1. `PATIENTS.csv` — date of birth and gender\n2. `ADMISSIONS.csv` — admission timestamps (used to compute age at first admission)\n3. `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\nFiles persist across Colab sessions when saved to Google Drive."
+ },
+ {
+ "cell_type": "code",
+ "id": "s3-upload",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# Check which files exist in the Drive-backed DATA_DIR\nrequired_files = {\n 'PATIENTS.csv': 'Patient demographics (DOB, gender)',\n 'ADMISSIONS.csv': 'Admission records (timestamps)',\n 'DIAGNOSES_ICD.csv': 'ICD-9 diagnosis codes',\n}\nexisting = {f: os.path.exists(f'{DATA_DIR}/{f}') for f in required_files}\nmissing = [f for f, ok in existing.items() if not ok]\n\nif not missing:\n # All files already in Drive — no upload needed\n print(\"✓ All MIMIC-III files found in Drive (no upload needed):\")\n for fname in required_files:\n size_mb = os.path.getsize(f'{DATA_DIR}/{fname}') / 1024 / 1024\n print(f\" {fname} ({size_mb:.1f} MB)\")\n print(f\"\\nFiles are reused from: {DATA_DIR}\")\n print(\"To force re-upload, delete files from that folder and re-run this cell.\")\nelse:\n print(\"MIMIC-III file status:\")\n for fname, desc in required_files.items():\n mark = \"✓\" if existing[fname] else \"✗ MISSING\"\n print(f\" {mark} {fname} — {desc}\")\n\n print(f\"\\nUploading {len(missing)} missing file(s)...\")\n uploaded = files.upload()\n\n # Normalize filenames — Colab renames duplicates as \"ADMISSIONS (1).csv\".\n # Match each upload to the required file it belongs to, then copy with\n # the canonical name so subsequent runs find the file in Drive.\n for uploaded_name, data in uploaded.items():\n matched = None\n for req in required_files:\n base = req.replace('.csv', '')\n if base in uploaded_name and uploaded_name.endswith('.csv'):\n matched = req\n break\n if matched:\n tmp = f'/content/{uploaded_name}'\n with open(tmp, 'wb') as f:\n f.write(data)\n dest = f'{DATA_DIR}/{matched}'\n shutil.copy(tmp, dest)\n size_mb = os.path.getsize(dest) / 1024 / 1024\n print(f\" ✓ Saved {matched} ({size_mb:.1f} MB) → {dest}\")\n else:\n print(f\" ⚠ Unrecognised file: {uploaded_name} (skipped)\")\n\n missing = [f for f in required_files if not os.path.exists(f'{DATA_DIR}/{f}')]\n if missing:\n raise FileNotFoundError(\n f\"Still missing: {missing}. Please re-run this cell to upload them.\"\n )\n print(\"\\n✓ All 3 MIMIC-III files present.\")"
+ },
+ {
+ "cell_type": "code",
+ "id": "s3-validate",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "print(\"Validating MIMIC-III files...\")\n\n_patients = pd.read_csv(f'{DATA_DIR}/PATIENTS.csv')\nassert 'SUBJECT_ID' in _patients.columns, \"PATIENTS.csv missing SUBJECT_ID\"\nassert 'GENDER' in _patients.columns, \"PATIENTS.csv missing GENDER\"\nassert 'DOB' in _patients.columns, \"PATIENTS.csv missing DOB\"\nprint(f\"✓ PATIENTS.csv: {len(_patients):>8,} rows\")\n\n_admissions = pd.read_csv(f'{DATA_DIR}/ADMISSIONS.csv')\nassert 'SUBJECT_ID' in _admissions.columns, \"ADMISSIONS.csv missing SUBJECT_ID\"\nassert 'HADM_ID' in _admissions.columns, \"ADMISSIONS.csv missing HADM_ID\"\nprint(f\"✓ ADMISSIONS.csv: {len(_admissions):>8,} rows\")\n\n_diagnoses = pd.read_csv(f'{DATA_DIR}/DIAGNOSES_ICD.csv')\nassert 'ICD9_CODE' in _diagnoses.columns, \"DIAGNOSES_ICD.csv missing ICD9_CODE\"\nprint(f\"✓ DIAGNOSES_ICD.csv: {len(_diagnoses):>8,} rows\")\n\ndel _patients, _admissions, _diagnoses # free memory\nprint(\"\\n✓ All files validated successfully\")"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s4-header",
+ "metadata": {},
+ "source": "---\n# 4. Training"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s4-desc",
+ "metadata": {},
+ "source": "**What happens during training:**\n\n1. **Dataset loading**: PyHealth reads MIMIC-III and creates one sample per patient (nested visit sequences + demographics: age at first admission, gender).\n2. **Tokenization**: Each ICD-9 code is mapped to a unique BART token ID. Special tokens mark visit boundaries: `[VISIT_START]`, `[VISIT_END]`, `[SEQ_END]`.\n3. **Demographic prompts**: Age and gender are encoded into learned prompt vectors prepended to the BART encoder input — steering the model toward age/gender-appropriate diagnosis patterns.\n4. **Fine-tuning**: HuggingFace Trainer fine-tunes the BART Seq2Seq model to predict the next token conditioned on the demographic prompts.\n5. **Checkpoint**: Saved to `{CHECKPOINT_DIR}/checkpoint.pt` after training.\n\nThe `WARMUP_STEPS` ramp up the learning rate gradually during early training, preventing catastrophic forgetting of BART's pretrained sequence modeling capabilities."
+ },
+ {
+ "cell_type": "code",
+ "id": "s4-dataset",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# Set all random seeds before any stochastic operation\ntorch.manual_seed(SEED)\nnp.random.seed(SEED)\nrandom.seed(SEED)\nif torch.cuda.is_available():\n torch.cuda.manual_seed_all(SEED)\n torch.backends.cudnn.deterministic = True\nprint(f\"✓ Random seed set to {SEED}\")\n\nfrom pyhealth.datasets import MIMIC3Dataset, split_by_patient\nfrom pyhealth.tasks import promptehr_generation_mimic3_fn\nfrom pyhealth.models import PromptEHR\n\nprint(\"\\nLoading MIMIC-III dataset (this may take a few minutes)...\")\ndataset = MIMIC3Dataset(\n root=DATA_DIR,\n tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n)\nprint(f\"Loaded {len(dataset.unique_patient_ids):,} patients\")\n\nprint(\"Applying PromptEHR generation task...\")\nsample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\nprint(f\"Eligible patients (≥2 visits with ICD-9 codes): {len(sample_dataset):,}\")\n\ntrain_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\nprint(f\"\\nSplit: {len(train_dataset):,} train / {len(val_dataset):,} val patients\")"
+ },
+ {
+ "cell_type": "code",
+ "id": "s4-train",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# Initialize model\nprint(\"Initializing PromptEHR model...\")\nmodel = PromptEHR(\n dataset=train_dataset,\n n_num_features=1, # 1 continuous demographic feature: age\n cat_cardinalities=[2], # 1 categorical feature: gender (binary: 0=male, 1=female)\n d_hidden=D_HIDDEN,\n prompt_length=PROMPT_LENGTH,\n bart_config_name=BART_CONFIG_NAME,\n epochs=EPOCHS,\n batch_size=BATCH_SIZE,\n lr=LR,\n warmup_steps=WARMUP_STEPS,\n max_seq_length=MAX_SEQ_LENGTH,\n save_dir=CHECKPOINT_DIR,\n)\n\nn_special = 7 # PAD, BOS, EOS, UNK, VISIT_START, VISIT_END, SEQ_END\nn_codes = model._vocab.total_size - n_special\ntotal_params = sum(p.numel() for p in model.parameters())\nprint(f\"✓ PromptEHR initialized\")\nprint(f\" Vocabulary: {model._vocab.total_size} tokens \"\n f\"({n_codes} ICD-9 codes + {n_special} special tokens)\")\nprint(f\" Parameters: {total_params:,}\")\n\n# Train\nprint(\"\\nStarting training...\")\nprint(\"HuggingFace Trainer will print step-by-step progress below.\")\nprint(\"=\" * 60)\n\nmodel.train_model(train_dataset, val_dataset=val_dataset)\n\nprint(\"=\" * 60)\nprint(\"✓ Training complete!\")\nprint(f\" Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s5-header",
+ "metadata": {},
+ "source": "---\n# 5. Generation"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s5-desc",
+ "metadata": {},
+ "source": "**How generation works:**\n\n1. **Demographic sampling**: For each synthetic patient, `synthesize_dataset` draws an `(age, gender)` pair from `model._demo_pool` — the real training population. This ensures the synthetic cohort's demographic profile mirrors MIMIC-III.\n2. **Prompt conditioning**: The sampled demographics are encoded into prompt vectors and prepended to the BART encoder input.\n3. **Autoregressive decoding**: BART generates tokens one at a time. Special tokens `[VISIT_START]` and `[VISIT_END]` structure the output into visits; `[SEQ_END]` ends the patient sequence.\n4. **Decoding**: Token IDs are mapped back to ICD-9 code strings.\n\n`RANDOM_SAMPLING = True` (default): nucleus sampling — diverse, realistic output. \n`RANDOM_SAMPLING = False`: greedy decoding — deterministic, may repeat common patterns."
+ },
+ {
+ "cell_type": "code",
+ "id": "s5-generate",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "print(f\"Generating {N_SYNTHETIC_SAMPLES:,} synthetic patients...\")\nprint(f\" Sampling: {'nucleus (random)' if RANDOM_SAMPLING else 'greedy'}\"\n + (f\", temperature={TEMPERATURE}, top_p={TOP_P}\" if RANDOM_SAMPLING else \"\"))\nprint(\"(This may take several minutes...)\")\n\nsynthetic = model.synthesize_dataset(\n num_samples=N_SYNTHETIC_SAMPLES,\n random_sampling=RANDOM_SAMPLING,\n)\n\nprint(f\"\\n✓ Generated {len(synthetic):,} synthetic patients\")\n\n# Preview\n_preview = []\nfor p in synthetic[:10]:\n _v0 = p[\"visits\"][0] if p[\"visits\"] else []\n _sample = \", \".join(_v0[:4]) + (\"...\" if len(_v0) > 4 else \"\")\n _preview.append({\n \"patient_id\": p[\"patient_id\"],\n \"n_visits\": len(p[\"visits\"]),\n \"total_codes\": sum(len(v) for v in p[\"visits\"]),\n \"first_visit_codes\": _sample or \"(empty)\",\n })\ndisplay(pd.DataFrame(_preview))"
+ },
+ {
+ "cell_type": "code",
+ "id": "s5-save",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# Save as CSV (flat SUBJECT_ID, VISIT_NUM, ICD9_CODE — matches MIMIC-III output schema)\n_rows = []\nfor p in synthetic:\n for _vnum, _visit in enumerate(p[\"visits\"], 1):\n for _code in _visit:\n _rows.append({\"SUBJECT_ID\": p[\"patient_id\"],\n \"VISIT_NUM\": _vnum,\n \"ICD9_CODE\": _code})\ndf_synthetic = pd.DataFrame(_rows)\ncsv_path = f'{OUTPUT_DIR}/synthetic_patients.csv'\ndf_synthetic.to_csv(csv_path, index=False)\nprint(f\"✓ {len(df_synthetic):,} records → {csv_path}\")\nprint(f\" Columns: SUBJECT_ID, VISIT_NUM, ICD9_CODE\")\nprint(\"\\nSample rows:\")\ndisplay(df_synthetic.head(8))"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s6-header",
+ "metadata": {},
+ "source": "---\n# 6. Results & Download"
+ },
+ {
+ "cell_type": "code",
+ "id": "s6-stats",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# Validate generated data\nprint(\"=\" * 60)\nprint(\"DATA QUALITY CHECKS\")\nprint(\"=\" * 60)\n\n# Check 1: Patient IDs\nunique_patients = df_synthetic['SUBJECT_ID'].nunique()\nprint(f\"\\nUnique patients: {unique_patients} out of {N_SYNTHETIC_SAMPLES} requested\")\n\n# Check 2: No empty values\nempty_subjects = df_synthetic['SUBJECT_ID'].isna().sum()\nempty_visits = df_synthetic['VISIT_NUM'].isna().sum()\nempty_codes = df_synthetic['ICD9_CODE'].isna().sum()\n\nprint(f\"\\nEmpty values check:\")\nprint(f\" Subject IDs: {empty_subjects} (should be 0)\")\nprint(f\" Visit numbers: {empty_visits} (should be 0)\")\nprint(f\" ICD9 codes: {empty_codes} (should be 0)\")\nassert empty_subjects == 0 and empty_visits == 0 and empty_codes == 0, \"Found empty values!\"\n\n# Check 3: Distribution statistics\ncodes_per_patient = df_synthetic.groupby('SUBJECT_ID').size()\nprint(f\"\\nCodes per patient:\")\nprint(f\" Min: {codes_per_patient.min()}\")\nprint(f\" Max: {codes_per_patient.max()}\")\nprint(f\" Mean: {codes_per_patient.mean():.2f}\")\nprint(f\" Median: {codes_per_patient.median():.2f}\")\n\nvisits_per_patient = df_synthetic.groupby('SUBJECT_ID')['VISIT_NUM'].max()\nprint(f\"\\nVisits per patient:\")\nprint(f\" Min: {visits_per_patient.min()}\")\nprint(f\" Max: {visits_per_patient.max()}\")\nprint(f\" Mean: {visits_per_patient.mean():.2f}\")\nprint(f\" Median: {visits_per_patient.median():.2f}\")\n\nprint(\"\\n\" + \"=\" * 60)\nprint(\"ALL QUALITY CHECKS PASSED\")\nprint(\"=\" * 60)"
+ },
+ {
+ "cell_type": "code",
+ "id": "s6-download",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# Download CSV file\nprint(\"=\" * 60)\nprint(\"DOWNLOAD SYNTHETIC DATA\")\nprint(\"=\" * 60)\n\nprint(f\"\\nYour synthetic data is ready:\")\nprint(f\" File: synthetic_patients.csv\")\nprint(f\" Patients: {unique_patients:,}\")\nprint(f\" Total records: {len(df_synthetic):,}\")\nprint(f\" Size: {os.path.getsize(csv_path) / (1024*1024):.2f} MB\")\n\nprint(f\"\\nDownloading file to your computer...\")\nfiles.download(csv_path)\n\nprint(f\"\\nDownload started!\")\nprint(f\"\\nFile also saved in Google Drive:\")\nprint(f\" {csv_path}\")"
+ },
+ {
+ "cell_type": "code",
+ "id": "s7-resume",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# ─────────────────────────────────────────────────────────────────────────────\n# CHECKPOINT RESUME — Run this cell instead of Section 4 if you already trained\n# ─────────────────────────────────────────────────────────────────────────────\n# Uncomment everything below to load an existing checkpoint, then skip to Section 5.\n\n# from pyhealth.datasets import MIMIC3Dataset, split_by_patient\n# from pyhealth.tasks import promptehr_generation_mimic3_fn\n# from pyhealth.models import PromptEHR\n#\n# dataset = MIMIC3Dataset(\n# root=DATA_DIR,\n# tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n# )\n# sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n# train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\n#\n# model = PromptEHR(\n# dataset=train_dataset,\n# n_num_features=1, cat_cardinalities=[2],\n# d_hidden=D_HIDDEN, prompt_length=PROMPT_LENGTH,\n# bart_config_name=BART_CONFIG_NAME,\n# epochs=EPOCHS, batch_size=BATCH_SIZE,\n# lr=LR, warmup_steps=WARMUP_STEPS,\n# max_seq_length=MAX_SEQ_LENGTH,\n# save_dir=CHECKPOINT_DIR,\n# )\n# ckpt = f'{CHECKPOINT_DIR}/checkpoint.pt'\n# model.load_model(ckpt)\n# print(f\"✓ Loaded checkpoint from {ckpt}. Proceed to Section 5.\")\n\nprint(\"(Resume template — uncomment the lines above to use)\")"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s7-congrats",
+ "metadata": {},
+ "source": "---\n## Congratulations!\n\nYou've successfully:\n1. Trained a PromptEHR model conditioned on patient demographics\n2. Generated synthetic patients whose age/gender distribution mirrors MIMIC-III\n3. Validated the synthetic data quality\n4. Downloaded the CSV file\n\n## Next Steps\n\n**Use your synthetic data:**\n- Train readmission/mortality/LoS prediction models on synthetic data\n- Evaluate fairness across demographic subgroups\n- Share synthetic patients without privacy concerns\n\n**Generate more samples:**\n- Change `N_SYNTHETIC_SAMPLES` and re-run Section 5\n- No need to retrain — the checkpoint is saved!\n\n**Production training:**\n- For publication-quality results, set `EPOCHS = 20` and `N_SYNTHETIC_SAMPLES = 10000`\n- Consider using Colab Pro or a dedicated GPU cluster\n- See `examples/slurm/` for cluster usage\n\n## Troubleshooting\n\n| Symptom | Cause | Fix |\n|---------|-------|-----|\n| `AssertionError: transformers>=4.48.3 required` | Old transformers installed | `pip install transformers --upgrade` |\n| Empty patients in output | Undertrained model | Increase `EPOCHS` or raise `TEMPERATURE` to `1.0` |\n| Training loss not decreasing after 2+ epochs | LR too high | Try `LR = 5e-6` and `WARMUP_STEPS = 500` |\n| Out of memory (OOM) | Batch too large | Reduce `BATCH_SIZE = 8` |\n| Very slow training | No GPU | Runtime → Change runtime type → T4 GPU |\n| Synthetic codes all the same | Temperature too low | Try `TEMPERATURE = 1.0`, `RANDOM_SAMPLING = True` |"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/examples/promptehr_mimic3_training.py b/examples/promptehr_mimic3_training.py
new file mode 100644
index 000000000..8387208db
--- /dev/null
+++ b/examples/promptehr_mimic3_training.py
@@ -0,0 +1,47 @@
+"""PromptEHR: Training on MIMIC-III.
+
+Train PromptEHR for synthetic EHR generation using PyHealth 2.0 API.
+
+Reference:
+ Wang et al. "PromptEHR: Conditional Electronic Health Records Generation
+ with Prompt Learning." CHIL 2023.
+"""
+
+from pyhealth.datasets import MIMIC3Dataset, split_by_patient
+from pyhealth.models import PromptEHR
+from pyhealth.tasks import promptehr_generation_mimic3_fn
+
+MIMIC3_ROOT = "/srv/local/data/physionet.org/files/mimiciii/1.4"
+
+# 1. Load MIMIC-III
+dataset = MIMIC3Dataset(
+ root=MIMIC3_ROOT,
+ tables=["patients", "admissions", "diagnoses_icd"],
+ code_mapping={},
+)
+
+# 2. Apply generation task
+sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)
+print(f"Patients: {len(sample_dataset)}")
+sample_dataset.stat()
+
+# 3. Split
+train, val, test = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])
+
+# 4. Initialize model
+model = PromptEHR(
+ dataset=sample_dataset,
+ n_num_features=1,
+ cat_cardinalities=[2],
+ d_hidden=128,
+ prompt_length=1,
+ epochs=20,
+ batch_size=16,
+ lr=1e-5,
+ warmup_steps=1000,
+ save_dir="./save/promptehr/",
+)
+
+# 5. Train
+model.train_model(train, val)
+print("Training complete. Checkpoint saved to ./save/promptehr/")
diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py
index b38c575c2..00a5e1884 100644
--- a/pyhealth/datasets/__init__.py
+++ b/pyhealth/datasets/__init__.py
@@ -48,10 +48,16 @@ def __init__(self, *args, **kwargs):
from .base_dataset import BaseDataset
from .cardiology import CardiologyDataset
-from .chestxray14 import ChestXray14Dataset
+try:
+ from .chestxray14 import ChestXray14Dataset
+except ImportError:
+ pass # PIL/torchvision unavailable
from .clinvar import ClinVarDataset
from .cosmic import COSMICDataset
-from .covid19_cxr import COVID19CXRDataset
+try:
+ from .covid19_cxr import COVID19CXRDataset
+except ImportError:
+ pass # PIL/torchvision unavailable
from .dreamt import DREAMTDataset
from .ehrshot import EHRShotDataset
from .eicu import eICUDataset
@@ -63,7 +69,10 @@ def __init__(self, *args, **kwargs):
from .omop import OMOPDataset
from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset
from .shhs import SHHSDataset
-from .sleepedf import SleepEDFDataset
+try:
+ from .sleepedf import SleepEDFDataset
+except ImportError:
+ pass # mne unavailable
from .bmd_hs import BMDHSDataset
from .support2 import Support2Dataset
from .tcga_prad import TCGAPRADDataset
@@ -76,8 +85,14 @@ def __init__(self, *args, **kwargs):
split_by_visit,
split_by_visit_conformal,
)
-from .tuab import TUABDataset
-from .tuev import TUEVDataset
+try:
+ from .tuab import TUABDataset
+except ImportError:
+ pass # mne unavailable; TUABDataset not registered
+try:
+ from .tuev import TUEVDataset
+except ImportError:
+ pass # mne unavailable; TUEVDataset not registered
from .utils import (
collate_fn_dict,
collate_fn_dict_with_padding,
diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py
index 7e569d2f3..bdc00f904 100644
--- a/pyhealth/datasets/mimic3.py
+++ b/pyhealth/datasets/mimic3.py
@@ -53,7 +53,7 @@ def __init__(
if config_path is None:
logger.info("No config path provided, using default config")
config_path = Path(__file__).parent / "configs" / "mimic3.yaml"
- default_tables = ["patients", "admissions", "icustays"]
+ default_tables = ["patients", "admissions"]
tables = default_tables + tables
if "prescriptions" in tables:
warnings.warn(
diff --git a/pyhealth/datasets/tuab.py b/pyhealth/datasets/tuab.py
index e2a3fc69c..1ba6cc3c8 100644
--- a/pyhealth/datasets/tuab.py
+++ b/pyhealth/datasets/tuab.py
@@ -5,7 +5,10 @@
from typing import Optional
from .base_dataset import BaseDataset
-from pyhealth.tasks import EEGAbnormalTUAB
+try:
+ from pyhealth.tasks import EEGAbnormalTUAB
+except ImportError:
+ EEGAbnormalTUAB = None # mne unavailable; TUABDataset.default_task will raise if called
logger = logging.getLogger(__name__)
diff --git a/pyhealth/datasets/tuev.py b/pyhealth/datasets/tuev.py
index 7e8dacf98..7dd30fd58 100644
--- a/pyhealth/datasets/tuev.py
+++ b/pyhealth/datasets/tuev.py
@@ -5,7 +5,10 @@
from typing import Optional
from .base_dataset import BaseDataset
-from pyhealth.tasks import EEGEventsTUEV
+try:
+ from pyhealth.tasks import EEGEventsTUEV
+except ImportError:
+ EEGEventsTUEV = None # mne unavailable; TUEVDataset.default_task will raise if called
logger = logging.getLogger(__name__)
diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py
index 14f0bf209..c8dcbe282 100644
--- a/pyhealth/models/__init__.py
+++ b/pyhealth/models/__init__.py
@@ -1,8 +1,14 @@
from .adacare import AdaCare, AdaCareLayer
from .agent import Agent, AgentLayer
from .base_model import BaseModel
-from .biot import BIOT
-from .cnn import CNN, CNNLayer
+try:
+ from .biot import BIOT
+except ImportError:
+ pass # einops unavailable
+try:
+ from .cnn import CNN, CNNLayer
+except ImportError:
+ pass # PIL/torchvision unavailable
from .concare import ConCare, ConCareLayer
from .contrawr import ContraWR, ResBlock2D
from .deepr import Deepr, DeeprLayer
@@ -12,33 +18,63 @@
from .logistic_regression import LogisticRegression
from .gan import GAN
from .gnn import GAT, GCN
-from .graph_torchvision_model import Graph_TorchvisionModel
-from .grasp import GRASP, GRASPLayer
+try:
+ from .graph_torchvision_model import Graph_TorchvisionModel
+except ImportError:
+ pass # torchvision unavailable
+try:
+ from .grasp import GRASP, GRASPLayer
+except ImportError:
+ pass # sklearn unavailable
from .medlink import MedLink
from .micron import MICRON, MICRONLayer
from .mlp import MLP
-from .molerec import MoleRec, MoleRecLayer
+try:
+ from .molerec import MoleRec, MoleRecLayer
+except ImportError:
+ pass # rdkit unavailable
+from .promptehr import PromptEHR
from .retain import RETAIN, RETAINLayer
from .rnn import MultimodalRNN, RNN, RNNLayer
-from .safedrug import SafeDrug, SafeDrugLayer
+try:
+ from .safedrug import SafeDrug, SafeDrugLayer
+except ImportError:
+ pass # rdkit unavailable
from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer
from .stagenet import StageNet, StageNetLayer
from .stagenet_mha import StageAttentionNet, StageNetAttentionLayer
from .tcn import TCN, TCNLayer
-from .tfm_tokenizer import (
- TFMTokenizer,
- TFM_VQVAE2_deep,
- TFM_TOKEN_Classifier,
- get_tfm_tokenizer_2x2x8,
- get_tfm_token_classifier_64x4,
- load_embedding_weights,
-)
-from .torchvision_model import TorchvisionModel
+try:
+ from .tfm_tokenizer import (
+ TFMTokenizer,
+ TFM_VQVAE2_deep,
+ TFM_TOKEN_Classifier,
+ get_tfm_tokenizer_2x2x8,
+ get_tfm_token_classifier_64x4,
+ load_embedding_weights,
+ )
+except ImportError:
+ pass # einops unavailable
+try:
+ from .torchvision_model import TorchvisionModel
+except ImportError:
+ pass # torchvision unavailable
from .transformer import Transformer, TransformerLayer
-from .transformers_model import TransformersModel
+try:
+ from .transformers_model import TransformersModel
+except ImportError:
+ pass # transformers unavailable
from .ehrmamba import EHRMamba, MambaBlock
from .vae import VAE
-from .vision_embedding import VisionEmbeddingModel
-from .text_embedding import TextEmbedding
-from .sdoh import SdohClassifier
-from .medlink import MedLink
+try:
+ from .vision_embedding import VisionEmbeddingModel
+except ImportError:
+ pass # PIL/torchvision unavailable
+try:
+ from .text_embedding import TextEmbedding
+except ImportError:
+ pass # transformers unavailable
+try:
+ from .sdoh import SdohClassifier
+except ImportError:
+ pass # transformers/peft unavailable
diff --git a/pyhealth/models/promptehr/__init__.py b/pyhealth/models/promptehr/__init__.py
new file mode 100644
index 000000000..fdf1327a3
--- /dev/null
+++ b/pyhealth/models/promptehr/__init__.py
@@ -0,0 +1,41 @@
+"""PromptEHR: Prompt-based BART model for synthetic EHR generation.
+
+This module provides a demographic-conditioned sequence-to-sequence model
+for generating realistic synthetic electronic health records.
+
+Main components:
+ - PromptEHR: Main model class (inherits from BaseModel)
+ - ConditionalPromptEncoder: Demographic conditioning with reparameterization
+ - PromptBartEncoder: Modified BART encoder with prompt injection
+ - PromptBartDecoder: Modified BART decoder with prompt injection
+ - VisitStructureSampler: Utility for structure-constrained generation
+ - Generation functions: sample_demographics, parse_sequence_to_visits, etc.
+"""
+
+from .model import PromptEHR
+from .conditional_prompt import ConditionalPromptEncoder
+from .bart_encoder import PromptBartEncoder
+from .bart_decoder import PromptBartDecoder
+from .visit_sampler import VisitStructureSampler
+from .generation import (
+ DemographicSampler,
+ sample_demographics,
+ decode_patient_demographics,
+ parse_sequence_to_visits,
+ generate_patient_sequence_conditional,
+ generate_patient_with_structure_constraints
+)
+
+__all__ = [
+ "PromptEHR",
+ "ConditionalPromptEncoder",
+ "PromptBartEncoder",
+ "PromptBartDecoder",
+ "VisitStructureSampler",
+ "DemographicSampler",
+ "sample_demographics",
+ "decode_patient_demographics",
+ "parse_sequence_to_visits",
+ "generate_patient_sequence_conditional",
+ "generate_patient_with_structure_constraints",
+]
diff --git a/pyhealth/models/promptehr/bart_decoder.py b/pyhealth/models/promptehr/bart_decoder.py
new file mode 100644
index 000000000..e6d01a70b
--- /dev/null
+++ b/pyhealth/models/promptehr/bart_decoder.py
@@ -0,0 +1,325 @@
+"""BART decoder with prompt injection for demographic conditioning.
+
+This module provides a modified BART decoder that accepts demographic prompt
+embeddings and prepends them to decoder input sequences for conditioning.
+
+Ported from pehr_scratch/prompt_bart_decoder.py (lines 1-207).
+"""
+
+import torch
+import torch.nn as nn
+from typing import Optional, Tuple
+from transformers.models.bart.modeling_bart import BartDecoder
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
+
+
+class PromptBartDecoder(BartDecoder):
+ """BART decoder modified to accept and prepend demographic prompt embeddings.
+
+ Extends the standard BART decoder to support prompt-based conditioning by:
+ 1. Accepting optional prompt embeddings as input
+ 2. Prepending prompts to decoder input token embeddings
+ 3. Extending attention masks to cover prepended prompts
+ 4. Creating causal masks for autoregressive generation
+ 5. Processing through standard BART decoder layers with cross-attention
+
+ This enables demographic conditioning (age + gender) by injecting learned
+ prompt vectors at the decoder input, maintaining demographic alignment
+ during generation (dual prompt injection with encoder).
+
+ Args:
+ config: BartConfig from transformers
+ embed_tokens: Token embedding layer (optional)
+
+ Example:
+ >>> from transformers import BartConfig
+ >>> config = BartConfig.from_pretrained("facebook/bart-base")
+ >>> decoder = PromptBartDecoder(config)
+ >>> # Decode with prompts
+ >>> prompt_embeds = torch.randn(16, 2, 768) # [batch, n_prompts, hidden]
+ >>> input_ids = torch.randint(0, 1000, (16, 50)) # [batch, tgt_len]
+ >>> encoder_outputs = torch.randn(16, 100, 768) # [batch, src_len, hidden]
+ >>> outputs = decoder(
+ ... input_ids,
+ ... encoder_hidden_states=encoder_outputs,
+ ... inputs_prompt_embeds=prompt_embeds
+ ... )
+ """
+
+ def __init__(self, config, embed_tokens=None):
+ """Initialize prompt-aware BART decoder.
+
+ Args:
+ config: BartConfig from transformers
+ embed_tokens: Optional token embedding layer
+ """
+ super().__init__(config, embed_tokens)
+
+ # Initialize embedding scale factor (BART uses sqrt(d_model) scaling)
+ self.embed_scale = None
+ if config.scale_embedding:
+ self.embed_scale = (config.d_model ** 0.5)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ inputs_prompt_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BaseModelOutputWithPastAndCrossAttentions:
+ """Forward pass with optional demographic prompt embeddings.
+
+ Args:
+ input_ids: [batch, tgt_seq_len] decoder token IDs
+ attention_mask: [batch, tgt_seq_len] decoder attention mask (1=attend, 0=ignore)
+ encoder_hidden_states: [batch, src_seq_len, hidden_dim] encoder outputs
+ encoder_attention_mask: [batch, src_seq_len] encoder attention mask
+ head_mask: [num_layers, num_heads] mask for self-attention heads
+ cross_attn_head_mask: [num_layers, num_heads] mask for cross-attention heads
+ past_key_values: Cached key-value states for efficient generation
+ inputs_embeds: [batch, tgt_seq_len, hidden_dim] pre-computed embeddings (optional)
+ inputs_prompt_embeds: [batch, n_prompts, hidden_dim] demographic prompts (optional)
+ use_cache: Whether to return key-value cache for generation
+ output_attentions: Whether to return attention weights
+ output_hidden_states: Whether to return all hidden states
+ return_dict: Whether to return BaseModelOutputWithPastAndCrossAttentions or tuple
+
+ Returns:
+ BaseModelOutputWithPastAndCrossAttentions with:
+ - last_hidden_state: [batch, n_prompts + tgt_len, hidden_dim]
+ - past_key_values: Cached key-value states (if use_cache=True)
+ - hidden_states: Tuple of all layer outputs (if output_hidden_states=True)
+ - attentions: Tuple of self-attention weights (if output_attentions=True)
+ - cross_attentions: Tuple of cross-attention weights (if output_attentions=True)
+ """
+ # Set output flags from config defaults
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Get decoder input embeddings from token IDs
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # Apply embedding scaling if configured
+ if self.embed_scale is not None:
+ inputs_embeds = inputs_embeds * self.embed_scale
+
+ # Store original sequence length before prepending prompts
+ original_seq_len = inputs_embeds.shape[1]
+
+ # Prepend prompt embeddings if provided
+ if inputs_prompt_embeds is not None:
+ # Concatenate prompts before decoder input embeddings
+ # inputs_prompt_embeds: [batch, n_prompts, hidden_dim]
+ # inputs_embeds: [batch, tgt_len, hidden_dim]
+ # Result: [batch, n_prompts + tgt_len, hidden_dim]
+ inputs_embeds = torch.cat([inputs_prompt_embeds, inputs_embeds], dim=1)
+
+ # Extend attention mask for prepended prompts
+ batch_size, n_prompts = inputs_prompt_embeds.shape[:2]
+
+ # Create attention mask for prompts (all 1s - always attend to prompts)
+ prompt_attention_mask = torch.ones(
+ batch_size, n_prompts,
+ dtype=attention_mask.dtype if attention_mask is not None else torch.long,
+ device=inputs_embeds.device
+ )
+
+ if attention_mask is not None:
+ # Concatenate prompt mask with decoder attention mask
+ attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
+ else:
+ # Create attention mask for all tokens (prompts + decoder input)
+ total_seq_len = inputs_embeds.shape[1]
+ attention_mask = torch.ones(
+ batch_size, total_seq_len,
+ dtype=torch.long,
+ device=inputs_embeds.device
+ )
+
+ # Get positional embeddings for full sequence (prompts + decoder tokens)
+ past_key_values_length = 0
+ if past_key_values is not None:
+ # Handle Cache object (new transformers API) or tuple (old API)
+ if hasattr(past_key_values, 'get_seq_length'):
+ past_key_values_length = past_key_values.get_seq_length()
+ elif isinstance(past_key_values, (tuple, list)) and len(past_key_values) > 0:
+ # Defensive: handle unexpected cache structures gracefully
+ # pehr-scratch-expert confirmed: defaulting to 0 is safe (slightly degrades
+ # quality but prevents crash). BART handles positional errors gracefully.
+ try:
+ if past_key_values[0] is not None and isinstance(past_key_values[0], (tuple, list)):
+ if len(past_key_values[0]) > 0 and past_key_values[0][0] is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ except (IndexError, TypeError, AttributeError):
+ # Safe fallback: slightly degrades quality but prevents crash
+ # Positional embeddings will be calculated from position 0
+ past_key_values_length = 0
+
+ # Get positional embeddings (BART uses learned positional embeddings)
+ positions = self.embed_positions(inputs_embeds, past_key_values_length)
+
+ # Combine input embeddings + positional embeddings
+ hidden_states = inputs_embeds + positions
+ hidden_states = self.layernorm_embedding(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # Create combined attention mask (causal + padding)
+ if attention_mask is not None:
+ # Create causal mask for decoder self-attention
+ combined_attention_mask = _make_causal_mask(
+ inputs_embeds.shape[:2],
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+ # Expand padding mask and combine with causal mask
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=inputs_embeds.shape[1])
+ combined_attention_mask = combined_attention_mask + expanded_attn_mask
+ else:
+ # Create causal mask only (no padding)
+ combined_attention_mask = _make_causal_mask(
+ inputs_embeds.shape[:2],
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ # Expand encoder attention mask for cross-attention
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ # [batch, src_len] → [batch, 1, tgt_len, src_len]
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=inputs_embeds.shape[1])
+
+ # Initialize output containers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ # Pass through decoder layers
+ for idx, decoder_layer in enumerate(self.layers):
+ # Save hidden state before layer if requested
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # Forward through decoder layer
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=combined_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ # Update hidden states
+ hidden_states = layer_outputs[0]
+
+ # Save attention weights if requested
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ # Save final hidden state if requested
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # Cache is handled by past_key_values object, not returned in tuple
+ next_cache = past_key_values if use_cache else None
+
+ # Return tuple format if not using return_dict
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+
+ # Return BaseModelOutputWithPastAndCrossAttentions
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+def _make_causal_mask(
+ input_shape: Tuple[int, int],
+ dtype: torch.dtype,
+ device: torch.device,
+ past_key_values_length: int = 0
+) -> torch.Tensor:
+ """Create causal mask for decoder self-attention.
+
+ Creates a lower-triangular mask that prevents attending to future positions.
+ This is essential for autoregressive generation where each position can only
+ attend to earlier positions.
+
+ Args:
+ input_shape: (batch_size, tgt_len) shape of decoder input
+ dtype: Data type for mask tensor
+ device: Device to create mask on
+ past_key_values_length: Length of cached key-values from previous steps
+
+ Returns:
+ [batch, 1, tgt_len, tgt_len + past_len] causal mask with -inf for future positions
+ """
+ batch_size, tgt_len = input_shape
+
+ # Initialize mask with -inf (prevents attention)
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
+
+ # Create lower triangular mask (0 for allowed positions, -inf for future)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ # If using cached key-values, allow attending to all past positions
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+
+ # Expand to [batch, 1, tgt_len, tgt_len + past_len]
+ return mask[None, None, :, :].expand(batch_size, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor:
+ """Expand attention mask from [batch, src_len] to [batch, 1, tgt_len, src_len].
+
+ Inverts the mask (1→0, 0→1) and fills masked positions with -inf to prevent attention.
+
+ Args:
+ mask: [batch, src_len] attention mask (1=attend, 0=ignore)
+ dtype: Target data type for the expanded mask
+ tgt_len: Target sequence length (defaults to src_len)
+
+ Returns:
+ [batch, 1, tgt_len, src_len] expanded mask with -inf for masked positions
+ """
+ batch_size, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ # Expand dimensions: [batch, src_len] → [batch, 1, tgt_len, src_len]
+ expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len).to(dtype)
+
+ # Invert mask: 1 (attend) → 0, 0 (ignore) → 1
+ inverted_mask = 1.0 - expanded_mask
+
+ # Fill masked positions with -inf (prevents attention)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
diff --git a/pyhealth/models/promptehr/bart_encoder.py b/pyhealth/models/promptehr/bart_encoder.py
new file mode 100644
index 000000000..726f34cb9
--- /dev/null
+++ b/pyhealth/models/promptehr/bart_encoder.py
@@ -0,0 +1,214 @@
+"""BART encoder with prompt injection for demographic conditioning.
+
+This module provides a modified BART encoder that accepts demographic prompt
+embeddings and prepends them to input sequences for conditioning.
+
+Ported from pehr_scratch/prompt_bart_encoder.py (lines 1-149).
+"""
+
+import torch
+import torch.nn as nn
+from typing import Optional
+from transformers.models.bart.modeling_bart import BartEncoder
+from transformers.modeling_outputs import BaseModelOutput
+
+
+class PromptBartEncoder(BartEncoder):
+ """BART encoder modified to accept and prepend demographic prompt embeddings.
+
+ Extends the standard BART encoder to support prompt-based conditioning by:
+ 1. Accepting optional prompt embeddings as input
+ 2. Prepending prompts to input token embeddings
+ 3. Extending attention masks to cover prepended prompts
+ 4. Processing through standard BART encoder layers
+
+ This enables demographic conditioning (age + gender) by injecting learned
+ prompt vectors at the encoder input.
+
+ Args:
+ config: BartConfig from transformers
+ embed_tokens: Token embedding layer (optional)
+
+ Example:
+ >>> from transformers import BartConfig
+ >>> config = BartConfig.from_pretrained("facebook/bart-base")
+ >>> encoder = PromptBartEncoder(config)
+ >>> # Encode with prompts
+ >>> prompt_embeds = torch.randn(16, 2, 768) # [batch, n_prompts, hidden]
+ >>> input_ids = torch.randint(0, 1000, (16, 100)) # [batch, seq_len]
+ >>> outputs = encoder(input_ids, inputs_prompt_embeds=prompt_embeds)
+ """
+
+ def __init__(self, config, embed_tokens=None):
+ """Initialize prompt-aware BART encoder.
+
+ Args:
+ config: BartConfig from transformers
+ embed_tokens: Optional token embedding layer
+ """
+ super().__init__(config, embed_tokens)
+
+ # Initialize embedding scale factor (BART uses sqrt(d_model) scaling)
+ self.embed_scale = None
+ if config.scale_embedding:
+ self.embed_scale = (config.d_model ** 0.5)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ inputs_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BaseModelOutput:
+ """Forward pass with optional demographic prompt embeddings.
+
+ Args:
+ input_ids: [batch, seq_len] token IDs
+ attention_mask: [batch, seq_len] attention mask (1=attend, 0=ignore)
+ head_mask: [num_layers, num_heads] mask for attention heads
+ inputs_embeds: [batch, seq_len, hidden_dim] pre-computed embeddings (optional)
+ inputs_prompt_embeds: [batch, n_prompts, hidden_dim] demographic prompts (optional)
+ output_attentions: Whether to return attention weights
+ output_hidden_states: Whether to return all hidden states
+ return_dict: Whether to return BaseModelOutput or tuple
+
+ Returns:
+ BaseModelOutput with:
+ - last_hidden_state: [batch, n_prompts + seq_len, hidden_dim]
+ - hidden_states: Tuple of all layer outputs (if output_hidden_states=True)
+ - attentions: Tuple of attention weights (if output_attentions=True)
+ """
+ # Set output flags from config defaults
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Get input embeddings from token IDs
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # Apply embedding scaling if configured
+ if self.embed_scale is not None:
+ inputs_embeds = inputs_embeds * self.embed_scale
+
+ # Prepend prompt embeddings if provided
+ if inputs_prompt_embeds is not None:
+ # Concatenate prompts before input embeddings
+ # inputs_prompt_embeds: [batch, n_prompts, hidden_dim]
+ # inputs_embeds: [batch, seq_len, hidden_dim]
+ # Result: [batch, n_prompts + seq_len, hidden_dim]
+ inputs_embeds = torch.cat([inputs_prompt_embeds, inputs_embeds], dim=1)
+
+ # Extend attention mask to account for prepended prompts
+ batch_size, n_prompts = inputs_prompt_embeds.shape[:2]
+
+ if attention_mask is not None:
+ # Create attention mask for prompts matching existing mask dtype/device
+ prompt_attention_mask = torch.ones(
+ batch_size, n_prompts,
+ dtype=attention_mask.dtype,
+ device=attention_mask.device
+ )
+ # Concatenate prompt mask with original mask
+ attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
+ else:
+ # Create full attention mask for prompts + sequence
+ seq_len = inputs_embeds.shape[1] # Total length including prompts already prepended
+ attention_mask = torch.ones(
+ batch_size, seq_len,
+ dtype=torch.long,
+ device=inputs_embeds.device
+ )
+
+ # Get positional embeddings (BART uses learned positional embeddings)
+ embed_pos = self.embed_positions(inputs_embeds)
+
+ # Combine input embeddings + positional embeddings
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = self.layernorm_embedding(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # Expand attention mask from [batch, seq_len] to [batch, 1, tgt_len, src_len]
+ if attention_mask is not None:
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
+
+ # Initialize output containers
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # Validate head_mask dimensionality
+ if head_mask is not None:
+ if head_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"head_mask should have {len(self.layers)} layers, but has {head_mask.size()[0]}"
+ )
+
+ # Pass through encoder layers
+ for idx, encoder_layer in enumerate(self.layers):
+ # Save hidden state before layer if requested
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ # Get layer-specific head mask
+ layer_head_mask = head_mask[idx] if head_mask is not None else None
+
+ # Forward through encoder layer
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+
+ # Update hidden states
+ hidden_states = layer_outputs[0]
+
+ # Save attention weights if requested
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ # Save final hidden state if requested
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ # Return tuple format if not using return_dict
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+
+ # Return BaseModelOutput
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_states,
+ attentions=all_attentions,
+ )
+
+
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor:
+ """Expand attention mask from [batch, src_len] to [batch, 1, tgt_len, src_len].
+
+ Inverts the mask (1→0, 0→1) and fills masked positions with -inf to prevent attention.
+
+ Args:
+ mask: [batch, src_len] attention mask (1=attend, 0=ignore)
+ dtype: Target data type for the expanded mask
+ tgt_len: Target sequence length (defaults to src_len for encoder self-attention)
+
+ Returns:
+ [batch, 1, tgt_len, src_len] expanded mask with -inf for masked positions
+ """
+ batch_size, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ # Expand dimensions: [batch, src_len] → [batch, 1, tgt_len, src_len]
+ expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len).to(dtype)
+
+ # Invert mask: 1 (attend) → 0, 0 (ignore) → 1
+ inverted_mask = 1.0 - expanded_mask
+
+ # Fill masked positions with -inf (prevents attention)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
diff --git a/pyhealth/models/promptehr/conditional_prompt.py b/pyhealth/models/promptehr/conditional_prompt.py
new file mode 100644
index 000000000..4122a5d31
--- /dev/null
+++ b/pyhealth/models/promptehr/conditional_prompt.py
@@ -0,0 +1,251 @@
+"""Conditional prompt encoder for demographic conditioning.
+
+This module provides demographic conditioning through prompt-based learning
+with reparameterization to prevent overfitting.
+
+Ported from pehr_scratch/conditional_prompt.py (lines 1-219).
+"""
+
+import torch
+import torch.nn as nn
+from typing import Optional
+
+
+class NumericalConditionalPrompt(nn.Module):
+ """Embeds continuous numerical features (e.g., age) with reparameterization.
+
+ Uses intermediate d_hidden=128 dimension for better gradient flow and
+ regularization, following PromptEHR's architecture.
+ """
+
+ def __init__(
+ self,
+ n_num_features: int,
+ hidden_dim: int,
+ d_hidden: int = 128,
+ prompt_length: int = 1
+ ):
+ """Initialize numerical prompt encoder with reparameterization.
+
+ Args:
+ n_num_features: Number of continuous features (1 for age only)
+ hidden_dim: Output dimension size (768 for BART-base)
+ d_hidden: Intermediate reparameterization dimension (default: 128)
+ prompt_length: Number of prompt vectors per feature (default: 1)
+ """
+ super().__init__()
+ self.n_num_features = n_num_features
+ self.hidden_dim = hidden_dim
+ self.d_hidden = d_hidden
+ self.prompt_length = prompt_length
+
+ # Reparameterization: learned weight and bias in d_hidden space
+ self.weight = nn.Parameter(torch.Tensor(n_num_features, d_hidden))
+ self.bias = nn.Parameter(torch.Tensor(n_num_features, d_hidden))
+ nn.init.xavier_uniform_(self.weight)
+ nn.init.xavier_uniform_(self.bias)
+
+ # Project from d_hidden to output dimension
+ self.proj = nn.Linear(d_hidden, hidden_dim, bias=False)
+
+ def forward(self, x_num: torch.Tensor) -> torch.Tensor:
+ """Embed numerical features with reparameterization.
+
+ Args:
+ x_num: [batch, n_num_features] continuous values
+
+ Returns:
+ [batch, prompt_length * n_num_features, hidden_dim] embeddings
+ """
+ # Reparameterization: weight * value + bias
+ # x_num: [batch, n_num_features]
+ # weight: [n_num_features, d_hidden]
+ # Result: [batch, n_num_features, d_hidden]
+ x = self.weight[None] * x_num[..., None]
+ x = x + self.bias[None]
+
+ # Project to output dimension
+ # x: [batch, n_num_features, d_hidden] → [batch, n_num_features, hidden_dim]
+ x = self.proj(x)
+
+ # Output: [batch, n_num_features * prompt_length, hidden_dim]
+ return x
+
+
+class CategoricalConditionalPrompt(nn.Module):
+ """Embeds categorical features with offset-based indexing and reparameterization.
+
+ Uses single embedding table with offset-based indexing to prevent category
+ collision, following PromptEHR's architecture.
+ """
+
+ def __init__(
+ self,
+ cat_cardinalities: list,
+ hidden_dim: int,
+ d_hidden: int = 128,
+ prompt_length: int = 1
+ ):
+ """Initialize categorical prompt encoder with reparameterization.
+
+ Args:
+ cat_cardinalities: List of category counts for each feature
+ [2] for gender (M/F) - ethnicity removed
+ hidden_dim: Output dimension size (768 for BART-base)
+ d_hidden: Intermediate reparameterization dimension (default: 128)
+ prompt_length: Number of prompt vectors per feature (default: 1)
+ """
+ super().__init__()
+ assert cat_cardinalities, 'cat_cardinalities must be non-empty'
+ self.cat_cardinalities = cat_cardinalities
+ self.hidden_dim = hidden_dim
+ self.d_hidden = d_hidden
+ self.prompt_length = prompt_length
+
+ # Compute offset indices to prevent category collision
+ # Example: [2] → offsets = [0]
+ # Gender 0 (M) → index 0, Gender 1 (F) → index 1
+ category_offsets = torch.tensor([0] + cat_cardinalities[:-1]).cumsum(0)
+ self.register_buffer('category_offsets', category_offsets, persistent=False)
+
+ # Single embedding table for all categories
+ total_categories = sum(cat_cardinalities)
+ self.embeddings = nn.Embedding(total_categories, d_hidden)
+
+ # Learned bias per feature (not per category)
+ self.bias = nn.Parameter(torch.Tensor(len(cat_cardinalities), d_hidden))
+ nn.init.xavier_uniform_(self.bias)
+
+ # Project from d_hidden to output dimension
+ self.proj = nn.Linear(d_hidden, hidden_dim, bias=False)
+
+ def forward(self, x_cat: torch.Tensor) -> torch.Tensor:
+ """Embed categorical features with offset-based indexing.
+
+ Args:
+ x_cat: [batch, n_cat_features] categorical IDs
+
+ Returns:
+ [batch, n_cat_features * prompt_length, hidden_dim] embeddings
+ """
+ # Add offsets to prevent category collision
+ # x_cat: [batch, n_cat_features]
+ # category_offsets: [n_cat_features]
+ x = self.embeddings(x_cat + self.category_offsets[None])
+
+ # Add learned bias per feature
+ # x: [batch, n_cat_features, d_hidden]
+ # bias: [n_cat_features, d_hidden]
+ x = x + self.bias[None]
+
+ # Project to output dimension
+ # x: [batch, n_cat_features, d_hidden] → [batch, n_cat_features, hidden_dim]
+ x = self.proj(x)
+
+ # Output: [batch, n_cat_features * prompt_length, hidden_dim]
+ return x
+
+
+class ConditionalPromptEncoder(nn.Module):
+ """Combined prompt encoder for both numerical and categorical features.
+
+ Encodes patient demographics (age + gender) into prompt vectors that
+ condition the BART encoder and decoder.
+
+ Example:
+ >>> # For PromptEHR: age (continuous) + gender (categorical)
+ >>> encoder = ConditionalPromptEncoder(
+ ... n_num_features=1, # age
+ ... cat_cardinalities=[2], # gender (M/F)
+ ... hidden_dim=768, # BART dimension
+ ... d_hidden=128 # reparameterization
+ ... )
+ >>> # Batch of 16 patients
+ >>> age = torch.randn(16, 1) # Normalized ages
+ >>> gender = torch.randint(0, 2, (16, 1)) # 0=M, 1=F
+ >>> prompts = encoder(x_num=age, x_cat=gender)
+ >>> prompts.shape # [16, 2, 768] - 2 prompts (age + gender)
+ """
+
+ def __init__(
+ self,
+ n_num_features: Optional[int] = None,
+ cat_cardinalities: Optional[list] = None,
+ hidden_dim: int = 768,
+ d_hidden: int = 128,
+ prompt_length: int = 1
+ ):
+ """Initialize combined prompt encoder.
+
+ Args:
+ n_num_features: Number of continuous features (None to disable)
+ cat_cardinalities: Category counts for each categorical feature (None to disable)
+ hidden_dim: Hidden dimension size (768 for BART-base)
+ d_hidden: Intermediate reparameterization dimension (default: 128)
+ prompt_length: Number of prompt vectors per feature (default: 1)
+ """
+ super().__init__()
+ self.n_num_features = n_num_features
+ self.cat_cardinalities = cat_cardinalities
+ self.hidden_dim = hidden_dim
+ self.d_hidden = d_hidden
+ self.prompt_length = prompt_length
+
+ # Initialize numerical prompt encoder (age)
+ if n_num_features is not None and n_num_features > 0:
+ self.num_prompt = NumericalConditionalPrompt(
+ n_num_features, hidden_dim, d_hidden, prompt_length
+ )
+ else:
+ self.num_prompt = None
+
+ # Initialize categorical prompt encoder (gender)
+ if cat_cardinalities is not None and len(cat_cardinalities) > 0:
+ self.cat_prompt = CategoricalConditionalPrompt(
+ cat_cardinalities, hidden_dim, d_hidden, prompt_length
+ )
+ else:
+ self.cat_prompt = None
+
+ def forward(
+ self,
+ x_num: Optional[torch.Tensor] = None,
+ x_cat: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Encode demographics to prompt embeddings.
+
+ Args:
+ x_num: [batch, n_num_features] continuous values (optional)
+ x_cat: [batch, n_cat_features] categorical IDs (optional)
+
+ Returns:
+ [batch, total_prompts, hidden_dim] combined prompt embeddings
+ """
+ prompts = []
+
+ if x_num is not None and self.num_prompt is not None:
+ num_embeds = self.num_prompt(x_num)
+ prompts.append(num_embeds)
+
+ if x_cat is not None and self.cat_prompt is not None:
+ cat_embeds = self.cat_prompt(x_cat)
+ prompts.append(cat_embeds)
+
+ if len(prompts) == 0:
+ raise ValueError("No prompt embeddings generated. Provide x_num or x_cat.")
+
+ # Concatenate along prompt dimension
+ combined_prompts = torch.cat(prompts, dim=1)
+ return combined_prompts
+
+ def get_num_prompts(self) -> int:
+ """Calculate total number of prompt tokens."""
+ num_prompts = 0
+
+ if self.num_prompt is not None:
+ num_prompts += self.n_num_features * self.prompt_length
+
+ if self.cat_prompt is not None:
+ num_prompts += len(self.cat_cardinalities) * self.prompt_length
+
+ return num_prompts
diff --git a/pyhealth/models/promptehr/generation.py b/pyhealth/models/promptehr/generation.py
new file mode 100644
index 000000000..3d674d1d1
--- /dev/null
+++ b/pyhealth/models/promptehr/generation.py
@@ -0,0 +1,1070 @@
+"""
+Generate synthetic patient sequences using trained PromptEHR model.
+
+This module provides functions for generating realistic synthetic EHR data
+using various conditioning strategies (demographics, visit structures, etc.).
+"""
+import json
+import math
+import numpy as np
+import torch
+from pathlib import Path
+from typing import Optional, List, Union, Dict
+
+
+class DemographicSampler:
+ """Sample patient demographics from empirical training distribution.
+
+ Samples age and gender by directly drawing from the observed distribution
+ in training data, ensuring synthetic patients match real population.
+ """
+
+ def __init__(self, patient_records: List, seed: int = 42):
+ """Initialize sampler with empirical demographics from training data.
+
+ Args:
+ patient_records: List of patient records from training set.
+ Each record should have 'age' and 'gender' attributes.
+ seed: Random seed for reproducibility.
+ """
+ self.rng = np.random.RandomState(seed)
+
+ # Extract empirical demographics
+ self.ages = []
+ self.genders = []
+
+ for patient in patient_records:
+ # Handle both dict-like and object-like patient records
+ if hasattr(patient, 'age') and hasattr(patient, 'gender'):
+ age = patient.age
+ gender = patient.gender
+ elif isinstance(patient, dict) and 'age' in patient and 'gender' in patient:
+ age = patient['age']
+ gender = patient['gender']
+ else:
+ continue
+
+ self.ages.append(float(age))
+ # Convert gender to int: M=0, F=1
+ if isinstance(gender, str):
+ gender_int = 0 if gender == 'M' else 1
+ else:
+ gender_int = int(gender)
+ self.genders.append(gender_int)
+
+ # Convert to numpy arrays
+ self.ages = np.array(self.ages)
+ self.genders = np.array(self.genders)
+
+ # Compute statistics
+ self.stats = {
+ 'age_mean': np.mean(self.ages),
+ 'age_std': np.std(self.ages),
+ 'age_median': np.median(self.ages),
+ 'age_min': np.min(self.ages),
+ 'age_max': np.max(self.ages),
+ 'male_pct': (self.genders == 0).mean(),
+ 'female_pct': (self.genders == 1).mean(),
+ }
+
+ def sample(self) -> dict:
+ """Sample demographics from empirical distribution.
+
+ Returns:
+ Dictionary with:
+ - 'age': float (sampled from training ages)
+ - 'sex': int (0=Male, 1=Female, sampled from training)
+ - 'sex_str': str ('M' or 'F')
+ """
+ # Sample random index from training data
+ idx = self.rng.randint(0, len(self.ages))
+
+ age = self.ages[idx]
+ sex = self.genders[idx]
+ sex_str = 'M' if sex == 0 else 'F'
+
+ return {
+ 'age': float(age),
+ 'sex': int(sex),
+ 'sex_str': sex_str
+ }
+
+ def __repr__(self):
+ return (
+ f"DemographicSampler(\n"
+ f" Age: mean={self.stats['age_mean']:.1f}, "
+ f"std={self.stats['age_std']:.1f}, "
+ f"range=[{self.stats['age_min']:.0f}, {self.stats['age_max']:.0f}]\n"
+ f" Gender: {self.stats['male_pct']:.1%} Male, "
+ f"{self.stats['female_pct']:.1%} Female\n"
+ f")"
+ )
+
+
+def build_first_code_prior(
+ training_data_path: str,
+ age_bins: int = 9
+) -> Dict:
+ """Build empirical P(first_code | age, gender) from training data.
+
+ Args:
+ training_data_path: Path to training data directory with MIMIC-III files
+ age_bins: Number of age bins (default: 9 for [0-10), [10-20), ..., [80-90])
+
+ Returns:
+ Dictionary mapping (age_bin, gender) -> {code: probability}
+
+ Example:
+ >>> prior = build_first_code_prior('/path/to/train_data')
+ >>> first_code = sample_first_code(65, 0, prior)
+ """
+ import pandas as pd
+
+ # Load training data
+ admissions = pd.read_csv(f'{training_data_path}/ADMISSIONS.csv')
+ patients = pd.read_csv(f'{training_data_path}/PATIENTS.csv')
+ diagnoses = pd.read_csv(f'{training_data_path}/DIAGNOSES_ICD.csv')
+
+ # Calculate age at first admission
+ admissions['ADMITTIME'] = pd.to_datetime(admissions['ADMITTIME'])
+ patients['DOB'] = pd.to_datetime(patients['DOB'])
+
+ first_admissions = admissions.loc[
+ admissions.groupby('SUBJECT_ID')['ADMITTIME'].idxmin()
+ ][['SUBJECT_ID', 'HADM_ID', 'ADMITTIME']]
+
+ demo = pd.merge(
+ patients[['SUBJECT_ID', 'GENDER', 'DOB']],
+ first_admissions,
+ on='SUBJECT_ID',
+ how='inner'
+ )
+ demo['AGE'] = (demo['ADMITTIME'].dt.year - demo['DOB'].dt.year)
+ demo['AGE'] = demo['AGE'].apply(lambda x: 90 if x > 89 else max(0, x))
+
+ # Get first diagnosis codes
+ first_diag = pd.merge(
+ demo[['SUBJECT_ID', 'HADM_ID', 'AGE', 'GENDER']],
+ diagnoses[['SUBJECT_ID', 'HADM_ID', 'ICD9_CODE']],
+ on=['SUBJECT_ID', 'HADM_ID'],
+ how='inner'
+ )
+
+ # Keep only first code per patient (seq_num=1 or first alphabetically)
+ first_diag = first_diag.sort_values(['SUBJECT_ID', 'ICD9_CODE'])
+ first_diag = first_diag.groupby('SUBJECT_ID').first().reset_index()
+
+ # Bin ages
+ first_diag['age_bin'] = pd.cut(
+ first_diag['AGE'],
+ bins=list(range(0, 91, 10)),
+ labels=list(range(age_bins)),
+ include_lowest=True
+ )
+
+ # Convert gender to int (0=M, 1=F)
+ first_diag['gender_int'] = (first_diag['GENDER'] == 'F').astype(int)
+
+ # Calculate empirical distribution
+ dist = {}
+ for (age_bin, gender), group in first_diag.groupby(['age_bin', 'gender_int']):
+ code_counts = group['ICD9_CODE'].value_counts()
+ total = code_counts.sum()
+ dist[(int(age_bin), int(gender))] = {
+ str(code): count / total
+ for code, count in code_counts.items()
+ }
+
+ return dist
+
+
+def sample_first_code(
+ age: float,
+ gender: int,
+ first_code_prior: Dict
+) -> str:
+ """Sample first diagnosis code from empirical distribution.
+
+ Args:
+ age: Patient age (0-90)
+ gender: Patient gender (0=Male, 1=Female)
+ first_code_prior: Prior from build_first_code_prior()
+
+ Returns:
+ Diagnosis code string (e.g., 'V3000', '41401')
+
+ Example:
+ >>> prior = build_first_code_prior('/path/to/train_data')
+ >>> code = sample_first_code(65, 0, prior)
+ >>> print(code) # e.g., 'V3000'
+ """
+ # Bin age
+ age_bin = min(int(age // 10), 8) # [0-9] -> 0, [10-19] -> 1, ..., [80+] -> 8
+
+ # Get distribution for this demographic
+ key = (age_bin, gender)
+ if key not in first_code_prior:
+ # Fallback to gender-only or overall distribution
+ fallback_key = None
+ for k in first_code_prior.keys():
+ if k[1] == gender:
+ fallback_key = k
+ break
+ if fallback_key:
+ key = fallback_key
+ else:
+ key = list(first_code_prior.keys())[0]
+
+ code_probs = first_code_prior[key]
+ codes = list(code_probs.keys())
+ probs = list(code_probs.values())
+
+ return np.random.choice(codes, p=probs)
+
+
+def build_frequency_prior(
+ tokenizer,
+ frequency_path: Optional[Union[str, Path]] = None,
+ epsilon: float = 1e-10,
+ vocab_size: Optional[int] = None
+) -> torch.Tensor:
+ """Build log-frequency prior over vocabulary for frequency-guided generation.
+
+ Args:
+ tokenizer: DiagnosisCodeTokenizer with vocab and code_offset attributes.
+ frequency_path: Path to training_frequencies.json. If None, uses uniform prior.
+ epsilon: Small constant to avoid log(0) (default: 1e-10).
+ vocab_size: Model vocabulary size. If None, inferred from tokenizer (not recommended).
+ Should match model's lm_head output dimension.
+
+ Returns:
+ torch.Tensor of shape [vocab_size] with log-frequencies.
+ Special tokens get 0 (neutral prior), diagnosis codes get log(freq + epsilon).
+
+ Example:
+ >>> prior = build_frequency_prior(tokenizer, './promptehr_outputs/training_frequencies.json', vocab_size=6963)
+ >>> logits_guided = logits + alpha * prior # Blend with model logits
+ """
+ # Use provided vocab size or infer from tokenizer
+ # WARNING: Inferred size may not match model if there's a mismatch!
+ if vocab_size is None:
+ vocab_size = len(tokenizer.vocab.idx2code)
+
+ log_freqs = torch.zeros(vocab_size)
+
+ if frequency_path is None:
+ # Uniform fallback: all codes equally likely
+ uniform_log_freq = math.log(1.0 / len(tokenizer.vocab.idx2code))
+ log_freqs[tokenizer.code_offset:] = uniform_log_freq
+ return log_freqs
+
+ # Load training frequencies
+ with open(frequency_path, 'r') as f:
+ freq_data = json.load(f)
+
+ frequencies = freq_data['frequencies']
+
+ # Fill in log-frequencies for each code
+ # NOTE: We map code_idx directly to token_id without adding code_offset
+ # because the model vocabulary doesn't include code_offset
+ for code, freq in frequencies.items():
+ if code in tokenizer.vocab.code2idx:
+ code_idx = tokenizer.vocab.code2idx[code]
+ if code_idx < vocab_size:
+ log_freqs[code_idx] = math.log(freq + epsilon)
+
+ # Codes not in training data get very low prior
+ min_log_freq = math.log(epsilon)
+ log_freqs = torch.where(
+ log_freqs == 0,
+ torch.tensor(min_log_freq),
+ log_freqs
+ )
+
+ return log_freqs
+
+
+def sample_demographics(
+ age_mean: float = 60.0,
+ age_std: float = 20.0,
+ male_prob: float = 0.56
+) -> dict:
+ """Sample realistic patient demographics.
+
+ Samples demographics from distributions matching MIMIC-III ICU population.
+
+ Args:
+ age_mean: Mean age for normal distribution (default: 60).
+ age_std: Standard deviation for age (default: 20).
+ male_prob: Probability of male gender (default: 0.56).
+
+ Returns:
+ Dictionary with:
+ - 'age': float in range [0, 90]
+ - 'sex': int (0=Male, 1=Female)
+ - 'sex_str': str ('M' or 'F')
+ """
+ # Sample age from normal distribution, clipped to [0, 90]
+ age = np.random.normal(age_mean, age_std)
+ age = np.clip(age, 0, 90)
+
+ # Sample sex from binomial distribution
+ sex = 0 if np.random.rand() < male_prob else 1
+ sex_str = 'M' if sex == 0 else 'F'
+
+ return {
+ 'age': float(age),
+ 'sex': sex,
+ 'sex_str': sex_str
+ }
+
+
+def decode_patient_demographics(age: float, gender: int) -> dict:
+ """Decode demographics back to readable format.
+
+ Args:
+ age: Normalized age value.
+ gender: Gender category index.
+
+ Returns:
+ Dictionary with decoded demographics.
+ """
+ # Gender mapping (from data_loader.py)
+ gender_map = {0: "M", 1: "F"} # Fixed: M=0, F=1
+
+ return {
+ "age": f"{age:.1f}",
+ "gender": gender_map.get(gender, "UNKNOWN")
+ }
+
+
+def parse_sequence_to_visits(
+ token_ids: List[int],
+ tokenizer
+) -> List[List[str]]:
+ """Parse generated token sequence into visit structure.
+
+ Extracts visits by splitting at and markers, and decodes
+ diagnosis codes within each visit.
+
+ Args:
+ token_ids: List of token IDs from model generation.
+ tokenizer: PyHealth Tokenizer instance (must have bos_token_id,
+ pad_token_id, code_offset, and vocab attributes).
+
+ Returns:
+ List of visits, where each visit is a list of ICD-9 code strings.
+
+ Example:
+ Input: [BOS, , 401.9, 250.00, , , 428.0, , ]
+ Output: [['401.9', '250.00'], ['428.0']]
+ """
+ visits = []
+ current_visit_codes = []
+
+ # Special token IDs
+ v_token_id = tokenizer.convert_tokens_to_indices([""])[0]
+ v_end_token_id = tokenizer.convert_tokens_to_indices(["<\\v>"])[0]
+ bos_token_id = tokenizer.bos_token_id
+ end_token_id = tokenizer.convert_tokens_to_indices([""])[0]
+
+ in_visit = False
+
+ for token_id in token_ids:
+ if token_id == v_token_id:
+ # Start of visit
+ in_visit = True
+ current_visit_codes = []
+ elif token_id == v_end_token_id:
+ # End of visit
+ if in_visit:
+ visits.append(current_visit_codes)
+ in_visit = False
+ elif token_id in [bos_token_id, end_token_id, tokenizer.pad_token_id]:
+ # Skip special tokens
+ continue
+ elif in_visit and token_id >= tokenizer.code_offset:
+ # Diagnosis code token - token_id is already the correct vocab index
+ # FIX: code2idx already includes special tokens, so don't subtract offset
+ if token_id < len(tokenizer.vocab.idx2code):
+ code = tokenizer.vocab.idx2code[token_id]
+ current_visit_codes.append(code)
+
+ # Handle case where sequence ends without closing visit marker
+ if in_visit and len(current_visit_codes) > 0:
+ visits.append(current_visit_codes)
+
+ return visits
+
+
+def generate_patient_sequence_conditional(
+ model,
+ tokenizer,
+ target_patient,
+ device: torch.device,
+ temperature: float = 0.3,
+ top_k: int = 0, # Disabled (test with top_p only)
+ top_p: float = 0.95, # Increased for more diversity
+ prompt_prob: float = 0.0,
+ max_codes_per_visit: int = 20
+) -> dict:
+ """Generate synthetic patient via conditional reconstruction (PromptEHR approach).
+
+ Given a real patient from test set, randomly masks codes and reconstructs
+ the full visit structure. Default prompt_prob=0.0 means zero-code-prompt
+ generation (only demographics provided).
+
+ Args:
+ model: Trained PromptBartModel or PromptEHR model.
+ tokenizer: DiagnosisCodeTokenizer instance.
+ target_patient: Patient record from test set to reconstruct.
+ Must have attributes: age, gender (or sex), visits.
+ device: Device to run on.
+ temperature: Sampling temperature (default: 0.3).
+ top_k: Top-k sampling parameter (default: 40).
+ top_p: Nucleus sampling parameter (default: 0.9).
+ prompt_prob: Probability of keeping each code as prompt (default: 0.0 = zero prompts).
+ max_codes_per_visit: Cap visit codes at this number (default: 20).
+
+ Returns:
+ Dictionary with:
+ - 'generated_visits': List[List[str]] of generated code sequences
+ - 'target_visits': List[List[str]] of original codes
+ - 'prompt_codes': List[List[str]] of codes provided as prompts
+ - 'demographics': dict of patient demographics
+ """
+ model.eval()
+
+ # Extract demographics (handle both 'gender' and 'sex' attributes)
+ if hasattr(target_patient, 'age'):
+ age = target_patient.age
+ else:
+ age = target_patient.get('age', 60.0)
+
+ if hasattr(target_patient, 'gender'):
+ gender_str = target_patient.gender
+ elif hasattr(target_patient, 'sex'):
+ gender_str = target_patient.sex
+ else:
+ gender_str = target_patient.get('gender', 'M')
+
+ gender = 1 if gender_str == 'F' else 0
+
+ x_num = torch.tensor([[age]], dtype=torch.float32).to(device)
+ x_cat = torch.tensor([[gender]], dtype=torch.long).to(device)
+
+ # Get visits
+ if hasattr(target_patient, 'visits'):
+ patient_visits = target_patient.visits
+ else:
+ patient_visits = target_patient.get('visits', [])
+
+ # Initialize accumulators
+ generated_visits = []
+ prompt_codes_per_visit = []
+
+ # Create dummy encoder input (prompts are in decoder)
+ encoder_input_ids = torch.tensor([[tokenizer.pad_token_id]], dtype=torch.long).to(device)
+ encoder_attention_mask = torch.ones_like(encoder_input_ids)
+
+ # Special token IDs
+ v_token_id = tokenizer.convert_tokens_to_indices([""])[0]
+ v_end_token_id = tokenizer.convert_tokens_to_indices(["<\\v>"])[0]
+
+ with torch.no_grad():
+ # Process each visit from target patient
+ for visit_idx, target_codes in enumerate(patient_visits):
+ # Step 1: Cap codes at max_codes_per_visit
+ num_codes = len(target_codes)
+ if num_codes > max_codes_per_visit:
+ target_codes = list(np.random.choice(target_codes, max_codes_per_visit, replace=False))
+ num_codes = max_codes_per_visit
+
+ if num_codes == 0:
+ # Empty visit - skip
+ generated_visits.append([])
+ prompt_codes_per_visit.append([])
+ continue
+
+ # Step 2: Randomly mask codes (binomial sampling)
+ keep_mask = np.random.binomial(1, prompt_prob, num_codes).astype(bool)
+ prompt_codes = [code for i, code in enumerate(target_codes) if keep_mask[i]]
+
+ # Step 3: Encode prompt codes as decoder input
+ prompt_token_ids = [tokenizer.bos_token_id, v_token_id]
+ for code in prompt_codes:
+ # FIX: code2idx already returns token ID with offset included
+ code_token_id = tokenizer.vocab.code2idx[code]
+ prompt_token_ids.append(code_token_id)
+
+ decoder_input_ids = torch.tensor([prompt_token_ids], dtype=torch.long).to(device)
+
+ # Step 4: Generate to reconstruct full visit
+ max_new_tokens = num_codes + 2 # Target length
+
+ # Use model.generate() for automatic handling
+ generated_ids = model.generate(
+ input_ids=encoder_input_ids,
+ attention_mask=encoder_attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ x_num=x_num,
+ x_cat=x_cat,
+ max_new_tokens=max_new_tokens,
+ do_sample=True,
+ num_beams=1, # Disable beam search, use sampling only
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ no_repeat_ngram_size=1, # Prevents duplicate codes
+ eos_token_id=v_end_token_id, # Stop at
+ pad_token_id=tokenizer.pad_token_id,
+ bad_words_ids=[[tokenizer.bos_token_id]] # Suppress BOS in generation
+ )
+
+ # Step 5: Extract generated codes
+ visit_token_ids = generated_ids[0].cpu().tolist()
+
+ # Extract code tokens (skip BOS, , )
+ generated_code_ids = [
+ tid for tid in visit_token_ids
+ if tid >= tokenizer.code_offset
+ ]
+
+ # Decode codes (convert token IDs back to diagnosis codes)
+ # FIX: code2idx already includes special tokens, so don't subtract offset
+ generated_codes = []
+ for tid in generated_code_ids:
+ if tid < len(tokenizer.vocab.idx2code):
+ code = tokenizer.vocab.idx2code[tid]
+ generated_codes.append(code)
+
+ # Step 6: Combine with prompt codes and deduplicate
+ all_codes = list(set(generated_codes + prompt_codes))
+
+ # Ensure exactly num_codes by sampling if needed
+ if len(all_codes) < num_codes:
+ # Not enough unique codes generated - resample with replacement
+ needed = num_codes - len(all_codes)
+ additional = list(np.random.choice(generated_codes, needed, replace=True)) if len(generated_codes) > 0 else []
+ all_codes.extend(additional)
+ elif len(all_codes) > num_codes:
+ # Too many codes - sample exactly num_codes
+ all_codes = list(np.random.choice(all_codes, num_codes, replace=False))
+
+ generated_visits.append(all_codes)
+ prompt_codes_per_visit.append(prompt_codes)
+
+ return {
+ 'generated_visits': generated_visits,
+ 'target_visits': patient_visits,
+ 'prompt_codes': prompt_codes_per_visit,
+ 'demographics': {
+ 'age': age,
+ 'gender': gender_str
+ }
+ }
+
+
+def generate_patient_with_structure_constraints(
+ model,
+ tokenizer,
+ device: torch.device,
+ target_structure: dict,
+ age: Optional[float] = None,
+ sex: Optional[int] = None,
+ first_code: Optional[str] = None,
+ temperature: float = 0.7,
+ top_k: int = 0, # Disabled (test with top_p only)
+ top_p: float = 0.95, # Increased for more diversity
+ max_codes_per_visit: int = 25
+) -> dict:
+ """Generate patient with realistic visit structure constraints.
+
+ This function generates patients visit-by-visit with controlled code counts
+ sampled from real data distributions, producing more realistic EHR records.
+
+ Args:
+ model: Trained PromptBartModel or PromptEHR model.
+ tokenizer: DiagnosisCodeTokenizer instance.
+ device: Device to run on.
+ target_structure: Dict with 'num_visits' and 'codes_per_visit' list.
+ age: Patient age (if None, sampled from distribution).
+ sex: Patient sex ID (0=M, 1=F; if None, sampled).
+ first_code: First diagnosis code to condition on (if None, generated by model).
+ temperature: Sampling temperature (default: 0.7).
+ top_k: Top-k sampling parameter (default: 40).
+ top_p: Nucleus sampling parameter (default: 0.9).
+ max_codes_per_visit: Maximum codes per visit safety cap (default: 25).
+
+ Returns:
+ Dictionary with:
+ - 'generated_visits': List[List[str]] of diagnosis codes
+ - 'demographics': dict with 'age' and 'sex'
+ - 'num_visits': int
+ - 'num_codes': int
+ - 'target_structure': dict (the structure we aimed for)
+ """
+ model.eval()
+
+ # Sample demographics if not provided
+ if age is None or sex is None:
+ sampled_demo = sample_demographics()
+ age = sampled_demo['age'] if age is None else age
+ sex = sampled_demo['sex'] if sex is None else sex
+
+ # Prepare demographic tensors
+ x_num = torch.tensor([[age]], dtype=torch.float32).to(device)
+ x_cat = torch.tensor([[sex]], dtype=torch.long).to(device)
+
+ # Special token IDs
+ bos_token_id = tokenizer.bos_token_id
+ v_token_id = tokenizer.convert_tokens_to_indices([""])[0]
+ v_end_token_id = tokenizer.convert_tokens_to_indices(["<\\v>"])[0]
+ end_token_id = tokenizer.convert_tokens_to_indices([""])[0]
+
+ # Extract target structure
+ num_visits = target_structure['num_visits']
+ codes_per_visit = target_structure['codes_per_visit']
+
+ # Handle case with no visits
+ if num_visits == 0 or len(codes_per_visit) == 0:
+ return {
+ 'generated_visits': [],
+ 'demographics': {'age': age, 'sex': sex},
+ 'num_visits': 0,
+ 'num_codes': 0,
+ 'target_structure': target_structure
+ }
+
+ # Initialize generation with empty sequence
+ # HuggingFace will prepend decoder_start_token_id () automatically
+ # This matches training pattern: [, , codes...] after first is appended
+ decoder_input_ids = torch.tensor([[]], dtype=torch.long).to(device)
+
+ # If first_code provided, prepopulate decoder with + first_code (no )
+ # This starts visit 0 with the sampled first code, then continues generating
+ first_visit_prepopulated = False
+ if first_code is not None and first_code in tokenizer.vocab.code2idx:
+ v_token_id_temp = tokenizer.convert_tokens_to_indices([""])[0]
+ first_code_id = tokenizer.vocab.code2idx[first_code]
+
+ # Add , first_code to decoder_input_ids (NO yet - let generation continue)
+ prepop_ids = torch.tensor([[v_token_id_temp, first_code_id]],
+ dtype=torch.long).to(device)
+ decoder_input_ids = torch.cat([decoder_input_ids, prepop_ids], dim=1)
+ first_visit_prepopulated = True
+
+ # Create dummy encoder input
+ encoder_input_ids = torch.tensor([[tokenizer.pad_token_id]], dtype=torch.long).to(device)
+ encoder_attention_mask = torch.ones_like(encoder_input_ids)
+
+ all_visits = []
+
+ with torch.no_grad():
+ for visit_idx in range(num_visits):
+ target_codes = min(codes_per_visit[visit_idx], max_codes_per_visit)
+
+ # For visit 0 with prepopulated first_code, reduce target by 1 since we already have 1 code
+ if visit_idx == 0 and first_visit_prepopulated:
+ target_codes = max(1, target_codes - 1) # At least 1 more code
+
+ # Skip if target is too small
+ if target_codes < 1:
+ continue
+
+ # Append token to start visit
+ v_token_tensor = torch.tensor([[v_token_id]], dtype=torch.long).to(device)
+ decoder_input_ids = torch.cat([decoder_input_ids, v_token_tensor], dim=1)
+
+ # Calculate max tokens to generate for this visit
+ # Each code is ~1 token, plus 1 for
+ # Add 50% buffer for flexibility
+ max_new_tokens_this_visit = int(target_codes * 1.5) + 1
+
+ try:
+ # Generate codes for this visit
+ generated_visit_ids = model.generate(
+ input_ids=encoder_input_ids,
+ attention_mask=encoder_attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ x_num=x_num,
+ x_cat=x_cat,
+ max_new_tokens=max_new_tokens_this_visit,
+ do_sample=True,
+ num_beams=1,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ no_repeat_ngram_size=1,
+ eos_token_id=v_end_token_id, # Stop at visit end
+ pad_token_id=tokenizer.pad_token_id
+ # Note: NOT passing bos_token_id - let BART use decoder_start_token_id () automatically
+ )
+
+ # Extract only the newly generated tokens (after decoder_input_ids)
+ new_tokens = generated_visit_ids[0, decoder_input_ids.shape[1]:]
+
+ # Parse the generated visit codes
+ visit_codes = []
+ for token_id in new_tokens:
+ token_id_val = token_id.item()
+ if token_id_val == v_end_token_id:
+ break # End of visit
+ elif token_id_val >= tokenizer.code_offset:
+ # Diagnosis code - token_id_val is already the correct vocab index
+ # FIX: code2idx already includes special tokens, so don't subtract offset
+ if token_id_val < len(tokenizer.vocab.idx2code):
+ code = tokenizer.vocab.idx2code[token_id_val]
+ visit_codes.append(code)
+
+ # If we generated codes, add visit
+ if len(visit_codes) > 0:
+ # Truncate to target if we over-generated
+ if len(visit_codes) > target_codes:
+ visit_codes = visit_codes[:target_codes]
+
+ all_visits.append(visit_codes)
+
+ # Update decoder_input_ids with the full visit (including )
+ # Reconstruct the visit tokens
+ visit_token_ids = [v_token_id] #
+ for code in visit_codes:
+ if code in tokenizer.vocab.code2idx:
+ # FIX: code2idx already returns token ID with offset included
+ code_token_id = tokenizer.vocab.code2idx[code]
+ visit_token_ids.append(code_token_id)
+ visit_token_ids.append(v_end_token_id) #
+
+ # Convert to tensor and concatenate (skip first since already added)
+ visit_tensor = torch.tensor([visit_token_ids[1:]], dtype=torch.long).to(device)
+ decoder_input_ids = torch.cat([decoder_input_ids, visit_tensor], dim=1)
+
+ except Exception as e:
+ # If generation fails for this visit, skip it
+ print(f"Warning: Generation failed for visit {visit_idx + 1}: {e}")
+ continue
+
+ # Check if we're approaching context limit (512 for BART)
+ if decoder_input_ids.shape[1] > 400:
+ break # Stop generating more visits
+
+ # Compute statistics
+ total_codes = sum(len(visit) for visit in all_visits)
+
+ return {
+ 'generated_visits': all_visits,
+ 'demographics': {'age': age, 'sex': sex},
+ 'num_visits': len(all_visits),
+ 'num_codes': total_codes,
+ 'target_structure': target_structure
+ }
+
+
+def generate_with_frequency_prior(
+ model,
+ tokenizer,
+ device: torch.device,
+ target_structure: dict,
+ frequency_prior: torch.Tensor,
+ alpha: float = 1.0,
+ age: Optional[float] = None,
+ sex: Optional[int] = None,
+ temperature: float = 0.7,
+ top_k: int = 0,
+ top_p: float = 0.95,
+ max_codes_per_visit: int = 25,
+ diagnostic_mode: bool = False,
+ diagnostic_path: Optional[str] = None
+) -> dict:
+ """Generate patient with frequency-guided sampling.
+
+ This function is identical to generate_patient_with_structure_constraints,
+ but blends model logits with training frequency prior for realistic code distributions.
+
+ Args:
+ model: Trained PromptBartModel or PromptEHR model.
+ tokenizer: DiagnosisCodeTokenizer instance.
+ device: Device to run on.
+ target_structure: Dict with 'num_visits' and 'codes_per_visit' list.
+ frequency_prior: [vocab_size] log-frequency tensor from build_frequency_prior().
+ alpha: Blending weight (0=pure model, higher=more frequency guidance).
+ Recommended: 0.5-2.0. Start with 1.0.
+ age: Patient age (if None, sampled from distribution).
+ sex: Patient sex ID (0=M, 1=F; if None, sampled).
+ temperature: Sampling temperature (default: 0.7).
+ top_k: Top-k sampling parameter (default: 0 = disabled).
+ top_p: Nucleus sampling parameter (default: 0.95).
+ max_codes_per_visit: Maximum codes per visit safety cap (default: 25).
+ diagnostic_mode: Enable detailed logging of generation process (default: False).
+ diagnostic_path: Path to save diagnostic JSON file (required if diagnostic_mode=True).
+
+ Returns:
+ Dictionary with:
+ - 'generated_visits': List[List[str]] of diagnosis codes
+ - 'demographics': dict with 'age' and 'sex'
+ - 'num_visits': int
+ - 'num_codes': int
+ - 'target_structure': dict (the structure we aimed for)
+ - 'alpha': float (frequency prior weight used)
+ - 'diagnostics': dict (if diagnostic_mode=True) with detailed generation logs
+
+ Example:
+ >>> prior = build_frequency_prior(tokenizer, './promptehr_outputs/training_frequencies.json')
+ >>> result = generate_with_frequency_prior(
+ ... model, tokenizer, device,
+ ... target_structure={'num_visits': 3, 'codes_per_visit': [5, 8, 6]},
+ ... frequency_prior=prior,
+ ... alpha=1.0
+ ... )
+ """
+ model.eval()
+
+ # Sample demographics if not provided
+ if age is None or sex is None:
+ sampled_demo = sample_demographics()
+ age = sampled_demo['age'] if age is None else age
+ sex = sampled_demo['sex'] if sex is None else sex
+
+ # Prepare demographic tensors
+ x_num = torch.tensor([[age]], dtype=torch.float32).to(device)
+ x_cat = torch.tensor([[sex]], dtype=torch.long).to(device)
+
+ # Move frequency prior to device
+ frequency_prior = frequency_prior.to(device)
+
+ # Special token IDs
+ bos_token_id = tokenizer.bos_token_id
+ v_token_id = tokenizer.convert_tokens_to_indices([""])[0]
+ v_end_token_id = tokenizer.convert_tokens_to_indices(["<\\v>"])[0]
+
+ # Extract target structure
+ num_visits = target_structure['num_visits']
+ codes_per_visit = target_structure['codes_per_visit']
+
+ # Handle case with no visits
+ if num_visits == 0 or len(codes_per_visit) == 0:
+ return {
+ 'generated_visits': [],
+ 'demographics': {'age': age, 'sex': sex},
+ 'num_visits': 0,
+ 'num_codes': 0,
+ 'target_structure': target_structure,
+ 'alpha': alpha
+ }
+
+ # Initialize generation with empty sequence
+ # HuggingFace will prepend decoder_start_token_id () automatically
+ # This matches training pattern: [, , codes...] after first is appended
+ decoder_input_ids = torch.tensor([[]], dtype=torch.long).to(device)
+
+ # Create dummy encoder input
+ encoder_input_ids = torch.tensor([[tokenizer.pad_token_id]], dtype=torch.long).to(device)
+ encoder_attention_mask = torch.ones_like(encoder_input_ids)
+
+ all_visits = []
+
+ # Initialize diagnostic tracking
+ all_diagnostics = {'visits': []} if diagnostic_mode else None
+
+ with torch.no_grad():
+ for visit_idx in range(num_visits):
+ target_codes = min(codes_per_visit[visit_idx], max_codes_per_visit)
+
+ # Skip if target is too small
+ if target_codes < 1:
+ continue
+
+ # Append token to start visit
+ v_token_tensor = torch.tensor([[v_token_id]], dtype=torch.long).to(device)
+ decoder_input_ids = torch.cat([decoder_input_ids, v_token_tensor], dim=1)
+
+ # Generate codes for this visit with frequency guidance
+ max_new_tokens_this_visit = int(target_codes * 1.5) + 1
+ visit_codes = []
+
+ # Initialize visit diagnostic tracking
+ visit_diagnostics = {'visit_idx': visit_idx, 'steps': []} if diagnostic_mode else None
+
+ for step in range(max_new_tokens_this_visit):
+ # Forward pass
+ outputs = model(
+ input_ids=encoder_input_ids,
+ attention_mask=encoder_attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ x_num=x_num,
+ x_cat=x_cat,
+ return_dict=True
+ )
+
+ # Get logits for next token (handle both dict and object outputs)
+ if hasattr(outputs, 'logits'):
+ logits = outputs.logits[0, -1, :] # [vocab_size]
+ elif isinstance(outputs, dict) and 'logits' in outputs:
+ logits = outputs['logits'][0, -1, :] # [vocab_size]
+ else:
+ raise TypeError(f"Unexpected output type: {type(outputs)}")
+
+ # Diagnostic logging: raw model logits
+ if diagnostic_mode:
+ step_diagnostics = {
+ 'step': step,
+ 'raw_logits': {
+ 'max': float(logits.max()),
+ 'min': float(logits.min()),
+ 'mean': float(logits.mean()),
+ 'std': float(logits.std()),
+ 'top_5_indices': [int(i) for i in logits.topk(5).indices],
+ 'top_5_codes': [tokenizer.vocab.idx2code.get(int(i), f"<{i}>")
+ for i in logits.topk(5).indices],
+ 'top_5_values': [float(v) for v in logits.topk(5).values]
+ }
+ }
+
+ # BLEND with frequency prior
+ logits_guided = logits + alpha * frequency_prior
+
+ # Diagnostic logging: frequency blending
+ if diagnostic_mode:
+ step_diagnostics['blending'] = {
+ 'alpha': alpha,
+ 'prior_contribution': float((alpha * frequency_prior).abs().mean()),
+ 'logits_shift': float((logits_guided - logits).abs().mean()),
+ 'top_5_after_blend_indices': [int(i) for i in logits_guided.topk(5).indices],
+ 'top_5_after_blend_codes': [tokenizer.vocab.idx2code.get(int(i), f"<{i}>")
+ for i in logits_guided.topk(5).indices],
+ 'top_5_after_blend_values': [float(v) for v in logits_guided.topk(5).values]
+ }
+
+ # Apply temperature
+ scaled_logits = logits_guided / temperature
+
+ # Convert to probabilities
+ probs = torch.softmax(scaled_logits, dim=0)
+
+ # Diagnostic logging: probabilities after temperature
+ if diagnostic_mode:
+ top_probs, top_indices = torch.topk(probs, 20)
+ step_diagnostics['probabilities'] = {
+ 'temperature': temperature,
+ 'entropy': float(-(probs * torch.log(probs + 1e-10)).sum()),
+ 'top_20': [
+ {'code': tokenizer.vocab.idx2code.get(int(idx), f"<{idx}>"),
+ 'prob': float(prob),
+ 'idx': int(idx)}
+ for idx, prob in zip(top_indices, top_probs)
+ ]
+ }
+
+ # Apply top-k filtering if enabled
+ if top_k > 0:
+ top_k_vals, top_k_indices = torch.topk(probs, min(top_k, probs.size(-1)))
+ probs_filtered = torch.zeros_like(probs)
+ probs_filtered.scatter_(0, top_k_indices, top_k_vals)
+ probs = probs_filtered / probs_filtered.sum()
+
+ # Apply nucleus (top-p) sampling
+ if top_p < 1.0:
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
+ cumsum_probs = torch.cumsum(sorted_probs, dim=0)
+ nucleus_mask = cumsum_probs <= top_p
+ nucleus_mask[0] = True # Always include top token
+
+ nucleus_indices = sorted_indices[nucleus_mask]
+ nucleus_probs = sorted_probs[nucleus_mask]
+ nucleus_probs = nucleus_probs / nucleus_probs.sum()
+
+ # Sample from nucleus
+ sampled_idx = torch.multinomial(nucleus_probs, 1)[0]
+ next_token = int(nucleus_indices[sampled_idx])
+ else:
+ # Sample directly from filtered probs
+ next_token = int(torch.multinomial(probs, 1)[0])
+
+ # Diagnostic logging: sampling decision
+ if diagnostic_mode:
+ selected_code = tokenizer.vocab.idx2code.get(next_token, f"<{next_token}>")
+ step_diagnostics['selected'] = {
+ 'token': next_token,
+ 'code': selected_code,
+ 'probability': float(probs[next_token]) if next_token < len(probs) else 0.0,
+ 'was_top_1': (next_token == int(probs.argmax())),
+ 'is_special_token': next_token < tokenizer.code_offset
+ }
+ visit_diagnostics['steps'].append(step_diagnostics)
+
+ # Check if we hit end-of-visit
+ if next_token == v_end_token_id:
+ break
+
+ # Extract code if it's a diagnosis code
+ # FIX: code2idx already includes special tokens, so don't subtract offset
+ if next_token >= tokenizer.code_offset:
+ if next_token < len(tokenizer.vocab.idx2code):
+ code = tokenizer.vocab.idx2code[next_token]
+ if code not in visit_codes: # Prevent duplicates
+ visit_codes.append(code)
+
+ # Append token to decoder input
+ next_token_tensor = torch.tensor([[next_token]], dtype=torch.long).to(device)
+ decoder_input_ids = torch.cat([decoder_input_ids, next_token_tensor], dim=1)
+
+ # Stop if we have enough codes
+ if len(visit_codes) >= target_codes:
+ break
+
+ # Add visit if we generated codes
+ if len(visit_codes) > 0:
+ # Truncate to target if over-generated
+ if len(visit_codes) > target_codes:
+ visit_codes = visit_codes[:target_codes]
+
+ all_visits.append(visit_codes)
+
+ # Add visit diagnostics
+ if diagnostic_mode:
+ visit_diagnostics['generated_codes'] = visit_codes
+ visit_diagnostics['target_codes'] = target_codes
+ all_diagnostics['visits'].append(visit_diagnostics)
+
+ # Append to close visit
+ v_end_tensor = torch.tensor([[v_end_token_id]], dtype=torch.long).to(device)
+ decoder_input_ids = torch.cat([decoder_input_ids, v_end_tensor], dim=1)
+
+ # Check if we're approaching context limit
+ if decoder_input_ids.shape[1] > 400:
+ break
+
+ # Compute statistics
+ total_codes = sum(len(visit) for visit in all_visits)
+
+ # Build result dictionary
+ result = {
+ 'generated_visits': all_visits,
+ 'demographics': {'age': age, 'sex': sex},
+ 'num_visits': len(all_visits),
+ 'num_codes': total_codes,
+ 'target_structure': target_structure,
+ 'alpha': alpha
+ }
+
+ # Add diagnostics if enabled
+ if diagnostic_mode:
+ all_diagnostics['demographics'] = {'age': age, 'sex': sex}
+ all_diagnostics['params'] = {
+ 'alpha': alpha,
+ 'temperature': temperature,
+ 'top_k': top_k,
+ 'top_p': top_p
+ }
+ all_diagnostics['generated_codes'] = all_visits
+ result['diagnostics'] = all_diagnostics
+
+ # Save diagnostics to file if path provided
+ if diagnostic_path:
+ import json
+ import os
+ os.makedirs(os.path.dirname(diagnostic_path), exist_ok=True)
+ with open(diagnostic_path, 'w') as f:
+ json.dump(all_diagnostics, f, indent=2)
+
+ return result
diff --git a/pyhealth/models/promptehr/model.py b/pyhealth/models/promptehr/model.py
new file mode 100644
index 000000000..d3697994e
--- /dev/null
+++ b/pyhealth/models/promptehr/model.py
@@ -0,0 +1,808 @@
+"""PromptEHR: BART-based generative model for synthetic EHR generation.
+
+This module provides the main PromptEHR model that combines demographic-conditioned
+prompts with BART encoder-decoder architecture for realistic patient record generation.
+
+Ported from pehr_scratch/prompt_bart_model.py (lines 16-276, excluding auxiliary losses).
+"""
+
+import os
+import random
+import sys
+from typing import Dict, List, Optional, Tuple
+import torch
+import torch.nn as nn
+from torch.nn.utils.rnn import pad_sequence
+
+# Temporarily hide torchvision so transformers skips the
+# image_utils → torchvision → PIL import chain (which fails in Colab
+# due to mixed-version Pillow files). PromptEHR only needs BART,
+# not any vision functionality from transformers.
+_tv = sys.modules.pop("torchvision", None)
+try:
+ from transformers import BartConfig, BartForConditionalGeneration
+ from transformers.modeling_outputs import Seq2SeqLMOutput
+finally:
+ if _tv is not None:
+ sys.modules["torchvision"] = _tv
+
+del _tv
+
+from pyhealth.models import BaseModel
+from .conditional_prompt import ConditionalPromptEncoder
+from .bart_encoder import PromptBartEncoder
+from .bart_decoder import PromptBartDecoder
+
+
+class PromptBartModel(BartForConditionalGeneration):
+ """BART model with demographic prompt conditioning for EHR generation.
+
+ Extends HuggingFace's BartForConditionalGeneration with:
+ 1. Dual prompt encoders (separate for encoder/decoder)
+ 2. Demographic conditioning via learned prompt vectors
+ 3. Label smoothing for diverse generation
+
+ This is the core generative model WITHOUT auxiliary losses (those caused
+ mode collapse and are excluded per implementation decision D003).
+
+ Args:
+ config: BART configuration from transformers
+ n_num_features: Number of continuous features (1 for age)
+ cat_cardinalities: Category counts for categorical features ([2] for gender M/F)
+ d_hidden: Intermediate reparameterization dimension (default: 128)
+ prompt_length: Number of prompt vectors per feature (default: 1)
+
+ Example:
+ >>> from transformers import BartConfig
+ >>> config = BartConfig.from_pretrained("facebook/bart-base")
+ >>> model = PromptBartModel(
+ ... config,
+ ... n_num_features=1, # age
+ ... cat_cardinalities=[2], # gender (M/F)
+ ... d_hidden=128,
+ ... prompt_length=1
+ ... )
+ >>> # Forward pass with demographics
+ >>> age = torch.randn(16, 1) # [batch, 1]
+ >>> gender = torch.randint(0, 2, (16, 1)) # [batch, 1]
+ >>> input_ids = torch.randint(0, 1000, (16, 100))
+ >>> labels = torch.randint(0, 1000, (16, 50))
+ >>> output = model(
+ ... input_ids=input_ids,
+ ... labels=labels,
+ ... x_num=age,
+ ... x_cat=gender
+ ... )
+ >>> loss = output.loss
+ """
+
+ def __init__(
+ self,
+ config: BartConfig,
+ n_num_features: Optional[int] = None,
+ cat_cardinalities: Optional[list] = None,
+ d_hidden: int = 128,
+ prompt_length: int = 1
+ ):
+ """Initialize PromptBART model with dual prompt conditioning.
+
+ Args:
+ config: BART configuration
+ n_num_features: Number of continuous features (e.g., 1 for age)
+ cat_cardinalities: Category counts for categorical features [n_genders]
+ d_hidden: Intermediate reparameterization dimension (default: 128)
+ prompt_length: Number of prompt vectors per feature (default: 1)
+ """
+ super().__init__(config)
+
+ # Replace encoder and decoder with prompt-aware versions
+ self.model.encoder = PromptBartEncoder(config, self.model.shared)
+ self.model.decoder = PromptBartDecoder(config, self.model.shared)
+
+ # Add SEPARATE conditional prompt encoders for encoder and decoder
+ # This provides stronger demographic conditioning than shared prompts (dual injection)
+ if n_num_features is not None or cat_cardinalities is not None:
+ # Encoder prompt encoder
+ self.encoder_prompt_encoder = ConditionalPromptEncoder(
+ n_num_features=n_num_features,
+ cat_cardinalities=cat_cardinalities,
+ hidden_dim=config.d_model,
+ d_hidden=d_hidden,
+ prompt_length=prompt_length
+ )
+ # Decoder prompt encoder (separate parameters for dual injection)
+ self.decoder_prompt_encoder = ConditionalPromptEncoder(
+ n_num_features=n_num_features,
+ cat_cardinalities=cat_cardinalities,
+ hidden_dim=config.d_model,
+ d_hidden=d_hidden,
+ prompt_length=prompt_length
+ )
+ self.num_prompts = self.encoder_prompt_encoder.get_num_prompts()
+ else:
+ self.encoder_prompt_encoder = None
+ self.decoder_prompt_encoder = None
+ self.num_prompts = 0
+
+ # Initialize weights
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ x_num: Optional[torch.FloatTensor] = None,
+ x_cat: Optional[torch.LongTensor] = None,
+ ) -> Seq2SeqLMOutput:
+ """Forward pass with demographic conditioning.
+
+ Args:
+ input_ids: [batch, seq_len] encoder input token IDs
+ attention_mask: [batch, seq_len] encoder attention mask
+ decoder_input_ids: [batch, tgt_len] decoder input token IDs
+ decoder_attention_mask: [batch, tgt_len] decoder attention mask
+ labels: [batch, tgt_len] target labels for loss computation
+ x_num: [batch, n_num_features] continuous demographic features (e.g., age)
+ x_cat: [batch, n_cat_features] categorical demographic features (e.g., gender)
+ Other args: Standard BART arguments
+
+ Returns:
+ Seq2SeqLMOutput with:
+ - loss: Cross-entropy loss with label smoothing=0.1
+ - logits: [batch, tgt_len, vocab_size] prediction logits
+ - past_key_values: Cached key-value states (if use_cache=True)
+ - decoder_hidden_states: Decoder layer outputs (if output_hidden_states=True)
+ - encoder_last_hidden_state: Final encoder output
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Encode demographic prompts separately for encoder and decoder
+ # Only prepend prompts on first step (when no cache exists)
+ encoder_prompt_embeds = None
+ decoder_prompt_embeds = None
+ if (x_num is not None or x_cat is not None) and past_key_values is None:
+ if self.encoder_prompt_encoder is not None:
+ encoder_prompt_embeds = self.encoder_prompt_encoder(x_num=x_num, x_cat=x_cat)
+ if self.decoder_prompt_encoder is not None:
+ decoder_prompt_embeds = self.decoder_prompt_encoder(x_num=x_num, x_cat=x_cat)
+
+ # Prepare decoder input IDs (shift labels right for teacher forcing)
+ if labels is not None:
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ # Encoder forward pass (with encoder prompts)
+ if encoder_outputs is None:
+ encoder_outputs = self.model.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ inputs_prompt_embeds=encoder_prompt_embeds, # Encoder-specific prompts
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ # Extend encoder attention mask for prompts
+ encoder_attention_mask = attention_mask
+ if encoder_prompt_embeds is not None and attention_mask is not None:
+ batch_size, n_prompts = encoder_prompt_embeds.shape[:2]
+ prompt_mask = torch.ones(batch_size, n_prompts, dtype=attention_mask.dtype, device=attention_mask.device)
+ encoder_attention_mask = torch.cat([prompt_mask, attention_mask], dim=1)
+
+ # Decoder forward pass (with decoder prompts)
+ decoder_outputs = self.model.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ inputs_prompt_embeds=decoder_prompt_embeds, # Decoder-specific prompts
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ # Language modeling head
+ lm_logits = self.lm_head(decoder_outputs[0])
+
+ # If decoder prompts were prepended, slice them off before loss computation
+ if decoder_prompt_embeds is not None and labels is not None:
+ # decoder_outputs[0] shape: [batch, n_prompts + seq_len, hidden_dim]
+ # We only want logits for the actual sequence positions
+ n_prompts = decoder_prompt_embeds.shape[1]
+ lm_logits = lm_logits[:, n_prompts:, :] # Remove prompt positions
+
+ # Compute loss if labels provided
+ loss = None
+ if labels is not None:
+ # Label smoothing = 0.1 to prevent overconfidence and encourage diversity
+ # Softens target distributions: 90% on correct token, 10% distributed to alternatives
+ loss_fct = nn.CrossEntropyLoss(label_smoothing=0.1)
+ loss = loss_fct(lm_logits.reshape(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
+ return ((loss,) + output) if loss is not None else output
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ x_num=None,
+ x_cat=None,
+ **kwargs
+ ):
+ """Prepare inputs for autoregressive generation.
+
+ Args:
+ decoder_input_ids: [batch, cur_len] current decoder input IDs
+ past_key_values: Cached key-value states from previous steps
+ x_num: [batch, n_num_features] continuous demographics (passed through)
+ x_cat: [batch, n_cat_features] categorical demographics (passed through)
+ Other args: Standard BART generation arguments
+
+ Returns:
+ Dictionary of inputs for next generation step
+ """
+ # Cut decoder_input_ids if past is used (only need last token)
+ if past_key_values is not None:
+ decoder_input_ids = decoder_input_ids[:, -1:]
+
+ return {
+ "input_ids": None,
+ "encoder_outputs": encoder_outputs,
+ "past_key_values": past_key_values,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache,
+ "x_num": x_num, # Pass demographics through generation
+ "x_cat": x_cat,
+ }
+
+ @staticmethod
+ def _expand_inputs_for_generation(
+ input_ids,
+ expand_size=1,
+ is_encoder_decoder=True,
+ attention_mask=None,
+ encoder_outputs=None,
+ x_num=None,
+ x_cat=None,
+ **model_kwargs,
+ ):
+ """Expand inputs for beam search or multiple samples.
+
+ Args:
+ input_ids: [batch, seq_len] input token IDs
+ expand_size: Number of beams/samples per input
+ x_num: [batch, n_num_features] continuous demographics
+ x_cat: [batch, n_cat_features] categorical demographics
+ Other args: Standard expansion arguments
+
+ Returns:
+ Expanded input_ids and model_kwargs
+ """
+ expanded_return_idx = (
+ torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
+ )
+
+ if attention_mask is not None:
+ model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
+
+ if encoder_outputs is not None:
+ encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
+ 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
+ )
+ model_kwargs["encoder_outputs"] = encoder_outputs
+
+ # Expand demographics for beam search
+ if x_num is not None:
+ model_kwargs["x_num"] = x_num.index_select(0, expanded_return_idx)
+
+ if x_cat is not None:
+ model_kwargs["x_cat"] = x_cat.index_select(0, expanded_return_idx)
+
+ return input_ids, model_kwargs
+
+
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """Shift input ids one token to the right for teacher forcing.
+
+ Args:
+ input_ids: [batch, seq_len] target token IDs
+ pad_token_id: ID for padding token
+ decoder_start_token_id: ID for decoder start token (BOS)
+
+ Returns:
+ [batch, seq_len] shifted token IDs with BOS prepended
+ """
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("config.pad_token_id must be defined for sequence generation")
+
+ # Replace -100 in labels with pad_token_id
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+class _PromptEHRVocab:
+ """Internal vocabulary bridging NestedSequenceProcessor indices to BART token IDs.
+
+ Token layout (7 special tokens + N diagnosis codes):
+ 0 = (BartConfig.pad_token_id)
+ 1 = (BartConfig.bos_token_id / decoder_start_token_id)
+ 2 = (BartConfig.eos_token_id)
+ 3 =
+ 4 = (visit start)
+ 5 = (visit end)
+ 6 = (sequence terminator)
+ 7+ = diagnosis codes
+
+ NestedSequenceProcessor uses pad=0, unk=1, codes=2+.
+ Mapping: processor_idx i -> BART token i + 5 (for i >= 2).
+ Total BART vocab size = processor.vocab_size() + 5.
+
+ Args:
+ code_vocab (dict): Mapping of code string to processor index, as
+ returned by ``NestedSequenceProcessor.code_vocab``. Must include
+ ``""`` -> 0 and ``""`` -> 1.
+
+ Examples:
+ >>> vocab = _PromptEHRVocab({"": 0, "": 1, "428": 2, "410": 3})
+ >>> isinstance(vocab, _PromptEHRVocab)
+ True
+ >>> vocab.total_size
+ 9
+ """
+
+ PAD = 0
+ BOS = 1
+ EOS = 2
+ UNK = 3
+ VISIT_START = 4
+ VISIT_END = 5
+ SEQ_END = 6
+ CODE_OFFSET = 7
+
+ def __init__(self, code_vocab: dict):
+ """Build vocab from NestedSequenceProcessor.code_vocab dict."""
+ self._bart_to_code: Dict[int, str] = {}
+ for code, pid in code_vocab.items():
+ if pid >= 2: # skip and
+ self._bart_to_code[pid + 5] = code
+ self.total_size = len(code_vocab) + 5 # 7 special - 2 reused + N codes
+
+ def encode_visits(self, visits_tensor: torch.Tensor) -> List[int]:
+ """Encode a processed [n_visits, max_codes] LongTensor to a token ID list.
+
+ Args:
+ visits_tensor (torch.Tensor): LongTensor of shape
+ ``(n_visits, max_codes_per_visit)`` from NestedSequenceProcessor.
+ Values 0 = pad, 1 = unk, 2+ = code index.
+
+ Returns:
+ list of int: Token IDs in format
+ ``[, code, ..., , , ..., , ]``.
+ """
+ tokens = []
+ for visit in visits_tensor:
+ codes_in_visit = [
+ int(c.item()) + 5 # processor idx 2+ → BART idx 7+
+ for c in visit
+ if c.item() >= 2 # skip pad and unk
+ ]
+ if codes_in_visit:
+ tokens.append(self.VISIT_START)
+ tokens.extend(codes_in_visit)
+ tokens.append(self.VISIT_END)
+ tokens.append(self.SEQ_END)
+ return tokens
+
+ def decode_tokens(self, token_ids: List[int]) -> List[List[str]]:
+ """Decode a generated token ID list back to visit structure.
+
+ Args:
+ token_ids (list of int): Raw generated token IDs from BART.
+
+ Returns:
+ list of list of str: Decoded diagnosis code strings per visit.
+ """
+ visits: List[List[str]] = []
+ current_visit: List[str] = []
+ in_visit = False
+ for tid in token_ids:
+ if tid in (self.PAD, self.BOS, self.EOS):
+ continue # skip framing tokens (BOS is first in generate output)
+ if tid == self.SEQ_END:
+ break
+ if tid == self.VISIT_START:
+ in_visit = True
+ current_visit = []
+ elif tid == self.VISIT_END:
+ if in_visit:
+ visits.append(current_visit)
+ in_visit = False
+ elif in_visit and tid >= self.CODE_OFFSET:
+ code = self._bart_to_code.get(tid)
+ if code:
+ current_visit.append(code)
+ if in_visit and current_visit:
+ visits.append(current_visit)
+ return visits
+
+
+def _promptehr_collate_fn(batch):
+ """Collate PromptEHR training samples, padding token sequences in a batch.
+
+ Pads ``input_ids`` and ``labels`` to the longest sequence in the batch using
+ ``pad_sequence``. Builds the attention mask from padded positions.
+
+ Args:
+ batch (list of dict): Each dict has ``"input_ids"``, ``"labels"``,
+ ``"x_num"``, and ``"x_cat"`` tensors.
+
+ Returns:
+ dict: Batched tensors ready for ``PromptBartModel.forward()``.
+ """
+ input_ids = pad_sequence(
+ [item["input_ids"] for item in batch],
+ batch_first=True,
+ padding_value=_PromptEHRVocab.PAD,
+ )
+ labels = pad_sequence(
+ [item["labels"] for item in batch],
+ batch_first=True,
+ padding_value=-100,
+ )
+ attention_mask = (input_ids != _PromptEHRVocab.PAD).long()
+ x_num = torch.cat([item["x_num"] for item in batch], dim=0)
+ x_cat = torch.cat([item["x_cat"] for item in batch], dim=0)
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels": labels,
+ "x_num": x_num,
+ "x_cat": x_cat,
+ }
+
+
+class PromptEHR(BaseModel):
+ """PromptEHR: demographic-conditioned BART model for synthetic EHR generation.
+
+ Wraps ``PromptBartModel`` (HuggingFace BART with dual prompt conditioning)
+ in a PyHealth ``BaseModel`` interface. Training is handled by a HuggingFace
+ ``Trainer`` loop; generation is autoregressive token-by-token decoding.
+
+ Demographics (age as continuous, gender as categorical) are injected via
+ learned prompt vectors prepended to both encoder and decoder hidden states.
+
+ Args:
+ dataset (SampleDataset): PyHealth sample dataset produced by
+ ``set_task(promptehr_generation_mimic3_fn)``. Must have
+ ``input_processors["visits"]`` (NestedSequenceProcessor).
+ n_num_features (int): Continuous demographic features (1 for age).
+ Default: 1.
+ cat_cardinalities (list of int): Category counts per categorical
+ feature ([2] for binary gender M/F). Default: [2].
+ d_hidden (int): Reparameterization dimension for prompt encoder.
+ Default: 128.
+ prompt_length (int): Number of prompt vectors per feature. Default: 1.
+ bart_config_name (str): Pretrained BART config to use.
+ Default: ``"facebook/bart-base"``.
+ epochs (int): Training epochs. Default: 20.
+ batch_size (int): Training batch size. Default: 16.
+ lr (float): AdamW learning rate. Default: 1e-5.
+ warmup_steps (int): Linear warmup steps. Default: 1000.
+ max_seq_length (int): Maximum token sequence length. Default: 512.
+ save_dir (str): Directory for checkpoints. Default: ``"./save/"``.
+
+ Examples:
+ >>> from pyhealth.datasets.sample_dataset import InMemorySampleDataset
+ >>> samples = [
+ ... {"patient_id": "p1", "visits": [["428", "427"], ["410"]], "age": 65.0, "gender": 0},
+ ... {"patient_id": "p2", "visits": [["250"], ["401", "272"]], "age": 52.0, "gender": 1},
+ ... ]
+ >>> dataset = InMemorySampleDataset(
+ ... samples=samples,
+ ... input_schema={"visits": "nested_sequence"},
+ ... output_schema={},
+ ... )
+ >>> model = PromptEHR(dataset, d_hidden=32, prompt_length=1)
+ >>> isinstance(model, PromptEHR)
+ True
+ """
+
+ def __init__(
+ self,
+ dataset,
+ n_num_features: int = 1,
+ cat_cardinalities: Optional[list] = None,
+ d_hidden: int = 128,
+ prompt_length: int = 1,
+ bart_config_name: "Union[str, BartConfig]" = "facebook/bart-base",
+ epochs: int = 20,
+ batch_size: int = 16,
+ lr: float = 1e-5,
+ warmup_steps: int = 1000,
+ max_seq_length: int = 512,
+ save_dir: str = "./save/",
+ ):
+ """Initialize PromptEHR with vocab derived from the dataset processor."""
+ super().__init__(dataset)
+
+ self.mode = None # skip discriminative evaluation
+ self.save_dir = save_dir
+ self.epochs = epochs
+ self.batch_size = batch_size
+ self.lr = lr
+ self.warmup_steps = warmup_steps
+ self.max_seq_length = max_seq_length
+ self._demo_pool: List[tuple] = [] # (age, gender) pairs from training data
+
+ if cat_cardinalities is None:
+ cat_cardinalities = [2]
+
+ # Derive vocab from the dataset's NestedSequenceProcessor
+ visits_processor = dataset.input_processors["visits"]
+ self._vocab = _PromptEHRVocab(visits_processor.code_vocab)
+ bart_vocab_size = self._vocab.total_size
+
+ # Configure BART with our custom vocab and special token IDs
+ if isinstance(bart_config_name, str):
+ bart_config = BartConfig.from_pretrained(bart_config_name)
+ else:
+ # Accept a BartConfig object directly (useful for tiny test models)
+ bart_config = bart_config_name
+ bart_config.vocab_size = bart_vocab_size
+ bart_config.pad_token_id = _PromptEHRVocab.PAD
+ bart_config.bos_token_id = _PromptEHRVocab.BOS
+ bart_config.eos_token_id = _PromptEHRVocab.EOS
+ bart_config.decoder_start_token_id = _PromptEHRVocab.BOS
+ bart_config.forced_eos_token_id = _PromptEHRVocab.SEQ_END
+ bart_config.dropout = 0.3
+ bart_config.attention_dropout = 0.3
+ bart_config.activation_dropout = 0.3
+
+ self.bart_model = PromptBartModel(
+ config=bart_config,
+ n_num_features=n_num_features,
+ cat_cardinalities=cat_cardinalities,
+ d_hidden=d_hidden,
+ prompt_length=prompt_length,
+ )
+
+ def forward(self, **kwargs) -> Dict:
+ """Not implemented — PromptEHR is a generative model without a discriminative forward.
+
+ Raises:
+ NotImplementedError: Always. Use ``train_model`` and
+ ``synthesize_dataset`` instead.
+ """
+ raise NotImplementedError(
+ "PromptEHR is a generative model. Use train_model() and synthesize_dataset()."
+ )
+
+ def train_model(self, train_dataset, val_dataset=None) -> None:
+ """Train PromptEHR using a HuggingFace Trainer loop.
+
+ Converts PyHealth SampleDataset samples to BART token sequences and
+ trains with HuggingFace ``Trainer``. Demographics (age, gender) are
+ passed as ``x_num`` / ``x_cat`` via a custom data collator.
+
+ Named ``train_model`` (not ``train``) to avoid shadowing
+ ``nn.Module.train()``.
+
+ Args:
+ train_dataset (SampleDataset): Training set with ``"visits"``,
+ ``"age"``, and ``"gender"`` fields.
+ val_dataset (SampleDataset, optional): Validation set for loss
+ monitoring. Default: None.
+ """
+ from torch.utils.data import Dataset as TorchDataset
+ from transformers import Trainer, TrainingArguments
+
+ vocab = self._vocab
+ max_len = self.max_seq_length
+
+ class _EHRDataset(TorchDataset):
+ def __init__(self, samples):
+ self._samples = list(samples)
+
+ def __len__(self):
+ return len(self._samples)
+
+ def __getitem__(self, idx):
+ s = self._samples[idx]
+ tokens = vocab.encode_visits(s["visits"])
+ if len(tokens) > max_len:
+ tokens = tokens[:max_len - 1] + [vocab.SEQ_END]
+ age = float(s.get("age", 60.0))
+ gender = int(s.get("gender", 0))
+ return {
+ "input_ids": torch.tensor(tokens, dtype=torch.long),
+ "labels": torch.tensor(tokens, dtype=torch.long),
+ "x_num": torch.tensor([[age]], dtype=torch.float32),
+ "x_cat": torch.tensor([[gender]], dtype=torch.long),
+ }
+
+ train_samples = list(train_dataset)
+ # Store demographics pool for synthesize_dataset sampling
+ self._demo_pool = [
+ (float(s.get("age", 60.0)), int(s.get("gender", 0)))
+ for s in train_samples
+ ]
+
+ os.makedirs(self.save_dir, exist_ok=True)
+ training_args = TrainingArguments(
+ output_dir=self.save_dir,
+ num_train_epochs=self.epochs,
+ per_device_train_batch_size=self.batch_size,
+ learning_rate=self.lr,
+ warmup_steps=self.warmup_steps,
+ save_strategy="epoch",
+ logging_steps=50,
+ remove_unused_columns=False, # essential: keeps x_num/x_cat
+ use_cpu=not torch.cuda.is_available(),
+ report_to="none",
+ )
+
+ trainer = Trainer(
+ model=self.bart_model,
+ args=training_args,
+ train_dataset=_EHRDataset(train_samples),
+ eval_dataset=_EHRDataset(list(val_dataset)) if val_dataset else None,
+ data_collator=_promptehr_collate_fn,
+ )
+ trainer.train()
+
+ self.save_model(os.path.join(self.save_dir, "checkpoint.pt"))
+
+ def synthesize_dataset(
+ self, num_samples: int, random_sampling: bool = True
+ ) -> List[Dict]:
+ """Generate a synthetic patient dataset.
+
+ Samples demographics from the training data distribution (if available)
+ and generates autoregressive token sequences via BART. Each sequence is
+ decoded back to a nested list of diagnosis code strings.
+
+ Args:
+ num_samples (int): Number of synthetic patients to generate.
+ random_sampling (bool): If True, uses multinomial sampling with
+ ``temperature=0.7, top_p=0.95``. If False, uses greedy decoding.
+ Default: True.
+
+ Returns:
+ list of dict: One record per synthetic patient. Each dict has:
+ ``"patient_id"`` (str): unique identifier, e.g. ``"synthetic_0"``.
+ ``"visits"`` (list of list of str): decoded code strings per visit.
+ """
+ self.bart_model.eval()
+ # Use bart_model's device, not self.device — HuggingFace Trainer
+ # moves bart_model to GPU but doesn't move the parent PromptEHR module.
+ device = next(self.bart_model.parameters()).device
+
+ results = []
+ with torch.no_grad():
+ for i in range(num_samples):
+ # Sample demographics from training distribution (or defaults)
+ if self._demo_pool:
+ age, gender = self._demo_pool[
+ random.randrange(len(self._demo_pool))
+ ]
+ else:
+ age, gender = 60.0, 0
+
+ x_num = torch.tensor([[age]], dtype=torch.float32, device=device)
+ x_cat = torch.tensor([[gender]], dtype=torch.long, device=device)
+
+ # PAD token as minimal encoder input; prompts carry the signal
+ encoder_input = torch.tensor(
+ [[_PromptEHRVocab.PAD]], dtype=torch.long, device=device
+ )
+
+ output_ids = self.bart_model.generate(
+ input_ids=encoder_input,
+ attention_mask=torch.ones_like(encoder_input),
+ x_num=x_num,
+ x_cat=x_cat,
+ max_length=self.max_seq_length,
+ num_beams=1,
+ early_stopping=False,
+ do_sample=random_sampling,
+ temperature=0.7 if random_sampling else 1.0,
+ top_p=0.95 if random_sampling else 1.0,
+ pad_token_id=_PromptEHRVocab.PAD,
+ eos_token_id=_PromptEHRVocab.SEQ_END,
+ bos_token_id=_PromptEHRVocab.BOS,
+ )
+
+ visits = self._vocab.decode_tokens(output_ids[0].tolist())
+ results.append({
+ "patient_id": f"synthetic_{i}",
+ "visits": visits,
+ })
+
+ return results
+
+ def save_model(self, path: str) -> None:
+ """Save model weights and vocab to a checkpoint file.
+
+ Args:
+ path (str): Destination file path (e.g. ``"./save/checkpoint.pt"``).
+
+ Examples:
+ >>> import tempfile, os
+ >>> tmpdir = tempfile.mkdtemp()
+ >>> model.save_model(os.path.join(tmpdir, "ckpt.pt"))
+ """
+ os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
+ torch.save(
+ {
+ "model": self.bart_model.state_dict(),
+ "vocab": self._vocab,
+ "bart_config": self.bart_model.config,
+ },
+ path,
+ )
+
+ def load_model(self, path: str) -> None:
+ """Load model weights from a checkpoint saved by ``save_model``.
+
+ Args:
+ path (str): Path to checkpoint file produced by ``save_model``.
+
+ Examples:
+ >>> model.load_model("./save/checkpoint.pt")
+ """
+ checkpoint = torch.load(path, map_location=self.device, weights_only=False)
+ self.bart_model.load_state_dict(checkpoint["model"])
+ if "vocab" in checkpoint:
+ self._vocab = checkpoint["vocab"]
diff --git a/pyhealth/models/promptehr/utils.py b/pyhealth/models/promptehr/utils.py
new file mode 100644
index 000000000..43e13ca83
--- /dev/null
+++ b/pyhealth/models/promptehr/utils.py
@@ -0,0 +1,29 @@
+"""Utility functions and classes for PromptEHR.
+
+This module contains:
+ - VisitStructureSampler: Samples realistic visit structures for generation
+ - Data collation functions
+ - Helper utilities
+"""
+
+import torch
+import torch.nn as nn
+
+
+class VisitStructureSampler:
+ """Samples realistic visit structures from training data.
+
+ This is a critical component added Nov 21, 2025 that solves the
+ over-generation problem. Reduces codes/patient from 18.1 → 11.97 (34%).
+
+ Args:
+ TODO: Add arguments after porting from pehr_scratch
+ """
+
+ def __init__(self, **kwargs):
+ # TODO: Port from ~/pehr_scratch/visit_structure_sampler.py
+ raise NotImplementedError("VisitStructureSampler porting in progress")
+
+ def sample(self, **kwargs):
+ """Sample a visit structure."""
+ raise NotImplementedError("VisitStructureSampler porting in progress")
diff --git a/pyhealth/models/promptehr/visit_sampler.py b/pyhealth/models/promptehr/visit_sampler.py
new file mode 100644
index 000000000..03efbf78f
--- /dev/null
+++ b/pyhealth/models/promptehr/visit_sampler.py
@@ -0,0 +1,121 @@
+"""
+Sample realistic visit structures from real MIMIC-III data distributions.
+
+This module provides functionality to sample the number of visits per patient
+and the number of diagnosis codes per visit, matching the empirical distributions
+observed in real EHR data.
+"""
+import numpy as np
+from typing import List
+
+
+class VisitStructureSampler:
+ """Sample realistic visit and code count structures from training data."""
+
+ def __init__(self, patient_records: List, seed: int = 42):
+ """Initialize sampler with empirical distributions from training data.
+
+ Args:
+ patient_records: List of patient records from training set.
+ Each record should have a 'visits' attribute (list of visit codes).
+ seed: Random seed for reproducibility.
+ """
+ self.rng = np.random.RandomState(seed)
+
+ # Extract empirical distributions
+ self.num_visits_per_patient = []
+ self.codes_per_visit_all = []
+
+ for patient in patient_records:
+ # Handle both dict-like and object-like patient records
+ if hasattr(patient, 'visits'):
+ visits = patient.visits
+ elif isinstance(patient, dict) and 'visits' in patient:
+ visits = patient['visits']
+ else:
+ continue
+
+ num_visits = len(visits)
+ self.num_visits_per_patient.append(num_visits)
+
+ for visit in visits:
+ num_codes = len(visit)
+ if num_codes > 0: # Only include non-empty visits
+ self.codes_per_visit_all.append(num_codes)
+
+ # Convert to numpy arrays
+ self.num_visits_per_patient = np.array(self.num_visits_per_patient)
+ self.codes_per_visit_all = np.array(self.codes_per_visit_all)
+
+ # Compute statistics for logging
+ self.stats = {
+ 'visits_mean': np.mean(self.num_visits_per_patient),
+ 'visits_median': np.median(self.num_visits_per_patient),
+ 'visits_90th': np.percentile(self.num_visits_per_patient, 90),
+ 'codes_mean': np.mean(self.codes_per_visit_all),
+ 'codes_median': np.median(self.codes_per_visit_all),
+ 'codes_90th': np.percentile(self.codes_per_visit_all, 90),
+ 'codes_95th': np.percentile(self.codes_per_visit_all, 95),
+ }
+
+ def sample_num_visits(self) -> int:
+ """Sample number of visits from empirical distribution.
+
+ Returns:
+ Number of visits (>= 0).
+ """
+ return int(self.rng.choice(self.num_visits_per_patient))
+
+ def sample_codes_per_visit(self, n_visits: int) -> List[int]:
+ """Sample number of codes for each visit from empirical distribution.
+
+ Args:
+ n_visits: Number of visits to sample code counts for.
+
+ Returns:
+ List of integers representing codes per visit.
+ """
+ if n_visits == 0:
+ return []
+
+ # Sample with replacement from empirical distribution
+ codes_counts = self.rng.choice(self.codes_per_visit_all, size=n_visits, replace=True)
+ return codes_counts.tolist()
+
+ def sample_structure(self) -> dict:
+ """Sample complete visit structure (visits + codes per visit).
+
+ Returns:
+ Dictionary with:
+ - 'num_visits': int (number of visits)
+ - 'codes_per_visit': List[int] (codes for each visit)
+ """
+ num_visits = self.sample_num_visits()
+ codes_per_visit = self.sample_codes_per_visit(num_visits)
+
+ return {
+ 'num_visits': num_visits,
+ 'codes_per_visit': codes_per_visit
+ }
+
+ def get_statistics(self) -> dict:
+ """Get statistics about the underlying distributions.
+
+ Returns:
+ Dictionary with mean/median/percentile statistics.
+ """
+ return self.stats.copy()
+
+ def __repr__(self) -> str:
+ """String representation showing distribution statistics."""
+ return (
+ f"VisitStructureSampler(\n"
+ f" Visits/patient: mean={self.stats['visits_mean']:.2f}, "
+ f"median={self.stats['visits_median']:.0f}, "
+ f"90th%={self.stats['visits_90th']:.0f}\n"
+ f" Codes/visit: mean={self.stats['codes_mean']:.2f}, "
+ f"median={self.stats['codes_median']:.0f}, "
+ f"90th%={self.stats['codes_90th']:.0f}, "
+ f"95th%={self.stats['codes_95th']:.0f}\n"
+ f")"
+ )
diff --git a/pyhealth/processors/__init__.py b/pyhealth/processors/__init__.py
index 283354f80..f24f1fa0c 100644
--- a/pyhealth/processors/__init__.py
+++ b/pyhealth/processors/__init__.py
@@ -18,7 +18,12 @@ def get_processor(name: str):
# Import all processors so they register themselves
-from .image_processor import ImageProcessor
+from .base_processor import FeatureProcessor
+try:
+ from .image_processor import ImageProcessor
+ _has_image_processor = True
+except (ImportError, RuntimeError):
+ _has_image_processor = False # PIL/torchvision unavailable or broken
from .label_processor import (
BinaryLabelProcessor,
MultiClassLabelProcessor,
@@ -44,16 +49,22 @@ def get_processor(name: str):
from .tensor_processor import TensorProcessor
from .text_processor import TextProcessor
from .timeseries_processor import TimeseriesProcessor
-from .time_image_processor import TimeImageProcessor
+try:
+ from .time_image_processor import TimeImageProcessor
+ _has_time_image_processor = True
+except (ImportError, RuntimeError):
+ _has_time_image_processor = False # PIL/torchvision unavailable or broken
from .audio_processor import AudioProcessor
from .ignore_processor import IgnoreProcessor
from .tuple_time_text_processor import TupleTimeTextProcessor
-# Expose public API
+# Expose public API — optional processors only listed if successfully imported
__all__ = [
"FeatureProcessor",
- "ImageProcessor",
- "LabelProcessor",
+ "BinaryLabelProcessor",
+ "MultiClassLabelProcessor",
+ "MultiLabelProcessor",
+ "RegressionLabelProcessor",
"MultiHotProcessor",
"NestedFloatsProcessor",
"NestedSequenceProcessor",
@@ -65,7 +76,10 @@ def get_processor(name: str):
"TensorProcessor",
"TextProcessor",
"TimeseriesProcessor",
- "TimeImageProcessor",
"AudioProcessor",
"TupleTimeTextProcessor",
]
+if _has_image_processor:
+ __all__.append("ImageProcessor")
+if _has_time_image_processor:
+ __all__.append("TimeImageProcessor")
diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py
index 2f4294a19..f9834777d 100644
--- a/pyhealth/tasks/__init__.py
+++ b/pyhealth/tasks/__init__.py
@@ -1,14 +1,23 @@
from .base_task import BaseTask
from .benchmark_ehrshot import BenchmarkEHRShot
+from .ehr_generation import (
+ PromptEHRGenerationMIMIC3,
+ PromptEHRGenerationMIMIC4,
+ promptehr_generation_mimic3_fn,
+ promptehr_generation_mimic4_fn,
+)
from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction
from .bmd_hs_disease_classification import BMDHSDiseaseClassification
-from .cardiology_detect import (
- cardiology_isAD_fn,
- cardiology_isAR_fn,
- cardiology_isBBBFB_fn,
- cardiology_isCD_fn,
- cardiology_isWA_fn,
-)
+try:
+ from .cardiology_detect import (
+ cardiology_isAD_fn,
+ cardiology_isAR_fn,
+ cardiology_isBBBFB_fn,
+ cardiology_isCD_fn,
+ cardiology_isWA_fn,
+ )
+except ImportError:
+ pass # scipy unavailable; cardiology tasks not registered
from .chestxray14_binary_classification import ChestXray14BinaryClassification
from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification
from .covid19_cxr_classification import COVID19CXRClassification
@@ -21,8 +30,14 @@
drug_recommendation_mimic4_fn,
drug_recommendation_omop_fn,
)
-from .EEG_abnormal import EEG_isAbnormal_fn
-from .EEG_events import EEG_events_fn
+try:
+ from .EEG_abnormal import EEG_isAbnormal_fn
+except ImportError:
+ pass # mne unavailable
+try:
+ from .EEG_events import EEG_events_fn
+except ImportError:
+ pass # mne unavailable
from .in_hospital_mortality_mimic4 import InHospitalMortalityMIMIC4
from .length_of_stay_prediction import (
LengthOfStayPredictioneICU,
@@ -53,16 +68,25 @@
ReadmissionPredictionMIMIC4,
ReadmissionPredictionOMOP,
)
-from .sleep_staging import (
- sleep_staging_isruc_fn,
- sleep_staging_shhs_fn,
- sleep_staging_sleepedf_fn,
-)
-from .sleep_staging_v2 import SleepStagingSleepEDF
-from .temple_university_EEG_tasks import (
- EEGEventsTUEV,
- EEGAbnormalTUAB
-)
+try:
+ from .sleep_staging import (
+ sleep_staging_isruc_fn,
+ sleep_staging_shhs_fn,
+ sleep_staging_sleepedf_fn,
+ )
+except ImportError:
+ pass # mne unavailable
+try:
+ from .sleep_staging_v2 import SleepStagingSleepEDF
+except ImportError:
+ pass # mne unavailable
+try:
+ from .temple_university_EEG_tasks import (
+ EEGEventsTUEV,
+ EEGAbnormalTUAB
+ )
+except ImportError:
+ pass # mne unavailable
from .variant_classification import (
MutationPathogenicityPrediction,
VariantClassificationClinVar,
diff --git a/pyhealth/tasks/ehr_generation.py b/pyhealth/tasks/ehr_generation.py
new file mode 100644
index 000000000..788dd0351
--- /dev/null
+++ b/pyhealth/tasks/ehr_generation.py
@@ -0,0 +1,142 @@
+"""Task function for PromptEHR synthetic EHR generation.
+
+Provides task classes for training PromptEHR on MIMIC-III and MIMIC-IV datasets.
+Demographics (age, gender) are extracted alongside visit codes because PromptEHR
+conditions generation on patient-level continuous and categorical features.
+"""
+
+from datetime import datetime
+from typing import Dict, List
+
+import polars as pl
+
+from pyhealth.tasks.base_task import BaseTask
+
+
+class PromptEHRGenerationMIMIC3(BaseTask):
+ """Task for PromptEHR synthetic data generation using MIMIC-III.
+
+ PromptEHR is a BART-based seq2seq model that conditions generation on
+ patient demographics (age, gender) via learned prompt vectors. This task
+ extracts per-admission ICD-9 diagnosis codes grouped into a nested visit
+ list, along with patient demographics for conditioning.
+
+ Patients with fewer than 2 admissions containing diagnosis codes are
+ excluded.
+
+ Attributes:
+ task_name (str): Unique task identifier.
+ input_schema (dict): ``"visits"`` uses ``"nested_sequence"`` encoding
+ (list of lists of code strings).
+ output_schema (dict): Empty — generative task, no conditioning label.
+ _icd_col (str): Polars column path for ICD codes in MIMIC-III.
+
+ Examples:
+ >>> fn = PromptEHRGenerationMIMIC3()
+ >>> fn.task_name
+ 'PromptEHRGenerationMIMIC3'
+ """
+
+ task_name = "PromptEHRGenerationMIMIC3"
+ input_schema = {"visits": "nested_sequence"}
+ output_schema = {}
+ _icd_col = "diagnoses_icd/icd9_code"
+
+ def __call__(self, patient) -> List[Dict]:
+ """Extract visit sequences and demographics for a single patient.
+
+ Diagnosis codes are grouped per admission into a nested list. Age is
+ computed as years between date-of-birth and the first admission date.
+ Gender is encoded as 0 (male) or 1 (female). Defaults of
+ ``age=60.0, gender=0`` are used when demographics are unavailable.
+
+ Args:
+ patient: A PyHealth Patient object with admissions and
+ diagnoses_icd event data.
+
+ Returns:
+ list of dict: A single-element list, or empty list if fewer
+ than 2 visits have diagnosis codes. Each dict contains:
+ ``"patient_id"`` (str): patient identifier.
+ ``"visits"`` (list of list of str): ICD codes per visit.
+ ``"age"`` (float): patient age at first admission in years.
+ ``"gender"`` (int): 0 for male, 1 for female.
+ """
+ admissions = list(patient.get_events(event_type="admissions"))
+ if len(admissions) < 2:
+ return []
+
+ # --- Demographics ---
+ age = 60.0
+ gender = 0
+ patients_df = patient.get_events(event_type="patients", return_df=True)
+ if len(patients_df) > 0:
+ if "patients/gender" in patients_df.columns:
+ gender_val = patients_df["patients/gender"][0]
+ if gender_val == "F":
+ gender = 1
+ if "patients/dob" in patients_df.columns and admissions:
+ dob_val = patients_df["patients/dob"][0]
+ first_admit_ts = admissions[0].timestamp
+ if dob_val is not None and first_admit_ts is not None:
+ # dob_val may be a date/datetime or a string
+ if hasattr(dob_val, "year"):
+ dob_dt = datetime(dob_val.year, dob_val.month, dob_val.day)
+ else:
+ dob_dt = datetime.strptime(str(dob_val)[:10], "%Y-%m-%d")
+ raw_age = (first_admit_ts - dob_dt).days / 365.25
+ # Clamp: MIMIC-III shifts >89-year-old DOBs far into the
+ # past; treat those as 90.
+ age = float(min(90.0, max(0.0, raw_age)))
+
+ # --- Visit codes ---
+ visits = []
+ for adm in admissions:
+ codes = (
+ patient.get_events(
+ event_type="diagnoses_icd",
+ filters=[("hadm_id", "==", adm.hadm_id)],
+ return_df=True,
+ )
+ .select(pl.col(self._icd_col))
+ .to_series()
+ .drop_nulls()
+ .to_list()
+ )
+ if codes:
+ visits.append(codes)
+
+ if len(visits) < 2:
+ return []
+
+ return [{
+ "patient_id": patient.patient_id,
+ "visits": visits,
+ "age": age,
+ "gender": gender,
+ }]
+
+
+class PromptEHRGenerationMIMIC4(PromptEHRGenerationMIMIC3):
+ """Task for PromptEHR synthetic data generation using MIMIC-IV.
+
+ Inherits all logic from :class:`PromptEHRGenerationMIMIC3`. Overrides only
+ the task name and ICD code column to match the MIMIC-IV schema, where the
+ column is ``icd_code`` (unversioned) rather than ``icd9_code``.
+
+ Attributes:
+ task_name (str): Unique task identifier.
+ _icd_col (str): Polars column path for ICD codes in MIMIC-IV.
+
+ Examples:
+ >>> fn = PromptEHRGenerationMIMIC4()
+ >>> fn.task_name
+ 'PromptEHRGenerationMIMIC4'
+ """
+
+ task_name = "PromptEHRGenerationMIMIC4"
+ _icd_col = "diagnoses_icd/icd_code"
+
+
+promptehr_generation_mimic3_fn = PromptEHRGenerationMIMIC3()
+promptehr_generation_mimic4_fn = PromptEHRGenerationMIMIC4()
diff --git a/pyproject.toml b/pyproject.toml
index 308e6b114..c9fd4626d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -33,7 +33,7 @@ dependencies = [
"networkx",
"mne~=1.10.0",
"urllib3~=2.5.0",
- "numpy~=2.2.0",
+ "numpy>=2.0.0",
"tqdm",
"polars~=1.35.2",
"pandas~=2.3.1",
diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/integration/test_promptehr_end_to_end.py b/tests/integration/test_promptehr_end_to_end.py
new file mode 100644
index 000000000..a3c0bdac6
--- /dev/null
+++ b/tests/integration/test_promptehr_end_to_end.py
@@ -0,0 +1,431 @@
+"""End-to-end integration tests for the PromptEHR synthetic EHR generation pipeline.
+
+Category A tests use InMemorySampleDataset with synthetic data — no external
+data required and must always pass.
+
+Category B tests require actual MIMIC-III data and are skipped gracefully when
+the data is unavailable.
+
+The bootstrap pattern mirrors test_corgan_end_to_end.py: load PromptEHR and
+InMemorySampleDataset via importlib while stubbing out heavy optional
+dependencies (litdata, pyarrow) that are not yet in the venv. transformers IS
+available in the venv and is loaded normally.
+"""
+
+import importlib.util
+import os
+import sys
+import tempfile
+import unittest
+from unittest.mock import MagicMock
+
+
+# ---------------------------------------------------------------------------
+# Bootstrap: load PromptEHR, BaseModel, and InMemorySampleDataset without
+# triggering pyhealth.models.__init__ (many models have unavailable deps) or
+# pyhealth.datasets.__init__ (requires litdata, pyarrow, ...).
+# ---------------------------------------------------------------------------
+
+
+def _bootstrap():
+ """Load PromptEHR, BaseModel, and InMemorySampleDataset via importlib.
+
+ Returns:
+ (BaseModel, PromptEHR, InMemorySampleDataset)
+ """
+ import pyhealth # noqa: F401 — top-level __init__ has no heavy deps
+
+ # Stub pyhealth.datasets so that base_model.py's
+ # "from ..datasets import SampleDataset" resolves cleanly.
+ if "pyhealth.datasets" not in sys.modules:
+ ds_stub = MagicMock()
+
+ class _FakeSampleDataset: # noqa: N801
+ pass
+
+ ds_stub.SampleDataset = _FakeSampleDataset
+ sys.modules["pyhealth.datasets"] = ds_stub
+
+ # Stub pyhealth.models so we can control loading without the real __init__.
+ if "pyhealth.models" not in sys.modules or isinstance(
+ sys.modules["pyhealth.models"], MagicMock
+ ):
+ models_stub = MagicMock()
+ sys.modules["pyhealth.models"] = models_stub
+ else:
+ models_stub = sys.modules["pyhealth.models"]
+
+ # Processors are safe to import normally.
+ from pyhealth.processors import PROCESSOR_REGISTRY # noqa: F401
+
+ def _load_file(mod_name, filepath):
+ spec = importlib.util.spec_from_file_location(mod_name, filepath)
+ mod = importlib.util.module_from_spec(spec)
+ sys.modules[mod_name] = mod
+ spec.loader.exec_module(mod)
+ return mod
+
+ root = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ models_dir = os.path.join(root, "pyhealth", "models")
+ promptehr_dir = os.path.join(models_dir, "promptehr")
+
+ # Load base_model and expose via stub.
+ bm_mod = _load_file(
+ "pyhealth.models.base_model", os.path.join(models_dir, "base_model.py")
+ )
+ BaseModel = bm_mod.BaseModel
+ models_stub.BaseModel = BaseModel
+
+ # Create a package stub for pyhealth.models.promptehr so that
+ # model.py's relative imports (from .conditional_prompt import ...) work.
+ promptehr_pkg_stub = MagicMock()
+ sys.modules.setdefault("pyhealth.models.promptehr", promptehr_pkg_stub)
+
+ # Load each PromptEHR submodule in dependency order.
+ # Each is standalone (only torch + transformers, no cross-module imports).
+ for mod_name in (
+ "conditional_prompt",
+ "bart_encoder",
+ "bart_decoder",
+ "visit_sampler",
+ "generation",
+ ):
+ _load_file(
+ f"pyhealth.models.promptehr.{mod_name}",
+ os.path.join(promptehr_dir, f"{mod_name}.py"),
+ )
+
+ # Load model.py last (depends on the submodules loaded above + BaseModel).
+ model_mod = _load_file(
+ "pyhealth.models.promptehr.model",
+ os.path.join(promptehr_dir, "model.py"),
+ )
+ PromptEHR = model_mod.PromptEHR
+
+ # Stub litdata so sample_dataset.py can be loaded without the full package.
+ if "litdata" not in sys.modules:
+ litdata_pkg = MagicMock()
+ litdata_pkg.StreamingDataset = type(
+ "StreamingDataset", (), {"__init__": lambda self, *a, **kw: None}
+ )
+ litdata_utilities = MagicMock()
+ litdata_utilities_train_test = MagicMock()
+ litdata_utilities_train_test.deepcopy_dataset = lambda x: x
+ litdata_utilities.train_test_split = litdata_utilities_train_test
+ litdata_pkg.utilities = litdata_utilities
+ sys.modules["litdata"] = litdata_pkg
+ sys.modules["litdata.utilities"] = litdata_utilities
+ sys.modules["litdata.utilities.train_test_split"] = (
+ litdata_utilities_train_test
+ )
+
+ # Load sample_dataset.py directly (bypasses datasets/__init__.py).
+ ds_file_mod = _load_file(
+ "pyhealth.datasets.sample_dataset",
+ os.path.join(root, "pyhealth", "datasets", "sample_dataset.py"),
+ )
+ InMemorySampleDataset = ds_file_mod.InMemorySampleDataset
+
+ return BaseModel, PromptEHR, InMemorySampleDataset
+
+
+BaseModel, PromptEHR, InMemorySampleDataset = _bootstrap()
+
+import torch # noqa: E402
+from transformers import BartConfig # noqa: E402
+
+
+# ---------------------------------------------------------------------------
+# Shared helpers
+# ---------------------------------------------------------------------------
+
+# Nested lists of code strings — PromptEHR uses nested_sequence schema.
+# 8 samples with ≥2 visits each, plus demographics.
+_SMALL_SAMPLES = [
+ {"patient_id": "p1", "visits": [["A", "B"], ["C", "D"]], "age": 65.0, "gender": 0},
+ {"patient_id": "p2", "visits": [["E"], ["F", "G"]], "age": 45.0, "gender": 1},
+ {"patient_id": "p3", "visits": [["A", "C"], ["B", "E"]], "age": 55.0, "gender": 0},
+ {"patient_id": "p4", "visits": [["D"], ["A"]], "age": 70.0, "gender": 1},
+ {"patient_id": "p5", "visits": [["B", "F"], ["C", "G"]], "age": 40.0, "gender": 0},
+ {"patient_id": "p6", "visits": [["E", "A"], ["D"]], "age": 60.0, "gender": 1},
+ {"patient_id": "p7", "visits": [["G", "B"], ["F", "A"]], "age": 50.0, "gender": 0},
+ {"patient_id": "p8", "visits": [["C"], ["D", "E"]], "age": 35.0, "gender": 1},
+]
+
+# Tiny BART config to keep tests fast (avoids downloading/using 768-dim bart-base).
+_TINY_BART_CONFIG = BartConfig(
+ d_model=32,
+ encoder_layers=1,
+ decoder_layers=1,
+ encoder_ffn_dim=64,
+ decoder_ffn_dim=64,
+ encoder_attention_heads=4,
+ decoder_attention_heads=4,
+ max_position_embeddings=128,
+)
+
+# Minimal model kwargs — tiny architecture and 1 epoch to keep tests fast.
+_SMALL_MODEL_KWARGS = dict(
+ n_num_features=1,
+ cat_cardinalities=[2],
+ d_hidden=32,
+ prompt_length=1,
+ bart_config_name=_TINY_BART_CONFIG,
+ epochs=1,
+ batch_size=4,
+ warmup_steps=0,
+ max_seq_length=64,
+)
+
+
+def _make_dataset(samples=None):
+ if samples is None:
+ samples = _SMALL_SAMPLES
+ return InMemorySampleDataset(
+ samples=samples,
+ input_schema={"visits": "nested_sequence"},
+ output_schema={},
+ )
+
+
+def _make_trained_model():
+ """Return a PromptEHR model trained for 1 epoch on _SMALL_SAMPLES."""
+ dataset = _make_dataset()
+ tmpdir = tempfile.mkdtemp()
+ model = PromptEHR(dataset, save_dir=tmpdir, **_SMALL_MODEL_KWARGS)
+ model.train_model(dataset)
+ return model, tmpdir
+
+
+# ---------------------------------------------------------------------------
+# Category A: In-Memory Integration Tests (must always pass)
+# ---------------------------------------------------------------------------
+
+
+class TestPromptEHRIsBaseModelInstance(unittest.TestCase):
+ """PromptEHR model is an instance of BaseModel."""
+
+ def test_model_is_basemodel_instance(self):
+ dataset = _make_dataset()
+ model = PromptEHR(dataset, **_SMALL_MODEL_KWARGS)
+ self.assertIsInstance(model, BaseModel)
+
+
+class TestPromptEHRFeatureKeys(unittest.TestCase):
+ """model.feature_keys equals ['visits']."""
+
+ def test_feature_keys(self):
+ dataset = _make_dataset()
+ model = PromptEHR(dataset, **_SMALL_MODEL_KWARGS)
+ self.assertEqual(model.feature_keys, ["visits"])
+
+
+class TestPromptEHRVocabSize(unittest.TestCase):
+ """_vocab.total_size equals processor.vocab_size() + 5."""
+
+ def test_vocab_size_matches_processor(self):
+ dataset = _make_dataset()
+ processor = dataset.input_processors["visits"]
+ model = PromptEHR(dataset, **_SMALL_MODEL_KWARGS)
+ expected = processor.vocab_size() + 5
+ self.assertEqual(model._vocab.total_size, expected)
+
+
+class TestPromptEHRForwardRaisesNotImplementedError(unittest.TestCase):
+ """Calling forward() raises NotImplementedError.
+
+ PromptEHR is a generative model; the discriminative forward pass is not
+ applicable.
+ """
+
+ def test_forward_not_implemented(self):
+ dataset = _make_dataset()
+ model = PromptEHR(dataset, **_SMALL_MODEL_KWARGS)
+ with self.assertRaises(NotImplementedError):
+ model.forward()
+
+
+class TestPromptEHRTrainModelRuns(unittest.TestCase):
+ """train_model completes one epoch without error."""
+
+ def test_train_model_runs_one_epoch(self):
+ dataset = _make_dataset()
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model = PromptEHR(dataset, save_dir=tmpdir, **_SMALL_MODEL_KWARGS)
+ try:
+ model.train_model(dataset, val_dataset=None)
+ except Exception as exc: # noqa: BLE001
+ self.fail(f"train_model raised an unexpected exception: {exc}")
+ # A checkpoint must be saved after training
+ ckpt = os.path.join(tmpdir, "checkpoint.pt")
+ self.assertTrue(os.path.exists(ckpt), f"Expected checkpoint at {ckpt}")
+
+
+class TestPromptEHRSynthesizeCount(unittest.TestCase):
+ """synthesize_dataset(num_samples=3) returns exactly 3 dicts."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.model, cls.tmpdir = _make_trained_model()
+
+ def test_synthesize_returns_correct_count(self):
+ result = self.model.synthesize_dataset(num_samples=3)
+ self.assertIsInstance(result, list)
+ self.assertEqual(len(result), 3)
+
+
+class TestPromptEHRSynthesizeOutputStructure(unittest.TestCase):
+ """Each synthesized dict has patient_id (str) and visits (nested list of str).
+
+ PromptEHR outputs nested visit lists — each patient is a list of visits,
+ each visit is a list of diagnosis code strings.
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ cls.model, cls.tmpdir = _make_trained_model()
+
+ def test_synthesize_output_structure(self):
+ result = self.model.synthesize_dataset(num_samples=3)
+ for i, item in enumerate(result):
+ self.assertIsInstance(item, dict, f"Item {i} is not a dict")
+ self.assertIn("patient_id", item, f"Item {i} missing 'patient_id'")
+ self.assertIn("visits", item, f"Item {i} missing 'visits'")
+ self.assertIsInstance(
+ item["patient_id"], str, f"patient_id in item {i} is not a str"
+ )
+ self.assertIsInstance(
+ item["visits"], list, f"visits in item {i} is not a list"
+ )
+ # visits is a nested list: list of visits, each visit a list of strings
+ for visit_idx, visit in enumerate(item["visits"]):
+ self.assertIsInstance(
+ visit, list,
+ f"visit {visit_idx} in item {i} is not a list"
+ )
+ for code in visit:
+ self.assertIsInstance(
+ code, str,
+ f"code '{code}' in visit {visit_idx}, item {i} is not str"
+ )
+
+
+class TestPromptEHRSaveLoadRoundtrip(unittest.TestCase):
+ """save_model then load_model; synthesize_dataset returns correct count."""
+
+ def test_save_load_roundtrip(self):
+ dataset = _make_dataset()
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model = PromptEHR(dataset, save_dir=tmpdir, **_SMALL_MODEL_KWARGS)
+ model.train_model(dataset)
+ ckpt_path = os.path.join(tmpdir, "test_ckpt.pt")
+ model.save_model(ckpt_path)
+ self.assertTrue(
+ os.path.exists(ckpt_path),
+ f"Expected checkpoint at {ckpt_path}",
+ )
+ model.load_model(ckpt_path)
+ result = model.synthesize_dataset(num_samples=3)
+ self.assertEqual(len(result), 3)
+
+
+# ---------------------------------------------------------------------------
+# Category B: MIMIC-III Integration Tests (skipped if data unavailable)
+# ---------------------------------------------------------------------------
+
+_MIMIC3_PATH = os.environ.get(
+ "PYHEALTH_MIMIC3_PATH",
+ "/srv/local/data/physionet.org/files/mimiciii/1.4",
+)
+
+
+class TestPromptEHRMIMIC3Integration(unittest.TestCase):
+ """End-to-end pipeline test with actual MIMIC-III data.
+
+ Skipped automatically when MIMIC-III is not present on this machine.
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ cls.skip_integration = False
+ cls.skip_reason = ""
+ try:
+ # Remove bootstrap stubs so we can attempt a real import.
+ _saved_ds_stub = sys.modules.pop("pyhealth.datasets", None)
+ try:
+ import importlib as _il
+ _il.invalidate_caches()
+ from pyhealth.datasets import MIMIC3Dataset as _MIMIC3Dataset
+ from pyhealth.tasks.ehr_generation import PromptEHRGenerationMIMIC3
+ except (ImportError, ModuleNotFoundError) as exc:
+ if _saved_ds_stub is not None:
+ sys.modules["pyhealth.datasets"] = _saved_ds_stub
+ raise ImportError(str(exc)) from exc
+
+ cls.dataset = _MIMIC3Dataset(
+ root=_MIMIC3_PATH,
+ tables=["patients", "admissions", "diagnoses_icd"],
+ )
+ task = PromptEHRGenerationMIMIC3()
+ cls.sample_dataset = cls.dataset.set_task(task)
+ except (FileNotFoundError, OSError, ImportError, ValueError) as exc:
+ cls.skip_integration = True
+ cls.skip_reason = str(exc)
+
+ def setUp(self):
+ if self.skip_integration:
+ self.skipTest(f"MIMIC-III integration test skipped: {self.skip_reason}")
+
+ def test_mimic3_set_task_returns_nonempty_dataset(self):
+ """set_task produces at least one sample from MIMIC-III."""
+ self.assertGreater(len(self.sample_dataset), 0)
+
+ def test_mimic3_sample_keys(self):
+ """Every sample must contain patient_id, visits, age, and gender keys."""
+ for sample in self.sample_dataset:
+ self.assertIn("patient_id", sample)
+ self.assertIn("visits", sample)
+ self.assertIn("age", sample)
+ self.assertIn("gender", sample)
+
+ def test_mimic3_visits_are_nested_tensors(self):
+ """visits must be a list of 1-D int64 tensors (NestedSequenceProcessor output).
+
+ NestedSequenceProcessor encodes each visit as a 1-D LongTensor of
+ code indices. This verifies the nested_sequence schema round-trips
+ correctly through set_task.
+ """
+ for sample in self.sample_dataset:
+ visits = sample["visits"]
+ self.assertIsInstance(visits, list)
+ self.assertGreater(len(visits), 0)
+ for visit in visits:
+ self.assertIsInstance(visit, torch.Tensor)
+ self.assertEqual(visit.dtype, torch.long)
+
+ def test_mimic3_full_pipeline_train_and_synthesize(self):
+ """Train one epoch on MIMIC-III data and synthesize a small batch."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model = PromptEHR(
+ self.sample_dataset,
+ d_hidden=64,
+ prompt_length=1,
+ bart_config_name=_TINY_BART_CONFIG,
+ epochs=1,
+ batch_size=16,
+ warmup_steps=0,
+ save_dir=tmpdir,
+ )
+ model.train_model(self.sample_dataset, val_dataset=None)
+ synthetic = model.synthesize_dataset(num_samples=5)
+ self.assertEqual(len(synthetic), 5)
+ for item in synthetic:
+ self.assertIn("patient_id", item)
+ self.assertIn("visits", item)
+ self.assertIsInstance(item["visits"], list)
+
+
+if __name__ == "__main__":
+ unittest.main()