Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def setup(app):
"sphinx.ext.autosummary",
"sphinxcontrib.mermaid",
"sphinx_copybutton",
"sphinxcontrib.autodoc_pydantic",
]

exclude_patterns = ["_build"]
Expand All @@ -74,9 +75,12 @@ def setup(app):
"members": True,
"member-order": "bysource",
"special-members": "__init__",
"undoc-members": False, # Don't show undocumented members
}

autodoc_pydantic_model_show_json = False
autodoc_pydantic_model_show_field_summary = False
autodoc_pydantic_model_show_config_summary = False

# Generate autosummary even if no references
autosummary_generate = True

Expand Down
6 changes: 6 additions & 0 deletions doc/workloads/vllm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ In this case the proxy script will be mounted from the vLLM repository (cloned l
API Documentation
-----------------

vLLM Serve Arguments
~~~~~~~~~~~~~~~~~~~~

.. autopydantic_model:: cloudai.workloads.vllm.vllm.VllmArgs
:members:

Command Arguments
~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ requires-python = ">=3.10"
"sphinx-design~=0.6",
"sphinxcontrib-mermaid~=2.0",
"sphinx-copybutton~=0.5",
"autodoc-pydantic~=2.2",
]
docs-cms = [
"sphinx~=8.1",
"sphinx-rtd-theme~=3.1",
"sphinx-copybutton~=0.5",
"sphinxcontrib-mermaid~=2.0",
"autodoc-pydantic~=2.2",
]

[build-system]
Expand Down
45 changes: 26 additions & 19 deletions src/cloudai/workloads/vllm/slurm_command_gen_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast
import json
from typing import Any, cast

from cloudai.systems.slurm import SlurmCommandGenStrategy

Expand Down Expand Up @@ -57,6 +58,10 @@ def decode_gpu_ids(self) -> list[int]:
mid = len(self.gpu_ids) // 2
return self.gpu_ids[mid:]

@staticmethod
def _to_json_str_arg(config: dict) -> str:
return "'" + json.dumps(config, separators=(",", ":")) + "'"

def get_vllm_serve_commands(self) -> list[list[str]]:
tdef: VllmTestDefinition = cast(VllmTestDefinition, self.test_run.test)
cmd_args: VllmCmdArgs = tdef.cmd_args
Expand All @@ -68,24 +73,26 @@ def get_vllm_serve_commands(self) -> list[list[str]]:
prefill_port = cmd_args.port + 100
decode_port = cmd_args.port + 200

prefill_extra_args = tdef.cmd_args.prefill.serve_args if tdef.cmd_args.prefill else []
prefill_cmd = [
*base_cmd,
"--port",
str(prefill_port),
"--kv-transfer-config",
'\'{"kv_connector":"NixlConnector","kv_role":"kv_producer"}\'',
*prefill_extra_args,
]
decode_cmd = [
*base_cmd,
"--port",
str(decode_port),
"--kv-transfer-config",
'\'{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}\'',
*tdef.cmd_args.decode.serve_args,
]
return [prefill_cmd, decode_cmd]
commands: list[list[str]] = []
for port, role, args in [
(prefill_port, "kv_producer", tdef.cmd_args.prefill),
(decode_port, "kv_consumer", tdef.cmd_args.decode),
]:
kv_transfer_config: dict[str, Any] = {"kv_connector": "NixlConnector", "kv_role": role}
if args.nixl_threads is not None:
kv_transfer_config["kv_connector_extra_config"] = {"num_threads": cast(int, args.nixl_threads)}
commands.append(
[
*base_cmd,
"--port",
str(port),
"--kv-transfer-config",
self._to_json_str_arg(kv_transfer_config),
*args.serve_args,
]
)

return commands

def get_proxy_command(self) -> list[str]:
prefill_port = self.tdef.cmd_args.port + 100
Expand Down
8 changes: 6 additions & 2 deletions src/cloudai/workloads/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,19 @@ class VllmArgs(CmdArgs):
"""Base command arguments for vLLM instances."""

gpu_ids: str | list[str] | None = Field(
default=None, description="Comma-separated GPU IDs. If not set, will use all available GPUs."
)

nixl_threads: int | list[int] | None = Field(
default=None,
description="Comma-separated GPU IDs. If not set, will use all available GPUs.",
description="Set ``kv_connector_extra_config.num_threads`` for ``--kv-transfer-config`` CLI argument.",
)

@property
def serve_args(self) -> list[str]:
"""Convert cmd_args_dict to command-line arguments list for vllm serve."""
args = []
for k, v in self.model_dump(exclude={"gpu_ids"}, exclude_none=True).items():
for k, v in self.model_dump(exclude={"gpu_ids", "nixl_threads"}, exclude_none=True).items():
opt = f"--{k.replace('_', '-')}"
if v == "":
args.append(opt)
Expand Down
130 changes: 87 additions & 43 deletions tests/workloads/vllm/test_command_gen_strategy_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ def vllm_disagg_tr(vllm: VllmTestDefinition, tmp_path: Path) -> TestRun:
return TestRun(test=vllm, num_nodes=1, nodes=[], output_path=tmp_path, name="vllm-disagg-job")


def test_container_mounts(vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy) -> None:
assert vllm_cmd_gen_strategy._container_mounts() == [
f"{vllm_cmd_gen_strategy.system.hf_home_path.absolute()}:/root/.cache/huggingface"
]


def test_sweep_detection(vllm: VllmTestDefinition) -> None:
assert vllm.is_dse_job is False
vllm.cmd_args.decode.gpu_ids = ["1"]
assert vllm.is_dse_job is True


class TestGpuDetection:
"""Tests for GPU detection logic."""

