From 6cfd033be6223e8c4fc188d913ba0cb36e6c7023 Mon Sep 17 00:00:00 2001 From: Andrei Maslennikov Date: Fri, 13 Feb 2026 11:25:15 +0100 Subject: [PATCH 1/7] Refactoring, move functions/classes --- .../test_vllm_slurm_command_gen_strategy.py | 94 ++++++++++--------- 1 file changed, 48 insertions(+), 46 deletions(-) diff --git a/tests/slurm_command_gen_strategy/test_vllm_slurm_command_gen_strategy.py b/tests/slurm_command_gen_strategy/test_vllm_slurm_command_gen_strategy.py index ff63a12e7..aa9dd9b56 100644 --- a/tests/slurm_command_gen_strategy/test_vllm_slurm_command_gen_strategy.py +++ b/tests/slurm_command_gen_strategy/test_vllm_slurm_command_gen_strategy.py @@ -60,6 +60,18 @@ def vllm_disagg_tr(vllm: VllmTestDefinition, tmp_path: Path) -> TestRun: return TestRun(test=vllm, num_nodes=1, nodes=[], output_path=tmp_path, name="vllm-disagg-job") +def test_container_mounts(vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy) -> None: + assert vllm_cmd_gen_strategy._container_mounts() == [ + f"{vllm_cmd_gen_strategy.system.hf_home_path.absolute()}:/root/.cache/huggingface" + ] + + +def test_sweep_detection(vllm: VllmTestDefinition) -> None: + assert vllm.is_dse_job is False + vllm.cmd_args.decode.gpu_ids = ["1"] + assert vllm.is_dse_job is True + + class TestGpuDetection: """Tests for GPU detection logic.""" @@ -168,10 +180,42 @@ def test_prefill_serve_args_with_custom_fields(self) -> None: ] -def test_container_mounts(vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy) -> None: - assert vllm_cmd_gen_strategy._container_mounts() == [ - f"{vllm_cmd_gen_strategy.system.hf_home_path.absolute()}:/root/.cache/huggingface" - ] +class TestVllmBenchCommand: + def test_get_vllm_bench_command(self, vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy) -> None: + tdef = cast(VllmTestDefinition, vllm_cmd_gen_strategy.test_run.test) + cmd_args = tdef.cmd_args + bench_args = tdef.bench_cmd_args + + command = vllm_cmd_gen_strategy.get_vllm_bench_command() + + expected = [ + "vllm", + "bench", + "serve", + f"--model {cmd_args.model}", + f"--base-url http://0.0.0.0:{cmd_args.port}", + f"--random-input-len {bench_args.random_input_len}", + f"--random-output-len {bench_args.random_output_len}", + f"--max-concurrency {bench_args.max_concurrency}", + f"--num-prompts {bench_args.num_prompts}", + f"--result-dir {vllm_cmd_gen_strategy.test_run.output_path.absolute()}", + f"--result-filename {VLLM_BENCH_JSON_FILE}", + "--save-result", + ] + assert command == expected + + def test_get_vllm_bench_command_with_extra_args( + self, vllm: VllmTestDefinition, vllm_tr: TestRun, slurm_system: SlurmSystem + ) -> None: + vllm.bench_cmd_args = VllmBenchCmdArgs.model_validate({"extra1": 1, "extra-2": 2, "extra_3": 3}) + vllm_tr.test = vllm + vllm_cmd_gen_strategy = VllmSlurmCommandGenStrategy(slurm_system, vllm_tr) + + cmd = vllm_cmd_gen_strategy.get_vllm_bench_command() + + assert "--extra1 1" in cmd + assert "--extra-2 2" in cmd + assert "--extra-3 3" in cmd class TestVllmAggregatedMode: @@ -211,42 +255,6 @@ def test_generate_wait_for_health_function(self, vllm_cmd_gen_strategy: VllmSlur assert func == expected - def test_get_vllm_bench_command(self, vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy) -> None: - tdef = cast(VllmTestDefinition, vllm_cmd_gen_strategy.test_run.test) - cmd_args = tdef.cmd_args - bench_args = tdef.bench_cmd_args - - command = vllm_cmd_gen_strategy.get_vllm_bench_command() - - expected = [ - "vllm", - "bench", - "serve", - f"--model {cmd_args.model}", - f"--base-url http://0.0.0.0:{cmd_args.port}", - f"--random-input-len {bench_args.random_input_len}", - f"--random-output-len {bench_args.random_output_len}", - f"--max-concurrency {bench_args.max_concurrency}", - f"--num-prompts {bench_args.num_prompts}", - f"--result-dir {vllm_cmd_gen_strategy.test_run.output_path.absolute()}", - f"--result-filename {VLLM_BENCH_JSON_FILE}", - "--save-result", - ] - assert command == expected - - def test_get_vllm_bench_command_with_extra_args( - self, vllm: VllmTestDefinition, vllm_tr: TestRun, slurm_system: SlurmSystem - ) -> None: - vllm.bench_cmd_args = VllmBenchCmdArgs.model_validate({"extra1": 1, "extra-2": 2, "extra_3": 3}) - vllm_tr.test = vllm - vllm_cmd_gen_strategy = VllmSlurmCommandGenStrategy(slurm_system, vllm_tr) - - cmd = vllm_cmd_gen_strategy.get_vllm_bench_command() - - assert "--extra1 1" in cmd - assert "--extra-2 2" in cmd - assert "--extra-3 3" in cmd - def test_gen_srun_command_full_flow(self, vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy) -> None: tdef = vllm_cmd_gen_strategy.test_run.test cmd_args = tdef.cmd_args @@ -411,9 +419,3 @@ def test_gen_srun_command_disagg_flow(self, vllm_disagg_tr: TestRun, slurm_syste {bench_cmd}""" assert srun_command == expected - - -def test_sweep_detection(vllm: VllmTestDefinition) -> None: - assert vllm.is_dse_job is False - vllm.cmd_args.decode.gpu_ids = ["1"] - assert vllm.is_dse_job is True From 0e356b2df1d050141027612553597501cde5525f Mon Sep 17 00:00:00 2001 From: Andrei Maslennikov Date: Fri, 13 Feb 2026 11:46:17 +0100 Subject: [PATCH 2/7] Unify prefill/decode cmd generation --- .../vllm/slurm_command_gen_strategy.py | 42 +++++++++++-------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py b/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py index 6c86844d6..5219e427b 100644 --- a/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from typing import cast from cloudai.systems.slurm import SlurmCommandGenStrategy @@ -57,6 +58,10 @@ def decode_gpu_ids(self) -> list[int]: mid = len(self.gpu_ids) // 2 return self.gpu_ids[mid:] + @staticmethod + def _to_json_str_arg(config: dict) -> str: + return "'" + json.dumps(config, separators=(",", ":")) + "'" + def get_vllm_serve_commands(self) -> list[list[str]]: tdef: VllmTestDefinition = cast(VllmTestDefinition, self.test_run.test) cmd_args: VllmCmdArgs = tdef.cmd_args @@ -68,24 +73,25 @@ def get_vllm_serve_commands(self) -> list[list[str]]: prefill_port = cmd_args.port + 100 decode_port = cmd_args.port + 200 - prefill_extra_args = tdef.cmd_args.prefill.serve_args if tdef.cmd_args.prefill else [] - prefill_cmd = [ - *base_cmd, - "--port", - str(prefill_port), - "--kv-transfer-config", - '\'{"kv_connector":"NixlConnector","kv_role":"kv_producer"}\'', - *prefill_extra_args, - ] - decode_cmd = [ - *base_cmd, - "--port", - str(decode_port), - "--kv-transfer-config", - '\'{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}\'', - *tdef.cmd_args.decode.serve_args, - ] - return [prefill_cmd, decode_cmd] + kv_transfer_config = {"kv_connector": "NixlConnector"} + + commands: list[list[str]] = [] + for port, role, args in [ + (prefill_port, "kv_producer", tdef.cmd_args.prefill), + (decode_port, "kv_consumer", tdef.cmd_args.decode), + ]: + commands.append( + [ + *base_cmd, + "--port", + str(port), + "--kv-transfer-config", + self._to_json_str_arg(kv_transfer_config | {"kv_role": role}), + *args.serve_args, + ] + ) + + return commands def get_proxy_command(self) -> list[str]: prefill_port = self.tdef.cmd_args.port + 100 From e82c1aa9223a82821aead55bf653dc9e49240ab5 Mon Sep 17 00:00:00 2001 From: Andrei Maslennikov Date: Fri, 13 Feb 2026 16:06:03 +0100 Subject: [PATCH 3/7] Add support for nixl_threads --- .../vllm/slurm_command_gen_strategy.py | 7 +++- src/cloudai/workloads/vllm/vllm.py | 6 ++- .../test_vllm_slurm_command_gen_strategy.py | 37 +++++++++++++++++++ 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py b/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py index 5219e427b..e18b43c23 100644 --- a/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py @@ -15,7 +15,7 @@ # limitations under the License. import json -from typing import cast +from typing import Any, cast from cloudai.systems.slurm import SlurmCommandGenStrategy @@ -73,13 +73,16 @@ def get_vllm_serve_commands(self) -> list[list[str]]: prefill_port = cmd_args.port + 100 decode_port = cmd_args.port + 200 - kv_transfer_config = {"kv_connector": "NixlConnector"} + kv_transfer_config: dict[str, Any] = {"kv_connector": "NixlConnector"} commands: list[list[str]] = [] for port, role, args in [ (prefill_port, "kv_producer", tdef.cmd_args.prefill), (decode_port, "kv_consumer", tdef.cmd_args.decode), ]: + if args.nixl_threads is not None: + kv_transfer_config.setdefault("kv_connector_extra_config", {}) + kv_transfer_config["kv_connector_extra_config"]["num_threads"] = cast(int, args.nixl_threads) commands.append( [ *base_cmd, diff --git a/src/cloudai/workloads/vllm/vllm.py b/src/cloudai/workloads/vllm/vllm.py index 96cc84c0a..9f8ddd31a 100644 --- a/src/cloudai/workloads/vllm/vllm.py +++ b/src/cloudai/workloads/vllm/vllm.py @@ -32,8 +32,12 @@ class VllmArgs(CmdArgs): """Base command arguments for vLLM instances.""" gpu_ids: str | list[str] | None = Field( + default=None, description="Comma-separated GPU IDs. If not set, will use all available GPUs." + ) + + nixl_threads: int | list[int] | None = Field( default=None, - description="Comma-separated GPU IDs. If not set, will use all available GPUs.", + description="Set ``kv_connector_extra_config.num_threads`` for ``--kv-transfer-config`` CLI argument.", ) @property diff --git a/tests/slurm_command_gen_strategy/test_vllm_slurm_command_gen_strategy.py b/tests/slurm_command_gen_strategy/test_vllm_slurm_command_gen_strategy.py index aa9dd9b56..7b15b9a41 100644 --- a/tests/slurm_command_gen_strategy/test_vllm_slurm_command_gen_strategy.py +++ b/tests/slurm_command_gen_strategy/test_vllm_slurm_command_gen_strategy.py @@ -180,6 +180,43 @@ def test_prefill_serve_args_with_custom_fields(self) -> None: ] +class TestVllmServeCommand: + @pytest.mark.parametrize("decode_nthreads", [None, 4]) + def test_decode_nixl_threads( + self, decode_nthreads: int | None, vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy + ) -> None: + tdef = cast(VllmTestDefinition, vllm_cmd_gen_strategy.test_run.test) + tdef.cmd_args.prefill = VllmArgs() + tdef.cmd_args.decode.nixl_threads = decode_nthreads + + commands = vllm_cmd_gen_strategy.get_vllm_serve_commands() + + assert len(commands) == 2 + dec_cmd = " ".join(commands[1]) + if decode_nthreads is not None: + assert "kv_connector_extra_config" in dec_cmd + assert f'"num_threads":{decode_nthreads}' in dec_cmd + else: + assert all(arg not in dec_cmd for arg in ["num_threads", "kv_connector_extra_config"]) + + @pytest.mark.parametrize("prefill_nthreads", [None, 2]) + def test_prefill_nixl_threads( + self, prefill_nthreads: int | None, vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy + ) -> None: + tdef = cast(VllmTestDefinition, vllm_cmd_gen_strategy.test_run.test) + tdef.cmd_args.prefill = VllmArgs(nixl_threads=prefill_nthreads) + + commands = vllm_cmd_gen_strategy.get_vllm_serve_commands() + + assert len(commands) == 2 + pre_cmd = " ".join(commands[0]) + if prefill_nthreads is not None: + assert "kv_connector_extra_config" in pre_cmd + assert f'"num_threads":{prefill_nthreads}' in pre_cmd + else: + assert all(arg not in pre_cmd for arg in ["num_threads", "kv_connector_extra_config"]) + + class TestVllmBenchCommand: def test_get_vllm_bench_command(self, vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy) -> None: tdef = cast(VllmTestDefinition, vllm_cmd_gen_strategy.test_run.test) From 18983f3f86b1acae993572f592e99432d4e82646 Mon Sep 17 00:00:00 2001 From: Andrei Maslennikov Date: Fri, 13 Feb 2026 16:06:51 +0100 Subject: [PATCH 4/7] Render VllmArgs documentation --- doc/conf.py | 6 +++++- doc/workloads/vllm.rst | 6 ++++++ pyproject.toml | 2 ++ uv.lock | 41 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 54 insertions(+), 1 deletion(-) diff --git a/doc/conf.py b/doc/conf.py index ba14c87f8..54055e069 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -65,6 +65,7 @@ def setup(app): "sphinx.ext.autosummary", "sphinxcontrib.mermaid", "sphinx_copybutton", + "sphinxcontrib.autodoc_pydantic", ] exclude_patterns = ["_build"] @@ -74,9 +75,12 @@ def setup(app): "members": True, "member-order": "bysource", "special-members": "__init__", - "undoc-members": False, # Don't show undocumented members } +autodoc_pydantic_model_show_json = False +autodoc_pydantic_model_show_field_summary = False +autodoc_pydantic_model_show_config_summary = False + # Generate autosummary even if no references autosummary_generate = True diff --git a/doc/workloads/vllm.rst b/doc/workloads/vllm.rst index 1e9f82ca4..646e822ad 100644 --- a/doc/workloads/vllm.rst +++ b/doc/workloads/vllm.rst @@ -132,6 +132,12 @@ In this case the proxy script will be mounted from the vLLM repository (cloned l API Documentation ----------------- +vLLM Serve Arguments +~~~~~~~~~~~~~~~~~~~~ + +.. autopydantic_model:: cloudai.workloads.vllm.vllm.VllmArgs + :members: + Command Arguments ~~~~~~~~~~~~~~~~~ diff --git a/pyproject.toml b/pyproject.toml index 5b26caa52..1cf4b7368 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,12 +56,14 @@ requires-python = ">=3.10" "sphinx-design~=0.6", "sphinxcontrib-mermaid~=2.0", "sphinx-copybutton~=0.5", + "autodoc-pydantic~=2.2", ] docs-cms = [ "sphinx~=8.1", "sphinx-rtd-theme~=3.1", "sphinx-copybutton~=0.5", "sphinxcontrib-mermaid~=2.0", + "autodoc-pydantic~=2.2", ] [build-system] diff --git a/uv.lock b/uv.lock index e8ea02212..5746a254f 100644 --- a/uv.lock +++ b/uv.lock @@ -97,6 +97,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" }, ] +[[package]] +name = "autodoc-pydantic" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "sphinx", version = "8.1.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "sphinx", version = "8.2.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/df/87120e2195f08d760bc5cf8a31cfa2381a6887517aa89453b23f1ae3354f/autodoc_pydantic-2.2.0-py3-none-any.whl", hash = "sha256:8c6a36fbf6ed2700ea9c6d21ea76ad541b621fbdf16b5a80ee04673548af4d95", size = 34001, upload-time = "2024-04-27T10:57:00.542Z" }, +] + [[package]] name = "babel" version = "2.18.0" @@ -299,6 +313,7 @@ dev = [ { name = "vulture" }, ] docs = [ + { name = "autodoc-pydantic" }, { name = "nvidia-sphinx-theme" }, { name = "sphinx", version = "8.1.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "sphinx", version = "8.2.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -309,6 +324,7 @@ docs = [ { name = "sphinxext-opengraph" }, ] docs-cms = [ + { name = "autodoc-pydantic" }, { name = "sphinx", version = "8.1.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "sphinx", version = "8.2.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "sphinx-copybutton" }, @@ -319,6 +335,8 @@ docs-cms = [ [package.metadata] requires-dist = [ { name = "aiconfigurator", specifier = "~=0.5.0" }, + { name = "autodoc-pydantic", marker = "extra == 'docs'", specifier = "~=2.2" }, + { name = "autodoc-pydantic", marker = "extra == 'docs-cms'", specifier = "~=2.2" }, { name = "bokeh", specifier = "~=3.8" }, { name = "build", marker = "extra == 'dev'", specifier = "~=1.4" }, { name = "click", specifier = "~=8.3" }, @@ -1958,6 +1976,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/36/c7/cfc8e811f061c841d7990b0201912c3556bfeb99cdcb7ed24adc8d6f8704/pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51", size = 2145302, upload-time = "2025-11-04T13:43:46.64Z" }, ] +[[package]] +name = "pydantic-settings" +version = "2.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/43/4b/ac7e0aae12027748076d72a8764ff1c9d82ca75a7a52622e67ed3f765c54/pydantic_settings-2.12.0.tar.gz", hash = "sha256:005538ef951e3c2a68e1c08b292b5f2e71490def8589d4221b95dab00dafcfd0", size = 194184, upload-time = "2025-11-10T14:25:47.013Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/60/5d4751ba3f4a40a6891f24eec885f51afd78d208498268c734e256fb13c4/pydantic_settings-2.12.0-py3-none-any.whl", hash = "sha256:fddb9fd99a5b18da837b29710391e945b1e30c135477f484084ee513adb93809", size = 51880, upload-time = "2025-11-10T14:25:45.546Z" }, +] + [[package]] name = "pydata-sphinx-theme" version = "0.16.1" @@ -2073,6 +2105,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "python-dotenv" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f0/26/19cadc79a718c5edbec86fd4919a6b6d3f681039a2f6d66d14be94e75fb9/python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6", size = 44221, upload-time = "2025-10-26T15:12:10.434Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" }, +] + [[package]] name = "pytz" version = "2025.2" From 6097bf1c64ff6f02012e744b421cd512fd67e8e6 Mon Sep 17 00:00:00 2001 From: Andrei Maslennikov Date: Fri, 13 Feb 2026 16:45:19 +0100 Subject: [PATCH 5/7] Address review comments --- .../vllm/slurm_command_gen_strategy.py | 5 +- .../vllm/test_command_gen_strategy_slurm.py | 48 ++++++++++--------- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py b/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py index e18b43c23..dea4b343c 100644 --- a/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py @@ -73,13 +73,12 @@ def get_vllm_serve_commands(self) -> list[list[str]]: prefill_port = cmd_args.port + 100 decode_port = cmd_args.port + 200 - kv_transfer_config: dict[str, Any] = {"kv_connector": "NixlConnector"} - commands: list[list[str]] = [] for port, role, args in [ (prefill_port, "kv_producer", tdef.cmd_args.prefill), (decode_port, "kv_consumer", tdef.cmd_args.decode), ]: + kv_transfer_config: dict[str, Any] = {"kv_connector": "NixlConnector", "kv_role": role} if args.nixl_threads is not None: kv_transfer_config.setdefault("kv_connector_extra_config", {}) kv_transfer_config["kv_connector_extra_config"]["num_threads"] = cast(int, args.nixl_threads) @@ -89,7 +88,7 @@ def get_vllm_serve_commands(self) -> list[list[str]]: "--port", str(port), "--kv-transfer-config", - self._to_json_str_arg(kv_transfer_config | {"kv_role": role}), + self._to_json_str_arg(kv_transfer_config), *args.serve_args, ] ) diff --git a/tests/workloads/vllm/test_command_gen_strategy_slurm.py b/tests/workloads/vllm/test_command_gen_strategy_slurm.py index 7b15b9a41..13f160516 100644 --- a/tests/workloads/vllm/test_command_gen_strategy_slurm.py +++ b/tests/workloads/vllm/test_command_gen_strategy_slurm.py @@ -181,40 +181,42 @@ def test_prefill_serve_args_with_custom_fields(self) -> None: class TestVllmServeCommand: - @pytest.mark.parametrize("decode_nthreads", [None, 4]) + @pytest.mark.parametrize( + "decode_nthreads,prefill_nthreads", + [ + (None, None), + (4, 2), + (None, 2), + (4, None), + ], + ) def test_decode_nixl_threads( - self, decode_nthreads: int | None, vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy + self, + decode_nthreads: int | None, + prefill_nthreads: int | None, + vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy, ) -> None: tdef = cast(VllmTestDefinition, vllm_cmd_gen_strategy.test_run.test) - tdef.cmd_args.prefill = VllmArgs() + tdef.cmd_args.prefill = VllmArgs(nixl_threads=prefill_nthreads) tdef.cmd_args.decode.nixl_threads = decode_nthreads commands = vllm_cmd_gen_strategy.get_vllm_serve_commands() assert len(commands) == 2 - dec_cmd = " ".join(commands[1]) - if decode_nthreads is not None: - assert "kv_connector_extra_config" in dec_cmd - assert f'"num_threads":{decode_nthreads}' in dec_cmd - else: - assert all(arg not in dec_cmd for arg in ["num_threads", "kv_connector_extra_config"]) - - @pytest.mark.parametrize("prefill_nthreads", [None, 2]) - def test_prefill_nixl_threads( - self, prefill_nthreads: int | None, vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy - ) -> None: - tdef = cast(VllmTestDefinition, vllm_cmd_gen_strategy.test_run.test) - tdef.cmd_args.prefill = VllmArgs(nixl_threads=prefill_nthreads) - - commands = vllm_cmd_gen_strategy.get_vllm_serve_commands() - assert len(commands) == 2 - pre_cmd = " ".join(commands[0]) + prefill_cmd = " ".join(commands[0]) if prefill_nthreads is not None: - assert "kv_connector_extra_config" in pre_cmd - assert f'"num_threads":{prefill_nthreads}' in pre_cmd + assert "kv_connector_extra_config" in prefill_cmd + assert f'"num_threads":{prefill_nthreads}' in prefill_cmd + else: + assert all(arg not in prefill_cmd for arg in ["num_threads", "kv_connector_extra_config"]) + + decode_cmd = " ".join(commands[1]) + if decode_nthreads is not None: + assert "kv_connector_extra_config" in decode_cmd + assert f'"num_threads":{decode_nthreads}' in decode_cmd else: - assert all(arg not in pre_cmd for arg in ["num_threads", "kv_connector_extra_config"]) + assert all(arg not in decode_cmd for arg in ["num_threads", "kv_connector_extra_config"]) class TestVllmBenchCommand: From 26499056b0f6d9d53de7f848203db63193c5bcf2 Mon Sep 17 00:00:00 2001 From: Andrei Maslennikov Date: Fri, 13 Feb 2026 16:55:52 +0100 Subject: [PATCH 6/7] Address review comments --- src/cloudai/workloads/vllm/slurm_command_gen_strategy.py | 3 +-- tests/workloads/vllm/test_command_gen_strategy_slurm.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py b/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py index dea4b343c..6334b775b 100644 --- a/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/vllm/slurm_command_gen_strategy.py @@ -80,8 +80,7 @@ def get_vllm_serve_commands(self) -> list[list[str]]: ]: kv_transfer_config: dict[str, Any] = {"kv_connector": "NixlConnector", "kv_role": role} if args.nixl_threads is not None: - kv_transfer_config.setdefault("kv_connector_extra_config", {}) - kv_transfer_config["kv_connector_extra_config"]["num_threads"] = cast(int, args.nixl_threads) + kv_transfer_config["kv_connector_extra_config"] = {"num_threads": cast(int, args.nixl_threads)} commands.append( [ *base_cmd, diff --git a/tests/workloads/vllm/test_command_gen_strategy_slurm.py b/tests/workloads/vllm/test_command_gen_strategy_slurm.py index 13f160516..380c29ba9 100644 --- a/tests/workloads/vllm/test_command_gen_strategy_slurm.py +++ b/tests/workloads/vllm/test_command_gen_strategy_slurm.py @@ -190,7 +190,7 @@ class TestVllmServeCommand: (4, None), ], ) - def test_decode_nixl_threads( + def test_nixl_threads( self, decode_nthreads: int | None, prefill_nthreads: int | None, From dcd70bd5e9d4d2e0639ae5c4da027b3a032fcac5 Mon Sep 17 00:00:00 2001 From: Andrei Maslennikov Date: Fri, 13 Feb 2026 17:02:29 +0100 Subject: [PATCH 7/7] Make sure --nixl-threads doesn't appear in CLI --- src/cloudai/workloads/vllm/vllm.py | 2 +- tests/workloads/vllm/test_command_gen_strategy_slurm.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/cloudai/workloads/vllm/vllm.py b/src/cloudai/workloads/vllm/vllm.py index 9f8ddd31a..ee091f7a8 100644 --- a/src/cloudai/workloads/vllm/vllm.py +++ b/src/cloudai/workloads/vllm/vllm.py @@ -44,7 +44,7 @@ class VllmArgs(CmdArgs): def serve_args(self) -> list[str]: """Convert cmd_args_dict to command-line arguments list for vllm serve.""" args = [] - for k, v in self.model_dump(exclude={"gpu_ids"}, exclude_none=True).items(): + for k, v in self.model_dump(exclude={"gpu_ids", "nixl_threads"}, exclude_none=True).items(): opt = f"--{k.replace('_', '-')}" if v == "": args.append(opt) diff --git a/tests/workloads/vllm/test_command_gen_strategy_slurm.py b/tests/workloads/vllm/test_command_gen_strategy_slurm.py index 380c29ba9..7f78d0b49 100644 --- a/tests/workloads/vllm/test_command_gen_strategy_slurm.py +++ b/tests/workloads/vllm/test_command_gen_strategy_slurm.py @@ -122,8 +122,9 @@ def test_decode_nodes_set(self, vllm_tr: TestRun, slurm_system: SlurmSystem) -> class TestServeExtraArgs: """Tests for serve_args property.""" - def test_serve_args_empty_by_default(self) -> None: + def test_serve_args_empty(self) -> None: assert VllmArgs().serve_args == [] + assert VllmArgs(gpu_ids="0", nixl_threads=1).serve_args == [] def test_empty_string_value_means_flag(self) -> None: assert VllmArgs.model_validate( @@ -205,6 +206,7 @@ def test_nixl_threads( assert len(commands) == 2 prefill_cmd = " ".join(commands[0]) + assert "--nixl-threads" not in prefill_cmd if prefill_nthreads is not None: assert "kv_connector_extra_config" in prefill_cmd assert f'"num_threads":{prefill_nthreads}' in prefill_cmd @@ -212,6 +214,7 @@ def test_nixl_threads( assert all(arg not in prefill_cmd for arg in ["num_threads", "kv_connector_extra_config"]) decode_cmd = " ".join(commands[1]) + assert "--nixl-threads" not in decode_cmd if decode_nthreads is not None: assert "kv_connector_extra_config" in decode_cmd assert f'"num_threads":{decode_nthreads}' in decode_cmd