Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/voice_agent/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ export PYTHONPATH=$NEMO_PATH:$PYTHONPATH
# export HF_TOKEN="hf_..." # Use your own HuggingFace API token if needed, as some models may require.
# export HF_HUB_CACHE="/path/to/your/huggingface/cache" # change where HF cache is stored if you don't want to use the default cache
# export SERVER_CONFIG_PATH="/path/to/your/server/config.yaml" # change to the server config you want to use, otherwise it will use the default config in `server/server_configs/default.yaml`
# export SERVER_PUBLIC_HOST="127.0.0.1" # set this to the host/IP clients should use for the WebSocket server
python ./server/server.py
```

Expand Down Expand Up @@ -308,4 +309,3 @@ For details of available NVIDIA NIM services, please refer to:
## Contributing

We welcome contributions to this project. Please feel free to submit a pull request or open an issue.

11 changes: 6 additions & 5 deletions examples/voice_agent/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, Request, WebSocket
from fastapi import FastAPI, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from omegaconf import OmegaConf
Expand All @@ -36,6 +36,7 @@
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frameworks.rtvi import RTVIAction, RTVIConfig, RTVIObserverParams, RTVIProcessor
from pipecat.serializers.protobuf import ProtobufFrameSerializer
from websocket_url import build_websocket_url

from nemo.agents.voice_agent.pipecat.processors.frameworks.rtvi import RTVIObserver
from nemo.agents.voice_agent.pipecat.services.nemo.audio_logger import AudioLogger, RTVIAudioLoggerObserver
Expand Down Expand Up @@ -92,6 +93,8 @@ def setup_logging():
RECORD_AUDIO_DATA = server_config.transport.get("record_audio_data", False)
AUDIO_LOG_DIR = server_config.transport.get("audio_log_dir", "./audio_logs")
SERVER_HOST = os.getenv("SERVER_HOST", "0.0.0.0")
SERVER_PUBLIC_HOST = os.getenv("SERVER_PUBLIC_HOST", "127.0.0.1")
WEBSOCKET_SCHEME = os.getenv("WEBSOCKET_SCHEME", "ws")
WEBSOCKET_PORT = int(os.getenv("WEBSOCKET_PORT", 8765))
FASTAPI_PORT = int(os.getenv("FASTAPI_PORT", 7860))

Expand Down Expand Up @@ -424,11 +427,9 @@ async def websocket_endpoint(websocket: WebSocket):


@app.post("/connect")
async def bot_connect(request: Request) -> Dict[Any, Any]:
async def bot_connect() -> Dict[Any, Any]:
print("Received /connect request")
# Use the host that the client connected to (from the request)
server_host = request.url.hostname or request.headers.get("host", "").split(":")[0]
ws_url = f"ws://{server_host}:{WEBSOCKET_PORT}"
ws_url = build_websocket_url(SERVER_PUBLIC_HOST, WEBSOCKET_PORT, WEBSOCKET_SCHEME)
print(f"Returning WebSocket URL: {ws_url}")
return {"ws_url": ws_url}

Expand Down
43 changes: 43 additions & 0 deletions examples/voice_agent/server/websocket_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from urllib.parse import urlsplit


def _normalize_websocket_scheme(scheme: str) -> str:
scheme = scheme.strip().lower()
if scheme not in {"ws", "wss"}:
raise ValueError("WEBSOCKET_SCHEME must be either 'ws' or 'wss'")
return scheme


def _normalize_websocket_host(host: str) -> str:
host = host.strip()
if not host:
raise ValueError("SERVER_PUBLIC_HOST must not be empty")
if "://" in host:
parsed_host = urlsplit(host).hostname
if not parsed_host:
raise ValueError("SERVER_PUBLIC_HOST must include a host name")
host = parsed_host
return host


def build_websocket_url(host: str, port: int, scheme: str = "ws") -> str:
"""Build the client-facing WebSocket URL from trusted server configuration."""
scheme = _normalize_websocket_scheme(scheme)
host = _normalize_websocket_host(host)
if ":" in host and not host.startswith("["):
host = f"[{host}]"
return f"{scheme}://{host}:{port}"
45 changes: 45 additions & 0 deletions examples/voice_agent/tests/test_websocket_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
from pathlib import Path

import pytest

voice_agent_server_path = Path(__file__).resolve().parents[1] / "server"
sys.path.insert(0, str(voice_agent_server_path))

from websocket_url import build_websocket_url


@pytest.mark.unit
def test_build_websocket_url_uses_configured_host():
assert build_websocket_url("voice-agent.example", 8765) == "ws://voice-agent.example:8765"


@pytest.mark.unit
def test_build_websocket_url_does_not_accept_request_host():
forged_request_host = "evil.example"

ws_url = build_websocket_url("voice-agent.example", 8765)

assert forged_request_host not in ws_url
assert ws_url == "ws://voice-agent.example:8765"


@pytest.mark.unit
@pytest.mark.parametrize("scheme", ["http", "https", ""])
def test_build_websocket_url_rejects_non_websocket_schemes(scheme):
with pytest.raises(ValueError):
build_websocket_url("voice-agent.example", 8765, scheme=scheme)
Loading