Skip to content
6 changes: 5 additions & 1 deletion pySEQTarget/SEQopts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class SEQopts:
:type bootstrap_sample: float
:param bootstrap_CI: If bootstrapped, confidence interval level
:type bootstrap_CI: float
:param bootstrap_CI_method: If bootstrapped, confidence method generation method ['SE' or 'percentile']
:param bootstrap_CI_method: If bootstrapped, confidence interval method ['SE' or 'percentile']
:type bootstrap_CI_method: str
:param cense_colname: Column name for censoring effect (LTFU, etc.)
:type cense_colname: str
Expand Down Expand Up @@ -109,6 +109,8 @@ class SEQopts:
:type weight_p99: bool
:param weight_preexpansion: Boolean to fit weights on preexpanded data
:type weight_preexpansion: bool
:param verbose: Boolean to print dataset size summaries and bootstrap information
:type verbose: bool
:param weighted: Boolean to weight analysis
:type weighted: bool
"""
Expand Down Expand Up @@ -163,6 +165,7 @@ class SEQopts:
weight_lag_condition: bool = True
weight_p99: bool = False
weight_preexpansion: bool = True
verbose: bool = False
weighted: bool = False

def _validate_bools(self):
Expand All @@ -178,6 +181,7 @@ def _validate_bools(self):
"selection_first_trial",
"selection_random",
"trial_include",
"verbose",
"weight_lag_condition",
"weight_p99",
"weight_preexpansion",
Expand Down
2 changes: 1 addition & 1 deletion pySEQTarget/SEQoutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class SEQoutput:

def plot(self) -> None:
"""
Displays the kaplan-meier graph
Displays the Kaplan-Meier graph
"""
if self.km_graph is None:
raise ValueError(
Expand Down
56 changes: 54 additions & 2 deletions pySEQTarget/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def expand(self):
:class:`polars.DataFrame` and skips all subsequent analysis steps.
"""
start = time.perf_counter()

if self.verbose:
n, m = self.data.shape
print(f"Full dataset: {n:,} observations, {m} variables")
n_elig = self.data.filter(pl.col(self.eligible_col) == 1).shape[0]
print(f"Eligible observations: {n_elig:,}")

kept = [
self.cense_colname,
self.cense_eligible_colname,
Expand Down Expand Up @@ -162,14 +169,25 @@ def expand(self):
pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col)
)

if self.verbose:
n, m = self.DT.shape
print(f"Expanded dataset: {n:,} observations, {m} variables")

if self.method == "dose-response" or (
self.method == "censoring" and not self.expand_only
):
_dynamic(self)
if self.selection_random:
_random_selection(self)
if self.verbose:
n, m = self.DT.shape
print(f"Sampled expanded dataset: {n:,} observations, {m} variables")
_diagnostics(self)

if self.verbose:
n, m = self.DT.shape
print(f"Final analysis dataset: {n:,} observations, {m} variables")

end = time.perf_counter()
self._expansion_time = _format_time(start, end)

Expand Down Expand Up @@ -200,6 +218,16 @@ def bootstrap(self, **kwargs) -> None:
)
NIDs = len(UIDs)

if self.verbose:
n_sample = round(self.bootstrap_sample * NIDs)
n_obs_sample = round(self.bootstrap_sample * len(self.DT))
print(
f"Bootstrapping with {self.bootstrap_sample * 100:.4g}% of "
f"{NIDs:,} subjects "
f"({n_sample:,} subjects, ~{n_obs_sample:,} observations per resample) "
f"{self.bootstrap_nboot} times"
)

self._boot_samples = []
for _ in range(self.bootstrap_nboot):
sampled_IDs = self._rng.choice(
Expand Down Expand Up @@ -244,8 +272,23 @@ def fit(self) -> None:
_weight_bind(self, WDT)
self.weight_stats = _weight_stats(self)

is_boot = boot_idx is not None
start = getattr(self, "_outcome_start_params", None) if is_boot else None

if self.subgroup_colname is not None:
return _subgroup_fit(self)
models_list = _subgroup_fit(self, start_params=start)
if not is_boot:
self._outcome_start_params = {
val: {
key: (m.params.values, list(m.model.exog_names))
for key, m in sg.items()
}
for val, sg in zip(self._unique_subgroups, models_list)
}
return models_list

start_outcome = (start or {}).get("outcome")
start_compevent = (start or {}).get("compevent")

models = {
"outcome": _outcome_fit(
Expand All @@ -255,6 +298,7 @@ def fit(self) -> None:
self.covariates,
self.weighted,
"weight",
start_params=start_outcome,
)
}
if self.compevent_colname is not None:
Expand All @@ -265,7 +309,15 @@ def fit(self) -> None:
self.covariates,
self.weighted,
"weight",
start_params=start_compevent,
)

