diff --git a/api/src/org/labkey/api/mcp/AbstractAgentAction.java b/api/src/org/labkey/api/mcp/AbstractAgentAction.java index 91dddcc7eaf..159695afe78 100644 --- a/api/src/org/labkey/api/mcp/AbstractAgentAction.java +++ b/api/src/org/labkey/api/mcp/AbstractAgentAction.java @@ -3,6 +3,7 @@ import com.google.genai.errors.ClientException; import com.google.genai.errors.ServerException; import jakarta.servlet.http.HttpSession; +import org.apache.commons.lang3.StringUtils; import org.json.JSONObject; import org.labkey.api.action.ReadOnlyApiAction; import org.labkey.api.util.HtmlString; @@ -25,26 +26,53 @@ public abstract class AbstractAgentAction extends ReadOnly protected abstract String getServicePrompt(); - protected ChatClient getChat() + protected ChatClient getChat(boolean create) { HttpSession session = getViewContext().getRequest().getSession(true); - ChatClient chatSession = McpService.get().getChat(session, getAgentName(), this::getServicePrompt); + ChatClient chatSession = McpService.get().getChat(session, getAgentName(), this::getServicePrompt, create); return chatSession; } + protected String handleEscape(String prompt) + { + prompt = StringUtils.trimToEmpty(prompt); + switch (prompt) + { + case "/clear" -> + { + ChatClient chatSession = getChat(false); // CONSIDER: getChat(boolean ifStarted) + if (null != chatSession) + McpService.get().close(getViewContext().getSession(), chatSession); + return "OK, let's start over."; + } + } + return null; + } + @Override public Object execute(PromptForm form, BindException errors) throws Exception { try (var mcpPush = McpContext.withContext(getViewContext())) { - ChatClient chatSession = getChat(); + String prompt = form.getPrompt(); + + String escapeResponse = handleEscape(prompt); + if (null != escapeResponse) + { + return new JSONObject(Map.of( + "contentType", "text/plain", + "response", escapeResponse, + "success", Boolean.TRUE)); + } + + // call getChat() after handleEscape() + ChatClient chatSession = getChat(true); if (null == chatSession) return new JSONObject(Map.of( "contentType", "text/plain", "response", "Service is not ready yet", "success", Boolean.FALSE)); - String prompt = form.getPrompt(); McpService.MessageResponse response = McpService.get().sendMessage(chatSession, prompt); var ret = new JSONObject(Map.of("success", Boolean.TRUE)); if (!HtmlString.isBlank(response.html())) diff --git a/api/src/org/labkey/api/mcp/McpService.java b/api/src/org/labkey/api/mcp/McpService.java index 164e8b81f0d..67854a58c61 100644 --- a/api/src/org/labkey/api/mcp/McpService.java +++ b/api/src/org/labkey/api/mcp/McpService.java @@ -74,7 +74,14 @@ default void register(McpProvider mcp) @Override ToolCallback @NonNull [] getToolCallbacks(); - ChatClient getChat(HttpSession session, String agentName, Supplier systemPromptSupplier); + default ChatClient getChat(HttpSession session, String agentName, Supplier systemPromptSupplier) + { + return getChat(session, agentName, systemPromptSupplier, true); + } + + ChatClient getChat(HttpSession session, String agentName, Supplier systemPromptSupplier, boolean createIfNotExists); + + void close(HttpSession session, ChatClient chat); record MessageResponse(String contentType, String text, HtmlString html) {} diff --git a/core/src/org/labkey/core/CoreController.java b/core/src/org/labkey/core/CoreController.java index 3f78eb63b95..d80b090fa49 100644 --- a/core/src/org/labkey/core/CoreController.java +++ b/core/src/org/labkey/core/CoreController.java @@ -21,7 +21,6 @@ import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpSession; import org.apache.commons.beanutils.ConversionException; -import org.apache.commons.beanutils.ConvertUtils; import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.Logger; import org.apache.xmlbeans.XmlObject; diff --git a/core/src/org/labkey/core/mpc/McpServiceImpl.java b/core/src/org/labkey/core/mpc/McpServiceImpl.java index 422d0140223..11e7b93a3a3 100644 --- a/core/src/org/labkey/core/mpc/McpServiceImpl.java +++ b/core/src/org/labkey/core/mpc/McpServiceImpl.java @@ -46,6 +46,7 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.google.genai.GoogleGenAiChatModel; import org.springframework.ai.google.genai.GoogleGenAiChatOptions; @@ -299,41 +300,93 @@ public void shutdownStarted() } } - @Override - public ChatClient getChat(HttpSession session, String agentName, Supplier systemPromptSupplier) + public ChatClient getChat(HttpSession session, String agentName, Supplier systemPromptSupplier, boolean createIfNotExists) { if (!serverReady) return null; - return SessionHelper.getAttribute(session, ChatClient.class.getName() + "#" + agentName, () -> + String sessionKey = ChatClient.class.getName() + "#" + agentName; + if (createIfNotExists) { - String systemPrompt = systemPromptSupplier.get(); - String conversationId = session.getId() + ":" + agentName; - List advisors = new ArrayList<>(); + return SessionHelper.getAttribute(session, sessionKey, () -> + { + var springClient = createSpringChat(session, agentName, systemPromptSupplier); + return new _ChatClient(springClient, sessionKey); + }); + } + return SessionHelper.getAttribute(session, sessionKey, null); + } - ChatMemory chatMemory = MessageWindowChatMemory.builder() - .maxMessages(100) - .chatMemoryRepository(chatMemoryRepository) - .build(); + private ChatClient createSpringChat(HttpSession session, String agentName, Supplier systemPromptSupplier) + { + String systemPrompt = systemPromptSupplier.get(); + String conversationId = session.getId() + ":" + agentName; + List advisors = new ArrayList<>(); + + ChatMemory chatMemory = MessageWindowChatMemory.builder() + .maxMessages(100) + .chatMemoryRepository(chatMemoryRepository) + .build(); + + MessageChatMemoryAdvisor chatMemoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory) + .conversationId(conversationId) + .build(); + advisors.add(chatMemoryAdvisor); + + VectorStore vs = getVectorStore(); + if (null != vs) + advisors.add(QuestionAnswerAdvisor.builder(vs).build()); + + return ChatClient.builder(modelProvider.getChatModel()) + .defaultOptions(modelProvider.getChatOptions()) + .defaultAdvisors(advisors) + .defaultSystem(systemPrompt) + .build(); + } - MessageChatMemoryAdvisor chatMemoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory) - .conversationId(conversationId) - .build(); - advisors.add(chatMemoryAdvisor); + private class _ChatClient implements ChatClient + { + final ChatClient springClient; + final String key; + _ChatClient(ChatClient client, String key) + { + this.springClient = client; + this.key = key; + } - VectorStore vs = getVectorStore(); - if (null != vs) - advisors.add(QuestionAnswerAdvisor.builder(vs).build()); + @Override + public ChatClientRequestSpec prompt() + { + return springClient.prompt(); + } - return ChatClient.builder(modelProvider.getChatModel()) - .defaultOptions(modelProvider.getChatOptions()) - .defaultAdvisors(advisors) - .defaultSystem(systemPrompt) - .build(); - }); + @Override + public ChatClientRequestSpec prompt(String content) + { + return springClient.prompt(content); + } + + @Override + public ChatClientRequestSpec prompt(Prompt prompt) + { + return springClient.prompt(prompt); + } + + @Override + public Builder mutate() + { + throw new UnsupportedOperationException(); + } } + @Override + public void close(HttpSession session, ChatClient chat) + { + if (null == chat) + return; + session.removeAttribute(((_ChatClient)chat).key); + } @Override public MessageResponse sendMessage(ChatClient chatSession, String message) @@ -477,7 +530,6 @@ interface _ModelProvider ChatModel getChatModel(); -// ChatClient getChat(HttpSession session, String agentName, Supplier systemPromptSupplier); EmbeddingModel createEmbeddingModel(); } diff --git a/query/src/org/labkey/query/controllers/QueryController.java b/query/src/org/labkey/query/controllers/QueryController.java index 0d14f2a7b27..5a3c88f2edc 100644 --- a/query/src/org/labkey/query/controllers/QueryController.java +++ b/query/src/org/labkey/query/controllers/QueryController.java @@ -8880,15 +8880,26 @@ public Object execute(SqlPromptForm form, BindException errors) throws Exception try (var mcpPush = McpContext.withContext(getViewContext())) { - // TODO when/how to do we reset or isolate different chat sessions, e.g. if two SQL windows are open concurrently? - ChatClient chatSession = getChat(); String prompt = form.getPrompt(); + + String escapeResponse = handleEscape(prompt); + if (null != escapeResponse) + { + return new JSONObject(Map.of( + "contentType", "text/plain", + "text", escapeResponse, + "success", Boolean.TRUE)); + } + + // TODO when/how to do we reset or isolate different chat sessions, e.g. if two SQL windows are open concurrently? + ChatClient chatSession = getChat(true); List responses; SqlResponse sqlResponse; if (isBlank(prompt)) { return new JSONObject(Map.of( + "contentType", "text/plain", "text", "🤷", "success", Boolean.TRUE)); }