diff --git a/docs/user_guide.md b/docs/user_guide.md index db01229..79e8dea 100644 --- a/docs/user_guide.md +++ b/docs/user_guide.md @@ -63,6 +63,17 @@ To overwrite default inference engine arguments, you can specify the arguments i vec-inf launch Meta-Llama-3.1-8B-Instruct --vllm-args '--max-model-len=65536,--compilation-config=3' ``` +To download models directly from HuggingFace Hub without needing cached local weights, use `--hf-model`: + +```bash +vec-inf launch Qwen2.5-3B-Instruct \ + --hf-model Qwen/Qwen2.5-3B-Instruct \ + --env 'HF_HOME=/path/to/cache' \ + --vllm-args '--max-model-len=4096' +``` + +`--env` parses environment variables to the container. If cached local weights exist, they take priority over `--hf-model`. + For the full list of inference engine arguments, you can find them here: * [vLLM: `vllm serve` Arguments](https://docs.vllm.ai/en/stable/serving/engine_args.html) diff --git a/tests/vec_inf/client/test_engine_selection.py b/tests/vec_inf/client/test_engine_selection.py index 8812e8e..d457108 100644 --- a/tests/vec_inf/client/test_engine_selection.py +++ b/tests/vec_inf/client/test_engine_selection.py @@ -24,6 +24,12 @@ def _set_required_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("VEC_INF_ACCOUNT", "test-account") monkeypatch.setenv("VEC_INF_WORK_DIR", "/tmp") + @pytest.fixture(autouse=True) + def _mock_validate_weights_path(self) -> None: + """Avoid disk checks for fake model paths in fixtures.""" + with patch("vec_inf.client._helper.utils.validate_weights_path"): + yield + @pytest.fixture def model_config_vllm(self) -> ModelConfig: """Fixture providing a vLLM model configuration.""" @@ -187,6 +193,12 @@ def _set_required_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("VEC_INF_ACCOUNT", "test-account") monkeypatch.setenv("VEC_INF_WORK_DIR", "/tmp") + @pytest.fixture(autouse=True) + def _mock_validate_weights_path(self) -> None: + """Avoid disk checks for fake model paths in fixtures.""" + with patch("vec_inf.client._helper.utils.validate_weights_path"): + yield + @pytest.fixture def model_config(self) -> ModelConfig: """Fixture providing a basic model configuration.""" diff --git a/tests/vec_inf/client/test_helper.py b/tests/vec_inf/client/test_helper.py index 7f91000..3a65985 100644 --- a/tests/vec_inf/client/test_helper.py +++ b/tests/vec_inf/client/test_helper.py @@ -32,6 +32,12 @@ class TestModelLauncher: """Tests for the ModelLauncher class.""" + @pytest.fixture(autouse=True) + def _mock_validate_weights_path(self) -> None: + """Avoid disk checks for fake model paths in fixtures.""" + with patch("vec_inf.client._helper.utils.validate_weights_path"): + yield + @pytest.fixture def model_config(self) -> ModelConfig: """Fixture providing a basic model configuration for tests.""" @@ -385,6 +391,12 @@ def test_launch_with_sglang_engine( class TestBatchModelLauncher: """Tests for the BatchModelLauncher class.""" + @pytest.fixture(autouse=True) + def _mock_validate_weights_path(self) -> None: + """Avoid disk checks for fake model paths in fixtures.""" + with patch("vec_inf.client._helper.utils.validate_weights_path"): + yield + @pytest.fixture def batch_model_configs(self) -> list[ModelConfig]: """Fixture providing batch model configurations for tests.""" diff --git a/tests/vec_inf/client/test_slurm_script_generator.py b/tests/vec_inf/client/test_slurm_script_generator.py index ba37b08..4e8a364 100644 --- a/tests/vec_inf/client/test_slurm_script_generator.py +++ b/tests/vec_inf/client/test_slurm_script_generator.py @@ -115,7 +115,7 @@ def test_init_singularity(self, singularity_params): assert generator.params == singularity_params assert generator.use_container assert not generator.is_multinode - assert generator.additional_binds == ",/scratch:/scratch,/data:/data" + assert generator.additional_binds == "/scratch:/scratch,/data:/data" assert generator.model_weights_path == "/path/to/model_weights/test-model" assert ( generator.env_str @@ -186,6 +186,17 @@ def test_generate_server_setup_singularity(self, singularity_params): "module load " in setup or "apptainer" in setup.lower() ) # Remove module name since it's inconsistent between clusters + def test_generate_server_setup_singularity_no_weights(self, singularity_params): + """Test server setup when using hf_model (no local weights in bind path).""" + params = singularity_params.copy() + params["hf_model"] = "test-org/test-model" + + generator = SlurmScriptGenerator(params) + setup = generator._generate_server_setup() + + assert "module load" in setup or "apptainer" in setup.lower() + assert "/path/to/model_weights/test-model" not in setup + def test_generate_launch_cmd_venv(self, basic_params): """Test launch command generation with virtual environment.""" generator = SlurmScriptGenerator(basic_params) @@ -314,6 +325,16 @@ def test_generate_script_content_sglang(self, basic_params): assert "sglang.launch_server" in content assert "find_available_port" in content + def test_generate_launch_cmd_with_hf_model_override(self, basic_params): + """Test launch command uses hf_model when specified.""" + params = basic_params.copy() + params["hf_model"] = "meta-llama/Meta-Llama-3.1-8B-Instruct" + generator = SlurmScriptGenerator(params) + launch_cmd = generator._generate_launch_cmd() + + assert "vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct" in launch_cmd + assert "vllm serve /path/to/model_weights/test-model" not in launch_cmd + def test_generate_launch_cmd_singularity(self, singularity_params): """Test launch command generation with Singularity.""" generator = SlurmScriptGenerator(singularity_params) @@ -322,6 +343,18 @@ def test_generate_launch_cmd_singularity(self, singularity_params): assert "apptainer exec --nv" in launch_cmd assert "source" not in launch_cmd + def test_generate_launch_cmd_singularity_no_local_weights(self, singularity_params): + """Test container launch when using hf_model instead of local weights.""" + params = singularity_params.copy() + params["hf_model"] = "test-org/test-model" + + generator = SlurmScriptGenerator(params) + launch_cmd = generator._generate_launch_cmd() + + assert "exec --nv" in launch_cmd + assert "vllm serve test-org/test-model" in launch_cmd + assert "vllm serve /path/to/model_weights/test-model" not in launch_cmd + def test_generate_launch_cmd_boolean_args(self, basic_params): """Test launch command with boolean vLLM arguments.""" params = basic_params.copy() @@ -468,11 +501,11 @@ def test_init_singularity(self, batch_singularity_params): assert generator.use_container assert ( generator.params["models"]["model1"]["additional_binds"] - == ",/scratch:/scratch,/data:/data" + == "/scratch:/scratch,/data:/data" ) assert ( generator.params["models"]["model2"]["additional_binds"] - == ",/scratch:/scratch,/data:/data" + == "/scratch:/scratch,/data:/data" ) def test_init_singularity_no_bind(self, batch_params): @@ -514,6 +547,22 @@ def test_generate_model_launch_script_basic( mock_touch.assert_called_once() mock_write_text.assert_called_once() + @patch("pathlib.Path.touch") + @patch("pathlib.Path.write_text") + def test_generate_model_launch_script_with_hf_model_override( + self, mock_write_text, mock_touch, batch_params + ): + """Test batch launch script uses hf_model when specified.""" + params = batch_params.copy() + params["models"] = {k: v.copy() for k, v in batch_params["models"].items()} + params["models"]["model1"]["hf_model"] = "meta-llama/Meta-Llama-3.1-8B-Instruct" + + generator = BatchSlurmScriptGenerator(params) + generator._generate_model_launch_script("model1") + + call_args = mock_write_text.call_args[0][0] + assert "vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct" in call_args + @patch("pathlib.Path.touch") @patch("pathlib.Path.write_text") def test_generate_model_launch_script_singularity( @@ -528,6 +577,26 @@ def test_generate_model_launch_script_singularity( mock_touch.assert_called_once() mock_write_text.assert_called_once() + @patch("pathlib.Path.touch") + @patch("pathlib.Path.write_text") + def test_generate_model_launch_script_singularity_no_weights( + self, mock_write_text, mock_touch, batch_singularity_params + ): + """Test batch model launch script when using hf_model (no local weights).""" + params = batch_singularity_params.copy() + params["models"] = { + k: v.copy() for k, v in batch_singularity_params["models"].items() + } + params["models"]["model1"]["hf_model"] = "test-org/model1" + + generator = BatchSlurmScriptGenerator(params) + script_path = generator._generate_model_launch_script("model1") + + assert script_path.name == "launch_model1.sh" + call_args = mock_write_text.call_args[0][0] + assert "/path/to/model_weights/model1" not in call_args + assert "vllm serve test-org/model1" in call_args + @patch("vec_inf.client._slurm_script_generator.datetime") @patch("pathlib.Path.touch") @patch("pathlib.Path.write_text") diff --git a/vec_inf/cli/_cli.py b/vec_inf/cli/_cli.py index b21b0fb..f086989 100644 --- a/vec_inf/cli/_cli.py +++ b/vec_inf/cli/_cli.py @@ -132,6 +132,15 @@ def cli() -> None: type=str, help="Path to parent directory containing model weights", ) +@click.option( + "--hf-model", + type=str, + help=( + "Full HuggingFace model id/path to use for vLLM serve (e.g. " + "'meta-llama/Meta-Llama-3.1-8B-Instruct'). " + "Keeps model-name as the short identifier for config/logs/job naming." + ), +) @click.option( "--engine", type=str, @@ -210,6 +219,8 @@ def launch( Path to SLURM log directory - model_weights_parent_dir : str, optional Path to model weights directory + - hf_model : str, optional + Full HuggingFace model id/path to use for vLLM serve - vllm_args : str, optional vllm engine arguments - sglang_args : str, optional diff --git a/vec_inf/client/_helper.py b/vec_inf/client/_helper.py index f9b7eb3..ba1c888 100644 --- a/vec_inf/client/_helper.py +++ b/vec_inf/client/_helper.py @@ -355,6 +355,9 @@ def _get_launch_params(self) -> dict[str, Any]: # Override config defaults with CLI arguments self._apply_cli_overrides(params) + # Validate weights path exists or HF model provided, and check HF cache config + utils.validate_weights_path(params, self.model_name) + # Check for required fields without default vals, will raise an error if missing utils.check_required_fields(params) @@ -606,6 +609,9 @@ def _get_launch_params( for engine in SUPPORTED_ENGINES: del params["models"][model_name][f"{engine}_args"] + # Validate that weights path exists or HF model provided + utils.validate_weights_path(params["models"][model_name], model_name) + # Validate resource allocation and parallelization settings self._validate_resource_and_parallel_settings( config, model_engine_args, model_name diff --git a/vec_inf/client/_slurm_script_generator.py b/vec_inf/client/_slurm_script_generator.py index 83bb369..dcca77d 100644 --- a/vec_inf/client/_slurm_script_generator.py +++ b/vec_inf/client/_slurm_script_generator.py @@ -35,9 +35,7 @@ def __init__(self, params: dict[str, Any]): self.engine = params.get("engine", "vllm") self.is_multinode = int(self.params["num_nodes"]) > 1 self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME - self.additional_binds = ( - f",{self.params['bind']}" if self.params.get("bind") else "" - ) + self.additional_binds = self.params["bind"] if self.params.get("bind") else "" self.model_weights_path = str( Path(self.params["model_weights_parent_dir"], self.params["model_name"]) ) @@ -113,7 +111,9 @@ def _generate_server_setup(self) -> str: server_script.append( SLURM_SCRIPT_TEMPLATE["bind_path"].format( work_dir=self.params.get("work_dir", str(Path.home())), - model_weights_path=self.model_weights_path, + model_weights_path=f"{self.model_weights_path}," + if not self.params.get("hf_model") + else "", additional_binds=self.additional_binds, ) ) @@ -185,7 +185,8 @@ def _generate_launch_cmd(self) -> str: launch_cmd.append( "\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"][self.engine]).format( # type: ignore[literal-required] - model_weights_path=self.model_weights_path, + model_weights_path=self.params.get("hf_model") + or self.model_weights_path, model_name=self.params["model_name"], ) ) @@ -215,7 +216,7 @@ def _generate_multinode_sglang_launch_cmd(self) -> str: SLURM_SCRIPT_TEMPLATE["launch_cmd"]["sglang_multinode"] ).format( num_nodes=self.params["num_nodes"], - model_weights_path=self.model_weights_path, + model_weights_path=self.params.get("hf_model") or self.model_weights_path, model_name=self.params["model_name"], ) @@ -275,7 +276,7 @@ def __init__(self, params: dict[str, Any]): self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME for model_name in self.params["models"]: self.params["models"][model_name]["additional_binds"] = ( - f",{self.params['models'][model_name]['bind']}" + self.params["models"][model_name]["bind"] if self.params["models"][model_name].get("bind") else "" ) @@ -321,7 +322,9 @@ def _generate_model_launch_script(self, model_name: str) -> Path: script_content.append( BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format( work_dir=self.params.get("work_dir", str(Path.home())), - model_weights_path=model_params["model_weights_path"], + model_weights_path=f"{model_params['model_weights_path']}," + if not model_params.get("hf_model") + else "", additional_binds=model_params["additional_binds"], ) ) @@ -348,7 +351,8 @@ def _generate_model_launch_script(self, model_name: str) -> Path: "\n".join( BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"][model_params["engine"]] ).format( - model_weights_path=model_params["model_weights_path"], + model_weights_path=model_params.get("hf_model") + or model_params["model_weights_path"], model_name=model_name, ) ) diff --git a/vec_inf/client/_utils.py b/vec_inf/client/_utils.py index e4a918f..e722f12 100644 --- a/vec_inf/client/_utils.py +++ b/vec_inf/client/_utils.py @@ -459,3 +459,92 @@ def check_required_fields(params: dict[str, Any]) -> dict[str, Any]: f"{arg} is required, please set it in the command arguments or environment variables" ) return env_overrides + + +def validate_weights_path(params: dict[str, Any], model_name: str) -> None: + """Validate that the model weights path exists or a HF model is provided. + + If cached weights exist and ``hf_model`` is also set, a warning is issued + and ``hf_model`` is removed (cached weights take priority). If no cached + weights exist and no ``hf_model`` is provided, a :class:`FileNotFoundError` + is raised. If ``hf_model`` is set without cached weights, + :func:`check_hf_cache_and_bind` is called to verify cache configuration. + + Parameters + ---------- + params : dict[str, Any] + Launch parameters dict; must contain ``model_weights_parent_dir`` and + may contain ``hf_model``, ``env``, and ``bind``. + model_name : str + Name of the model being validated. + + Raises + ------ + FileNotFoundError + If the model weights path does not exist and no ``hf_model`` is provided. + """ + model_weights_path = Path( + params["model_weights_parent_dir"], model_name + ).expanduser() + + if model_weights_path.exists() and params.get("hf_model"): + warnings.warn( + f"Model weights found at '{model_weights_path}' but 'hf_model' " + f"parameter is also set. Cached weights take priority, so 'hf_model' " + f"will be ignored.", + UserWarning, + stacklevel=4, + ) + del params["hf_model"] + return + + if not model_weights_path.exists(): + if not params.get("hf_model"): + raise FileNotFoundError( + f"Model weights path '{model_weights_path}' does not exist, and no HF path provided" + ) + check_hf_cache_and_bind(params, model_name) + + +def check_hf_cache_and_bind(params: dict[str, Any], model_name: str) -> None: + """Check HF cache configuration and update bind mounts if needed. + + Inspects the ``env`` dict inside *params* for HuggingFace cache variables + (``HF_HOME``, ``HF_HUB_CACHE``, ``HUGGINGFACE_HUB_CACHE``). If none are + set, a warning is issued. If any are set, their values are added to the + ``bind`` string so the container can access the cache directory. + + Parameters + ---------- + params : dict[str, Any] + Launch parameters dict; may contain ``env`` and ``bind``. + model_name : str + Name of the model (used in the warning message). + """ + hf_cache_vars = ["HF_HOME", "HF_HUB_CACHE", "HUGGINGFACE_HUB_CACHE"] + env_vars = params.get("env", {}) + set_cache_values = {env_vars[var] for var in hf_cache_vars if var in env_vars} + + if not set_cache_values: + warnings.warn( + f"Model weights for '{model_name}' will be downloaded, but no HuggingFace " + f"cache directory is set (HF_HOME, HF_HUB_CACHE, or HUGGINGFACE_HUB_CACHE). " + f"The model may be downloaded to your home directory, which could consume " + f"your storage quota. Consider setting one of these environment variables " + f"to a shared cache location.", + UserWarning, + stacklevel=5, + ) + return + + bind_str = params.get("bind", "") + existing_hosts = ( + {b.split(":")[0] for b in bind_str.split(",") if b.strip()} + if bind_str + else set() + ) + + new_paths = set_cache_values - existing_hosts + if new_paths: + all_binds = [bind_str] + list(new_paths) if bind_str else list(new_paths) + params["bind"] = ",".join(all_binds) diff --git a/vec_inf/client/config.py b/vec_inf/client/config.py index eee9742..6a7c40a 100644 --- a/vec_inf/client/config.py +++ b/vec_inf/client/config.py @@ -66,6 +66,9 @@ class ModelConfig(BaseModel): Directory path for storing logs model_weights_parent_dir : Path, optional Base directory containing model weights + hf_model : str, optional + HuggingFace model id for vLLM to download (e.g. "meta-llama/Llama-3.1-8B"). + Used as model source when local weights don't exist. engine: str, optional Inference engine to be used, supports 'vllm' and 'sglang' vllm_args : dict[str, Any], optional @@ -154,6 +157,13 @@ class ModelConfig(BaseModel): default=Path(DEFAULT_ARGS["model_weights_parent_dir"]), description="Base directory for model weights", ) + hf_model: Optional[str] = Field( + default=None, + description=( + "Full HuggingFace model id/path to use for vLLM serve (e.g. " + "'meta-llama/Meta-Llama-3.1-8B-Instruct')." + ), + ) engine: Optional[str] = Field( default="vllm", description="Inference engine to be used, supports 'vllm' and 'sglang'", diff --git a/vec_inf/client/models.py b/vec_inf/client/models.py index 2d69ce9..76d1b1e 100644 --- a/vec_inf/client/models.py +++ b/vec_inf/client/models.py @@ -222,6 +222,9 @@ class LaunchOptions: Directory for logs model_weights_parent_dir : str, optional Parent directory containing model weights + hf_model : str, optional + HuggingFace model id for vLLM to download (e.g. "meta-llama/Llama-3.1-8B"). + Used as model source when local weights don't exist. engine: str, optional Inference engine to use vllm_args : str, optional @@ -254,6 +257,7 @@ class LaunchOptions: venv: Optional[str] = None log_dir: Optional[str] = None model_weights_parent_dir: Optional[str] = None + hf_model: Optional[str] = None engine: Optional[str] = None vllm_args: Optional[str] = None sglang_args: Optional[str] = None