Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 37 additions & 24 deletions src/somd2/runner/_repex.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,46 +575,59 @@ def _check_device_memory(device_index=0):
index: int
The index of the GPU device.
"""
import pyopencl as cl

# Get the device.
platforms = cl.get_platforms()
all_devices = []
for platform in platforms:
try:
devices = platform.get_devices(device_type=cl.device_type.GPU)
all_devices.extend(devices)
except:
continue

if device_index >= len(all_devices):
msg = f"Device index {device_index} out of range. Found {len(all_devices)} GPU(s)."
_logger.error(msg)
raise IndexError(msg)
# Try to use pyopencl to detect the GPU vendor.
vendor = None
ocl_device = None
try:
import pyopencl as cl

device = all_devices[device_index]
total = device.global_mem_size
platforms = cl.get_platforms()
all_devices = []
for platform in platforms:
try:
devices = platform.get_devices(device_type=cl.device_type.GPU)
all_devices.extend(devices)
except Exception:
continue

if device_index < len(all_devices):
ocl_device = all_devices[device_index]
vendor = ocl_device.vendor
else:
msg = f"Device index {device_index} out of range. Found {len(all_devices)} GPU(s)."
_logger.error(msg)
raise IndexError(msg)
except IndexError:
raise
except Exception:
_logger.warning(
"Could not query GPU platform via OpenCL; falling back to pynvml for NVIDIA detection."
)

# NVIDIA: Use pynvml
if "NVIDIA" in device.vendor:
# NVIDIA: Use pynvml (also used as fallback when OpenCL is unavailable).
if vendor is None or "NVIDIA" in vendor:
try:
import pynvml

pynvml.nvmlInit()

handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
pynvml.nvmlShutdown()
return (memory.used, memory.free, memory.total)
except Exception as e:
msg = f"Could not get NVIDIA GPU memory info for device {device_index}: {e}"
if vendor is None:
msg = f"Could not get GPU memory info for device {device_index} via OpenCL or pynvml: {e}"
else:
msg = f"Could not get NVIDIA GPU memory info for device {device_index}: {e}"
_logger.error(msg)
raise RuntimeError(msg) from e

# AMD: Use OpenCL extension
elif "AMD" in device.vendor or "Advanced Micro Devices" in device.vendor:
# AMD: Use OpenCL extension.
elif "AMD" in vendor or "Advanced Micro Devices" in vendor:
try:
free_memory_info = device.get_info(0x4038)
total = ocl_device.global_mem_size
free_memory_info = ocl_device.get_info(0x4038)
free_kb = (
free_memory_info[0]
if isinstance(free_memory_info, list)
Expand Down