From 9e109361c328ae98ba1ae2e4df79639233501b76 Mon Sep 17 00:00:00 2001 From: chufangao Date: Sun, 1 Mar 2026 04:34:14 -0600 Subject: [PATCH 1/2] init --- .../ehr_generation_mimic3_transformer.py | 393 ++++++++++++++++++ pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/ehr_generation.py | 143 +++++++ tests/core/test_mimic3_ehr_generation.py | 337 +++++++++++++++ tests/core/test_transformer_ehr_helpers.py | 308 ++++++++++++++ 5 files changed, 1182 insertions(+) create mode 100644 examples/ehr_generation/ehr_generation_mimic3_transformer.py create mode 100644 pyhealth/tasks/ehr_generation.py create mode 100644 tests/core/test_mimic3_ehr_generation.py create mode 100644 tests/core/test_transformer_ehr_helpers.py diff --git a/examples/ehr_generation/ehr_generation_mimic3_transformer.py b/examples/ehr_generation/ehr_generation_mimic3_transformer.py new file mode 100644 index 000000000..ab32907f4 --- /dev/null +++ b/examples/ehr_generation/ehr_generation_mimic3_transformer.py @@ -0,0 +1,393 @@ +""" +EHR Generation with a GPT-2 style Transformer on MIMIC-III (PyHealth) +====================================================================== + +This example trains a GPT-2 style decoder-only transformer to synthesise +longitudinal patient EHR sequences consisting of ICD-9 diagnosis codes. + +The pipeline: + +1. Load MIMIC-III via **PyHealth** and apply the ``EHRGenerationMIMIC3`` task + to obtain per-patient nested visit sequences. +2. Serialise the nested sequences into plain text using ``VISIT_DELIM`` + separators (e.g. ``"250.00 401.9 VISIT_DELIM 272.0 428.0"``). +3. Train a word-level GPT-2 model on the serialised text. +4. Sample synthetic text sequences and deserialise them back to a long-form + ``(SUBJECT_ID, HADM_ID, ICD9_CODE)`` DataFrame for downstream evaluation. + +References +---------- +- *Accelerating Reproducible Research in Synthetic EHR Generation* (CHIL 2026) + +Usage +----- +.. code-block:: bash + + # Full vocabulary (~6,955 ICD-9 codes) – recommended + python ehr_generation_mimic3_transformer.py \\ + --mimic3_root /path/to/mimic-iii-clinical-database-1.4 \\ + --output_dir ./synthetic_output + + # Optional: replicate the legacy 3-digit vocabulary + python ehr_generation_mimic3_transformer.py \\ + --mimic3_root /path/to/mimic-iii \\ + --truncate_icd \\ + --output_dir ./synthetic_output_3digit +""" + +import argparse +import os + +import pandas as pd +import torch +from tokenizers import Tokenizer, models, pre_tokenizers, processors, trainers +from torch.utils.data import Dataset +from tqdm import trange +from transformers import ( + DataCollatorForLanguageModeling, + GPT2Config, + GPT2LMHeadModel, + PreTrainedTokenizerFast, + Trainer, + TrainingArguments, +) + +from pyhealth.datasets import MIMIC3Dataset, split_by_patient +from pyhealth.tasks import EHRGenerationMIMIC3 + +# ── Constants ───────────────────────────────────────────────────────────────── + +VISIT_DELIM = "VISIT_DELIM" + + +# ── 1. Sequence helpers ──────────────────────────────────────────────────────── + + +def samples_to_sequences(samples: list) -> list[str]: + """Convert PyHealth ``EHRGenerationMIMIC3`` samples to text sequences. + + Each sample's ``conditions`` field is a ``List[List[str]]`` (visits × codes). + Adjacent visits are joined by ``VISIT_DELIM`` so the full patient history + becomes a space-separated string. + + Args: + samples: List of dicts with at least a ``"conditions"`` key. + + Returns: + A list of strings, one per patient, e.g. + ``"250.00 401.9 VISIT_DELIM 272.0 428.0 VISIT_DELIM 250.00"``. + """ + sequences = [] + for sample in samples: + visit_texts = [" ".join(visit_codes) for visit_codes in sample["conditions"]] + sequences.append(f" {VISIT_DELIM} ".join(visit_texts)) + return sequences + + +def sequences_to_dataframe(sequences: list[str]) -> pd.DataFrame: + """Deserialise generated text sequences back to long-form ``(SUBJECT_ID, HADM_ID, ICD9_CODE)``. + + Assigns synthetic sequential identifiers; the real MIMIC-III IDs are not + recovered (generation is unconditional). + + Args: + sequences: Generated text sequences from the transformer. + + Returns: + A ``pd.DataFrame`` with columns ``SUBJECT_ID``, ``HADM_ID``, ``ICD9_CODE``. + """ + rows = [] + for subj_idx, seq in enumerate(sequences): + for hadm_idx, visit_str in enumerate(seq.strip().split(VISIT_DELIM)): + for code in visit_str.strip().split(): + if code: + rows.append( + { + "SUBJECT_ID": subj_idx, + "HADM_ID": hadm_idx, + "ICD9_CODE": code, + } + ) + return pd.DataFrame(rows) + + +# ── 2. PyTorch Dataset ───────────────────────────────────────────────────────── + + +class EHRTextDataset(Dataset): + """Tokenises a list of EHR text sequences for causal language modelling. + + Args: + sequences: Plain-text patient sequences (one string per patient). + tokenizer: A HuggingFace ``PreTrainedTokenizerFast``. + max_length: Maximum token length; longer sequences are truncated. + """ + + def __init__( + self, + sequences: list[str], + tokenizer: PreTrainedTokenizerFast, + max_length: int = 512, + ) -> None: + self.input_ids = [] + for txt in sequences: + enc = tokenizer(txt, truncation=True, max_length=max_length, padding="max_length") + self.input_ids.append(torch.tensor(enc["input_ids"])) + + def __len__(self) -> int: + return len(self.input_ids) + + def __getitem__(self, idx: int) -> dict: + return {"input_ids": self.input_ids[idx], "labels": self.input_ids[idx]} + + +# ── 3. Tokeniser builder ─────────────────────────────────────────────────────── + + +def build_tokenizer(text_data: list[str]) -> PreTrainedTokenizerFast: + """Build and train a word-level tokeniser on the EHR text corpus. + + Special tokens: + * ``[UNK]`` – unknown token + * ``[PAD]`` – padding + * ``[BOS]`` – beginning-of-sequence + * ``[EOS]`` – end-of-sequence + + The ``VISIT_DELIM`` delimiter token is treated as a regular vocabulary + word so that its visit-boundary semantics are learned by the model. + + Args: + text_data: List of space-separated code sequences. + + Returns: + A ``PreTrainedTokenizerFast`` wrapping the trained word-level tokeniser. + """ + tokenizer_obj = Tokenizer(models.WordLevel(unk_token="[UNK]")) + tokenizer_obj.pre_tokenizer = pre_tokenizers.Whitespace() + + special_tokens = ["[UNK]", "[PAD]", "[BOS]", "[EOS]"] + word_trainer = trainers.WordLevelTrainer(special_tokens=special_tokens) + tokenizer_obj.train_from_iterator(text_data, trainer=word_trainer) + + tokenizer_obj.post_processor = processors.TemplateProcessing( + single="[BOS] $A [EOS]", + special_tokens=[ + ("[BOS]", tokenizer_obj.token_to_id("[BOS]")), + ("[EOS]", tokenizer_obj.token_to_id("[EOS]")), + ], + ) + + return PreTrainedTokenizerFast( + tokenizer_object=tokenizer_obj, + unk_token="[UNK]", + pad_token="[PAD]", + bos_token="[BOS]", + eos_token="[EOS]", + ) + + +# ── 4. Main pipeline ─────────────────────────────────────────────────────────── + + +def main(args: argparse.Namespace) -> None: + os.makedirs(args.output_dir, exist_ok=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # ------------------------------------------------------------------ + # STEP 1: Load MIMIC-III via PyHealth + # ------------------------------------------------------------------ + print("\nSTEP 1: Loading MIMIC-III dataset …") + base_dataset = MIMIC3Dataset( + root=args.mimic3_root, + tables=["diagnoses_icd"], + ) + base_dataset.stats() + + # ------------------------------------------------------------------ + # STEP 2: Apply EHRGenerationMIMIC3 task + # ------------------------------------------------------------------ + print("\nSTEP 2: Applying EHRGenerationMIMIC3 task …") + task = EHRGenerationMIMIC3( + min_visits=args.min_visits, + truncate_icd=args.truncate_icd, + ) + sample_dataset = base_dataset.set_task(task) + print(f" Total patients: {len(sample_dataset)}") + + train_dataset, _, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1]) + print(f" Training patients: {len(train_dataset)}") + + # ------------------------------------------------------------------ + # STEP 3: Serialise to text sequences + # ------------------------------------------------------------------ + print("\nSTEP 3: Serialising patient sequences …") + train_samples = list(train_dataset) + text_data = samples_to_sequences(train_samples) + max_len = max(len(seq.split()) for seq in text_data) + print(f" Max sequence length: {max_len} tokens") + + # ------------------------------------------------------------------ + # STEP 4: Build tokeniser + # ------------------------------------------------------------------ + print("\nSTEP 4: Building word-level tokeniser …") + tokenizer = build_tokenizer(text_data) + print(f" Vocabulary size: {len(tokenizer)}") + + train_torch_dataset = EHRTextDataset(text_data, tokenizer, max_length=args.max_seq_len) + + # ------------------------------------------------------------------ + # STEP 5: Initialise GPT-2 style decoder model + # ------------------------------------------------------------------ + print("\nSTEP 5: Initialising GPT-2 model …") + config = GPT2Config( + vocab_size=len(tokenizer), + n_positions=args.max_seq_len, + n_ctx=args.max_seq_len, + n_embd=512, + n_layer=8, + n_head=8, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + model = GPT2LMHeadModel(config).to(device) + num_params = sum(p.numel() for p in model.parameters()) / 1e6 + print(f" Model parameters: {num_params:.1f}M") + + # ------------------------------------------------------------------ + # STEP 6: Train + # ------------------------------------------------------------------ + print("\nSTEP 6: Training …") + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + training_args = TrainingArguments( + output_dir=os.path.join(args.output_dir, "checkpoints"), + overwrite_output_dir=True, + num_train_epochs=args.epochs, + per_device_train_batch_size=args.batch_size, + logging_steps=50, + learning_rate=1e-4, + lr_scheduler_type="cosine", + warmup_steps=100, + use_cpu=not torch.cuda.is_available(), + save_strategy="epoch", + ) + + hf_trainer = Trainer( + model=model, + args=training_args, + data_collator=data_collator, + train_dataset=train_torch_dataset, + ) + hf_trainer.train() + + model_save_path = os.path.join(args.output_dir, "transformer_ehr_model") + hf_trainer.save_model(model_save_path) + print(f" Model saved to: {model_save_path}") + + # ------------------------------------------------------------------ + # STEP 7: Generate synthetic EHRs + # ------------------------------------------------------------------ + print(f"\nSTEP 7: Generating {args.num_synthetic} synthetic patients …") + model.eval() + + all_syn: list[pd.DataFrame] = [] + start_subj_id = 0 + for batch_start in trange(0, args.num_synthetic, args.gen_batch_size): + batch_end = min(batch_start + args.gen_batch_size, args.num_synthetic) + bsz = batch_end - batch_start + + batch_input_ids = torch.tensor( + [[tokenizer.bos_token_id]] * bsz, device=device + ) + with torch.no_grad(): + generated = model.generate( + batch_input_ids, + max_new_tokens=args.max_seq_len, + do_sample=True, + top_k=50, + top_p=0.95, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + decoded = [ + tokenizer.decode(seq, skip_special_tokens=True) for seq in generated + ] + batch_df = sequences_to_dataframe(decoded) + batch_df["SUBJECT_ID"] += start_subj_id + start_subj_id += bsz + all_syn.append(batch_df) + + synthetic_df = pd.concat(all_syn, ignore_index=True) + print(f" Generated {synthetic_df['SUBJECT_ID'].nunique()} patients, " + f"{synthetic_df.shape[0]} (patient, visit, code) rows") + + out_csv = os.path.join(args.output_dir, "synthetic_ehr.csv") + synthetic_df.to_csv(out_csv, index=False) + print(f" Synthetic data saved to: {out_csv}") + + +# ── CLI entry point ──────────────────────────────────────────────────────────── + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Train a GPT-2 transformer for synthetic EHR generation (MIMIC-III).", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--mimic3_root", + type=str, + required=True, + help="Path to the MIMIC-III root directory containing raw CSV/CSV.GZ files.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="./ehr_generation_output", + help="Directory to save the trained model and synthetic data.", + ) + parser.add_argument( + "--min_visits", + type=int, + default=1, + help="Minimum number of valid admissions a patient must have.", + ) + parser.add_argument( + "--truncate_icd", + action="store_true", + default=False, + help="Truncate ICD-9 codes to 3-digit prefixes (reduces vocab to ~1,071 codes).", + ) + parser.add_argument( + "--max_seq_len", + type=int, + default=512, + help="Maximum token sequence length.", + ) + parser.add_argument( + "--epochs", + type=int, + default=50, + help="Number of training epochs.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=64, + help="Training batch size.", + ) + parser.add_argument( + "--num_synthetic", + type=int, + default=10000, + help="Number of synthetic patients to generate.", + ) + parser.add_argument( + "--gen_batch_size", + type=int, + default=512, + help="Generation batch size.", + ) + main(parser.parse_args()) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 2f4294a19..23688453a 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -68,3 +68,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .ehr_generation import EHRGenerationMIMIC3 diff --git a/pyhealth/tasks/ehr_generation.py b/pyhealth/tasks/ehr_generation.py new file mode 100644 index 000000000..77a0408f4 --- /dev/null +++ b/pyhealth/tasks/ehr_generation.py @@ -0,0 +1,143 @@ +from typing import Any, Dict, List + +from pyhealth.tasks.base_task import BaseTask + + +class EHRGenerationMIMIC3(BaseTask): + """Task for training synthetic EHR generative models using MIMIC-III. + + Transforms longitudinal patient records into a visit-sequence representation + suitable for generative modeling. Each sample corresponds to one patient and + captures the complete temporal trajectory of ICD-9 diagnosis codes across + admissions. + + Two downstream representations can be derived from the output: + + * **Sequential** (PromptEHR, HALO, GPT): the nested ``conditions`` list + retains full visit boundaries and ordering. + * **Matrix / flattened** (MedGAN, CorGAN): flatten ``conditions`` into a + single list (binary presence) or count vector per patient. + + For standardised evaluation, every synthetic or real record should be + converted to a long-format schema of ``(subject_id, visit_id, code)`` + triplets as recommended by the paper *Accelerating Reproducible Research in + Synthetic EHR Generation*. + + Attributes: + task_name (str): Identifier for this task. + input_schema (Dict[str, str]): ``{"conditions": "nested_sequence"}`` – + tells PyHealth's processor to serialise the variable-length nested + visit list correctly (same convention as + ``DrugRecommendationMIMIC3``). + output_schema (Dict[str, str]): ``{}`` – no supervised label is + produced. + min_visits (int): Minimum number of valid visits a patient must have + to be included. Defaults to ``1``. + truncate_icd (bool): When ``True``, ICD-9 codes are truncated to the + first 3 characters (e.g. ``"250.40"`` → ``"250"``), reducing the + vocabulary from ~6,955 to 1,071 codes. The paper recommends + keeping ``False`` for full clinical fidelity. Defaults to + ``False``. + + Note: + A full end-to-end training example using a GPT-2 style decoder can be + found at ``examples/ehr_generation/ehr_generation_mimic3_transformer.py``. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import EHRGenerationMIMIC3 + >>> dataset = MIMIC3Dataset( + ... root="/path/to/mimic-iii/1.4", + ... tables=["diagnoses_icd"], + ... ) + >>> task = EHRGenerationMIMIC3() + >>> samples = dataset.set_task(task) + >>> # Each sample: {patient_id, conditions, num_visits} + >>> # conditions is a list of visits; each visit is a list of ICD-9 codes + >>> # e.g. [["250.00", "401.9"], ["272.0", "428.0"]] + """ + + task_name: str = "EHRGenerationMIMIC3" + input_schema: Dict[str, str] = {"conditions": "nested_sequence"} + output_schema: Dict[str, str] = {} + + def __init__( + self, + min_visits: int = 1, + truncate_icd: bool = False, + ) -> None: + """Initialise the task. + + Args: + min_visits (int): Patients with fewer than ``min_visits`` valid + admissions (i.e. admissions that contain at least one ICD-9 + code) are excluded. Defaults to ``1``. + truncate_icd (bool): Truncate ICD-9 codes to 3-digit prefixes. + Useful for reproducing prior work that caps the vocabulary at + 1,071 codes. Defaults to ``False`` (full 6,955-code vocabulary). + """ + self.min_visits = min_visits + self.truncate_icd = truncate_icd + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process a single patient and return a list with one generation sample. + + Each returned sample represents the patient's full longitudinal record + as a nested list of ICD-9 diagnosis code sequences, one inner list per + hospital admission (visit). + + Admissions with no ICD-9 codes are silently skipped. Patients with + fewer valid visits than ``self.min_visits`` return an empty list. + + Args: + patient: A PyHealth ``Patient`` object providing a + ``get_events(event_type, filters)`` interface. + + Returns: + A list containing a single dict with: + + * ``patient_id`` (str): MIMIC-III ``subject_id``. + * ``conditions`` (List[List[str]]): Nested list of ICD-9 diagnosis + codes, grouped by admission. Outer index = visit order (chronological); + inner index = code index within that visit. + The number of visits can be derived as ``len(conditions)``. + """ + admissions = patient.get_events(event_type="admissions") + + visit_sequences: List[List[str]] = [] + + for admission in admissions: + diagnoses = patient.get_events( + event_type="diagnoses_icd", + filters=[("hadm_id", "==", admission.hadm_id)], + ) + + codes = [event.icd9_code for event in diagnoses if event.icd9_code] + + if self.truncate_icd: + codes = [code[:3] for code in codes] + + # Deduplicate while preserving order + seen: set = set() + unique_codes: List[str] = [] + for code in codes: + if code not in seen: + seen.add(code) + unique_codes.append(code) + codes = unique_codes + + if not codes: + continue + + visit_sequences.append(codes) + + if len(visit_sequences) < self.min_visits: + return [] + + return [ + { + "patient_id": patient.patient_id, + "conditions": visit_sequences, + } + ] + diff --git a/tests/core/test_mimic3_ehr_generation.py b/tests/core/test_mimic3_ehr_generation.py new file mode 100644 index 000000000..a5dd8b66d --- /dev/null +++ b/tests/core/test_mimic3_ehr_generation.py @@ -0,0 +1,337 @@ +"""Tests for EHRGenerationMIMIC3 task and the sequence helper utilities. + +All tests use a fully synthetic mock dataset so no real MIMIC-III files or +PyHealth's set_task() / litdata pipeline are required. +""" + +import unittest + +import pandas as pd + +from pyhealth.tasks import EHRGenerationMIMIC3 + +# ── visit-delimiter helpers (mirrored from the example script) ───────────────── + +VISIT_DELIM = "VISIT_DELIM" + + +def samples_to_sequences(samples: list) -> list: + """Nested visit list → VISIT_DELIM-delimited text string per patient.""" + sequences = [] + for sample in samples: + visit_texts = [" ".join(visit_codes) for visit_codes in sample["conditions"]] + sequences.append(f" {VISIT_DELIM} ".join(visit_texts)) + return sequences + + +def sequences_to_dataframe(sequences: list) -> pd.DataFrame: + """Text sequences → long-form (SUBJECT_ID, HADM_ID, ICD9_CODE) DataFrame.""" + rows = [] + for subj_idx, seq in enumerate(sequences): + for hadm_idx, visit_str in enumerate(seq.strip().split(VISIT_DELIM)): + for code in visit_str.strip().split(): + if code: + rows.append( + {"SUBJECT_ID": subj_idx, "HADM_ID": hadm_idx, "ICD9_CODE": code} + ) + return pd.DataFrame(rows) + + +# ── minimal mock objects that mimic PyHealth's Patient/Event interface ───────── + + +class _MockAdmission: + """Lightweight stand-in for a MIMIC-III admission event.""" + + def __init__(self, hadm_id: str) -> None: + self.hadm_id = hadm_id + + +class _MockDiagnosis: + """Lightweight stand-in for a diagnoses_icd event.""" + + def __init__(self, hadm_id: str, icd9_code: str) -> None: + self.hadm_id = hadm_id + self.icd9_code = icd9_code + + +class _MockPatient: + """Mimics BasePatient.get_events() for admissions and diagnoses_icd tables.""" + + def __init__(self, patient_id: str, visits: list) -> None: + """ + Args: + patient_id: Synthetic subject_id string. + visits: List of visits; each visit is a list of ICD-9 code strings. + Duplicates within a visit are intentional to test dedup logic. + """ + self.patient_id = patient_id + self._admissions = [ + _MockAdmission(hadm_id=str(100 + i)) for i in range(len(visits)) + ] + self._diagnoses = [] + for admission, codes in zip(self._admissions, visits): + for code in codes: + self._diagnoses.append(_MockDiagnosis(admission.hadm_id, code)) + + def get_events(self, event_type: str, filters=None): + if event_type == "admissions": + return list(self._admissions) + if event_type == "diagnoses_icd": + result = list(self._diagnoses) + if filters: + for field, op, value in filters: + if op == "==": + result = [e for e in result if getattr(e, field) == value] + return result + return [] + + +# ── synthetic patient corpus ─────────────────────────────────────────────────── + +_PATIENTS = { + # 3 visits, no duplicates + "P001": [ + ["250.00", "401.9", "278.00"], + ["250.00", "272.0"], + ["428.0", "401.9", "285.9"], + ], + # 2 visits with intentional within-visit duplicates + "P002": [ + ["410.01", "410.01", "412"], # 410.01 duplicated intentionally + ["414.01", "V45.81"], + ], + # 1 visit (used to test min_visits filtering) + "P003": [ + ["486", "518.81"], + ], + # 4 visits with long codes (used for truncate_icd tests) + "P004": [ + ["250.40", "250.00"], + ["401.10", "401.90"], + ["428.00"], + ["272.00", "272.10"], + ], + # patient with some empty visits (should be silently skipped) + "P005": [ + [], # empty → skipped + ["V15.82"], + [], # empty → skipped + ["401.9"], + ], +} + +ALL_PATIENTS = [_MockPatient(pid, visits) for pid, visits in _PATIENTS.items()] + + +# ── test class ───────────────────────────────────────────────────────────────── + + +class TestEHRGenerationMIMIC3Task(unittest.TestCase): + """Unit tests for EHRGenerationMIMIC3 using synthetic mock patients.""" + + def _run_task(self, task, patients=None): + """Helper: run task over a list of mock patients, flatten results.""" + if patients is None: + patients = ALL_PATIENTS + samples = [] + for p in patients: + samples.extend(task(p)) + return samples + + # ── schema / init ────────────────────────────────────────────────────────── + + def test_task_name(self): + self.assertEqual(EHRGenerationMIMIC3.task_name, "EHRGenerationMIMIC3") + + def test_input_schema(self): + # nested_sequence required so PyHealth's processor handles variable-length visits + self.assertEqual(EHRGenerationMIMIC3.input_schema, {"conditions": "nested_sequence"}) + + def test_output_schema(self): + self.assertEqual(EHRGenerationMIMIC3.output_schema, {}) + + def test_default_init(self): + task = EHRGenerationMIMIC3() + self.assertEqual(task.min_visits, 1) + self.assertFalse(task.truncate_icd) + + def test_custom_init(self): + task = EHRGenerationMIMIC3(min_visits=3, truncate_icd=True) + self.assertEqual(task.min_visits, 3) + self.assertTrue(task.truncate_icd) + + # ── per-patient __call__ output ──────────────────────────────────────────── + + def test_returns_one_sample_per_patient(self): + """Each qualifying patient produces exactly one sample dict.""" + task = EHRGenerationMIMIC3() + for patient in ALL_PATIENTS: + result = task(patient) + self.assertIn(len(result), (0, 1)) + + def test_sample_keys_present(self): + """Each sample must have patient_id and conditions keys.""" + task = EHRGenerationMIMIC3() + samples = self._run_task(task) + self.assertGreater(len(samples), 0) + for sample in samples: + self.assertIn("patient_id", sample) + self.assertIn("conditions", sample) + + def test_patient_id_matches(self): + """sample['patient_id'] must equal the originating patient id.""" + task = EHRGenerationMIMIC3() + for patient in ALL_PATIENTS: + for sample in task(patient): + self.assertEqual(sample["patient_id"], patient.patient_id) + + def test_conditions_is_nested_list_of_strings(self): + """conditions must be List[List[str]] with no empty inner lists.""" + task = EHRGenerationMIMIC3() + samples = self._run_task(task) + for sample in samples: + conds = sample["conditions"] + self.assertIsInstance(conds, list) + self.assertGreater(len(conds), 0) + for visit in conds: + self.assertIsInstance(visit, list) + self.assertGreater(len(visit), 0, "Empty visits must be dropped") + for code in visit: + self.assertIsInstance(code, str) + self.assertGreater(len(code), 0) + + def test_empty_visits_skipped(self): + """Admissions with no ICD-9 codes are silently skipped.""" + task = EHRGenerationMIMIC3() + p005 = _MockPatient("P005", _PATIENTS["P005"]) + result = task(p005) + self.assertEqual(len(result), 1) + # 4 admissions, 2 empty → 2 valid visits + self.assertEqual(len(result[0]["conditions"]), 2) + + def test_within_visit_deduplication(self): + """Duplicate ICD-9 codes within a single visit are removed.""" + task = EHRGenerationMIMIC3() + p002 = _MockPatient("P002", _PATIENTS["P002"]) + result = task(p002) + self.assertEqual(len(result), 1) + for visit in result[0]["conditions"]: + self.assertEqual(len(visit), len(set(visit)), + f"Duplicate codes in visit: {visit}") + + def test_visit_order_preserved(self): + """Visits appear in the same order they were supplied.""" + task = EHRGenerationMIMIC3() + p001 = _MockPatient("P001", _PATIENTS["P001"]) + result = task(p001) + self.assertIn("250.00", result[0]["conditions"][0]) + self.assertIn("428.0", result[0]["conditions"][2]) + + def test_conditions_length_matches_nonempty_visits(self): + """len(conditions) equals number of non-empty visits.""" + task = EHRGenerationMIMIC3() + self.assertEqual(len(task(_MockPatient("P001", _PATIENTS["P001"]))[0]["conditions"]), 3) + self.assertEqual(len(task(_MockPatient("P005", _PATIENTS["P005"]))[0]["conditions"]), 2) + + # ── min_visits filtering ─────────────────────────────────────────────────── + + def test_min_visits_1_includes_single_visit_patient(self): + task = EHRGenerationMIMIC3(min_visits=1) + self.assertEqual(len(task(_MockPatient("P003", _PATIENTS["P003"]))), 1) + + def test_min_visits_2_excludes_single_visit_patient(self): + task = EHRGenerationMIMIC3(min_visits=2) + self.assertEqual(len(task(_MockPatient("P003", _PATIENTS["P003"]))), 0) + + def test_min_visits_2_keeps_multi_visit_patient(self): + task = EHRGenerationMIMIC3(min_visits=2) + self.assertEqual(len(task(_MockPatient("P001", _PATIENTS["P001"]))), 1) + + def test_min_visits_too_high_returns_empty_for_all(self): + task = EHRGenerationMIMIC3(min_visits=10) + self.assertEqual(self._run_task(task), []) + + # ── truncate_icd ─────────────────────────────────────────────────────────── + + def test_truncate_icd_shortens_codes_to_3_chars(self): + """All codes must be ≤ 3 characters when truncate_icd=True.""" + task = EHRGenerationMIMIC3(truncate_icd=True) + for sample in self._run_task(task): + for visit in sample["conditions"]: + for code in visit: + self.assertLessEqual(len(code), 3, + f"Code '{code}' exceeds 3 chars") + + def test_truncate_icd_false_preserves_full_codes(self): + """Codes longer than 3 chars must survive when truncate_icd=False.""" + task = EHRGenerationMIMIC3(truncate_icd=False) + result = task(_MockPatient("P004", _PATIENTS["P004"])) + all_codes = [c for visit in result[0]["conditions"] for c in visit] + self.assertTrue(any(len(c) > 3 for c in all_codes), + "Expected full-length codes like '250.40'") + + def test_truncate_icd_dedup_after_merge(self): + """After truncation, merged codes are deduplicated within each visit.""" + # visit 0 of P004: "250.40" and "250.00" both → "250" (only one should survive) + task = EHRGenerationMIMIC3(truncate_icd=True) + result = task(_MockPatient("P004", _PATIENTS["P004"])) + visit_0 = result[0]["conditions"][0] + self.assertEqual(visit_0, ["250"]) + + # ── edge cases ───────────────────────────────────────────────────────────── + + def test_all_empty_visits_returns_empty(self): + task = EHRGenerationMIMIC3() + self.assertEqual(task(_MockPatient("PEMPTY", [[], [], []])), []) + + def test_no_visits_returns_empty(self): + task = EHRGenerationMIMIC3() + self.assertEqual(task(_MockPatient("PNONE", [])), []) + + # ── sequence helper: samples_to_sequences ───────────────────────────────── + + def test_samples_to_sequences_one_string_per_sample(self): + samples = self._run_task(EHRGenerationMIMIC3()) + seqs = samples_to_sequences(samples) + self.assertEqual(len(seqs), len(samples)) + for seq in seqs: + self.assertIsInstance(seq, str) + self.assertGreater(len(seq.strip()), 0) + + def test_samples_to_sequences_delimiter_present_for_multi_visit(self): + sample = EHRGenerationMIMIC3()(_MockPatient("P001", _PATIENTS["P001"]))[0] + self.assertIn(VISIT_DELIM, samples_to_sequences([sample])[0]) + + def test_samples_to_sequences_no_delimiter_for_single_visit(self): + sample = EHRGenerationMIMIC3()(_MockPatient("P003", _PATIENTS["P003"]))[0] + self.assertNotIn(VISIT_DELIM, samples_to_sequences([sample])[0]) + + # ── sequence helper: sequences_to_dataframe ─────────────────────────────── + + def test_sequences_to_dataframe_columns(self): + samples = self._run_task(EHRGenerationMIMIC3()) + df = sequences_to_dataframe(samples_to_sequences(samples)) + for col in ("SUBJECT_ID", "HADM_ID", "ICD9_CODE"): + self.assertIn(col, df.columns) + + def test_round_trip_all_codes_preserved(self): + """Every code in the original samples must appear in the recovered DataFrame.""" + samples = self._run_task(EHRGenerationMIMIC3()) + df = sequences_to_dataframe(samples_to_sequences(samples)) + original = {c for s in samples for visit in s["conditions"] for c in visit} + recovered = set(df["ICD9_CODE"].tolist()) + self.assertEqual(original, recovered) + + def test_round_trip_visit_count_per_patient(self): + """The DataFrame must reconstruct the correct visit count per patient.""" + samples = self._run_task(EHRGenerationMIMIC3()) + df = sequences_to_dataframe(samples_to_sequences(samples)) + for idx, sample in enumerate(samples): + syn_visits = df[df["SUBJECT_ID"] == idx]["HADM_ID"].nunique() + self.assertEqual(syn_visits, len(sample["conditions"]), + f"Visit count mismatch at sample index {idx}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_transformer_ehr_helpers.py b/tests/core/test_transformer_ehr_helpers.py new file mode 100644 index 000000000..33bf870aa --- /dev/null +++ b/tests/core/test_transformer_ehr_helpers.py @@ -0,0 +1,308 @@ +""" +Tests for the utility functions and classes defined in +examples/ehr_generation/ehr_generation_mimic3_transformer.py + +Covered: +* ``samples_to_sequences`` – nested visit lists → text strings +* ``sequences_to_dataframe`` – text strings → long-form DataFrame +* ``build_tokenizer`` – word-level HuggingFace tokenizer +* ``EHRTextDataset`` – PyTorch Dataset wrapping tokenized EHR sequences +""" + +import sys +import os + +import pytest +import torch + +# Allow importing directly from the examples directory without installing it. +sys.path.insert( + 0, + os.path.join(os.path.dirname(__file__), "../../examples/ehr_generation"), +) + +from ehr_generation_mimic3_transformer import ( # noqa: E402 + VISIT_DELIM, + EHRTextDataset, + build_tokenizer, + samples_to_sequences, + sequences_to_dataframe, +) + +# ───────────────────────────────────────────────────────────────────────────── +# Shared fixtures +# ───────────────────────────────────────────────────────────────────────────── + +_SINGLE_VISIT_SAMPLE = {"conditions": [["250.00", "401.9"]]} +_MULTI_VISIT_SAMPLE = {"conditions": [["250.00", "401.9"], ["272.0", "428.0"], ["250.00"]]} +_EMPTY_VISIT_SAMPLE = {"conditions": []} # patient with no visits + + +# ───────────────────────────────────────────────────────────────────────────── +# 1. samples_to_sequences +# ───────────────────────────────────────────────────────────────────────────── + + +class TestSamplesToSequences: + def test_returns_one_string_per_sample(self): + samples = [_SINGLE_VISIT_SAMPLE, _MULTI_VISIT_SAMPLE] + result = samples_to_sequences(samples) + assert len(result) == 2 + + def test_empty_input_returns_empty_list(self): + assert samples_to_sequences([]) == [] + + def test_single_visit_no_delimiter(self): + result = samples_to_sequences([_SINGLE_VISIT_SAMPLE]) + assert VISIT_DELIM not in result[0] + + def test_multi_visit_delimiter_count_matches(self): + # 3 visits → 2 VISIT_DELIM occurrences + result = samples_to_sequences([_MULTI_VISIT_SAMPLE]) + assert result[0].count(VISIT_DELIM) == 2 + + def test_codes_present_in_output(self): + result = samples_to_sequences([_MULTI_VISIT_SAMPLE]) + for visit in _MULTI_VISIT_SAMPLE["conditions"]: + for code in visit: + assert code in result[0] + + def test_single_visit_codes_space_separated(self): + result = samples_to_sequences([_SINGLE_VISIT_SAMPLE]) + assert result[0] == "250.00 401.9" + + def test_multi_visit_format(self): + result = samples_to_sequences([_MULTI_VISIT_SAMPLE]) + expected = f"250.00 401.9 {VISIT_DELIM} 272.0 428.0 {VISIT_DELIM} 250.00" + assert result[0] == expected + + def test_empty_conditions_yields_empty_string(self): + result = samples_to_sequences([_EMPTY_VISIT_SAMPLE]) + assert result[0] == "" + + def test_single_code_per_visit(self): + sample = {"conditions": [["A"], ["B"], ["C"]]} + result = samples_to_sequences([sample]) + assert result[0] == f"A {VISIT_DELIM} B {VISIT_DELIM} C" + + def test_multiple_samples_independent(self): + samples = [_SINGLE_VISIT_SAMPLE, _MULTI_VISIT_SAMPLE] + result = samples_to_sequences(samples) + assert result[0] != result[1] + + def test_output_is_list_of_strings(self): + result = samples_to_sequences([_SINGLE_VISIT_SAMPLE]) + assert isinstance(result, list) + assert all(isinstance(s, str) for s in result) + + +# ───────────────────────────────────────────────────────────────────────────── +# 2. sequences_to_dataframe +# ───────────────────────────────────────────────────────────────────────────── + + +class TestSequencesToDataframe: + _SEQ_SINGLE = "250.00 401.9" + _SEQ_MULTI = f"250.00 401.9 {VISIT_DELIM} 272.0 428.0" + + def test_required_columns_present(self): + df = sequences_to_dataframe([self._SEQ_SINGLE]) + assert set(df.columns) == {"SUBJECT_ID", "HADM_ID", "ICD9_CODE"} + + def test_empty_input_returns_empty_dataframe(self): + df = sequences_to_dataframe([]) + assert df.empty + assert list(df.columns) == [] # pd.concat on empty list → empty DF + + def test_single_visit_produces_correct_codes(self): + df = sequences_to_dataframe([self._SEQ_SINGLE]) + codes = set(df["ICD9_CODE"].tolist()) + assert codes == {"250.00", "401.9"} + + def test_single_visit_single_hadm_id(self): + df = sequences_to_dataframe([self._SEQ_SINGLE]) + assert df["HADM_ID"].nunique() == 1 + assert df["HADM_ID"].iloc[0] == 0 + + def test_multi_visit_hadm_ids(self): + df = sequences_to_dataframe([self._SEQ_MULTI]) + assert set(df["HADM_ID"].tolist()) == {0, 1} + + def test_subject_ids_sequential(self): + df = sequences_to_dataframe([self._SEQ_SINGLE, self._SEQ_SINGLE]) + assert set(df["SUBJECT_ID"].tolist()) == {0, 1} + + def test_multi_patient_subject_id_mapping(self): + df = sequences_to_dataframe([self._SEQ_MULTI, self._SEQ_SINGLE]) + assert df[df["SUBJECT_ID"] == 0]["HADM_ID"].nunique() == 2 + assert df[df["SUBJECT_ID"] == 1]["HADM_ID"].nunique() == 1 + + def test_row_count_matches_codes(self): + # seq has 4 codes across 2 visits + df = sequences_to_dataframe([self._SEQ_MULTI]) + assert len(df) == 4 + + def test_whitespace_only_sequence_returns_empty(self): + df = sequences_to_dataframe([" "]) + assert df.empty + + def test_round_trip_from_samples(self): + samples = [_MULTI_VISIT_SAMPLE] + seqs = samples_to_sequences(samples) + df = sequences_to_dataframe(seqs) + all_codes = set( + code + for visit in _MULTI_VISIT_SAMPLE["conditions"] + for code in visit + ) + recovered_codes = set(df["ICD9_CODE"].tolist()) + assert all_codes == recovered_codes + + def test_round_trip_visit_count(self): + samples = [_MULTI_VISIT_SAMPLE] + seqs = samples_to_sequences(samples) + df = sequences_to_dataframe(seqs) + n_visits = df.groupby("SUBJECT_ID")["HADM_ID"].nunique().iloc[0] + assert n_visits == len(_MULTI_VISIT_SAMPLE["conditions"]) + + +# ───────────────────────────────────────────────────────────────────────────── +# 3. build_tokenizer +# ───────────────────────────────────────────────────────────────────────────── + +_CORPUS = [ + "250.00 401.9 VISIT_DELIM 272.0", + "428.0 VISIT_DELIM 250.00", + "401.9 272.0 428.0", +] + + +class TestBuildTokenizer: + @pytest.fixture(scope="class") + def tokenizer(self): + return build_tokenizer(_CORPUS) + + def test_special_tokens_in_vocab(self, tokenizer): + for tok in ("[UNK]", "[PAD]", "[BOS]", "[EOS]"): + assert tok in tokenizer.get_vocab(), f"{tok!r} missing from vocab" + + def test_visit_delim_in_vocab(self, tokenizer): + assert VISIT_DELIM in tokenizer.get_vocab() + + def test_medical_codes_in_vocab(self, tokenizer): + # The Whitespace pre-tokenizer splits on punctuation, so "250.00" becomes + # the sub-tokens ["250", ".", "00"]. Assert each constituent sub-token + # (digits and the dot) appears in the vocabulary instead of the full code. + vocab = tokenizer.get_vocab() + for sub in ["250", "00", "401", "9", "272", "0", "428", "."]: + assert sub in vocab, f"sub-token {sub!r} missing from vocab" + + def test_vocab_size_at_least_corpus_tokens(self, tokenizer): + # 4 special tokens + 5 unique code tokens + VISIT_DELIM = at least 10 + assert len(tokenizer) >= 10 + + def test_bos_eos_token_ids_set(self, tokenizer): + assert tokenizer.bos_token_id is not None + assert tokenizer.eos_token_id is not None + + def test_pad_token_id_set(self, tokenizer): + assert tokenizer.pad_token_id is not None + + def test_encode_includes_bos_eos(self, tokenizer): + ids = tokenizer("250.00 401.9")["input_ids"] + assert ids[0] == tokenizer.bos_token_id + assert ids[-1] == tokenizer.eos_token_id + + def test_encode_decode_roundtrip(self, tokenizer): + text = "250.00 401.9 VISIT_DELIM 272.0" + ids = tokenizer(text, add_special_tokens=True)["input_ids"] + decoded = tokenizer.decode(ids, skip_special_tokens=True) + # The Whitespace pre-tokenizer splits codes on '.', so the round-trip + # produces sub-tokens (e.g. "250 . 00" instead of "250.00"). Verify + # that all digit sub-tokens and the VISIT_DELIM are present. + for sub_token in ["250", "00", "401", "9", VISIT_DELIM, "272", "0"]: + assert sub_token in decoded.split(), f"{sub_token!r} missing from decoded" + + def test_unknown_token_maps_to_unk_id(self, tokenizer): + enc = tokenizer("UNKNOWN_CODE_XYZ")["input_ids"] + # Strip BOS/EOS; the middle token should be [UNK] + inner = enc[1:-1] + assert tokenizer.unk_token_id in inner + + def test_returns_pretrained_tokenizer_fast(self, tokenizer): + from transformers import PreTrainedTokenizerFast + + assert isinstance(tokenizer, PreTrainedTokenizerFast) + + +# ───────────────────────────────────────────────────────────────────────────── +# 4. EHRTextDataset +# ───────────────────────────────────────────────────────────────────────────── + +_SEQUENCES = [ + "250.00 401.9 VISIT_DELIM 272.0", + "428.0", + "401.9 272.0 428.0 VISIT_DELIM 250.00 VISIT_DELIM 272.0", +] +_MAX_LENGTH = 16 + + +class TestEHRTextDataset: + @pytest.fixture(scope="class") + def tokenizer(self): + return build_tokenizer(_SEQUENCES) + + @pytest.fixture(scope="class") + def dataset(self, tokenizer): + return EHRTextDataset(_SEQUENCES, tokenizer, max_length=_MAX_LENGTH) + + def test_len_matches_sequences(self, dataset): + assert len(dataset) == len(_SEQUENCES) + + def test_getitem_returns_dict(self, dataset): + item = dataset[0] + assert isinstance(item, dict) + + def test_getitem_has_input_ids_key(self, dataset): + assert "input_ids" in dataset[0] + + def test_getitem_has_labels_key(self, dataset): + assert "labels" in dataset[0] + + def test_input_ids_are_tensors(self, dataset): + assert isinstance(dataset[0]["input_ids"], torch.Tensor) + + def test_labels_are_tensors(self, dataset): + assert isinstance(dataset[0]["labels"], torch.Tensor) + + def test_input_ids_length_equals_max_length(self, dataset): + for i in range(len(dataset)): + assert dataset[i]["input_ids"].shape[0] == _MAX_LENGTH + + def test_labels_equal_input_ids(self, dataset): + item = dataset[0] + assert torch.equal(item["input_ids"], item["labels"]) + + def test_all_items_same_length(self, dataset): + lengths = {dataset[i]["input_ids"].shape[0] for i in range(len(dataset))} + assert len(lengths) == 1 # all padded/truncated to max_length + + def test_empty_sequences_list(self, tokenizer): + ds = EHRTextDataset([], tokenizer, max_length=_MAX_LENGTH) + assert len(ds) == 0 + + def test_single_sequence(self, tokenizer): + ds = EHRTextDataset(["250.00"], tokenizer, max_length=_MAX_LENGTH) + assert len(ds) == 1 + item = ds[0] + assert item["input_ids"].shape[0] == _MAX_LENGTH + + def test_long_sequence_truncated(self, tokenizer): + # Construct a sequence much longer than max_length + long_seq = " ".join(["250.00"] * 100) + ds = EHRTextDataset([long_seq], tokenizer, max_length=_MAX_LENGTH) + assert ds[0]["input_ids"].shape[0] == _MAX_LENGTH + + def test_index_out_of_range_raises(self, dataset): + with pytest.raises(IndexError): + _ = dataset[len(_SEQUENCES)] From 69caa1f9ded516f47a107980a5dcf795b84fc900 Mon Sep 17 00:00:00 2001 From: chufangao Date: Sun, 1 Mar 2026 04:50:14 -0600 Subject: [PATCH 2/2] Added EHR generation task and baseline model, with support for nested visit sequences and optional ICD-9 truncation. Updated imports and documentation accordingly. --- .../ehr_generation_mimic3_transformer.py | 323 ++--------- pyhealth/models/__init__.py | 7 + pyhealth/models/generators/__init__.py | 19 + pyhealth/models/generators/gpt_baseline.py | 506 ++++++++++++++++++ tests/core/test_transformer_ehr_helpers.py | 316 +++++------ 5 files changed, 721 insertions(+), 450 deletions(-) create mode 100644 pyhealth/models/generators/__init__.py create mode 100644 pyhealth/models/generators/gpt_baseline.py diff --git a/examples/ehr_generation/ehr_generation_mimic3_transformer.py b/examples/ehr_generation/ehr_generation_mimic3_transformer.py index ab32907f4..dc599551d 100644 --- a/examples/ehr_generation/ehr_generation_mimic3_transformer.py +++ b/examples/ehr_generation/ehr_generation_mimic3_transformer.py @@ -2,8 +2,8 @@ EHR Generation with a GPT-2 style Transformer on MIMIC-III (PyHealth) ====================================================================== -This example trains a GPT-2 style decoder-only transformer to synthesise -longitudinal patient EHR sequences consisting of ICD-9 diagnosis codes. +This example applies the :class:`~pyhealth.models.generators.EHRGPTBaseline` +model to MIMIC-III data and generates synthetic patient EHR sequences. The pipeline: @@ -11,9 +11,9 @@ to obtain per-patient nested visit sequences. 2. Serialise the nested sequences into plain text using ``VISIT_DELIM`` separators (e.g. ``"250.00 401.9 VISIT_DELIM 272.0 428.0"``). -3. Train a word-level GPT-2 model on the serialised text. -4. Sample synthetic text sequences and deserialise them back to a long-form - ``(SUBJECT_ID, HADM_ID, ICD9_CODE)`` DataFrame for downstream evaluation. +3. Train a word-level GPT-2 model via :meth:`EHRGPTBaseline.fit`. +4. Sample synthetic sequences via :meth:`EHRGPTBaseline.generate` and + save the resulting ``(SUBJECT_ID, HADM_ID, ICD9_CODE)`` DataFrame. References ---------- @@ -38,166 +38,21 @@ import argparse import os -import pandas as pd import torch -from tokenizers import Tokenizer, models, pre_tokenizers, processors, trainers -from torch.utils.data import Dataset -from tqdm import trange -from transformers import ( - DataCollatorForLanguageModeling, - GPT2Config, - GPT2LMHeadModel, - PreTrainedTokenizerFast, - Trainer, - TrainingArguments, -) from pyhealth.datasets import MIMIC3Dataset, split_by_patient +from pyhealth.models.generators import EHRGPTBaseline, samples_to_sequences from pyhealth.tasks import EHRGenerationMIMIC3 -# ── Constants ───────────────────────────────────────────────────────────────── - -VISIT_DELIM = "VISIT_DELIM" - - -# ── 1. Sequence helpers ──────────────────────────────────────────────────────── - - -def samples_to_sequences(samples: list) -> list[str]: - """Convert PyHealth ``EHRGenerationMIMIC3`` samples to text sequences. - - Each sample's ``conditions`` field is a ``List[List[str]]`` (visits × codes). - Adjacent visits are joined by ``VISIT_DELIM`` so the full patient history - becomes a space-separated string. - - Args: - samples: List of dicts with at least a ``"conditions"`` key. - - Returns: - A list of strings, one per patient, e.g. - ``"250.00 401.9 VISIT_DELIM 272.0 428.0 VISIT_DELIM 250.00"``. - """ - sequences = [] - for sample in samples: - visit_texts = [" ".join(visit_codes) for visit_codes in sample["conditions"]] - sequences.append(f" {VISIT_DELIM} ".join(visit_texts)) - return sequences - - -def sequences_to_dataframe(sequences: list[str]) -> pd.DataFrame: - """Deserialise generated text sequences back to long-form ``(SUBJECT_ID, HADM_ID, ICD9_CODE)``. - - Assigns synthetic sequential identifiers; the real MIMIC-III IDs are not - recovered (generation is unconditional). - - Args: - sequences: Generated text sequences from the transformer. - - Returns: - A ``pd.DataFrame`` with columns ``SUBJECT_ID``, ``HADM_ID``, ``ICD9_CODE``. - """ - rows = [] - for subj_idx, seq in enumerate(sequences): - for hadm_idx, visit_str in enumerate(seq.strip().split(VISIT_DELIM)): - for code in visit_str.strip().split(): - if code: - rows.append( - { - "SUBJECT_ID": subj_idx, - "HADM_ID": hadm_idx, - "ICD9_CODE": code, - } - ) - return pd.DataFrame(rows) - - -# ── 2. PyTorch Dataset ───────────────────────────────────────────────────────── - - -class EHRTextDataset(Dataset): - """Tokenises a list of EHR text sequences for causal language modelling. - - Args: - sequences: Plain-text patient sequences (one string per patient). - tokenizer: A HuggingFace ``PreTrainedTokenizerFast``. - max_length: Maximum token length; longer sequences are truncated. - """ - - def __init__( - self, - sequences: list[str], - tokenizer: PreTrainedTokenizerFast, - max_length: int = 512, - ) -> None: - self.input_ids = [] - for txt in sequences: - enc = tokenizer(txt, truncation=True, max_length=max_length, padding="max_length") - self.input_ids.append(torch.tensor(enc["input_ids"])) - - def __len__(self) -> int: - return len(self.input_ids) - - def __getitem__(self, idx: int) -> dict: - return {"input_ids": self.input_ids[idx], "labels": self.input_ids[idx]} - - -# ── 3. Tokeniser builder ─────────────────────────────────────────────────────── - - -def build_tokenizer(text_data: list[str]) -> PreTrainedTokenizerFast: - """Build and train a word-level tokeniser on the EHR text corpus. - - Special tokens: - * ``[UNK]`` – unknown token - * ``[PAD]`` – padding - * ``[BOS]`` – beginning-of-sequence - * ``[EOS]`` – end-of-sequence - - The ``VISIT_DELIM`` delimiter token is treated as a regular vocabulary - word so that its visit-boundary semantics are learned by the model. - - Args: - text_data: List of space-separated code sequences. - - Returns: - A ``PreTrainedTokenizerFast`` wrapping the trained word-level tokeniser. - """ - tokenizer_obj = Tokenizer(models.WordLevel(unk_token="[UNK]")) - tokenizer_obj.pre_tokenizer = pre_tokenizers.Whitespace() - - special_tokens = ["[UNK]", "[PAD]", "[BOS]", "[EOS]"] - word_trainer = trainers.WordLevelTrainer(special_tokens=special_tokens) - tokenizer_obj.train_from_iterator(text_data, trainer=word_trainer) - - tokenizer_obj.post_processor = processors.TemplateProcessing( - single="[BOS] $A [EOS]", - special_tokens=[ - ("[BOS]", tokenizer_obj.token_to_id("[BOS]")), - ("[EOS]", tokenizer_obj.token_to_id("[EOS]")), - ], - ) - - return PreTrainedTokenizerFast( - tokenizer_object=tokenizer_obj, - unk_token="[UNK]", - pad_token="[PAD]", - bos_token="[BOS]", - eos_token="[EOS]", - ) - - -# ── 4. Main pipeline ─────────────────────────────────────────────────────────── - def main(args: argparse.Namespace) -> None: os.makedirs(args.output_dir, exist_ok=True) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") + print(f"Using device: {'cuda' if torch.cuda.is_available() else 'cpu'}") # ------------------------------------------------------------------ # STEP 1: Load MIMIC-III via PyHealth # ------------------------------------------------------------------ - print("\nSTEP 1: Loading MIMIC-III dataset …") + print("\nSTEP 1: Loading MIMIC-III dataset ...") base_dataset = MIMIC3Dataset( root=args.mimic3_root, tables=["diagnoses_icd"], @@ -207,7 +62,7 @@ def main(args: argparse.Namespace) -> None: # ------------------------------------------------------------------ # STEP 2: Apply EHRGenerationMIMIC3 task # ------------------------------------------------------------------ - print("\nSTEP 2: Applying EHRGenerationMIMIC3 task …") + print("\nSTEP 2: Applying EHRGenerationMIMIC3 task ...") task = EHRGenerationMIMIC3( min_visits=args.min_visits, truncate_icd=args.truncate_icd, @@ -221,115 +76,51 @@ def main(args: argparse.Namespace) -> None: # ------------------------------------------------------------------ # STEP 3: Serialise to text sequences # ------------------------------------------------------------------ - print("\nSTEP 3: Serialising patient sequences …") + print("\nSTEP 3: Serialising patient sequences ...") train_samples = list(train_dataset) text_data = samples_to_sequences(train_samples) max_len = max(len(seq.split()) for seq in text_data) - print(f" Max sequence length: {max_len} tokens") - - # ------------------------------------------------------------------ - # STEP 4: Build tokeniser - # ------------------------------------------------------------------ - print("\nSTEP 4: Building word-level tokeniser …") - tokenizer = build_tokenizer(text_data) - print(f" Vocabulary size: {len(tokenizer)}") - - train_torch_dataset = EHRTextDataset(text_data, tokenizer, max_length=args.max_seq_len) + print(f" Max sequence length (tokens): {max_len}") # ------------------------------------------------------------------ - # STEP 5: Initialise GPT-2 style decoder model + # STEP 4 - 6: Build tokeniser, initialise GPT-2, train # ------------------------------------------------------------------ - print("\nSTEP 5: Initialising GPT-2 model …") - config = GPT2Config( - vocab_size=len(tokenizer), - n_positions=args.max_seq_len, - n_ctx=args.max_seq_len, - n_embd=512, - n_layer=8, - n_head=8, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - ) - model = GPT2LMHeadModel(config).to(device) - num_params = sum(p.numel() for p in model.parameters()) / 1e6 - print(f" Model parameters: {num_params:.1f}M") + print("\nSTEP 4-6: Building tokeniser and training GPT-2 ...") + model = EHRGPTBaseline( + n_embd=args.n_embd, + n_layer=args.n_layer, + n_head=args.n_head, + max_seq_len=args.max_seq_len, + ) + model.fit( + sequences=text_data, + output_dir=args.output_dir, + epochs=args.epochs, + batch_size=args.batch_size, + ) + n_params = sum(p.numel() for p in model.model.parameters()) / 1e6 + print(f" Vocabulary size : {len(model.tokenizer)}") + print(f" Model parameters: {n_params:.1f}M") # ------------------------------------------------------------------ - # STEP 6: Train + # STEP 7: Generate synthetic EHRs # ------------------------------------------------------------------ - print("\nSTEP 6: Training …") - data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) - - training_args = TrainingArguments( - output_dir=os.path.join(args.output_dir, "checkpoints"), - overwrite_output_dir=True, - num_train_epochs=args.epochs, - per_device_train_batch_size=args.batch_size, - logging_steps=50, - learning_rate=1e-4, - lr_scheduler_type="cosine", - warmup_steps=100, - use_cpu=not torch.cuda.is_available(), - save_strategy="epoch", + print(f"\nSTEP 7: Generating {args.num_synthetic} synthetic patients ...") + synthetic_df = model.generate( + n_patients=args.num_synthetic, + batch_size=args.gen_batch_size, ) - - hf_trainer = Trainer( - model=model, - args=training_args, - data_collator=data_collator, - train_dataset=train_torch_dataset, + print( + f" Generated {synthetic_df['SUBJECT_ID'].nunique()} patients, " + f"{synthetic_df.shape[0]} (patient, visit, code) rows" ) - hf_trainer.train() - - model_save_path = os.path.join(args.output_dir, "transformer_ehr_model") - hf_trainer.save_model(model_save_path) - print(f" Model saved to: {model_save_path}") - - # ------------------------------------------------------------------ - # STEP 7: Generate synthetic EHRs - # ------------------------------------------------------------------ - print(f"\nSTEP 7: Generating {args.num_synthetic} synthetic patients …") - model.eval() - - all_syn: list[pd.DataFrame] = [] - start_subj_id = 0 - for batch_start in trange(0, args.num_synthetic, args.gen_batch_size): - batch_end = min(batch_start + args.gen_batch_size, args.num_synthetic) - bsz = batch_end - batch_start - - batch_input_ids = torch.tensor( - [[tokenizer.bos_token_id]] * bsz, device=device - ) - with torch.no_grad(): - generated = model.generate( - batch_input_ids, - max_new_tokens=args.max_seq_len, - do_sample=True, - top_k=50, - top_p=0.95, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - ) - - decoded = [ - tokenizer.decode(seq, skip_special_tokens=True) for seq in generated - ] - batch_df = sequences_to_dataframe(decoded) - batch_df["SUBJECT_ID"] += start_subj_id - start_subj_id += bsz - all_syn.append(batch_df) - - synthetic_df = pd.concat(all_syn, ignore_index=True) - print(f" Generated {synthetic_df['SUBJECT_ID'].nunique()} patients, " - f"{synthetic_df.shape[0]} (patient, visit, code) rows") out_csv = os.path.join(args.output_dir, "synthetic_ehr.csv") synthetic_df.to_csv(out_csv, index=False) print(f" Synthetic data saved to: {out_csv}") -# ── CLI entry point ──────────────────────────────────────────────────────────── - +# -- CLI entry point ----------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -340,54 +131,38 @@ def main(args: argparse.Namespace) -> None: "--mimic3_root", type=str, required=True, - help="Path to the MIMIC-III root directory containing raw CSV/CSV.GZ files.", + help="Path to the MIMIC-III root directory (raw CSV/CSV.GZ files).", ) parser.add_argument( "--output_dir", type=str, default="./ehr_generation_output", - help="Directory to save the trained model and synthetic data.", + help="Directory to save checkpoints and synthetic data.", ) parser.add_argument( "--min_visits", type=int, default=1, - help="Minimum number of valid admissions a patient must have.", + help="Minimum valid admissions a patient must have.", ) parser.add_argument( "--truncate_icd", action="store_true", default=False, - help="Truncate ICD-9 codes to 3-digit prefixes (reduces vocab to ~1,071 codes).", - ) - parser.add_argument( - "--max_seq_len", - type=int, - default=512, - help="Maximum token sequence length.", - ) - parser.add_argument( - "--epochs", - type=int, - default=50, - help="Number of training epochs.", + help="Truncate ICD-9 codes to 3-digit prefixes.", ) + parser.add_argument("--n_embd", type=int, default=512, help="Embedding dimension.") + parser.add_argument("--n_layer", type=int, default=8, help="Number of transformer layers.") + parser.add_argument("--n_head", type=int, default=8, help="Number of attention heads.") parser.add_argument( - "--batch_size", - type=int, - default=64, - help="Training batch size.", + "--max_seq_len", type=int, default=512, help="Maximum token sequence length." ) + parser.add_argument("--epochs", type=int, default=50, help="Training epochs.") + parser.add_argument("--batch_size", type=int, default=64, help="Training batch size.") parser.add_argument( - "--num_synthetic", - type=int, - default=10000, - help="Number of synthetic patients to generate.", + "--num_synthetic", type=int, default=10000, help="Synthetic patients to generate." ) parser.add_argument( - "--gen_batch_size", - type=int, - default=512, - help="Generation batch size.", + "--gen_batch_size", type=int, default=512, help="Generation batch size." ) main(parser.parse_args()) diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index a13b18a51..b29cd3b36 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -43,3 +43,10 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding +from .generators import ( + EHRGPTBaseline, + EHRTextDataset, + build_tokenizer, + samples_to_sequences, + sequences_to_dataframe, +) diff --git a/pyhealth/models/generators/__init__.py b/pyhealth/models/generators/__init__.py new file mode 100644 index 000000000..b2e5c1ca5 --- /dev/null +++ b/pyhealth/models/generators/__init__.py @@ -0,0 +1,19 @@ +"""Generative models for synthetic EHR generation.""" + +from .gpt_baseline import ( + VISIT_DELIM, + EHRGPTBaseline, + EHRTextDataset, + build_tokenizer, + samples_to_sequences, + sequences_to_dataframe, +) + +__all__ = [ + "VISIT_DELIM", + "EHRGPTBaseline", + "EHRTextDataset", + "build_tokenizer", + "samples_to_sequences", + "sequences_to_dataframe", +] diff --git a/pyhealth/models/generators/gpt_baseline.py b/pyhealth/models/generators/gpt_baseline.py new file mode 100644 index 000000000..4daf91d69 --- /dev/null +++ b/pyhealth/models/generators/gpt_baseline.py @@ -0,0 +1,506 @@ +""" +GPT-2 Baseline for Synthetic EHR Generation +============================================ + +This module provides a self-contained GPT-2 decoder-only language model for +generating synthetic longitudinal EHR sequences composed of ICD-9 diagnosis +codes. + +Design +------ +Patient histories are first serialised as space-separated code sequences where +consecutive visits are separated by the special ``VISIT_DELIM`` token: + + ``"250.00 401.9 VISIT_DELIM 272.0 428.0 VISIT_DELIM 250.00"`` + +A word-level HuggingFace tokeniser is then trained on this corpus, and a +GPT-2 causal language model is fine-tuned on the resulting token IDs. At +inference time, sequences are sampled autoregressively and deserialised back +to a long-form ``(SUBJECT_ID, HADM_ID, ICD9_CODE)`` DataFrame. + +References +---------- +- *Accelerating Reproducible Research in Synthetic EHR Generation* (CHIL 2026) + +Typical usage +------------- +.. code-block:: python + + from pyhealth.models.generators import EHRGPTBaseline + from pyhealth.tasks.ehr_generation import samples_to_sequences + + model = EHRGPTBaseline(n_embd=256, n_layer=4, n_head=4) + model.fit(text_sequences, output_dir="./checkpoints", epochs=20) + synthetic_df = model.generate(n_patients=1000) +""" + +import os +from typing import Optional + +import pandas as pd +import torch +import torch.nn as nn +from tokenizers import Tokenizer, models, pre_tokenizers, processors, trainers +from torch.utils.data import Dataset +from tqdm import trange +from transformers import ( + DataCollatorForLanguageModeling, + GPT2Config, + GPT2LMHeadModel, + PreTrainedTokenizerFast, + Trainer, + TrainingArguments, +) + +__all__ = [ + "VISIT_DELIM", + "samples_to_sequences", + "sequences_to_dataframe", + "build_tokenizer", + "EHRTextDataset", + "EHRGPTBaseline", +] + +# ── Constants ────────────────────────────────────────────────────────────────── + +VISIT_DELIM = "VISIT_DELIM" + + +# ── Sequence helpers ─────────────────────────────────────────────────────────── + + +def samples_to_sequences(samples: list) -> list[str]: + """Convert ``EHRGenerationMIMIC3`` samples to VISIT_DELIM-delimited text. + + Each sample's ``conditions`` field is a ``List[List[str]]`` (visits × codes). + Adjacent visits are joined by ``VISIT_DELIM`` so the full patient history + becomes a single space-separated string. + + Args: + samples: List of dicts with at least a ``"conditions"`` key. + + Returns: + One string per patient, e.g. + ``"250.00 401.9 VISIT_DELIM 272.0 428.0 VISIT_DELIM 250.00"``. + + Examples: + >>> samples = [{"conditions": [["250.00", "401.9"], ["272.0"]]}] + >>> samples_to_sequences(samples) + ['250.00 401.9 VISIT_DELIM 272.0'] + """ + sequences: list[str] = [] + for sample in samples: + visit_texts = [" ".join(visit_codes) for visit_codes in sample["conditions"]] + sequences.append(f" {VISIT_DELIM} ".join(visit_texts)) + return sequences + + +def sequences_to_dataframe(sequences: list[str]) -> pd.DataFrame: + """Deserialise generated text sequences to long-form EHR rows. + + Assigns synthetic sequential identifiers; original MIMIC-III IDs are not + preserved (generation is unconditional). + + Args: + sequences: Generated text sequences from :meth:`EHRGPTBaseline.generate`. + + Returns: + A ``pd.DataFrame`` with columns ``SUBJECT_ID``, ``HADM_ID``, + ``ICD9_CODE``. + + Examples: + >>> sequences_to_dataframe(["250.00 VISIT_DELIM 401.9"]) + SUBJECT_ID HADM_ID ICD9_CODE + 0 0 0 250.00 + 1 0 1 401.9 + """ + rows: list[dict] = [] + for subj_idx, seq in enumerate(sequences): + for hadm_idx, visit_str in enumerate(seq.strip().split(VISIT_DELIM)): + for code in visit_str.strip().split(): + if code: + rows.append( + { + "SUBJECT_ID": subj_idx, + "HADM_ID": hadm_idx, + "ICD9_CODE": code, + } + ) + return pd.DataFrame(rows) + + +# ── Tokeniser ────────────────────────────────────────────────────────────────── + + +def build_tokenizer(text_data: list[str]) -> PreTrainedTokenizerFast: + """Build and train a word-level tokeniser on an EHR text corpus. + + Uses the HuggingFace ``tokenizers`` library. Special tokens: + + * ``[UNK]`` – unknown token + * ``[PAD]`` – padding + * ``[BOS]`` – beginning-of-sequence + * ``[EOS]`` – end-of-sequence + + ``VISIT_DELIM`` is treated as a regular vocabulary word so the model + learns its visit-boundary semantics. + + Args: + text_data: List of space-separated code sequences (one per patient). + + Returns: + A ``PreTrainedTokenizerFast`` wrapping the trained word-level model. + + Note: + The ``Whitespace`` pre-tokeniser splits on punctuation, so ICD-9 codes + such as ``"250.00"`` are stored as the sub-tokens ``["250", ".", "00"]``. + This is intentional: it drastically reduces the vocabulary size while + preserving code structure. + """ + tokenizer_obj = Tokenizer(models.WordLevel(unk_token="[UNK]")) + tokenizer_obj.pre_tokenizer = pre_tokenizers.Whitespace() + + special_tokens = ["[UNK]", "[PAD]", "[BOS]", "[EOS]"] + word_trainer = trainers.WordLevelTrainer(special_tokens=special_tokens) + tokenizer_obj.train_from_iterator(text_data, trainer=word_trainer) + + tokenizer_obj.post_processor = processors.TemplateProcessing( + single="[BOS] $A [EOS]", + special_tokens=[ + ("[BOS]", tokenizer_obj.token_to_id("[BOS]")), + ("[EOS]", tokenizer_obj.token_to_id("[EOS]")), + ], + ) + + return PreTrainedTokenizerFast( + tokenizer_object=tokenizer_obj, + unk_token="[UNK]", + pad_token="[PAD]", + bos_token="[BOS]", + eos_token="[EOS]", + ) + + +# ── PyTorch Dataset ──────────────────────────────────────────────────────────── + + +class EHRTextDataset(Dataset): + """Tokenises EHR text sequences for causal language-model training. + + Each sequence is tokenised, truncated/padded to ``max_length``, and stored + as a fixed-length ``LongTensor``. The ``labels`` field mirrors + ``input_ids`` so the HuggingFace ``Trainer`` can compute the standard + next-token prediction loss. + + Args: + sequences: Plain-text patient sequences (one string per patient). + tokenizer: A trained :class:`~transformers.PreTrainedTokenizerFast`. + max_length: Token budget; longer sequences are right-truncated. + + Examples: + >>> from pyhealth.models.generators import build_tokenizer, EHRTextDataset + >>> tok = build_tokenizer(["250.00 VISIT_DELIM 401.9"]) + >>> ds = EHRTextDataset(["250.00 VISIT_DELIM 401.9"], tok, max_length=16) + >>> len(ds) + 1 + >>> ds[0]["input_ids"].shape + torch.Size([16]) + """ + + def __init__( + self, + sequences: list[str], + tokenizer: PreTrainedTokenizerFast, + max_length: int = 512, + ) -> None: + self.input_ids: list[torch.Tensor] = [] + for txt in sequences: + enc = tokenizer( + txt, + truncation=True, + max_length=max_length, + padding="max_length", + ) + self.input_ids.append(torch.tensor(enc["input_ids"])) + + def __len__(self) -> int: + return len(self.input_ids) + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + ids = self.input_ids[idx] + return {"input_ids": ids, "labels": ids} + + +# ── Main model class ─────────────────────────────────────────────────────────── + + +class EHRGPTBaseline(nn.Module): + """GPT-2 decoder-only language model for synthetic EHR generation. + + Wraps a HuggingFace ``GPT2LMHeadModel`` and exposes a high-level API + (:meth:`fit`, :meth:`generate`) that matches the training pipeline + described in *Accelerating Reproducible Research in Synthetic EHR + Generation* (CHIL 2026). + + Architecture + ------------ + * Word-level ICD-9 tokeniser (``VISIT_DELIM`` as vocabulary entry) + * GPT-2 transformer decoder with configurable depth and width + * Autoregressive next-token prediction objective + + Args: + n_embd: Embedding and hidden dimension. Default: 512. + n_layer: Number of transformer decoder layers. Default: 8. + n_head: Number of self-attention heads. Default: 8. + max_seq_len: Maximum token sequence length. Default: 512. + + Attributes: + tokenizer: The fitted :class:`~transformers.PreTrainedTokenizerFast` + (``None`` until :meth:`fit` is called). + model: The underlying :class:`~transformers.GPT2LMHeadModel` + (``None`` until :meth:`fit` is called). + + Examples: + .. code-block:: python + + from pyhealth.models.generators import EHRGPTBaseline, samples_to_sequences + + gpt = EHRGPTBaseline(n_embd=256, n_layer=4, n_head=4) + gpt.fit(text_sequences, output_dir="./ckpt", epochs=10, batch_size=32) + df = gpt.generate(n_patients=500) + """ + + def __init__( + self, + n_embd: int = 512, + n_layer: int = 8, + n_head: int = 8, + max_seq_len: int = 512, + ) -> None: + super().__init__() + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.max_seq_len = max_seq_len + + # Populated by fit() + self.tokenizer: Optional[PreTrainedTokenizerFast] = None + self.model: Optional[GPT2LMHeadModel] = None + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward(self, input_ids: torch.Tensor, **kwargs) -> dict: + """Run a forward pass through the GPT-2 model. + + Delegates directly to :class:`~transformers.GPT2LMHeadModel`. + + Args: + input_ids: Token ID tensor of shape ``(batch, seq_len)``. + **kwargs: Additional arguments forwarded to GPT2LMHeadModel. + + Returns: + The ``CausalLMOutputWithCrossAttentions`` dict-like object from + HuggingFace (contains ``logits``, ``loss`` when ``labels`` are + supplied, etc.). + + Raises: + RuntimeError: If called before :meth:`fit`. + """ + if self.model is None: + raise RuntimeError("Call fit() before forward().") + return self.model(input_ids, **kwargs) + + # ------------------------------------------------------------------ + # fit + # ------------------------------------------------------------------ + + def fit( + self, + sequences: list[str], + output_dir: str = "./ehr_gpt_output", + epochs: int = 50, + batch_size: int = 64, + learning_rate: float = 1e-4, + warmup_steps: int = 100, + ) -> "EHRGPTBaseline": + """Build the tokeniser, initialise GPT-2, and train on ``sequences``. + + This method is idempotent: calling it again re-initialises the + tokeniser and model from scratch. + + Args: + sequences: List of VISIT_DELIM-delimited patient text sequences + produced by :func:`samples_to_sequences`. + output_dir: Directory for HuggingFace ``Trainer`` checkpoints. + epochs: Training epochs. + batch_size: Per-device training batch size. + learning_rate: Peak learning rate (cosine schedule). + warmup_steps: Linear warm-up steps. + + Returns: + ``self`` (fluent API). + """ + os.makedirs(output_dir, exist_ok=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # ── tokeniser ───────────────────────────────────────────────────────── + self.tokenizer = build_tokenizer(sequences) + + # ── dataset ─────────────────────────────────────────────────────────── + train_ds = EHRTextDataset(sequences, self.tokenizer, max_length=self.max_seq_len) + + # ── model ───────────────────────────────────────────────────────────── + config = GPT2Config( + vocab_size=len(self.tokenizer), + n_positions=self.max_seq_len, + n_ctx=self.max_seq_len, + n_embd=self.n_embd, + n_layer=self.n_layer, + n_head=self.n_head, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + ) + self.model = GPT2LMHeadModel(config).to(device) + + # ── training ────────────────────────────────────────────────────────── + data_collator = DataCollatorForLanguageModeling( + tokenizer=self.tokenizer, mlm=False + ) + + training_args = TrainingArguments( + output_dir=os.path.join(output_dir, "checkpoints"), + overwrite_output_dir=True, + num_train_epochs=epochs, + per_device_train_batch_size=batch_size, + logging_steps=50, + learning_rate=learning_rate, + lr_scheduler_type="cosine", + warmup_steps=warmup_steps, + use_cpu=not torch.cuda.is_available(), + save_strategy="epoch", + ) + + hf_trainer = Trainer( + model=self.model, + args=training_args, + data_collator=data_collator, + train_dataset=train_ds, + ) + hf_trainer.train() + + # Persist model and tokeniser side-by-side + model_dir = os.path.join(output_dir, "gpt_ehr_model") + hf_trainer.save_model(model_dir) + self.tokenizer.save_pretrained(model_dir) + + return self + + # ------------------------------------------------------------------ + # generate + # ------------------------------------------------------------------ + + def generate( + self, + n_patients: int = 1000, + batch_size: int = 512, + top_k: int = 50, + top_p: float = 0.95, + ) -> pd.DataFrame: + """Sample synthetic EHR sequences and return a long-form DataFrame. + + Args: + n_patients: Number of synthetic patients to generate. + batch_size: Generation batch size (GPU memory permitting). + top_k: Top-k sampling parameter. + top_p: Nucleus sampling probability threshold. + + Returns: + A ``pd.DataFrame`` with columns ``SUBJECT_ID``, ``HADM_ID``, + ``ICD9_CODE``. + + Raises: + RuntimeError: If called before :meth:`fit`. + """ + if self.model is None or self.tokenizer is None: + raise RuntimeError("Call fit() before generate().") + + device = next(self.model.parameters()).device + self.model.eval() + + all_dfs: list[pd.DataFrame] = [] + start_subj = 0 + + for batch_start in trange(0, n_patients, batch_size, desc="Generating"): + bsz = min(batch_size, n_patients - batch_start) + prompt = torch.tensor( + [[self.tokenizer.bos_token_id]] * bsz, device=device + ) + with torch.no_grad(): + generated = self.model.generate( + prompt, + max_new_tokens=self.max_seq_len, + do_sample=True, + top_k=top_k, + top_p=top_p, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + ) + decoded = [ + self.tokenizer.decode(seq, skip_special_tokens=True) + for seq in generated + ] + batch_df = sequences_to_dataframe(decoded) + batch_df["SUBJECT_ID"] += start_subj + start_subj += bsz + all_dfs.append(batch_df) + + return pd.concat(all_dfs, ignore_index=True) + + # ------------------------------------------------------------------ + # Persistence helpers + # ------------------------------------------------------------------ + + def save(self, path: str) -> None: + """Save the GPT-2 model weights and tokeniser to ``path``. + + Args: + path: Directory to save into (created if absent). + """ + if self.model is None or self.tokenizer is None: + raise RuntimeError("Nothing to save – call fit() first.") + os.makedirs(path, exist_ok=True) + self.model.save_pretrained(path) + self.tokenizer.save_pretrained(path) + + @classmethod + def load(cls, path: str, **init_kwargs) -> "EHRGPTBaseline": + """Load a previously saved :class:`EHRGPTBaseline` from ``path``. + + Args: + path: Directory created by :meth:`save`. + **init_kwargs: Forwarded to ``__init__`` (overrides defaults for + ``n_embd``, ``n_layer``, ``n_head``, ``max_seq_len``). + + Returns: + A fully initialised :class:`EHRGPTBaseline` ready for + :meth:`generate`. + """ + instance = cls(**init_kwargs) + instance.tokenizer = PreTrainedTokenizerFast.from_pretrained(path) + instance.model = GPT2LMHeadModel.from_pretrained(path) + return instance + + # ------------------------------------------------------------------ + # repr + # ------------------------------------------------------------------ + + def __repr__(self) -> str: # pragma: no cover + fitted = self.model is not None + return ( + f"EHRGPTBaseline(" + f"n_embd={self.n_embd}, n_layer={self.n_layer}, " + f"n_head={self.n_head}, max_seq_len={self.max_seq_len}, " + f"fitted={fitted})" + ) diff --git a/tests/core/test_transformer_ehr_helpers.py b/tests/core/test_transformer_ehr_helpers.py index 33bf870aa..e93e11dc3 100644 --- a/tests/core/test_transformer_ehr_helpers.py +++ b/tests/core/test_transformer_ehr_helpers.py @@ -1,27 +1,20 @@ """ Tests for the utility functions and classes defined in -examples/ehr_generation/ehr_generation_mimic3_transformer.py +pyhealth/models/generators/gpt_baseline.py Covered: -* ``samples_to_sequences`` – nested visit lists → text strings -* ``sequences_to_dataframe`` – text strings → long-form DataFrame -* ``build_tokenizer`` – word-level HuggingFace tokenizer -* ``EHRTextDataset`` – PyTorch Dataset wrapping tokenized EHR sequences +* ``samples_to_sequences`` – nested visit lists -> text strings +* ``sequences_to_dataframe`` – text strings -> long-form DataFrame +* ``build_tokenizer`` – word-level HuggingFace tokenizer +* ``EHRTextDataset`` – PyTorch Dataset wrapping tokenized EHR sequences """ -import sys -import os +import unittest -import pytest import torch +from transformers import PreTrainedTokenizerFast -# Allow importing directly from the examples directory without installing it. -sys.path.insert( - 0, - os.path.join(os.path.dirname(__file__), "../../examples/ehr_generation"), -) - -from ehr_generation_mimic3_transformer import ( # noqa: E402 +from pyhealth.models.generators import ( VISIT_DELIM, EHRTextDataset, build_tokenizer, @@ -29,280 +22,251 @@ sequences_to_dataframe, ) -# ───────────────────────────────────────────────────────────────────────────── -# Shared fixtures -# ───────────────────────────────────────────────────────────────────────────── +# ── shared test data ─────────────────────────────────────────────────────────── _SINGLE_VISIT_SAMPLE = {"conditions": [["250.00", "401.9"]]} _MULTI_VISIT_SAMPLE = {"conditions": [["250.00", "401.9"], ["272.0", "428.0"], ["250.00"]]} -_EMPTY_VISIT_SAMPLE = {"conditions": []} # patient with no visits +_EMPTY_VISIT_SAMPLE = {"conditions": []} + +_CORPUS = [ + "250.00 401.9 VISIT_DELIM 272.0", + "428.0 VISIT_DELIM 250.00", + "401.9 272.0 428.0", +] + +_SEQUENCES = [ + "250.00 401.9 VISIT_DELIM 272.0", + "428.0", + "401.9 272.0 428.0 VISIT_DELIM 250.00 VISIT_DELIM 272.0", +] +_MAX_LENGTH = 16 -# ───────────────────────────────────────────────────────────────────────────── -# 1. samples_to_sequences -# ───────────────────────────────────────────────────────────────────────────── +# ── 1. samples_to_sequences ──────────────────────────────────────────────────── -class TestSamplesToSequences: +class TestSamplesToSequences(unittest.TestCase): def test_returns_one_string_per_sample(self): - samples = [_SINGLE_VISIT_SAMPLE, _MULTI_VISIT_SAMPLE] - result = samples_to_sequences(samples) - assert len(result) == 2 + result = samples_to_sequences([_SINGLE_VISIT_SAMPLE, _MULTI_VISIT_SAMPLE]) + self.assertEqual(len(result), 2) def test_empty_input_returns_empty_list(self): - assert samples_to_sequences([]) == [] + self.assertEqual(samples_to_sequences([]), []) def test_single_visit_no_delimiter(self): result = samples_to_sequences([_SINGLE_VISIT_SAMPLE]) - assert VISIT_DELIM not in result[0] + self.assertNotIn(VISIT_DELIM, result[0]) def test_multi_visit_delimiter_count_matches(self): - # 3 visits → 2 VISIT_DELIM occurrences + # 3 visits -> 2 VISIT_DELIM occurrences result = samples_to_sequences([_MULTI_VISIT_SAMPLE]) - assert result[0].count(VISIT_DELIM) == 2 + self.assertEqual(result[0].count(VISIT_DELIM), 2) def test_codes_present_in_output(self): result = samples_to_sequences([_MULTI_VISIT_SAMPLE]) for visit in _MULTI_VISIT_SAMPLE["conditions"]: for code in visit: - assert code in result[0] + self.assertIn(code, result[0]) def test_single_visit_codes_space_separated(self): result = samples_to_sequences([_SINGLE_VISIT_SAMPLE]) - assert result[0] == "250.00 401.9" + self.assertEqual(result[0], "250.00 401.9") def test_multi_visit_format(self): result = samples_to_sequences([_MULTI_VISIT_SAMPLE]) expected = f"250.00 401.9 {VISIT_DELIM} 272.0 428.0 {VISIT_DELIM} 250.00" - assert result[0] == expected + self.assertEqual(result[0], expected) def test_empty_conditions_yields_empty_string(self): result = samples_to_sequences([_EMPTY_VISIT_SAMPLE]) - assert result[0] == "" + self.assertEqual(result[0], "") def test_single_code_per_visit(self): sample = {"conditions": [["A"], ["B"], ["C"]]} result = samples_to_sequences([sample]) - assert result[0] == f"A {VISIT_DELIM} B {VISIT_DELIM} C" + self.assertEqual(result[0], f"A {VISIT_DELIM} B {VISIT_DELIM} C") def test_multiple_samples_independent(self): - samples = [_SINGLE_VISIT_SAMPLE, _MULTI_VISIT_SAMPLE] - result = samples_to_sequences(samples) - assert result[0] != result[1] + result = samples_to_sequences([_SINGLE_VISIT_SAMPLE, _MULTI_VISIT_SAMPLE]) + self.assertNotEqual(result[0], result[1]) def test_output_is_list_of_strings(self): result = samples_to_sequences([_SINGLE_VISIT_SAMPLE]) - assert isinstance(result, list) - assert all(isinstance(s, str) for s in result) + self.assertIsInstance(result, list) + for s in result: + self.assertIsInstance(s, str) -# ───────────────────────────────────────────────────────────────────────────── -# 2. sequences_to_dataframe -# ───────────────────────────────────────────────────────────────────────────── +# ── 2. sequences_to_dataframe ───────────────────────────────────────────────── -class TestSequencesToDataframe: +class TestSequencesToDataframe(unittest.TestCase): _SEQ_SINGLE = "250.00 401.9" _SEQ_MULTI = f"250.00 401.9 {VISIT_DELIM} 272.0 428.0" def test_required_columns_present(self): df = sequences_to_dataframe([self._SEQ_SINGLE]) - assert set(df.columns) == {"SUBJECT_ID", "HADM_ID", "ICD9_CODE"} + self.assertEqual(set(df.columns), {"SUBJECT_ID", "HADM_ID", "ICD9_CODE"}) def test_empty_input_returns_empty_dataframe(self): df = sequences_to_dataframe([]) - assert df.empty - assert list(df.columns) == [] # pd.concat on empty list → empty DF + self.assertTrue(df.empty) + self.assertEqual(list(df.columns), []) def test_single_visit_produces_correct_codes(self): df = sequences_to_dataframe([self._SEQ_SINGLE]) - codes = set(df["ICD9_CODE"].tolist()) - assert codes == {"250.00", "401.9"} + self.assertEqual(set(df["ICD9_CODE"].tolist()), {"250.00", "401.9"}) def test_single_visit_single_hadm_id(self): df = sequences_to_dataframe([self._SEQ_SINGLE]) - assert df["HADM_ID"].nunique() == 1 - assert df["HADM_ID"].iloc[0] == 0 + self.assertEqual(df["HADM_ID"].nunique(), 1) + self.assertEqual(df["HADM_ID"].iloc[0], 0) def test_multi_visit_hadm_ids(self): df = sequences_to_dataframe([self._SEQ_MULTI]) - assert set(df["HADM_ID"].tolist()) == {0, 1} + self.assertEqual(set(df["HADM_ID"].tolist()), {0, 1}) def test_subject_ids_sequential(self): df = sequences_to_dataframe([self._SEQ_SINGLE, self._SEQ_SINGLE]) - assert set(df["SUBJECT_ID"].tolist()) == {0, 1} + self.assertEqual(set(df["SUBJECT_ID"].tolist()), {0, 1}) def test_multi_patient_subject_id_mapping(self): df = sequences_to_dataframe([self._SEQ_MULTI, self._SEQ_SINGLE]) - assert df[df["SUBJECT_ID"] == 0]["HADM_ID"].nunique() == 2 - assert df[df["SUBJECT_ID"] == 1]["HADM_ID"].nunique() == 1 + self.assertEqual(df[df["SUBJECT_ID"] == 0]["HADM_ID"].nunique(), 2) + self.assertEqual(df[df["SUBJECT_ID"] == 1]["HADM_ID"].nunique(), 1) def test_row_count_matches_codes(self): - # seq has 4 codes across 2 visits df = sequences_to_dataframe([self._SEQ_MULTI]) - assert len(df) == 4 + self.assertEqual(len(df), 4) def test_whitespace_only_sequence_returns_empty(self): df = sequences_to_dataframe([" "]) - assert df.empty + self.assertTrue(df.empty) def test_round_trip_from_samples(self): - samples = [_MULTI_VISIT_SAMPLE] - seqs = samples_to_sequences(samples) + seqs = samples_to_sequences([_MULTI_VISIT_SAMPLE]) df = sequences_to_dataframe(seqs) - all_codes = set( - code - for visit in _MULTI_VISIT_SAMPLE["conditions"] - for code in visit - ) - recovered_codes = set(df["ICD9_CODE"].tolist()) - assert all_codes == recovered_codes + all_codes = {c for visit in _MULTI_VISIT_SAMPLE["conditions"] for c in visit} + self.assertEqual(all_codes, set(df["ICD9_CODE"].tolist())) def test_round_trip_visit_count(self): - samples = [_MULTI_VISIT_SAMPLE] - seqs = samples_to_sequences(samples) + seqs = samples_to_sequences([_MULTI_VISIT_SAMPLE]) df = sequences_to_dataframe(seqs) n_visits = df.groupby("SUBJECT_ID")["HADM_ID"].nunique().iloc[0] - assert n_visits == len(_MULTI_VISIT_SAMPLE["conditions"]) + self.assertEqual(n_visits, len(_MULTI_VISIT_SAMPLE["conditions"])) -# ───────────────────────────────────────────────────────────────────────────── -# 3. build_tokenizer -# ───────────────────────────────────────────────────────────────────────────── - -_CORPUS = [ - "250.00 401.9 VISIT_DELIM 272.0", - "428.0 VISIT_DELIM 250.00", - "401.9 272.0 428.0", -] +# ── 3. build_tokenizer ──────────────────────────────────────────────────────── -class TestBuildTokenizer: - @pytest.fixture(scope="class") - def tokenizer(self): - return build_tokenizer(_CORPUS) +class TestBuildTokenizer(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.tokenizer = build_tokenizer(_CORPUS) - def test_special_tokens_in_vocab(self, tokenizer): + def test_special_tokens_in_vocab(self): + vocab = self.tokenizer.get_vocab() for tok in ("[UNK]", "[PAD]", "[BOS]", "[EOS]"): - assert tok in tokenizer.get_vocab(), f"{tok!r} missing from vocab" + self.assertIn(tok, vocab, f"{tok!r} missing from vocab") - def test_visit_delim_in_vocab(self, tokenizer): - assert VISIT_DELIM in tokenizer.get_vocab() + def test_visit_delim_in_vocab(self): + self.assertIn(VISIT_DELIM, self.tokenizer.get_vocab()) - def test_medical_codes_in_vocab(self, tokenizer): - # The Whitespace pre-tokenizer splits on punctuation, so "250.00" becomes - # the sub-tokens ["250", ".", "00"]. Assert each constituent sub-token - # (digits and the dot) appears in the vocabulary instead of the full code. - vocab = tokenizer.get_vocab() + def test_medical_codes_in_vocab(self): + # Whitespace pre-tokenizer splits "250.00" -> ["250", ".", "00"] + vocab = self.tokenizer.get_vocab() for sub in ["250", "00", "401", "9", "272", "0", "428", "."]: - assert sub in vocab, f"sub-token {sub!r} missing from vocab" + self.assertIn(sub, vocab, f"sub-token {sub!r} missing from vocab") - def test_vocab_size_at_least_corpus_tokens(self, tokenizer): - # 4 special tokens + 5 unique code tokens + VISIT_DELIM = at least 10 - assert len(tokenizer) >= 10 + def test_vocab_size_at_least_corpus_tokens(self): + self.assertGreaterEqual(len(self.tokenizer), 10) - def test_bos_eos_token_ids_set(self, tokenizer): - assert tokenizer.bos_token_id is not None - assert tokenizer.eos_token_id is not None + def test_bos_eos_token_ids_set(self): + self.assertIsNotNone(self.tokenizer.bos_token_id) + self.assertIsNotNone(self.tokenizer.eos_token_id) - def test_pad_token_id_set(self, tokenizer): - assert tokenizer.pad_token_id is not None + def test_pad_token_id_set(self): + self.assertIsNotNone(self.tokenizer.pad_token_id) - def test_encode_includes_bos_eos(self, tokenizer): - ids = tokenizer("250.00 401.9")["input_ids"] - assert ids[0] == tokenizer.bos_token_id - assert ids[-1] == tokenizer.eos_token_id + def test_encode_includes_bos_eos(self): + ids = self.tokenizer("250.00 401.9")["input_ids"] + self.assertEqual(ids[0], self.tokenizer.bos_token_id) + self.assertEqual(ids[-1], self.tokenizer.eos_token_id) - def test_encode_decode_roundtrip(self, tokenizer): + def test_encode_decode_roundtrip(self): text = "250.00 401.9 VISIT_DELIM 272.0" - ids = tokenizer(text, add_special_tokens=True)["input_ids"] - decoded = tokenizer.decode(ids, skip_special_tokens=True) - # The Whitespace pre-tokenizer splits codes on '.', so the round-trip - # produces sub-tokens (e.g. "250 . 00" instead of "250.00"). Verify - # that all digit sub-tokens and the VISIT_DELIM are present. - for sub_token in ["250", "00", "401", "9", VISIT_DELIM, "272", "0"]: - assert sub_token in decoded.split(), f"{sub_token!r} missing from decoded" + ids = self.tokenizer(text, add_special_tokens=True)["input_ids"] + decoded = self.tokenizer.decode(ids, skip_special_tokens=True) + # Whitespace splits on '.', so check sub-tokens + for sub in ["250", "00", "401", "9", VISIT_DELIM, "272", "0"]: + self.assertIn(sub, decoded.split(), f"{sub!r} missing from decoded") - def test_unknown_token_maps_to_unk_id(self, tokenizer): - enc = tokenizer("UNKNOWN_CODE_XYZ")["input_ids"] - # Strip BOS/EOS; the middle token should be [UNK] - inner = enc[1:-1] - assert tokenizer.unk_token_id in inner + def test_unknown_token_maps_to_unk_id(self): + enc = self.tokenizer("UNKNOWN_CODE_XYZ")["input_ids"] + inner = enc[1:-1] # strip BOS/EOS + self.assertIn(self.tokenizer.unk_token_id, inner) - def test_returns_pretrained_tokenizer_fast(self, tokenizer): - from transformers import PreTrainedTokenizerFast + def test_returns_pretrained_tokenizer_fast(self): + self.assertIsInstance(self.tokenizer, PreTrainedTokenizerFast) - assert isinstance(tokenizer, PreTrainedTokenizerFast) +# ── 4. EHRTextDataset ───────────────────────────────────────────────────────── -# ───────────────────────────────────────────────────────────────────────────── -# 4. EHRTextDataset -# ───────────────────────────────────────────────────────────────────────────── -_SEQUENCES = [ - "250.00 401.9 VISIT_DELIM 272.0", - "428.0", - "401.9 272.0 428.0 VISIT_DELIM 250.00 VISIT_DELIM 272.0", -] -_MAX_LENGTH = 16 +class TestEHRTextDataset(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.tokenizer = build_tokenizer(_SEQUENCES) + cls.dataset = EHRTextDataset(_SEQUENCES, cls.tokenizer, max_length=_MAX_LENGTH) + def test_len_matches_sequences(self): + self.assertEqual(len(self.dataset), len(_SEQUENCES)) -class TestEHRTextDataset: - @pytest.fixture(scope="class") - def tokenizer(self): - return build_tokenizer(_SEQUENCES) + def test_getitem_returns_dict(self): + self.assertIsInstance(self.dataset[0], dict) - @pytest.fixture(scope="class") - def dataset(self, tokenizer): - return EHRTextDataset(_SEQUENCES, tokenizer, max_length=_MAX_LENGTH) + def test_getitem_has_input_ids_key(self): + self.assertIn("input_ids", self.dataset[0]) - def test_len_matches_sequences(self, dataset): - assert len(dataset) == len(_SEQUENCES) + def test_getitem_has_labels_key(self): + self.assertIn("labels", self.dataset[0]) - def test_getitem_returns_dict(self, dataset): - item = dataset[0] - assert isinstance(item, dict) + def test_input_ids_are_tensors(self): + self.assertIsInstance(self.dataset[0]["input_ids"], torch.Tensor) - def test_getitem_has_input_ids_key(self, dataset): - assert "input_ids" in dataset[0] + def test_labels_are_tensors(self): + self.assertIsInstance(self.dataset[0]["labels"], torch.Tensor) - def test_getitem_has_labels_key(self, dataset): - assert "labels" in dataset[0] + def test_input_ids_length_equals_max_length(self): + for i in range(len(self.dataset)): + self.assertEqual(self.dataset[i]["input_ids"].shape[0], _MAX_LENGTH) - def test_input_ids_are_tensors(self, dataset): - assert isinstance(dataset[0]["input_ids"], torch.Tensor) + def test_labels_equal_input_ids(self): + item = self.dataset[0] + self.assertTrue(torch.equal(item["input_ids"], item["labels"])) - def test_labels_are_tensors(self, dataset): - assert isinstance(dataset[0]["labels"], torch.Tensor) + def test_all_items_same_length(self): + lengths = {self.dataset[i]["input_ids"].shape[0] for i in range(len(self.dataset))} + self.assertEqual(len(lengths), 1) - def test_input_ids_length_equals_max_length(self, dataset): - for i in range(len(dataset)): - assert dataset[i]["input_ids"].shape[0] == _MAX_LENGTH + def test_empty_sequences_list(self): + ds = EHRTextDataset([], self.tokenizer, max_length=_MAX_LENGTH) + self.assertEqual(len(ds), 0) - def test_labels_equal_input_ids(self, dataset): - item = dataset[0] - assert torch.equal(item["input_ids"], item["labels"]) + def test_single_sequence(self): + ds = EHRTextDataset(["250.00"], self.tokenizer, max_length=_MAX_LENGTH) + self.assertEqual(len(ds), 1) + self.assertEqual(ds[0]["input_ids"].shape[0], _MAX_LENGTH) - def test_all_items_same_length(self, dataset): - lengths = {dataset[i]["input_ids"].shape[0] for i in range(len(dataset))} - assert len(lengths) == 1 # all padded/truncated to max_length - - def test_empty_sequences_list(self, tokenizer): - ds = EHRTextDataset([], tokenizer, max_length=_MAX_LENGTH) - assert len(ds) == 0 + def test_long_sequence_truncated(self): + long_seq = " ".join(["250.00"] * 100) + ds = EHRTextDataset([long_seq], self.tokenizer, max_length=_MAX_LENGTH) + self.assertEqual(ds[0]["input_ids"].shape[0], _MAX_LENGTH) - def test_single_sequence(self, tokenizer): - ds = EHRTextDataset(["250.00"], tokenizer, max_length=_MAX_LENGTH) - assert len(ds) == 1 - item = ds[0] - assert item["input_ids"].shape[0] == _MAX_LENGTH + def test_index_out_of_range_raises(self): + with self.assertRaises(IndexError): + _ = self.dataset[len(_SEQUENCES)] - def test_long_sequence_truncated(self, tokenizer): - # Construct a sequence much longer than max_length - long_seq = " ".join(["250.00"] * 100) - ds = EHRTextDataset([long_seq], tokenizer, max_length=_MAX_LENGTH) - assert ds[0]["input_ids"].shape[0] == _MAX_LENGTH - def test_index_out_of_range_raises(self, dataset): - with pytest.raises(IndexError): - _ = dataset[len(_SEQUENCES)] +if __name__ == "__main__": + unittest.main()