if not is_boot:
self._outcome_start_params = {
k: (m.params.values, list(m.model.exog_names))
for k, m in models.items()
}

if self.offload:
offloaded_models = {}
for key, model in models.items():
Expand Down Expand Up @@ -332,7 +384,7 @@ def plot(self, **kwargs) -> None:

def collect(self) -> SEQoutput:
"""
Collects all results current created into ``SEQoutput`` class
Collects all results currently created into ``SEQoutput`` class
"""
self._time_collected = datetime.datetime.now()

Expand Down
2 changes: 1 addition & 1 deletion pySEQTarget/analysis/_hazard.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
sim_data_pd = sim_data.to_pandas()

try:
# COXPHFITER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow()
# COXPHFITTER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow()
warnings.filterwarnings("ignore", message=".*datetime.datetime.utcnow.*")
if ce_model is not None:
cox_data = sim_data_pd[sim_data_pd["event"].isin([0, 1])].copy()
Expand Down
21 changes: 20 additions & 1 deletion pySEQTarget/analysis/_outcome_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def _outcome_fit(
formula: str,
weighted: bool = False,
weight_col: str = "weight",
start_params=None,
):
if weighted:
df = df.with_columns(
Expand Down Expand Up @@ -102,5 +103,23 @@ def _outcome_fit(
glm_kwargs["var_weights"] = df_pd[weight_col]

model = smf.glm(**glm_kwargs)
model_fit = model.fit()

# Drop warm-start coefs unless the design matrix columns match exactly
# by name — bootstrap resamples can shift categorical reference levels or
# column ordering, in which case the cached coefs are meaningless and
# IRLS can diverge into NaN/Inf and crash LAPACK.
if start_params is not None:
sp_values, sp_names = start_params
if list(model.exog_names) != list(sp_names):
start_params = None
else:
start_params = sp_values

try:
model_fit = model.fit(start_params=start_params)
except Exception:
if start_params is not None:
model_fit = model.fit()
else:
raise
return model_fit
12 changes: 10 additions & 2 deletions pySEQTarget/analysis/_subgroup_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,24 @@
from ._outcome_fit import _outcome_fit


def _subgroup_fit(self):
def _subgroup_fit(self, start_params=None):
subgroups = sorted(self.DT[self.subgroup_colname].unique().to_list())
self._unique_subgroups = subgroups

models_list = []
for val in subgroups:
subDT = self.DT.filter(pl.col(self.subgroup_colname) == val)
sg_start = (start_params or {}).get(val, {}) or {}

models = {
"outcome": _outcome_fit(
self, subDT, self.outcome_col, self.covariates, self.weighted, "weight"
self,
subDT,
self.outcome_col,
self.covariates,
self.weighted,
"weight",
start_params=sg_start.get("outcome"),
)
}

Expand All @@ -25,6 +32,7 @@ def _subgroup_fit(self):
self.covariates,
self.weighted,
"weight",
start_params=sg_start.get("compevent"),
)
models_list.append(models)
return models_list
66 changes: 64 additions & 2 deletions pySEQTarget/analysis/_survival_pred.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
import polars as pl
from patsy import PatsyError, dmatrix

from ..helpers._fix_categories import _fix_categories_for_predict
from ..helpers._predict_model import _safe_predict
from ._outcome_fit import _cast_categories

Expand All @@ -25,20 +28,79 @@ def _store_boot_risks(obj, treatment_val, TxDT, boot_cols, is_survival=False):
)


def _build_design_matrix(design_info, data):
"""
Build a design matrix from a cached design_info, applying the same
category-alignment fallback that _safe_predict uses on mismatch.
"""
try:
return np.asarray(dmatrix(design_info, data))
except PatsyError as e:
if "mismatching levels" not in str(e):
raise

# Reuse the existing fix by wrapping design_info in a stub object
class _Stub:
class model:
class data:
pass

