diff --git a/pySEQTarget/expansion/_binder.py b/pySEQTarget/expansion/_binder.py index 727e0e6..bba0e45 100644 --- a/pySEQTarget/expansion/_binder.py +++ b/pySEQTarget/expansion/_binder.py @@ -95,4 +95,15 @@ def _binder(self, kept_cols): .drop([f"{self.eligible_col}{self.indicator_baseline}", self.eligible_col]) ) + # Truncate each (id, trial) at the first outcome event so that subjects who + # experience the outcome early are not carried forward with subsequent rows. + DT = DT.filter( + pl.col(self.outcome_col) + .fill_null(0) + .cum_max() + .shift(1, fill_value=0) + .over([self.id_col, "trial"]) + == 0 + ) + return DT diff --git a/pyproject.toml b/pyproject.toml index 1e7dac8..3c97a2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.12.8" +version = "0.12.9" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} @@ -25,7 +25,7 @@ authors = [ {name = "Ryan O'Dea", email = "ryan.odea@psi.ch"}, {name = "Alejandro Szmulewicz", email = "aszmulewicz@hsph.harvard.edu"}, {name = "Tom Palmer", email = "tom.palmer@bristol.ac.uk"}, - {name = "Miguel Hernan", email = "mhernan@hsph.harvard.edu"}, + {name = "Miguel Hernán", email = "mhernan@hsph.harvard.edu"}, ] maintainers = [ diff --git a/tests/test_expansion.py b/tests/test_expansion.py new file mode 100644 index 0000000..e9e81cb --- /dev/null +++ b/tests/test_expansion.py @@ -0,0 +1,92 @@ +import polars as pl + +from pySEQTarget import SEQuential + + +def _make_model(data): + return SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="treatment", + outcome_col="outcome", + time_varying_cols=[], + fixed_cols=[], + ) + + +def test_expansion_truncates_at_first_outcome(): + """Expansion should truncate each (id, trial) at and including the first + outcome=1 row. Rows from later periods must not appear.""" + data = pl.DataFrame( + { + "ID": [1, 1, 1, 1, 1], + "time": [0, 1, 2, 3, 4], + "eligible": [1, 0, 0, 0, 0], + # Both treatment values required by default treatment_level=[0,1] + "treatment": [0, 1, 0, 1, 0], + # outcome=1 at time=2, but original data continues with outcome=0 + "outcome": [0, 0, 1, 0, 0], + } + ) + + model = _make_model(data) + model.expand() + + # Only trial 0 exists (eligible only at time=0). + # Should have followup 0, 1, 2 — the outcome=1 row is included but not beyond. + followups = sorted(model.DT["followup"].to_list()) + assert followups == [0, 1, 2] + + # The outcome=1 row must be present + outcome_row = model.DT.filter(pl.col("outcome") == 1) + assert len(outcome_row) == 1 + assert int(outcome_row["followup"][0]) == 2 + + +def test_expansion_does_not_truncate_without_outcome(): + """Subjects who never experience the outcome should retain all expanded rows.""" + data = pl.DataFrame( + { + "ID": [1, 1, 1, 1, 1], + "time": [0, 1, 2, 3, 4], + "eligible": [1, 0, 0, 0, 0], + "treatment": [0, 1, 0, 1, 0], + "outcome": [0, 0, 0, 0, 0], + } + ) + + model = _make_model(data) + model.expand() + + # All 5 followup periods should be present + followups = sorted(model.DT["followup"].to_list()) + assert followups == [0, 1, 2, 3, 4] + + +def test_expansion_truncates_each_trial_independently(): + """Truncation must apply per (id, trial), not globally. A subject enrolled in + multiple trials should have each trial truncated at its own first outcome.""" + data = pl.DataFrame( + { + "ID": [1, 1, 1, 1, 1], + "time": [0, 1, 2, 3, 4], + "eligible": [1, 1, 0, 0, 0], + "treatment": [0, 1, 0, 1, 0], + # outcome=1 at time=3: trial 0 sees it at followup=3, trial 1 at followup=2 + "outcome": [0, 0, 0, 1, 0], + } + ) + + model = _make_model(data) + model.expand() + + trial_0 = model.DT.filter(pl.col("trial") == 0) + trial_1 = model.DT.filter(pl.col("trial") == 1) + + # Trial 0 starts at time=0, outcome at time=3 → followup 0,1,2,3 + assert sorted(trial_0["followup"].to_list()) == [0, 1, 2, 3] + + # Trial 1 starts at time=1, outcome at time=3 → followup 0,1,2 + assert sorted(trial_1["followup"].to_list()) == [0, 1, 2]