Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
46cb952
feat(hpc): add HPCDaskManager for SLURM cluster orchestration
harryswift01 May 12, 2026
8edd7d8
feat(hpc): include `dask-jobqueue` within `pyproject.toml`
harryswift01 May 12, 2026
0935f93
feat(hpc): include range for `dask-jobqueue` within `pyproject.toml`
harryswift01 May 12, 2026
4052e45
feat(hpc): fix dask importing warning
harryswift01 May 12, 2026
cd9241e
feat(parallel): add parallel helper functions
harryswift01 May 18, 2026
6135273
Merge remote-tracking branch 'origin/main' into 306-feature-dask-base…
harryswift01 May 18, 2026
98d95e1
Merge remote-tracking branch 'origin/main' into 306-feature-dask-base…
harryswift01 May 19, 2026
c2ac3e5
Merge remote-tracking branch 'origin/main' into 306-dask-parallel-imp…
harryswift01 May 22, 2026
e689e00
feat(parallel): wire frame-level Dask execution into workflow
harryswift01 May 27, 2026
1ee825c
tests(unit): tests for dask HPC integration
harryswift01 May 29, 2026
cbd9dca
tests(unit): additional unit tests for parallel frame execution
harryswift01 May 29, 2026
f8223b6
tests(unit): add additional tests for parallel and combine level dag …
harryswift01 May 29, 2026
b7cb020
tests(unit): add additional tests to argsparse for dask introduction
harryswift01 May 29, 2026
ee0b464
feat(parallel): add configurable Dask frame execution
harryswift01 May 29, 2026
89eb097
feat(parallel): add configurable Dask frame execution
harryswift01 May 29, 2026
978a57e
feat(parallel): add configurable Dask frame execution
harryswift01 May 29, 2026
5055b68
feat(parallel): add Dask frame execution and SLURM support
harryswift01 May 29, 2026
b3ef78f
tests(unit): ensure unit tests reference correct CodeEntropy SLURM
harryswift01 May 29, 2026
bf372a3
docs(parallel): document Dask and SLURM arguments
harryswift01 May 29, 2026
7baaaff
fix(hpc): use explicit Dask scheduler interface
harryswift01 Jun 2, 2026
e342aa4
tests(unit): update tests to use explicit Dask scheduler interface
harryswift01 Jun 2, 2026
da08234
fix(hpc): avoid invalid Dask worker interfaces
harryswift01 Jun 2, 2026
ebcc343
fix(parallel): remove all inherited SLURM environments
harryswift01 Jun 9, 2026
318969e
feat(hpc): auto-detect conda settings for SLURM Dask submission
harryswift01 Jun 10, 2026
6ace042
feat(hpc): guard SLURM submit mode inside submitted jobs
harryswift01 Jun 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 232 additions & 66 deletions CodeEntropy/config/argparse.py

Large diffs are not rendered by default.

15 changes: 14 additions & 1 deletion CodeEntropy/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from rich.text import Text

from CodeEntropy.config.argparse import ConfigResolver
from CodeEntropy.core.dask_clusters import HPCDaskManager
from CodeEntropy.core.logging import LoggingConfig
from CodeEntropy.entropy.workflow import EntropyWorkflow
from CodeEntropy.levels.dihedrals import ConformationStateBuilder
Expand Down Expand Up @@ -223,8 +224,9 @@ def run_entropy_workflow(self) -> None:

This method:
- Sets up logging and prints the splash screen
- Loads YAML config from CWD and parses CLI args
- Loads YAML configuration from CWD and parses CLI args
- Merges args with YAML per-run config
- Optionally submits a master SLURM job and exits
- Builds the MDAnalysis Universe (with optional force merging)
- Validates user parameters
- Constructs dependencies and executes EntropyWorkflow
Expand Down Expand Up @@ -256,6 +258,16 @@ def run_entropy_workflow(self) -> None:

args = self._config_manager.resolve(args, run_config)

if getattr(args, "submit", False):
if os.environ.get("CODEENTROPY_SUBMITTED_JOB") == "1":
run_logger.info(
"Already running inside submitted SLURM job; "
"continuing workflow."
)
else:
HPCDaskManager(args).submit_master()
return

log_level = (
logging.DEBUG if getattr(args, "verbose", False) else logging.INFO
)
Expand Down Expand Up @@ -298,6 +310,7 @@ def run_entropy_workflow(self) -> None:
except Exception:
logger.error("Run arguments at failure could not be serialized")

logger.exception("Fatal error during entropy calculation")
raise RuntimeError("CodeEntropyRunner encountered an error") from exc

@staticmethod
Expand Down
279 changes: 279 additions & 0 deletions CodeEntropy/core/dask_clusters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
"""
Helpers for setting up Dask clusters on HPC using SLURM.
"""

from __future__ import annotations

import logging
import os
import shlex
import subprocess
import sys

import psutil
from dask.distributed import Client
from dask_jobqueue import SLURMCluster

logger = logging.getLogger(__name__)


