From aecfa95b3b73f086164d8bc01da254ae5bfe08df Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Wed, 4 Mar 2026 09:26:13 +0000 Subject: [PATCH] Handle missing OpenCL ICD loader during GPU platform detection. --- src/somd2/runner/_repex.py | 61 +++++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index cdc1851..f6ba67b 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -575,46 +575,59 @@ def _check_device_memory(device_index=0): index: int The index of the GPU device. """ - import pyopencl as cl - # Get the device. - platforms = cl.get_platforms() - all_devices = [] - for platform in platforms: - try: - devices = platform.get_devices(device_type=cl.device_type.GPU) - all_devices.extend(devices) - except: - continue - - if device_index >= len(all_devices): - msg = f"Device index {device_index} out of range. Found {len(all_devices)} GPU(s)." - _logger.error(msg) - raise IndexError(msg) + # Try to use pyopencl to detect the GPU vendor. + vendor = None + ocl_device = None + try: + import pyopencl as cl - device = all_devices[device_index] - total = device.global_mem_size + platforms = cl.get_platforms() + all_devices = [] + for platform in platforms: + try: + devices = platform.get_devices(device_type=cl.device_type.GPU) + all_devices.extend(devices) + except Exception: + continue + + if device_index < len(all_devices): + ocl_device = all_devices[device_index] + vendor = ocl_device.vendor + else: + msg = f"Device index {device_index} out of range. Found {len(all_devices)} GPU(s)." + _logger.error(msg) + raise IndexError(msg) + except IndexError: + raise + except Exception: + _logger.warning( + "Could not query GPU platform via OpenCL; falling back to pynvml for NVIDIA detection." + ) - # NVIDIA: Use pynvml - if "NVIDIA" in device.vendor: + # NVIDIA: Use pynvml (also used as fallback when OpenCL is unavailable). + if vendor is None or "NVIDIA" in vendor: try: import pynvml pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(device_index) memory = pynvml.nvmlDeviceGetMemoryInfo(handle) pynvml.nvmlShutdown() return (memory.used, memory.free, memory.total) except Exception as e: - msg = f"Could not get NVIDIA GPU memory info for device {device_index}: {e}" + if vendor is None: + msg = f"Could not get GPU memory info for device {device_index} via OpenCL or pynvml: {e}" + else: + msg = f"Could not get NVIDIA GPU memory info for device {device_index}: {e}" _logger.error(msg) raise RuntimeError(msg) from e - # AMD: Use OpenCL extension - elif "AMD" in device.vendor or "Advanced Micro Devices" in device.vendor: + # AMD: Use OpenCL extension. + elif "AMD" in vendor or "Advanced Micro Devices" in vendor: try: - free_memory_info = device.get_info(0x4038) + total = ocl_device.global_mem_size + free_memory_info = ocl_device.get_info(0x4038) free_kb = ( free_memory_info[0] if isinstance(free_memory_info, list)