diff --git a/python/semantic_kernel/connectors/ai/modelslab/__init__.py b/python/semantic_kernel/connectors/ai/modelslab/__init__.py new file mode 100644 index 000000000000..57324d8e9ac0 --- /dev/null +++ b/python/semantic_kernel/connectors/ai/modelslab/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""ModelsLab connector for Semantic Kernel. + +Exposes ModelsLab's OpenAI-compatible chat API through Semantic Kernel's +standard ``ChatCompletionClientBase`` interface. + +Quick import +------------ +.. code-block:: python + + from semantic_kernel.connectors.ai.modelslab import ModelsLabChatCompletion +""" + +from semantic_kernel.connectors.ai.modelslab.modelslab_chat_completion import ( + ModelsLabChatCompletion, +) +from semantic_kernel.connectors.ai.modelslab.modelslab_settings import ( + MODELSLAB_CHAT_BASE_URL, + MODELSLAB_CHAT_MODELS, + MODELSLAB_DEFAULT_CHAT_MODEL, + ModelsLabSettings, +) + +__all__ = [ + "ModelsLabChatCompletion", + "ModelsLabSettings", + "MODELSLAB_CHAT_BASE_URL", + "MODELSLAB_CHAT_MODELS", + "MODELSLAB_DEFAULT_CHAT_MODEL", +] diff --git a/python/semantic_kernel/connectors/ai/modelslab/modelslab_chat_completion.py b/python/semantic_kernel/connectors/ai/modelslab/modelslab_chat_completion.py new file mode 100644 index 000000000000..2046ffc13e4a --- /dev/null +++ b/python/semantic_kernel/connectors/ai/modelslab/modelslab_chat_completion.py @@ -0,0 +1,182 @@ +# Copyright (c) Microsoft. All rights reserved. + +import logging +import os +from collections.abc import Mapping +from typing import Any + +from openai import AsyncOpenAI + +from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import OpenAIChatCompletion +from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError + +from .modelslab_settings import ( + MODELSLAB_CHAT_BASE_URL, + MODELSLAB_CHAT_MODELS, + MODELSLAB_DEFAULT_CHAT_MODEL, +) + +logger: logging.Logger = logging.getLogger(__name__) + + +class ModelsLabChatCompletion(OpenAIChatCompletion): + """ModelsLab Chat Completion connector for Semantic Kernel. + + Provides access to ModelsLab's uncensored large language models via an + OpenAI-compatible endpoint. Because ModelsLab's chat API is fully + compatible with the OpenAI Chat Completions spec, this class simply + subclasses :class:`OpenAIChatCompletion` and wires a custom + ``AsyncOpenAI`` client that points at the ModelsLab endpoint — no + additional request/response translation is needed. + + Supported models + ---------------- + - ``llama-3.1-8b-uncensored`` (128 K context, default) + - ``llama-3.1-70b-uncensored`` (128 K context) + + Quickstart + ---------- + .. code-block:: python + + import asyncio + from semantic_kernel import Kernel + from semantic_kernel.connectors.ai.modelslab import ModelsLabChatCompletion + from semantic_kernel.connectors.ai.open_ai import OpenAIChatPromptExecutionSettings + from semantic_kernel.contents import ChatHistory + + kernel = Kernel() + kernel.add_service( + ModelsLabChatCompletion( + ai_model_id="llama-3.1-70b-uncensored", + api_key="YOUR_MODELSLAB_API_KEY", # or set MODELSLAB_API_KEY env var + ) + ) + + chat = ChatHistory() + chat.add_user_message("Write a short poem about open-source AI.") + + settings = OpenAIChatPromptExecutionSettings(max_tokens=256, temperature=0.7) + + async def main(): + chat_service = kernel.get_service(type=ModelsLabChatCompletion) + result = await chat_service.get_chat_message_contents(chat, settings) + print(result[0].content) + + asyncio.run(main()) + + Environment variables + --------------------- + ``MODELSLAB_API_KEY`` + Your ModelsLab API key (required when ``api_key`` is not passed + explicitly). + ``MODELSLAB_CHAT_MODEL_ID`` + Model ID override (optional). + ``MODELSLAB_CHAT_BASE_URL`` + API base URL override (optional, defaults to the official endpoint). + """ + + def __init__( + self, + ai_model_id: str | None = None, + service_id: str | None = None, + api_key: str | None = None, + base_url: str | None = None, + default_headers: Mapping[str, str] | None = None, + async_client: AsyncOpenAI | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + ) -> None: + """Initialize a ModelsLabChatCompletion service. + + Parameters + ---------- + ai_model_id: + ModelsLab model identifier. Defaults to + ``"llama-3.1-8b-uncensored"`` (or ``MODELSLAB_CHAT_MODEL_ID`` + env var). + service_id: + Optional service ID used when multiple chat services are + registered with the same Kernel. + api_key: + ModelsLab API key. Falls back to ``MODELSLAB_API_KEY`` env var. + base_url: + Override the default ModelsLab chat endpoint. Falls back to + ``MODELSLAB_CHAT_BASE_URL`` env var, then the official URL. + default_headers: + Extra HTTP headers to attach to every request. + async_client: + Bring-your-own ``AsyncOpenAI`` client (skips all key/URL + resolution when provided). + env_file_path: + Path to a ``.env`` file for reading configuration. + env_file_encoding: + Encoding of the ``.env`` file (default: ``"utf-8"``). + """ + resolved_api_key = ( + api_key + or os.environ.get("MODELSLAB_API_KEY") + ) + resolved_model = ( + ai_model_id + or os.environ.get("MODELSLAB_CHAT_MODEL_ID") + or MODELSLAB_DEFAULT_CHAT_MODEL + ) + resolved_base_url = ( + base_url + or os.environ.get("MODELSLAB_CHAT_BASE_URL") + or MODELSLAB_CHAT_BASE_URL + ) + + if not async_client and not resolved_api_key: + raise ServiceInitializationError( + "ModelsLab API key is required. Pass it via the `api_key` " + "argument or set the MODELSLAB_API_KEY environment variable. " + "Get your key at https://modelslab.com/api-keys" + ) + + if resolved_model not in MODELSLAB_CHAT_MODELS: + logger.warning( + "Model '%s' is not in the known ModelsLab chat model list %s. " + "Proceeding anyway — the API will reject unsupported models.", + resolved_model, + MODELSLAB_CHAT_MODELS, + ) + + # Build the OpenAI-compatible client pointed at ModelsLab + if async_client is None: + async_client = AsyncOpenAI( + api_key=resolved_api_key, + base_url=resolved_base_url, + ) + + # Delegate entirely to OpenAIChatCompletion — no extra wiring needed + super().__init__( + ai_model_id=resolved_model, + service_id=service_id, + default_headers=default_headers, + async_client=async_client, + ) + + logger.info( + "ModelsLabChatCompletion initialised (model=%s, endpoint=%s)", + resolved_model, + resolved_base_url, + ) + + @classmethod + def from_dict(cls, settings: dict[str, Any]) -> "ModelsLabChatCompletion": + """Construct a ``ModelsLabChatCompletion`` from a settings dict. + + Parameters + ---------- + settings: + Dictionary with optional keys: ``ai_model_id``, ``service_id``, + ``api_key``, ``base_url``, ``default_headers``. + """ + return cls( + ai_model_id=settings.get("ai_model_id"), + service_id=settings.get("service_id"), + api_key=settings.get("api_key"), + base_url=settings.get("base_url"), + default_headers=settings.get("default_headers"), + ) diff --git a/python/semantic_kernel/connectors/ai/modelslab/modelslab_settings.py b/python/semantic_kernel/connectors/ai/modelslab/modelslab_settings.py new file mode 100644 index 000000000000..9bc02109d240 --- /dev/null +++ b/python/semantic_kernel/connectors/ai/modelslab/modelslab_settings.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft. All rights reserved. + +import logging +from typing import ClassVar + +from pydantic import SecretStr +from pydantic_settings import BaseSettings + +logger = logging.getLogger(__name__) + +MODELSLAB_CHAT_BASE_URL = "https://modelslab.com/api/uncensored-chat/v1" +MODELSLAB_API_BASE_URL = "https://modelslab.com/api/v6" + +# Default models available via ModelsLab uncensored chat +MODELSLAB_DEFAULT_CHAT_MODEL = "llama-3.1-8b-uncensored" +MODELSLAB_CHAT_MODELS: list[str] = [ + "llama-3.1-8b-uncensored", + "llama-3.1-70b-uncensored", +] + + +class ModelsLabSettings(BaseSettings): + """Settings for the ModelsLab connector. + + The settings are first loaded from environment variables with the prefix + ``MODELSLAB_``. If they are not found, the optional .env file is loaded + and the settings are loaded from there. + + Required: + - api_key: ModelsLab API key (``MODELSLAB_API_KEY``) + + Optional: + - chat_model_id: Model ID to use for chat completion + (``MODELSLAB_CHAT_MODEL_ID``, default: "llama-3.1-8b-uncensored") + - chat_base_url: Base URL for the ModelsLab chat API + (``MODELSLAB_CHAT_BASE_URL``) + """ + + env_prefix: ClassVar[str] = "MODELSLAB_" + + api_key: SecretStr | None = None + chat_model_id: str | None = None + chat_base_url: str | None = None + + class Config: + env_prefix = "MODELSLAB_" + env_file = None + extra = "ignore" diff --git a/python/tests/test_modelslab_chat_completion.py b/python/tests/test_modelslab_chat_completion.py new file mode 100644 index 000000000000..d130eacb11bd --- /dev/null +++ b/python/tests/test_modelslab_chat_completion.py @@ -0,0 +1,189 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for ModelsLabChatCompletion.""" + +import os +from unittest.mock import MagicMock, patch + +import pytest +from openai import AsyncOpenAI + +from semantic_kernel.connectors.ai.modelslab import ( + MODELSLAB_CHAT_BASE_URL, + MODELSLAB_CHAT_MODELS, + MODELSLAB_DEFAULT_CHAT_MODEL, + ModelsLabChatCompletion, +) +from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import ( + OpenAIChatCompletion, +) +from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_service(**kwargs) -> ModelsLabChatCompletion: + """Return a ModelsLabChatCompletion backed by a mock AsyncOpenAI client.""" + mock_client = MagicMock(spec=AsyncOpenAI) + return ModelsLabChatCompletion(async_client=mock_client, **kwargs) + + +# --------------------------------------------------------------------------- +# Inheritance +# --------------------------------------------------------------------------- + +class TestInheritance: + def test_is_open_ai_chat_completion_subclass(self): + """ModelsLabChatCompletion must inherit from OpenAIChatCompletion.""" + assert issubclass(ModelsLabChatCompletion, OpenAIChatCompletion) + + def test_instance_is_open_ai_chat_completion(self): + svc = _make_service() + assert isinstance(svc, OpenAIChatCompletion) + + +# --------------------------------------------------------------------------- +# Initialisation — happy paths +# --------------------------------------------------------------------------- + +class TestInit: + def test_default_model(self): + svc = _make_service() + assert svc.ai_model_id == MODELSLAB_DEFAULT_CHAT_MODEL + + def test_custom_model(self): + svc = _make_service(ai_model_id="llama-3.1-70b-uncensored") + assert svc.ai_model_id == "llama-3.1-70b-uncensored" + + def test_service_id_stored(self): + svc = _make_service(service_id="my-modelslab") + assert svc.service_id == "my-modelslab" + + def test_service_id_defaults_to_model_id(self): + svc = _make_service() + # SK behaviour: service_id defaults to ai_model_id when not supplied + assert svc.service_id is not None + + def test_api_key_from_argument(self): + """Passing api_key directly should not raise.""" + svc = ModelsLabChatCompletion(api_key="test-key-123") + assert svc is not None + + def test_api_key_from_env(self, monkeypatch): + monkeypatch.setenv("MODELSLAB_API_KEY", "env-key-xyz") + svc = ModelsLabChatCompletion() + assert svc is not None + + def test_custom_base_url(self): + """Service initialises cleanly with a custom base URL.""" + svc = ModelsLabChatCompletion( + api_key="k", + base_url="https://custom.endpoint/v1", + ) + assert svc is not None + + def test_base_url_from_env(self, monkeypatch): + monkeypatch.setenv("MODELSLAB_CHAT_BASE_URL", "https://my-proxy/v1") + svc = ModelsLabChatCompletion(api_key="k") + assert svc is not None + + def test_model_from_env(self, monkeypatch): + monkeypatch.setenv("MODELSLAB_CHAT_MODEL_ID", "llama-3.1-70b-uncensored") + svc = _make_service() + assert svc.ai_model_id == "llama-3.1-70b-uncensored" + + def test_async_client_bypasses_key_check(self): + """When an async_client is supplied, no API key is required.""" + mock_client = MagicMock(spec=AsyncOpenAI) + svc = ModelsLabChatCompletion(async_client=mock_client) + assert svc is not None + + +# --------------------------------------------------------------------------- +# Initialisation — error paths +# --------------------------------------------------------------------------- + +class TestInitErrors: + def test_missing_api_key_raises(self, monkeypatch): + monkeypatch.delenv("MODELSLAB_API_KEY", raising=False) + with pytest.raises(ServiceInitializationError, match="API key"): + ModelsLabChatCompletion() + + def test_unknown_model_logs_warning(self, caplog): + import logging + with caplog.at_level(logging.WARNING): + svc = ModelsLabChatCompletion( + api_key="k", + ai_model_id="gpt-4-not-real", + ) + assert "not in the known" in caplog.text + + +# --------------------------------------------------------------------------- +# from_dict factory +# --------------------------------------------------------------------------- + +class TestFromDict: + def test_from_dict_creates_instance(self): + svc = ModelsLabChatCompletion.from_dict({ + "ai_model_id": "llama-3.1-8b-uncensored", + "service_id": "ml-svc", + "api_key": "dict-key", + }) + assert isinstance(svc, ModelsLabChatCompletion) + assert svc.ai_model_id == "llama-3.1-8b-uncensored" + + def test_from_dict_empty_dict(self, monkeypatch): + monkeypatch.setenv("MODELSLAB_API_KEY", "env-key") + svc = ModelsLabChatCompletion.from_dict({}) + assert svc.ai_model_id == MODELSLAB_DEFAULT_CHAT_MODEL + + def test_from_dict_custom_base_url(self): + svc = ModelsLabChatCompletion.from_dict({ + "api_key": "k", + "base_url": "https://proxy.example.com/v1", + }) + assert svc is not None + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +class TestConstants: + def test_default_model_in_model_list(self): + assert MODELSLAB_DEFAULT_CHAT_MODEL in MODELSLAB_CHAT_MODELS + + def test_base_url_is_https(self): + assert MODELSLAB_CHAT_BASE_URL.startswith("https://") + + def test_known_models(self): + assert "llama-3.1-8b-uncensored" in MODELSLAB_CHAT_MODELS + assert "llama-3.1-70b-uncensored" in MODELSLAB_CHAT_MODELS + + +# --------------------------------------------------------------------------- +# AsyncOpenAI client wiring +# --------------------------------------------------------------------------- + +class TestClientWiring: + def test_custom_client_is_used(self): + """The supplied AsyncOpenAI client must be wired into the service.""" + mock_client = MagicMock(spec=AsyncOpenAI) + svc = ModelsLabChatCompletion(async_client=mock_client) + # OpenAIChatCompletion stores the client at .client + assert svc.client is mock_client + + def test_default_client_points_to_modelslab(self): + """When no client is given, the auto-built client must use ModelsLab URL.""" + with patch( + "semantic_kernel.connectors.ai.modelslab.modelslab_chat_completion.AsyncOpenAI" + ) as MockAsyncOpenAI: + MockAsyncOpenAI.return_value = MagicMock(spec=AsyncOpenAI) + svc = ModelsLabChatCompletion(api_key="test-key") + + call_kwargs = MockAsyncOpenAI.call_args.kwargs + assert call_kwargs["base_url"] == MODELSLAB_CHAT_BASE_URL + assert call_kwargs["api_key"] == "test-key"