class HPCDaskManager:
"""
Manage SLURM-backed Dask clusters and submission utilities for HPC environments.
"""

def __init__(self, args):
"""
Initialise HPCDaskManager with runtime arguments.

Args:
args: Parsed CLI arguments containing HPC and conda configuration.
"""
self.args = args

def _conda_env(self) -> str:
"""Determine the activated conda/mamba environment."""
try:
return os.environ["CONDA_DEFAULT_ENV"]
except KeyError:
logger.error("Please activate your conda/mamba environment.")
raise SystemExit(1) from None

def _conda_exec(self) -> str:
"""Determine whether conda or mamba should be used for activation."""
if os.environ.get("MAMBA_EXE"):
return "mamba"

if os.environ.get("CONDA_EXE"):
return "conda"

logger.error(
"Cannot determine your conda executable. "
"Please make sure conda or mamba has been initialised."
)
raise SystemExit(1)

def _conda_path(self) -> str:
"""Determine the path to the conda executable used for shell initialisation."""
conda_exe = os.environ.get("CONDA_EXE")

if conda_exe:
return conda_exe

logger.error("Please make sure conda is set up correctly.")
raise SystemExit(1)

def resolve_conda_settings(self) -> None:
"""
Fill missing conda/mamba settings from the active environment.

Explicit user-provided values are preserved. Auto-detection is only used
when values are missing.
"""
args = self.args

if not getattr(args, "conda_env", None):
args.conda_env = self._conda_env()

if not getattr(args, "conda_exec", None):
args.conda_exec = self._conda_exec()

if not getattr(args, "conda_path", None) or args.conda_path == "conda":
args.conda_path = self._conda_path()

def check_slurm_env(self) -> None:
"""
Remove inherited SLURM environment variables that can break nested srun calls.

This is important when the master CodeEntropy process itself is already
running inside a SLURM allocation and then launches Dask worker jobs.
"""
for variable in (
"SLURM_CPU_BIND",
"SLURM_MEM_PER_CPU",
"SLURM_MEM_PER_GPU",
"SLURM_MEM_PER_NODE",
):
os.environ.pop(variable, None)

def system_network_interface(self) -> str:
"""
Get the best candidate for the HPC network interface.

This deliberately follows the WaterEntropy-style behaviour and only
selects from known HPC-safe interfaces. It avoids selecting arbitrary
interfaces such as eno1, which may exist on the master node but not on
worker nodes.
"""
hpc_nics = ["bond0", "ib0", "hsn0", "eth0"]
interfaces = list(psutil.net_if_addrs().keys())

for iface in hpc_nics:
if iface in interfaces:
return iface

raise RuntimeError(
"Could not find a known HPC network interface. "
f"Available interfaces: {interfaces}. "
"Expected one of: bond0, ib0, hsn0, eth0."
)

def slurm_directives(self) -> tuple[list[str], list[str]]:
"""
Process additional SLURM directives and directives to skip.

Returns:
Tuple containing extra directives and skipped directives.
"""
args = self.args
extra: list[str] = []

if args.hpc_account:
extra.append(f"--account={args.hpc_account}")
if args.hpc_qos:
extra.append(f"--qos={args.hpc_qos}")
if args.hpc_constraint:
extra.append(f"--constraint={args.hpc_constraint}")

skip = ["--mem"]

return extra, skip

def slurm_prologues(self) -> list[str]:
"""
Build environment setup commands for the SLURM worker job script.

Returns:
List of shell commands executed before the Dask worker starts.
"""
args = self.args
prologue: list[str] = []

for module_name in getattr(args, "hpc_modules", None) or []:
prologue.append(f"module load {module_name}")

prologue.append("unset SLURM_MEM_PER_CPU")
prologue.append("unset SLURM_MEM_PER_GPU")
prologue.append("unset SLURM_MEM_PER_NODE")
prologue.append("unset SLURM_CPU_BIND")

prologue.append(f'eval "$({args.conda_path} shell.bash hook)"')

if args.conda_exec == "mamba":
prologue.append(f'eval "$({args.conda_exec} shell hook --shell bash)"')

prologue.append(f"{args.conda_exec} activate {args.conda_env}")
prologue.append("export SLURM_CPU_FREQ_REQ=2250000")

return prologue

def configure_cluster(self) -> Client:
"""
Configure a SLURM-backed Dask cluster.

Returns:
Dask distributed client connected to the SLURMCluster.
"""
args = self.args

self.resolve_conda_settings()

extra, skip = self.slurm_directives()
prologue = self.slurm_prologues()
iface = self.system_network_interface()

self.check_slurm_env()

cluster = SLURMCluster(
cores=args.hpc_cores,
processes=args.hpc_processes,
memory=args.hpc_memory,
queue=args.hpc_queue,
job_directives_skip=skip,
job_extra_directives=extra,
python="srun python",
walltime=args.hpc_walltime,
shebang="#!/bin/bash --login",
local_directory="$PWD",
interface=iface,
job_script_prologue=prologue,
)

cluster.scale(jobs=args.hpc_nodes)

