diff --git a/tests/unit/vertexai/genai/test_agent_engines.py b/tests/unit/vertexai/genai/test_agent_engines.py index 55e240533b..3c0637626c 100644 --- a/tests/unit/vertexai/genai/test_agent_engines.py +++ b/tests/unit/vertexai/genai/test_agent_engines.py @@ -552,6 +552,12 @@ def register_operations(self) -> Dict[str, List[str]]: _genai_types.IdentityType.SERVICE_ACCOUNT ) _TEST_AGENT_ENGINE_ENCRYPTION_SPEC = {"kms_key_name": "test-kms-key"} +_TEST_AGENT_ENGINE_KEEP_ALIVE_PROBE = { + "http_get": { + "path": "/health", + }, + "max_seconds": 60, +} _TEST_AGENT_ENGINE_SPEC = _genai_types.ReasoningEngineSpecDict( agent_framework=_TEST_AGENT_ENGINE_FRAMEWORK, class_methods=[_TEST_AGENT_ENGINE_CLASS_METHOD_1], @@ -1087,6 +1093,7 @@ def test_create_agent_engine_config_with_source_packages( config["spec"]["identity_type"] == _TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT ) + assert "keep_alive_probe" not in config["spec"].get("deployment_spec", {}) def test_create_agent_engine_config_with_developer_connect_source(self): with tempfile.TemporaryDirectory() as tmpdir: @@ -1128,6 +1135,29 @@ def test_create_agent_engine_config_with_developer_connect_source(self): config["spec"]["identity_type"] == _TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT ) + assert "keep_alive_probe" not in config["spec"].get("deployment_spec", {}) + + @mock.patch.object( + _agent_engines_utils, + "_create_base64_encoded_tarball", + return_value="test_tarball", + ) + def test_create_agent_engine_config_with_empty_keep_alive_probe( + self, mock_create_base64_encoded_tarball + ): + with tempfile.TemporaryDirectory() as tmpdir: + test_file_path = os.path.join(tmpdir, "test_file.txt") + with open(test_file_path, "w") as f: + f.write("test content") + config = self.client.agent_engines._create_config( + mode="create", + source_packages=[test_file_path], + class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS, + entrypoint_module="main", + entrypoint_object="app", + keep_alive_probe={}, + ) + assert "keep_alive_probe" in config["spec"].get("deployment_spec", {}) def test_create_agent_engine_config_with_agent_config_source_and_requirements_file( self, @@ -1337,6 +1367,33 @@ def test_create_agent_engine_config_with_container_spec(self): config["spec"]["identity_type"] == _TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT ) + assert "keep_alive_probe" not in config["spec"].get("deployment_spec", {}) + + def test_create_agent_engine_config_with_container_spec_and_keep_alive_probe( + self, + ): + container_spec = {"image_uri": "gcr.io/test-project/test-image"} + config = self.client.agent_engines._create_config( + mode="create", + display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, + description=_TEST_AGENT_ENGINE_DESCRIPTION, + container_spec=container_spec, + class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS, + identity_type=_TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT, + keep_alive_probe=_TEST_AGENT_ENGINE_KEEP_ALIVE_PROBE, + ) + assert config["display_name"] == _TEST_AGENT_ENGINE_DISPLAY_NAME + assert config["description"] == _TEST_AGENT_ENGINE_DESCRIPTION + assert config["spec"]["container_spec"] == container_spec + assert config["spec"]["class_methods"] == _TEST_AGENT_ENGINE_CLASS_METHODS + assert ( + config["spec"]["identity_type"] + == _TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT + ) + assert ( + config["spec"]["deployment_spec"]["keep_alive_probe"] + == _TEST_AGENT_ENGINE_KEEP_ALIVE_PROBE + ) def test_create_agent_engine_config_with_container_spec_and_others_raises(self): container_spec = {"image_uri": "gcr.io/test-project/test-image"} @@ -2133,6 +2190,7 @@ def test_create_agent_engine_with_env_vars_dict( image_spec=None, agent_config_source=None, container_spec=None, + keep_alive_probe=None, ) request_mock.assert_called_with( "post", @@ -2238,6 +2296,7 @@ def test_create_agent_engine_with_custom_service_account( image_spec=None, agent_config_source=None, container_spec=None, + keep_alive_probe=None, ) request_mock.assert_called_with( "post", @@ -2342,6 +2401,7 @@ def test_create_agent_engine_with_experimental_mode( image_spec=None, agent_config_source=None, container_spec=None, + keep_alive_probe=None, ) request_mock.assert_called_with( "post", @@ -2515,6 +2575,7 @@ def test_create_agent_engine_with_class_methods( image_spec=None, agent_config_source=None, container_spec=None, + keep_alive_probe=None, ) request_mock.assert_called_with( "post", @@ -2614,6 +2675,7 @@ def test_create_agent_engine_with_agent_framework( image_spec=None, agent_config_source=None, container_spec=None, + keep_alive_probe=None, ) request_mock.assert_called_with( "post", @@ -2816,6 +2878,109 @@ def test_update_agent_engine_env_vars( None, ) + @mock.patch.object(_agent_engines_utils, "_prepare") + @mock.patch.object(_agent_engines_utils, "_await_operation") + def test_update_agent_engine_with_empty_keep_alive_probe( + self, mock_await_operation, mock_prepare + ): + mock_await_operation.return_value = _genai_types.AgentEngineOperation( + response=_genai_types.ReasoningEngine( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + spec=_TEST_AGENT_ENGINE_SPEC, + ) + ) + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + self.client.agent_engines.update( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + agent=self.test_agent, + config=_genai_types.AgentEngineConfig( + staging_bucket=_TEST_STAGING_BUCKET, + keep_alive_probe={}, + ), + ) + update_mask = ",".join( + [ + "spec.package_spec.pickle_object_gcs_uri", + "spec.package_spec.requirements_gcs_uri", + "spec.class_methods", + "spec.deployment_spec.keep_alive_probe", + "spec.agent_framework", + ] + ) + query_params = {"updateMask": update_mask} + request_mock.assert_called_with( + "patch", + f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}?{urlencode(query_params)}", + { + "_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}, + "spec": { + "agent_framework": _TEST_AGENT_ENGINE_FRAMEWORK, + "class_methods": mock.ANY, + "package_spec": { + "python_version": _TEST_PYTHON_VERSION, + "pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI, + "requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, + }, + "deployment_spec": {"keep_alive_probe": {}}, + }, + "_query": {"updateMask": update_mask}, + }, + None, + ) + + @mock.patch.object(_agent_engines_utils, "_await_operation") + def test_update_agent_engine_with_container_spec_and_keep_alive_probe( + self, mock_await_operation + ): + mock_await_operation.return_value = _genai_types.AgentEngineOperation( + response=_genai_types.ReasoningEngine( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + spec=_TEST_AGENT_ENGINE_SPEC, + ) + ) + container_spec = {"image_uri": "gcr.io/test-project/test-image"} + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + self.client.agent_engines.update( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + config=_genai_types.AgentEngineConfig( + container_spec=container_spec, + keep_alive_probe=_TEST_AGENT_ENGINE_KEEP_ALIVE_PROBE, + class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS, + ), + ) + update_mask = ",".join( + [ + "spec.class_methods", + "spec.container_spec", + "spec.deployment_spec.keep_alive_probe", + "spec.agent_framework", + ] + ) + query_params = {"updateMask": update_mask} + request_mock.assert_called_with( + "patch", + f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}?{urlencode(query_params)}", + { + "_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}, + "spec": { + "agent_framework": "custom", + "container_spec": container_spec, + "deployment_spec": { + "keep_alive_probe": _TEST_AGENT_ENGINE_KEEP_ALIVE_PROBE, + }, + "class_methods": mock.ANY, + }, + "_query": {"updateMask": update_mask}, + }, + None, + ) + @mock.patch.object(_agent_engines_utils, "_await_operation") def test_update_agent_engine_display_name(self, mock_await_operation): mock_await_operation.return_value = _genai_types.AgentEngineOperation( diff --git a/vertexai/_genai/agent_engines.py b/vertexai/_genai/agent_engines.py index d75ef16910..6af7a6b59d 100644 --- a/vertexai/_genai/agent_engines.py +++ b/vertexai/_genai/agent_engines.py @@ -1917,6 +1917,11 @@ def create( agent_config_source = config.agent_config_source if agent_config_source is not None: agent_config_source = json.loads(agent_config_source.model_dump_json()) + keep_alive_probe = config.keep_alive_probe + if keep_alive_probe is not None: + keep_alive_probe = json.loads( + keep_alive_probe.model_dump_json(exclude_none=True) + ) if agent and agent_engine: raise ValueError("Please specify only one of `agent` or `agent_engine`.") elif agent_engine: @@ -1958,6 +1963,7 @@ def create( image_spec=config.image_spec, agent_config_source=agent_config_source, container_spec=config.container_spec, + keep_alive_probe=keep_alive_probe, ) operation = self._create(config=api_config) reasoning_engine_id = _agent_engines_utils._get_reasoning_engine_id( @@ -2269,6 +2275,7 @@ def _create_config( types.ReasoningEngineSpecSourceCodeSpecAgentConfigSourceDict ] = None, container_spec: Optional[types.ReasoningEngineSpecContainerSpecDict] = None, + keep_alive_probe: Optional[dict[str, Any]] = None, ) -> types.UpdateAgentEngineConfigDict: import sys @@ -2399,14 +2406,15 @@ def _create_config( or max_instances is not None or resource_limits is not None or container_concurrency is not None + or keep_alive_probe is not None ) if agent_engine_spec is None and is_deployment_spec_updated: raise ValueError( "To update `env_vars`, `psc_interface_config`, `min_instances`, " - "`max_instances`, `resource_limits`, or `container_concurrency`, " - "you must also provide the `agent` variable or the source code " - "options (`source_packages`, `developer_connect_source` or " - "`agent_config_source`)." + "`max_instances`, `resource_limits`, `container_concurrency`, or " + "`keep_alive_probe`, you must also provide the `agent` variable or " + "the source code options (`source_packages`, " + "`developer_connect_source` or `agent_config_source`)." ) if agent_engine_spec is not None: @@ -2422,6 +2430,7 @@ def _create_config( max_instances=max_instances, resource_limits=resource_limits, container_concurrency=container_concurrency, + keep_alive_probe=keep_alive_probe, ) update_masks.extend(deployment_update_masks) agent_engine_spec["deployment_spec"] = deployment_spec @@ -2487,6 +2496,7 @@ def _generate_deployment_spec_or_raise( max_instances: Optional[int] = None, resource_limits: Optional[dict[str, str]] = None, container_concurrency: Optional[int] = None, + keep_alive_probe: Optional[dict[str, Any]] = None, ) -> Tuple[dict[str, Any], Sequence[str]]: deployment_spec: dict[str, Any] = {} update_masks = [] @@ -2537,6 +2547,9 @@ def _generate_deployment_spec_or_raise( if container_concurrency: deployment_spec["container_concurrency"] = container_concurrency update_masks.append("spec.deployment_spec.container_concurrency") + if keep_alive_probe is not None: + deployment_spec["keep_alive_probe"] = keep_alive_probe + update_masks.append("spec.deployment_spec.keep_alive_probe") return deployment_spec, update_masks def _update_deployment_spec_with_env_vars_dict_or_raise( @@ -2678,6 +2691,11 @@ def update( agent_config_source = config.agent_config_source if agent_config_source is not None: agent_config_source = json.loads(agent_config_source.model_dump_json()) + keep_alive_probe = config.keep_alive_probe + if keep_alive_probe is not None: + keep_alive_probe = json.loads( + keep_alive_probe.model_dump_json(exclude_none=True) + ) if agent and agent_engine: raise ValueError("Please specify only one of `agent` or `agent_engine`.") elif agent_engine: @@ -2725,6 +2743,7 @@ def update( image_spec=image_spec, agent_config_source=agent_config_source, container_spec=container_spec, + keep_alive_probe=keep_alive_probe, ) operation = self._update(name=name, config=api_config) reasoning_engine_id = _agent_engines_utils._get_reasoning_engine_id( diff --git a/vertexai/_genai/types/common.py b/vertexai/_genai/types/common.py index fd771b8cb4..56eecc5d2b 100644 --- a/vertexai/_genai/types/common.py +++ b/vertexai/_genai/types/common.py @@ -7317,7 +7317,7 @@ class KeepAliveProbeHttpGet(_common.BaseModel): path: Optional[str] = Field( default=None, - description="""Required. Specifies the path of the HTTP GET request (e.g., `"/is_busy"`).""", + description="""Required. Specifies the path of the HTTP GET request (e.g., "/is_busy").""", ) port: Optional[int] = Field( default=None, @@ -7329,7 +7329,7 @@ class KeepAliveProbeHttpGetDict(TypedDict, total=False): """Specifies the HTTP GET configuration for the probe.""" path: Optional[str] - """Required. Specifies the path of the HTTP GET request (e.g., `"/is_busy"`).""" + """Required. Specifies the path of the HTTP GET request (e.g., "/is_busy").""" port: Optional[int] """Optional. Specifies the port number on the container to which the request is sent.""" @@ -8250,6 +8250,12 @@ class CreateAgentEngineConfig(_common.BaseModel): default=None, description="""Agent Gateway configuration for a Reasoning Engine deployment.""", ) + keep_alive_probe: Optional[KeepAliveProbe] = Field( + default=None, + description="""Optional. Specifies the configuration for keep-alive probe. + Contains configuration on a specified endpoint that a deployment host + should use to keep the container alive based on the probe settings.""", + ) class CreateAgentEngineConfigDict(TypedDict, total=False): @@ -8386,6 +8392,11 @@ class CreateAgentEngineConfigDict(TypedDict, total=False): ] """Agent Gateway configuration for a Reasoning Engine deployment.""" + keep_alive_probe: Optional[KeepAliveProbeDict] + """Optional. Specifies the configuration for keep-alive probe. + Contains configuration on a specified endpoint that a deployment host + should use to keep the container alive based on the probe settings.""" + CreateAgentEngineConfigOrDict = Union[ CreateAgentEngineConfig, CreateAgentEngineConfigDict @@ -8908,6 +8919,12 @@ class UpdateAgentEngineConfig(_common.BaseModel): default=None, description="""Agent Gateway configuration for a Reasoning Engine deployment.""", ) + keep_alive_probe: Optional[KeepAliveProbe] = Field( + default=None, + description="""Optional. Specifies the configuration for keep-alive probe. + Contains configuration on a specified endpoint that a deployment host + should use to keep the container alive based on the probe settings.""", + ) update_mask: Optional[str] = Field( default=None, description="""The update mask to apply. For the `FieldMask` definition, see @@ -9049,6 +9066,11 @@ class UpdateAgentEngineConfigDict(TypedDict, total=False): ] """Agent Gateway configuration for a Reasoning Engine deployment.""" + keep_alive_probe: Optional[KeepAliveProbeDict] + """Optional. Specifies the configuration for keep-alive probe. + Contains configuration on a specified endpoint that a deployment host + should use to keep the container alive based on the probe settings.""" + update_mask: Optional[str] """The update mask to apply. For the `FieldMask` definition, see https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask.""" @@ -17096,6 +17118,12 @@ class AgentEngineConfig(_common.BaseModel): default=None, description="""Agent Gateway configuration for a Reasoning Engine deployment.""", ) + keep_alive_probe: Optional[KeepAliveProbe] = Field( + default=None, + description="""Optional. Specifies the configuration for keep-alive probe. + Contains configuration on a specified endpoint that a deployment host + should use to keep the container alive based on the probe settings.""", + ) class AgentEngineConfigDict(TypedDict, total=False): @@ -17275,6 +17303,11 @@ class AgentEngineConfigDict(TypedDict, total=False): ] """Agent Gateway configuration for a Reasoning Engine deployment.""" + keep_alive_probe: Optional[KeepAliveProbeDict] + """Optional. Specifies the configuration for keep-alive probe. + Contains configuration on a specified endpoint that a deployment host + should use to keep the container alive based on the probe settings.""" + AgentEngineConfigOrDict = Union[AgentEngineConfig, AgentEngineConfigDict]