Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pySEQTarget/expansion/_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,15 @@ def _binder(self, kept_cols):
.drop([f"{self.eligible_col}{self.indicator_baseline}", self.eligible_col])
)

# Truncate each (id, trial) at the first outcome event so that subjects who
# experience the outcome early are not carried forward with subsequent rows.
DT = DT.filter(
pl.col(self.outcome_col)
.fill_null(0)
.cum_max()
.shift(1, fill_value=0)
.over([self.id_col, "trial"])
== 0
)

return DT
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pySEQTarget"
version = "0.12.8"
version = "0.12.9"
description = "Sequentially Nested Target Trial Emulation"
readme = "README.md"
license = {text = "MIT"}
Expand All @@ -25,7 +25,7 @@ authors = [
{name = "Ryan O'Dea", email = "ryan.odea@psi.ch"},
{name = "Alejandro Szmulewicz", email = "aszmulewicz@hsph.harvard.edu"},
{name = "Tom Palmer", email = "tom.palmer@bristol.ac.uk"},
{name = "Miguel Hernan", email = "mhernan@hsph.harvard.edu"},
{name = "Miguel Hernán", email = "mhernan@hsph.harvard.edu"},
]

maintainers = [
Expand Down
92 changes: 92 additions & 0 deletions tests/test_expansion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import polars as pl

from pySEQTarget import SEQuential


def _make_model(data):
return SEQuential(
data,
id_col="ID",
time_col="time",
eligible_col="eligible",
treatment_col="treatment",
outcome_col="outcome",
time_varying_cols=[],
fixed_cols=[],
)


def test_expansion_truncates_at_first_outcome():
"""Expansion should truncate each (id, trial) at and including the first
outcome=1 row. Rows from later periods must not appear."""
data = pl.DataFrame(
{
"ID": [1, 1, 1, 1, 1],
"time": [0, 1, 2, 3, 4],
"eligible": [1, 0, 0, 0, 0],
# Both treatment values required by default treatment_level=[0,1]
"treatment": [0, 1, 0, 1, 0],
# outcome=1 at time=2, but original data continues with outcome=0
"outcome": [0, 0, 1, 0, 0],
}
)

model = _make_model(data)
model.expand()

# Only trial 0 exists (eligible only at time=0).
# Should have followup 0, 1, 2 — the outcome=1 row is included but not beyond.
followups = sorted(model.DT["followup"].to_list())
assert followups == [0, 1, 2]

# The outcome=1 row must be present
outcome_row = model.DT.filter(pl.col("outcome") == 1)
assert len(outcome_row) == 1
assert int(outcome_row["followup"][0]) == 2


def test_expansion_does_not_truncate_without_outcome():
"""Subjects who never experience the outcome should retain all expanded rows."""
data = pl.DataFrame(
{
"ID": [1, 1, 1, 1, 1],
"time": [0, 1, 2, 3, 4],
"eligible": [1, 0, 0, 0, 0],
"treatment": [0, 1, 0, 1, 0],
"outcome": [0, 0, 0, 0, 0],
}
)

model = _make_model(data)
model.expand()

# All 5 followup periods should be present
followups = sorted(model.DT["followup"].to_list())
assert followups == [0, 1, 2, 3, 4]


def test_expansion_truncates_each_trial_independently():
"""Truncation must apply per (id, trial), not globally. A subject enrolled in
multiple trials should have each trial truncated at its own first outcome."""
data = pl.DataFrame(
{
"ID": [1, 1, 1, 1, 1],
"time": [0, 1, 2, 3, 4],
"eligible": [1, 1, 0, 0, 0],
"treatment": [0, 1, 0, 1, 0],
# outcome=1 at time=3: trial 0 sees it at followup=3, trial 1 at followup=2
"outcome": [0, 0, 0, 1, 0],
}
)

model = _make_model(data)
model.expand()

trial_0 = model.DT.filter(pl.col("trial") == 0)
trial_1 = model.DT.filter(pl.col("trial") == 1)

# Trial 0 starts at time=0, outcome at time=3 → followup 0,1,2,3
assert sorted(trial_0["followup"].to_list()) == [0, 1, 2, 3]

# Trial 1 starts at time=1, outcome at time=3 → followup 0,1,2
assert sorted(trial_1["followup"].to_list()) == [0, 1, 2]
Loading