From ea357bfd0d70b3cdf21b336a2fd070d43ae8d47d Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Tue, 19 May 2026 11:33:05 +0100 Subject: [PATCH 1/7] Reduce the height and padding of the piccolo_theme top navigation bar --- docs/_static/custom.css | 8 ++++++++ docs/conf.py | 1 + 2 files changed, 9 insertions(+) create mode 100644 docs/_static/custom.css diff --git a/docs/_static/custom.css b/docs/_static/custom.css new file mode 100644 index 0000000..23a7d8b --- /dev/null +++ b/docs/_static/custom.css @@ -0,0 +1,8 @@ +/* Reduce the height of the piccolo_theme top navigation bar */ +:root { + --navbarHeight: 3.25rem; +} + +div#top_nav nav { + padding: 0.7rem 1rem; +} diff --git a/docs/conf.py b/docs/conf.py index eaa3ba4..aa8fa75 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -50,3 +50,4 @@ html_theme = "piccolo_theme" html_static_path = ["_static"] +html_css_files = ["custom.css"] From 3f4f1293eb9d06c0021106ac4ca50d2badbdbe76 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 21 May 2026 06:41:18 +0100 Subject: [PATCH 2/7] Add risk_times option --- pySEQTarget/SEQopts.py | 13 ++ pySEQTarget/analysis/_risk_estimates.py | 253 ++++++++++++++---------- 2 files changed, 164 insertions(+), 102 deletions(-) diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 6a6ae88..4392b7e 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -78,6 +78,10 @@ class SEQopts: :type plot_title: str :param plot_type: Type of plot to show ["risk", "survival" or "incidence" if compevent is specified] :type plot_type: str + :param risk_times: Followup times at which to report risk difference and risk ratio when ``km_curves = True``. + Each requested time is snapped to the latest available followup at or before it, and the maximum + followup is always included. Defaults to ``None`` (report at the maximum followup only). + :type risk_times: Optional[List[float]] or None :param seed: RNG seed :type seed: int :param selection_first_trial: Boolean to only use first trial for analysis (similar to non-expanded) @@ -150,6 +154,7 @@ class SEQopts: plot_labels: List[str] = field(default_factory=lambda: []) plot_title: str = None plot_type: Literal["risk", "survival", "incidence"] = "survival" + risk_times: Optional[List[float]] = None seed: Optional[int] = None selection_first_trial: bool = False selection_sample: float = 0.8 @@ -210,6 +215,14 @@ def _validate_ranges(self): raise ValueError( f"followup_min ({self.followup_min}) must be less than followup_max ({self.followup_max})." ) + if self.risk_times is not None: + times = ( + self.risk_times + if isinstance(self.risk_times, (list, tuple)) + else [self.risk_times] + ) + if any(not isinstance(t, (int, float)) or t < 0 for t in times): + raise ValueError("risk_times values must be non-negative numbers.") def _validate_choices(self): if self.plot_type not in ["risk", "survival", "incidence"]: diff --git a/pySEQTarget/analysis/_risk_estimates.py b/pySEQTarget/analysis/_risk_estimates.py index 561179c..d8a6a0a 100644 --- a/pySEQTarget/analysis/_risk_estimates.py +++ b/pySEQTarget/analysis/_risk_estimates.py @@ -73,10 +73,42 @@ def _compute_rd_rr(comp, has_bootstrap, z=None, group_cols=None): return rd_comp, rr_comp +def _resolve_risk_times(grid, risk_times): + """ + Snap each requested risk time to the latest available followup at or before + it, always including the maximum followup. Returns a sorted list of followup + values that exist in ``grid``. + """ + grid = sorted(set(grid)) + final = grid[-1] + + if risk_times is None: + return [final] + + req = risk_times if isinstance(risk_times, (list, tuple)) else [risk_times] + req = [float(t) for t in req if t is not None] + if not req: + return [final] + + above = [t for t in req if t > final] + if above: + raise ValueError( + f"risk_times value(s) exceed the maximum followup ({final}): {above}" + ) + below = [t for t in req if t < grid[0]] + if below: + raise ValueError( + f"risk_times value(s) below the minimum followup ({grid[0]}): {below}" + ) + + snapped = [max(g for g in grid if g <= t) for t in req] + return sorted(set(snapped + [final])) + + def _risk_estimates(self): - last_followup = self.km_data["followup"].max() - risk = self.km_data.filter( - (pl.col("followup") == last_followup) & (pl.col("estimate") == "risk") + risk_all = self.km_data.filter(pl.col("estimate") == "risk") + report_times = _resolve_risk_times( + risk_all["followup"].unique().to_list(), self.risk_times ) group_cols = [self.subgroup_colname] if self.subgroup_colname else [] @@ -101,115 +133,132 @@ def _risk_estimates(self): z = None alpha = None - risk_by_level = {} - for tx in self.treatment_level: - level_data = risk.filter(pl.col(self.treatment_col) == tx) - risk_by_level[tx] = {"pred": level_data.select(group_cols + ["pred"])} - if has_bootstrap and not use_paired: - risk_by_level[tx]["SE"] = level_data.select(group_cols + ["SE"]) - rd_comparisons = [] rr_comparisons = [] - for tx_x in self.treatment_level: - for tx_y in self.treatment_level: - if tx_x == tx_y: - continue - - if use_paired: - boot_x = ( - self._boot_risks[tx_x] - .filter(pl.col("followup") == last_followup) - .select(["boot_idx", pl.col("risk").alias("risk_x")]) - ) - boot_y = ( - self._boot_risks[tx_y] - .filter(pl.col("followup") == last_followup) - .select(["boot_idx", pl.col("risk").alias("risk_y")]) - ) - paired = boot_x.join(boot_y, on="boot_idx").with_columns( - (pl.col("risk_x") - pl.col("risk_y")).alias("RD") - ) - - risk_x_val = float(risk_by_level[tx_x]["pred"]["pred"][0]) - risk_y_val = float(risk_by_level[tx_y]["pred"]["pred"][0]) - rd_point = risk_x_val - risk_y_val - rr_point = risk_x_val / risk_y_val if risk_y_val != 0 else float("inf") - - # Filter degenerate RR bootstrap values (risk_y == 0 or negative) - valid_rr = paired.filter( - (pl.col("risk_y") > 0) & (pl.col("risk_x") >= 0) - ).with_columns((pl.col("risk_x") / pl.col("risk_y")).alias("RR")) - - n_valid_rr = len(valid_rr) - - if self.bootstrap_CI_method == "percentile": - rd_lci = float(paired["RD"].quantile(alpha / 2)) - rd_uci = float(paired["RD"].quantile(1 - alpha / 2)) - if n_valid_rr >= 2: - rr_lci = float(valid_rr["RR"].quantile(alpha / 2)) - rr_uci = float(valid_rr["RR"].quantile(1 - alpha / 2)) - else: - rr_lci = float("nan") - rr_uci = float("nan") - else: - rd_se = float(paired["RD"].std()) - rd_lci = rd_point - z * rd_se - rd_uci = rd_point + z * rd_se - if n_valid_rr >= 2 and rr_point > 0: - log_rr_se = float(valid_rr["RR"].log().std()) - rr_lci = math.exp(math.log(rr_point) - z * log_rr_se) - rr_uci = math.exp(math.log(rr_point) + z * log_rr_se) + for followup_t in report_times: + risk = risk_all.filter(pl.col("followup") == followup_t) + + risk_by_level = {} + for tx in self.treatment_level: + level_data = risk.filter(pl.col(self.treatment_col) == tx) + risk_by_level[tx] = {"pred": level_data.select(group_cols + ["pred"])} + if has_bootstrap and not use_paired: + risk_by_level[tx]["SE"] = level_data.select(group_cols + ["SE"]) + + for tx_x in self.treatment_level: + for tx_y in self.treatment_level: + if tx_x == tx_y: + continue + + if use_paired: + boot_x = ( + self._boot_risks[tx_x] + .filter(pl.col("followup") == followup_t) + .select(["boot_idx", pl.col("risk").alias("risk_x")]) + ) + boot_y = ( + self._boot_risks[tx_y] + .filter(pl.col("followup") == followup_t) + .select(["boot_idx", pl.col("risk").alias("risk_y")]) + ) + paired = boot_x.join(boot_y, on="boot_idx").with_columns( + (pl.col("risk_x") - pl.col("risk_y")).alias("RD") + ) + + risk_x_val = float(risk_by_level[tx_x]["pred"]["pred"][0]) + risk_y_val = float(risk_by_level[tx_y]["pred"]["pred"][0]) + rd_point = risk_x_val - risk_y_val + rr_point = ( + risk_x_val / risk_y_val if risk_y_val != 0 else float("inf") + ) + + # Filter degenerate RR bootstrap values (risk_y == 0 or negative) + valid_rr = paired.filter( + (pl.col("risk_y") > 0) & (pl.col("risk_x") >= 0) + ).with_columns((pl.col("risk_x") / pl.col("risk_y")).alias("RR")) + + n_valid_rr = len(valid_rr) + + if self.bootstrap_CI_method == "percentile": + rd_lci = float(paired["RD"].quantile(alpha / 2)) + rd_uci = float(paired["RD"].quantile(1 - alpha / 2)) + if n_valid_rr >= 2: + rr_lci = float(valid_rr["RR"].quantile(alpha / 2)) + rr_uci = float(valid_rr["RR"].quantile(1 - alpha / 2)) + else: + rr_lci = float("nan") + rr_uci = float("nan") else: - rr_lci = float("nan") - rr_uci = float("nan") - - rd_comp = pl.DataFrame( - { - "A_x": [tx_x], - "A_y": [tx_y], - "Risk Difference": [rd_point], - "RD 95% LCI": [rd_lci], - "RD 95% UCI": [rd_uci], - } - ) - rr_comp = pl.DataFrame( - { - "A_x": [tx_x], - "A_y": [tx_y], - "Risk Ratio": [rr_point], - "RR 95% LCI": [rr_lci], - "RR 95% UCI": [rr_uci], - } - ) - else: - # Fall back to independent delta method - risk_x = risk_by_level[tx_x]["pred"].rename({"pred": "risk_x"}) - risk_y = risk_by_level[tx_y]["pred"].rename({"pred": "risk_y"}) - - if group_cols: - comp = risk_x.join(risk_y, on=group_cols, how="left") - else: - comp = risk_x.join(risk_y, how="cross") + rd_se = float(paired["RD"].std()) + rd_lci = rd_point - z * rd_se + rd_uci = rd_point + z * rd_se + if n_valid_rr >= 2 and rr_point > 0: + log_rr_se = float(valid_rr["RR"].log().std()) + rr_lci = math.exp(math.log(rr_point) - z * log_rr_se) + rr_uci = math.exp(math.log(rr_point) + z * log_rr_se) + else: + rr_lci = float("nan") + rr_uci = float("nan") - comp = comp.with_columns( - [pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")] - ) + rd_comp = pl.DataFrame( + { + "Followup": [followup_t], + "A_x": [tx_x], + "A_y": [tx_y], + "Risk Difference": [rd_point], + "RD 95% LCI": [rd_lci], + "RD 95% UCI": [rd_uci], + } + ) + rr_comp = pl.DataFrame( + { + "Followup": [followup_t], + "A_x": [tx_x], + "A_y": [tx_y], + "Risk Ratio": [rr_point], + "RR 95% LCI": [rr_lci], + "RR 95% UCI": [rr_uci], + } + ) + else: + # Fall back to independent delta method + risk_x = risk_by_level[tx_x]["pred"].rename({"pred": "risk_x"}) + risk_y = risk_by_level[tx_y]["pred"].rename({"pred": "risk_y"}) - if has_bootstrap: - se_x = risk_by_level[tx_x]["SE"].rename({"SE": "se_x"}) - se_y = risk_by_level[tx_y]["SE"].rename({"SE": "se_y"}) if group_cols: - comp = comp.join(se_x, on=group_cols, how="left") - comp = comp.join(se_y, on=group_cols, how="left") + comp = risk_x.join(risk_y, on=group_cols, how="left") else: - comp = comp.join(se_x, how="cross") - comp = comp.join(se_y, how="cross") + comp = risk_x.join(risk_y, how="cross") + + comp = comp.with_columns( + [pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")] + ) + + if has_bootstrap: + se_x = risk_by_level[tx_x]["SE"].rename({"SE": "se_x"}) + se_y = risk_by_level[tx_y]["SE"].rename({"SE": "se_y"}) + if group_cols: + comp = comp.join(se_x, on=group_cols, how="left") + comp = comp.join(se_y, on=group_cols, how="left") + else: + comp = comp.join(se_x, how="cross") + comp = comp.join(se_y, how="cross") - rd_comp, rr_comp = _compute_rd_rr(comp, has_bootstrap, z, group_cols) + rd_comp, rr_comp = _compute_rd_rr( + comp, has_bootstrap, z, group_cols + ) + rd_cols = rd_comp.columns + rr_cols = rr_comp.columns + rd_comp = rd_comp.with_columns( + pl.lit(followup_t).alias("Followup") + ).select(["Followup"] + rd_cols) + rr_comp = rr_comp.with_columns( + pl.lit(followup_t).alias("Followup") + ).select(["Followup"] + rr_cols) - rd_comparisons.append(rd_comp) - rr_comparisons.append(rr_comp) + rd_comparisons.append(rd_comp) + rr_comparisons.append(rr_comp) risk_difference = pl.concat(rd_comparisons) if rd_comparisons else pl.DataFrame() risk_ratio = pl.concat(rr_comparisons) if rr_comparisons else pl.DataFrame() From 3b16106727044f80e6f09fb7d990f94151695922 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 21 May 2026 06:41:33 +0100 Subject: [PATCH 3/7] Add risk_times tests --- tests/test_survival.py | 114 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/tests/test_survival.py b/tests/test_survival.py index d9a577f..183dc32 100644 --- a/tests/test_survival.py +++ b/tests/test_survival.py @@ -1,11 +1,125 @@ import os +import polars as pl import pytest from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data +def _final_followup(s): + return s.km_data.filter(pl.col("estimate") == "risk")["followup"].max() + + +def test_risk_times_reports_requested_followups(): + data = load_data("SEQdata") + + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts( + km_curves=True, risk_times=[2, 5], bootstrap_nboot=3, seed=42 + ), + ) + s.expand() + s.bootstrap() + s.fit() + s.survival() + + final = _final_followup(s) + rd = s.risk_estimates["risk_difference"] + rr = s.risk_estimates["risk_ratio"] + + assert "Followup" in rd.columns + assert set(rd["Followup"].to_list()) == {2, 5, final} + assert set(rr["Followup"].to_list()) == {2, 5, final} + for col in ["RD 95% LCI", "RD 95% UCI"]: + assert rd[col].null_count() == 0 + + +def test_risk_times_default_reports_only_final(): + data = load_data("SEQdata") + + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=True), + ) + s.expand() + s.fit() + s.survival() + + assert set(s.risk_estimates["risk_difference"]["Followup"].to_list()) == { + _final_followup(s) + } + + +def test_risk_times_snaps_to_grid(): + data = load_data("SEQdata") + + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=True, risk_times=[2.5]), + ) + s.expand() + s.fit() + s.survival() + + # 2.5 snaps down to 2; final followup is always included + assert set(s.risk_estimates["risk_difference"]["Followup"].to_list()) == { + 2, + _final_followup(s), + } + + +def test_risk_times_exceeding_max_raises(): + data = load_data("SEQdata") + + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=True, risk_times=[1e6]), + ) + s.expand() + s.fit() + with pytest.raises(ValueError, match="maximum followup"): + s.survival() + + +def test_risk_times_negative_rejected(): + with pytest.raises(ValueError, match="non-negative"): + SEQopts(km_curves=True, risk_times=[-1]) + + def test_regular_survival(): data = load_data("SEQdata") From fde1d6c42d77ee510f4776673a4820a2bccbc890 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 21 May 2026 06:41:41 +0100 Subject: [PATCH 4/7] Bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0eb1343..0f4c31f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.13.4" +version = "0.13.5" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} From a9294da1b1aaab471aaf672b8753ad3f0b06991e Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 21 May 2026 07:35:58 +0100 Subject: [PATCH 5/7] Ensure stable factor encoding for categorical time-varying covariates in the outcome model --- pySEQTarget/analysis/_outcome_fit.py | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/pySEQTarget/analysis/_outcome_fit.py b/pySEQTarget/analysis/_outcome_fit.py index ed049a4..af91176 100644 --- a/pySEQTarget/analysis/_outcome_fit.py +++ b/pySEQTarget/analysis/_outcome_fit.py @@ -1,9 +1,11 @@ import re import numpy as np +import pandas as pd import polars as pl import statsmodels.api as sm import statsmodels.formula.api as smf +from pandas.api.types import is_numeric_dtype def _compute_spline_knots(followup_arr, df=3): @@ -40,6 +42,20 @@ def _apply_spline_formula(formula, indicator_squared, spline_knots): return spline +def _categorical_tv_columns(self, df_pd): + """ + Names of the categorical (non-numeric) time-varying covariate columns + present in ``df_pd``, including their baseline (``indicator_baseline``) + versions used by the outcome model. + """ + cols = [] + for col in self.time_varying_cols or []: + for variant in (col, f"{col}{self.indicator_baseline}"): + if variant in df_pd.columns and not is_numeric_dtype(df_pd[variant]): + cols.append(variant) + return cols + + def _cast_categories(self, df_pd): if self.treatment_col in df_pd.columns: df_pd[self.treatment_col] = df_pd[self.treatment_col].astype("category") @@ -58,6 +74,22 @@ def _cast_categories(self, df_pd): if col in df_pd.columns: df_pd[col] = df_pd[col].astype("category") + # Stable factor encoding for categorical time-varying covariates: fix the + # level set from the full expanded data (captured on the non-bootstrap + # pass) so a bootstrap resample cannot realise a different set of levels — + # otherwise a level absent from the resample would be unknown to that fit + # and crash counterfactual prediction with NaNs. + tv_cat_cols = _categorical_tv_columns(self, df_pd) + if getattr(self, "_current_boot_idx", None) is None: + cats = getattr(self, "_covariate_categories", {}) + for col in tv_cat_cols: + cats[col] = sorted(df_pd[col].dropna().unique().tolist()) + self._covariate_categories = cats + cats = getattr(self, "_covariate_categories", {}) + for col in tv_cat_cols: + if col in cats: + df_pd[col] = pd.Categorical(df_pd[col], categories=cats[col]) + return df_pd From e54f6091d091574b447b57f83b0fe210b0ca200d Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 21 May 2026 07:36:18 +0100 Subject: [PATCH 6/7] Add test for stable factor encoding in the outcome model --- tests/test_categorical_covariates.py | 72 ++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/test_categorical_covariates.py diff --git a/tests/test_categorical_covariates.py b/tests/test_categorical_covariates.py new file mode 100644 index 0000000..5aa4019 --- /dev/null +++ b/tests/test_categorical_covariates.py @@ -0,0 +1,72 @@ +import numpy as np +import polars as pl + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _model(data, **opts): + return SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P", "grp"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=True, **opts), + ) + + +def test_string_time_varying_covariate_bootstrap(): + """A categorical (string) time-varying covariate should run through the full + bootstrap pipeline and produce risk estimates.""" + data = load_data("SEQdata") + rng = np.random.RandomState(1) + data = data.with_columns( + pl.Series("grp", rng.choice(["a", "b", "c"], size=data.height)) + ) + + s = _model(data, bootstrap_nboot=3, seed=42) + s.expand() + s.bootstrap() + s.fit() + s.survival() + + assert "grp_bas" in s.DT.columns + rd = s.risk_estimates["risk_difference"] + assert rd.height > 0 + assert rd.select(["RD 95% LCI", "RD 95% UCI"]).null_count().to_series().sum() == 0 + + +def test_rare_level_not_dropped_by_bootstrap_resample(): + """A rare categorical level absent from some bootstrap resamples must not + crash counterfactual prediction. The full-data level set is fixed at fit + time so every resample shares a stable factor encoding.""" + data = load_data("SEQdata") + ids = data["ID"].unique().to_list() + rng = np.random.RandomState(0) + # Level "c" appears for a single ID only, so aggressive subsampling will + # produce resamples that omit it entirely. + grp = pl.Series( + "grp", + np.where( + np.isin(data["ID"].to_numpy(), [ids[0]]), + "c", + rng.choice(["a", "b"], size=data.height), + ), + ) + data = data.with_columns(grp) + + s = _model(data, bootstrap_nboot=8, bootstrap_sample=0.5, seed=7) + s.expand() + s.bootstrap() + s.fit() + s.survival() # previously raised ValueError on NaN predictions + + rd = s.risk_estimates["risk_difference"] + assert rd.height > 0 + assert not rd["RD 95% LCI"].is_nan().any() + assert not rd["RD 95% UCI"].is_nan().any() From 0793a95be4f98d40a680c832ee291c576b908f11 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 21 May 2026 19:48:40 +0100 Subject: [PATCH 7/7] Fix off-by-one in survival/risk followup grid _calculate_risk predicted the discrete hazard on followup = 1..followup_max, which silently dropped the first interval's hazard (followup = 0, where an event can already occur in the expanded data) and ended the curve one step short. The first survival step used h(1) instead of h(0), shifting every point and biasing risks downward. Predict on the full grid (0..followup_max) and shift the curve labels +1 after the cumulative product so followup = k means "survival/risk after k elapsed intervals", giving rows 0..followup_max+1. The shift is applied before _store_boot_risks and the bootstrap CI join, so _boot_risks, _resolve_risk_times, and the paired RD/RR computation stay aligned. This matches the SEQTaRget (R) survival output. --- pySEQTarget/analysis/_survival_pred.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pySEQTarget/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py index ac9a4ef..c9e17a7 100644 --- a/pySEQTarget/analysis/_survival_pred.py +++ b/pySEQTarget/analysis/_survival_pred.py @@ -132,8 +132,13 @@ def _calculate_risk(self, data, idx=None, val=None): lci = a / 2 uci = 1 - lci - # Pre-compute the followup range once (starts at 1, not 0) - followup_range = list(range(1, self.followup_max + 1)) + # Predict the hazard on the full followup grid starting at 0 — the first + # interval of every trial, where an event can already occur. Curve labels + # are shifted +1 after the cumulative product (below) so that followup=k + # means "survival/risk after k elapsed intervals", giving rows + # 0..followup_max+1. This matches SEQTaRget (R); starting the grid at 1 + # silently dropped the first interval's hazard and ended one step short. + followup_range = list(range(0, self.followup_max + 1)) SDT = ( data.with_columns( @@ -223,6 +228,7 @@ def _calculate_risk(self, data, idx=None, val=None): TxDT.group_by("followup") .agg([pl.col(col).mean() for col in surv_names + inc_names]) .sort("followup") + .with_columns(pl.col("followup") + 1) ) main_col = "surv" boot_cols = [col for col in surv_names if col != "surv"] @@ -242,6 +248,7 @@ def _calculate_risk(self, data, idx=None, val=None): .agg([pl.col(col).mean() for col in outcome_names]) .sort("followup") .with_columns([(1 - pl.col(col)).alias(col) for col in outcome_names]) + .with_columns(pl.col("followup") + 1) ) main_col = "pred_outcome" boot_cols = [col for col in outcome_names if col != "pred_outcome"]