Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions api/src/org/labkey/api/mcp/AbstractAgentAction.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.google.genai.errors.ClientException;
import com.google.genai.errors.ServerException;
import jakarta.servlet.http.HttpSession;
import org.apache.commons.lang3.StringUtils;
import org.json.JSONObject;
import org.labkey.api.action.ReadOnlyApiAction;
import org.labkey.api.util.HtmlString;
Expand All @@ -25,26 +26,53 @@ public abstract class AbstractAgentAction<F extends PromptForm> extends ReadOnly

protected abstract String getServicePrompt();

protected ChatClient getChat()
protected ChatClient getChat(boolean create)
{
HttpSession session = getViewContext().getRequest().getSession(true);
ChatClient chatSession = McpService.get().getChat(session, getAgentName(), this::getServicePrompt);
ChatClient chatSession = McpService.get().getChat(session, getAgentName(), this::getServicePrompt, create);
return chatSession;
}

protected String handleEscape(String prompt)
{
prompt = StringUtils.trimToEmpty(prompt);
switch (prompt)
{
case "/clear" ->
{
ChatClient chatSession = getChat(false); // CONSIDER: getChat(boolean ifStarted)
if (null != chatSession)
McpService.get().close(getViewContext().getSession(), chatSession);
return "OK, let's start over.";
}
}
return null;
}

@Override
public Object execute(PromptForm form, BindException errors) throws Exception
{
try (var mcpPush = McpContext.withContext(getViewContext()))
{
ChatClient chatSession = getChat();
String prompt = form.getPrompt();

String escapeResponse = handleEscape(prompt);
if (null != escapeResponse)
{
return new JSONObject(Map.of(
"contentType", "text/plain",
"response", escapeResponse,
"success", Boolean.TRUE));
}

// call getChat() after handleEscape()
ChatClient chatSession = getChat(true);
if (null == chatSession)
return new JSONObject(Map.of(
"contentType", "text/plain",
"response", "Service is not ready yet",
"success", Boolean.FALSE));

String prompt = form.getPrompt();
McpService.MessageResponse response = McpService.get().sendMessage(chatSession, prompt);
var ret = new JSONObject(Map.of("success", Boolean.TRUE));
if (!HtmlString.isBlank(response.html()))
Expand Down
9 changes: 8 additions & 1 deletion api/src/org/labkey/api/mcp/McpService.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,14 @@ default void register(McpProvider mcp)
@Override
ToolCallback @NonNull [] getToolCallbacks();

ChatClient getChat(HttpSession session, String agentName, Supplier<String> systemPromptSupplier);
default ChatClient getChat(HttpSession session, String agentName, Supplier<String> systemPromptSupplier)
{
return getChat(session, agentName, systemPromptSupplier, true);
}

ChatClient getChat(HttpSession session, String agentName, Supplier<String> systemPromptSupplier, boolean createIfNotExists);

void close(HttpSession session, ChatClient chat);

record MessageResponse(String contentType, String text, HtmlString html) {}

Expand Down
1 change: 0 additions & 1 deletion core/src/org/labkey/core/CoreController.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import org.apache.commons.beanutils.ConversionException;
import org.apache.commons.beanutils.ConvertUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.Logger;
import org.apache.xmlbeans.XmlObject;
Expand Down
100 changes: 76 additions & 24 deletions core/src/org/labkey/core/mpc/McpServiceImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.google.genai.GoogleGenAiChatModel;
import org.springframework.ai.google.genai.GoogleGenAiChatOptions;
Expand Down Expand Up @@ -299,41 +300,93 @@ public void shutdownStarted()
}
}


