Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
fc843ed
Add support to download models automatically if --model specified in …
rohan-uiuc Oct 30, 2025
5f790ff
create model dir if it doesn't exist
rohan-uiuc Oct 30, 2025
0f22bec
Check model weights existence before binding; use HF model name if mi…
rohan-uiuc Nov 12, 2025
9f2fdd2
Remove commented code
rohan-uiuc Nov 12, 2025
38011be
Apply code formatting fixes from pre-commit
rohan-uiuc Nov 12, 2025
4de3563
revert unnecessary test change
rohan-uiuc Nov 12, 2025
eb1e929
Merge branch 'develop' of https://github.com/Center-for-AI-Innovation…
rohan-uiuc Nov 12, 2025
8b6a211
Apply formatting fixes from pre-commit
rohan-uiuc Nov 12, 2025
c68cb35
Add tests for model weights existence coverage
rohan-uiuc Nov 12, 2025
b610891
Remove redundant /dev/infiniband
rohan-uiuc Jan 5, 2026
a7a5deb
Remove unused variable
rohan-uiuc Jan 5, 2026
bb3142b
Add warning if downloading weights and HF cache not set
rohan-uiuc Jan 5, 2026
2db9c6c
format ONLY
rohan-uiuc Jan 5, 2026
079c86a
Add --hf-model CLI option and config field
rohan-uiuc Jan 29, 2026
20163da
Use hf_model as model source when local weights missing
rohan-uiuc Jan 29, 2026
ed82b77
Pass hf_model from CLI to launch params
rohan-uiuc Jan 29, 2026
d3f6772
Add tests for hf_model override in slurm script generation
rohan-uiuc Jan 29, 2026
1c312df
Add documentation for --hf-model option
rohan-uiuc Jan 29, 2026
cd7a9b1
Rebase with main and update hf_model field processing logic
XkunW Mar 29, 2026
03864a5
Fix typos
XkunW Mar 30, 2026
292200f
Fix tests
XkunW Mar 30, 2026
ba9030f
ruff check & format
XkunW Mar 30, 2026
b4037e6
Merge branch 'main' into hf_download
XkunW Mar 30, 2026
8e66e26
ruff format
XkunW Mar 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ To overwrite default inference engine arguments, you can specify the arguments i
vec-inf launch Meta-Llama-3.1-8B-Instruct --vllm-args '--max-model-len=65536,--compilation-config=3'
```

To download models directly from HuggingFace Hub without needing cached local weights, use `--hf-model`:

```bash
vec-inf launch Qwen2.5-3B-Instruct \
--hf-model Qwen/Qwen2.5-3B-Instruct \
--env 'HF_HOME=/path/to/cache' \
--vllm-args '--max-model-len=4096'
```

`--env` parses environment variables to the container. If cached local weights exist, they take priority over `--hf-model`.

For the full list of inference engine arguments, you can find them here:

* [vLLM: `vllm serve` Arguments](https://docs.vllm.ai/en/stable/serving/engine_args.html)
Expand Down
12 changes: 12 additions & 0 deletions tests/vec_inf/client/test_engine_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def _set_required_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VEC_INF_ACCOUNT", "test-account")
monkeypatch.setenv("VEC_INF_WORK_DIR", "/tmp")

@pytest.fixture(autouse=True)
def _mock_validate_weights_path(self) -> None:
"""Avoid disk checks for fake model paths in fixtures."""
with patch("vec_inf.client._helper.utils.validate_weights_path"):
yield

@pytest.fixture
def model_config_vllm(self) -> ModelConfig:
"""Fixture providing a vLLM model configuration."""
Expand Down Expand Up @@ -187,6 +193,12 @@ def _set_required_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VEC_INF_ACCOUNT", "test-account")
monkeypatch.setenv("VEC_INF_WORK_DIR", "/tmp")

@pytest.fixture(autouse=True)
def _mock_validate_weights_path(self) -> None:
"""Avoid disk checks for fake model paths in fixtures."""
with patch("vec_inf.client._helper.utils.validate_weights_path"):
yield

@pytest.fixture
def model_config(self) -> ModelConfig:
"""Fixture providing a basic model configuration."""
Expand Down
12 changes: 12 additions & 0 deletions tests/vec_inf/client/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
class TestModelLauncher:
"""Tests for the ModelLauncher class."""

@pytest.fixture(autouse=True)
def _mock_validate_weights_path(self) -> None:
"""Avoid disk checks for fake model paths in fixtures."""
with patch("vec_inf.client._helper.utils.validate_weights_path"):
yield

@pytest.fixture
def model_config(self) -> ModelConfig:
"""Fixture providing a basic model configuration for tests."""
Expand Down Expand Up @@ -385,6 +391,12 @@ def test_launch_with_sglang_engine(
class TestBatchModelLauncher:
"""Tests for the BatchModelLauncher class."""

@pytest.fixture(autouse=True)
def _mock_validate_weights_path(self) -> None:
"""Avoid disk checks for fake model paths in fixtures."""
with patch("vec_inf.client._helper.utils.validate_weights_path"):
yield

@pytest.fixture
def batch_model_configs(self) -> list[ModelConfig]:
"""Fixture providing batch model configurations for tests."""
Expand Down
75 changes: 72 additions & 3 deletions tests/vec_inf/client/test_slurm_script_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_init_singularity(self, singularity_params):
assert generator.params == singularity_params
assert generator.use_container
assert not generator.is_multinode
assert generator.additional_binds == ",/scratch:/scratch,/data:/data"
assert generator.additional_binds == "/scratch:/scratch,/data:/data"
assert generator.model_weights_path == "/path/to/model_weights/test-model"
assert (
generator.env_str
Expand Down Expand Up @@ -186,6 +186,17 @@ def test_generate_server_setup_singularity(self, singularity_params):
"module load " in setup or "apptainer" in setup.lower()
) # Remove module name since it's inconsistent between clusters

def test_generate_server_setup_singularity_no_weights(self, singularity_params):
"""Test server setup when using hf_model (no local weights in bind path)."""
params = singularity_params.copy()
params["hf_model"] = "test-org/test-model"

generator = SlurmScriptGenerator(params)
setup = generator._generate_server_setup()

assert "module load" in setup or "apptainer" in setup.lower()
assert "/path/to/model_weights/test-model" not in setup

def test_generate_launch_cmd_venv(self, basic_params):
"""Test launch command generation with virtual environment."""
generator = SlurmScriptGenerator(basic_params)
Expand Down Expand Up @@ -314,6 +325,16 @@ def test_generate_script_content_sglang(self, basic_params):
assert "sglang.launch_server" in content
assert "find_available_port" in content

def test_generate_launch_cmd_with_hf_model_override(self, basic_params):
"""Test launch command uses hf_model when specified."""
params = basic_params.copy()
params["hf_model"] = "meta-llama/Meta-Llama-3.1-8B-Instruct"
generator = SlurmScriptGenerator(params)
launch_cmd = generator._generate_launch_cmd()

assert "vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct" in launch_cmd
assert "vllm serve /path/to/model_weights/test-model" not in launch_cmd

def test_generate_launch_cmd_singularity(self, singularity_params):
"""Test launch command generation with Singularity."""
generator = SlurmScriptGenerator(singularity_params)
Expand All @@ -322,6 +343,18 @@ def test_generate_launch_cmd_singularity(self, singularity_params):
assert "apptainer exec --nv" in launch_cmd
assert "source" not in launch_cmd

def test_generate_launch_cmd_singularity_no_local_weights(self, singularity_params):
"""Test container launch when using hf_model instead of local weights."""
params = singularity_params.copy()
params["hf_model"] = "test-org/test-model"

generator = SlurmScriptGenerator(params)
launch_cmd = generator._generate_launch_cmd()

assert "exec --nv" in launch_cmd
assert "vllm serve test-org/test-model" in launch_cmd
assert "vllm serve /path/to/model_weights/test-model" not in launch_cmd

def test_generate_launch_cmd_boolean_args(self, basic_params):
"""Test launch command with boolean vLLM arguments."""
params = basic_params.copy()
Expand Down Expand Up @@ -468,11 +501,11 @@ def test_init_singularity(self, batch_singularity_params):
assert generator.use_container
assert (
generator.params["models"]["model1"]["additional_binds"]
== ",/scratch:/scratch,/data:/data"
== "/scratch:/scratch,/data:/data"
)
assert (
generator.params["models"]["model2"]["additional_binds"]
== ",/scratch:/scratch,/data:/data"
== "/scratch:/scratch,/data:/data"
)

def test_init_singularity_no_bind(self, batch_params):
Expand Down Expand Up @@ -514,6 +547,22 @@ def test_generate_model_launch_script_basic(
mock_touch.assert_called_once()
mock_write_text.assert_called_once()

@patch("pathlib.Path.touch")
@patch("pathlib.Path.write_text")
def test_generate_model_launch_script_with_hf_model_override(
self, mock_write_text, mock_touch, batch_params
):
"""Test batch launch script uses hf_model when specified."""
params = batch_params.copy()
params["models"] = {k: v.copy() for k, v in batch_params["models"].items()}
params["models"]["model1"]["hf_model"] = "meta-llama/Meta-Llama-3.1-8B-Instruct"

generator = BatchSlurmScriptGenerator(params)
generator._generate_model_launch_script("model1")

call_args = mock_write_text.call_args[0][0]
assert "vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct" in call_args

@patch("pathlib.Path.touch")
@patch("pathlib.Path.write_text")
def test_generate_model_launch_script_singularity(
Expand All @@ -528,6 +577,26 @@ def test_generate_model_launch_script_singularity(
mock_touch.assert_called_once()
mock_write_text.assert_called_once()

@patch("pathlib.Path.touch")
@patch("pathlib.Path.write_text")
def test_generate_model_launch_script_singularity_no_weights(
self, mock_write_text, mock_touch, batch_singularity_params
):
"""Test batch model launch script when using hf_model (no local weights)."""
params = batch_singularity_params.copy()
params["models"] = {
k: v.copy() for k, v in batch_singularity_params["models"].items()
}
params["models"]["model1"]["hf_model"] = "test-org/model1"

generator = BatchSlurmScriptGenerator(params)
script_path = generator._generate_model_launch_script("model1")

assert script_path.name == "launch_model1.sh"
call_args = mock_write_text.call_args[0][0]
assert "/path/to/model_weights/model1" not in call_args
assert "vllm serve test-org/model1" in call_args

@patch("vec_inf.client._slurm_script_generator.datetime")
@patch("pathlib.Path.touch")
@patch("pathlib.Path.write_text")
Expand Down
11 changes: 11 additions & 0 deletions vec_inf/cli/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,15 @@ def cli() -> None:
type=str,
help="Path to parent directory containing model weights",
)
@click.option(
"--hf-model",
type=str,
help=(
"Full HuggingFace model id/path to use for vLLM serve (e.g. "
"'meta-llama/Meta-Llama-3.1-8B-Instruct'). "
"Keeps model-name as the short identifier for config/logs/job naming."
),
)
@click.option(
"--engine",
type=str,
Expand Down Expand Up @@ -210,6 +219,8 @@ def launch(
Path to SLURM log directory
- model_weights_parent_dir : str, optional
Path to model weights directory
- hf_model : str, optional
Full HuggingFace model id/path to use for vLLM serve
- vllm_args : str, optional
vllm engine arguments
- sglang_args : str, optional
Expand Down
6 changes: 6 additions & 0 deletions vec_inf/client/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ def _get_launch_params(self) -> dict[str, Any]:
# Override config defaults with CLI arguments
self._apply_cli_overrides(params)

# Validate weights path exists or HF model provided, and check HF cache config
utils.validate_weights_path(params, self.model_name)

# Check for required fields without default vals, will raise an error if missing
utils.check_required_fields(params)

Expand Down Expand Up @@ -606,6 +609,9 @@ def _get_launch_params(
for engine in SUPPORTED_ENGINES:
del params["models"][model_name][f"{engine}_args"]

# Validate that weights path exists or HF model provided
utils.validate_weights_path(params["models"][model_name], model_name)

# Validate resource allocation and parallelization settings
self._validate_resource_and_parallel_settings(
config, model_engine_args, model_name
Expand Down
22 changes: 13 additions & 9 deletions vec_inf/client/_slurm_script_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def __init__(self, params: dict[str, Any]):
self.engine = params.get("engine", "vllm")
self.is_multinode = int(self.params["num_nodes"]) > 1
self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
self.additional_binds = (
f",{self.params['bind']}" if self.params.get("bind") else ""
)
self.additional_binds = self.params["bind"] if self.params.get("bind") else ""
self.model_weights_path = str(
Path(self.params["model_weights_parent_dir"], self.params["model_name"])
)
Expand Down Expand Up @@ -113,7 +111,9 @@ def _generate_server_setup(self) -> str:
server_script.append(
SLURM_SCRIPT_TEMPLATE["bind_path"].format(
work_dir=self.params.get("work_dir", str(Path.home())),
model_weights_path=self.model_weights_path,
model_weights_path=f"{self.model_weights_path},"
if not self.params.get("hf_model")
else "",
additional_binds=self.additional_binds,
)
)
Expand Down Expand Up @@ -185,7 +185,8 @@ def _generate_launch_cmd(self) -> str:

launch_cmd.append(
"\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"][self.engine]).format( # type: ignore[literal-required]
model_weights_path=self.model_weights_path,
model_weights_path=self.params.get("hf_model")
or self.model_weights_path,
model_name=self.params["model_name"],
)
)
Expand Down Expand Up @@ -215,7 +216,7 @@ def _generate_multinode_sglang_launch_cmd(self) -> str:
SLURM_SCRIPT_TEMPLATE["launch_cmd"]["sglang_multinode"]
).format(
num_nodes=self.params["num_nodes"],
model_weights_path=self.model_weights_path,
model_weights_path=self.params.get("hf_model") or self.model_weights_path,
model_name=self.params["model_name"],
)

Expand Down Expand Up @@ -275,7 +276,7 @@ def __init__(self, params: dict[str, Any]):
self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
for model_name in self.params["models"]:
self.params["models"][model_name]["additional_binds"] = (
f",{self.params['models'][model_name]['bind']}"
self.params["models"][model_name]["bind"]
if self.params["models"][model_name].get("bind")
else ""
)
Expand Down Expand Up @@ -321,7 +322,9 @@ def _generate_model_launch_script(self, model_name: str) -> Path:
script_content.append(
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format(
work_dir=self.params.get("work_dir", str(Path.home())),
model_weights_path=model_params["model_weights_path"],
model_weights_path=f"{model_params['model_weights_path']},"
if not model_params.get("hf_model")
else "",
additional_binds=model_params["additional_binds"],
)
)
Expand All @@ -348,7 +351,8 @@ def _generate_model_launch_script(self, model_name: str) -> Path:
"\n".join(
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"][model_params["engine"]]
).format(
model_weights_path=model_params["model_weights_path"],
model_weights_path=model_params.get("hf_model")
or model_params["model_weights_path"],
model_name=model_name,
)
)
Expand Down
Loading
Loading