diff --git a/slack_bolt/agent/agent.py b/slack_bolt/agent/agent.py index db1c78aa9..3663b245b 100644 --- a/slack_bolt/agent/agent.py +++ b/slack_bolt/agent/agent.py @@ -1,6 +1,7 @@ -from typing import Optional +from typing import List, Optional from slack_sdk import WebClient +from slack_sdk.web import SlackResponse from slack_sdk.web.chat_stream import ChatStream @@ -71,3 +72,32 @@ def chat_stream( recipient_user_id=recipient_user_id or self._user_id, **kwargs, ) + + def set_status( + self, + *, + status: str, + loading_messages: Optional[List[str]] = None, + channel: Optional[str] = None, + thread_ts: Optional[str] = None, + **kwargs, + ) -> SlackResponse: + """Sets the status of an assistant thread. + + Args: + status: The status text to display. + loading_messages: Optional list of loading messages to cycle through. + channel: Channel ID. Defaults to the channel from the event context. + thread_ts: Thread timestamp. Defaults to the thread_ts from the event context. + **kwargs: Additional arguments passed to ``WebClient.assistant_threads_setStatus()``. + + Returns: + ``SlackResponse`` from the API call. + """ + return self._client.assistant_threads_setStatus( + channel_id=channel or self._channel_id, # type: ignore[arg-type] + thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type] + status=status, + loading_messages=loading_messages, + **kwargs, + ) diff --git a/slack_bolt/agent/async_agent.py b/slack_bolt/agent/async_agent.py index 2ee15aa2e..5b86533e6 100644 --- a/slack_bolt/agent/async_agent.py +++ b/slack_bolt/agent/async_agent.py @@ -1,6 +1,6 @@ -from typing import Optional +from typing import List, Optional -from slack_sdk.web.async_client import AsyncWebClient +from slack_sdk.web.async_client import AsyncSlackResponse, AsyncWebClient from slack_sdk.web.async_chat_stream import AsyncChatStream @@ -68,3 +68,32 @@ async def chat_stream( recipient_user_id=recipient_user_id or self._user_id, **kwargs, ) + + async def set_status( + self, + *, + status: str, + loading_messages: Optional[List[str]] = None, + channel: Optional[str] = None, + thread_ts: Optional[str] = None, + **kwargs, + ) -> AsyncSlackResponse: + """Sets the status of an assistant thread. + + Args: + status: The status text to display. + loading_messages: Optional list of loading messages to cycle through. + channel: Channel ID. Defaults to the channel from the event context. + thread_ts: Thread timestamp. Defaults to the thread_ts from the event context. + **kwargs: Additional arguments passed to ``AsyncWebClient.assistant_threads_setStatus()``. + + Returns: + ``AsyncSlackResponse`` from the API call. + """ + return await self._client.assistant_threads_setStatus( + channel_id=channel or self._channel_id, # type: ignore[arg-type] + thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type] + status=status, + loading_messages=loading_messages, + **kwargs, + ) diff --git a/tests/slack_bolt/agent/test_agent.py b/tests/slack_bolt/agent/test_agent.py index 00e998379..7dad481b0 100644 --- a/tests/slack_bolt/agent/test_agent.py +++ b/tests/slack_bolt/agent/test_agent.py @@ -92,6 +92,111 @@ def test_chat_stream_passes_extra_kwargs(self): buffer_size=512, ) + def test_set_status_uses_context_defaults(self): + """BoltAgent.set_status() passes context defaults to WebClient.assistant_threads_setStatus().""" + client = MagicMock(spec=WebClient) + client.assistant_threads_setStatus.return_value = MagicMock() + + agent = BoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + agent.set_status(status="Thinking...") + + client.assistant_threads_setStatus.assert_called_once_with( + channel_id="C111", + thread_ts="1234567890.123456", + status="Thinking...", + loading_messages=None, + ) + + def test_set_status_with_loading_messages(self): + """BoltAgent.set_status() forwards loading_messages.""" + client = MagicMock(spec=WebClient) + client.assistant_threads_setStatus.return_value = MagicMock() + + agent = BoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + agent.set_status( + status="Thinking...", + loading_messages=["Sitting...", "Waiting..."], + ) + + client.assistant_threads_setStatus.assert_called_once_with( + channel_id="C111", + thread_ts="1234567890.123456", + status="Thinking...", + loading_messages=["Sitting...", "Waiting..."], + ) + + def test_set_status_overrides_context_defaults(self): + """Explicit channel/thread_ts override context defaults.""" + client = MagicMock(spec=WebClient) + client.assistant_threads_setStatus.return_value = MagicMock() + + agent = BoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + agent.set_status( + status="Thinking...", + channel="C999", + thread_ts="9999999999.999999", + ) + + client.assistant_threads_setStatus.assert_called_once_with( + channel_id="C999", + thread_ts="9999999999.999999", + status="Thinking...", + loading_messages=None, + ) + + def test_set_status_passes_extra_kwargs(self): + """Extra kwargs are forwarded to WebClient.assistant_threads_setStatus().""" + client = MagicMock(spec=WebClient) + client.assistant_threads_setStatus.return_value = MagicMock() + + agent = BoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + agent.set_status(status="Thinking...", token="xoxb-override") + + client.assistant_threads_setStatus.assert_called_once_with( + channel_id="C111", + thread_ts="1234567890.123456", + status="Thinking...", + loading_messages=None, + token="xoxb-override", + ) + + def test_set_status_requires_status(self): + """set_status() raises TypeError when status is not provided.""" + client = MagicMock(spec=WebClient) + agent = BoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + with pytest.raises(TypeError): + agent.set_status() + def test_import_from_slack_bolt(self): from slack_bolt import BoltAgent as ImportedBoltAgent diff --git a/tests/slack_bolt_async/agent/test_async_agent.py b/tests/slack_bolt_async/agent/test_async_agent.py index 02251fa4b..8e4c4d5c8 100644 --- a/tests/slack_bolt_async/agent/test_async_agent.py +++ b/tests/slack_bolt_async/agent/test_async_agent.py @@ -18,6 +18,17 @@ async def fake_chat_stream(**kwargs): return fake_chat_stream, call_tracker, mock_stream +def _make_async_api_mock(): + mock_response = MagicMock() + call_tracker = MagicMock() + + async def fake_api_call(**kwargs): + call_tracker(**kwargs) + return mock_response + + return fake_api_call, call_tracker, mock_response + + class TestAsyncBoltAgent: @pytest.mark.asyncio async def test_chat_stream_uses_context_defaults(self): @@ -107,6 +118,116 @@ async def test_chat_stream_passes_extra_kwargs(self): buffer_size=512, ) + @pytest.mark.asyncio + async def test_set_status_uses_context_defaults(self): + """AsyncBoltAgent.set_status() passes context defaults to AsyncWebClient.assistant_threads_setStatus().""" + client = MagicMock(spec=AsyncWebClient) + client.assistant_threads_setStatus, call_tracker, _ = _make_async_api_mock() + + agent = AsyncBoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + await agent.set_status(status="Thinking...") + + call_tracker.assert_called_once_with( + channel_id="C111", + thread_ts="1234567890.123456", + status="Thinking...", + loading_messages=None, + ) + + @pytest.mark.asyncio + async def test_set_status_with_loading_messages(self): + """AsyncBoltAgent.set_status() forwards loading_messages.""" + client = MagicMock(spec=AsyncWebClient) + client.assistant_threads_setStatus, call_tracker, _ = _make_async_api_mock() + + agent = AsyncBoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + await agent.set_status( + status="Thinking...", + loading_messages=["Sitting...", "Waiting..."], + ) + + call_tracker.assert_called_once_with( + channel_id="C111", + thread_ts="1234567890.123456", + status="Thinking...", + loading_messages=["Sitting...", "Waiting..."], + ) + + @pytest.mark.asyncio + async def test_set_status_overrides_context_defaults(self): + """Explicit channel/thread_ts override context defaults.""" + client = MagicMock(spec=AsyncWebClient) + client.assistant_threads_setStatus, call_tracker, _ = _make_async_api_mock() + + agent = AsyncBoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + await agent.set_status( + status="Thinking...", + channel="C999", + thread_ts="9999999999.999999", + ) + + call_tracker.assert_called_once_with( + channel_id="C999", + thread_ts="9999999999.999999", + status="Thinking...", + loading_messages=None, + ) + + @pytest.mark.asyncio + async def test_set_status_passes_extra_kwargs(self): + """Extra kwargs are forwarded to AsyncWebClient.assistant_threads_setStatus().""" + client = MagicMock(spec=AsyncWebClient) + client.assistant_threads_setStatus, call_tracker, _ = _make_async_api_mock() + + agent = AsyncBoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + await agent.set_status(status="Thinking...", token="xoxb-override") + + call_tracker.assert_called_once_with( + channel_id="C111", + thread_ts="1234567890.123456", + status="Thinking...", + loading_messages=None, + token="xoxb-override", + ) + + @pytest.mark.asyncio + async def test_set_status_requires_status(self): + """set_status() raises TypeError when status is not provided.""" + client = MagicMock(spec=AsyncWebClient) + agent = AsyncBoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + with pytest.raises(TypeError): + await agent.set_status() + @pytest.mark.asyncio async def test_import_from_agent_module(self): from slack_bolt.agent.async_agent import AsyncBoltAgent as ImportedAsyncBoltAgent