From f497547ab6710d396b1d73da0371584ae97bbfd3 Mon Sep 17 00:00:00 2001 From: Parman Mohammadalizadeh Date: Thu, 28 May 2026 10:46:08 +0330 Subject: [PATCH] chore: apply ruff formatting and lint fixes Add ruff to dev dependencies and run `ruff format .` and `ruff check --fix .` across the codebase. Reformats 57 files to consistent style and removes unused imports (14 auto-fixed lint errors). Closes #320 Signed-off-by: Parman Mohammadalizadeh --- pyproject.toml | 1 + src/agent/_litellm.py | 2 +- src/agent/claude_agent/cli.py | 4 +- src/agent/claude_agent/runner.py | 34 +++- src/agent/claude_agent/tests/test_runner.py | 23 ++- src/agent/deep_agent/cli.py | 4 +- src/agent/deep_agent/runner.py | 8 +- src/agent/deep_agent/tests/test_runner.py | 47 ++++- src/agent/openai_agent/cli.py | 4 +- src/agent/openai_agent/runner.py | 23 ++- src/agent/openai_agent/tests/test_runner.py | 2 +- src/agent/plan_execute/executor.py | 10 +- src/agent/plan_execute/runner.py | 5 +- src/agent/tests/test_planner.py | 9 +- src/agent/tests/test_runner.py | 167 +++++++++++------ src/couchdb/init_asset_data.py | 20 ++- src/couchdb/init_wo.py | 27 ++- src/evaluation/cli.py | 3 +- src/evaluation/evaluator.py | 1 - src/evaluation/metrics.py | 10 +- src/evaluation/scorers/__init__.py | 4 +- src/evaluation/scorers/llm_judge.py | 4 +- src/evaluation/tests/test_loader.py | 12 +- src/evaluation/tests/test_metrics.py | 47 ++++- src/evaluation/tests/test_models.py | 4 +- src/evaluation/tests/test_report.py | 13 +- src/evaluation/tests/test_runner.py | 12 +- src/llm/base.py | 4 +- src/llm/litellm.py | 4 +- src/observability/persistence.py | 4 +- src/observability/runspan.py | 2 - src/observability/tests/test_persistence.py | 8 +- src/observability/tests/test_tracing.py | 4 +- src/observability/tracing.py | 4 +- src/servers/fmsr/main.py | 19 +- src/servers/fmsr/tests/conftest.py | 6 +- src/servers/fmsr/tests/test_tools.py | 18 +- src/servers/iot/main.py | 6 +- src/servers/iot/tests/conftest.py | 1 + src/servers/iot/tests/test_couchdb.py | 4 +- src/servers/tsfm/forecasting.py | 1 - src/servers/tsfm/io.py | 1 - src/servers/tsfm/main.py | 12 +- src/servers/tsfm/metrics.py | 1 - src/servers/tsfm/tests/conftest.py | 3 +- src/servers/tsfm/tests/test_tools.py | 85 ++++++--- src/servers/utilities/main.py | 13 +- src/servers/utilities/tests/conftest.py | 1 - src/servers/vibration/data_store.py | 4 +- src/servers/vibration/dsp/__init__.py | 14 +- src/servers/vibration/dsp/bearing_freqs.py | 6 +- src/servers/vibration/dsp/envelope.py | 4 +- src/servers/vibration/dsp/fault_detection.py | 8 +- src/servers/vibration/main.py | 6 +- .../generate_synthetic_vibration.py | 52 +++--- src/servers/vibration/tests/test_dsp.py | 7 +- src/servers/vibration/tests/test_mcp_e2e.py | 63 +++++-- src/servers/vibration/tests/test_tools.py | 169 +++++++++++------- src/servers/wo/data.py | 37 +++- src/servers/wo/main.py | 9 +- src/servers/wo/tests/conftest.py | 47 ++++- src/servers/wo/tests/test_integration.py | 149 +++++++++++---- src/servers/wo/tests/test_tools.py | 115 +++++++++--- src/servers/wo/tools.py | 75 ++++++-- uv.lock | 27 +++ 65 files changed, 1072 insertions(+), 421 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7a215c7b3..b82002a39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dev = [ "pytest-anyio>=0.0.0", "opentelemetry-api>=1.27.0", "opentelemetry-sdk>=1.27.0", + "ruff>=0.9.0", ] # Optional heavy ML deps for the TSFM server. # tsfm_public must be installed separately: pip install git+https://github.com/ibm-granite/granite-tsfm diff --git a/src/agent/_litellm.py b/src/agent/_litellm.py index fbad90d10..f9f620515 100644 --- a/src/agent/_litellm.py +++ b/src/agent/_litellm.py @@ -29,5 +29,5 @@ def resolve_model(model_id: str) -> str: "anthropic/claude-sonnet-4-6" -> "anthropic/claude-sonnet-4-6" """ if model_id.startswith(LITELLM_PREFIX): - return model_id[len(LITELLM_PREFIX):] + return model_id[len(LITELLM_PREFIX) :] return model_id diff --git a/src/agent/claude_agent/cli.py b/src/agent/claude_agent/cli.py index 42f29f340..5bb1e30cf 100644 --- a/src/agent/claude_agent/cli.py +++ b/src/agent/claude_agent/cli.py @@ -50,7 +50,9 @@ async def _run(args: argparse.Namespace) -> None: runner = ClaudeAgentRunner(model=args.model_id, max_turns=args.max_turns) result = await runner.run(args.question) - print_result(result, show_trajectory=args.show_trajectory, output_json=args.output_json) + print_result( + result, show_trajectory=args.show_trajectory, output_json=args.output_json + ) def main() -> None: diff --git a/src/agent/claude_agent/runner.py b/src/agent/claude_agent/runner.py index 9a11b3e84..cf75c979e 100644 --- a/src/agent/claude_agent/runner.py +++ b/src/agent/claude_agent/runner.py @@ -21,7 +21,13 @@ import time from pathlib import Path -from claude_agent_sdk import AssistantMessage, ClaudeAgentOptions, HookMatcher, ResultMessage, query +from claude_agent_sdk import ( + AssistantMessage, + ClaudeAgentOptions, + HookMatcher, + ResultMessage, + query, +) from claude_agent_sdk import TextBlock, ToolUseBlock from observability import agent_run_span, persist_trajectory @@ -132,8 +138,14 @@ async def run(self, question: str) -> AgentResult: last_turn_start = run_started tool_outputs: dict[str, object] = {} - async def _capture_tool_output(input_data, tool_use_id: str, context) -> dict: - resp = input_data.get("tool_response") if isinstance(input_data, dict) else input_data + async def _capture_tool_output( + input_data, tool_use_id: str, context + ) -> dict: + resp = ( + input_data.get("tool_response") + if isinstance(input_data, dict) + else input_data + ) if isinstance(resp, dict): tool_outputs[tool_use_id] = resp.get("content", resp) else: @@ -145,7 +157,9 @@ async def _capture_tool_output(input_data, tool_use_id: str, context) -> dict: # per-tool duration for claude-agent is therefore not captured # (matches openai-agent / deep-agent). options.hooks = { - "PostToolUse": [HookMatcher(matcher=".*", hooks=[_capture_tool_output])], + "PostToolUse": [ + HookMatcher(matcher=".*", hooks=[_capture_tool_output]) + ], } def _flush_tool_outputs() -> None: @@ -169,7 +183,9 @@ def _flush_tool_outputs() -> None: text += block.text elif isinstance(block, ToolUseBlock): tool_calls.append( - ToolCall(name=block.name, input=block.input, id=block.id) + ToolCall( + name=block.name, input=block.input, id=block.id + ) ) usage = message.usage or {} trajectory.turns.append( @@ -197,8 +213,12 @@ def _flush_tool_outputs() -> None: duration_ms = (time.perf_counter() - run_started) * 1000 span.set_attribute("agent.answer.length", len(answer)) - span.set_attribute("gen_ai.usage.input_tokens", trajectory.total_input_tokens) - span.set_attribute("gen_ai.usage.output_tokens", trajectory.total_output_tokens) + span.set_attribute( + "gen_ai.usage.input_tokens", trajectory.total_input_tokens + ) + span.set_attribute( + "gen_ai.usage.output_tokens", trajectory.total_output_tokens + ) span.set_attribute("agent.turns", len(trajectory.turns)) span.set_attribute("agent.tool_calls", len(trajectory.all_tool_calls)) span.set_attribute("agent.duration_ms", duration_ms) diff --git a/src/agent/claude_agent/tests/test_runner.py b/src/agent/claude_agent/tests/test_runner.py index 3c4ddbc59..fb65c7baf 100644 --- a/src/agent/claude_agent/tests/test_runner.py +++ b/src/agent/claude_agent/tests/test_runner.py @@ -6,7 +6,7 @@ from __future__ import annotations from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest @@ -113,7 +113,12 @@ async def fake_query(prompt, options): @pytest.mark.anyio async def test_run_collects_trajectory(): - from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock, ToolUseBlock + from claude_agent_sdk import ( + AssistantMessage, + ResultMessage, + TextBlock, + ToolUseBlock, + ) mock_tool = MagicMock(spec=ToolUseBlock) mock_tool.name = "sensors" @@ -157,7 +162,12 @@ async def fake_query(prompt, options): @pytest.mark.anyio async def test_run_tool_output_captured(): """PostToolUse hook output is attached to the matching ToolCall.""" - from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock, ToolUseBlock + from claude_agent_sdk import ( + AssistantMessage, + ResultMessage, + TextBlock, + ToolUseBlock, + ) mock_tool = MagicMock(spec=ToolUseBlock) mock_tool.name = "sensors" @@ -206,7 +216,12 @@ async def fake_query(prompt, options): @pytest.mark.anyio async def test_run_tool_output_string_response(): """PostToolUse hook handles string tool_response (no .get).""" - from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock, ToolUseBlock + from claude_agent_sdk import ( + AssistantMessage, + ResultMessage, + TextBlock, + ToolUseBlock, + ) mock_tool = MagicMock(spec=ToolUseBlock) mock_tool.name = "sites" diff --git a/src/agent/deep_agent/cli.py b/src/agent/deep_agent/cli.py index c02c24c53..2dc7758b2 100644 --- a/src/agent/deep_agent/cli.py +++ b/src/agent/deep_agent/cli.py @@ -56,7 +56,9 @@ async def _run(args: argparse.Namespace) -> None: recursion_limit=args.recursion_limit, ) result = await runner.run(args.question) - print_result(result, show_trajectory=args.show_trajectory, output_json=args.output_json) + print_result( + result, show_trajectory=args.show_trajectory, output_json=args.output_json + ) def main() -> None: diff --git a/src/agent/deep_agent/runner.py b/src/agent/deep_agent/runner.py index 3d975d55c..d9ab31d63 100644 --- a/src/agent/deep_agent/runner.py +++ b/src/agent/deep_agent/runner.py @@ -245,8 +245,12 @@ async def run(self, question: str) -> AgentResult: ) span.set_attribute("agent.answer.length", len(answer)) - span.set_attribute("gen_ai.usage.input_tokens", trajectory.total_input_tokens) - span.set_attribute("gen_ai.usage.output_tokens", trajectory.total_output_tokens) + span.set_attribute( + "gen_ai.usage.input_tokens", trajectory.total_input_tokens + ) + span.set_attribute( + "gen_ai.usage.output_tokens", trajectory.total_output_tokens + ) span.set_attribute("agent.turns", len(trajectory.turns)) span.set_attribute("agent.tool_calls", len(trajectory.all_tool_calls)) span.set_attribute( diff --git a/src/agent/deep_agent/tests/test_runner.py b/src/agent/deep_agent/tests/test_runner.py index 2014612d8..834ada5ea 100644 --- a/src/agent/deep_agent/tests/test_runner.py +++ b/src/agent/deep_agent/tests/test_runner.py @@ -122,12 +122,20 @@ def test_build_trajectory_tool_calls_and_outputs(): AIMessage( content="", tool_calls=[{"name": "sensors", "args": {"asset_id": "CH-6"}, "id": "c1"}], - usage_metadata={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120}, + usage_metadata={ + "input_tokens": 100, + "output_tokens": 20, + "total_tokens": 120, + }, ), ToolMessage(content="5 sensors found", tool_call_id="c1"), AIMessage( content="Chiller 6 has 5 sensors.", - usage_metadata={"input_tokens": 150, "output_tokens": 30, "total_tokens": 180}, + usage_metadata={ + "input_tokens": 150, + "output_tokens": 30, + "total_tokens": 180, + }, ), ] traj = _build_trajectory(messages) @@ -149,7 +157,12 @@ def test_build_trajectory_tool_calls_and_outputs(): def test_build_trajectory_list_content(): messages = [ - AIMessage(content=[{"type": "text", "text": "part one "}, {"type": "text", "text": "part two"}]) + AIMessage( + content=[ + {"type": "text", "text": "part one "}, + {"type": "text", "text": "part two"}, + ] + ) ] traj = _build_trajectory(messages) assert traj.turns[0].text == "part one part two" @@ -172,13 +185,21 @@ def test_build_trajectory_multiple_tool_calls_one_turn(): {"name": "sites", "args": {}, "id": "c1"}, {"name": "assets", "args": {"site_id": "MAIN"}, "id": "c2"}, ], - usage_metadata={"input_tokens": 50, "output_tokens": 10, "total_tokens": 60}, + usage_metadata={ + "input_tokens": 50, + "output_tokens": 10, + "total_tokens": 60, + }, ), ToolMessage(content=["MAIN"], tool_call_id="c1"), ToolMessage(content=["Chiller 6"], tool_call_id="c2"), AIMessage( content="Found Chiller 6 at site MAIN.", - usage_metadata={"input_tokens": 80, "output_tokens": 15, "total_tokens": 95}, + usage_metadata={ + "input_tokens": 80, + "output_tokens": 15, + "total_tokens": 95, + }, ), ] traj = _build_trajectory(messages) @@ -242,13 +263,23 @@ async def test_run_collects_trajectory(): HumanMessage(content="What sensors are on Chiller 6?"), AIMessage( content="", - tool_calls=[{"name": "sensors", "args": {"asset_id": "CH-6"}, "id": "c1"}], - usage_metadata={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120}, + tool_calls=[ + {"name": "sensors", "args": {"asset_id": "CH-6"}, "id": "c1"} + ], + usage_metadata={ + "input_tokens": 100, + "output_tokens": 20, + "total_tokens": 120, + }, ), ToolMessage(content="sensor data", tool_call_id="c1"), AIMessage( content="Chiller 6 has 5 sensors.", - usage_metadata={"input_tokens": 150, "output_tokens": 30, "total_tokens": 180}, + usage_metadata={ + "input_tokens": 150, + "output_tokens": 30, + "total_tokens": 180, + }, ), ] } diff --git a/src/agent/openai_agent/cli.py b/src/agent/openai_agent/cli.py index 05e66dea5..b0a3599d2 100644 --- a/src/agent/openai_agent/cli.py +++ b/src/agent/openai_agent/cli.py @@ -52,7 +52,9 @@ async def _run(args: argparse.Namespace) -> None: runner = OpenAIAgentRunner(model=args.model_id, max_turns=args.max_turns) result = await runner.run(args.question) - print_result(result, show_trajectory=args.show_trajectory, output_json=args.output_json) + print_result( + result, show_trajectory=args.show_trajectory, output_json=args.output_json + ) def main() -> None: diff --git a/src/agent/openai_agent/runner.py b/src/agent/openai_agent/runner.py index 8dfccb48d..fc065721b 100644 --- a/src/agent/openai_agent/runner.py +++ b/src/agent/openai_agent/runner.py @@ -25,7 +25,14 @@ from openai import AsyncOpenAI -from agents import Agent, ModelProvider, OpenAIChatCompletionsModel, RunConfig, Runner, set_tracing_disabled +from agents import ( + Agent, + ModelProvider, + OpenAIChatCompletionsModel, + RunConfig, + Runner, + set_tracing_disabled, +) from agents.mcp import MCPServerStdio from observability import agent_run_span, persist_trajectory @@ -144,7 +151,9 @@ def _flush() -> None: tc_id = getattr(raw, "call_id", "") or getattr(raw, "id", "") or "" tc_args = getattr(raw, "arguments", "{}") or "{}" try: - tc_input = json.loads(tc_args) if isinstance(tc_args, str) else tc_args + tc_input = ( + json.loads(tc_args) if isinstance(tc_args, str) else tc_args + ) except (json.JSONDecodeError, TypeError): tc_input = {"raw": tc_args} tool_calls.append(ToolCall(name=tc_name, input=tc_input, id=tc_id)) @@ -258,8 +267,12 @@ async def run(self, question: str) -> AgentResult: ) span.set_attribute("agent.answer.length", len(answer)) - span.set_attribute("gen_ai.usage.input_tokens", trajectory.total_input_tokens) - span.set_attribute("gen_ai.usage.output_tokens", trajectory.total_output_tokens) + span.set_attribute( + "gen_ai.usage.input_tokens", trajectory.total_input_tokens + ) + span.set_attribute( + "gen_ai.usage.output_tokens", trajectory.total_output_tokens + ) span.set_attribute("agent.turns", len(trajectory.turns)) span.set_attribute("agent.tool_calls", len(trajectory.all_tool_calls)) span.set_attribute( @@ -277,5 +290,3 @@ async def run(self, question: str) -> AgentResult: answer=answer, trajectory=trajectory, ) - - diff --git a/src/agent/openai_agent/tests/test_runner.py b/src/agent/openai_agent/tests/test_runner.py index 74468d7d5..030b0a990 100644 --- a/src/agent/openai_agent/tests/test_runner.py +++ b/src/agent/openai_agent/tests/test_runner.py @@ -7,7 +7,7 @@ from pathlib import Path from types import SimpleNamespace -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest diff --git a/src/agent/plan_execute/executor.py b/src/agent/plan_execute/executor.py index 2e27b5144..d785f97ea 100644 --- a/src/agent/plan_execute/executor.py +++ b/src/agent/plan_execute/executor.py @@ -112,7 +112,9 @@ async def execute_plan(self, plan: Plan, question: str) -> list[StepResult]: ) schema = tool_schemas.get(step.server, {}).get(step.tool, "") step_started = time.perf_counter() - result = await self.execute_step(step, context, question, tool_schema=schema) + result = await self.execute_step( + step, context, question, tool_schema=schema + ) result.duration_ms = (time.perf_counter() - step_started) * 1000 if result.success: _log.info("Step %d OK.", step.step_number) @@ -202,8 +204,7 @@ async def _resolve_args_with_llm( f"Step {n}: {r.response}" for n, r in sorted(context.items()) ) prompt = ( - _ARG_RESOLUTION_PROMPT - .replace("{question}", question) + _ARG_RESOLUTION_PROMPT.replace("{question}", question) .replace("{task}", task) .replace("{tool}", tool) .replace("{tool_schema}", tool_schema or "(unknown)") @@ -214,7 +215,8 @@ async def _resolve_args_with_llm( if resolved is None: _log.warning( "Tool '%s': arg resolution returned no parseable JSON (response: %r…)", - tool, raw[:120], + tool, + raw[:120], ) return {} return resolved diff --git a/src/agent/plan_execute/runner.py b/src/agent/plan_execute/runner.py index ea684a650..445f47b09 100644 --- a/src/agent/plan_execute/runner.py +++ b/src/agent/plan_execute/runner.py @@ -50,9 +50,7 @@ def generate(self, prompt: str, temperature: float = 0.0) -> str: self.output_tokens += result.output_tokens return result.text - def generate_with_usage( - self, prompt: str, temperature: float = 0.0 - ) -> LLMResult: + def generate_with_usage(self, prompt: str, temperature: float = 0.0) -> LLMResult: result = self._inner.generate_with_usage(prompt, temperature) self.input_tokens += result.input_tokens self.output_tokens += result.output_tokens @@ -62,6 +60,7 @@ def generate_with_usage( def model_id(self) -> str: return self._inner.model_id + _log = logging.getLogger(__name__) _SUMMARIZE_PROMPT = """\ diff --git a/src/agent/tests/test_planner.py b/src/agent/tests/test_planner.py index 77bc4497d..6ce6e78de 100644 --- a/src/agent/tests/test_planner.py +++ b/src/agent/tests/test_planner.py @@ -144,7 +144,9 @@ def test_generate_plan_uses_llm_output(self, mock_llm): planner = Planner(llm) plan = planner.generate_plan( "List all assets", - {"iot": " - sites(): List sites\n - assets(site_name: string): List assets"}, + { + "iot": " - sites(): List sites\n - assets(site_name: string): List assets" + }, ) assert len(plan.steps) == 2 assert plan.steps[0].server == "iot" @@ -170,7 +172,10 @@ def test_generate_plan_prompt_contains_agent_names(self, mock_llm, monkeypatch): Planner(llm).generate_plan( "Q", - {"iot": " - sites(): List sites", "utilities": " - current_date_time(): Get time"}, + { + "iot": " - sites(): List sites", + "utilities": " - current_date_time(): Get time", + }, ) assert "iot" in captured[0] assert "utilities" in captured[0] diff --git a/src/agent/tests/test_runner.py b/src/agent/tests/test_runner.py index 6bdbac98b..062d78656 100644 --- a/src/agent/tests/test_runner.py +++ b/src/agent/tests/test_runner.py @@ -37,7 +37,11 @@ _MOCK_TOOLS = [ {"name": "sites", "description": "List IoT sites", "parameters": []}, - {"name": "current_date_time", "description": "Get current datetime", "parameters": []}, + { + "name": "current_date_time", + "description": "Get current datetime", + "parameters": [], + }, ] _TOOL_RESPONSE = json.dumps({"sites": ["MAIN"]}) @@ -51,9 +55,13 @@ def _patch_mcp(tool_response: str = _TOOL_RESPONSE): return ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), patch( - "agent.plan_execute.executor._call_tool", new=AsyncMock(return_value=tool_response) + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), + patch( + "agent.plan_execute.executor._call_tool", + new=AsyncMock(return_value=tool_response), ), ) @@ -93,12 +101,14 @@ def generate(self, prompt: str, **_kw) -> str: @pytest.mark.anyio async def test_orchestrator_run_returns_result(sequential_llm): - llm = sequential_llm([ - _TWO_STEP_PLAN, # planner call - _STEP1_ARGS, # arg resolution for step 1 - _STEP2_ARGS, # arg resolution for step 2 - _FINAL_ANSWER, # summarisation - ]) + llm = sequential_llm( + [ + _TWO_STEP_PLAN, # planner call + _STEP1_ARGS, # arg resolution for step 1 + _STEP2_ARGS, # arg resolution for step 2 + _FINAL_ANSWER, # summarisation + ] + ) with _patch_mcp()[0], _patch_mcp()[1]: result = await PlanExecuteRunner(llm).run("What are the IoT sites?") @@ -145,9 +155,7 @@ def __init__(self, items: list[tuple[str, int, int]]) -> None: def generate(self, prompt: str, temperature: float = 0.0) -> str: return self.generate_with_usage(prompt, temperature).text - def generate_with_usage( - self, prompt: str, temperature: float = 0.0 - ) -> LLMResult: + def generate_with_usage(self, prompt: str, temperature: float = 0.0) -> LLMResult: text, in_tok, out_tok = next(self._items, ("", 0, 0)) return LLMResult(text=text, input_tokens=in_tok, output_tokens=out_tok) @@ -155,12 +163,14 @@ def generate_with_usage( @pytest.mark.anyio async def test_orchestrator_accumulates_token_usage_across_llm_calls(): """Plan + 2 arg-resolution + summarise → summed input/output tokens.""" - llm = _UsageReportingLLM([ - (_TWO_STEP_PLAN, 100, 50), # planner - (_STEP1_ARGS, 20, 5), # step 1 arg resolution - (_STEP2_ARGS, 30, 5), # step 2 arg resolution - (_FINAL_ANSWER, 200, 40), # summarise - ]) + llm = _UsageReportingLLM( + [ + (_TWO_STEP_PLAN, 100, 50), # planner + (_STEP1_ARGS, 20, 5), # step 1 arg resolution + (_STEP2_ARGS, 30, 5), # step 2 arg resolution + (_FINAL_ANSWER, 200, 40), # summarise + ] + ) runner = PlanExecuteRunner(llm) with _patch_mcp()[0], _patch_mcp()[1]: await runner.run("Q") @@ -245,8 +255,13 @@ async def test_executor_step_result_carries_resolved_args(sequential_llm): step = _make_step(1, tool="assets") with ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), - patch("agent.plan_execute.executor._call_tool", new=AsyncMock(return_value="{}")), + patch( + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), + patch( + "agent.plan_execute.executor._call_tool", new=AsyncMock(return_value="{}") + ), ): result = await executor.execute_step(step, {}, "List assets at MAIN") @@ -258,13 +273,19 @@ async def test_executor_tool_call_exception_recorded_as_error(sequential_llm): """If _call_tool raises, the error is captured in StepResult (no crash).""" from pathlib import Path - llm = sequential_llm(['{}']) + llm = sequential_llm(["{}"]) executor = Executor(llm, server_paths={"iot": Path("/fake/server.py")}) step = _make_step(1, tool="sites") with ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), - patch("agent.plan_execute.executor._call_tool", new=AsyncMock(side_effect=RuntimeError("timeout"))), + patch( + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), + patch( + "agent.plan_execute.executor._call_tool", + new=AsyncMock(side_effect=RuntimeError("timeout")), + ), ): result = await executor.execute_step(step, {}, "Q") @@ -277,10 +298,12 @@ async def test_executor_calls_llm_to_generate_args(sequential_llm): """Each tool step triggers exactly one LLM call for arg generation.""" from pathlib import Path - llm = sequential_llm([ - '{}', # step 1: sites (no args) - '{"site_name": "MAIN", "asset_id": "CH-1"}', # step 2: sensors - ]) + llm = sequential_llm( + [ + "{}", # step 1: sites (no args) + '{"site_name": "MAIN", "asset_id": "CH-1"}', # step 2: sensors + ] + ) executor = Executor(llm, server_paths={"iot": Path("/fake/server.py")}) plan = Plan( @@ -290,12 +313,17 @@ async def test_executor_calls_llm_to_generate_args(sequential_llm): ], raw="", ) - call_mock = AsyncMock(side_effect=[ - json.dumps({"sites": ["MAIN"]}), - json.dumps({"sensors": ["temp"]}), - ]) + call_mock = AsyncMock( + side_effect=[ + json.dumps({"sites": ["MAIN"]}), + json.dumps({"sensors": ["temp"]}), + ] + ) with ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), + patch( + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), patch("agent.plan_execute.executor._call_tool", new=call_mock), ): results = await executor.execute_plan(plan, "Q") @@ -324,7 +352,10 @@ async def test_executor_prior_step_results_in_llm_prompt(): site_resp = json.dumps({"sites": ["MAIN"]}) call_mock = AsyncMock(side_effect=[site_resp, '{"sensors": []}']) with ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), + patch( + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), patch("agent.plan_execute.executor._call_tool", new=call_mock), ): await executor.execute_plan(plan, "List sensors for CH-1") @@ -338,13 +369,18 @@ async def test_executor_no_prior_context_shows_none_in_prompt(): """When no prior steps exist the prompt contains the literal '(none)'.""" from pathlib import Path - llm = _CapturingLLM('{}') + llm = _CapturingLLM("{}") executor = Executor(llm, server_paths={"iot": Path("/fake/server.py")}) # type: ignore[arg-type] step = _make_step(1, tool="sites") with ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), - patch("agent.plan_execute.executor._call_tool", new=AsyncMock(return_value="{}")), + patch( + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), + patch( + "agent.plan_execute.executor._call_tool", new=AsyncMock(return_value="{}") + ), ): await executor.execute_step(step, {}, "Q") @@ -356,7 +392,7 @@ async def test_executor_context_accumulates_across_steps(): """Step 3's LLM prompt contains results from both steps 1 and 2.""" from pathlib import Path - llm = _CapturingLLM('{}') + llm = _CapturingLLM("{}") executor = Executor(llm, server_paths={"iot": Path("/fake/server.py")}) # type: ignore[arg-type] plan = Plan( @@ -370,7 +406,10 @@ async def test_executor_context_accumulates_across_steps(): resp1, resp2, resp3 = '{"sites":["MAIN"]}', '{"assets":["CH-1"]}', '{"sensors":[]}' call_mock = AsyncMock(side_effect=[resp1, resp2, resp3]) with ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), + patch( + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), patch("agent.plan_execute.executor._call_tool", new=call_mock), ): await executor.execute_plan(plan, "Q") @@ -395,16 +434,21 @@ async def test_pipeline_uses_llm_args_for_each_step(sequential_llm): "#Dependency2: #S1\n" "#ExpectedOutput2: List of assets" ) - llm = sequential_llm([ - planner_output, # planner call - '{}', # arg resolution for step 1 (sites needs no args) - '{"site_name": "MAIN"}', # arg resolution for step 2 (uses step 1 result) - "Final answer.", # summarisation - ]) + llm = sequential_llm( + [ + planner_output, # planner call + "{}", # arg resolution for step 1 (sites needs no args) + '{"site_name": "MAIN"}', # arg resolution for step 2 (uses step 1 result) + "Final answer.", # summarisation + ] + ) call_mock = AsyncMock(side_effect=['{"sites": ["MAIN"]}', '{"assets": ["CH-1"]}']) with ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), + patch( + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), patch("agent.plan_execute.executor._call_tool", new=call_mock), ): result = await PlanExecuteRunner(llm).run("List all assets at site MAIN") @@ -420,10 +464,18 @@ async def test_pipeline_uses_llm_args_for_each_step(sequential_llm): @pytest.mark.anyio async def test_resolve_args_with_llm_uses_context(mock_llm): llm = mock_llm('{"asset_id": "CH-1"}') - ctx = {1: StepResult(step_number=1, task="t", server="a", - response='{"assets": ["CH-1", "CH-2"]}')} + ctx = { + 1: StepResult( + step_number=1, task="t", server="a", response='{"assets": ["CH-1", "CH-2"]}' + ) + } result = await _resolve_args_with_llm( - "What sensors does CH-1 have?", "get sensors", "sensors", "", ctx, llm, + "What sensors does CH-1 have?", + "get sensors", + "sensors", + "", + ctx, + llm, ) assert result["asset_id"] == "CH-1" @@ -440,14 +492,19 @@ async def test_resolve_args_with_llm_fallback_on_bad_json(mock_llm): async def test_resolve_args_with_llm_question_in_prompt(): llm = _CapturingLLM('{"site_name": "MAIN"}') await _resolve_args_with_llm( - "What sites exist?", "List sites", "sites", "", {}, llm # type: ignore[arg-type] + "What sites exist?", + "List sites", + "sites", + "", + {}, + llm, # type: ignore[arg-type] ) assert "What sites exist?" in llm.prompts[0] @pytest.mark.anyio async def test_resolve_args_with_llm_tool_in_prompt(): - llm = _CapturingLLM('{}') + llm = _CapturingLLM("{}") await _resolve_args_with_llm("Q", "List IoT sites", "sites", "", {}, llm) # type: ignore[arg-type] assert "sites" in llm.prompts[0] @@ -465,7 +522,7 @@ async def test_resolve_args_with_llm_schema_in_prompt(): @pytest.mark.anyio async def test_resolve_args_with_llm_unknown_schema_shows_sentinel(): """Empty schema renders as '(unknown)' in the prompt.""" - llm = _CapturingLLM('{}') + llm = _CapturingLLM("{}") await _resolve_args_with_llm("Q", "task", "tool", "", {}, llm) # type: ignore[arg-type] assert "(unknown)" in llm.prompts[0] @@ -473,15 +530,17 @@ async def test_resolve_args_with_llm_unknown_schema_shows_sentinel(): @pytest.mark.anyio async def test_resolve_args_with_llm_context_in_prompt(): """Prior step results appear verbatim in the generated prompt.""" - llm = _CapturingLLM('{}') - ctx = {1: StepResult(step_number=1, task="t", server="a", response="step-one-result")} + llm = _CapturingLLM("{}") + ctx = { + 1: StepResult(step_number=1, task="t", server="a", response="step-one-result") + } await _resolve_args_with_llm("Q", "task", "tool", "", ctx, llm) # type: ignore[arg-type] assert "step-one-result" in llm.prompts[0] @pytest.mark.anyio async def test_resolve_args_with_llm_empty_context_shows_none(): - llm = _CapturingLLM('{}') + llm = _CapturingLLM("{}") await _resolve_args_with_llm("Q", "task", "tool", "", {}, llm) # type: ignore[arg-type] assert "(none)" in llm.prompts[0] diff --git a/src/couchdb/init_asset_data.py b/src/couchdb/init_asset_data.py index 041827068..51da06b01 100644 --- a/src/couchdb/init_asset_data.py +++ b/src/couchdb/init_asset_data.py @@ -31,9 +31,7 @@ # --------------------------------------------------------------------------- _SCRIPT_DIR = os.path.dirname(__file__) -_DEFAULT_DATA_FILE = os.path.join( - _SCRIPT_DIR, "sample_data", "iot", "chiller_6.json" -) +_DEFAULT_DATA_FILE = os.path.join(_SCRIPT_DIR, "sample_data", "iot", "chiller_6.json") COUCHDB_URL = os.environ.get("COUCHDB_URL", "http://localhost:5984") COUCHDB_USERNAME = os.environ.get("COUCHDB_USERNAME", "admin") @@ -92,7 +90,9 @@ def _bulk_insert(db_name: str, docs: list, batch_size: int = 500) -> None: resp.raise_for_status() errors = [r for r in resp.json() if r.get("error")] if errors: - logger.warning("%d bulk-insert errors in batch %d", len(errors), i // batch_size) + logger.warning( + "%d bulk-insert errors in batch %d", len(errors), i // batch_size + ) logger.info( "Inserted batch %d/%d (%d docs)", i // batch_size + 1, @@ -107,10 +107,16 @@ def _bulk_insert(db_name: str, docs: list, batch_size: int = 500) -> None: def main() -> None: - parser = argparse.ArgumentParser(description="Initialize CouchDB IoT asset database from JSON.") - parser.add_argument("--data-file", default=ASSET_DATA_FILE, help="Path to sensor data JSON file") + parser = argparse.ArgumentParser( + description="Initialize CouchDB IoT asset database from JSON." + ) + parser.add_argument( + "--data-file", default=ASSET_DATA_FILE, help="Path to sensor data JSON file" + ) parser.add_argument("--db", default=IOT_DBNAME, help="CouchDB database name") - parser.add_argument("--drop", action="store_true", help="Drop and recreate database if it exists") + parser.add_argument( + "--drop", action="store_true", help="Drop and recreate database if it exists" + ) args = parser.parse_args() logger.info("CouchDB URL: %s", COUCHDB_URL) diff --git a/src/couchdb/init_wo.py b/src/couchdb/init_wo.py index 882aaf6c0..4967bd443 100644 --- a/src/couchdb/init_wo.py +++ b/src/couchdb/init_wo.py @@ -120,8 +120,15 @@ def _bulk_insert(db_name: str, docs: list, batch_size: int = 500) -> None: resp.raise_for_status() errors = [r for r in resp.json() if r.get("error")] if errors: - logger.warning("%d bulk-insert errors in batch %d", len(errors), i // batch_size) - logger.info("Inserted batch %d/%d (%d docs)", i // batch_size + 1, math.ceil(total / batch_size), len(batch)) + logger.warning( + "%d bulk-insert errors in batch %d", len(errors), i // batch_size + ) + logger.info( + "Inserted batch %d/%d (%d docs)", + i // batch_size + 1, + math.ceil(total / batch_size), + len(batch), + ) def _row_to_doc(row: dict, dataset: str, date_cols: dict) -> dict: @@ -148,7 +155,9 @@ def load_dataset(data_dir: str, csv_file: str, dataset: str, date_cols: dict) -> if col in df.columns: df[col] = pd.to_datetime(df[col], format=fmt, errors="coerce") - docs = [_row_to_doc(row, dataset, date_cols) for row in df.to_dict(orient="records")] + docs = [ + _row_to_doc(row, dataset, date_cols) for row in df.to_dict(orient="records") + ] logger.info("Loaded %d rows from '%s' → dataset '%s'", len(docs), csv_file, dataset) return docs @@ -159,10 +168,16 @@ def load_dataset(data_dir: str, csv_file: str, dataset: str, date_cols: dict) -> def main() -> None: - parser = argparse.ArgumentParser(description="Initialize CouchDB work-order database from CSVs.") - parser.add_argument("--data-dir", default=WO_DATA_DIR, help="Directory containing CSVs") + parser = argparse.ArgumentParser( + description="Initialize CouchDB work-order database from CSVs." + ) + parser.add_argument( + "--data-dir", default=WO_DATA_DIR, help="Directory containing CSVs" + ) parser.add_argument("--db", default=WO_DBNAME, help="CouchDB database name") - parser.add_argument("--drop", action="store_true", help="Drop and recreate database if it exists") + parser.add_argument( + "--drop", action="store_true", help="Drop and recreate database if it exists" + ) args = parser.parse_args() logger.info("CouchDB URL: %s", COUCHDB_URL) diff --git a/src/evaluation/cli.py b/src/evaluation/cli.py index faf369652..ebf0bbf13 100644 --- a/src/evaluation/cli.py +++ b/src/evaluation/cli.py @@ -47,8 +47,7 @@ def _build_parser() -> argparse.ArgumentParser: "--scorer-default", dest="scorer_default", default="llm_judge", - help="Scorer name when scenario.scoring_method is unset. " - "Default: llm_judge.", + help="Scorer name when scenario.scoring_method is unset. Default: llm_judge.", ) p.add_argument( "--judge-model", diff --git a/src/evaluation/evaluator.py b/src/evaluation/evaluator.py index dedc82885..3f884bfc6 100644 --- a/src/evaluation/evaluator.py +++ b/src/evaluation/evaluator.py @@ -19,7 +19,6 @@ PersistedTrajectory, Scenario, ScenarioResult, - ScorerResult, ) from .report import build_report from .scorers import Scorer diff --git a/src/evaluation/metrics.py b/src/evaluation/metrics.py index 325074a7e..0255263c3 100644 --- a/src/evaluation/metrics.py +++ b/src/evaluation/metrics.py @@ -40,7 +40,9 @@ def _from_sdk_trajectory(traj: dict, model: str) -> OpsMetrics: tokens_in = sum(int(t.get("input_tokens") or 0) for t in turns) tokens_out = sum(int(t.get("output_tokens") or 0) for t in turns) - durations_ms = [t.get("duration_ms") for t in turns if t.get("duration_ms") is not None] + durations_ms = [ + t.get("duration_ms") for t in turns if t.get("duration_ms") is not None + ] duration_ms = sum(durations_ms) if durations_ms else None tool_names: list[str] = [] @@ -65,11 +67,7 @@ def _from_plan_execute(steps: list[Any], model: str) -> OpsMetrics: # plan-execute persists ``list[StepResult]``; the dataclass exposes # ``server`` / ``tool`` / ``response`` fields but no per-step token # counts, so we surface what is available and leave the rest at zero. - tool_names = [ - s.get("tool") - for s in steps - if isinstance(s, dict) and s.get("tool") - ] + tool_names = [s.get("tool") for s in steps if isinstance(s, dict) and s.get("tool")] return OpsMetrics( turn_count=len(steps), tool_call_count=len(tool_names), diff --git a/src/evaluation/scorers/__init__.py b/src/evaluation/scorers/__init__.py index a2fa994e6..f1b32fd05 100644 --- a/src/evaluation/scorers/__init__.py +++ b/src/evaluation/scorers/__init__.py @@ -30,9 +30,7 @@ def register(name: str, scorer: Scorer) -> None: def get(name: str) -> Scorer: if name not in _REGISTRY: - raise KeyError( - f"unknown scorer {name!r}; registered: {sorted(_REGISTRY)}" - ) + raise KeyError(f"unknown scorer {name!r}; registered: {sorted(_REGISTRY)}") return _REGISTRY[name] diff --git a/src/evaluation/scorers/llm_judge.py b/src/evaluation/scorers/llm_judge.py index e37ecc219..139744ddb 100644 --- a/src/evaluation/scorers/llm_judge.py +++ b/src/evaluation/scorers/llm_judge.py @@ -140,9 +140,7 @@ def __call__( if review.get("hallucinations") is True: score = max(0.0, score - 0.2) - rationale = str( - review.get("suggestions") or review.get("reason") or "" - )[:500] + rationale = str(review.get("suggestions") or review.get("reason") or "")[:500] return ScorerResult( scorer=self.name, passed=passed, diff --git a/src/evaluation/tests/test_loader.py b/src/evaluation/tests/test_loader.py index 24260b34b..580c136e6 100644 --- a/src/evaluation/tests/test_loader.py +++ b/src/evaluation/tests/test_loader.py @@ -21,7 +21,9 @@ def test_load_trajectories_from_dir(trajectory_dir: Path): def test_load_trajectories_skips_unparseable(tmp_path: Path, make_persisted_record): - (tmp_path / "good.json").write_text(json.dumps(make_persisted_record()), encoding="utf-8") + (tmp_path / "good.json").write_text( + json.dumps(make_persisted_record()), encoding="utf-8" + ) (tmp_path / "bad.json").write_text("{not json", encoding="utf-8") records = load_trajectories(tmp_path) assert len(records) == 1 @@ -30,9 +32,7 @@ def test_load_trajectories_skips_unparseable(tmp_path: Path, make_persisted_reco def test_load_scenarios_json_list(tmp_path: Path): p = tmp_path / "s.json" p.write_text( - json.dumps( - [{"id": 1, "text": "Q1"}, {"id": "2", "text": "Q2"}] - ), + json.dumps([{"id": 1, "text": "Q1"}, {"id": "2", "text": "Q2"}]), encoding="utf-8", ) out = load_scenarios(p) @@ -65,7 +65,9 @@ def test_join_drops_orphans(make_persisted_record): ] trajs = [ PersistedTrajectory.from_raw(make_persisted_record(scenario_id=1)), - PersistedTrajectory.from_raw(make_persisted_record(run_id="r2", scenario_id=99)), + PersistedTrajectory.from_raw( + make_persisted_record(run_id="r2", scenario_id=99) + ), ] pairs = list(join_records(scenarios, trajs)) assert len(pairs) == 1 diff --git a/src/evaluation/tests/test_metrics.py b/src/evaluation/tests/test_metrics.py index 21f097b1c..df096d032 100644 --- a/src/evaluation/tests/test_metrics.py +++ b/src/evaluation/tests/test_metrics.py @@ -47,9 +47,27 @@ def test_plan_execute_list_trajectory(self, make_persisted_record): rec = PersistedTrajectory.from_raw( make_persisted_record( trajectory=[ - {"step_number": 1, "task": "t", "server": "iot", "tool": "sites", "response": "ok"}, - {"step_number": 2, "task": "t2", "server": "iot", "tool": "assets", "response": "ok"}, - {"step_number": 3, "task": "t3", "server": "iot", "tool": "sites", "response": "ok"}, + { + "step_number": 1, + "task": "t", + "server": "iot", + "tool": "sites", + "response": "ok", + }, + { + "step_number": 2, + "task": "t2", + "server": "iot", + "tool": "assets", + "response": "ok", + }, + { + "step_number": 3, + "task": "t3", + "server": "iot", + "tool": "sites", + "response": "ok", + }, ] ) ) @@ -67,9 +85,21 @@ def test_empty(self): def test_sums_and_percentiles(self): results = [ - _result(ops=OpsMetrics(tokens_in=10, tokens_out=5, duration_ms=100.0, tool_call_count=1)), - _result(ops=OpsMetrics(tokens_in=20, tokens_out=10, duration_ms=300.0, tool_call_count=2)), - _result(ops=OpsMetrics(tokens_in=30, tokens_out=15, duration_ms=500.0, tool_call_count=3)), + _result( + ops=OpsMetrics( + tokens_in=10, tokens_out=5, duration_ms=100.0, tool_call_count=1 + ) + ), + _result( + ops=OpsMetrics( + tokens_in=20, tokens_out=10, duration_ms=300.0, tool_call_count=2 + ) + ), + _result( + ops=OpsMetrics( + tokens_in=30, tokens_out=15, duration_ms=500.0, tool_call_count=3 + ) + ), ] agg = aggregate_ops(results) assert agg.tokens_in_total == 60 @@ -90,7 +120,10 @@ def test_cost_only_when_some_present(self): class TestNormalizeModel: def test_strips_provider_prefix(self): - assert _normalize_model("litellm_proxy/anthropic/claude-opus-4-5") == "claude-opus-4-5" + assert ( + _normalize_model("litellm_proxy/anthropic/claude-opus-4-5") + == "claude-opus-4-5" + ) assert _normalize_model("watsonx/ibm/granite-13b") == "granite-13b" def test_strips_long_numeric_suffix(self): diff --git a/src/evaluation/tests/test_models.py b/src/evaluation/tests/test_models.py index 4aca4d551..621107a02 100644 --- a/src/evaluation/tests/test_models.py +++ b/src/evaluation/tests/test_models.py @@ -10,7 +10,9 @@ def test_scenario_from_raw_coerces_int_id_to_str(): def test_scenario_preserves_extra_fields(): - s = Scenario.from_raw({"id": "1", "text": "Q", "characteristic_form": "X", "tolerance": 0.01}) + s = Scenario.from_raw( + {"id": "1", "text": "Q", "characteristic_form": "X", "tolerance": 0.01} + ) extra = s.model_extra or {} assert extra.get("tolerance") == 0.01 diff --git a/src/evaluation/tests/test_report.py b/src/evaluation/tests/test_report.py index 7c71788dc..aabb5042c 100644 --- a/src/evaluation/tests/test_report.py +++ b/src/evaluation/tests/test_report.py @@ -27,7 +27,9 @@ def _result(stype: str, passed: bool, run_id: str = "", **ops_kwargs) -> Scenari model="watsonx/ibm/granite", question="q", answer="a", - score=ScorerResult(scorer="llm_judge", passed=passed, score=1.0 if passed else 0.0), + score=ScorerResult( + scorer="llm_judge", passed=passed, score=1.0 if passed else 0.0 + ), ops=OpsMetrics(**ops_kwargs), ) @@ -98,7 +100,14 @@ def test_write_reports_dir_falls_back_to_scenario_id(tmp_path: Path): def test_render_summary_includes_headlines(): results = [ - _result("iot", True, tokens_in=10, tokens_out=5, duration_ms=100.0, tool_call_count=1), + _result( + "iot", + True, + tokens_in=10, + tokens_out=5, + duration_ms=100.0, + tool_call_count=1, + ), _result("iot", False, tokens_in=8, tokens_out=4, duration_ms=200.0), ] text = render_summary(build_report(results)) diff --git a/src/evaluation/tests/test_runner.py b/src/evaluation/tests/test_runner.py index f8a936db0..b82123f74 100644 --- a/src/evaluation/tests/test_runner.py +++ b/src/evaluation/tests/test_runner.py @@ -10,7 +10,9 @@ from evaluation import scorers as registry -def _always_pass_scorer(scenario: Scenario, answer: str, trajectory_text: str) -> ScorerResult: +def _always_pass_scorer( + scenario: Scenario, answer: str, trajectory_text: str +) -> ScorerResult: return ScorerResult(scorer="stub", passed=True, score=1.0) @@ -46,11 +48,15 @@ def test_evaluate_end_to_end(tmp_path: Path, make_persisted_record): assert report.ops.tokens_in_total > 0 -def _always_fail_scorer(scenario: Scenario, answer: str, trajectory_text: str) -> ScorerResult: +def _always_fail_scorer( + scenario: Scenario, answer: str, trajectory_text: str +) -> ScorerResult: return ScorerResult(scorer="stub-fail", passed=False, score=0.0) -def test_evaluate_uses_per_scenario_scoring_method(tmp_path: Path, make_persisted_record): +def test_evaluate_uses_per_scenario_scoring_method( + tmp_path: Path, make_persisted_record +): rec = make_persisted_record(run_id="run-x", scenario_id=1, answer="A.") (tmp_path / "run-x.json").write_text(json.dumps(rec), encoding="utf-8") diff --git a/src/llm/base.py b/src/llm/base.py index a6b085141..6df322ab5 100644 --- a/src/llm/base.py +++ b/src/llm/base.py @@ -27,9 +27,7 @@ def generate(self, prompt: str, temperature: float = 0.0) -> str: """Generate text given a prompt.""" ... - def generate_with_usage( - self, prompt: str, temperature: float = 0.0 - ) -> LLMResult: + def generate_with_usage(self, prompt: str, temperature: float = 0.0) -> LLMResult: """Generate text and report token usage. Default impl delegates to :meth:`generate` and reports zero usage — diff --git a/src/llm/litellm.py b/src/llm/litellm.py index 85067c7c1..3a1edd1b7 100644 --- a/src/llm/litellm.py +++ b/src/llm/litellm.py @@ -36,9 +36,7 @@ def __init__(self, model_id: str) -> None: def generate(self, prompt: str, temperature: float = 0.0) -> str: return self.generate_with_usage(prompt, temperature).text - def generate_with_usage( - self, prompt: str, temperature: float = 0.0 - ) -> LLMResult: + def generate_with_usage(self, prompt: str, temperature: float = 0.0) -> LLMResult: import litellm kwargs: dict = { diff --git a/src/observability/persistence.py b/src/observability/persistence.py index 692e13503..49f7443e9 100644 --- a/src/observability/persistence.py +++ b/src/observability/persistence.py @@ -79,9 +79,7 @@ def persist_trajectory( } try: - out_path.write_text( - json.dumps(record, indent=2, default=str), encoding="utf-8" - ) + out_path.write_text(json.dumps(record, indent=2, default=str), encoding="utf-8") except OSError: _log.exception("persist_trajectory: write failed at %s", out_path) return None diff --git a/src/observability/runspan.py b/src/observability/runspan.py index c22435806..6fe53688e 100644 --- a/src/observability/runspan.py +++ b/src/observability/runspan.py @@ -76,5 +76,3 @@ def agent_run_span( span.record_exception(exc) span.set_status(Status(StatusCode.ERROR, str(exc))) raise - - diff --git a/src/observability/tests/test_persistence.py b/src/observability/tests/test_persistence.py index d881adf5d..555ab465e 100644 --- a/src/observability/tests/test_persistence.py +++ b/src/observability/tests/test_persistence.py @@ -64,7 +64,9 @@ def test_persist_writes_file(monkeypatch, tmp_path: Path): _FakeTurn( index=0, text="hello", - tool_calls=[_FakeToolCall(name="sensors", input={"id": "CH-6"}, output="ok")], + tool_calls=[ + _FakeToolCall(name="sensors", input={"id": "CH-6"}, output="ok") + ], input_tokens=100, output_tokens=20, ), @@ -111,7 +113,9 @@ class _FakeStep: ) record = json.loads(out.read_text()) - assert record["trajectory"] == [{"step_number": 1, "task": "do thing", "success": True}] + assert record["trajectory"] == [ + {"step_number": 1, "task": "do thing", "success": True} + ] def test_persist_skips_when_no_run_id(monkeypatch, tmp_path: Path, caplog): diff --git a/src/observability/tests/test_tracing.py b/src/observability/tests/test_tracing.py index cea33a38a..8be10235f 100644 --- a/src/observability/tests/test_tracing.py +++ b/src/observability/tests/test_tracing.py @@ -88,7 +88,9 @@ def test_agent_run_span_emits_attributes(memory_exporter): assert s.attributes["agent.runner"] == "plan-execute" assert s.attributes["gen_ai.system"] == "anthropic" assert s.attributes["gen_ai.request.model"] == "litellm_proxy/aws/claude-opus-4-6" - assert s.attributes["agent.question.length"] == len("What sensors are on Chiller 6?") + assert s.attributes["agent.question.length"] == len( + "What sensors are on Chiller 6?" + ) assert s.attributes["custom.flag"] is True diff --git a/src/observability/tracing.py b/src/observability/tracing.py index d4b12ce44..7c0eb9afd 100644 --- a/src/observability/tracing.py +++ b/src/observability/tracing.py @@ -73,7 +73,9 @@ def init_tracing(service_name: str) -> None: if _initialized: return - provider = TracerProvider(resource=Resource.create({"service.name": service_name})) + provider = TracerProvider( + resource=Resource.create({"service.name": service_name}) + ) if (path := _traces_file_path()) is not None: from .file_exporter import OTLPJsonFileExporter diff --git a/src/servers/fmsr/main.py b/src/servers/fmsr/main.py index 1638b7629..756147798 100644 --- a/src/servers/fmsr/main.py +++ b/src/servers/fmsr/main.py @@ -30,7 +30,9 @@ load_dotenv() -_log_level = getattr(logging, os.environ.get("LOG_LEVEL", "WARNING").upper(), logging.WARNING) +_log_level = getattr( + logging, os.environ.get("LOG_LEVEL", "WARNING").upper(), logging.WARNING +) logging.basicConfig(level=_log_level) logger = logging.getLogger("fmsr-mcp-server") @@ -62,6 +64,7 @@ # ── Output parsers ──────────────────────────────────────────────────────────── + def _parse_numbered_list(text: str) -> list[str]: """Parse a numbered list response into a plain list of strings.""" items = [] @@ -97,11 +100,15 @@ def _build_llm(): model_id = os.environ.get("FMSR_MODEL_ID", _DEFAULT_MODEL_ID) if model_id.startswith("watsonx/"): - missing = [v for v in ("WATSONX_APIKEY", "WATSONX_PROJECT_ID") if not os.environ.get(v)] + missing = [ + v for v in ("WATSONX_APIKEY", "WATSONX_PROJECT_ID") if not os.environ.get(v) + ] if missing: raise RuntimeError(f"Missing env vars for WatsonX: {missing}") else: - missing = [v for v in ("LITELLM_API_KEY", "LITELLM_BASE_URL") if not os.environ.get(v)] + missing = [ + v for v in ("LITELLM_API_KEY", "LITELLM_BASE_URL") if not os.environ.get(v) + ] if missing: raise RuntimeError(f"Missing env vars for LiteLLM: {missing}") return LiteLLMBackend(model_id) @@ -155,6 +162,7 @@ def _call_relevancy(asset_name: str, failure_mode: str, sensor: str) -> dict: # ── Result models ───────────────────────────────────────────────────────────── + class ErrorResult(BaseModel): error: str @@ -188,7 +196,10 @@ class FailureModeSensorMappingResult(BaseModel): # ── FastMCP server ──────────────────────────────────────────────────────────── -mcp = FastMCP("fmsr", instructions="Failure mode and sensor reasoning: get failure modes for assets and determine which sensors can detect each failure.") +mcp = FastMCP( + "fmsr", + instructions="Failure mode and sensor reasoning: get failure modes for assets and determine which sensors can detect each failure.", +) @mcp.tool(title="Get Failure Modes") diff --git a/src/servers/fmsr/tests/conftest.py b/src/servers/fmsr/tests/conftest.py index 4a1959c69..b2b10e0db 100644 --- a/src/servers/fmsr/tests/conftest.py +++ b/src/servers/fmsr/tests/conftest.py @@ -27,7 +27,11 @@ def no_llm(): def mock_relevancy_chain(): """Patch _call_relevancy so it always returns 'Yes' without calling the LLM.""" mock = MagicMock( - return_value={"answer": "Yes", "reason": "Relevant sensor", "temporal_behavior": "Increases"} + return_value={ + "answer": "Yes", + "reason": "Relevant sensor", + "temporal_behavior": "Increases", + } ) with patch("servers.fmsr.main._call_relevancy", mock): with patch("servers.fmsr.main._llm_available", True): diff --git a/src/servers/fmsr/tests/test_tools.py b/src/servers/fmsr/tests/test_tools.py index 3bbc3129d..cece28473 100644 --- a/src/servers/fmsr/tests/test_tools.py +++ b/src/servers/fmsr/tests/test_tools.py @@ -75,7 +75,11 @@ async def test_returns_expected_keys(self, mock_relevancy_chain): data = await call_tool( mcp, "get_failure_mode_sensor_mapping", - {"asset_name": "Chiller 6", "failure_modes": _FAILURE_MODES, "sensors": _SENSORS}, + { + "asset_name": "Chiller 6", + "failure_modes": _FAILURE_MODES, + "sensors": _SENSORS, + }, ) assert "fm2sensor" in data assert "sensor2fm" in data @@ -88,7 +92,11 @@ async def test_full_relevancy_count(self, mock_relevancy_chain): data = await call_tool( mcp, "get_failure_mode_sensor_mapping", - {"asset_name": "Chiller 6", "failure_modes": _FAILURE_MODES, "sensors": _SENSORS}, + { + "asset_name": "Chiller 6", + "failure_modes": _FAILURE_MODES, + "sensors": _SENSORS, + }, ) assert len(data["full_relevancy"]) == 4 @@ -115,7 +123,11 @@ async def test_llm_unavailable_returns_error(self, no_llm): data = await call_tool( mcp, "get_failure_mode_sensor_mapping", - {"asset_name": "Chiller 6", "failure_modes": _FAILURE_MODES, "sensors": _SENSORS}, + { + "asset_name": "Chiller 6", + "failure_modes": _FAILURE_MODES, + "sensors": _SENSORS, + }, ) assert "error" in data diff --git a/src/servers/iot/main.py b/src/servers/iot/main.py index 9e4732087..d657b10fd 100644 --- a/src/servers/iot/main.py +++ b/src/servers/iot/main.py @@ -1,7 +1,6 @@ import os import logging from datetime import datetime -from functools import lru_cache from typing import Any, Dict, List, Optional, Union from mcp.server.fastmcp import FastMCP from pydantic import BaseModel @@ -37,7 +36,10 @@ logger.error(f"Failed to connect to CouchDB: {e}") db = None -mcp = FastMCP("iot", instructions="IoT sensor data: browse sites, assets, sensors, and query historical readings from CouchDB.") +mcp = FastMCP( + "iot", + instructions="IoT sensor data: browse sites, assets, sensors, and query historical readings from CouchDB.", +) # Static site as per original requirement SITES = ["MAIN"] diff --git a/src/servers/iot/tests/conftest.py b/src/servers/iot/tests/conftest.py index 83a9ef3df..b99bcd0c2 100644 --- a/src/servers/iot/tests/conftest.py +++ b/src/servers/iot/tests/conftest.py @@ -16,6 +16,7 @@ def _couchdb_reachable() -> bool: return False try: import requests + requests.get(url, timeout=2) return True except Exception: diff --git a/src/servers/iot/tests/test_couchdb.py b/src/servers/iot/tests/test_couchdb.py index 36fea30cc..fc27ae6ce 100644 --- a/src/servers/iot/tests/test_couchdb.py +++ b/src/servers/iot/tests/test_couchdb.py @@ -28,7 +28,9 @@ def couchdb_client(): @requires_couchdb class TestCouchDBInfrastructure: def test_connection(self): - resp = requests.get(f"http://{COUCHDB_HOST}", auth=(COUCHDB_USERNAME, COUCHDB_PASSWORD)) + resp = requests.get( + f"http://{COUCHDB_HOST}", auth=(COUCHDB_USERNAME, COUCHDB_PASSWORD) + ) assert resp.status_code == 200 client = couchdb3.Server(FULL_URL) diff --git a/src/servers/tsfm/forecasting.py b/src/servers/tsfm/forecasting.py index bcd2f2f3c..08ab34460 100644 --- a/src/servers/tsfm/forecasting.py +++ b/src/servers/tsfm/forecasting.py @@ -394,7 +394,6 @@ def _finetune_ttm_hf( TimeSeriesPreprocessor, get_datasets, ) - from tsfm_public.toolkit.util import select_by_index from transformers import Trainer, TrainingArguments, EarlyStoppingCallback, set_seed if training_config_dic is None: diff --git a/src/servers/tsfm/io.py b/src/servers/tsfm/io.py index d0a347a0d..d38728032 100644 --- a/src/servers/tsfm/io.py +++ b/src/servers/tsfm/io.py @@ -7,7 +7,6 @@ import tempfile import uuid from datetime import datetime -from typing import Optional import numpy as np import pandas as pd diff --git a/src/servers/tsfm/main.py b/src/servers/tsfm/main.py index 288a388b2..5df5d34d0 100644 --- a/src/servers/tsfm/main.py +++ b/src/servers/tsfm/main.py @@ -29,7 +29,7 @@ import tempfile import uuid from functools import lru_cache -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union import numpy as np import pandas as pd @@ -72,7 +72,6 @@ logger = logging.getLogger("tsfm-mcp-server") - # ── Internal helpers ────────────────────────────────────────────────────────── @@ -115,7 +114,10 @@ def _tsad_output_to_df(output: dict) -> pd.DataFrame: # ── FastMCP server ──────────────────────────────────────────────────────────── -mcp = FastMCP("tsfm", instructions="Time-series foundation models: forecasting, finetuning, and anomaly detection using IBM Granite TinyTimeMixer.") +mcp = FastMCP( + "tsfm", + instructions="Time-series foundation models: forecasting, finetuning, and anomaly detection using IBM Granite TinyTimeMixer.", +) # ── Static tools ────────────────────────────────────────────────────────────── @@ -576,7 +578,9 @@ def run_integrated_tsad( frequency_sampling, autoregressive_modeling, ) - full_data_df = _read_ts_data(dataset_path, dataset_config_dictionary=full_config) + full_data_df = _read_ts_data( + dataset_path, dataset_config_dictionary=full_config + ) for col in target_columns: col_config = _build_dataset_config( diff --git a/src/servers/tsfm/metrics.py b/src/servers/tsfm/metrics.py index 1068c8a5d..491bca4f6 100644 --- a/src/servers/tsfm/metrics.py +++ b/src/servers/tsfm/metrics.py @@ -188,7 +188,6 @@ def _TILDEQ(outputs, targets, axis=1): def _derivatives(inp, device="cpu"): - import torch batch_size, lens = inp.shape[0:2] input2 = inp[:, 2:lens].to(device) diff --git a/src/servers/tsfm/tests/conftest.py b/src/servers/tsfm/tests/conftest.py index 169484aad..7445a47fa 100644 --- a/src/servers/tsfm/tests/conftest.py +++ b/src/servers/tsfm/tests/conftest.py @@ -3,14 +3,15 @@ from __future__ import annotations import json -import os import pytest + # Skip marker for tests that require tsfm_public + its ML dependencies. def _tsfm_available() -> bool: try: import tsfm_public # noqa: F401 + return True except ImportError: return False diff --git a/src/servers/tsfm/tests/test_tools.py b/src/servers/tsfm/tests/test_tools.py index 744b4abc9..78790f34f 100644 --- a/src/servers/tsfm/tests/test_tools.py +++ b/src/servers/tsfm/tests/test_tools.py @@ -15,6 +15,7 @@ # ── get_ai_tasks ────────────────────────────────────────────────────────────── + class TestGetAITasks: @pytest.mark.anyio async def test_returns_tasks_list(self): @@ -40,6 +41,7 @@ async def test_each_task_has_description(self): # ── get_tsfm_models ─────────────────────────────────────────────────────────── + class TestGetTSFMModels: @pytest.mark.anyio async def test_returns_models_list(self): @@ -65,11 +67,13 @@ async def test_each_model_has_checkpoint_and_description(self): # ── run_tsfm_forecasting — input validation ─────────────────────────────────── + class TestRunTSFMForecastingValidation: @pytest.mark.anyio async def test_empty_dataset_path_returns_error(self): data = await call_tool( - mcp, "run_tsfm_forecasting", + mcp, + "run_tsfm_forecasting", {"dataset_path": "", "timestamp_column": "ts", "target_columns": ["val"]}, ) assert "error" in data @@ -78,8 +82,13 @@ async def test_empty_dataset_path_returns_error(self): @pytest.mark.anyio async def test_empty_target_columns_returns_error(self): data = await call_tool( - mcp, "run_tsfm_forecasting", - {"dataset_path": "/tmp/data.csv", "timestamp_column": "ts", "target_columns": []}, + mcp, + "run_tsfm_forecasting", + { + "dataset_path": "/tmp/data.csv", + "timestamp_column": "ts", + "target_columns": [], + }, ) assert "error" in data assert "target_columns" in data["error"] @@ -89,7 +98,8 @@ async def test_missing_deps_returns_error(self): # tsfm_public is not expected to be installed in the CI/MCP environment. # If it IS installed this test is a no-op (the import succeeds). data = await call_tool( - mcp, "run_tsfm_forecasting", + mcp, + "run_tsfm_forecasting", { "dataset_path": "/nonexistent/data.csv", "timestamp_column": "Timestamp", @@ -103,11 +113,13 @@ async def test_missing_deps_returns_error(self): # ── run_tsfm_finetuning — input validation ──────────────────────────────────── + class TestRunTSFMFinetuningValidation: @pytest.mark.anyio async def test_empty_dataset_path_returns_error(self): data = await call_tool( - mcp, "run_tsfm_finetuning", + mcp, + "run_tsfm_finetuning", {"dataset_path": "", "timestamp_column": "ts", "target_columns": ["val"]}, ) assert "error" in data @@ -116,8 +128,13 @@ async def test_empty_dataset_path_returns_error(self): @pytest.mark.anyio async def test_empty_target_columns_returns_error(self): data = await call_tool( - mcp, "run_tsfm_finetuning", - {"dataset_path": "/tmp/data.csv", "timestamp_column": "ts", "target_columns": []}, + mcp, + "run_tsfm_finetuning", + { + "dataset_path": "/tmp/data.csv", + "timestamp_column": "ts", + "target_columns": [], + }, ) assert "error" in data assert "target_columns" in data["error"] @@ -125,11 +142,13 @@ async def test_empty_target_columns_returns_error(self): # ── run_tsad — input validation ─────────────────────────────────────────────── + class TestRunTSADValidation: @pytest.mark.anyio async def test_empty_dataset_path_returns_error(self): data = await call_tool( - mcp, "run_tsad", + mcp, + "run_tsad", { "dataset_path": "", "tsfm_output_json": "/tmp/pred.json", @@ -143,7 +162,8 @@ async def test_empty_dataset_path_returns_error(self): @pytest.mark.anyio async def test_empty_tsfm_output_json_returns_error(self): data = await call_tool( - mcp, "run_tsad", + mcp, + "run_tsad", { "dataset_path": "/tmp/data.csv", "tsfm_output_json": "", @@ -157,7 +177,8 @@ async def test_empty_tsfm_output_json_returns_error(self): @pytest.mark.anyio async def test_invalid_task_returns_error(self): data = await call_tool( - mcp, "run_tsad", + mcp, + "run_tsad", { "dataset_path": "/tmp/data.csv", "tsfm_output_json": "/tmp/pred.json", @@ -172,7 +193,8 @@ async def test_invalid_task_returns_error(self): @pytest.mark.anyio async def test_empty_target_columns_returns_error(self): data = await call_tool( - mcp, "run_tsad", + mcp, + "run_tsad", { "dataset_path": "/tmp/data.csv", "tsfm_output_json": "/tmp/pred.json", @@ -186,11 +208,13 @@ async def test_empty_target_columns_returns_error(self): # ── run_integrated_tsad — input validation ──────────────────────────────────── + class TestRunIntegratedTSADValidation: @pytest.mark.anyio async def test_empty_dataset_path_returns_error(self): data = await call_tool( - mcp, "run_integrated_tsad", + mcp, + "run_integrated_tsad", {"dataset_path": "", "timestamp_column": "ts", "target_columns": ["val"]}, ) assert "error" in data @@ -199,8 +223,13 @@ async def test_empty_dataset_path_returns_error(self): @pytest.mark.anyio async def test_empty_target_columns_returns_error(self): data = await call_tool( - mcp, "run_integrated_tsad", - {"dataset_path": "/tmp/data.csv", "timestamp_column": "ts", "target_columns": []}, + mcp, + "run_integrated_tsad", + { + "dataset_path": "/tmp/data.csv", + "timestamp_column": "ts", + "target_columns": [], + }, ) assert "error" in data assert "target_columns" in data["error"] @@ -208,6 +237,7 @@ async def test_empty_target_columns_returns_error(self): # ── Integration tests (requires tsfm_public) ───────────────────────────────── + @requires_tsfm class TestTSFMForecastingIntegration: @pytest.mark.anyio @@ -218,15 +248,18 @@ async def test_forecasting_returns_results_file(self, tmp_path): # Create a small synthetic sine-wave CSV n = 200 - df = pd.DataFrame({ - "Timestamp": pd.date_range("2024-01-01", periods=n, freq="15min"), - "sensor_1": np.sin(np.linspace(0, 4 * np.pi, n)), - }) + df = pd.DataFrame( + { + "Timestamp": pd.date_range("2024-01-01", periods=n, freq="15min"), + "sensor_1": np.sin(np.linspace(0, 4 * np.pi, n)), + } + ) csv_path = str(tmp_path / "synthetic.csv") df.to_csv(csv_path, index=False) data = await call_tool( - mcp, "run_tsfm_forecasting", + mcp, + "run_tsfm_forecasting", { "dataset_path": csv_path, "timestamp_column": "Timestamp", @@ -249,15 +282,19 @@ async def test_integrated_tsad_returns_csv(self, tmp_path): import numpy as np n = 300 - df = pd.DataFrame({ - "Timestamp": pd.date_range("2024-01-01", periods=n, freq="15min"), - "sensor_1": np.sin(np.linspace(0, 6 * np.pi, n)) + np.random.randn(n) * 0.05, - }) + df = pd.DataFrame( + { + "Timestamp": pd.date_range("2024-01-01", periods=n, freq="15min"), + "sensor_1": np.sin(np.linspace(0, 6 * np.pi, n)) + + np.random.randn(n) * 0.05, + } + ) csv_path = str(tmp_path / "synthetic_ad.csv") df.to_csv(csv_path, index=False) data = await call_tool( - mcp, "run_integrated_tsad", + mcp, + "run_integrated_tsad", { "dataset_path": csv_path, "timestamp_column": "Timestamp", diff --git a/src/servers/utilities/main.py b/src/servers/utilities/main.py index 48e1858b8..42858783c 100644 --- a/src/servers/utilities/main.py +++ b/src/servers/utilities/main.py @@ -13,11 +13,16 @@ # Setup logging — default WARNING so stderr stays quiet when used as MCP server; # set LOG_LEVEL=INFO (or DEBUG) in the environment to see verbose output. -_log_level = getattr(logging, os.environ.get("LOG_LEVEL", "WARNING").upper(), logging.WARNING) +_log_level = getattr( + logging, os.environ.get("LOG_LEVEL", "WARNING").upper(), logging.WARNING +) logging.basicConfig(level=_log_level) logger = logging.getLogger("utilities-mcp-server") -mcp = FastMCP("utilities", instructions="General utilities: read JSON files and get current date/time.") +mcp = FastMCP( + "utilities", + instructions="General utilities: read JSON files and get current date/time.", +) class DateTimeResult(BaseModel): @@ -75,7 +80,9 @@ def current_date_time() -> DateTimeResult: description = f"Today's date is {date_part} and time is {time_part}." - return DateTimeResult(currentDateTime=now_iso, currentDateTimeDescription=description) + return DateTimeResult( + currentDateTime=now_iso, currentDateTimeDescription=description + ) @mcp.tool(title="Get Current Time in English") diff --git a/src/servers/utilities/tests/conftest.py b/src/servers/utilities/tests/conftest.py index e25f1a88b..04bc254fc 100644 --- a/src/servers/utilities/tests/conftest.py +++ b/src/servers/utilities/tests/conftest.py @@ -1,6 +1,5 @@ import json -import pytest # Helper functions for tests diff --git a/src/servers/vibration/data_store.py b/src/servers/vibration/data_store.py index e546899aa..b5088691b 100644 --- a/src/servers/vibration/data_store.py +++ b/src/servers/vibration/data_store.py @@ -86,9 +86,7 @@ def summary(self) -> dict: "sample_rate_hz": self.sample_rate, "duration_s": round(self.duration_s, 4), "channel_stats": channel_stats, - "metadata": { - k: v for k, v in self.metadata.items() if k != "axis_labels" - }, + "metadata": {k: v for k, v in self.metadata.items() if k != "axis_labels"}, } diff --git a/src/servers/vibration/dsp/__init__.py b/src/servers/vibration/dsp/__init__.py index 4f7dc3977..3522139c4 100644 --- a/src/servers/vibration/dsp/__init__.py +++ b/src/servers/vibration/dsp/__init__.py @@ -1,9 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # Adapted from https://github.com/LGDiMaggio/claude-stwinbox-diagnostics/tree/main/mcp-servers/vibration-analysis-mcp -from .fft_analysis import compute_fft, compute_psd, compute_spectrogram, find_peaks_in_spectrum +from .fft_analysis import ( + compute_fft, + compute_psd, + compute_spectrogram, + find_peaks_in_spectrum, +) from .envelope import envelope_spectrum, check_bearing_peaks -from .bearing_freqs import compute_bearing_frequencies, get_bearing, list_bearings, COMMON_BEARINGS +from .bearing_freqs import ( + compute_bearing_frequencies, + get_bearing, + list_bearings, + COMMON_BEARINGS, +) from .fault_detection import ( assess_iso10816, extract_shaft_features, diff --git a/src/servers/vibration/dsp/bearing_freqs.py b/src/servers/vibration/dsp/bearing_freqs.py index 149fb9975..9a22da7de 100644 --- a/src/servers/vibration/dsp/bearing_freqs.py +++ b/src/servers/vibration/dsp/bearing_freqs.py @@ -91,8 +91,10 @@ def compute_bearing_frequencies( ftf = f_shaft * 0.5 * (1.0 - ratio * math.cos(alpha_rad)) bpfo = f_shaft * (n_balls / 2.0) * (1.0 - ratio * math.cos(alpha_rad)) bpfi = f_shaft * (n_balls / 2.0) * (1.0 + ratio * math.cos(alpha_rad)) - bsf = f_shaft * (pitch_dia / (2.0 * ball_dia)) * ( - 1.0 - (ratio * math.cos(alpha_rad)) ** 2 + bsf = ( + f_shaft + * (pitch_dia / (2.0 * ball_dia)) + * (1.0 - (ratio * math.cos(alpha_rad)) ** 2) ) return BearingFrequencies( diff --git a/src/servers/vibration/dsp/envelope.py b/src/servers/vibration/dsp/envelope.py index f60b23ec3..90888f33a 100644 --- a/src/servers/vibration/dsp/envelope.py +++ b/src/servers/vibration/dsp/envelope.py @@ -189,7 +189,9 @@ def check_bearing_peaks( "harmonics_checked": n_harmonics, "harmonics_detected": detected_count, "confidence": ( - "high" if detected_count >= 2 else ("medium" if detected_count == 1 else "none") + "high" + if detected_count >= 2 + else ("medium" if detected_count == 1 else "none") ), "details": results, } diff --git a/src/servers/vibration/dsp/fault_detection.py b/src/servers/vibration/dsp/fault_detection.py index d613a8147..a588a7c36 100644 --- a/src/servers/vibration/dsp/fault_detection.py +++ b/src/servers/vibration/dsp/fault_detection.py @@ -51,9 +51,7 @@ def assess_iso10816( Returns: dict with zone (A/B/C/D), description, and thresholds used. """ - thresholds = ISO_10816_THRESHOLDS.get( - machine_group, ISO_10816_THRESHOLDS["group2"] - ) + thresholds = ISO_10816_THRESHOLDS.get(machine_group, ISO_10816_THRESHOLDS["group2"]) if rms_velocity_mm_s <= thresholds["A_good"]: zone, desc = "A", "Good - newly commissioned machines" @@ -246,9 +244,7 @@ def classify_faults( # --- Mechanical looseness: many harmonics + sub-harmonics --- n_significant = sum( - 1 - for a in [features.amp_1x, features.amp_2x, features.amp_3x] - if a / rms > 1.5 + 1 for a in [features.amp_1x, features.amp_2x, features.amp_3x] if a / rms > 1.5 ) if n_significant >= 3 or (features.amp_half_x / rms > 1.5): evidence = [f"Harmonics above threshold: {n_significant}/3"] diff --git a/src/servers/vibration/main.py b/src/servers/vibration/main.py index 398da8bd6..f2f91e37e 100644 --- a/src/servers/vibration/main.py +++ b/src/servers/vibration/main.py @@ -22,7 +22,6 @@ from .couchdb_client import fetch_vibration_timeseries, list_sensor_fields from .data_store import store from .dsp.bearing_freqs import ( - COMMON_BEARINGS, compute_bearing_frequencies, get_bearing, list_bearings, @@ -43,7 +42,10 @@ logging.basicConfig(level=_log_level) logger = logging.getLogger("vibration-mcp-server") -mcp = FastMCP("vibration", instructions="Vibration signal analysis: FFT, envelope spectrum, bearing fault detection, and ISO 10816 severity assessment.") +mcp = FastMCP( + "vibration", + instructions="Vibration signal analysis: FFT, envelope spectrum, bearing fault detection, and ISO 10816 severity assessment.", +) # --------------------------------------------------------------------------- diff --git a/src/servers/vibration/sample_data/generate_synthetic_vibration.py b/src/servers/vibration/sample_data/generate_synthetic_vibration.py index b6c4d4882..137678a15 100644 --- a/src/servers/vibration/sample_data/generate_synthetic_vibration.py +++ b/src/servers/vibration/sample_data/generate_synthetic_vibration.py @@ -32,6 +32,7 @@ python generate_synthetic_vibration.py # writes JSON to cwd python generate_synthetic_vibration.py --check # writes JSON + prints stats """ + from __future__ import annotations import argparse @@ -44,30 +45,30 @@ # --------------------------------------------------------------------------- # Machine / bearing parameters # --------------------------------------------------------------------------- -FS = 4096 # sampling rate [Hz] -DURATION = 1.0 # seconds -RPM = 1800 # shaft speed +FS = 4096 # sampling rate [Hz] +DURATION = 1.0 # seconds +RPM = 1800 # shaft speed F_SHAFT = RPM / 60 # shaft frequency [Hz] # SKF 6205-2RS (common small motor bearing) N_BALLS = 9 -BD = 7.94 # ball diameter [mm] -PD = 39.04 # pitch diameter [mm] -ALPHA = 0.0 # contact angle [rad] +BD = 7.94 # ball diameter [mm] +PD = 39.04 # pitch diameter [mm] +ALPHA = 0.0 # contact angle [rad] # Derived characteristic frequencies BPFO = N_BALLS / 2 * F_SHAFT * (1 - BD / PD * np.cos(ALPHA)) # ~107.5 Hz # Resonance and damping -F_RESONANCE = 3200.0 # structural resonance [Hz] -DAMPING = 5000.0 # exponential decay rate [1/s] (fast → sharp impulses) -IMPULSE_AMP = 2.0 # peak impulse amplitude [g] -LOAD_MOD = 0.5 # load-zone modulation depth (0 = none, 1 = full) +F_RESONANCE = 3200.0 # structural resonance [Hz] +DAMPING = 5000.0 # exponential decay rate [1/s] (fast → sharp impulses) +IMPULSE_AMP = 2.0 # peak impulse amplitude [g] +LOAD_MOD = 0.5 # load-zone modulation depth (0 = none, 1 = full) # Background -SHAFT_1X = 0.10 # 1× shaft amplitude [g] -SHAFT_2X = 0.04 # 2× shaft amplitude [g] -NOISE_STD = 0.02 # broadband noise σ [g] +SHAFT_1X = 0.10 # 1× shaft amplitude [g] +SHAFT_2X = 0.04 # 2× shaft amplitude [g] +NOISE_STD = 0.02 # broadband noise σ [g] # Time origin (arbitrary) T0 = datetime(2024, 1, 15, 0, 0, 0) @@ -82,8 +83,9 @@ def generate() -> tuple[np.ndarray, np.ndarray]: t = np.arange(n_samples) / FS # Shaft harmonics (healthy background) - shaft = SHAFT_1X * np.sin(2 * np.pi * F_SHAFT * t) + \ - SHAFT_2X * np.sin(2 * np.pi * 2 * F_SHAFT * t) + shaft = SHAFT_1X * np.sin(2 * np.pi * F_SHAFT * t) + SHAFT_2X * np.sin( + 2 * np.pi * 2 * F_SHAFT * t + ) # Bearing fault impulses at BPFO impulse_times = np.arange(0, DURATION, 1.0 / BPFO) @@ -93,8 +95,12 @@ def generate() -> tuple[np.ndarray, np.ndarray]: mask = dt >= 0 # Load-zone amplitude modulation amp = 1.0 + LOAD_MOD * np.cos(2 * np.pi * F_SHAFT * t_imp) - ring = amp * IMPULSE_AMP * np.exp(-DAMPING * dt[mask]) * \ - np.sin(2 * np.pi * F_RESONANCE * dt[mask]) + ring = ( + amp + * IMPULSE_AMP + * np.exp(-DAMPING * dt[mask]) + * np.sin(2 * np.pi * F_RESONANCE * dt[mask]) + ) bearing[mask] += ring noise = NOISE_STD * rng.standard_normal(n_samples) @@ -115,8 +121,9 @@ def to_couchdb_docs(t: np.ndarray, signal: np.ndarray) -> list[dict]: def main() -> None: parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--check", action="store_true", - help="Print signal statistics after generation") + parser.add_argument( + "--check", action="store_true", help="Print signal statistics after generation" + ) args = parser.parse_args() t, signal = generate() @@ -129,11 +136,12 @@ def main() -> None: print(f"Wrote {len(docs)} documents to {out}") if args.check: - rms = float(np.sqrt(np.mean(signal ** 2))) + rms = float(np.sqrt(np.mean(signal**2))) peak = float(np.max(np.abs(signal))) # Excess kurtosis with sample std (ddof=1), consistent with main.py - kurt = float(np.mean((signal - signal.mean()) ** 4) / - np.std(signal, ddof=1) ** 4 - 3) + kurt = float( + np.mean((signal - signal.mean()) ** 4) / np.std(signal, ddof=1) ** 4 - 3 + ) print(f" BPFO: {BPFO:.2f} Hz") print(f" f_shaft: {F_SHAFT:.1f} Hz") print(f" f_resonance: {F_RESONANCE:.1f} Hz") diff --git a/src/servers/vibration/tests/test_dsp.py b/src/servers/vibration/tests/test_dsp.py index 181d93d04..85f33b9b1 100644 --- a/src/servers/vibration/tests/test_dsp.py +++ b/src/servers/vibration/tests/test_dsp.py @@ -1,9 +1,7 @@ """Pure-function unit tests for the DSP layer — no MCP, no CouchDB.""" -import math import numpy as np -import pytest from servers.vibration.dsp.fft_analysis import ( compute_fft, @@ -219,8 +217,9 @@ def test_basic(self): freqs = np.array(fft["frequencies"]) mags = np.array(fft["magnitude"]) shaft_freq = 50.0 # as if rpm=3000 - features = extract_shaft_features(freqs, mags, shaft_freq, - time_signal=COMPOSITE) + features = extract_shaft_features( + freqs, mags, shaft_freq, time_signal=COMPOSITE + ) assert features.f_shaft == 50.0 assert features.amp_1x > 0 diff --git a/src/servers/vibration/tests/test_mcp_e2e.py b/src/servers/vibration/tests/test_mcp_e2e.py index 8892da325..e2ef5079a 100644 --- a/src/servers/vibration/tests/test_mcp_e2e.py +++ b/src/servers/vibration/tests/test_mcp_e2e.py @@ -38,7 +38,11 @@ import anyio import pytest from mcp.client.session import ClientSession -from mcp.client.stdio import StdioServerParameters, get_default_environment, stdio_client +from mcp.client.stdio import ( + StdioServerParameters, + get_default_environment, + stdio_client, +) # --------------------------------------------------------------------------- # Constants @@ -64,12 +68,21 @@ def _find_repo_root(start: Path) -> Path: # LLM credentials that must not reach the test subprocess. # Prevents accidental billable API calls if server-side logic is ever changed. -_SENSITIVE_KEYS: frozenset[str] = frozenset({ - "WATSONX_APIKEY", "WATSONX_PROJECT_ID", "WATSONX_URL", - "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "LITELLM_API_KEY", - "LITELLM_BASE_URL", "COHERE_API_KEY", "AZURE_API_KEY", - "AZURE_API_BASE", "HUGGINGFACE_API_KEY", -}) +_SENSITIVE_KEYS: frozenset[str] = frozenset( + { + "WATSONX_APIKEY", + "WATSONX_PROJECT_ID", + "WATSONX_URL", + "OPENAI_API_KEY", + "ANTHROPIC_API_KEY", + "LITELLM_API_KEY", + "LITELLM_BASE_URL", + "COHERE_API_KEY", + "AZURE_API_KEY", + "AZURE_API_BASE", + "HUGGINGFACE_API_KEY", + } +) # --------------------------------------------------------------------------- # Helpers @@ -173,7 +186,9 @@ class TestVibrationMCPProtocol: @pytest.mark.anyio async def test_sc01_tool_listing(self, vibration_session: ClientSession) -> None: """SC-01: Server starts and exposes expected tools over stdio.""" - tools = await asyncio.wait_for(vibration_session.list_tools(), timeout=_DEADLINE) + tools = await asyncio.wait_for( + vibration_session.list_tools(), timeout=_DEADLINE + ) names = {t.name for t in tools.tools} expected = { "get_vibration_data", @@ -188,7 +203,9 @@ async def test_sc01_tool_listing(self, vibration_session: ClientSession) -> None assert expected <= names, f"Missing tools: {expected - names}" @pytest.mark.anyio - async def test_sc02_static_tool_happy_path(self, vibration_session: ClientSession) -> None: + async def test_sc02_static_tool_happy_path( + self, vibration_session: ClientSession + ) -> None: """SC-02: list_known_bearings returns static database without CouchDB.""" result = await asyncio.wait_for( vibration_session.call_tool("list_known_bearings", {}), @@ -201,16 +218,26 @@ async def test_sc02_static_tool_happy_path(self, vibration_session: ClientSessio assert any("6205" in n for n in names), f"6205 not found in {names}" @pytest.mark.anyio - async def test_sc03_iso_severity_zone_classification(self, vibration_session: ClientSession) -> None: + async def test_sc03_iso_severity_zone_classification( + self, vibration_session: ClientSession + ) -> None: """SC-03: assess_vibration_severity classifies ISO 10816 zones correctly.""" - zone_d = _parse_result(await asyncio.wait_for( - vibration_session.call_tool("assess_vibration_severity", {"rms_velocity_mm_s": 50.0}), - timeout=_DEADLINE, - )) - zone_a = _parse_result(await asyncio.wait_for( - vibration_session.call_tool("assess_vibration_severity", {"rms_velocity_mm_s": 0.5}), - timeout=_DEADLINE, - )) + zone_d = _parse_result( + await asyncio.wait_for( + vibration_session.call_tool( + "assess_vibration_severity", {"rms_velocity_mm_s": 50.0} + ), + timeout=_DEADLINE, + ) + ) + zone_a = _parse_result( + await asyncio.wait_for( + vibration_session.call_tool( + "assess_vibration_severity", {"rms_velocity_mm_s": 0.5} + ), + timeout=_DEADLINE, + ) + ) assert zone_d.get("iso_zone") == "D", f"Expected D, got: {zone_d}" assert zone_a.get("iso_zone") == "A", f"Expected A, got: {zone_a}" diff --git a/src/servers/vibration/tests/test_tools.py b/src/servers/vibration/tests/test_tools.py index 8e081df76..fc3298ad6 100644 --- a/src/servers/vibration/tests/test_tools.py +++ b/src/servers/vibration/tests/test_tools.py @@ -16,8 +16,13 @@ # Helpers # --------------------------------------------------------------------------- -def _make_sine(freq_hz: float = 50.0, sr: float = 2048.0, - duration: float = 1.0, amplitude: float = 1.0) -> tuple: + +def _make_sine( + freq_hz: float = 50.0, + sr: float = 2048.0, + duration: float = 1.0, + amplitude: float = 1.0, +) -> tuple: """Generate a pure sine wave and store it; return (data_id, signal, sr).""" t = np.arange(0, duration, 1.0 / sr) sig = amplitude * np.sin(2 * np.pi * freq_hz * t) @@ -26,8 +31,9 @@ def _make_sine(freq_hz: float = 50.0, sr: float = 2048.0, return data_id, sig, sr -def _make_composite(freqs: list[float], sr: float = 4096.0, - duration: float = 2.0) -> str: +def _make_composite( + freqs: list[float], sr: float = 4096.0, duration: float = 2.0 +) -> str: """Composite signal with multiple sine components; returns data_id.""" t = np.arange(0, duration, 1.0 / sr) sig = np.zeros_like(t) @@ -56,16 +62,18 @@ async def test_basic_50hz(self): @pytest.mark.anyio async def test_missing_data_id(self): - result = await call_tool(mcp, "compute_fft_spectrum", - {"data_id": "nonexistent"}) + result = await call_tool( + mcp, "compute_fft_spectrum", {"data_id": "nonexistent"} + ) assert "error" in result @pytest.mark.anyio async def test_window_types(self): data_id, _, _ = _make_sine(100.0) for win in ("hann", "hamming", "blackman", "rectangular"): - result = await call_tool(mcp, "compute_fft_spectrum", - {"data_id": data_id, "window": win}) + result = await call_tool( + mcp, "compute_fft_spectrum", {"data_id": data_id, "window": win} + ) assert "error" not in result assert result["window"] == win @@ -79,16 +87,14 @@ class TestComputeEnvelopeSpectrum: @pytest.mark.anyio async def test_basic_run(self): data_id, _, _ = _make_sine(120.0, sr=4096.0) - result = await call_tool(mcp, "compute_envelope_spectrum", - {"data_id": data_id}) + result = await call_tool(mcp, "compute_envelope_spectrum", {"data_id": data_id}) assert "error" not in result assert "filter_band_hz" in result assert result["sample_rate_hz"] == 4096.0 @pytest.mark.anyio async def test_missing_data_id(self): - result = await call_tool(mcp, "compute_envelope_spectrum", - {"data_id": "nope"}) + result = await call_tool(mcp, "compute_envelope_spectrum", {"data_id": "nope"}) assert "error" in result @@ -100,22 +106,26 @@ async def test_missing_data_id(self): class TestAssessVibrationSeverity: @pytest.mark.anyio async def test_zone_a(self): - result = await call_tool(mcp, "assess_vibration_severity", - {"rms_velocity_mm_s": 0.5}) + result = await call_tool( + mcp, "assess_vibration_severity", {"rms_velocity_mm_s": 0.5} + ) assert result["iso_zone"] == "A" @pytest.mark.anyio async def test_zone_d(self): - result = await call_tool(mcp, "assess_vibration_severity", - {"rms_velocity_mm_s": 50.0}) + result = await call_tool( + mcp, "assess_vibration_severity", {"rms_velocity_mm_s": 50.0} + ) assert result["iso_zone"] == "D" @pytest.mark.anyio async def test_group_param(self): for grp in ("group1", "group2", "group3", "group4"): - result = await call_tool(mcp, "assess_vibration_severity", - {"rms_velocity_mm_s": 4.5, - "machine_group": grp}) + result = await call_tool( + mcp, + "assess_vibration_severity", + {"rms_velocity_mm_s": 4.5, "machine_group": grp}, + ) assert result["iso_zone"] in ("A", "B", "C", "D") @@ -142,13 +152,17 @@ async def test_returns_bearings(self): class TestCalculateBearingFrequencies: @pytest.mark.anyio async def test_basic(self): - result = await call_tool(mcp, "calculate_bearing_frequencies", { - "rpm": 1800, - "n_balls": 9, - "ball_diameter_mm": 7.94, - "pitch_diameter_mm": 39.04, - "contact_angle_deg": 0.0, - }) + result = await call_tool( + mcp, + "calculate_bearing_frequencies", + { + "rpm": 1800, + "n_balls": 9, + "ball_diameter_mm": 7.94, + "pitch_diameter_mm": 39.04, + "contact_angle_deg": 0.0, + }, + ) assert "bpfo_hz" in result assert "bpfi_hz" in result assert "bsf_hz" in result @@ -157,13 +171,17 @@ async def test_basic(self): @pytest.mark.anyio async def test_with_name(self): - result = await call_tool(mcp, "calculate_bearing_frequencies", { - "rpm": 3600, - "n_balls": 8, - "ball_diameter_mm": 10.0, - "pitch_diameter_mm": 46.0, - "bearing_name": "test-bearing", - }) + result = await call_tool( + mcp, + "calculate_bearing_frequencies", + { + "rpm": 3600, + "n_balls": 8, + "ball_diameter_mm": 10.0, + "pitch_diameter_mm": 46.0, + "bearing_name": "test-bearing", + }, + ) assert "bearing" in result assert result["bearing"] == "test-bearing" @@ -178,9 +196,13 @@ class TestDiagnoseVibration: async def test_no_rpm(self): """Without RPM we expect a partial result with a warning.""" data_id, _, _ = _make_sine(120.0, sr=4096.0, duration=2.0) - result = await call_tool(mcp, "diagnose_vibration", { - "data_id": data_id, - }) + result = await call_tool( + mcp, + "diagnose_vibration", + { + "data_id": data_id, + }, + ) assert "error" not in result assert "warning" in result assert result["shaft_features"] is None @@ -188,10 +210,14 @@ async def test_no_rpm(self): @pytest.mark.anyio async def test_with_rpm(self): data_id = _make_composite([30, 60, 90], sr=4096.0, duration=2.0) - result = await call_tool(mcp, "diagnose_vibration", { - "data_id": data_id, - "rpm": 1800.0, - }) + result = await call_tool( + mcp, + "diagnose_vibration", + { + "data_id": data_id, + "rpm": 1800.0, + }, + ) assert "error" not in result assert result["shaft_features"] is not None assert result["iso_10816"] is not None @@ -200,11 +226,15 @@ async def test_with_rpm(self): @pytest.mark.anyio async def test_with_bearing_designation(self): data_id = _make_composite([30, 60, 120], sr=4096.0, duration=2.0) - result = await call_tool(mcp, "diagnose_vibration", { - "data_id": data_id, - "rpm": 1800.0, - "bearing_designation": "6205", - }) + result = await call_tool( + mcp, + "diagnose_vibration", + { + "data_id": data_id, + "rpm": 1800.0, + "bearing_designation": "6205", + }, + ) assert "error" not in result assert result["bearing_info_source"] is not None assert "database" in result["bearing_info_source"] @@ -212,20 +242,23 @@ async def test_with_bearing_designation(self): @pytest.mark.anyio async def test_with_custom_bearing_geometry(self): data_id = _make_composite([30, 60], sr=4096.0, duration=2.0) - result = await call_tool(mcp, "diagnose_vibration", { - "data_id": data_id, - "rpm": 1800.0, - "bearing_n_balls": 9, - "bearing_ball_dia_mm": 7.94, - "bearing_pitch_dia_mm": 39.04, - }) + result = await call_tool( + mcp, + "diagnose_vibration", + { + "data_id": data_id, + "rpm": 1800.0, + "bearing_n_balls": 9, + "bearing_ball_dia_mm": 7.94, + "bearing_pitch_dia_mm": 39.04, + }, + ) assert "error" not in result assert result["bearing_info_source"] == "custom geometry" @pytest.mark.anyio async def test_missing_data_id(self): - result = await call_tool(mcp, "diagnose_vibration", - {"data_id": "ghost"}) + result = await call_tool(mcp, "diagnose_vibration", {"data_id": "ghost"}) assert "error" in result @@ -238,12 +271,16 @@ class TestGetVibrationData: @requires_couchdb @pytest.mark.anyio async def test_fetch_integration(self): - result = await call_tool(mcp, "get_vibration_data", { - "site_name": "MAIN", - "asset_id": "Motor_01", - "sensor_name": "Vibration_X", - "start": "2024-01-15T00:00:00", - }) + result = await call_tool( + mcp, + "get_vibration_data", + { + "site_name": "MAIN", + "asset_id": "Motor_01", + "sensor_name": "Vibration_X", + "start": "2024-01-15T00:00:00", + }, + ) assert "error" not in result assert "data_id" in result @@ -257,8 +294,12 @@ class TestListVibrationSensors: @requires_couchdb @pytest.mark.anyio async def test_list_integration(self): - result = await call_tool(mcp, "list_vibration_sensors", { - "site_name": "MAIN", - "asset_id": "Chiller 6", - }) + result = await call_tool( + mcp, + "list_vibration_sensors", + { + "site_name": "MAIN", + "asset_id": "Chiller 6", + }, + ) assert "sensors" in result or "error" in result diff --git a/src/servers/wo/data.py b/src/servers/wo/data.py index a2c6bf735..7427a421d 100644 --- a/src/servers/wo/data.py +++ b/src/servers/wo/data.py @@ -99,7 +99,10 @@ def load(dataset: str) -> Optional[pd.DataFrame]: df = pd.DataFrame(docs) # Drop internal CouchDB fields - df.drop(columns=[c for c in ("_id", "_rev", "dataset") if c in df.columns], inplace=True) + df.drop( + columns=[c for c in ("_id", "_rev", "dataset") if c in df.columns], + inplace=True, + ) # Parse date columns for col in _DATE_COLS.get(dataset, []): @@ -142,12 +145,16 @@ def parse_date(value: Optional[str]) -> Optional[datetime]: raise ValueError(f"date must be YYYY-MM-DD, got '{value}'") from exc -def date_conditions(equipment_id: str, date_col: str, start: Optional[str], end: Optional[str]) -> dict: +def date_conditions( + equipment_id: str, date_col: str, start: Optional[str], end: Optional[str] +) -> dict: """Build a filter-conditions dict for equipment + optional date range.""" start_dt = parse_date(start) end_dt = parse_date(end) cond: dict = { - "equipment_id": lambda x, eid=equipment_id: isinstance(x, str) and x.strip().lower() == eid.strip().lower() + "equipment_id": lambda x, eid=equipment_id: ( + isinstance(x, str) and x.strip().lower() == eid.strip().lower() + ) } if start_dt or end_dt: cond[date_col] = lambda x, s=start_dt, e=end_dt: ( @@ -184,10 +191,18 @@ def row_to_wo(row: Any) -> WorkOrderItem: equipment_id=str(row.get("equipment_id", "")), equipment_name=str(row.get("equipment_name", "")), preventive=str(row.get("preventive", "")).upper() == "TRUE", - work_priority=int(row["work_priority"]) if pd.notna(row.get("work_priority")) else None, - actual_finish=row["actual_finish"].isoformat() if pd.notna(row.get("actual_finish")) else None, - duration=str(row.get("duration", "")) if pd.notna(row.get("duration")) else None, - actual_labor_hours=str(row.get("actual_labor_hours", "")) if pd.notna(row.get("actual_labor_hours")) else None, + work_priority=int(row["work_priority"]) + if pd.notna(row.get("work_priority")) + else None, + actual_finish=row["actual_finish"].isoformat() + if pd.notna(row.get("actual_finish")) + else None, + duration=str(row.get("duration", "")) + if pd.notna(row.get("duration")) + else None, + actual_labor_hours=str(row.get("actual_labor_hours", "")) + if pd.notna(row.get("actual_labor_hours")) + else None, ) @@ -197,10 +212,14 @@ def row_to_event(row: Any) -> EventItem: event_group=str(row.get("event_group", "")), event_category=str(row.get("event_category", "")), event_type=str(row["event_type"]) if pd.notna(row.get("event_type")) else None, - description=str(row["description"]) if pd.notna(row.get("description")) else None, + description=str(row["description"]) + if pd.notna(row.get("description")) + else None, equipment_id=str(row.get("equipment_id", "")), equipment_name=str(row.get("equipment_name", "")), - event_time=row["event_time"].isoformat() if pd.notna(row.get("event_time")) else "", + event_time=row["event_time"].isoformat() + if pd.notna(row.get("event_time")) + else "", note=str(row["note"]) if pd.notna(row.get("note")) else None, ) diff --git a/src/servers/wo/main.py b/src/servers/wo/main.py index 1dbde8ed5..5b8ef9726 100644 --- a/src/servers/wo/main.py +++ b/src/servers/wo/main.py @@ -13,10 +13,15 @@ load_dotenv() -_log_level = getattr(logging, os.environ.get("LOG_LEVEL", "WARNING").upper(), logging.WARNING) +_log_level = getattr( + logging, os.environ.get("LOG_LEVEL", "WARNING").upper(), logging.WARNING +) logging.basicConfig(level=_log_level) -mcp = FastMCP("wo", instructions="Work order analytics: query work orders, events, failure codes, and predict maintenance patterns.") +mcp = FastMCP( + "wo", + instructions="Work order analytics: query work orders, events, failure codes, and predict maintenance patterns.", +) # Register tools — imported after mcp is created to avoid circular imports. from . import tools # noqa: E402 diff --git a/src/servers/wo/tests/conftest.py b/src/servers/wo/tests/conftest.py index 8ca2f3c30..e5b9e83c0 100644 --- a/src/servers/wo/tests/conftest.py +++ b/src/servers/wo/tests/conftest.py @@ -18,6 +18,7 @@ def _couchdb_reachable() -> bool: return False try: import requests + requests.get(url, timeout=2) return True except Exception: @@ -36,12 +37,27 @@ def _couchdb_reachable() -> bool: def _make_wo_df() -> pd.DataFrame: data = { "wo_id": ["WO001", "WO002", "WO003", "WO004"], - "wo_description": ["Oil Analysis", "Routine Maintenance", "Corrective Repair", "Emergency Fix"], + "wo_description": [ + "Oil Analysis", + "Routine Maintenance", + "Corrective Repair", + "Emergency Fix", + ], "collection": ["compressor", "compressor", "motor", "motor"], "primary_code": ["MT010", "MT001", "MT013", "MT013"], - "primary_code_description": ["Oil Analysis", "Routine Maintenance", "Corrective", "Corrective"], + "primary_code_description": [ + "Oil Analysis", + "Routine Maintenance", + "Corrective", + "Corrective", + ], "secondary_code": ["MT010b", "MT001a", "MT013a", "MT013b"], - "secondary_code_description": ["Routine Oil Analysis", "Basic Maint", "Repair", "Emergency"], + "secondary_code_description": [ + "Routine Oil Analysis", + "Basic Maint", + "Repair", + "Emergency", + ], "equipment_id": ["CWC04013", "CWC04013", "CWC04013", "CWC04007"], "equipment_name": ["Chiller 13", "Chiller 13", "Chiller 13", "Chiller 7"], "preventive": ["TRUE", "TRUE", "FALSE", "FALSE"], @@ -79,9 +95,17 @@ def _make_events_df() -> pd.DataFrame: def _make_failure_codes_df() -> pd.DataFrame: data = { - "category": ["Maintenance and Routine Checks", "Maintenance and Routine Checks", "Corrective"], + "category": [ + "Maintenance and Routine Checks", + "Maintenance and Routine Checks", + "Corrective", + ], "primary_code": ["MT010", "MT001", "MT013"], - "primary_code_description": ["Oil Analysis", "Routine Maintenance", "Corrective"], + "primary_code_description": [ + "Oil Analysis", + "Routine Maintenance", + "Corrective", + ], "secondary_code": ["MT010b", "MT001a", "MT013a"], "secondary_code_description": ["Routine Oil Analysis", "Basic Maint", "Repair"], } @@ -90,9 +114,17 @@ def _make_failure_codes_df() -> pd.DataFrame: def _make_primary_failure_codes_df() -> pd.DataFrame: data = { - "category": ["Maintenance and Routine Checks", "Maintenance and Routine Checks", "Corrective"], + "category": [ + "Maintenance and Routine Checks", + "Maintenance and Routine Checks", + "Corrective", + ], "primary_code": ["MT010", "MT001", "MT013"], - "primary_code_description": ["Oil Analysis", "Routine Maintenance", "Corrective"], + "primary_code_description": [ + "Oil Analysis", + "Routine Maintenance", + "Corrective", + ], } return pd.DataFrame(data) @@ -132,6 +164,7 @@ def _make_alert_events_df() -> pd.DataFrame: @pytest.fixture def mock_data(): """Patch load() in tools namespace to return fixture DataFrames without CouchDB.""" + def _fake_load(key: str): factory = _FIXTURE_DATA.get(key) return factory() if factory else None diff --git a/src/servers/wo/tests/test_integration.py b/src/servers/wo/tests/test_integration.py index 642b957b8..bf3175a31 100644 --- a/src/servers/wo/tests/test_integration.py +++ b/src/servers/wo/tests/test_integration.py @@ -12,10 +12,10 @@ from .conftest import requires_couchdb, call_tool # Real equipment IDs and rule IDs present in the sample dataset -EQUIPMENT_ID = "CWC04013" # 431 work orders in dataset -EQUIPMENT_RICH = "CWC04014" # 524 work orders — most records -EQUIPMENT_ALERT = "CWC04009" # has alert events with RUL0018 -RULE_ID = "RUL0018" # 183 alert events for CWC04009 +EQUIPMENT_ID = "CWC04013" # 431 work orders in dataset +EQUIPMENT_RICH = "CWC04014" # 524 work orders — most records +EQUIPMENT_ALERT = "CWC04009" # has alert events with RUL0018 +RULE_ID = "RUL0018" # 183 alert events for CWC04009 # --------------------------------------------------------------------------- @@ -34,11 +34,17 @@ async def test_returns_results(self): @pytest.mark.anyio async def test_date_range_narrows_results(self): - all_data = await call_tool(mcp, "get_work_orders", {"equipment_id": EQUIPMENT_ID}) + all_data = await call_tool( + mcp, "get_work_orders", {"equipment_id": EQUIPMENT_ID} + ) filtered = await call_tool( mcp, "get_work_orders", - {"equipment_id": EQUIPMENT_ID, "start_date": "2015-01-01", "end_date": "2017-12-31"}, + { + "equipment_id": EQUIPMENT_ID, + "start_date": "2015-01-01", + "end_date": "2017-12-31", + }, ) assert filtered["total"] < all_data["total"] assert filtered["total"] > 0 @@ -46,7 +52,14 @@ async def test_date_range_narrows_results(self): @pytest.mark.anyio async def test_each_wo_has_required_fields(self): data = await call_tool(mcp, "get_work_orders", {"equipment_id": EQUIPMENT_ID}) - required = {"wo_id", "wo_description", "equipment_id", "primary_code", "preventive", "actual_finish"} + required = { + "wo_id", + "wo_description", + "equipment_id", + "primary_code", + "preventive", + "actual_finish", + } for wo in data["work_orders"]: assert required <= wo.keys() assert wo["equipment_id"].upper() == EQUIPMENT_ID.upper() @@ -59,7 +72,9 @@ async def test_preventive_field_is_bool(self): @pytest.mark.anyio async def test_unknown_equipment_returns_error(self): - data = await call_tool(mcp, "get_work_orders", {"equipment_id": "DOES_NOT_EXIST"}) + data = await call_tool( + mcp, "get_work_orders", {"equipment_id": "DOES_NOT_EXIST"} + ) assert "error" in data @@ -72,7 +87,9 @@ async def test_unknown_equipment_returns_error(self): class TestGetPreventiveWorkOrdersLive: @pytest.mark.anyio async def test_all_results_are_preventive(self): - data = await call_tool(mcp, "get_preventive_work_orders", {"equipment_id": EQUIPMENT_ID}) + data = await call_tool( + mcp, "get_preventive_work_orders", {"equipment_id": EQUIPMENT_ID} + ) assert "work_orders" in data assert data["total"] > 0 for wo in data["work_orders"]: @@ -80,8 +97,12 @@ async def test_all_results_are_preventive(self): @pytest.mark.anyio async def test_count_less_than_all_work_orders(self): - all_data = await call_tool(mcp, "get_work_orders", {"equipment_id": EQUIPMENT_ID}) - prev_data = await call_tool(mcp, "get_preventive_work_orders", {"equipment_id": EQUIPMENT_ID}) + all_data = await call_tool( + mcp, "get_work_orders", {"equipment_id": EQUIPMENT_ID} + ) + prev_data = await call_tool( + mcp, "get_preventive_work_orders", {"equipment_id": EQUIPMENT_ID} + ) assert prev_data["total"] <= all_data["total"] @@ -94,7 +115,9 @@ async def test_count_less_than_all_work_orders(self): class TestGetCorrectiveWorkOrdersLive: @pytest.mark.anyio async def test_all_results_are_corrective(self): - data = await call_tool(mcp, "get_corrective_work_orders", {"equipment_id": EQUIPMENT_ID}) + data = await call_tool( + mcp, "get_corrective_work_orders", {"equipment_id": EQUIPMENT_ID} + ) assert "work_orders" in data assert data["total"] > 0 for wo in data["work_orders"]: @@ -102,9 +125,15 @@ async def test_all_results_are_corrective(self): @pytest.mark.anyio async def test_preventive_and_corrective_partition_all(self): - all_data = await call_tool(mcp, "get_work_orders", {"equipment_id": EQUIPMENT_ID}) - prev_data = await call_tool(mcp, "get_preventive_work_orders", {"equipment_id": EQUIPMENT_ID}) - corr_data = await call_tool(mcp, "get_corrective_work_orders", {"equipment_id": EQUIPMENT_ID}) + all_data = await call_tool( + mcp, "get_work_orders", {"equipment_id": EQUIPMENT_ID} + ) + prev_data = await call_tool( + mcp, "get_preventive_work_orders", {"equipment_id": EQUIPMENT_ID} + ) + corr_data = await call_tool( + mcp, "get_corrective_work_orders", {"equipment_id": EQUIPMENT_ID} + ) assert prev_data["total"] + corr_data["total"] == all_data["total"] @@ -131,7 +160,13 @@ async def test_event_groups_valid(self): @pytest.mark.anyio async def test_each_event_has_required_fields(self): data = await call_tool(mcp, "get_events", {"equipment_id": EQUIPMENT_ID}) - required = {"event_id", "event_group", "event_category", "equipment_id", "event_time"} + required = { + "event_id", + "event_group", + "event_category", + "equipment_id", + "event_time", + } for event in data["events"]: assert required <= event.keys() assert event["equipment_id"].upper() == EQUIPMENT_ID.upper() @@ -141,7 +176,11 @@ async def test_date_range_filters_events(self): data = await call_tool( mcp, "get_events", - {"equipment_id": EQUIPMENT_ID, "start_date": "2015-01-01", "end_date": "2015-12-31"}, + { + "equipment_id": EQUIPMENT_ID, + "start_date": "2015-01-01", + "end_date": "2015-12-31", + }, ) assert "events" in data assert data["total"] > 0 @@ -165,8 +204,13 @@ async def test_returns_codes(self): @pytest.mark.anyio async def test_required_fields_present(self): data = await call_tool(mcp, "get_failure_codes", {}) - required = {"category", "primary_code", "primary_code_description", - "secondary_code", "secondary_code_description"} + required = { + "category", + "primary_code", + "primary_code_description", + "secondary_code", + "secondary_code_description", + } for fc in data["failure_codes"]: assert required <= fc.keys() @@ -188,39 +232,59 @@ async def test_known_code_present(self): class TestGetWorkOrderDistributionLive: @pytest.mark.anyio async def test_returns_distribution(self): - data = await call_tool(mcp, "get_work_order_distribution", {"equipment_id": EQUIPMENT_ID}) + data = await call_tool( + mcp, "get_work_order_distribution", {"equipment_id": EQUIPMENT_ID} + ) assert "distribution" in data assert data["total_work_orders"] > 0 assert len(data["distribution"]) > 0 @pytest.mark.anyio async def test_counts_sum_to_total(self): - data = await call_tool(mcp, "get_work_order_distribution", {"equipment_id": EQUIPMENT_ID}) + data = await call_tool( + mcp, "get_work_order_distribution", {"equipment_id": EQUIPMENT_ID} + ) total_from_dist = sum(e["count"] for e in data["distribution"]) # distribution only counts entries matched in failure_codes; total_work_orders is the raw filter count assert total_from_dist <= data["total_work_orders"] @pytest.mark.anyio async def test_sorted_descending(self): - data = await call_tool(mcp, "get_work_order_distribution", {"equipment_id": EQUIPMENT_ID}) + data = await call_tool( + mcp, "get_work_order_distribution", {"equipment_id": EQUIPMENT_ID} + ) counts = [e["count"] for e in data["distribution"]] assert counts == sorted(counts, reverse=True) @pytest.mark.anyio async def test_distribution_fields_present(self): - data = await call_tool(mcp, "get_work_order_distribution", {"equipment_id": EQUIPMENT_ID}) - required = {"category", "primary_code", "primary_code_description", - "secondary_code", "secondary_code_description", "count"} + data = await call_tool( + mcp, "get_work_order_distribution", {"equipment_id": EQUIPMENT_ID} + ) + required = { + "category", + "primary_code", + "primary_code_description", + "secondary_code", + "secondary_code_description", + "count", + } for entry in data["distribution"]: assert required <= entry.keys() @pytest.mark.anyio async def test_date_range_reduces_total(self): - all_data = await call_tool(mcp, "get_work_order_distribution", {"equipment_id": EQUIPMENT_RICH}) + all_data = await call_tool( + mcp, "get_work_order_distribution", {"equipment_id": EQUIPMENT_RICH} + ) filtered = await call_tool( mcp, "get_work_order_distribution", - {"equipment_id": EQUIPMENT_RICH, "start_date": "2016-01-01", "end_date": "2016-12-31"}, + { + "equipment_id": EQUIPMENT_RICH, + "start_date": "2016-01-01", + "end_date": "2016-12-31", + }, ) assert filtered["total_work_orders"] < all_data["total_work_orders"] @@ -234,36 +298,51 @@ async def test_date_range_reduces_total(self): class TestPredictNextWorkOrderLive: @pytest.mark.anyio async def test_returns_predictions(self): - data = await call_tool(mcp, "predict_next_work_order", {"equipment_id": EQUIPMENT_RICH}) + data = await call_tool( + mcp, "predict_next_work_order", {"equipment_id": EQUIPMENT_RICH} + ) assert "predictions" in data assert "last_work_order_type" in data assert len(data["predictions"]) > 0 @pytest.mark.anyio async def test_probabilities_sum_to_one(self): - data = await call_tool(mcp, "predict_next_work_order", {"equipment_id": EQUIPMENT_RICH}) + data = await call_tool( + mcp, "predict_next_work_order", {"equipment_id": EQUIPMENT_RICH} + ) if "predictions" in data: total = sum(p["probability"] for p in data["predictions"]) assert abs(total - 1.0) < 1e-6 @pytest.mark.anyio async def test_prediction_fields_present(self): - data = await call_tool(mcp, "predict_next_work_order", {"equipment_id": EQUIPMENT_RICH}) + data = await call_tool( + mcp, "predict_next_work_order", {"equipment_id": EQUIPMENT_RICH} + ) if "predictions" in data: - required = {"category", "primary_code", "primary_code_description", "probability"} + required = { + "category", + "primary_code", + "primary_code_description", + "probability", + } for pred in data["predictions"]: assert required <= pred.keys() @pytest.mark.anyio async def test_probabilities_between_zero_and_one(self): - data = await call_tool(mcp, "predict_next_work_order", {"equipment_id": EQUIPMENT_RICH}) + data = await call_tool( + mcp, "predict_next_work_order", {"equipment_id": EQUIPMENT_RICH} + ) if "predictions" in data: for pred in data["predictions"]: assert 0.0 <= pred["probability"] <= 1.0 @pytest.mark.anyio async def test_unknown_equipment_returns_error(self): - data = await call_tool(mcp, "predict_next_work_order", {"equipment_id": "DOES_NOT_EXIST"}) + data = await call_tool( + mcp, "predict_next_work_order", {"equipment_id": "DOES_NOT_EXIST"} + ) assert "error" in data @@ -316,7 +395,9 @@ async def test_work_order_transition_has_avg_hours(self): {"equipment_id": EQUIPMENT_ALERT, "rule_id": RULE_ID}, ) if "transitions" in data: - wo_transitions = [t for t in data["transitions"] if t["transition"] == "WORK_ORDER"] + wo_transitions = [ + t for t in data["transitions"] if t["transition"] == "WORK_ORDER" + ] for t in wo_transitions: assert t["average_hours_to_maintenance"] is not None assert t["average_hours_to_maintenance"] > 0 diff --git a/src/servers/wo/tests/test_tools.py b/src/servers/wo/tests/test_tools.py index 6528c9bfd..0745c2947 100644 --- a/src/servers/wo/tests/test_tools.py +++ b/src/servers/wo/tests/test_tools.py @@ -32,7 +32,11 @@ async def test_date_range_filter(self, mock_data): data = await call_tool( mcp, "get_work_orders", - {"equipment_id": "CWC04013", "start_date": "2017-01-01", "end_date": "2017-12-31"}, + { + "equipment_id": "CWC04013", + "start_date": "2017-01-01", + "end_date": "2017-12-31", + }, ) assert data["total"] == 3 for wo in data["work_orders"]: @@ -41,7 +45,9 @@ async def test_date_range_filter(self, mock_data): @pytest.mark.anyio async def test_invalid_date(self, mock_data): data = await call_tool( - mcp, "get_work_orders", {"equipment_id": "CWC04013", "start_date": "not-a-date"} + mcp, + "get_work_orders", + {"equipment_id": "CWC04013", "start_date": "not-a-date"}, ) assert "error" in data @@ -49,7 +55,13 @@ async def test_invalid_date(self, mock_data): async def test_work_order_fields_present(self, mock_data): data = await call_tool(mcp, "get_work_orders", {"equipment_id": "CWC04013"}) wo = data["work_orders"][0] - for field in ("wo_id", "wo_description", "primary_code", "preventive", "equipment_id"): + for field in ( + "wo_id", + "wo_description", + "primary_code", + "preventive", + "equipment_id", + ): assert field in wo @requires_couchdb @@ -58,7 +70,11 @@ async def test_integration_cwc04013_2017(self): data = await call_tool( mcp, "get_work_orders", - {"equipment_id": "CWC04013", "start_date": "2017-01-01", "end_date": "2017-12-31"}, + { + "equipment_id": "CWC04013", + "start_date": "2017-01-01", + "end_date": "2017-12-31", + }, ) assert "work_orders" in data assert data["total"] > 0 @@ -72,14 +88,18 @@ async def test_integration_cwc04013_2017(self): class TestGetPreventiveWorkOrders: @pytest.mark.anyio async def test_returns_only_preventive(self, mock_data): - data = await call_tool(mcp, "get_preventive_work_orders", {"equipment_id": "CWC04013"}) + data = await call_tool( + mcp, "get_preventive_work_orders", {"equipment_id": "CWC04013"} + ) assert data["total"] == 2 for wo in data["work_orders"]: assert wo["preventive"] is True @pytest.mark.anyio async def test_unknown_equipment(self, mock_data): - data = await call_tool(mcp, "get_preventive_work_orders", {"equipment_id": "UNKNOWN"}) + data = await call_tool( + mcp, "get_preventive_work_orders", {"equipment_id": "UNKNOWN"} + ) assert "error" in data @requires_couchdb @@ -88,7 +108,11 @@ async def test_integration(self): data = await call_tool( mcp, "get_preventive_work_orders", - {"equipment_id": "CWC04013", "start_date": "2017-01-01", "end_date": "2017-12-31"}, + { + "equipment_id": "CWC04013", + "start_date": "2017-01-01", + "end_date": "2017-12-31", + }, ) assert "work_orders" in data for wo in data["work_orders"]: @@ -103,14 +127,18 @@ async def test_integration(self): class TestGetCorrectiveWorkOrders: @pytest.mark.anyio async def test_returns_only_corrective(self, mock_data): - data = await call_tool(mcp, "get_corrective_work_orders", {"equipment_id": "CWC04013"}) + data = await call_tool( + mcp, "get_corrective_work_orders", {"equipment_id": "CWC04013"} + ) assert data["total"] == 1 for wo in data["work_orders"]: assert wo["preventive"] is False @pytest.mark.anyio async def test_unknown_equipment(self, mock_data): - data = await call_tool(mcp, "get_corrective_work_orders", {"equipment_id": "UNKNOWN"}) + data = await call_tool( + mcp, "get_corrective_work_orders", {"equipment_id": "UNKNOWN"} + ) assert "error" in data @requires_couchdb @@ -119,7 +147,11 @@ async def test_integration(self): data = await call_tool( mcp, "get_corrective_work_orders", - {"equipment_id": "CWC04013", "start_date": "2017-01-01", "end_date": "2017-12-31"}, + { + "equipment_id": "CWC04013", + "start_date": "2017-01-01", + "end_date": "2017-12-31", + }, ) assert "work_orders" in data for wo in data["work_orders"]: @@ -149,7 +181,11 @@ async def test_date_range(self, mock_data): data = await call_tool( mcp, "get_events", - {"equipment_id": "CWC04013", "start_date": "2017-07-01", "end_date": "2017-12-31"}, + { + "equipment_id": "CWC04013", + "start_date": "2017-07-01", + "end_date": "2017-12-31", + }, ) assert data["total"] == 2 @@ -178,7 +214,12 @@ async def test_returns_codes(self, mock_data): async def test_fields_present(self, mock_data): data = await call_tool(mcp, "get_failure_codes", {}) fc = data["failure_codes"][0] - for field in ("category", "primary_code", "primary_code_description", "secondary_code"): + for field in ( + "category", + "primary_code", + "primary_code_description", + "secondary_code", + ): assert field in fc @requires_couchdb @@ -196,12 +237,16 @@ async def test_integration(self): class TestGetWorkOrderDistribution: @pytest.mark.anyio async def test_unknown_equipment(self, mock_data): - data = await call_tool(mcp, "get_work_order_distribution", {"equipment_id": "UNKNOWN"}) + data = await call_tool( + mcp, "get_work_order_distribution", {"equipment_id": "UNKNOWN"} + ) assert "error" in data @pytest.mark.anyio async def test_distribution_counts(self, mock_data): - data = await call_tool(mcp, "get_work_order_distribution", {"equipment_id": "CWC04013"}) + data = await call_tool( + mcp, "get_work_order_distribution", {"equipment_id": "CWC04013"} + ) assert data["total_work_orders"] == 3 codes = {e["primary_code"]: e["count"] for e in data["distribution"]} assert codes.get("MT010") == 1 @@ -210,7 +255,9 @@ async def test_distribution_counts(self, mock_data): @pytest.mark.anyio async def test_sorted_descending(self, mock_data): - data = await call_tool(mcp, "get_work_order_distribution", {"equipment_id": "CWC04013"}) + data = await call_tool( + mcp, "get_work_order_distribution", {"equipment_id": "CWC04013"} + ) counts = [e["count"] for e in data["distribution"]] assert counts == sorted(counts, reverse=True) @@ -220,7 +267,11 @@ async def test_integration(self): data = await call_tool( mcp, "get_work_order_distribution", - {"equipment_id": "CWC04013", "start_date": "2017-01-01", "end_date": "2017-12-31"}, + { + "equipment_id": "CWC04013", + "start_date": "2017-01-01", + "end_date": "2017-12-31", + }, ) assert "distribution" in data assert data["total_work_orders"] > 0 @@ -234,12 +285,16 @@ async def test_integration(self): class TestPredictNextWorkOrder: @pytest.mark.anyio async def test_unknown_equipment(self, mock_data): - data = await call_tool(mcp, "predict_next_work_order", {"equipment_id": "UNKNOWN"}) + data = await call_tool( + mcp, "predict_next_work_order", {"equipment_id": "UNKNOWN"} + ) assert "error" in data @pytest.mark.anyio async def test_returns_predictions(self, mock_data): - data = await call_tool(mcp, "predict_next_work_order", {"equipment_id": "CWC04013"}) + data = await call_tool( + mcp, "predict_next_work_order", {"equipment_id": "CWC04013"} + ) # Should either return predictions or an error about transition data assert "predictions" in data or "error" in data if "predictions" in data: @@ -248,7 +303,9 @@ async def test_returns_predictions(self, mock_data): @pytest.mark.anyio async def test_probabilities_sum_to_one(self, mock_data): - data = await call_tool(mcp, "predict_next_work_order", {"equipment_id": "CWC04013"}) + data = await call_tool( + mcp, "predict_next_work_order", {"equipment_id": "CWC04013"} + ) if "predictions" in data and data["predictions"]: total = sum(p["probability"] for p in data["predictions"]) assert abs(total - 1.0) < 1e-6 @@ -256,7 +313,9 @@ async def test_probabilities_sum_to_one(self, mock_data): @requires_couchdb @pytest.mark.anyio async def test_integration(self): - data = await call_tool(mcp, "predict_next_work_order", {"equipment_id": "CWC04013"}) + data = await call_tool( + mcp, "predict_next_work_order", {"equipment_id": "CWC04013"} + ) assert "predictions" in data or "error" in data @@ -269,14 +328,18 @@ class TestAnalyzeAlertToFailure: @pytest.mark.anyio async def test_unknown_rule(self, mock_data): data = await call_tool( - mcp, "analyze_alert_to_failure", {"equipment_id": "CWC04013", "rule_id": "UNKNOWN"} + mcp, + "analyze_alert_to_failure", + {"equipment_id": "CWC04013", "rule_id": "UNKNOWN"}, ) assert "error" in data @pytest.mark.anyio async def test_returns_transitions(self, mock_data): data = await call_tool( - mcp, "analyze_alert_to_failure", {"equipment_id": "CWC04013", "rule_id": "CR00002"} + mcp, + "analyze_alert_to_failure", + {"equipment_id": "CWC04013", "rule_id": "CR00002"}, ) # fixture only has 3 rows so transitions may be empty or present assert "transitions" in data or "error" in data @@ -284,7 +347,9 @@ async def test_returns_transitions(self, mock_data): @pytest.mark.anyio async def test_probabilities_valid(self, mock_data): data = await call_tool( - mcp, "analyze_alert_to_failure", {"equipment_id": "CWC04013", "rule_id": "CR00002"} + mcp, + "analyze_alert_to_failure", + {"equipment_id": "CWC04013", "rule_id": "CR00002"}, ) if "transitions" in data and data["transitions"]: total_prob = sum(t["probability"] for t in data["transitions"]) @@ -294,6 +359,8 @@ async def test_probabilities_valid(self, mock_data): @pytest.mark.anyio async def test_integration(self): data = await call_tool( - mcp, "analyze_alert_to_failure", {"equipment_id": "CWC04013", "rule_id": "CR00002"} + mcp, + "analyze_alert_to_failure", + {"equipment_id": "CWC04013", "rule_id": "CR00002"}, ) assert "transitions" in data or "error" in data diff --git a/src/servers/wo/tools.py b/src/servers/wo/tools.py index 0473edc76..4dfa0c029 100644 --- a/src/servers/wo/tools.py +++ b/src/servers/wo/tools.py @@ -8,7 +8,6 @@ from collections import Counter from typing import List, Optional, Union -import pandas as pd from .data import ( date_conditions, @@ -54,7 +53,9 @@ def get_work_orders( except ValueError as exc: return ErrorResult(error=str(exc)) if not wos: - return ErrorResult(error=f"No work orders found for equipment_id '{equipment_id}'") + return ErrorResult( + error=f"No work orders found for equipment_id '{equipment_id}'" + ) return WorkOrdersResult( equipment_id=equipment_id, start_date=start_date, @@ -81,11 +82,15 @@ def get_preventive_work_orders( if df is None: return ErrorResult(error="Work order data not available") try: - wos = fetch_work_orders(df[df["preventive"] == "TRUE"], equipment_id, start_date, end_date) + wos = fetch_work_orders( + df[df["preventive"] == "TRUE"], equipment_id, start_date, end_date + ) except ValueError as exc: return ErrorResult(error=str(exc)) if not wos: - return ErrorResult(error=f"No preventive work orders found for equipment_id '{equipment_id}'") + return ErrorResult( + error=f"No preventive work orders found for equipment_id '{equipment_id}'" + ) return WorkOrdersResult( equipment_id=equipment_id, start_date=start_date, @@ -112,11 +117,15 @@ def get_corrective_work_orders( if df is None: return ErrorResult(error="Work order data not available") try: - wos = fetch_work_orders(df[df["preventive"] == "FALSE"], equipment_id, start_date, end_date) + wos = fetch_work_orders( + df[df["preventive"] == "FALSE"], equipment_id, start_date, end_date + ) except ValueError as exc: return ErrorResult(error=str(exc)) if not wos: - return ErrorResult(error=f"No corrective work orders found for equipment_id '{equipment_id}'") + return ErrorResult( + error=f"No corrective work orders found for equipment_id '{equipment_id}'" + ) return WorkOrdersResult( equipment_id=equipment_id, start_date=start_date, @@ -149,7 +158,9 @@ def get_events( return ErrorResult(error=str(exc)) cond: dict = { - "equipment_id": lambda x, eid=equipment_id: isinstance(x, str) and x.strip().lower() == eid.strip().lower() + "equipment_id": lambda x, eid=equipment_id: ( + isinstance(x, str) and x.strip().lower() == eid.strip().lower() + ) } if start_dt or end_dt: cond["event_time"] = lambda x, s=start_dt, e=end_dt: ( @@ -224,7 +235,9 @@ def get_work_order_distribution( filtered = filtered[filtered["actual_finish"] <= end_dt] if filtered.empty: - return ErrorResult(error=f"No work orders found for equipment_id '{equipment_id}'") + return ErrorResult( + error=f"No work orders found for equipment_id '{equipment_id}'" + ) counts = ( filtered.groupby(["primary_code", "secondary_code"]) @@ -293,20 +306,31 @@ def predict_next_work_order( cond = date_conditions(equipment_id, "actual_finish", start_date, end_date) filtered = filter_df(wo_df, cond) if filtered is None or filtered.empty: - return ErrorResult(error=f"No historical work orders found for equipment_id '{equipment_id}'") + return ErrorResult( + error=f"No historical work orders found for equipment_id '{equipment_id}'" + ) filtered = filtered.sort_values("actual_finish").reset_index(drop=True) transition_matrix = get_transition_matrix(filtered, "primary_code") last_type = filtered.iloc[-1]["primary_code"] if last_type not in transition_matrix.index: - return ErrorResult(error=f"No transition data for last work order type '{last_type}'") + return ErrorResult( + error=f"No transition data for last work order type '{last_type}'" + ) - raw = sorted(transition_matrix.loc[last_type].items(), key=lambda t: t[1], reverse=True) + raw = sorted( + transition_matrix.loc[last_type].items(), key=lambda t: t[1], reverse=True + ) predictions: List[NextWorkOrderEntry] = [] for primary_code, prob in raw: - entry = NextWorkOrderEntry(category="", primary_code=primary_code, primary_code_description="", probability=float(prob)) + entry = NextWorkOrderEntry( + category="", + primary_code=primary_code, + primary_code_description="", + probability=float(prob), + ) if pfc_df is not None: match = pfc_df[pfc_df["primary_code"] == primary_code] if not match.empty: @@ -357,23 +381,34 @@ def analyze_alert_to_failure( return ErrorResult(error=str(exc)) cond: dict = { - "equipment_id": lambda x, eid=equipment_id: isinstance(x, str) and x.strip().lower() == eid.strip().lower(), - "rule_id": lambda x, rid=rule_id: isinstance(x, str) and x.strip().lower() == rid.strip().lower(), + "equipment_id": lambda x, eid=equipment_id: ( + isinstance(x, str) and x.strip().lower() == eid.strip().lower() + ), + "rule_id": lambda x, rid=rule_id: ( + isinstance(x, str) and x.strip().lower() == rid.strip().lower() + ), } filtered = filter_df(alert_df, cond) if filtered is None or filtered.empty: - return ErrorResult(error=f"No alert events found for equipment '{equipment_id}' and rule '{rule_id}'") + return ErrorResult( + error=f"No alert events found for equipment '{equipment_id}' and rule '{rule_id}'" + ) filtered = filtered.sort_values("start_time").reset_index(drop=True) transitions: List[str] = [] time_diffs: List[float] = [] for i in range(len(filtered) - 1): - if str(filtered.iloc[i].get("rule_id", "")).strip().lower() == rule_id.strip().lower(): + if ( + str(filtered.iloc[i].get("rule_id", "")).strip().lower() + == rule_id.strip().lower() + ): for j in range(i + 1, len(filtered)): if str(filtered.iloc[j].get("event_group", "")).upper() == "WORK_ORDER": transitions.append("WORK_ORDER") - diff = filtered.iloc[j]["start_time"] - filtered.iloc[i]["start_time"] + diff = ( + filtered.iloc[j]["start_time"] - filtered.iloc[i]["start_time"] + ) time_diffs.append(diff.total_seconds() / 3600) break else: @@ -387,7 +422,11 @@ def analyze_alert_to_failure( entries: List[AlertToFailureEntry] = [] for transition, count in sorted(counts.items(), key=lambda t: t[1], reverse=True): - avg_hours = sum(time_diffs) / len(time_diffs) if transition == "WORK_ORDER" and time_diffs else None + avg_hours = ( + sum(time_diffs) / len(time_diffs) + if transition == "WORK_ORDER" and time_diffs + else None + ) entries.append( AlertToFailureEntry( transition=transition, diff --git a/uv.lock b/uv.lock index f12494f10..3e7f111c2 100644 --- a/uv.lock +++ b/uv.lock @@ -198,6 +198,7 @@ dev = [ { name = "opentelemetry-sdk" }, { name = "pytest" }, { name = "pytest-anyio" }, + { name = "ruff" }, ] otel = [ { name = "opentelemetry-api" }, @@ -238,6 +239,7 @@ dev = [ { name = "opentelemetry-sdk", specifier = ">=1.27.0" }, { name = "pytest", specifier = ">=9.0.2" }, { name = "pytest-anyio", specifier = ">=0.0.0" }, + { name = "ruff", specifier = ">=0.9.0" }, ] otel = [ { name = "opentelemetry-api", specifier = ">=1.27.0" }, @@ -3144,6 +3146,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/02/fa464cdfbe6b26e0600b62c528b72d8608f5cc49f96b8d6e38c95d60c676/rpds_py-0.30.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27f4b0e92de5bfbc6f86e43959e6edd1425c33b5e69aab0984a72047f2bcf1e3", size = 226532, upload-time = "2025-11-30T20:24:14.634Z" }, ] +[[package]] +name = "ruff" +version = "0.15.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/dc/8a/8bce2894573e9dae6ff4d77fe34ad727d79b9e6238ad288c5638990d90f6/ruff-0.15.14.tar.gz", hash = "sha256:48e866b165be4a9bdbf310f7d3c9a07edef2fe8cd63ffeb4e00bb590506ebf9f", size = 4700910, upload-time = "2026-05-21T14:34:55.177Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b9/c8/74a92c6ff9fcfb4f1f947126d3ebee8389276e161ecc85de5bda7cda51bd/ruff-0.15.14-py3-none-linux_armv6l.whl", hash = "sha256:8dd2db9416e487c8d4b01fa7056bb02c4d05969d4f8d17a08c229c2f4ff3c108", size = 10739177, upload-time = "2026-05-21T14:34:37.332Z" }, + { url = "https://files.pythonhosted.org/packages/45/91/254a35c20acc38a7223c9d2d594af12e794432464f2cdeb52af1dc4a892d/ruff-0.15.14-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:be4ff55af755bd71a00ab3dc6bd7ffc467bd76e0df6881e286c2e3d23e8fb43b", size = 11144969, upload-time = "2026-05-21T14:34:43.978Z" }, + { url = "https://files.pythonhosted.org/packages/56/9e/d13e40f83b8d0a94430e6778ce1d94a43b38cf2efe63278bdd2b4c65abbf/ruff-0.15.14-py3-none-macosx_11_0_arm64.whl", hash = "sha256:48d5909d7d06276ce7dde6d32bfa4b0d4cb2651145cd8ee4b440722cbc77832f", size = 10478207, upload-time = "2026-05-21T14:34:48.378Z" }, + { url = "https://files.pythonhosted.org/packages/8d/f1/b15a7839fa4f332f8acec78e20564f26bb2d866e3d21710b877fd0263000/ruff-0.15.14-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca8cbfa94c4f90984a67561978602746d4cd27103568f745fa90eee3f0d4107d", size = 10818459, upload-time = "2026-05-21T14:34:22.318Z" }, + { url = "https://files.pythonhosted.org/packages/45/33/53d651177f84f94b400a0e27f8824eeada3dddc9d5ee8aeb048f4352a520/ruff-0.15.14-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9a6bbc0333f1ab053423bcbf6226477d266ca7cec7738c4c8e3f55647803f3c4", size = 10541800, upload-time = "2026-05-21T14:34:20.209Z" }, + { url = "https://files.pythonhosted.org/packages/b8/a6/868f87e0bf9786ed24b5d0d0ad8676b8a94fd1912f42cddf9cfc7857818a/ruff-0.15.14-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8a24a4f7605d7003a6674d4387651effd939dead3fddd0f36561eb77a9a2e542", size = 11342149, upload-time = "2026-05-21T14:34:46.365Z" }, + { url = "https://files.pythonhosted.org/packages/a7/8b/38cd5c19faffdcc05a408d2b78edccc69492ab9720eadb49ea15ef80d768/ruff-0.15.14-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:049b5326e53ed80978f2fc041a280603f69dd6b0c95464342a2bb4572d9d9e2f", size = 12212563, upload-time = "2026-05-21T14:34:28.579Z" }, + { url = "https://files.pythonhosted.org/packages/3e/4d/a3c5b874a556d5731e3e657aaf04311bb76f0a5c3ec220ed43051be6b64b/ruff-0.15.14-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d4ed42e6696c8dfa5f06728e6441993901f548eb92d73bc472cb5a38d1395fbf", size = 11493299, upload-time = "2026-05-21T14:34:41.836Z" }, + { url = "https://files.pythonhosted.org/packages/1e/c0/56472c251d09858a53e51efbd485b09e1995d8731668b76d52e5dd6ee0f1/ruff-0.15.14-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:715c543cf450c4888251f91c52f1942a800541d9bddd7ac060aa4e6b77ae7cba", size = 11455931, upload-time = "2026-05-21T14:34:57.276Z" }, + { url = "https://files.pythonhosted.org/packages/2c/4a/e2e7b4d8dbf233d4eace59c75bc3435fa6d8bd3bae82d351d4e4300c0fd1/ruff-0.15.14-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:72ebab6013ec887d439d8b7593737a0a4ffb06d45d209d4e4bf2e92813082d3f", size = 11400794, upload-time = "2026-05-21T14:34:39.773Z" }, + { url = "https://files.pythonhosted.org/packages/97/c7/83c0539fe34c3e09136204d1e75d6052492364e0b3cb05e9465423f567d7/ruff-0.15.14-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:49072d36abdbe97a8dd7f480afe9c675699c0c495d4c84076e2c1203c4550581", size = 10804759, upload-time = "2026-05-21T14:34:31.045Z" }, + { url = "https://files.pythonhosted.org/packages/86/a6/18f2bfc095a2ab4a78745644e428205532ce6653a5d0fa8501572891534d/ruff-0.15.14-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:958522aee105068640c2c2ceae08f413ae44d922f52a1374ac13d6a96032fc93", size = 10539517, upload-time = "2026-05-21T14:34:53.064Z" }, + { url = "https://files.pythonhosted.org/packages/54/3a/5a8b3b69c654d4e4bf1d246ac5b49cbcdac6eaab6905925f8915f31e3b80/ruff-0.15.14-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f3707da619a143a2e8830e2abab8224478d69ace2d28cb6c20543ae97c36bf61", size = 11065169, upload-time = "2026-05-21T14:34:24.484Z" }, + { url = "https://files.pythonhosted.org/packages/ed/c5/8864e4e7925b836ea354b31d57641ec03830564e281a8b6f061f8c3e0ec1/ruff-0.15.14-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:bb01d645694e3ec0102105d07ef2d53703970407d59c04e59d3ba0b7a1d53553", size = 11560214, upload-time = "2026-05-21T14:34:50.975Z" }, + { url = "https://files.pythonhosted.org/packages/36/38/012bf76752e1f89ed50b77b99532d90f3a3e287bc7918e1fc0948ac866ac/ruff-0.15.14-py3-none-win32.whl", hash = "sha256:6d0c1ad2a0ab718d39b6d8fd2217981ce4d625cd96a720095f798fb47d8b13e6", size = 10805548, upload-time = "2026-05-21T14:34:33.453Z" }, + { url = "https://files.pythonhosted.org/packages/d1/b7/4ea2c170f10ad760fff2a5250beb18897719dc8b52b53a24cddbb9dd3f19/ruff-0.15.14-py3-none-win_amd64.whl", hash = "sha256:802342981e056db3851a7836e5b070f8f15f67d4a685ae2a6160939d364b2902", size = 11939523, upload-time = "2026-05-21T14:34:18.077Z" }, + { url = "https://files.pythonhosted.org/packages/62/d5/bc97ff895ec35cf3925d4bd60f3b39d822f377a446906ec9bcc87405e59b/ruff-0.15.14-py3-none-win_arm64.whl", hash = "sha256:ff47b90a9ef6a40c9e2f3b479c1fb78531adf055b94c1eba0a7ba04b31951826", size = 11208607, upload-time = "2026-05-21T14:34:26.525Z" }, +] + [[package]] name = "safetensors" version = "0.7.0"