Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ require (
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2
github.com/joho/godotenv v1.5.1
github.com/openai/openai-go/v2 v2.7.1
github.com/openai/openai-go/v3 v3.12.0
github.com/samber/lo v1.51.0
github.com/stretchr/testify v1.10.0
go.mongodb.org/mongo-driver/v2 v2.3.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/openai/openai-go/v2 v2.7.1 h1:/tfvTJhfv7hTSL8mWwc5VL4WLLSDL5yn9VqVykdu9r8=
github.com/openai/openai-go/v2 v2.7.1/go.mod h1:jrJs23apqJKKbT+pqtFgNKpRju/KP9zpUTZhz3GElQE=
github.com/openai/openai-go/v3 v3.12.0 h1:NkrImaglFQeDycc/n/fEmpFV8kKr8snl9/8X2x4eHOg=
github.com/openai/openai-go/v3 v3.12.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
Expand Down
25 changes: 20 additions & 5 deletions internal/api/chat/create_conversation_message_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,23 @@ func (s *ChatServer) CreateConversationMessageStream(
) error {
ctx := stream.Context()

languageModel := models.LanguageModel(req.GetLanguageModel())
// Handle oneof model field: prefer ModelSlug, fallback to LanguageModel enum
modelSlug := req.GetModelSlug()
if modelSlug == "" {
var err error
modelSlug, err = models.LanguageModel(req.GetLanguageModel()).Name()
if err != nil {
return s.sendStreamError(stream, err)
}
}

ctx, conversation, settings, err := s.prepare(
ctx,
req.GetProjectId(),
req.GetConversationId(),
req.GetUserMessage(),
req.GetUserSelectedText(),
languageModel,
modelSlug,
req.GetConversationType(),
)
if err != nil {
Expand All @@ -41,11 +50,17 @@ func (s *ChatServer) CreateConversationMessageStream(

// 用法跟 ChatCompletion 一样,只是传递了 stream 参数
llmProvider := &models.LLMProviderConfig{
Endpoint: s.cfg.OpenAIBaseURL,
Endpoint: "",
APIKey: settings.OpenAIAPIKey,
}

openaiChatHistory, inappChatHistory, err := s.aiClient.ChatCompletionStream(ctx, stream, conversation.ID.Hex(), languageModel, conversation.OpenaiChatHistory, llmProvider)
var legacyLanguageModel *chatv1.LanguageModel
if req.GetModelSlug() == "" {
m := req.GetLanguageModel()
legacyLanguageModel = &m
}

openaiChatHistory, inappChatHistory, err := s.aiClient.ChatCompletionStream(ctx, stream, conversation.ID.Hex(), modelSlug, legacyLanguageModel, conversation.OpenaiChatHistoryCompletion, llmProvider)
if err != nil {
return s.sendStreamError(stream, err)
}
Expand All @@ -60,7 +75,7 @@ func (s *ChatServer) CreateConversationMessageStream(
bsonMessages[i] = bsonMsg
}
conversation.InappChatHistory = append(conversation.InappChatHistory, bsonMessages...)
conversation.OpenaiChatHistory = openaiChatHistory
conversation.OpenaiChatHistoryCompletion = openaiChatHistory
if err := s.chatService.UpdateConversation(conversation); err != nil {
return s.sendStreamError(stream, err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ package chat
import (
"context"

"paperdebugger/internal/api/mapper"
"paperdebugger/internal/libs/contextutil"
"paperdebugger/internal/libs/shared"
"paperdebugger/internal/models"
chatv1 "paperdebugger/pkg/gen/api/chat/v1"

"github.com/google/uuid"
"github.com/openai/openai-go/v2/responses"
"github.com/openai/openai-go/v3"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"google.golang.org/protobuf/encoding/protojson"
Expand All @@ -21,10 +20,10 @@ import (
// 我们发送给 GPT 的就是从数据库里拿到的 Conversation 对象里面的内容(InputItemList)

// buildUserMessage constructs both the user-facing message and the OpenAI input message
func (s *ChatServer) buildUserMessage(ctx context.Context, userMessage, userSelectedText string, conversationType chatv1.ConversationType) (*chatv1.Message, *responses.ResponseInputItemUnionParam, error) {
func (s *ChatServer) buildUserMessage(ctx context.Context, userMessage, userSelectedText string, conversationType chatv1.ConversationType) (*chatv1.Message, openai.ChatCompletionMessageParamUnion, error) {
userPrompt, err := s.chatService.GetPrompt(ctx, userMessage, userSelectedText, conversationType)
if err != nil {
return nil, nil, err
return nil, openai.ChatCompletionMessageParamUnion{}, err
}

var inappMessage *chatv1.Message
Expand Down Expand Up @@ -54,20 +53,12 @@ func (s *ChatServer) buildUserMessage(ctx context.Context, userMessage, userSele
}
}

openaiMessage := &responses.ResponseInputItemUnionParam{
OfInputMessage: &responses.ResponseInputItemMessageParam{
Role: "user",
Content: responses.ResponseInputMessageContentListParam{
responses.ResponseInputContentParamOfInputText(userPrompt),
},
},
}

openaiMessage := openai.UserMessage(userPrompt)
return inappMessage, openaiMessage, nil
}

// buildSystemMessage constructs both the user-facing system message and the OpenAI input message
func (s *ChatServer) buildSystemMessage(systemPrompt string) (*chatv1.Message, *responses.ResponseInputItemUnionParam) {
func (s *ChatServer) buildSystemMessage(systemPrompt string) (*chatv1.Message, openai.ChatCompletionMessageParamUnion) {
inappMessage := &chatv1.Message{
MessageId: "pd_msg_system_" + uuid.New().String(),
Payload: &chatv1.MessagePayload{
Expand All @@ -79,14 +70,7 @@ func (s *ChatServer) buildSystemMessage(systemPrompt string) (*chatv1.Message, *
},
}

openaiMessage := &responses.ResponseInputItemUnionParam{
OfInputMessage: &responses.ResponseInputItemMessageParam{
Role: "system",
Content: responses.ResponseInputMessageContentListParam{
responses.ResponseInputContentParamOfInputText(systemPrompt),
},
},
}
openaiMessage := openai.SystemMessage(systemPrompt)

return inappMessage, openaiMessage
}
Expand Down Expand Up @@ -115,7 +99,7 @@ func (s *ChatServer) createConversation(
userInstructions string,
userMessage string,
userSelectedText string,
languageModel models.LanguageModel,
modelSlug string,
conversationType chatv1.ConversationType,
) (*models.Conversation, error) {
systemPrompt, err := s.chatService.GetSystemPrompt(ctx, latexFullSource, projectInstructions, userInstructions, conversationType)
Expand All @@ -130,12 +114,13 @@ func (s *ChatServer) createConversation(
}

messages := []*chatv1.Message{inappUserMsg}
oaiHistory := responses.ResponseNewParamsInputUnion{
OfInputItemList: responses.ResponseInputParam{*openaiSystemMsg, *openaiUserMsg},
oaiHistory := []openai.ChatCompletionMessageParamUnion{
openaiSystemMsg,
openaiUserMsg,
}

return s.chatService.InsertConversationToDB(
ctx, userId, projectId, languageModel, messages, oaiHistory.OfInputItemList,
ctx, userId, projectId, modelSlug, messages, oaiHistory,
)
}

Expand Down Expand Up @@ -169,8 +154,7 @@ func (s *ChatServer) appendConversationMessage(
return nil, err
}
conversation.InappChatHistory = append(conversation.InappChatHistory, bsonMsg)
conversation.OpenaiChatHistory = append(conversation.OpenaiChatHistory, *userOaiMsg)

conversation.OpenaiChatHistoryCompletion = append(conversation.OpenaiChatHistoryCompletion, userOaiMsg)
if err := s.chatService.UpdateConversation(conversation); err != nil {
return nil, err
}
Expand All @@ -180,7 +164,7 @@ func (s *ChatServer) appendConversationMessage(

// 如果 conversationId 是 "", 就创建新对话,否则就追加消息到对话
// conversationType 可以在一次 conversation 中多次切换
func (s *ChatServer) prepare(ctx context.Context, projectId string, conversationId string, userMessage string, userSelectedText string, languageModel models.LanguageModel, conversationType chatv1.ConversationType) (context.Context, *models.Conversation, *models.Settings, error) {
func (s *ChatServer) prepare(ctx context.Context, projectId string, conversationId string, userMessage string, userSelectedText string, modelSlug string, conversationType chatv1.ConversationType) (context.Context, *models.Conversation, *models.Settings, error) {
actor, err := contextutil.GetActor(ctx)
if err != nil {
return ctx, nil, nil, err
Expand Down Expand Up @@ -223,7 +207,7 @@ func (s *ChatServer) prepare(ctx context.Context, projectId string, conversation
userInstructions,
userMessage,
userSelectedText,
languageModel,
modelSlug,
conversationType,
)
} else {
Expand Down Expand Up @@ -251,68 +235,3 @@ func (s *ChatServer) prepare(ctx context.Context, projectId string, conversation

return ctx, conversation, settings, nil
}

// Deprecated: Use CreateConversationMessageStream instead.
func (s *ChatServer) CreateConversationMessage(
ctx context.Context,
req *chatv1.CreateConversationMessageRequest,
) (*chatv1.CreateConversationMessageResponse, error) {
languageModel := models.LanguageModel(req.GetLanguageModel())
ctx, conversation, settings, err := s.prepare(
ctx,
req.GetProjectId(),
req.GetConversationId(),
req.GetUserMessage(),
req.GetUserSelectedText(),
languageModel,
req.GetConversationType(),
)
if err != nil {
return nil, err
}

llmProvider := &models.LLMProviderConfig{
Endpoint: s.cfg.OpenAIBaseURL,
APIKey: settings.OpenAIAPIKey,
}
openaiChatHistory, inappChatHistory, err := s.aiClient.ChatCompletion(ctx, languageModel, conversation.OpenaiChatHistory, llmProvider)
if err != nil {
return nil, err
}

bsonMessages := make([]bson.M, len(inappChatHistory))
for i := range inappChatHistory {
bsonMsg, err := convertToBSON(&inappChatHistory[i])
if err != nil {
return nil, err
}
bsonMessages[i] = bsonMsg
}
conversation.InappChatHistory = append(conversation.InappChatHistory, bsonMessages...)
conversation.OpenaiChatHistory = openaiChatHistory

if err := s.chatService.UpdateConversation(conversation); err != nil {
return nil, err
}

go func() {
protoMessages := make([]*chatv1.Message, len(conversation.InappChatHistory))
for i, bsonMsg := range conversation.InappChatHistory {
protoMessages[i] = mapper.BSONToChatMessage(bsonMsg)
}
title, err := s.aiClient.GetConversationTitle(ctx, protoMessages, llmProvider)
if err != nil {
s.logger.Error("Failed to get conversation title", "error", err, "conversationID", conversation.ID.Hex())
return
}
conversation.Title = title
if err := s.chatService.UpdateConversation(conversation); err != nil {
s.logger.Error("Failed to update conversation with new title", "error", err, "conversationID", conversation.ID.Hex())
return
}
}()

return &chatv1.CreateConversationMessageResponse{
Conversation: mapper.MapModelConversationToProto(conversation),
}, nil
}
20 changes: 16 additions & 4 deletions internal/api/chat/list_supported_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"paperdebugger/internal/libs/contextutil"
chatv1 "paperdebugger/pkg/gen/api/chat/v1"

"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v3"
)

func (s *ChatServer) ListSupportedModels(
Expand All @@ -30,15 +30,27 @@ func (s *ChatServer) ListSupportedModels(
{

Name: "GPT-4o",
Slug: openai.ChatModelGPT4o,
Slug: "openai/" + openai.ChatModelGPT4o,
},
{
Name: "GPT-4.1",
Slug: openai.ChatModelGPT4_1,
Slug: "openai/" + openai.ChatModelGPT4_1,
},
{
Name: "GPT-4.1-mini",
Slug: openai.ChatModelGPT4_1Mini,
Slug: "openai/" + openai.ChatModelGPT4_1Mini,
},
{
Name: "GPT 5 nano",
Slug: "openai/" + openai.ChatModelGPT5Nano,
},
{
Name: "Qwen Plus",
Slug: "qwen/qwen-plus",
},
{
Name: "Qwen 3 (235B A22B)",
Slug: "qwen/qwen3-235b-a22b:free",
},
}
} else {
Expand Down
25 changes: 18 additions & 7 deletions internal/api/mapper/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,30 @@ func BSONToChatMessage(msg bson.M) *chatv1.Message {
}

func MapModelConversationToProto(conversation *models.Conversation) *chatv1.Conversation {
// Convert BSON messages back to protobuf messages
filteredMessages := lo.Map(conversation.InappChatHistory, func(msg bson.M, _ int) *chatv1.Message {
return BSONToChatMessage(msg)
// Convert BSON messages back to protobuf messages, filtering out system messages
filteredMessages := lo.FilterMap(conversation.InappChatHistory, func(msg bson.M, _ int) (*chatv1.Message, bool) {
m := BSONToChatMessage(msg)
if m == nil {
return nil, false
}
return m, m.GetPayload().GetMessageType() != &chatv1.MessagePayload_System{}
})

filteredMessages = lo.Filter(filteredMessages, func(msg *chatv1.Message, _ int) bool {
return msg.GetPayload().GetMessageType() != &chatv1.MessagePayload_System{}
})
// Get model slug: prefer new ModelSlug field, fallback to legacy LanguageModel
// modelSlug := conversation.ModelSlug
// if modelSlug == "" {
// var err error
// modelSlug, err = conversation.LanguageModel.Name()
// if err != nil {
// return nil
// }
// }

return &chatv1.Conversation{
Id: conversation.ID.Hex(),
Title: conversation.Title,
LanguageModel: chatv1.LanguageModel(conversation.LanguageModel),
Messages: filteredMessages,
// ModelSlug: &modelSlug, // TODO: when new version is ready, enable this line
Messages: filteredMessages,
}
}
27 changes: 13 additions & 14 deletions internal/libs/cfg/cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,34 @@ import (
)

type Cfg struct {
OpenAIBaseURL string
OpenAIAPIKey string
JwtSigningKey string

MongoURI string
XtraMCPURI string
PDInferenceBaseURL string
PDInferenceAPIKey string
JwtSigningKey string
MongoURI string
XtraMCPURI string
}

var cfg *Cfg

func GetCfg() *Cfg {
_ = godotenv.Load()
cfg = &Cfg{
OpenAIBaseURL: openAIBaseURL(),
OpenAIAPIKey: os.Getenv("OPENAI_API_KEY"),
JwtSigningKey: os.Getenv("JWT_SIGNING_KEY"),
MongoURI: mongoURI(),
XtraMCPURI: xtraMCPURI(),
PDInferenceBaseURL: pdInferenceBaseURL(),
PDInferenceAPIKey: os.Getenv("PD_INFERENCE_API_KEY"),
JwtSigningKey: os.Getenv("JWT_SIGNING_KEY"),
MongoURI: mongoURI(),
XtraMCPURI: xtraMCPURI(),
}

return cfg
}

func openAIBaseURL() string {
val := os.Getenv("OPENAI_BASE_URL")
func pdInferenceBaseURL() string {
val := os.Getenv("PD_INFERENCE_BASE_URL")
if val != "" {
return val
}
return "https://api.openai.com/v1"
return "https://inference.paperdebugger.workers.dev/"
}

func xtraMCPURI() string {
Expand Down
Loading