From 69ad62984bd3e79968266d9e5d064c2c44f9d59c Mon Sep 17 00:00:00 2001 From: RalfG Date: Tue, 14 Apr 2026 09:28:03 +0200 Subject: [PATCH 1/2] Add num_threads argument to predict controlling Torch CPU parallelization --- im2deep/_model_ops.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/im2deep/_model_ops.py b/im2deep/_model_ops.py index 82507ed..c33c9bc 100644 --- a/im2deep/_model_ops.py +++ b/im2deep/_model_ops.py @@ -92,6 +92,7 @@ def predict( device: str = "cpu", batch_size: int = 512, num_workers: int = 0, + num_threads: int | None = None, ) -> torch.Tensor: """ Predict using a trained model. @@ -108,6 +109,8 @@ def predict( Batch size for prediction. num_workers Number of workers for data loading. + num_threads + Number of threads for model operations on CPU (ignored if using GPU). Returns ------- @@ -119,6 +122,8 @@ def predict( if data is None: raise ValueError("Data must be provided for prediction.") + torch.set_num_threads(num_threads or torch.get_num_threads()) + # TODO: implement custom model inference LOGGER.debug("Loading model for prediction.") From 86e1cad8e3c2769b6e3a6aa16035646a75ef706b Mon Sep 17 00:00:00 2001 From: RalfG Date: Tue, 14 Apr 2026 15:06:53 +0200 Subject: [PATCH 2/2] Fix type checking issues by removing unnecessary type: ignore comments Replace blanket type: ignore directives with proper type narrowing (profiler None check, dict key access) and add ty-specific suppression for the optional wandb import. --- im2deep/__main__.py | 11 ++++++----- im2deep/_architectures/callbacks.py | 2 +- im2deep/_io_helpers.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/im2deep/__main__.py b/im2deep/__main__.py index dc8b8a4..9f1a1ca 100644 --- a/im2deep/__main__.py +++ b/im2deep/__main__.py @@ -178,6 +178,7 @@ def predict(ctx, *args, **kwargs): """ # Check if profiling is enabled from parent context profile_enabled = ctx.obj.get("profile", False) + profiler = None if profile_enabled: # Run with profiling @@ -187,8 +188,8 @@ def predict(ctx, *args, **kwargs): try: _run_predict(*args, **kwargs) finally: - if profile_enabled: - profiler.disable() # type: ignore + if profiler is not None: + profiler.disable() # Get the IM2Deep root directory (two levels up from this file) root_dir = Path(__file__).parent.parent @@ -196,7 +197,7 @@ def predict(ctx, *args, **kwargs): profiles_dir.mkdir(exist_ok=True) profile_output = profiles_dir / ctx.obj.get("profile_name", "im2deep_profile.prof") - profiler.dump_stats(profile_output) # type: ignore + profiler.dump_stats(profile_output) LOGGER.info(f"Profiling data saved to {profile_output}") LOGGER.info(f"View with: snakeviz {profile_output}") @@ -215,13 +216,13 @@ def _run_predict(*args, **kwargs): # Parse input files LOGGER.info("Parsing input files...") - psm_list = parse_input(Path(kwargs.get("precursors"))) # type: ignore[invalid-arg] + psm_list = parse_input(Path(kwargs["precursors"])) # Run prediction LOGGER.info("Running CCS prediction...") if kwargs.get("calibration_precursors"): LOGGER.info("Calibration file provided, performing calibration and prediction...") - psm_list_cal = parse_input(Path(kwargs.get("calibration_precursors"))) # type: ignore[invalid-arg] + psm_list_cal = parse_input(Path(kwargs["calibration_precursors"])) predictions = core.predict_and_calibrate(psm_list, psm_list_cal, *args, **kwargs) else: LOGGER.info( diff --git a/im2deep/_architectures/callbacks.py b/im2deep/_architectures/callbacks.py index 9ba794e..72a325a 100644 --- a/im2deep/_architectures/callbacks.py +++ b/im2deep/_architectures/callbacks.py @@ -3,7 +3,7 @@ import lightning as L try: - import wandb # type: ignore[import] + import wandb # type: ignore[import] # ty: ignore[unresolved-import] except ImportError: wandb = None diff --git a/im2deep/_io_helpers.py b/im2deep/_io_helpers.py index bf40b3d..11bba4e 100644 --- a/im2deep/_io_helpers.py +++ b/im2deep/_io_helpers.py @@ -99,7 +99,7 @@ def parse_input( if "CCS" in row: if precursor.metadata is None: precursor.metadata = {} - precursor.metadata["CCS"] = _normalize_ccs_metadata_value(row["CCS"]) # type: ignore + precursor.metadata["CCS"] = _normalize_ccs_metadata_value(row["CCS"]) list_of_precursors.append(precursor) except Exception as e: LOGGER.warning("Error parsing row %d: %s. Skipping.", idx, e)