diff --git a/Examples/Examples/Chat/ChatCustomGrammarExample.cs b/Examples/Examples/Chat/ChatCustomGrammarExample.cs index ca655806..a7288a6e 100644 --- a/Examples/Examples/Chat/ChatCustomGrammarExample.cs +++ b/Examples/Examples/Chat/ChatCustomGrammarExample.cs @@ -24,7 +24,7 @@ public async Task Start() await AIHub.Chat() .WithModel(Models.Local.Gemma2_2b) .WithMessage("Generate random person") - .WithInferenceParams(new InferenceParams + .WithInferenceParams(new LocalInferenceParams { Grammar = personGrammar }) diff --git a/Examples/Examples/Chat/ChatExampleOpenAi.cs b/Examples/Examples/Chat/ChatExampleOpenAi.cs index d241edc1..da50a17c 100644 --- a/Examples/Examples/Chat/ChatExampleOpenAi.cs +++ b/Examples/Examples/Chat/ChatExampleOpenAi.cs @@ -1,5 +1,6 @@ using Examples.Utils; using MaIN.Core.Hub; +using MaIN.Domain.Configuration.BackendInferenceParams; using MaIN.Domain.Models; namespace Examples.Chat; @@ -15,6 +16,14 @@ public async Task Start() await AIHub.Chat() .WithModel(Models.OpenAi.Gpt5Nano) .WithMessage("What do you consider to be the greatest invention in history?") + .WithInferenceParams(new OpenAiInferenceParams // We could override some inference params + { + ResponseFormat = "text", + AdditionalParams = new Dictionary + { + ["max_completion_tokens"] = 2137 + } + }) .CompleteAsync(interactive: true); } } diff --git a/Examples/Examples/Chat/ChatGrammarExampleGemini.cs b/Examples/Examples/Chat/ChatGrammarExampleGemini.cs index ac9234f9..d0232f2b 100644 --- a/Examples/Examples/Chat/ChatGrammarExampleGemini.cs +++ b/Examples/Examples/Chat/ChatGrammarExampleGemini.cs @@ -40,7 +40,7 @@ public async Task Start() await AIHub.Chat() .WithModel(Models.Gemini.Gemini2_5Flash) .WithMessage("Generate random person") - .WithInferenceParams(new InferenceParams + .WithInferenceParams(new LocalInferenceParams { Grammar = new Grammar(grammarValue, GrammarFormat.JSONSchema) }) diff --git a/MaIN.Core.IntegrationTests/BackendParamsTests.cs b/MaIN.Core.IntegrationTests/BackendParamsTests.cs new file mode 100644 index 00000000..cf240ff6 --- /dev/null +++ b/MaIN.Core.IntegrationTests/BackendParamsTests.cs @@ -0,0 +1,320 @@ +using MaIN.Core.Hub; +using MaIN.Domain.Configuration; +using MaIN.Domain.Entities; +using MaIN.Domain.Configuration.BackendInferenceParams; +using MaIN.Domain.Exceptions; +using MaIN.Domain.Models; +using MaIN.Domain.Models.Concrete; + +namespace MaIN.Core.IntegrationTests; + +public class BackendParamsTests : IntegrationTestBase +{ + private const string TestQuestion = "What is 2+2? Answer with just the number."; + + [SkippableFact] + public async Task OpenAi_Should_RespondWithParams() + { + SkipIfMissingKey(LLMApiRegistry.GetEntry(BackendType.OpenAi)?.ApiKeyEnvName!); + + var result = await AIHub.Chat() + .WithModel(Models.OpenAi.Gpt4oMini) + .WithMessage(TestQuestion) + .WithInferenceParams(new OpenAiInferenceParams + { + Temperature = 0.3f, + MaxTokens = 100, + TopP = 0.9f + }) + .CompleteAsync(); + + Assert.True(result.Done); + Assert.NotNull(result.Message); + Assert.NotEmpty(result.Message.Content); + Assert.Contains("4", result.Message.Content); + } + + [SkippableFact] + public async Task Anthropic_Should_RespondWithParams() + { + SkipIfMissingKey(LLMApiRegistry.GetEntry(BackendType.Anthropic)?.ApiKeyEnvName!); + + var result = await AIHub.Chat() + .WithModel(Models.Anthropic.ClaudeSonnet4) + .WithMessage(TestQuestion) + .WithInferenceParams(new AnthropicInferenceParams + { + Temperature = 0.3f, + MaxTokens = 100, + TopP = 0.9f + }) + .CompleteAsync(); + + Assert.True(result.Done); + Assert.NotNull(result.Message); + Assert.NotEmpty(result.Message.Content); + Assert.Contains("4", result.Message.Content); + } + + [SkippableFact] + public async Task Gemini_Should_RespondWithParams() + { + SkipIfMissingKey(LLMApiRegistry.GetEntry(BackendType.Gemini)?.ApiKeyEnvName!); + + var result = await AIHub.Chat() + .WithModel(Models.Gemini.Gemini2_0Flash) + .WithMessage(TestQuestion) + .WithInferenceParams(new GeminiInferenceParams + { + Temperature = 0.3f, + MaxTokens = 100, + TopP = 0.9f + }) + .CompleteAsync(); + + Assert.True(result.Done); + Assert.NotNull(result.Message); + Assert.NotEmpty(result.Message.Content); + Assert.Contains("4", result.Message.Content); + } + + [SkippableFact] + public async Task DeepSeek_Should_RespondWithParams() + { + SkipIfMissingKey(LLMApiRegistry.GetEntry(BackendType.DeepSeek)?.ApiKeyEnvName!); + + var result = await AIHub.Chat() + .WithModel(Models.DeepSeek.Reasoner) + .WithMessage(TestQuestion) + .WithInferenceParams(new DeepSeekInferenceParams + { + Temperature = 0.3f, + MaxTokens = 100, + TopP = 0.9f + }) + .CompleteAsync(); + + Assert.True(result.Done); + Assert.NotNull(result.Message); + Assert.NotEmpty(result.Message.Content); + Assert.Contains("4", result.Message.Content); + } + + [SkippableFact] + public async Task GroqCloud_Should_RespondWithParams() + { + SkipIfMissingKey(LLMApiRegistry.GetEntry(BackendType.GroqCloud)?.ApiKeyEnvName!); + + var result = await AIHub.Chat() + .WithModel(Models.Groq.Llama3_1_8bInstant) + .WithMessage(TestQuestion) + .WithInferenceParams(new GroqCloudInferenceParams + { + Temperature = 0.3f, + MaxTokens = 100, + TopP = 0.9f + }) + .CompleteAsync(); + + Assert.True(result.Done); + Assert.NotNull(result.Message); + Assert.NotEmpty(result.Message.Content); + Assert.Contains("4", result.Message.Content); + } + + [SkippableFact] + public async Task Xai_Should_RespondWithParams() + { + SkipIfMissingKey(LLMApiRegistry.GetEntry(BackendType.Xai)?.ApiKeyEnvName!); + + var result = await AIHub.Chat() + .WithModel(Models.Xai.Grok3Beta) + .WithMessage(TestQuestion) + .WithInferenceParams(new XaiInferenceParams + { + Temperature = 0.3f, + MaxTokens = 100, + TopP = 0.9f + }) + .CompleteAsync(); + + Assert.True(result.Done); + Assert.NotNull(result.Message); + Assert.NotEmpty(result.Message.Content); + Assert.Contains("4", result.Message.Content); + } + + [SkippableFact] + public async Task Self_Should_RespondWithParams() + { + Skip.If(!File.Exists("C:/Models/gemma2-2b.gguf"), "Local model not found at C:/Models/gemma2-2b.gguf"); + + var result = await AIHub.Chat() + .WithModel(Models.Local.Gemma2_2b) + .WithMessage(TestQuestion) + .WithInferenceParams(new LocalInferenceParams + { + Temperature = 0.3f, + ContextSize = 8192, + MaxTokens = 100, + TopK = 40, + TopP = 0.9f + }) + .CompleteAsync(); + + Assert.True(result.Done); + Assert.NotNull(result.Message); + Assert.NotEmpty(result.Message.Content); + Assert.Contains("4", result.Message.Content); + } + + [SkippableFact] + public async Task LocalOllama_Should_RespondWithParams() + { + SkipIfOllamaNotRunning(); + + var result = await AIHub.Chat() + .WithModel(Models.Ollama.Gemma3_4b) + .WithMessage(TestQuestion) + .WithInferenceParams(new OllamaInferenceParams + { + Temperature = 0.3f, + MaxTokens = 100, + TopK = 40, + TopP = 0.9f, + NumCtx = 2048 + }) + .CompleteAsync(); + + Assert.True(result.Done); + Assert.NotNull(result.Message); + Assert.NotEmpty(result.Message.Content); + Assert.Contains("4", result.Message.Content); + } + + [SkippableFact] + public async Task ClaudOllama_Should_RespondWithParams() + { + SkipIfMissingKey(LLMApiRegistry.GetEntry(BackendType.Ollama)?.ApiKeyEnvName!); + + var result = await AIHub.Chat() + .WithModel(Models.Ollama.Gemma3_4b) + .WithMessage(TestQuestion) + .WithInferenceParams(new OllamaInferenceParams + { + Temperature = 0.3f, + MaxTokens = 100, + TopK = 40, + TopP = 0.9f, + NumCtx = 2048 + }) + .CompleteAsync(); + + Assert.True(result.Done); + Assert.NotNull(result.Message); + Assert.NotEmpty(result.Message.Content); + Assert.Contains("4", result.Message.Content); + } + + // --- Params mismatch validation (no API key required) --- + + [Fact] + public async Task Self_Should_ThrowWhenGivenWrongParams() + { + await Assert.ThrowsAsync(() => + AIHub.Chat() + .WithModel(Models.Local.Gemma2_2b) + .WithMessage(TestQuestion) + .WithInferenceParams(new OpenAiInferenceParams()) + .CompleteAsync()); + } + + [Fact] + public async Task OpenAi_Should_ThrowWhenGivenWrongParams() + { + await Assert.ThrowsAsync(() => + AIHub.Chat() + .WithModel(Models.OpenAi.Gpt4oMini) + .WithMessage(TestQuestion) + .WithInferenceParams(new DeepSeekInferenceParams()) + .CompleteAsync()); + } + + [Fact] + public async Task Anthropic_Should_ThrowWhenGivenWrongParams() + { + await Assert.ThrowsAsync(() => + AIHub.Chat() + .WithModel(Models.Anthropic.ClaudeSonnet4) + .WithMessage(TestQuestion) + .WithInferenceParams(new OpenAiInferenceParams()) + .CompleteAsync()); + } + + [Fact] + public async Task Gemini_Should_ThrowWhenGivenWrongParams() + { + await Assert.ThrowsAsync(() => + AIHub.Chat() + .WithModel(Models.Gemini.Gemini2_0Flash) + .WithMessage(TestQuestion) + .WithInferenceParams(new AnthropicInferenceParams()) + .CompleteAsync()); + } + + [Fact] + public async Task DeepSeek_Should_ThrowWhenGivenWrongParams() + { + await Assert.ThrowsAsync(() => + AIHub.Chat() + .WithModel(Models.DeepSeek.Reasoner) + .WithMessage(TestQuestion) + .WithInferenceParams(new GeminiInferenceParams()) + .CompleteAsync()); + } + + [Fact] + public async Task GroqCloud_Should_ThrowWhenGivenWrongParams() + { + await Assert.ThrowsAsync(() => + AIHub.Chat() + .WithModel(Models.Groq.Llama3_1_8bInstant) + .WithMessage(TestQuestion) + .WithInferenceParams(new OpenAiInferenceParams()) + .CompleteAsync()); + } + + [Fact] + public async Task Xai_Should_ThrowWhenGivenWrongParams() + { + await Assert.ThrowsAsync(() => + AIHub.Chat() + .WithModel(Models.Xai.Grok3Beta) + .WithMessage(TestQuestion) + .WithInferenceParams(new AnthropicInferenceParams()) + .CompleteAsync()); + } + + [Fact] + public async Task Ollama_Should_ThrowWhenGivenWrongParams() + { + await Assert.ThrowsAsync(() => + AIHub.Chat() + .WithModel(Models.Ollama.Gemma3_4b) + .WithMessage(TestQuestion) + .WithInferenceParams(new DeepSeekInferenceParams()) + .CompleteAsync()); + } + + private static void SkipIfMissingKey(string envName) + { + Skip.If(string.IsNullOrEmpty(Environment.GetEnvironmentVariable(envName)), + $"{envName} environment variable not set"); + } + + private static void SkipIfOllamaNotRunning() + { + Skip.If(!Helpers.NetworkHelper.PingHost("127.0.0.1", 11434, 3), + "Ollama is not running on localhost:11434"); + } +} diff --git a/MaIN.Core.IntegrationTests/MaIN.Core.IntegrationTests.csproj b/MaIN.Core.IntegrationTests/MaIN.Core.IntegrationTests.csproj index 94edd801..460c35cf 100644 --- a/MaIN.Core.IntegrationTests/MaIN.Core.IntegrationTests.csproj +++ b/MaIN.Core.IntegrationTests/MaIN.Core.IntegrationTests.csproj @@ -15,6 +15,7 @@ all runtime; build; native; contentfiles; analyzers; buildtransitive + diff --git a/Releases/0.10.2.md b/Releases/0.10.2.md new file mode 100644 index 00000000..dd048c8a --- /dev/null +++ b/Releases/0.10.2.md @@ -0,0 +1,3 @@ +# 0.10.2 release + +Inference parameters are now backend-specific — each AI provider has its own typed params class where only explicitly set values are sent to the API, with an AdditionalParams dictionary for custom fields. \ No newline at end of file diff --git a/src/MaIN.Core.UnitTests/AgentContextTests.cs b/src/MaIN.Core.UnitTests/AgentContextTests.cs index a5187b4c..333205c6 100644 --- a/src/MaIN.Core.UnitTests/AgentContextTests.cs +++ b/src/MaIN.Core.UnitTests/AgentContextTests.cs @@ -139,7 +139,7 @@ public async Task CreateAsync_ShouldCallAgentServiceCreateAgent() It.IsAny(), It.IsAny(), It.IsAny(), - It.IsAny(), + It.IsAny(), It.IsAny(), It.IsAny())) .ReturnsAsync(agent); @@ -153,7 +153,7 @@ public async Task CreateAsync_ShouldCallAgentServiceCreateAgent() It.IsAny(), It.Is(f => f == true), It.Is(r => r == false), - It.IsAny(), + It.IsAny(), It.IsAny(), It.IsAny()), Times.Once); diff --git a/src/MaIN.Core/.nuspec b/src/MaIN.Core/.nuspec index 91376edf..30970341 100644 --- a/src/MaIN.Core/.nuspec +++ b/src/MaIN.Core/.nuspec @@ -2,7 +2,7 @@ MaIN.NET - 0.10.1 + 0.10.2 Wisedev Wisedev favicon.png diff --git a/src/MaIN.Core/Hub/Contexts/AgentContext.cs b/src/MaIN.Core/Hub/Contexts/AgentContext.cs index 86ce1501..0ac257ab 100644 --- a/src/MaIN.Core/Hub/Contexts/AgentContext.cs +++ b/src/MaIN.Core/Hub/Contexts/AgentContext.cs @@ -17,7 +17,7 @@ namespace MaIN.Core.Hub.Contexts; public sealed class AgentContext : IAgentBuilderEntryPoint, IAgentConfigurationBuilder, IAgentContextExecutor { private readonly IAgentService _agentService; - private InferenceParams? _inferenceParams; + private IBackendInferenceParams? _inferenceParams; private MemoryParams? _memoryParams; private bool _disableCache; private bool _ensureModelDownloaded; @@ -138,7 +138,7 @@ public IAgentConfigurationBuilder WithMcpConfig(Mcp mcpConfig) return this; } - public IAgentConfigurationBuilder WithInferenceParams(InferenceParams inferenceParams) + public IAgentConfigurationBuilder WithInferenceParams(IBackendInferenceParams inferenceParams) { _inferenceParams = inferenceParams; return this; diff --git a/src/MaIN.Core/Hub/Contexts/ChatContext.cs b/src/MaIN.Core/Hub/Contexts/ChatContext.cs index 4eb67d04..a405c888 100644 --- a/src/MaIN.Core/Hub/Contexts/ChatContext.cs +++ b/src/MaIN.Core/Hub/Contexts/ChatContext.cs @@ -57,9 +57,9 @@ public IChatMessageBuilder EnsureModelDownloaded() return this; } - public IChatConfigurationBuilder WithInferenceParams(InferenceParams inferenceParams) + public IChatConfigurationBuilder WithInferenceParams(IBackendInferenceParams inferenceParams) { - _chat.InterferenceParams = inferenceParams; + _chat.BackendParams = inferenceParams; return this; } diff --git a/src/MaIN.Core/Hub/Contexts/Interfaces/AgentContext/IAgentConfigurationBuilder.cs b/src/MaIN.Core/Hub/Contexts/Interfaces/AgentContext/IAgentConfigurationBuilder.cs index 27d81068..a7fe0643 100644 --- a/src/MaIN.Core/Hub/Contexts/Interfaces/AgentContext/IAgentConfigurationBuilder.cs +++ b/src/MaIN.Core/Hub/Contexts/Interfaces/AgentContext/IAgentConfigurationBuilder.cs @@ -71,10 +71,10 @@ public interface IAgentConfigurationBuilder : IAgentActions /// based on specific parameters. Inference parameters can influence various aspects of the chat, such as response length, /// temperature, and other model-specific settings. /// - /// An object that holds the parameters for inference, such as Temperature, MaxTokens, + /// An object that holds the parameters for inference, such as Temperature, MaxTokens, /// TopP, etc. These parameters control the generation behavior of the agent. /// The context instance implementing for method chaining. - IAgentConfigurationBuilder WithInferenceParams(InferenceParams inferenceParams); + IAgentConfigurationBuilder WithInferenceParams(IBackendInferenceParams inferenceParams); /// /// Sets the memory parameters for the chat session, allowing you to customize how the AI accesses and uses its memory diff --git a/src/MaIN.Core/Hub/Contexts/Interfaces/ChatContext/IChatConfigurationBuilder.cs b/src/MaIN.Core/Hub/Contexts/Interfaces/ChatContext/IChatConfigurationBuilder.cs index fcbec9f4..bdef46f9 100644 --- a/src/MaIN.Core/Hub/Contexts/Interfaces/ChatContext/IChatConfigurationBuilder.cs +++ b/src/MaIN.Core/Hub/Contexts/Interfaces/ChatContext/IChatConfigurationBuilder.cs @@ -13,10 +13,10 @@ public interface IChatConfigurationBuilder : IChatActions /// responses based on specific parameters. Inference parameters can influence various aspects of the chat, such as response length, /// temperature, and other model-specific settings. /// - /// An object that holds the parameters for inference, such as Temperature, + /// An object that holds the parameters for inference, such as Temperature, /// MaxTokens, TopP, etc. These parameters control the generation behavior of the chat. /// The context instance implementing for method chaining. - IChatConfigurationBuilder WithInferenceParams(InferenceParams inferenceParams); + IChatConfigurationBuilder WithInferenceParams(IBackendInferenceParams inferenceParams); /// /// Attaches external tools/functions that the model can invoke during the conversation. diff --git a/src/MaIN.Domain/Configuration/BackendInferenceParams/AnthropicInferenceParams.cs b/src/MaIN.Domain/Configuration/BackendInferenceParams/AnthropicInferenceParams.cs new file mode 100644 index 00000000..d649a233 --- /dev/null +++ b/src/MaIN.Domain/Configuration/BackendInferenceParams/AnthropicInferenceParams.cs @@ -0,0 +1,16 @@ +using MaIN.Domain.Entities; +using Grammar = MaIN.Domain.Models.Grammar; + +namespace MaIN.Domain.Configuration.BackendInferenceParams; + +public class AnthropicInferenceParams : IBackendInferenceParams +{ + public BackendType Backend => BackendType.Anthropic; + + public float? Temperature { get; init; } + public int? MaxTokens { get; init; } + public int? TopK { get; init; } + public float? TopP { get; init; } + public Grammar? Grammar { get; set; } + public Dictionary? AdditionalParams { get; init; } +} diff --git a/src/MaIN.Domain/Configuration/BackendInferenceParams/BackendParamsFactory.cs b/src/MaIN.Domain/Configuration/BackendInferenceParams/BackendParamsFactory.cs new file mode 100644 index 00000000..fa8d729b --- /dev/null +++ b/src/MaIN.Domain/Configuration/BackendInferenceParams/BackendParamsFactory.cs @@ -0,0 +1,19 @@ +using MaIN.Domain.Entities; + +namespace MaIN.Domain.Configuration.BackendInferenceParams; + +public static class BackendParamsFactory +{ + public static IBackendInferenceParams Create(BackendType backend) => backend switch + { + BackendType.Self => new LocalInferenceParams(), + BackendType.OpenAi => new OpenAiInferenceParams(), + BackendType.DeepSeek => new DeepSeekInferenceParams(), + BackendType.GroqCloud => new GroqCloudInferenceParams(), + BackendType.Xai => new XaiInferenceParams(), + BackendType.Gemini => new GeminiInferenceParams(), + BackendType.Anthropic => new AnthropicInferenceParams(), + BackendType.Ollama => new OllamaInferenceParams(), + _ => new LocalInferenceParams() + }; +} diff --git a/src/MaIN.Domain/Configuration/BackendInferenceParams/DeepSeekInferenceParams.cs b/src/MaIN.Domain/Configuration/BackendInferenceParams/DeepSeekInferenceParams.cs new file mode 100644 index 00000000..00661c59 --- /dev/null +++ b/src/MaIN.Domain/Configuration/BackendInferenceParams/DeepSeekInferenceParams.cs @@ -0,0 +1,18 @@ +using MaIN.Domain.Entities; +using Grammar = MaIN.Domain.Models.Grammar; + +namespace MaIN.Domain.Configuration.BackendInferenceParams; + +public class DeepSeekInferenceParams : IBackendInferenceParams +{ + public BackendType Backend => BackendType.DeepSeek; + + public float? Temperature { get; init; } + public int? MaxTokens { get; init; } + public float? TopP { get; init; } + public float? FrequencyPenalty { get; init; } + public float? PresencePenalty { get; init; } + public string? ResponseFormat { get; init; } + public Grammar? Grammar { get; set; } + public Dictionary? AdditionalParams { get; init; } +} diff --git a/src/MaIN.Domain/Configuration/BackendInferenceParams/GeminiInferenceParams.cs b/src/MaIN.Domain/Configuration/BackendInferenceParams/GeminiInferenceParams.cs new file mode 100644 index 00000000..9147b017 --- /dev/null +++ b/src/MaIN.Domain/Configuration/BackendInferenceParams/GeminiInferenceParams.cs @@ -0,0 +1,16 @@ +using MaIN.Domain.Entities; +using Grammar = MaIN.Domain.Models.Grammar; + +namespace MaIN.Domain.Configuration.BackendInferenceParams; + +public class GeminiInferenceParams : IBackendInferenceParams +{ + public BackendType Backend => BackendType.Gemini; + + public float? Temperature { get; init; } + public int? MaxTokens { get; init; } + public float? TopP { get; init; } + public string[]? StopSequences { get; init; } + public Grammar? Grammar { get; set; } + public Dictionary? AdditionalParams { get; init; } +} diff --git a/src/MaIN.Domain/Configuration/BackendInferenceParams/GroqCloudInferenceParams.cs b/src/MaIN.Domain/Configuration/BackendInferenceParams/GroqCloudInferenceParams.cs new file mode 100644 index 00000000..b7c91c88 --- /dev/null +++ b/src/MaIN.Domain/Configuration/BackendInferenceParams/GroqCloudInferenceParams.cs @@ -0,0 +1,17 @@ +using MaIN.Domain.Entities; +using Grammar = MaIN.Domain.Models.Grammar; + +namespace MaIN.Domain.Configuration.BackendInferenceParams; + +public class GroqCloudInferenceParams : IBackendInferenceParams +{ + public BackendType Backend => BackendType.GroqCloud; + + public float? Temperature { get; init; } + public int? MaxTokens { get; init; } + public float? TopP { get; init; } + public float? FrequencyPenalty { get; init; } + public string? ResponseFormat { get; init; } + public Grammar? Grammar { get; set; } + public Dictionary? AdditionalParams { get; init; } +} diff --git a/src/MaIN.Domain/Configuration/BackendInferenceParams/OllamaInferenceParams.cs b/src/MaIN.Domain/Configuration/BackendInferenceParams/OllamaInferenceParams.cs new file mode 100644 index 00000000..92138e6b --- /dev/null +++ b/src/MaIN.Domain/Configuration/BackendInferenceParams/OllamaInferenceParams.cs @@ -0,0 +1,18 @@ +using MaIN.Domain.Entities; +using Grammar = MaIN.Domain.Models.Grammar; + +namespace MaIN.Domain.Configuration.BackendInferenceParams; + +public class OllamaInferenceParams : IBackendInferenceParams +{ + public BackendType Backend => BackendType.Ollama; + + public float? Temperature { get; init; } + public int? MaxTokens { get; init; } + public int? TopK { get; init; } + public float? TopP { get; init; } + public int? NumCtx { get; init; } + public int? NumGpu { get; init; } + public Grammar? Grammar { get; set; } + public Dictionary? AdditionalParams { get; init; } +} diff --git a/src/MaIN.Domain/Configuration/BackendInferenceParams/OpenAiInferenceParams.cs b/src/MaIN.Domain/Configuration/BackendInferenceParams/OpenAiInferenceParams.cs new file mode 100644 index 00000000..b69d2640 --- /dev/null +++ b/src/MaIN.Domain/Configuration/BackendInferenceParams/OpenAiInferenceParams.cs @@ -0,0 +1,18 @@ +using MaIN.Domain.Entities; +using Grammar = MaIN.Domain.Models.Grammar; + +namespace MaIN.Domain.Configuration.BackendInferenceParams; + +public class OpenAiInferenceParams : IBackendInferenceParams +{ + public BackendType Backend => BackendType.OpenAi; + + public float? Temperature { get; init; } + public int? MaxTokens { get; init; } + public float? TopP { get; init; } + public float? FrequencyPenalty { get; init; } + public float? PresencePenalty { get; init; } + public string? ResponseFormat { get; init; } + public Grammar? Grammar { get; set; } + public Dictionary? AdditionalParams { get; init; } +} diff --git a/src/MaIN.Domain/Configuration/BackendInferenceParams/XaiInferenceParams.cs b/src/MaIN.Domain/Configuration/BackendInferenceParams/XaiInferenceParams.cs new file mode 100644 index 00000000..b8a7c196 --- /dev/null +++ b/src/MaIN.Domain/Configuration/BackendInferenceParams/XaiInferenceParams.cs @@ -0,0 +1,17 @@ +using MaIN.Domain.Entities; +using Grammar = MaIN.Domain.Models.Grammar; + +namespace MaIN.Domain.Configuration.BackendInferenceParams; + +public class XaiInferenceParams : IBackendInferenceParams +{ + public BackendType Backend => BackendType.Xai; + + public float? Temperature { get; init; } + public int? MaxTokens { get; init; } + public float? TopP { get; init; } + public float? FrequencyPenalty { get; init; } + public float? PresencePenalty { get; init; } + public Grammar? Grammar { get; set; } + public Dictionary? AdditionalParams { get; init; } +} diff --git a/src/MaIN.Domain/Entities/Chat.cs b/src/MaIN.Domain/Entities/Chat.cs index b55eb9a8..b2cffe5d 100644 --- a/src/MaIN.Domain/Entities/Chat.cs +++ b/src/MaIN.Domain/Entities/Chat.cs @@ -1,5 +1,9 @@ using LLama.Batched; +using MaIN.Domain.Configuration; +using MaIN.Domain.Configuration.BackendInferenceParams; using MaIN.Domain.Entities.Tools; +using MaIN.Domain.Models.Abstract; +using Grammar = MaIN.Domain.Models.Grammar; namespace MaIN.Domain.Entities; @@ -11,7 +15,38 @@ public class Chat public List Messages { get; set; } = []; public ChatType Type { get; set; } = ChatType.Conversation; public bool ImageGen { get; set; } - public InferenceParams InterferenceParams { get; set; } = new(); + public IBackendInferenceParams? BackendParams { get; set; } + public LocalInferenceParams? LocalParams => BackendParams as LocalInferenceParams; + + public Grammar? InferenceGrammar + { + get => BackendParams switch + { + LocalInferenceParams p => p.Grammar, + OpenAiInferenceParams p => p.Grammar, + DeepSeekInferenceParams p => p.Grammar, + GroqCloudInferenceParams p => p.Grammar, + XaiInferenceParams p => p.Grammar, + GeminiInferenceParams p => p.Grammar, + AnthropicInferenceParams p => p.Grammar, + OllamaInferenceParams p => p.Grammar, + _ => null + }; + set + { + switch (BackendParams) + { + case LocalInferenceParams p: p.Grammar = value; break; + case OpenAiInferenceParams p: p.Grammar = value; break; + case DeepSeekInferenceParams p: p.Grammar = value; break; + case GroqCloudInferenceParams p: p.Grammar = value; break; + case XaiInferenceParams p: p.Grammar = value; break; + case GeminiInferenceParams p: p.Grammar = value; break; + case AnthropicInferenceParams p: p.Grammar = value; break; + case OllamaInferenceParams p: p.Grammar = value; break; + } + } + } public MemoryParams MemoryParams { get; set; } = new(); public ToolsConfiguration? ToolsConfiguration { get; set; } public TextToSpeechParams? TextToSpeechParams { get; set; } diff --git a/src/MaIN.Domain/Entities/IBackendInferenceParams.cs b/src/MaIN.Domain/Entities/IBackendInferenceParams.cs new file mode 100644 index 00000000..8fe00b17 --- /dev/null +++ b/src/MaIN.Domain/Entities/IBackendInferenceParams.cs @@ -0,0 +1,9 @@ +using MaIN.Domain.Configuration; + +namespace MaIN.Domain.Entities; + +public interface IBackendInferenceParams +{ + BackendType Backend { get; } + Dictionary? AdditionalParams { get; } +} diff --git a/src/MaIN.Domain/Entities/InferenceParams.cs b/src/MaIN.Domain/Entities/LocalInferenceParams.cs similarity index 76% rename from src/MaIN.Domain/Entities/InferenceParams.cs rename to src/MaIN.Domain/Entities/LocalInferenceParams.cs index 13b591f0..89f52621 100644 --- a/src/MaIN.Domain/Entities/InferenceParams.cs +++ b/src/MaIN.Domain/Entities/LocalInferenceParams.cs @@ -1,9 +1,12 @@ +using MaIN.Domain.Configuration; using Grammar = MaIN.Domain.Models.Grammar; namespace MaIN.Domain.Entities; -public class InferenceParams +public class LocalInferenceParams : IBackendInferenceParams { + public BackendType Backend => BackendType.Self; + public float Temperature { get; init; } = 0.8f; public int ContextSize { get; init; } = 1024; public int GpuLayerCount { get; init; } = 30; @@ -13,11 +16,12 @@ public class InferenceParams public bool Embeddings { get; init; } = false; public int TypeK { get; init; } = 0; public int TypeV { get; init; } = 0; - + public int TokensKeep { get; set; } public int MaxTokens { get; set; } = -1; - + public int TopK { get; init; } = 40; public float TopP { get; init; } = 0.9f; public Grammar? Grammar { get; set; } -} \ No newline at end of file + public Dictionary? AdditionalParams { get; init; } +} diff --git a/src/MaIN.Domain/Exceptions/InvalidBackendParamsException.cs b/src/MaIN.Domain/Exceptions/InvalidBackendParamsException.cs new file mode 100644 index 00000000..e4bae40a --- /dev/null +++ b/src/MaIN.Domain/Exceptions/InvalidBackendParamsException.cs @@ -0,0 +1,10 @@ +using System.Net; + +namespace MaIN.Domain.Exceptions; + +public class InvalidBackendParamsException(string serviceName, string expectedType, string receivedType) + : MaINCustomException($"{serviceName} service requires {expectedType}, but received {receivedType}.") +{ + public override string PublicErrorMessage => Message; + public override HttpStatusCode HttpStatusCode => HttpStatusCode.BadRequest; +} diff --git a/src/MaIN.InferPage/Components/Pages/Home.razor b/src/MaIN.InferPage/Components/Pages/Home.razor index 59397b39..36ec3e99 100644 --- a/src/MaIN.InferPage/Components/Pages/Home.razor +++ b/src/MaIN.InferPage/Components/Pages/Home.razor @@ -9,6 +9,7 @@ @using MaIN.Core.Hub.Contexts.Interfaces.ChatContext @using MaIN.Domain.Configuration @using MaIN.Domain.Entities +@using MaIN.Domain.Configuration.BackendInferenceParams @using MaIN.Domain.Exceptions @using MaIN.Domain.Models @using MaIN.Domain.Models.Abstract @@ -310,7 +311,11 @@ { ModelRegistry.RegisterOrReplace(model); var newCtx = AIHub.Chat().WithModel(model.Id); - // Preserve history on model switch; cast is safe — ChatContext implements both interfaces. + // Set backend params before adding messages. + // Cast is safe — ChatContext implements both IChatMessageBuilder and IChatConfigurationBuilder. + ((IChatConfigurationBuilder)newCtx) + .WithInferenceParams(BackendParamsFactory.Create(Utils.BackendType)); + // Preserve history on model switch. ctx = Chat.Messages.Count > 0 ? (IChatMessageBuilder)newCtx.WithMessages(Chat.Messages) : newCtx; diff --git a/src/MaIN.Infrastructure/Mappers/ChatDocumentMapper.cs b/src/MaIN.Infrastructure/Mappers/ChatDocumentMapper.cs index 2d5f70db..3f7c6d87 100644 --- a/src/MaIN.Infrastructure/Mappers/ChatDocumentMapper.cs +++ b/src/MaIN.Infrastructure/Mappers/ChatDocumentMapper.cs @@ -16,7 +16,7 @@ internal static class ChatDocumentMapper ImageGen = chat.ImageGen, ToolsConfiguration = chat.ToolsConfiguration, MemoryParams = chat.MemoryParams.ToDocument(), - InferenceParams = chat.InterferenceParams.ToDocument(), + InferenceParams = (chat.BackendParams as LocalInferenceParams)?.ToDocument(), ConvState = chat.ConversationState, Properties = chat.Properties, Interactive = chat.Interactive, @@ -35,7 +35,7 @@ internal static class ChatDocumentMapper ToolsConfiguration = chat.ToolsConfiguration, ConversationState = chat.ConvState as Conversation.State, MemoryParams = chat.MemoryParams!.ToDomain(), - InterferenceParams = chat.InferenceParams!.ToDomain(), + BackendParams = chat.InferenceParams?.ToDomain() ?? new LocalInferenceParams(), Interactive = chat.Interactive, Translate = chat.Translate, Type = Enum.Parse(chat.Type.ToString()) @@ -78,7 +78,7 @@ internal static class ChatDocumentMapper Type = llmTokenValue.Type }; - private static InferenceParamsDocument ToDocument(this InferenceParams inferenceParams) => new() + private static InferenceParamsDocument ToDocument(this LocalInferenceParams inferenceParams) => new() { Temperature = inferenceParams.Temperature, ContextSize = inferenceParams.ContextSize, @@ -96,7 +96,7 @@ internal static class ChatDocumentMapper Grammar = inferenceParams.Grammar }; - private static InferenceParams ToDomain(this InferenceParamsDocument inferenceParams) => new() + private static LocalInferenceParams ToDomain(this InferenceParamsDocument inferenceParams) => new() { Temperature = inferenceParams.Temperature, ContextSize = inferenceParams.ContextSize, diff --git a/src/MaIN.Services/Constants/ServiceConstants.cs b/src/MaIN.Services/Constants/ServiceConstants.cs index a38ee412..ead2d5ba 100644 --- a/src/MaIN.Services/Constants/ServiceConstants.cs +++ b/src/MaIN.Services/Constants/ServiceConstants.cs @@ -63,6 +63,9 @@ public static class Properties public const string DisableCacheProperty = "DisableCache"; public const string AgentIdProperty = "AgentId"; public const string MmProjNameProperty = "MmProjName"; + public const string ToolCallsProperty = "ToolCalls"; + public const string ToolCallIdProperty = "ToolCallId"; + public const string ToolNameProperty = "ToolName"; } public static class Defaults diff --git a/src/MaIN.Services/Services/Abstract/IAgentService.cs b/src/MaIN.Services/Services/Abstract/IAgentService.cs index 25c0449e..16cb248a 100644 --- a/src/MaIN.Services/Services/Abstract/IAgentService.cs +++ b/src/MaIN.Services/Services/Abstract/IAgentService.cs @@ -11,7 +11,7 @@ public interface IAgentService Task Process(Chat chat, string agentId, Knowledge? knowledge, bool translatePrompt = false, Func? callbackToken = null, Func? callbackTool = null); Task CreateAgent(Agent agent, bool flow = false, bool interactiveResponse = false, - InferenceParams? inferenceParams = null, MemoryParams? memoryParams = null, bool disableCache = false); + IBackendInferenceParams? inferenceParams = null, MemoryParams? memoryParams = null, bool disableCache = false); Task GetChatByAgent(string agentId); Task Restart(string agentId); Task> GetAgents(); diff --git a/src/MaIN.Services/Services/AgentService.cs b/src/MaIN.Services/Services/AgentService.cs index 80d6af48..9e637fa6 100644 --- a/src/MaIN.Services/Services/AgentService.cs +++ b/src/MaIN.Services/Services/AgentService.cs @@ -94,7 +94,7 @@ await notificationService.DispatchNotification( } public async Task CreateAgent(Agent agent, bool flow = false, bool interactiveResponse = false, - InferenceParams? inferenceParams = null, MemoryParams? memoryParams = null, bool disableCache = false) + IBackendInferenceParams? inferenceParams = null, MemoryParams? memoryParams = null, bool disableCache = false) { var chat = new Chat { @@ -103,7 +103,7 @@ public async Task CreateAgent(Agent agent, bool flow = false, bool intera Name = agent.Name, ImageGen = agent.Model == ImageGenService.LocalImageModels.FLUX, ToolsConfiguration = agent.ToolsConfiguration, - InterferenceParams = inferenceParams ?? new InferenceParams(), + BackendParams = inferenceParams ?? new LocalInferenceParams(), MemoryParams = memoryParams ?? new MemoryParams(), Messages = [], Interactive = interactiveResponse, diff --git a/src/MaIN.Services/Services/ChatService.cs b/src/MaIN.Services/Services/ChatService.cs index 5b1ad06e..21805d21 100644 --- a/src/MaIN.Services/Services/ChatService.cs +++ b/src/MaIN.Services/Services/ChatService.cs @@ -1,4 +1,5 @@ using MaIN.Domain.Configuration; +using MaIN.Domain.Configuration.BackendInferenceParams; using MaIN.Domain.Entities; using MaIN.Domain.Exceptions.Chats; using MaIN.Domain.Models; @@ -44,6 +45,7 @@ public async Task Completions( } var backend = model!.Backend; + chat.BackendParams ??= BackendParamsFactory.Create(backend); chat.Messages.Where(x => x.Type == MessageType.NotSet).ToList() .ForEach(x => x.Type = backend != BackendType.Self ? MessageType.CloudLLM : MessageType.LocalLLM); diff --git a/src/MaIN.Services/Services/LLMService/AnthropicService.cs b/src/MaIN.Services/Services/LLMService/AnthropicService.cs index bbf1d879..f2181c67 100644 --- a/src/MaIN.Services/Services/LLMService/AnthropicService.cs +++ b/src/MaIN.Services/Services/LLMService/AnthropicService.cs @@ -14,6 +14,7 @@ using MaIN.Domain.Exceptions; using MaIN.Domain.Models.Concrete; using MaIN.Services.Services.LLMService.Utils; +using MaIN.Domain.Configuration.BackendInferenceParams; namespace MaIN.Services.Services.LLMService; @@ -26,7 +27,6 @@ public sealed class AnthropicService( { private readonly MaINSettings _settings = settings ?? throw new ArgumentNullException(nameof(settings)); - private static readonly HashSet AnthropicImageExtensions = [".jpg", ".jpeg", ".png", ".gif", ".webp"]; private static readonly ConcurrentDictionary> SessionCache = new(); private const string CompletionsUrl = ServiceConstants.ApiUrls.AnthropicChatMessages; @@ -59,21 +59,24 @@ private void ValidateApiKey() public async Task Send(Chat chat, ChatRequestOptions options, CancellationToken cancellationToken = default) { + if (chat.BackendParams is not AnthropicInferenceParams) + { + throw new InvalidBackendParamsException(LLMApiRegistry.Anthropic.ApiName, nameof(AnthropicInferenceParams), chat.BackendParams.GetType().Name); + } + ValidateApiKey(); if (!chat.Messages.Any()) return null; - var apiKey = GetApiKey(); - var lastMessage = chat.Messages.Last(); - await ExtractImageFromFiles(lastMessage); + await ChatHelper.ExtractImageFromFiles(lastMessage); var conversation = GetOrCreateConversation(chat, options.CreateSession); var resultBuilder = new StringBuilder(); var tokens = new List(); - if (HasFiles(lastMessage)) + if (ChatHelper.HasFiles(lastMessage)) { var result = ChatHelper.ExtractMemoryOptions(lastMessage); var memoryResult = await AskMemory(chat, result, options, cancellationToken); @@ -97,7 +100,6 @@ await options.TokenCallback(new LLMTokenValue() return await ProcessWithToolsAsync( chat, conversation, - apiKey, tokens, options, cancellationToken); @@ -108,7 +110,6 @@ await options.TokenCallback(new LLMTokenValue() await ProcessStreamingChatAsync( chat, conversation, - apiKey, tokens, resultBuilder, options.TokenCallback, @@ -120,7 +121,6 @@ await ProcessStreamingChatAsync( await ProcessNonStreamingChatAsync( chat, conversation, - apiKey, resultBuilder, cancellationToken); } @@ -144,7 +144,6 @@ await notificationService.DispatchNotification( private async Task ProcessWithToolsAsync( Chat chat, List conversation, - string apiKey, List tokens, ChatRequestOptions options, CancellationToken cancellationToken) @@ -177,7 +176,6 @@ await notificationService.DispatchNotification( currentToolUses = await ProcessStreamingChatWithToolsAsync( chat, conversation, - apiKey, tokens, resultBuilder, options, @@ -188,7 +186,6 @@ await notificationService.DispatchNotification( currentToolUses = await ProcessNonStreamingChatWithToolsAsync( chat, conversation, - apiKey, resultBuilder, cancellationToken); } @@ -313,7 +310,6 @@ await notificationService.DispatchNotification( private async Task?> ProcessStreamingChatWithToolsAsync( Chat chat, List conversation, - string apiKey, List tokens, StringBuilder resultBuilder, ChatRequestOptions options, @@ -321,7 +317,7 @@ await notificationService.DispatchNotification( { var httpClient = CreateAnthropicHttpClient(); - var requestBody = BuildAnthropicRequestBody(chat, conversation, true); + var requestBody = await BuildAnthropicRequestBody(chat, conversation, true); var requestJson = JsonSerializer.Serialize(requestBody); var content = new StringContent(requestJson, Encoding.UTF8, "application/json"); @@ -458,18 +454,17 @@ private async Task HandleApiError(HttpResponseMessage response, CancellationToke private async Task?> ProcessNonStreamingChatWithToolsAsync( Chat chat, List conversation, - string apiKey, StringBuilder resultBuilder, CancellationToken cancellationToken) { var httpClient = CreateAnthropicHttpClient(); - var requestBody = BuildAnthropicRequestBody(chat, conversation, false); + var requestBody = await BuildAnthropicRequestBody(chat, conversation, false); var requestJson = JsonSerializer.Serialize(requestBody); var content = new StringContent(requestJson, Encoding.UTF8, "application/json"); using var response = await httpClient.PostAsync(CompletionsUrl, content, cancellationToken); - + if (!response.IsSuccessStatusCode) { await HandleApiError(response, cancellationToken); @@ -504,16 +499,31 @@ private async Task HandleApiError(HttpResponseMessage response, CancellationToke return toolUses.Any() ? toolUses : null; } - private object BuildAnthropicRequestBody(Chat chat, List conversation, bool stream) + private async Task> BuildAnthropicRequestBody(Chat chat, List conversation, bool stream) { + var anthParams = chat.BackendParams as AnthropicInferenceParams; + var requestBody = new Dictionary { ["model"] = chat.ModelId, - ["max_tokens"] = chat.InterferenceParams.MaxTokens < 0 ? 4096 : chat.InterferenceParams.MaxTokens, + ["max_tokens"] = anthParams?.MaxTokens ?? 4096, ["stream"] = stream, - ["messages"] = BuildAnthropicMessages(conversation) + ["messages"] = await ChatHelper.BuildMessagesArray(conversation, chat, ImageType.AsBase64) }; + if (anthParams != null) + { + if (anthParams.Temperature.HasValue) requestBody["temperature"] = anthParams.Temperature.Value; + if (anthParams.TopP.HasValue) requestBody["top_p"] = anthParams.TopP.Value; + if (anthParams.TopK.HasValue) requestBody["top_k"] = anthParams.TopK.Value; + } + + if (chat.BackendParams?.AdditionalParams != null) + { + foreach (var (key, value) in chat.BackendParams.AdditionalParams) + requestBody[key] = value; + } + var systemMessage = conversation.FirstOrDefault(m => m.Role.Equals("system", StringComparison.OrdinalIgnoreCase)); @@ -522,10 +532,10 @@ private object BuildAnthropicRequestBody(Chat chat, List conversati requestBody["system"] = systemContent; } - if (chat.InterferenceParams.Grammar is not null) + if (chat.InferenceGrammar is not null) { requestBody["system"] = - $"Respond only using the following grammar format: \n{chat.InterferenceParams.Grammar.Value}\n. Do not add explanations, code tags, or any extra content."; + $"Respond only using the following grammar format: \n{chat.InferenceGrammar.Value}\n. Do not add explanations, code tags, or any extra content."; } if (chat.ToolsConfiguration?.Tools != null && chat.ToolsConfiguration.Tools.Any()) @@ -541,40 +551,6 @@ private object BuildAnthropicRequestBody(Chat chat, List conversati return requestBody; } - private List BuildAnthropicMessages(List conversation) - { - var messages = new List(); - - foreach (var msg in conversation) - { - if (msg.Role.Equals("system", StringComparison.OrdinalIgnoreCase)) - continue; - - object content; - - if (msg.Content is string textContent) - { - content = textContent; - } - else if (msg.Content is List contentBlocks) - { - content = contentBlocks; - } - else - { - content = msg.Content; - } - - messages.Add(new - { - role = msg.Role, - content = content - }); - } - - return messages; - } - public async Task AskMemory(Chat chat, ChatMemoryOptions memoryOptions, ChatRequestOptions requestOptions, CancellationToken cancellationToken = default) { @@ -618,7 +594,7 @@ private List GetOrCreateConversation(Chat chat, bool createSession) conversation = new List(); } - OpenAiCompatibleService.MergeMessages(conversation, chat.Messages); + ChatHelper.MergeMessages(conversation, chat.Messages); return conversation; } @@ -630,51 +606,9 @@ private void UpdateSessionCache(string chatId, string assistantResponse, bool cr } } - private static bool HasFiles(Message message) - { - return message.Files != null && message.Files.Count > 0; - } - - private static async Task ExtractImageFromFiles(Message message) - { - if (message.Files == null || message.Files.Count == 0) - return; - - var imageFiles = message.Files - .Where(f => AnthropicImageExtensions.Contains(f.Extension.ToLowerInvariant())) - .ToList(); - - if (imageFiles.Count == 0) - return; - - var imageBytesList = new List(); - foreach (var imageFile in imageFiles) - { - if (imageFile.StreamContent != null) - { - using var ms = new MemoryStream(); - imageFile.StreamContent.Position = 0; - await imageFile.StreamContent.CopyToAsync(ms); - imageBytesList.Add(ms.ToArray()); - } - else if (imageFile.Path != null) - { - imageBytesList.Add(await File.ReadAllBytesAsync(imageFile.Path)); - } - - message.Files.Remove(imageFile); - } - - message.Images = imageBytesList; - - if (message.Files.Count == 0) - message.Files = null; - } - private async Task ProcessStreamingChatAsync( Chat chat, List conversation, - string apiKey, List tokens, StringBuilder resultBuilder, Func? tokenCallback, @@ -683,17 +617,7 @@ private async Task ProcessStreamingChatAsync( { var httpClient = CreateAnthropicHttpClient(); - var requestBody = new - { - model = chat.ModelId, - max_tokens = chat.InterferenceParams.MaxTokens < 0 ? 4096 : chat.InterferenceParams.MaxTokens, - stream = true, - system = chat.InterferenceParams.Grammar is not null - ? $"Respond only using the following grammar format: \n{chat.InterferenceParams.Grammar.Value}\n. Do not add explanations, code tags, or any extra content." - : "", - messages = await OpenAiCompatibleService.BuildMessagesArray(conversation, chat, ImageType.AsBase64) - }; - + var requestBody = await BuildAnthropicRequestBody(chat, conversation, true); var requestJson = JsonSerializer.Serialize(requestBody); var content = new StringContent(requestJson, Encoding.UTF8, "application/json"); @@ -765,23 +689,12 @@ await notificationService.DispatchNotification( private async Task ProcessNonStreamingChatAsync( Chat chat, List conversation, - string apiKey, StringBuilder resultBuilder, CancellationToken cancellationToken) { var httpClient = CreateAnthropicHttpClient(); - var requestBody = new - { - model = chat.ModelId, - max_tokens = chat.InterferenceParams.MaxTokens < 0 ? 4096 : chat.InterferenceParams.MaxTokens, - stream = false, - system = chat.InterferenceParams.Grammar is not null - ? $"Respond only using the following grammar format: \n{chat.InterferenceParams.Grammar.Value}\n. Do not add explanations, code tags, or any extra content." - : "", - messages = await OpenAiCompatibleService.BuildMessagesArray(conversation, chat, ImageType.AsBase64) - }; - + var requestBody = await BuildAnthropicRequestBody(chat, conversation, false); var requestJson = JsonSerializer.Serialize(requestBody); var content = new StringContent(requestJson, Encoding.UTF8, "application/json"); diff --git a/src/MaIN.Services/Services/LLMService/DeepSeekService.cs b/src/MaIN.Services/Services/LLMService/DeepSeekService.cs index 4c5f303a..2c70d92d 100644 --- a/src/MaIN.Services/Services/LLMService/DeepSeekService.cs +++ b/src/MaIN.Services/Services/LLMService/DeepSeekService.cs @@ -11,6 +11,7 @@ using System.Text.Json.Serialization; using MaIN.Domain.Exceptions; using MaIN.Domain.Models.Concrete; +using MaIN.Domain.Configuration.BackendInferenceParams; namespace MaIN.Services.Services.LLMService; @@ -31,6 +32,7 @@ public sealed class DeepSeekService( protected override string HttpClientName => ServiceConstants.HttpClients.DeepSeekClient; protected override string ChatCompletionsUrl => ServiceConstants.ApiUrls.DeepSeekOpenAiChatCompletions; protected override string ModelsUrl => ServiceConstants.ApiUrls.DeepSeekModels; + protected override Type ExpectedParamsType => typeof(DeepSeekInferenceParams); protected override string GetApiKey() { @@ -49,6 +51,17 @@ protected override void ValidateApiKey() } } + protected override void ApplyBackendParams(Dictionary requestBody, Chat chat) + { + if (chat.BackendParams is not DeepSeekInferenceParams p) return; + if (p.Temperature.HasValue) requestBody["temperature"] = p.Temperature.Value; + if (p.MaxTokens.HasValue) requestBody["max_tokens"] = p.MaxTokens.Value; + if (p.TopP.HasValue) requestBody["top_p"] = p.TopP.Value; + if (p.FrequencyPenalty.HasValue) requestBody["frequency_penalty"] = p.FrequencyPenalty.Value; + if (p.PresencePenalty.HasValue) requestBody["presence_penalty"] = p.PresencePenalty.Value; + if (p.ResponseFormat != null) requestBody["response_format"] = new { type = p.ResponseFormat }; + } + public override async Task AskMemory( Chat chat, ChatMemoryOptions memoryOptions, diff --git a/src/MaIN.Services/Services/LLMService/GeminiService.cs b/src/MaIN.Services/Services/LLMService/GeminiService.cs index 630d76ed..b8cb6ce5 100644 --- a/src/MaIN.Services/Services/LLMService/GeminiService.cs +++ b/src/MaIN.Services/Services/LLMService/GeminiService.cs @@ -13,6 +13,7 @@ using MaIN.Domain.Models; using MaIN.Domain.Models.Concrete; using MaIN.Services.Utils; +using MaIN.Domain.Configuration.BackendInferenceParams; namespace MaIN.Services.Services.LLMService; @@ -23,9 +24,7 @@ public sealed class GeminiService( IMemoryFactory memoryFactory, IMemoryService memoryService, ILogger? logger = null) -#pragma warning disable CS9107 // Parameter is captured into the state of the enclosing type and its value is also passed to the base constructor. The value might be captured by the base class as well. : OpenAiCompatibleService(notificationService, httpClientFactory, memoryFactory, memoryService, logger) -#pragma warning restore CS9107 // Parameter is captured into the state of the enclosing type and its value is also passed to the base constructor. The value might be captured by the base class as well. { private readonly MaINSettings _settings = settings ?? throw new ArgumentNullException(nameof(settings)); @@ -37,6 +36,7 @@ public sealed class GeminiService( protected override string HttpClientName => ServiceConstants.HttpClients.GeminiClient; protected override string ChatCompletionsUrl => ServiceConstants.ApiUrls.GeminiOpenAiChatCompletions; + protected override Type ExpectedParamsType => typeof(GeminiInferenceParams); public override async Task GetCurrentModels() { @@ -75,6 +75,15 @@ protected override void ValidateApiKey() } } + protected override void ApplyBackendParams(Dictionary requestBody, Chat chat) + { + if (chat.BackendParams is not GeminiInferenceParams p) return; + if (p.Temperature.HasValue) requestBody["temperature"] = p.Temperature.Value; + if (p.MaxTokens.HasValue) requestBody["max_tokens"] = p.MaxTokens.Value; + if (p.TopP.HasValue) requestBody["top_p"] = p.TopP.Value; + if (p.StopSequences is { Length: > 0 }) requestBody["stop"] = p.StopSequences; + } + public override async Task AskMemory( Chat chat, ChatMemoryOptions memoryOptions, diff --git a/src/MaIN.Services/Services/LLMService/GroqCloudService.cs b/src/MaIN.Services/Services/LLMService/GroqCloudService.cs index bab7221e..a19b8116 100644 --- a/src/MaIN.Services/Services/LLMService/GroqCloudService.cs +++ b/src/MaIN.Services/Services/LLMService/GroqCloudService.cs @@ -8,6 +8,7 @@ using MaIN.Services.Services.LLMService.Memory; using MaIN.Services.Constants; using MaIN.Services.Services.Models; +using MaIN.Domain.Configuration.BackendInferenceParams; namespace MaIN.Services.Services.LLMService; @@ -25,6 +26,7 @@ public sealed class GroqCloudService( protected override string HttpClientName => ServiceConstants.HttpClients.GroqCloudClient; protected override string ChatCompletionsUrl => ServiceConstants.ApiUrls.GroqCloudOpenAiChatCompletions; protected override string ModelsUrl => ServiceConstants.ApiUrls.GroqCloudModels; + protected override Type ExpectedParamsType => typeof(GroqCloudInferenceParams); protected override string GetApiKey() { @@ -43,6 +45,16 @@ protected override void ValidateApiKey() } } + protected override void ApplyBackendParams(Dictionary requestBody, Chat chat) + { + if (chat.BackendParams is not GroqCloudInferenceParams p) return; + if (p.Temperature.HasValue) requestBody["temperature"] = p.Temperature.Value; + if (p.MaxTokens.HasValue) requestBody["max_tokens"] = p.MaxTokens.Value; + if (p.TopP.HasValue) requestBody["top_p"] = p.TopP.Value; + if (p.FrequencyPenalty.HasValue) requestBody["frequency_penalty"] = p.FrequencyPenalty.Value; + if (p.ResponseFormat != null) requestBody["response_format"] = new { type = p.ResponseFormat }; + } + public override async Task AskMemory( Chat chat, ChatMemoryOptions memoryOptions, diff --git a/src/MaIN.Services/Services/LLMService/LLMService.cs b/src/MaIN.Services/Services/LLMService/LLMService.cs index 4164d5cf..791df6b2 100644 --- a/src/MaIN.Services/Services/LLMService/LLMService.cs +++ b/src/MaIN.Services/Services/LLMService/LLMService.cs @@ -5,8 +5,9 @@ using LLama.Sampling; using MaIN.Domain.Configuration; using MaIN.Domain.Entities; -using MaIN.Domain.Entities.Tools; +using MaIN.Domain.Exceptions; using MaIN.Domain.Exceptions.Models; +using MaIN.Domain.Entities.Tools; using MaIN.Domain.Models; using MaIN.Domain.Models.Abstract; using MaIN.Services.Constants; @@ -20,7 +21,7 @@ using System.Text; using System.Text.Json; using Grammar = LLama.Sampling.Grammar; -using InferenceParams = MaIN.Domain.Entities.InferenceParams; +using LocalInferenceParams = MaIN.Domain.Entities.LocalInferenceParams; #pragma warning disable KMEXP00 namespace MaIN.Services.Services.LLMService; @@ -55,6 +56,11 @@ public LLMService( ChatRequestOptions requestOptions, CancellationToken cancellationToken = default) { + if (chat.BackendParams is not LocalInferenceParams) + { + throw new InvalidBackendParamsException("Local LLM", nameof(LocalInferenceParams), chat.BackendParams.GetType().Name); + } + if (chat.Messages.Count == 0) { return null; @@ -322,14 +328,14 @@ private ModelParams CreateModelParameters(Chat chat, string modelKey, string? cu { return new ModelParams(ResolvePath(customPath, modelKey)) { - ContextSize = (uint?)chat.InterferenceParams.ContextSize, - GpuLayerCount = chat.InterferenceParams.GpuLayerCount, - SeqMax = chat.InterferenceParams.SeqMax, - BatchSize = chat.InterferenceParams.BatchSize, - UBatchSize = chat.InterferenceParams.UBatchSize, - Embeddings = chat.InterferenceParams.Embeddings, - TypeK = (GGMLType)chat.InterferenceParams.TypeK, - TypeV = (GGMLType)chat.InterferenceParams.TypeV, + ContextSize = (uint?)chat.LocalParams!.ContextSize, + GpuLayerCount = chat.LocalParams!.GpuLayerCount, + SeqMax = chat.LocalParams!.SeqMax, + BatchSize = chat.LocalParams!.BatchSize, + UBatchSize = chat.LocalParams!.UBatchSize, + Embeddings = chat.LocalParams!.Embeddings, + TypeK = (GGMLType)chat.LocalParams!.TypeK, + TypeV = (GGMLType)chat.LocalParams!.TypeV, }; } @@ -388,7 +394,7 @@ private static void ProcessTextMessage(Conversation conversation, bool isNewConversation) { var template = new LLamaTemplate(llmModel); - var finalPrompt = ChatHelper.GetFinalPrompt(lastMsg, model, isNewConversation); + var finalPrompt = GetFinalPrompt(lastMsg, model, isNewConversation); var hasTools = chat.ToolsConfiguration?.Tools is not null && chat.ToolsConfiguration.Tools.Count != 0; @@ -464,10 +470,22 @@ private static string FormatToolsForPrompt(ToolsConfiguration toolsConfig) var isComplete = false; var hasFailed = false; - using var sampler = LLMService.CreateSampler(chat.InterferenceParams); + using var sampler = LLMService.CreateSampler(chat.LocalParams!); var decoder = new StreamingTokenDecoder(executor.Context); - var inferenceParams = ChatHelper.CreateInferenceParams(chat, llmModel); + var inferenceParams = new InferenceParams + { + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = chat.LocalParams!.Temperature, + TopK = chat.LocalParams!.TopK, + TopP = chat.LocalParams!.TopP + }, + AntiPrompts = [llmModel.Vocab.EOT?.ToString() ?? "User:"], + TokensKeep = chat.LocalParams!.TokensKeep, + MaxTokens = chat.LocalParams!.MaxTokens + }; + var maxTokens = inferenceParams.MaxTokens == -1 ? int.MaxValue : inferenceParams.MaxTokens; var reasoningModel = model as IReasoningModel; @@ -528,22 +546,22 @@ private static string FormatToolsForPrompt(ToolsConfiguration toolsConfig) return (tokens, isComplete, hasFailed); } - private static BaseSamplingPipeline CreateSampler(InferenceParams interferenceParams) + private static BaseSamplingPipeline CreateSampler(LocalInferenceParams inferenceParams) { - return interferenceParams.Temperature == 0 + return inferenceParams.Temperature == 0 ? new GreedySamplingPipeline() { - Grammar = interferenceParams.Grammar is not null - ? new Grammar(interferenceParams.Grammar.Value, "root") + Grammar = inferenceParams.Grammar is not null + ? new Grammar(inferenceParams.Grammar.Value, "root") : null } : new DefaultSamplingPipeline() { - Temperature = interferenceParams.Temperature, - TopP = interferenceParams.TopP, - TopK = interferenceParams.TopK, - Grammar = interferenceParams.Grammar is not null - ? new Grammar(interferenceParams.Grammar.Value, "root") + Temperature = inferenceParams.Temperature, + TopP = inferenceParams.TopP, + TopK = inferenceParams.TopK, + Grammar = inferenceParams.Grammar is not null + ? new Grammar(inferenceParams.Grammar.Value, "root") : null }; } @@ -699,7 +717,7 @@ await SendNotification( } var toolCalls = parseResult.ToolCalls!; - responseMessage.Properties[ToolCallsProperty] = JsonSerializer.Serialize(toolCalls); + responseMessage.Properties[ServiceConstants.Properties.ToolCallsProperty] = JsonSerializer.Serialize(toolCalls); foreach (var toolCall in toolCalls) { @@ -752,8 +770,8 @@ await requestOptions.ToolCallback.Invoke(new ToolInvocation Type = MessageType.LocalLLM, Tool = true }; - toolMessage.Properties[ToolCallIdProperty] = toolCall.Id; - toolMessage.Properties[ToolNameProperty] = toolCall.Function.Name; + toolMessage.Properties[ServiceConstants.Properties.ToolCallIdProperty] = toolCall.Id; + toolMessage.Properties[ServiceConstants.Properties.ToolNameProperty] = toolCall.Function.Name; chat.Messages.Add(toolMessage.MarkProcessed()); } catch (Exception ex) @@ -766,8 +784,8 @@ await requestOptions.ToolCallback.Invoke(new ToolInvocation Type = MessageType.LocalLLM, Tool = true }; - toolMessage.Properties[ToolCallIdProperty] = toolCall.Id; - toolMessage.Properties[ToolNameProperty] = toolCall.Function.Name; + toolMessage.Properties[ServiceConstants.Properties.ToolCallIdProperty] = toolCall.Id; + toolMessage.Properties[ServiceConstants.Properties.ToolNameProperty] = toolCall.Function.Name; chat.Messages.Add(toolMessage.MarkProcessed()); } } @@ -805,7 +823,12 @@ await SendNotification( }; } - private const string ToolCallsProperty = "ToolCalls"; - private const string ToolCallIdProperty = "ToolCallId"; - private const string ToolNameProperty = "ToolName"; + private static string GetFinalPrompt(Message message, AIModel model, bool startSession) + { + var additionalPrompt = (model as IReasoningModel)?.AdditionalPrompt; + return startSession && additionalPrompt != null + ? $"{message.Content}{additionalPrompt}" + : message.Content; + } + } diff --git a/src/MaIN.Services/Services/LLMService/OllamaService.cs b/src/MaIN.Services/Services/LLMService/OllamaService.cs index cd59661a..978ff6c8 100644 --- a/src/MaIN.Services/Services/LLMService/OllamaService.cs +++ b/src/MaIN.Services/Services/LLMService/OllamaService.cs @@ -1,5 +1,6 @@ using System.Text; using MaIN.Domain.Configuration; +using MaIN.Domain.Configuration.BackendInferenceParams; using MaIN.Domain.Entities; using MaIN.Domain.Models.Concrete; using MaIN.Services.Constants; @@ -26,6 +27,7 @@ public sealed class OllamaService( protected override string HttpClientName => HasApiKey ? ServiceConstants.HttpClients.OllamaClient : ServiceConstants.HttpClients.OllamaLocalClient; protected override string ChatCompletionsUrl => HasApiKey ? ServiceConstants.ApiUrls.OllamaOpenAiChatCompletions : ServiceConstants.ApiUrls.OllamaLocalOpenAiChatCompletions; protected override string ModelsUrl => HasApiKey ? ServiceConstants.ApiUrls.OllamaModels : ServiceConstants.ApiUrls.OllamaLocalModels; + protected override Type ExpectedParamsType => typeof(OllamaInferenceParams); protected override string GetApiKey() { @@ -40,6 +42,22 @@ protected override void ValidateApiKey() // Cloud Ollama will fail at runtime if the key is missing } + protected override void ApplyBackendParams(Dictionary requestBody, Chat chat) + { + if (chat.BackendParams is not OllamaInferenceParams p) return; + if (p.Temperature.HasValue) requestBody["temperature"] = p.Temperature.Value; + if (p.MaxTokens.HasValue) requestBody["max_tokens"] = p.MaxTokens.Value; + if (p.TopP.HasValue) requestBody["top_p"] = p.TopP.Value; + if (p.TopK.HasValue) requestBody["top_k"] = p.TopK.Value; + if (p.NumCtx.HasValue || p.NumGpu.HasValue) + { + var options = new Dictionary(); + if (p.NumCtx.HasValue) options["num_ctx"] = p.NumCtx.Value; + if (p.NumGpu.HasValue) options["num_gpu"] = p.NumGpu.Value; + requestBody["options"] = options; + } + } + public override async Task AskMemory( Chat chat, ChatMemoryOptions memoryOptions, diff --git a/src/MaIN.Services/Services/LLMService/OpenAiCompatibleService.cs b/src/MaIN.Services/Services/LLMService/OpenAiCompatibleService.cs index 355e8391..ff4b641f 100644 --- a/src/MaIN.Services/Services/LLMService/OpenAiCompatibleService.cs +++ b/src/MaIN.Services/Services/LLMService/OpenAiCompatibleService.cs @@ -36,13 +36,11 @@ public abstract class OpenAiCompatibleService( private static readonly JsonSerializerOptions DefaultJsonSerializerOptions = new() { PropertyNameCaseInsensitive = true }; - private const string ToolCallsProperty = "ToolCalls"; - private const string ToolCallIdProperty = "ToolCallId"; - private const string ToolNameProperty = "ToolName"; protected abstract string GetApiKey(); protected abstract string GetApiName(); protected abstract void ValidateApiKey(); + protected abstract Type ExpectedParamsType { get; } protected virtual string HttpClientName => ServiceConstants.HttpClients.OpenAiClient; protected virtual string ChatCompletionsUrl => ServiceConstants.ApiUrls.OpenAiChatCompletions; protected virtual string ModelsUrl => ServiceConstants.ApiUrls.OpenAiModels; @@ -53,6 +51,11 @@ public abstract class OpenAiCompatibleService( ChatRequestOptions options, CancellationToken cancellationToken = default) { + if (chat.BackendParams.GetType() != ExpectedParamsType) + { + throw new InvalidBackendParamsException(GetApiName(), ExpectedParamsType.Name, chat.BackendParams.GetType().Name); + } + ValidateApiKey(); if (!chat.Messages.Any()) return null; @@ -457,7 +460,7 @@ await _notificationService.DispatchNotification( } // If there are images, use SearchAsync + regular chat with images - if (HasImages(lastMessage)) + if (ChatHelper.HasImages(lastMessage)) { var searchResult = await kernel.SearchAsync(userQuery, cancellationToken: cancellationToken); await kernel.DeleteIndexAsync(cancellationToken: cancellationToken); @@ -622,7 +625,7 @@ private List GetOrCreateConversation(Chat chat, bool createSession) conversation = new List(); } - MergeMessages(conversation, chat.Messages); + ChatHelper.MergeMessages(conversation, chat.Messages); return conversation; } @@ -682,7 +685,6 @@ private static void SetAuthorizationIfNeeded(HttpClient client, string apiKey) client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey); } } - private async Task ProcessStreamingChatAsync( Chat chat, List conversation, @@ -852,10 +854,13 @@ private object BuildRequestBody(Chat chat, List conversation, bool var requestBody = new Dictionary { ["model"] = chat.ModelId, - ["messages"] = BuildMessagesArray(conversation, chat, ImageType.AsUrl).Result, + ["messages"] = ChatHelper.BuildMessagesArray(conversation, chat, ImageType.AsUrl).Result, ["stream"] = stream }; + ApplyBackendParams(requestBody, chat); + ApplyAdditionalParams(requestBody, chat); + if (chat.ToolsConfiguration?.Tools != null && chat.ToolsConfiguration.Tools.Any()) { requestBody["tools"] = chat.ToolsConfiguration.Tools.Select(t => new @@ -868,7 +873,7 @@ private object BuildRequestBody(Chat chat, List conversation, bool parameters = t.Function.Parameters } : null }).ToList(); - + if (!string.IsNullOrEmpty(chat.ToolsConfiguration.ToolChoice)) { requestBody["tool_choice"] = chat.ToolsConfiguration.ToolChoice; @@ -878,54 +883,20 @@ private object BuildRequestBody(Chat chat, List conversation, bool return requestBody; } - internal static void MergeMessages(List conversation, List messages) + protected virtual void ApplyBackendParams(Dictionary requestBody, Chat chat) { - var existing = new HashSet<(string, object)>(conversation.Select(m => (m.Role, m.Content))); - foreach (var msg in messages) + } + + private static void ApplyAdditionalParams(Dictionary requestBody, Chat chat) + { + if (chat.BackendParams?.AdditionalParams == null) return; + foreach (var (key, value) in chat.BackendParams.AdditionalParams) { - var role = msg.Role.ToLowerInvariant(); - - if (HasImages(msg)) - { - var simplifiedContent = $"{msg.Content} [Contains image]"; - if (!existing.Contains((role, simplifiedContent))) - { - var chatMessage = new ChatMessage(role, msg.Content); - chatMessage.OriginalMessage = msg; - conversation.Add(chatMessage); - existing.Add((role, simplifiedContent)); - } - } - else - { - if (!existing.Contains((role, msg.Content))) - { - var chatMessage = new ChatMessage(role, msg.Content); - - // Extract tool-related data from Properties - if (msg.Tool && msg.Properties.ContainsKey(ToolCallsProperty)) - { - var toolCallsJson = msg.Properties[ToolCallsProperty]; - chatMessage.ToolCalls = JsonSerializer.Deserialize>(toolCallsJson); - } - - if (msg.Properties.ContainsKey(ToolCallIdProperty)) - { - chatMessage.ToolCallId = msg.Properties[ToolCallIdProperty]; - } - - if (msg.Properties.ContainsKey(ToolNameProperty)) - { - chatMessage.Name = msg.Properties[ToolNameProperty]; - } - - conversation.Add(chatMessage); - existing.Add((role, msg.Content)); - } - } + requestBody[key] = value; } } + protected static ChatResult CreateChatResult(Chat chat, string content, List tokens) { return new ChatResult @@ -943,59 +914,6 @@ protected static ChatResult CreateChatResult(Chat chat, string content, List BuildMessagesArray(List conversation, Chat chat, ImageType imageType) - { - var messages = new List(); - - foreach (var msg in conversation) - { - var content = msg.OriginalMessage != null ? BuildMessageContent(msg.OriginalMessage, imageType) : msg.Content; - if (chat.InterferenceParams.Grammar != null && msg.Role == "user") - { - var jsonGrammarConverter = new GrammarToJsonConverter(); - string jsonGrammar = jsonGrammarConverter.ConvertToJson(chat.InterferenceParams.Grammar); - - var grammarInstruction = $" | Respond only using the following JSON format: \n{jsonGrammar}\n. Do not add explanations, code tags, or any extra content."; - - if (content is string textContent) - { - content = textContent + grammarInstruction; - } - else if (content is List contentParts) - { - var modifiedParts = contentParts.ToList(); - modifiedParts.Add(new { type = "text", text = grammarInstruction }); - content = modifiedParts; - } - } - - var messageObj = new Dictionary - { - ["role"] = msg.Role, - ["content"] = content ?? string.Empty - }; - - if (msg.ToolCalls != null && msg.ToolCalls.Any()) - { - messageObj["tool_calls"] = msg.ToolCalls; - } - - if (!string.IsNullOrEmpty(msg.ToolCallId)) - { - messageObj["tool_call_id"] = msg.ToolCallId; - - if (!string.IsNullOrEmpty(msg.Name)) - { - messageObj["name"] = msg.Name; - } - } - - messages.Add(messageObj); - } - - return messages.ToArray(); - } - private static async Task InvokeTokenCallbackAsync(Func? callback, LLMTokenValue token) { if (callback != null) @@ -1003,112 +921,6 @@ private static async Task InvokeTokenCallbackAsync(Func? ca await callback.Invoke(token); } } - - private static bool HasImages(Message message) - { - return message.Images?.Count > 0; - } - - private static object BuildMessageContent(Message message, ImageType imageType) - { - if (!HasImages(message)) - { - return message.Content; - } - - var contentParts = new List(); - - if (!string.IsNullOrEmpty(message.Content)) - { - contentParts.Add(new - { - type = "text", - text = message.Content - }); - } - - foreach (var imageBytes in message.Images!) - { - var base64Data = Convert.ToBase64String(imageBytes); - var mimeType = DetectImageMimeType(imageBytes); - - switch (imageType) - { - case ImageType.AsUrl: - contentParts.Add(new - { - type = "image_url", - image_url = new - { - url = $"data:{mimeType};base64,{base64Data}", - detail = "auto" - } - }); - break; - case ImageType.AsBase64: - contentParts.Add(new - { - type = "image", - source = new - { - data = base64Data, - media_type = mimeType, - type = "base64" - } - }); - break; - } - } - - return contentParts; - } - - private static string DetectImageMimeType(byte[] imageBytes) - { - if (imageBytes.Length < 4) - return "image/jpeg"; - - if (imageBytes[0] == 0xFF && imageBytes[1] == 0xD8) - return "image/jpeg"; - - if (imageBytes.Length >= 8 && - imageBytes[0] == 0x89 && imageBytes[1] == 0x50 && - imageBytes[2] == 0x4E && imageBytes[3] == 0x47) - return "image/png"; - - if (imageBytes.Length >= 6 && - imageBytes[0] == 0x47 && imageBytes[1] == 0x49 && - imageBytes[2] == 0x46 && imageBytes[3] == 0x38) - return "image/gif"; - - if (imageBytes.Length >= 12 && - imageBytes[0] == 0x52 && imageBytes[1] == 0x49 && - imageBytes[2] == 0x46 && imageBytes[3] == 0x46 && - imageBytes[8] == 0x57 && imageBytes[9] == 0x45 && - imageBytes[10] == 0x42 && imageBytes[11] == 0x50) - return "image/webp"; - - // HEIC/HEIF format (iPhone photos) - if (imageBytes.Length >= 12 && - imageBytes[4] == 0x66 && imageBytes[5] == 0x74 && - imageBytes[6] == 0x79 && imageBytes[7] == 0x70) - { - // Check for heic/heif brands - if ((imageBytes[8] == 0x68 && imageBytes[9] == 0x65 && imageBytes[10] == 0x69 && imageBytes[11] == 0x63) || - (imageBytes[8] == 0x68 && imageBytes[9] == 0x65 && imageBytes[10] == 0x69 && imageBytes[11] == 0x66)) - return "image/heic"; - } - - // AVIF format - if (imageBytes.Length >= 12 && - imageBytes[4] == 0x66 && imageBytes[5] == 0x74 && - imageBytes[6] == 0x79 && imageBytes[7] == 0x70 && - imageBytes[8] == 0x61 && imageBytes[9] == 0x76 && - imageBytes[10] == 0x69 && imageBytes[11] == 0x66) - return "image/avif"; - - return "image/jpeg"; - } } internal class ChatMessage diff --git a/src/MaIN.Services/Services/LLMService/OpenAiService.cs b/src/MaIN.Services/Services/LLMService/OpenAiService.cs index 203d52b1..1848d58c 100644 --- a/src/MaIN.Services/Services/LLMService/OpenAiService.cs +++ b/src/MaIN.Services/Services/LLMService/OpenAiService.cs @@ -1,9 +1,11 @@ using MaIN.Domain.Configuration; +using MaIN.Domain.Entities; using MaIN.Domain.Exceptions; using MaIN.Domain.Models.Concrete; using MaIN.Services.Services.Abstract; using Microsoft.Extensions.Logging; using MaIN.Services.Services.LLMService.Memory; +using MaIN.Domain.Configuration.BackendInferenceParams; namespace MaIN.Services.Services.LLMService; @@ -17,7 +19,9 @@ public sealed class OpenAiService( : OpenAiCompatibleService(notificationService, httpClientFactory, memoryFactory, memoryService, logger) { private readonly MaINSettings _settings = settings ?? throw new ArgumentNullException(nameof(settings)); - + + protected override Type ExpectedParamsType => typeof(OpenAiInferenceParams); + protected override string GetApiKey() { return _settings.OpenAiKey ?? Environment.GetEnvironmentVariable(LLMApiRegistry.OpenAi.ApiKeyEnvName) ?? @@ -34,6 +38,17 @@ protected override void ValidateApiKey() } } + protected override void ApplyBackendParams(Dictionary requestBody, Chat chat) + { + if (chat.BackendParams is not OpenAiInferenceParams p) return; + if (p.MaxTokens.HasValue) requestBody["max_tokens"] = p.MaxTokens.Value; + if (p.Temperature.HasValue) requestBody["temperature"] = p.Temperature.Value; + if (p.TopP.HasValue) requestBody["top_p"] = p.TopP.Value; + if (p.FrequencyPenalty.HasValue) requestBody["frequency_penalty"] = p.FrequencyPenalty.Value; + if (p.PresencePenalty.HasValue) requestBody["presence_penalty"] = p.PresencePenalty.Value; + if (p.ResponseFormat != null) requestBody["response_format"] = new { type = p.ResponseFormat }; + } + public override async Task GetCurrentModels() { var allModels = await base.GetCurrentModels(); diff --git a/src/MaIN.Services/Services/LLMService/Utils/ChatHelper.cs b/src/MaIN.Services/Services/LLMService/Utils/ChatHelper.cs index 7c33858c..d9c08bd9 100644 --- a/src/MaIN.Services/Services/LLMService/Utils/ChatHelper.cs +++ b/src/MaIN.Services/Services/LLMService/Utils/ChatHelper.cs @@ -1,9 +1,8 @@ -using LLama; -using LLama.Sampling; +using System.Text.Json; using MaIN.Domain.Entities; -using MaIN.Domain.Models.Abstract; +using MaIN.Domain.Entities.Tools; using MaIN.Services.Constants; -using InferenceParams = LLama.Common.InferenceParams; +using MaIN.Services.Utils; namespace MaIN.Services.Services.LLMService.Utils; @@ -59,36 +58,6 @@ public static async Task ExtractImageFromFiles(Message message) } - /// - /// Generates final prompt including additional prompt if needed - /// - public static string GetFinalPrompt(Message message, AIModel model, bool startSession) - { - var additionalPrompt = (model as IReasoningModel)?.AdditionalPrompt; - return startSession && additionalPrompt != null - ? $"{message.Content}{additionalPrompt}" - : message.Content; - } - - /// - /// Creates inference parameters for a chat - /// - public static InferenceParams CreateInferenceParams(Chat chat, LLamaWeights model) - { - return new InferenceParams - { - SamplingPipeline = new DefaultSamplingPipeline - { - Temperature = chat.InterferenceParams.Temperature, - TopK = chat.InterferenceParams.TopK, - TopP = chat.InterferenceParams.TopP - }, - AntiPrompts = [model.Vocab.EOT?.ToString() ?? "User:"], - TokensKeep = chat.InterferenceParams.TokensKeep, - MaxTokens = chat.InterferenceParams.MaxTokens - }; - } - /// /// Checks if a message contains files /// @@ -97,6 +66,11 @@ public static bool HasFiles(Message message) return message.Files?.Any() ?? false; } + public static bool HasImages(Message message) + { + return message.Images?.Count > 0; + } + /// /// Extracts memory options from a message with files /// @@ -127,4 +101,200 @@ public static ChatMemoryOptions ExtractMemoryOptions(Message message) PreProcess = preProcess }; } + + /// + /// Builds an array of message objects for API requests, handling images and grammar injection. + /// + internal static void MergeMessages(List conversation, List messages) + { + var existing = new HashSet<(string, object)>(conversation.Select(m => (m.Role, m.Content))); + foreach (var msg in messages) + { + var role = msg.Role.ToLowerInvariant(); + + if (HasImages(msg)) + { + var simplifiedContent = $"{msg.Content} [Contains image]"; + if (!existing.Contains((role, simplifiedContent))) + { + var chatMessage = new ChatMessage(role, msg.Content) { OriginalMessage = msg }; + conversation.Add(chatMessage); + existing.Add((role, simplifiedContent)); + } + } + else + { + if (!existing.Contains((role, msg.Content))) + { + var chatMessage = new ChatMessage(role, msg.Content); + + if (msg.Tool && msg.Properties.ContainsKey(ServiceConstants.Properties.ToolCallsProperty)) + { + var toolCallsJson = msg.Properties[ServiceConstants.Properties.ToolCallsProperty]; + chatMessage.ToolCalls = JsonSerializer.Deserialize>(toolCallsJson); + } + + if (msg.Properties.ContainsKey(ServiceConstants.Properties.ToolCallIdProperty)) + chatMessage.ToolCallId = msg.Properties[ServiceConstants.Properties.ToolCallIdProperty]; + + if (msg.Properties.ContainsKey(ServiceConstants.Properties.ToolNameProperty)) + chatMessage.Name = msg.Properties[ServiceConstants.Properties.ToolNameProperty]; + + conversation.Add(chatMessage); + existing.Add((role, msg.Content)); + } + } + } + } + + internal static async Task BuildMessagesArray(List conversation, Chat chat, ImageType imageType) + { + var messages = new List(); + + foreach (var msg in conversation) + { + var content = msg.OriginalMessage != null ? BuildMessageContent(msg.OriginalMessage, imageType) : msg.Content; + if (chat.InferenceGrammar != null && msg.Role == "user") + { + var jsonGrammarConverter = new GrammarToJsonConverter(); + string jsonGrammar = jsonGrammarConverter.ConvertToJson(chat.InferenceGrammar); + + var grammarInstruction = $" | Respond only using the following JSON format: \n{jsonGrammar}\n. Do not add explanations, code tags, or any extra content."; + + if (content is string textContent) + { + content = textContent + grammarInstruction; + } + else if (content is List contentParts) + { + var modifiedParts = contentParts.ToList(); + modifiedParts.Add(new { type = "text", text = grammarInstruction }); + content = modifiedParts; + } + } + + var messageObj = new Dictionary + { + ["role"] = msg.Role, + ["content"] = content ?? string.Empty + }; + + if (msg.ToolCalls != null && msg.ToolCalls.Any()) + { + messageObj["tool_calls"] = msg.ToolCalls; + } + + if (!string.IsNullOrEmpty(msg.ToolCallId)) + { + messageObj["tool_call_id"] = msg.ToolCallId; + + if (!string.IsNullOrEmpty(msg.Name)) + { + messageObj["name"] = msg.Name; + } + } + + messages.Add(messageObj); + } + + return messages.ToArray(); + } + + private static object BuildMessageContent(Message message, ImageType imageType) + { + if (message.Images == null || message.Images.Count == 0) + { + return message.Content; + } + + var contentParts = new List(); + + if (!string.IsNullOrEmpty(message.Content)) + { + contentParts.Add(new + { + type = "text", + text = message.Content + }); + } + + foreach (var imageBytes in message.Images) + { + var base64Data = Convert.ToBase64String(imageBytes); + var mimeType = DetectImageMimeType(imageBytes); + + switch (imageType) + { + case ImageType.AsUrl: + contentParts.Add(new + { + type = "image_url", + image_url = new + { + url = $"data:{mimeType};base64,{base64Data}", + detail = "auto" + } + }); + break; + case ImageType.AsBase64: + contentParts.Add(new + { + type = "image", + source = new + { + data = base64Data, + media_type = mimeType, + type = "base64" + } + }); + break; + } + } + + return contentParts; + } + + private static string DetectImageMimeType(byte[] imageBytes) + { + if (imageBytes.Length < 4) + return "image/jpeg"; + + if (imageBytes[0] == 0xFF && imageBytes[1] == 0xD8) + return "image/jpeg"; + + if (imageBytes.Length >= 8 && + imageBytes[0] == 0x89 && imageBytes[1] == 0x50 && + imageBytes[2] == 0x4E && imageBytes[3] == 0x47) + return "image/png"; + + if (imageBytes.Length >= 6 && + imageBytes[0] == 0x47 && imageBytes[1] == 0x49 && + imageBytes[2] == 0x46 && imageBytes[3] == 0x38) + return "image/gif"; + + if (imageBytes.Length >= 12 && + imageBytes[0] == 0x52 && imageBytes[1] == 0x49 && + imageBytes[2] == 0x46 && imageBytes[3] == 0x46 && + imageBytes[8] == 0x57 && imageBytes[9] == 0x45 && + imageBytes[10] == 0x42 && imageBytes[11] == 0x50) + return "image/webp"; + + if (imageBytes.Length >= 12 && + imageBytes[4] == 0x66 && imageBytes[5] == 0x74 && + imageBytes[6] == 0x79 && imageBytes[7] == 0x70) + { + if ((imageBytes[8] == 0x68 && imageBytes[9] == 0x65 && imageBytes[10] == 0x69 && imageBytes[11] == 0x63) || + (imageBytes[8] == 0x68 && imageBytes[9] == 0x65 && imageBytes[10] == 0x69 && imageBytes[11] == 0x66)) + return "image/heic"; + } + + if (imageBytes.Length >= 12 && + imageBytes[4] == 0x66 && imageBytes[5] == 0x74 && + imageBytes[6] == 0x79 && imageBytes[7] == 0x70 && + imageBytes[8] == 0x61 && imageBytes[9] == 0x76 && + imageBytes[10] == 0x69 && imageBytes[11] == 0x66) + return "image/avif"; + + return "image/jpeg"; + } } \ No newline at end of file diff --git a/src/MaIN.Services/Services/LLMService/XaiService.cs b/src/MaIN.Services/Services/LLMService/XaiService.cs index 632b3e55..b1231005 100644 --- a/src/MaIN.Services/Services/LLMService/XaiService.cs +++ b/src/MaIN.Services/Services/LLMService/XaiService.cs @@ -8,6 +8,7 @@ using System.Text; using MaIN.Domain.Exceptions; using MaIN.Domain.Models.Concrete; +using MaIN.Domain.Configuration.BackendInferenceParams; namespace MaIN.Services.Services.LLMService; @@ -25,6 +26,7 @@ public sealed class XaiService( protected override string HttpClientName => ServiceConstants.HttpClients.XaiClient; protected override string ChatCompletionsUrl => ServiceConstants.ApiUrls.XaiOpenAiChatCompletions; protected override string ModelsUrl => ServiceConstants.ApiUrls.XaiModels; + protected override Type ExpectedParamsType => typeof(XaiInferenceParams); protected override string GetApiKey() { @@ -42,6 +44,16 @@ protected override void ValidateApiKey() } } + protected override void ApplyBackendParams(Dictionary requestBody, Chat chat) + { + if (chat.BackendParams is not XaiInferenceParams p) return; + if (p.Temperature.HasValue) requestBody["temperature"] = p.Temperature.Value; + if (p.MaxTokens.HasValue) requestBody["max_tokens"] = p.MaxTokens.Value; + if (p.TopP.HasValue) requestBody["top_p"] = p.TopP.Value; + if (p.FrequencyPenalty.HasValue) requestBody["frequency_penalty"] = p.FrequencyPenalty.Value; + if (p.PresencePenalty.HasValue) requestBody["presence_penalty"] = p.PresencePenalty.Value; + } + public override async Task AskMemory( Chat chat, ChatMemoryOptions memoryOptions, diff --git a/src/MaIN.Services/Services/Steps/Commands/AnswerCommandHandler.cs b/src/MaIN.Services/Services/Steps/Commands/AnswerCommandHandler.cs index 4ea777f0..8a88fc83 100644 --- a/src/MaIN.Services/Services/Steps/Commands/AnswerCommandHandler.cs +++ b/src/MaIN.Services/Services/Steps/Commands/AnswerCommandHandler.cs @@ -78,7 +78,7 @@ private async Task ShouldUseKnowledge(Knowledge? knowledge, Chat chat, Bac var indexAsKnowledge = knowledge?.Index.Items.ToDictionary(x => x.Name, x => x.Tags); var index = JsonSerializer.Serialize(indexAsKnowledge, _jsonOptions); - chat.InterferenceParams.Grammar = new Grammar(ServiceConstants.Grammars.DecisionGrammar, GrammarFormat.GBNF); + chat.InferenceGrammar = new Grammar(ServiceConstants.Grammars.DecisionGrammar, GrammarFormat.GBNF); chat.Messages.Last().Content = $""" KNOWLEDGE: @@ -98,7 +98,7 @@ private async Task ShouldUseKnowledge(Knowledge? knowledge, Chat chat, Bac }); var decision = JsonSerializer.Deserialize(result!.Message.Content, _jsonOptions); var decisionValue = decision.GetProperty("decision").GetRawText(); - chat.InterferenceParams.Grammar = null; + chat.InferenceGrammar = null; var shouldUseKnowledge = bool.Parse(decisionValue.Trim('"')); chat.Messages.Last().Content = originalContent; return shouldUseKnowledge; @@ -110,7 +110,7 @@ private async Task ShouldUseKnowledge(Knowledge? knowledge, Chat chat, Bac var indexAsKnowledge = knowledge?.Index.Items.ToDictionary(x => x.Name, x => x.Tags); var index = JsonSerializer.Serialize(indexAsKnowledge, _jsonOptions); - chat.InterferenceParams.Grammar = new Grammar(ServiceConstants.Grammars.KnowledgeGrammar, GrammarFormat.GBNF); + chat.InferenceGrammar = new Grammar(ServiceConstants.Grammars.KnowledgeGrammar, GrammarFormat.GBNF); chat.Messages.Last().Content = $""" KNOWLEDGE: diff --git a/src/MaIN.Services/Services/Steps/FechDataStepHandler.cs b/src/MaIN.Services/Services/Steps/FechDataStepHandler.cs index bd74e401..bd5cca41 100644 --- a/src/MaIN.Services/Services/Steps/FechDataStepHandler.cs +++ b/src/MaIN.Services/Services/Steps/FechDataStepHandler.cs @@ -73,7 +73,7 @@ private static Chat CreateMemoryChat(StepContext context, string? filterVal) ModelId = context.Chat.ModelId, Properties = context.Chat.Properties, MemoryParams = context.Chat.MemoryParams, - InterferenceParams = context.Chat.InterferenceParams, + BackendParams = context.Chat.BackendParams, Name = "Memory Chat", Id = Guid.NewGuid().ToString() };