From f39d27c1c2274f5b0f818a866597866e53712895 Mon Sep 17 00:00:00 2001 From: Charlie Truong Date: Thu, 18 Jun 2026 10:42:03 -0500 Subject: [PATCH] Fix voice agent websocket host handling Signed-off-by: Charlie Truong --- examples/voice_agent/README.md | 2 +- examples/voice_agent/server/server.py | 11 ++--- examples/voice_agent/server/websocket_url.py | 43 ++++++++++++++++++ .../voice_agent/tests/test_websocket_url.py | 45 +++++++++++++++++++ 4 files changed, 95 insertions(+), 6 deletions(-) create mode 100644 examples/voice_agent/server/websocket_url.py create mode 100644 examples/voice_agent/tests/test_websocket_url.py diff --git a/examples/voice_agent/README.md b/examples/voice_agent/README.md index 2360a19b5e90..248ddfe22754 100644 --- a/examples/voice_agent/README.md +++ b/examples/voice_agent/README.md @@ -108,6 +108,7 @@ export PYTHONPATH=$NEMO_PATH:$PYTHONPATH # export HF_TOKEN="hf_..." # Use your own HuggingFace API token if needed, as some models may require. # export HF_HUB_CACHE="/path/to/your/huggingface/cache" # change where HF cache is stored if you don't want to use the default cache # export SERVER_CONFIG_PATH="/path/to/your/server/config.yaml" # change to the server config you want to use, otherwise it will use the default config in `server/server_configs/default.yaml` +# export SERVER_PUBLIC_HOST="127.0.0.1" # set this to the host/IP clients should use for the WebSocket server python ./server/server.py ``` @@ -308,4 +309,3 @@ For details of available NVIDIA NIM services, please refer to: ## Contributing We welcome contributions to this project. Please feel free to submit a pull request or open an issue. - diff --git a/examples/voice_agent/server/server.py b/examples/voice_agent/server/server.py index 592ebc774586..74db8ca035b7 100644 --- a/examples/voice_agent/server/server.py +++ b/examples/voice_agent/server/server.py @@ -24,7 +24,7 @@ import uvicorn from dotenv import load_dotenv -from fastapi import FastAPI, Request, WebSocket +from fastapi import FastAPI, WebSocket from fastapi.middleware.cors import CORSMiddleware from loguru import logger from omegaconf import OmegaConf @@ -36,6 +36,7 @@ from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.processors.frameworks.rtvi import RTVIAction, RTVIConfig, RTVIObserverParams, RTVIProcessor from pipecat.serializers.protobuf import ProtobufFrameSerializer +from websocket_url import build_websocket_url from nemo.agents.voice_agent.pipecat.processors.frameworks.rtvi import RTVIObserver from nemo.agents.voice_agent.pipecat.services.nemo.audio_logger import AudioLogger, RTVIAudioLoggerObserver @@ -92,6 +93,8 @@ def setup_logging(): RECORD_AUDIO_DATA = server_config.transport.get("record_audio_data", False) AUDIO_LOG_DIR = server_config.transport.get("audio_log_dir", "./audio_logs") SERVER_HOST = os.getenv("SERVER_HOST", "0.0.0.0") +SERVER_PUBLIC_HOST = os.getenv("SERVER_PUBLIC_HOST", "127.0.0.1") +WEBSOCKET_SCHEME = os.getenv("WEBSOCKET_SCHEME", "ws") WEBSOCKET_PORT = int(os.getenv("WEBSOCKET_PORT", 8765)) FASTAPI_PORT = int(os.getenv("FASTAPI_PORT", 7860)) @@ -424,11 +427,9 @@ async def websocket_endpoint(websocket: WebSocket): @app.post("/connect") -async def bot_connect(request: Request) -> Dict[Any, Any]: +async def bot_connect() -> Dict[Any, Any]: print("Received /connect request") - # Use the host that the client connected to (from the request) - server_host = request.url.hostname or request.headers.get("host", "").split(":")[0] - ws_url = f"ws://{server_host}:{WEBSOCKET_PORT}" + ws_url = build_websocket_url(SERVER_PUBLIC_HOST, WEBSOCKET_PORT, WEBSOCKET_SCHEME) print(f"Returning WebSocket URL: {ws_url}") return {"ws_url": ws_url} diff --git a/examples/voice_agent/server/websocket_url.py b/examples/voice_agent/server/websocket_url.py new file mode 100644 index 000000000000..cb8279c34905 --- /dev/null +++ b/examples/voice_agent/server/websocket_url.py @@ -0,0 +1,43 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from urllib.parse import urlsplit + + +def _normalize_websocket_scheme(scheme: str) -> str: + scheme = scheme.strip().lower() + if scheme not in {"ws", "wss"}: + raise ValueError("WEBSOCKET_SCHEME must be either 'ws' or 'wss'") + return scheme + + +def _normalize_websocket_host(host: str) -> str: + host = host.strip() + if not host: + raise ValueError("SERVER_PUBLIC_HOST must not be empty") + if "://" in host: + parsed_host = urlsplit(host).hostname + if not parsed_host: + raise ValueError("SERVER_PUBLIC_HOST must include a host name") + host = parsed_host + return host + + +def build_websocket_url(host: str, port: int, scheme: str = "ws") -> str: + """Build the client-facing WebSocket URL from trusted server configuration.""" + scheme = _normalize_websocket_scheme(scheme) + host = _normalize_websocket_host(host) + if ":" in host and not host.startswith("["): + host = f"[{host}]" + return f"{scheme}://{host}:{port}" diff --git a/examples/voice_agent/tests/test_websocket_url.py b/examples/voice_agent/tests/test_websocket_url.py new file mode 100644 index 000000000000..efdcedf3c768 --- /dev/null +++ b/examples/voice_agent/tests/test_websocket_url.py @@ -0,0 +1,45 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pathlib import Path + +import pytest + +voice_agent_server_path = Path(__file__).resolve().parents[1] / "server" +sys.path.insert(0, str(voice_agent_server_path)) + +from websocket_url import build_websocket_url + + +@pytest.mark.unit +def test_build_websocket_url_uses_configured_host(): + assert build_websocket_url("voice-agent.example", 8765) == "ws://voice-agent.example:8765" + + +@pytest.mark.unit +def test_build_websocket_url_does_not_accept_request_host(): + forged_request_host = "evil.example" + + ws_url = build_websocket_url("voice-agent.example", 8765) + + assert forged_request_host not in ws_url + assert ws_url == "ws://voice-agent.example:8765" + + +@pytest.mark.unit +@pytest.mark.parametrize("scheme", ["http", "https", ""]) +def test_build_websocket_url_rejects_non_websocket_schemes(scheme): + with pytest.raises(ValueError): + build_websocket_url("voice-agent.example", 8765, scheme=scheme)