@Override
public ChatClient getChat(HttpSession session, String agentName, Supplier<String> systemPromptSupplier)
public ChatClient getChat(HttpSession session, String agentName, Supplier<String> systemPromptSupplier, boolean createIfNotExists)
{
if (!serverReady)
return null;

return SessionHelper.getAttribute(session, ChatClient.class.getName() + "#" + agentName, () ->
String sessionKey = ChatClient.class.getName() + "#" + agentName;
if (createIfNotExists)
{
String systemPrompt = systemPromptSupplier.get();
String conversationId = session.getId() + ":" + agentName;
List<Advisor> advisors = new ArrayList<>();
return SessionHelper.getAttribute(session, sessionKey, () ->
{
var springClient = createSpringChat(session, agentName, systemPromptSupplier);
return new _ChatClient(springClient, sessionKey);
});
}
return SessionHelper.getAttribute(session, sessionKey, null);
}

ChatMemory chatMemory = MessageWindowChatMemory.builder()
.maxMessages(100)
.chatMemoryRepository(chatMemoryRepository)
.build();
private ChatClient createSpringChat(HttpSession session, String agentName, Supplier<String> systemPromptSupplier)
{
String systemPrompt = systemPromptSupplier.get();
String conversationId = session.getId() + ":" + agentName;
List<Advisor> advisors = new ArrayList<>();

ChatMemory chatMemory = MessageWindowChatMemory.builder()
.maxMessages(100)
.chatMemoryRepository(chatMemoryRepository)
.build();

MessageChatMemoryAdvisor chatMemoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory)
.conversationId(conversationId)
.build();
advisors.add(chatMemoryAdvisor);

VectorStore vs = getVectorStore();
if (null != vs)
advisors.add(QuestionAnswerAdvisor.builder(vs).build());

return ChatClient.builder(modelProvider.getChatModel())
.defaultOptions(modelProvider.getChatOptions())
.defaultAdvisors(advisors)
.defaultSystem(systemPrompt)
.build();
}

MessageChatMemoryAdvisor chatMemoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory)
.conversationId(conversationId)
.build();
advisors.add(chatMemoryAdvisor);
private class _ChatClient implements ChatClient
{
final ChatClient springClient;
final String key;
_ChatClient(ChatClient client, String key)
{
this.springClient = client;
this.key = key;
}

VectorStore vs = getVectorStore();
if (null != vs)
advisors.add(QuestionAnswerAdvisor.builder(vs).build());
@Override
public ChatClientRequestSpec prompt()
{
return springClient.prompt();
}

return ChatClient.builder(modelProvider.getChatModel())
.defaultOptions(modelProvider.getChatOptions())
.defaultAdvisors(advisors)
.defaultSystem(systemPrompt)
.build();
});
@Override
public ChatClientRequestSpec prompt(String content)
{
return springClient.prompt(content);
}

@Override
public ChatClientRequestSpec prompt(Prompt prompt)
{
return springClient.prompt(prompt);
}

@Override
public Builder mutate()
{
throw new UnsupportedOperationException();
}
}

@Override
public void close(HttpSession session, ChatClient chat)
{
if (null == chat)
return;
session.removeAttribute(((_ChatClient)chat).key);
}

@Override
public MessageResponse sendMessage(ChatClient chatSession, String message)
Expand Down Expand Up @@ -477,7 +530,6 @@ interface _ModelProvider

ChatModel getChatModel();

// ChatClient getChat(HttpSession session, String agentName, Supplier<String> systemPromptSupplier);
EmbeddingModel createEmbeddingModel();
}

Expand Down
15 changes: 13 additions & 2 deletions query/src/org/labkey/query/controllers/QueryController.java
Original file line number Diff line number Diff line change
Expand Up @@ -8880,15 +8880,26 @@ public Object execute(SqlPromptForm form, BindException errors) throws Exception

try (var mcpPush = McpContext.withContext(getViewContext()))
{
// TODO when/how to do we reset or isolate different chat sessions, e.g. if two SQL windows are open concurrently?
ChatClient chatSession = getChat();
String prompt = form.getPrompt();

String escapeResponse = handleEscape(prompt);
if (null != escapeResponse)
{
return new JSONObject(Map.of(
"contentType", "text/plain",
"text", escapeResponse,
"success", Boolean.TRUE));
}

// TODO when/how to do we reset or isolate different chat sessions, e.g. if two SQL windows are open concurrently?
ChatClient chatSession = getChat(true);
List<McpService.MessageResponse> responses;
SqlResponse sqlResponse;

if (isBlank(prompt))
{
return new JSONObject(Map.of(
"contentType", "text/plain",
"text", "🤷",
"success", Boolean.TRUE));
}
Expand Down