stub = _Stub()
stub.model.data.design_info = design_info
fixed = _fix_categories_for_predict(stub, data.copy())
return np.asarray(dmatrix(design_info, fixed))


def _cached_predict(model, X_cached, ref_column_names, data):
"""
Predict using a pre-built design matrix when the model's design_info
column structure matches the reference, falling back to patsy via
_safe_predict on mismatch (e.g. a bootstrap resample that dropped a
categorical level).
"""
dinfo = model.model.data.design_info
if list(dinfo.column_names) == ref_column_names:
probs = np.asarray(model.predict(X_cached, transform=False))
if not np.any(np.isnan(probs)):
return np.clip(probs, 0, 1)
return _safe_predict(model, data)


def _get_outcome_predictions(self, TxDT, idx=None):
data = _cast_categories(self, TxDT.to_pandas())
predictions = {"outcome": []}
if self.compevent_colname is not None:
predictions["compevent"] = []

# Pre-build the design matrix once using the main fit's design_info.
# Each bootstrap model that shares the same column structure can then
# reuse it, skipping patsy entirely on the predict path.
main = self.outcome_model[0]
main_dict = main[idx] if idx is not None else main
main_outcome = self._offloader.load_model(main_dict["outcome"])
outcome_dinfo = main_outcome.model.data.design_info
X_outcome = _build_design_matrix(outcome_dinfo, data)
outcome_cols = list(outcome_dinfo.column_names)

X_compevent = compevent_cols = None
if self.compevent_colname is not None:
main_compevent = self._offloader.load_model(main_dict["compevent"])
compevent_dinfo = main_compevent.model.data.design_info
X_compevent = _build_design_matrix(compevent_dinfo, data)
compevent_cols = list(compevent_dinfo.column_names)

for boot_model in self.outcome_model:
model_dict = boot_model[idx] if idx is not None else boot_model
outcome_model = self._offloader.load_model(model_dict["outcome"])
predictions["outcome"].append(_safe_predict(outcome_model, data))
predictions["outcome"].append(
_cached_predict(outcome_model, X_outcome, outcome_cols, data)
)

if self.compevent_colname is not None:
compevent_model = self._offloader.load_model(model_dict["compevent"])
predictions["compevent"].append(_safe_predict(compevent_model, data))
predictions["compevent"].append(
_cached_predict(compevent_model, X_compevent, compevent_cols, data)
)

return predictions

Expand Down
4 changes: 2 additions & 2 deletions pySEQTarget/error/_data_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def _data_checker(self):
invalid = check.filter(pl.col("row_count") != pl.col("max_time") + 1)
if len(invalid) > 0:
raise ValueError(
f"Data validation failed: {len(invalid)} ID(s) have mismatched "
f"This suggests invalid times"
f"Data validation failed: {len(invalid)} ID(s) have mismatched row counts. "
f"This suggests invalid times. "
f"Invalid IDs:\n{invalid}"
)

Expand Down
6 changes: 4 additions & 2 deletions pySEQTarget/error/_param_checker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from ..helpers import _pad


Expand Down Expand Up @@ -26,13 +28,13 @@ def _param_checker(self):

if len(self.excused_colnames) == 0 and self.excused:
self.excused = False
raise Warning(
warnings.warn(
"Excused column names not provided but excused is set to True. Automatically set excused to False"
)

if len(self.excused_colnames) > 0 and not self.excused:
self.excused = True
raise Warning(
warnings.warn(
"Excused column names provided but excused is set to False. Automatically set excused to True"
)

Expand Down
2 changes: 1 addition & 1 deletion pySEQTarget/expansion/_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def _binder(self, kept_cols):
"""
Internal function to bind data to the map created by __mapper
Internal function to bind data to the map created by _mapper
"""
excluded = {
"dose",
Expand Down
2 changes: 1 addition & 1 deletion pySEQTarget/expansion/_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def _dynamic(self):
"""
Handles special cases for the data from the __mapper -> __binder pipeline
Handles special cases for the data from the _mapper -> _binder pipeline
"""
if self.method == "dose-response":
DT = self.DT.with_columns(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pySEQTarget"
version = "0.13.3"
version = "0.13.4"
description = "Sequentially Nested Target Trial Emulation"
readme = "README.md"
license = {text = "MIT"}
Expand Down
Loading
Loading