Expand Down Expand Up @@ -110,8 +122,9 @@ def test_decode_nodes_set(self, vllm_tr: TestRun, slurm_system: SlurmSystem) ->
class TestServeExtraArgs:
"""Tests for serve_args property."""

def test_serve_args_empty_by_default(self) -> None:
def test_serve_args_empty(self) -> None:
assert VllmArgs().serve_args == []
assert VllmArgs(gpu_ids="0", nixl_threads=1).serve_args == []

def test_empty_string_value_means_flag(self) -> None:
assert VllmArgs.model_validate(
Expand Down Expand Up @@ -168,49 +181,48 @@ def test_prefill_serve_args_with_custom_fields(self) -> None:
]


def test_container_mounts(vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy) -> None:
assert vllm_cmd_gen_strategy._container_mounts() == [
f"{vllm_cmd_gen_strategy.system.hf_home_path.absolute()}:/root/.cache/huggingface"
]


class TestVllmAggregatedMode:
"""Tests for vLLM non-disaggregated mode with 1 GPU."""

def test_get_vllm_serve_commands_single_gpu(self, vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy) -> None:
cmd_args = vllm_cmd_gen_strategy.test_run.test.cmd_args
class TestVllmServeCommand:
@pytest.mark.parametrize(
"decode_nthreads,prefill_nthreads",
[
(None, None),
(4, 2),
(None, 2),
(4, None),
],
)
def test_nixl_threads(
self,
decode_nthreads: int | None,
prefill_nthreads: int | None,
vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy,
) -> None:
tdef = cast(VllmTestDefinition, vllm_cmd_gen_strategy.test_run.test)
tdef.cmd_args.prefill = VllmArgs(nixl_threads=prefill_nthreads)
tdef.cmd_args.decode.nixl_threads = decode_nthreads

commands = vllm_cmd_gen_strategy.get_vllm_serve_commands()

assert len(commands) == 1
assert commands[0] == ["vllm", "serve", cmd_args.model, "--port", str(cmd_args.port)]

def test_generate_wait_for_health_function(self, vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy) -> None:
cmd_args = vllm_cmd_gen_strategy.test_run.test.cmd_args

func = vllm_cmd_gen_strategy.generate_wait_for_health_function()

expected = f"""\
wait_for_health() {{
local endpoint="$1"
local timeout={cmd_args.vllm_serve_wait_seconds}
local interval=5
local end_time=$(($(date +%s) + timeout))
assert len(commands) == 2

while [ "$(date +%s)" -lt "$end_time" ]; do
if curl -sf "$endpoint" > /dev/null 2>&1; then
echo "Health check passed: $endpoint"
return 0
fi
sleep "$interval"
done
prefill_cmd = " ".join(commands[0])
assert "--nixl-threads" not in prefill_cmd
if prefill_nthreads is not None:
assert "kv_connector_extra_config" in prefill_cmd
assert f'"num_threads":{prefill_nthreads}' in prefill_cmd
else:
assert all(arg not in prefill_cmd for arg in ["num_threads", "kv_connector_extra_config"])

echo "Timeout waiting for: $endpoint"
return 1
}}"""
decode_cmd = " ".join(commands[1])
assert "--nixl-threads" not in decode_cmd
if decode_nthreads is not None:
assert "kv_connector_extra_config" in decode_cmd
assert f'"num_threads":{decode_nthreads}' in decode_cmd
else:
assert all(arg not in decode_cmd for arg in ["num_threads", "kv_connector_extra_config"])

assert func == expected

class TestVllmBenchCommand:
def test_get_vllm_bench_command(self, vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy) -> None:
tdef = cast(VllmTestDefinition, vllm_cmd_gen_strategy.test_run.test)
cmd_args = tdef.cmd_args
Expand Down Expand Up @@ -247,6 +259,44 @@ def test_get_vllm_bench_command_with_extra_args(
assert "--extra-2 2" in cmd
assert "--extra-3 3" in cmd


class TestVllmAggregatedMode:
"""Tests for vLLM non-disaggregated mode with 1 GPU."""

def test_get_vllm_serve_commands_single_gpu(self, vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy) -> None:
cmd_args = vllm_cmd_gen_strategy.test_run.test.cmd_args

commands = vllm_cmd_gen_strategy.get_vllm_serve_commands()

assert len(commands) == 1
assert commands[0] == ["vllm", "serve", cmd_args.model, "--port", str(cmd_args.port)]

def test_generate_wait_for_health_function(self, vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy) -> None:
cmd_args = vllm_cmd_gen_strategy.test_run.test.cmd_args

func = vllm_cmd_gen_strategy.generate_wait_for_health_function()

expected = f"""\
wait_for_health() {{
local endpoint="$1"
local timeout={cmd_args.vllm_serve_wait_seconds}
local interval=5
local end_time=$(($(date +%s) + timeout))

while [ "$(date +%s)" -lt "$end_time" ]; do
if curl -sf "$endpoint" > /dev/null 2>&1; then
echo "Health check passed: $endpoint"
return 0
fi
sleep "$interval"
done

echo "Timeout waiting for: $endpoint"
return 1
}}"""

assert func == expected

def test_gen_srun_command_full_flow(self, vllm_cmd_gen_strategy: VllmSlurmCommandGenStrategy) -> None:
tdef = vllm_cmd_gen_strategy.test_run.test
cmd_args = tdef.cmd_args
Expand Down Expand Up @@ -411,9 +461,3 @@ def test_gen_srun_command_disagg_flow(self, vllm_disagg_tr: TestRun, slurm_syste
{bench_cmd}"""

assert srun_command == expected


def test_sweep_detection(vllm: VllmTestDefinition) -> None:
assert vllm.is_dse_job is False
vllm.cmd_args.decode.gpu_ids = ["1"]
assert vllm.is_dse_job is True
41 changes: 41 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.