client = Client(cluster)

with open("dask-cluster-submit.sh", "w", encoding="utf-8") as f:
f.write(cluster.job_script())

return client

def submit_master(self) -> None:
"""
Submit a SLURM job that runs the master CodeEntropy process.

This generates a temporary SLURM script and submits it via sbatch.
"""
self.resolve_conda_settings()

cli = list(sys.argv[1:])

if "--submit" in cli:
idx = cli.index("--submit")
cli.pop(idx)

if idx < len(cli) and str(cli[idx]).lower() in {"true", "false"}:
cli.pop(idx)

script_name = "CodeEntropy-master-submit.sh"

with open(script_name, "w", encoding="utf-8") as f:
f.write("#!/bin/bash --login\n\n")
f.write("#SBATCH --job-name=codeentropy-master\n")
f.write("#SBATCH --nodes=1\n")
f.write("#SBATCH --ntasks=1\n")
f.write("#SBATCH --cpus-per-task=2\n")
f.write(f"#SBATCH --time={self.args.hpc_walltime}\n")
f.write(f"#SBATCH --partition={self.args.hpc_queue}\n")
f.write("#SBATCH --output=CodeEntropy-master-%j.out\n")
f.write("#SBATCH --error=CodeEntropy-master-%j.err\n")

if self.args.hpc_account:
f.write(f"#SBATCH --account={self.args.hpc_account}\n")

if self.args.hpc_qos:
f.write(f"#SBATCH --qos={self.args.hpc_qos}\n")

if self.args.hpc_constraint:
f.write(f"#SBATCH --constraint={self.args.hpc_constraint}\n")

f.write("\n")

for module_name in getattr(self.args, "hpc_modules", None) or []:
f.write(f"module load {module_name}\n")

f.write("unset SLURM_MEM_PER_CPU\n")
f.write("unset SLURM_MEM_PER_GPU\n")
f.write("unset SLURM_MEM_PER_NODE\n")
f.write("unset SLURM_CPU_BIND\n")

f.write(f'eval "$({self.args.conda_path} shell.bash hook)"\n')

if self.args.conda_exec == "mamba":
f.write(f'eval "$({self.args.conda_exec} shell hook --shell bash)"\n')

f.write(f"{self.args.conda_exec} activate {self.args.conda_env}\n")
f.write("export SLURM_CPU_FREQ_REQ=2250000\n")
f.write("export CODEENTROPY_SUBMITTED_JOB=1\n\n")

command = " ".join(["srun", "CodeEntropy", shlex.join(cli)])
f.write(f"{command}\n")

self.check_slurm_env()

try:
result = subprocess.check_output(["sbatch", script_name])
print(result.decode("utf-8"))
except subprocess.CalledProcessError as exc:
print(exc.output.decode("utf-8", errors="replace"))
raise
55 changes: 52 additions & 3 deletions CodeEntropy/entropy/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pandas as pd

from CodeEntropy.core.dask_clusters import HPCDaskManager
from CodeEntropy.core.logging import LoggingConfig
from CodeEntropy.entropy.graph import EntropyGraph
from CodeEntropy.entropy.water import WaterEntropy
Expand Down Expand Up @@ -116,13 +117,61 @@ def execute(self) -> None:
frame_selection=frame_selection,
)

with self._reporter.progress(transient=False) as p:
self._run_level_dag(shared_data, progress=p)
self._run_entropy_graph(shared_data, progress=p)
self._configure_parallel_frame_execution(shared_data)

try:
with self._reporter.progress(transient=False) as p:
self._run_level_dag(shared_data, progress=p)
self._run_entropy_graph(shared_data, progress=p)
finally:
client = shared_data.get("dask_client")
if client is not None:
client.close()

self._finalize_molecule_results()
self._reporter.log_tables()

def _configure_parallel_frame_execution(self, shared_data: SharedData) -> None:
"""Attach a Dask client to shared_data if parallel frames are requested.

Supports:
- Local Dask via --parallel_frames true / --use_dask true
- SLURM-backed Dask via --hpc true
"""
use_parallel = bool(
getattr(self._args, "parallel_frames", False)
or getattr(self._args, "use_dask", False)
or getattr(self._args, "hpc", False)
)

if not use_parallel:
return

if "dask_client" in shared_data:
shared_data["parallel_frames"] = True
return

if getattr(self._args, "hpc", False):
client = HPCDaskManager(self._args).configure_cluster()
shared_data["dask_client"] = client
shared_data["parallel_frames"] = True
return

try:
from dask.distributed import Client
except ImportError as exc:
raise RuntimeError(
"Parallel frame execution was requested, but dask.distributed "
"is not installed."
) from exc

shared_data["dask_client"] = Client(
processes=True,
n_workers=getattr(self._args, "dask_workers", None),
threads_per_worker=getattr(self._args, "dask_threads_per_worker", 1),
)
shared_data["parallel_frames"] = True

def _build_frame_selection(self) -> FrameSelection:
"""Build the workflow frame selection.

Expand Down
Loading