diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index c4fc496..6a6ae88 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -15,7 +15,7 @@ class SEQopts: :type bootstrap_sample: float :param bootstrap_CI: If bootstrapped, confidence interval level :type bootstrap_CI: float - :param bootstrap_CI_method: If bootstrapped, confidence method generation method ['SE' or 'percentile'] + :param bootstrap_CI_method: If bootstrapped, confidence interval method ['SE' or 'percentile'] :type bootstrap_CI_method: str :param cense_colname: Column name for censoring effect (LTFU, etc.) :type cense_colname: str @@ -109,6 +109,8 @@ class SEQopts: :type weight_p99: bool :param weight_preexpansion: Boolean to fit weights on preexpanded data :type weight_preexpansion: bool + :param verbose: Boolean to print dataset size summaries and bootstrap information + :type verbose: bool :param weighted: Boolean to weight analysis :type weighted: bool """ @@ -163,6 +165,7 @@ class SEQopts: weight_lag_condition: bool = True weight_p99: bool = False weight_preexpansion: bool = True + verbose: bool = False weighted: bool = False def _validate_bools(self): @@ -178,6 +181,7 @@ def _validate_bools(self): "selection_first_trial", "selection_random", "trial_include", + "verbose", "weight_lag_condition", "weight_p99", "weight_preexpansion", diff --git a/pySEQTarget/SEQoutput.py b/pySEQTarget/SEQoutput.py index cdc11eb..0bff491 100644 --- a/pySEQTarget/SEQoutput.py +++ b/pySEQTarget/SEQoutput.py @@ -62,7 +62,7 @@ class SEQoutput: def plot(self) -> None: """ - Displays the kaplan-meier graph + Displays the Kaplan-Meier graph """ if self.km_graph is None: raise ValueError( diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index a6ebaf7..c46cb61 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -114,6 +114,13 @@ def expand(self): :class:`polars.DataFrame` and skips all subsequent analysis steps. """ start = time.perf_counter() + + if self.verbose: + n, m = self.data.shape + print(f"Full dataset: {n:,} observations, {m} variables") + n_elig = self.data.filter(pl.col(self.eligible_col) == 1).shape[0] + print(f"Eligible observations: {n_elig:,}") + kept = [ self.cense_colname, self.cense_eligible_colname, @@ -162,14 +169,25 @@ def expand(self): pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col) ) + if self.verbose: + n, m = self.DT.shape + print(f"Expanded dataset: {n:,} observations, {m} variables") + if self.method == "dose-response" or ( self.method == "censoring" and not self.expand_only ): _dynamic(self) if self.selection_random: _random_selection(self) + if self.verbose: + n, m = self.DT.shape + print(f"Sampled expanded dataset: {n:,} observations, {m} variables") _diagnostics(self) + if self.verbose: + n, m = self.DT.shape + print(f"Final analysis dataset: {n:,} observations, {m} variables") + end = time.perf_counter() self._expansion_time = _format_time(start, end) @@ -200,6 +218,16 @@ def bootstrap(self, **kwargs) -> None: ) NIDs = len(UIDs) + if self.verbose: + n_sample = round(self.bootstrap_sample * NIDs) + n_obs_sample = round(self.bootstrap_sample * len(self.DT)) + print( + f"Bootstrapping with {self.bootstrap_sample * 100:.4g}% of " + f"{NIDs:,} subjects " + f"({n_sample:,} subjects, ~{n_obs_sample:,} observations per resample) " + f"{self.bootstrap_nboot} times" + ) + self._boot_samples = [] for _ in range(self.bootstrap_nboot): sampled_IDs = self._rng.choice( @@ -244,8 +272,23 @@ def fit(self) -> None: _weight_bind(self, WDT) self.weight_stats = _weight_stats(self) + is_boot = boot_idx is not None + start = getattr(self, "_outcome_start_params", None) if is_boot else None + if self.subgroup_colname is not None: - return _subgroup_fit(self) + models_list = _subgroup_fit(self, start_params=start) + if not is_boot: + self._outcome_start_params = { + val: { + key: (m.params.values, list(m.model.exog_names)) + for key, m in sg.items() + } + for val, sg in zip(self._unique_subgroups, models_list) + } + return models_list + + start_outcome = (start or {}).get("outcome") + start_compevent = (start or {}).get("compevent") models = { "outcome": _outcome_fit( @@ -255,6 +298,7 @@ def fit(self) -> None: self.covariates, self.weighted, "weight", + start_params=start_outcome, ) } if self.compevent_colname is not None: @@ -265,7 +309,15 @@ def fit(self) -> None: self.covariates, self.weighted, "weight", + start_params=start_compevent, ) + + if not is_boot: + self._outcome_start_params = { + k: (m.params.values, list(m.model.exog_names)) + for k, m in models.items() + } + if self.offload: offloaded_models = {} for key, model in models.items(): @@ -332,7 +384,7 @@ def plot(self, **kwargs) -> None: def collect(self) -> SEQoutput: """ - Collects all results current created into ``SEQoutput`` class + Collects all results currently created into ``SEQoutput`` class """ self._time_collected = datetime.datetime.now() diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index 0062fc8..64613fd 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -182,7 +182,7 @@ def _hazard_handler(self, data, idx, boot_idx, rng): sim_data_pd = sim_data.to_pandas() try: - # COXPHFITER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow() + # COXPHFITTER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow() warnings.filterwarnings("ignore", message=".*datetime.datetime.utcnow.*") if ce_model is not None: cox_data = sim_data_pd[sim_data_pd["event"].isin([0, 1])].copy() diff --git a/pySEQTarget/analysis/_outcome_fit.py b/pySEQTarget/analysis/_outcome_fit.py index a0fdda6..ed049a4 100644 --- a/pySEQTarget/analysis/_outcome_fit.py +++ b/pySEQTarget/analysis/_outcome_fit.py @@ -68,6 +68,7 @@ def _outcome_fit( formula: str, weighted: bool = False, weight_col: str = "weight", + start_params=None, ): if weighted: df = df.with_columns( @@ -102,5 +103,23 @@ def _outcome_fit( glm_kwargs["var_weights"] = df_pd[weight_col] model = smf.glm(**glm_kwargs) - model_fit = model.fit() + + # Drop warm-start coefs unless the design matrix columns match exactly + # by name — bootstrap resamples can shift categorical reference levels or + # column ordering, in which case the cached coefs are meaningless and + # IRLS can diverge into NaN/Inf and crash LAPACK. + if start_params is not None: + sp_values, sp_names = start_params + if list(model.exog_names) != list(sp_names): + start_params = None + else: + start_params = sp_values + + try: + model_fit = model.fit(start_params=start_params) + except Exception: + if start_params is not None: + model_fit = model.fit() + else: + raise return model_fit diff --git a/pySEQTarget/analysis/_subgroup_fit.py b/pySEQTarget/analysis/_subgroup_fit.py index fd481cf..e78e19d 100644 --- a/pySEQTarget/analysis/_subgroup_fit.py +++ b/pySEQTarget/analysis/_subgroup_fit.py @@ -3,17 +3,24 @@ from ._outcome_fit import _outcome_fit -def _subgroup_fit(self): +def _subgroup_fit(self, start_params=None): subgroups = sorted(self.DT[self.subgroup_colname].unique().to_list()) self._unique_subgroups = subgroups models_list = [] for val in subgroups: subDT = self.DT.filter(pl.col(self.subgroup_colname) == val) + sg_start = (start_params or {}).get(val, {}) or {} models = { "outcome": _outcome_fit( - self, subDT, self.outcome_col, self.covariates, self.weighted, "weight" + self, + subDT, + self.outcome_col, + self.covariates, + self.weighted, + "weight", + start_params=sg_start.get("outcome"), ) } @@ -25,6 +32,7 @@ def _subgroup_fit(self): self.covariates, self.weighted, "weight", + start_params=sg_start.get("compevent"), ) models_list.append(models) return models_list diff --git a/pySEQTarget/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py index 48b7e56..ac9a4ef 100644 --- a/pySEQTarget/analysis/_survival_pred.py +++ b/pySEQTarget/analysis/_survival_pred.py @@ -1,5 +1,8 @@ +import numpy as np import polars as pl +from patsy import PatsyError, dmatrix +from ..helpers._fix_categories import _fix_categories_for_predict from ..helpers._predict_model import _safe_predict from ._outcome_fit import _cast_categories @@ -25,20 +28,79 @@ def _store_boot_risks(obj, treatment_val, TxDT, boot_cols, is_survival=False): ) +def _build_design_matrix(design_info, data): + """ + Build a design matrix from a cached design_info, applying the same + category-alignment fallback that _safe_predict uses on mismatch. + """ + try: + return np.asarray(dmatrix(design_info, data)) + except PatsyError as e: + if "mismatching levels" not in str(e): + raise + + # Reuse the existing fix by wrapping design_info in a stub object + class _Stub: + class model: + class data: + pass + + stub = _Stub() + stub.model.data.design_info = design_info + fixed = _fix_categories_for_predict(stub, data.copy()) + return np.asarray(dmatrix(design_info, fixed)) + + +def _cached_predict(model, X_cached, ref_column_names, data): + """ + Predict using a pre-built design matrix when the model's design_info + column structure matches the reference, falling back to patsy via + _safe_predict on mismatch (e.g. a bootstrap resample that dropped a + categorical level). + """ + dinfo = model.model.data.design_info + if list(dinfo.column_names) == ref_column_names: + probs = np.asarray(model.predict(X_cached, transform=False)) + if not np.any(np.isnan(probs)): + return np.clip(probs, 0, 1) + return _safe_predict(model, data) + + def _get_outcome_predictions(self, TxDT, idx=None): data = _cast_categories(self, TxDT.to_pandas()) predictions = {"outcome": []} if self.compevent_colname is not None: predictions["compevent"] = [] + # Pre-build the design matrix once using the main fit's design_info. + # Each bootstrap model that shares the same column structure can then + # reuse it, skipping patsy entirely on the predict path. + main = self.outcome_model[0] + main_dict = main[idx] if idx is not None else main + main_outcome = self._offloader.load_model(main_dict["outcome"]) + outcome_dinfo = main_outcome.model.data.design_info + X_outcome = _build_design_matrix(outcome_dinfo, data) + outcome_cols = list(outcome_dinfo.column_names) + + X_compevent = compevent_cols = None + if self.compevent_colname is not None: + main_compevent = self._offloader.load_model(main_dict["compevent"]) + compevent_dinfo = main_compevent.model.data.design_info + X_compevent = _build_design_matrix(compevent_dinfo, data) + compevent_cols = list(compevent_dinfo.column_names) + for boot_model in self.outcome_model: model_dict = boot_model[idx] if idx is not None else boot_model outcome_model = self._offloader.load_model(model_dict["outcome"]) - predictions["outcome"].append(_safe_predict(outcome_model, data)) + predictions["outcome"].append( + _cached_predict(outcome_model, X_outcome, outcome_cols, data) + ) if self.compevent_colname is not None: compevent_model = self._offloader.load_model(model_dict["compevent"]) - predictions["compevent"].append(_safe_predict(compevent_model, data)) + predictions["compevent"].append( + _cached_predict(compevent_model, X_compevent, compevent_cols, data) + ) return predictions diff --git a/pySEQTarget/error/_data_checker.py b/pySEQTarget/error/_data_checker.py index 15c0817..217be1c 100644 --- a/pySEQTarget/error/_data_checker.py +++ b/pySEQTarget/error/_data_checker.py @@ -31,8 +31,8 @@ def _data_checker(self): invalid = check.filter(pl.col("row_count") != pl.col("max_time") + 1) if len(invalid) > 0: raise ValueError( - f"Data validation failed: {len(invalid)} ID(s) have mismatched " - f"This suggests invalid times" + f"Data validation failed: {len(invalid)} ID(s) have mismatched row counts. " + f"This suggests invalid times. " f"Invalid IDs:\n{invalid}" ) diff --git a/pySEQTarget/error/_param_checker.py b/pySEQTarget/error/_param_checker.py index 66492cb..fcacad4 100644 --- a/pySEQTarget/error/_param_checker.py +++ b/pySEQTarget/error/_param_checker.py @@ -1,3 +1,5 @@ +import warnings + from ..helpers import _pad @@ -26,13 +28,13 @@ def _param_checker(self): if len(self.excused_colnames) == 0 and self.excused: self.excused = False - raise Warning( + warnings.warn( "Excused column names not provided but excused is set to True. Automatically set excused to False" ) if len(self.excused_colnames) > 0 and not self.excused: self.excused = True - raise Warning( + warnings.warn( "Excused column names provided but excused is set to False. Automatically set excused to True" ) diff --git a/pySEQTarget/expansion/_binder.py b/pySEQTarget/expansion/_binder.py index 3b5a723..4e10271 100644 --- a/pySEQTarget/expansion/_binder.py +++ b/pySEQTarget/expansion/_binder.py @@ -5,7 +5,7 @@ def _binder(self, kept_cols): """ - Internal function to bind data to the map created by __mapper + Internal function to bind data to the map created by _mapper """ excluded = { "dose", diff --git a/pySEQTarget/expansion/_dynamic.py b/pySEQTarget/expansion/_dynamic.py index 8d1918f..52060a8 100644 --- a/pySEQTarget/expansion/_dynamic.py +++ b/pySEQTarget/expansion/_dynamic.py @@ -3,7 +3,7 @@ def _dynamic(self): """ - Handles special cases for the data from the __mapper -> __binder pipeline + Handles special cases for the data from the _mapper -> _binder pipeline """ if self.method == "dose-response": DT = self.DT.with_columns( diff --git a/pyproject.toml b/pyproject.toml index 73035e1..0eb1343 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.13.3" +version = "0.13.4" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} diff --git a/tests/test_expansion.py b/tests/test_expansion.py index eac5b39..1a27261 100644 --- a/tests/test_expansion.py +++ b/tests/test_expansion.py @@ -2,6 +2,7 @@ from polars.testing import assert_frame_equal from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data def _make_model(data): @@ -126,3 +127,59 @@ def test_expand_only_returns_expanded_dataframe(): model_full.expand() assert_frame_equal(result, model_full.DT) + + +def _make_verbose_model(verbose, **extra_opts): + data = load_data("SEQdata") + return SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(verbose=verbose, **extra_opts), + ) + + +def test_verbose_expand(capsys): + s = _make_verbose_model(verbose=True) + s.expand() + out = capsys.readouterr().out + assert "Full dataset:" in out + assert "Eligible observations:" in out + assert "Expanded dataset:" in out + assert "Final analysis dataset:" in out + assert "Sampled expanded dataset:" not in out + assert "observations" in out + assert "variables" in out + + +def test_verbose_expand_with_sampling(capsys): + s = _make_verbose_model(verbose=True, selection_random=True, selection_sample=0.5) + s.expand() + out = capsys.readouterr().out + assert "Sampled expanded dataset:" in out + + +def test_verbose_bootstrap(capsys): + s = _make_verbose_model(verbose=True, bootstrap_nboot=10) + s.expand() + capsys.readouterr() + s.bootstrap() + out = capsys.readouterr().out + assert "Bootstrapping" in out + assert "subjects" in out + assert "observations per resample" in out + assert "10 times" in out + + +def test_verbose_false_no_output(capsys): + s = _make_verbose_model(verbose=False, bootstrap_nboot=5) + s.expand() + s.bootstrap() + out = capsys.readouterr().out + assert out == ""