Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- uses: actions/checkout@v6

- name: Install uv
uses: astral-sh/setup-uv@v8.0.0
uses: astral-sh/setup-uv@v8.1.0

- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
Expand Down
4 changes: 4 additions & 0 deletions pySEQTarget/error/_data_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def _data_checker(self):

for col in self.weight_eligible_colnames:
if col is not None:
if col not in self.data.columns:
raise ValueError(
f"weight_eligible_colnames entry '{col}' not found in data columns."
)
_check_binary(self.data, col)

check = self.data.group_by(self.id_col).agg(
Expand Down
6 changes: 5 additions & 1 deletion pySEQTarget/expansion/_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,14 @@ def _binder(self, kept_cols):
for c in baseline_cols
]

to_drop = [f"{self.eligible_col}{self.indicator_baseline}"]
if self.eligible_col not in kept_cols:
to_drop.append(self.eligible_col)

DT = (
DT.with_columns(bas)
.filter(pl.col(f"{self.eligible_col}{self.indicator_baseline}") == 1)
.drop([f"{self.eligible_col}{self.indicator_baseline}", self.eligible_col])
.drop(to_drop)
)

# Truncate each (id, trial) at the first outcome event so that subjects who
Expand Down
45 changes: 38 additions & 7 deletions pySEQTarget/helpers/_bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import time
import warnings
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import wraps

Expand Down Expand Up @@ -103,7 +104,7 @@ def wrapper(self, *args, **kwargs):
self.DT = None

with ProcessPoolExecutor(max_workers=ncores) as executor:
futures = [
futures = {
executor.submit(
_bootstrap_worker,
self,
Expand All @@ -113,13 +114,24 @@ def wrapper(self, *args, **kwargs):
seed,
args,
kwargs,
)
): i
for i in range(nboot)
]
}
skipped = 0
for j in tqdm(
as_completed(futures), total=nboot, desc="Bootstrapping..."
):
results.append(j.result())
boot_idx = futures[j]
try:
results.append(j.result())
except np.linalg.LinAlgError as e:
skipped += 1
warnings.warn(
f"Bootstrap iteration {boot_idx + 1} failed "
f"({e}); skipping replicate.",
UserWarning,
stacklevel=2,
)

self._rng = original_rng
self.DT = self._offloader.load_dataframe(original_DT_ref)
Expand All @@ -131,6 +143,7 @@ def wrapper(self, *args, **kwargs):
else:
original_DT_ref = original_DT

skipped = 0
for i in tqdm(range(nboot), desc="Bootstrapping..."):
self._current_boot_idx = i + 1
if seed is not None:
Expand All @@ -140,12 +153,30 @@ def wrapper(self, *args, **kwargs):
if self._offloader.enabled:
del tmp
self.bootstrap_nboot = 0
boot_fit = method(self, *args, **kwargs)
results.append(boot_fit)
try:
boot_fit = method(self, *args, **kwargs)
results.append(boot_fit)
except np.linalg.LinAlgError as e:
skipped += 1
warnings.warn(
f"Bootstrap iteration {i + 1} failed "
f"({e}); skipping replicate.",
UserWarning,
stacklevel=2,
)

self.bootstrap_nboot = nboot
self.DT = self._offloader.load_dataframe(original_DT_ref)

self.bootstrap_nboot = len(results) - 1
if skipped > 0:
warnings.warn(
f"{skipped} of {nboot} bootstrap replicate(s) skipped due to "
"singular Hessian; effective bootstrap_nboot is "
f"{self.bootstrap_nboot}.",
UserWarning,
stacklevel=2,
)

end = time.perf_counter()
self._model_time = _format_time(start, end)

Expand Down
2 changes: 0 additions & 2 deletions pySEQTarget/helpers/_predict_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import warnings

import numpy as np

from ._fix_categories import _fix_categories_for_predict
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pySEQTarget"
version = "0.12.9"
version = "0.13.0"
description = "Sequentially Nested Target Trial Emulation"
readme = "README.md"
license = {text = "MIT"}
Expand Down Expand Up @@ -83,3 +83,6 @@ SEQdata = ["data/*.csv"]
[tool.pytest.ini_options]
pythonpath = ["."]
testpaths = ["tests"]
filterwarnings = [
"ignore:FigureCanvasAgg is non-interactive:UserWarning",
]
4 changes: 2 additions & 2 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

matplotlib.use("Agg") # non-interactive backend — no windows opened

from pySEQTarget import SEQopts, SEQuential
from pySEQTarget.data import load_data
from pySEQTarget import SEQopts, SEQuential # noqa: E402
from pySEQTarget.data import load_data # noqa: E402


@pytest.fixture(autouse=True)
Expand Down