Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions tools/mcp/modelopt_mcp/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@

import os
import re
import shlex
import subprocess # nosec B404
import time
from dataclasses import dataclass
from pathlib import Path
from pathlib import Path, PurePosixPath

import yaml

Expand All @@ -52,6 +53,7 @@
_STATUS_FAILURE_WORDS: frozenset[str] = frozenset(
{"failed", "error", "errored", "cancelled", "canceled"}
)
_SAFE_REMOTE_ARTIFACT_PATH = re.compile(r"^[A-Za-z0-9._/@+=:,-]+$")


def _find_launcher_examples_dir() -> Path | None:
Expand Down Expand Up @@ -995,6 +997,19 @@ def provision_passwordless_ssh_dry_run_impl(
}


def _validate_remote_artifact_path(path: str) -> str | None:
"""Return an error reason when a remote artifact path is unsafe."""
parts = path.split("/")
posix_path = PurePosixPath(path)
if posix_path.is_absolute():
return "absolute paths are not supported"
if any(part in {"", ".", ".."} for part in parts):
return "empty, '.', and '..' path segments are not supported"
if not _SAFE_REMOTE_ARTIFACT_PATH.fullmatch(path):
return "only letters, digits, '/', '.', '_', '-', '@', '+', '=', ':', and ',' are allowed"
return None


# ---------------------------------------------------------------------------
# read_cluster_artifact — wraps nemo_run's tunnel
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -1079,9 +1094,22 @@ def read_cluster_artifact_impl(
"bytes": len(content),
}

invalid_path_reason = _validate_remote_artifact_path(path)
if invalid_path_reason:
return {
"ok": False,
"experiment_id": experiment_id,
"path": path,
"reason": "invalid_artifact_path",
"diagnostic": (
"Artifact path must be a relative POSIX path inside the "
f"experiment dir: {invalid_path_reason}."
),
}

# Mode 2: arbitrary path via the experiment's tunnel. nemo_run's
# Experiment loads the executor + tunnel from disk; we rsync the
# named relative path into a local tmp dir, then read it back.
# Experiment loads the executor + tunnel from disk, then reads the
# named relative path.
try:
from nemo_run.run.experiment import Experiment
except ImportError:
Expand Down Expand Up @@ -1149,14 +1177,14 @@ def read_cluster_artifact_impl(
"reason": "no_remote_job_dir",
"diagnostic": (
"Executor / tunnel metadata didn't carry a remote "
"job_dir. Pass the full remote path as `path` instead "
"of a relative one."
"job_dir. Re-submit the experiment with executor metadata "
"that includes job_dir, then pass a relative artifact path."
),
}
remote_path = path if path.startswith("/") else f"{remote_dir}/{experiment_id}/{path}"
remote_path = f"{str(remote_dir).rstrip('/')}/{experiment_id}/{path}"

try:
result = tunnel.run(f"cat {remote_path}", warn=True)
result = tunnel.run(f"cat -- {shlex.quote(remote_path)}", warn=True)
except Exception as e:
return {
"ok": False,
Expand Down
81 changes: 81 additions & 0 deletions tools/mcp/tests/test_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from __future__ import annotations

import subprocess
import sys
import types

import pytest

Expand Down Expand Up @@ -602,6 +604,85 @@ def fake_run(argv, **kwargs):
assert result["reason"] == "logs_fetch_timeout"


def _install_fake_experiment(monkeypatch, experiment):
nemo_run_mod = types.ModuleType("nemo_run")
run_mod = types.ModuleType("nemo_run.run")
experiment_mod = types.ModuleType("nemo_run.run.experiment")

class FakeExperiment:
@classmethod
def from_id(cls, experiment_id):
return experiment

experiment_mod.Experiment = FakeExperiment
nemo_run_mod.run = run_mod
run_mod.experiment = experiment_mod
monkeypatch.setitem(sys.modules, "nemo_run", nemo_run_mod)
monkeypatch.setitem(sys.modules, "nemo_run.run", run_mod)
monkeypatch.setitem(sys.modules, "nemo_run.run.experiment", experiment_mod)


def test_read_cluster_artifact_path_mode_quotes_remote_path(monkeypatch):
"""Relative path mode quotes the remote cat target before invoking the tunnel."""
captured = {}

class FakeTunnel:
def run(self, command, **kwargs):
captured["command"] = command
captured["kwargs"] = kwargs
return types.SimpleNamespace(stdout="artifact content")

executor = types.SimpleNamespace(
job_dir="/lustre/job dir/$(not_a_command)",
tunnel=FakeTunnel(),
)
experiment = types.SimpleNamespace(tasks=[types.SimpleNamespace(executor=executor)])
_install_fake_experiment(monkeypatch, experiment)

result = bridge.read_cluster_artifact_impl(
experiment_id="cicd_42",
path="results/specbench_results.json",
job_idx=0,
)

assert result["ok"] is True
assert result["mode"] == "arbitrary_path"
assert result["content"] == "artifact content"
assert result["remote_path"] == (
"/lustre/job dir/$(not_a_command)/cicd_42/results/specbench_results.json"
)
assert captured["command"] == (
"cat -- '/lustre/job dir/$(not_a_command)/cicd_42/results/specbench_results.json'"
)
assert captured["kwargs"] == {"warn": True}


@pytest.mark.parametrize(
"path",
[
"/etc/passwd",
"../secret.txt",
"results/../secret.txt",
"results//secret.txt",
"results/./secret.txt",
"results/secret; touch pwned",
"results/$(touch pwned)",
"results/`touch pwned`",
"results/secret file.txt",
],
)
def test_read_cluster_artifact_rejects_unsafe_paths(path):
"""Unsafe artifact paths fail before nemo_run import or tunnel execution."""
result = bridge.read_cluster_artifact_impl(
experiment_id="cicd_42",
path=path,
job_idx=0,
)

assert result["ok"] is False
assert result["reason"] == "invalid_artifact_path"


# ---------------------------------------------------------------------------
# open_draft_pr — subprocess mocked
# ---------------------------------------------------------------------------
Expand Down