diff --git a/config.yaml.full b/config.yaml.full index df7aac0a..e29ed159 100644 --- a/config.yaml.full +++ b/config.yaml.full @@ -15,8 +15,8 @@ model: api_key: # [optional] for knowledgebase embedding: - name: doubao-embedding-text-240715 - dim: 2560 + name: doubao-embedding-vision-250615 + dim: 2048 api_base: https://ark.cn-beijing.volces.com/api/v3/ api_key: video: diff --git a/veadk/configs/model_configs.py b/veadk/configs/model_configs.py index de66505f..c6878a5c 100644 --- a/veadk/configs/model_configs.py +++ b/veadk/configs/model_configs.py @@ -23,6 +23,8 @@ DEFAULT_MODEL_AGENT_API_BASE, DEFAULT_MODEL_AGENT_NAME, DEFAULT_MODEL_AGENT_PROVIDER, + DEFAULT_MODEL_EMBEDDING_NAME, + DEFAULT_MODEL_EMBEDDING_DIM, ) @@ -46,13 +48,13 @@ def api_key(self) -> str: class EmbeddingModelConfig(BaseSettings): model_config = SettingsConfigDict(env_prefix="MODEL_EMBEDDING_") - name: str = "doubao-embedding-text-240715" + name: str = DEFAULT_MODEL_EMBEDDING_NAME """Model name for embedding.""" - dim: int = 2560 + dim: int = DEFAULT_MODEL_EMBEDDING_DIM """Embedding dim is different from different models.""" - api_base: str = "https://ark.cn-beijing.volces.com/api/v3/" + api_base: str = DEFAULT_MODEL_AGENT_API_BASE """The api base of the model for embedding.""" @cached_property diff --git a/veadk/consts.py b/veadk/consts.py index 5eaade54..fa5c1aa6 100644 --- a/veadk/consts.py +++ b/veadk/consts.py @@ -75,3 +75,7 @@ DEFAULT_NACOS_GROUP = "VEADK_GROUP" DEFAULT_NACOS_INSTANCE_NAME = "veadk" + +DEFAULT_MODEL_EMBEDDING_NAME = "doubao-embedding-vision-250615" +DEFAULT_MODEL_EMBEDDING_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/" +DEFAULT_MODEL_EMBEDDING_DIM = 2048 diff --git a/veadk/knowledgebase/backends/in_memory_backend.py b/veadk/knowledgebase/backends/in_memory_backend.py index 33b8b037..83a5f04d 100644 --- a/veadk/knowledgebase/backends/in_memory_backend.py +++ b/veadk/knowledgebase/backends/in_memory_backend.py @@ -14,13 +14,13 @@ from llama_index.core import Document, SimpleDirectoryReader, VectorStoreIndex from llama_index.core.schema import BaseNode -from llama_index.embeddings.openai_like import OpenAILikeEmbedding from pydantic import Field from typing_extensions import Any, override from veadk.configs.model_configs import EmbeddingModelConfig, NormalEmbeddingModelConfig from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend from veadk.knowledgebase.backends.utils import get_llama_index_splitter +from veadk.models.ark_embedding import create_embedding_model class InMemoryKnowledgeBackend(BaseKnowledgebaseBackend): @@ -39,7 +39,7 @@ class InMemoryKnowledgeBackend(BaseKnowledgebaseBackend): ) def model_post_init(self, __context: Any) -> None: - self._embed_model = OpenAILikeEmbedding( + self._embed_model = create_embedding_model( model_name=self.embedding_config.name, api_key=self.embedding_config.api_key, api_base=self.embedding_config.api_base, diff --git a/veadk/knowledgebase/backends/opensearch_backend.py b/veadk/knowledgebase/backends/opensearch_backend.py index 13abc0ab..598baf5b 100644 --- a/veadk/knowledgebase/backends/opensearch_backend.py +++ b/veadk/knowledgebase/backends/opensearch_backend.py @@ -21,7 +21,6 @@ VectorStoreIndex, ) from llama_index.core.schema import BaseNode -from llama_index.embeddings.openai_like import OpenAILikeEmbedding from pydantic import Field from typing_extensions import Any, override @@ -33,6 +32,7 @@ ) from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend from veadk.knowledgebase.backends.utils import get_llama_index_splitter +from veadk.models.ark_embedding import create_embedding_model from veadk.utils.logger import get_logger try: @@ -112,7 +112,7 @@ def model_post_init(self, __context: Any) -> None: vector_store=self._vector_store ) - self._embed_model = OpenAILikeEmbedding( + self._embed_model = create_embedding_model( model_name=self.embedding_config.name, api_key=self.embedding_config.api_key, api_base=self.embedding_config.api_base, diff --git a/veadk/knowledgebase/backends/redis_backend.py b/veadk/knowledgebase/backends/redis_backend.py index 7867562e..cf50aaa3 100644 --- a/veadk/knowledgebase/backends/redis_backend.py +++ b/veadk/knowledgebase/backends/redis_backend.py @@ -19,7 +19,6 @@ VectorStoreIndex, ) from llama_index.core.schema import BaseNode -from llama_index.embeddings.openai_like import OpenAILikeEmbedding from pydantic import Field from typing_extensions import Any, override @@ -28,6 +27,7 @@ from veadk.configs.model_configs import EmbeddingModelConfig, NormalEmbeddingModelConfig from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend from veadk.knowledgebase.backends.utils import get_llama_index_splitter +from veadk.models.ark_embedding import create_embedding_model try: from llama_index.vector_stores.redis import RedisVectorStore @@ -92,7 +92,7 @@ def model_post_init(self, __context: Any) -> None: password=self.redis_config.password, ) - self._embed_model = OpenAILikeEmbedding( + self._embed_model = create_embedding_model( model_name=self.embedding_config.name, api_key=self.embedding_config.api_key, api_base=self.embedding_config.api_base, diff --git a/veadk/knowledgebase/backends/tos_vector_backend.py b/veadk/knowledgebase/backends/tos_vector_backend.py index 779187c7..545e3216 100644 --- a/veadk/knowledgebase/backends/tos_vector_backend.py +++ b/veadk/knowledgebase/backends/tos_vector_backend.py @@ -20,7 +20,6 @@ SimpleDirectoryReader, ) from llama_index.core.schema import BaseNode -from llama_index.embeddings.openai_like import OpenAILikeEmbedding from pydantic import Field from typing_extensions import Any, override @@ -30,6 +29,7 @@ from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend from veadk.knowledgebase.backends.utils import get_llama_index_splitter +from veadk.models.ark_embedding import create_embedding_model from veadk.utils.logger import get_logger logger = get_logger(__name__) @@ -77,7 +77,7 @@ def model_post_init(self, __context: Any) -> None: # create_bucket and index if not exist self._create_index() - self._embed_model = OpenAILikeEmbedding( + self._embed_model = create_embedding_model( model_name=self.embedding_config.name, api_key=self.embedding_config.api_key, api_base=self.embedding_config.api_base, diff --git a/veadk/memory/long_term_memory_backends/in_memory_backend.py b/veadk/memory/long_term_memory_backends/in_memory_backend.py index 8d905cd4..cf047a11 100644 --- a/veadk/memory/long_term_memory_backends/in_memory_backend.py +++ b/veadk/memory/long_term_memory_backends/in_memory_backend.py @@ -14,7 +14,6 @@ from llama_index.core import Document, VectorStoreIndex from llama_index.core.schema import BaseNode -from llama_index.embeddings.openai_like import OpenAILikeEmbedding from pydantic import Field from typing_extensions import Any, override @@ -23,6 +22,7 @@ from veadk.memory.long_term_memory_backends.base_backend import ( BaseLongTermMemoryBackend, ) +from veadk.models.ark_embedding import create_embedding_model class InMemoryLTMBackend(BaseLongTermMemoryBackend): @@ -30,7 +30,7 @@ class InMemoryLTMBackend(BaseLongTermMemoryBackend): """Embedding model configs""" def model_post_init(self, __context: Any) -> None: - self._embed_model = OpenAILikeEmbedding( + self._embed_model = create_embedding_model( model_name=self.embedding_config.name, api_key=self.embedding_config.api_key, api_base=self.embedding_config.api_base, diff --git a/veadk/memory/long_term_memory_backends/opensearch_backend.py b/veadk/memory/long_term_memory_backends/opensearch_backend.py index 35231674..373d1dcb 100644 --- a/veadk/memory/long_term_memory_backends/opensearch_backend.py +++ b/veadk/memory/long_term_memory_backends/opensearch_backend.py @@ -16,7 +16,6 @@ from llama_index.core import Document, VectorStoreIndex from llama_index.core.schema import BaseNode -from llama_index.embeddings.openai_like import OpenAILikeEmbedding from pydantic import Field from typing_extensions import Any, override @@ -30,6 +29,7 @@ from veadk.memory.long_term_memory_backends.base_backend import ( BaseLongTermMemoryBackend, ) +from veadk.models.ark_embedding import create_embedding_model from veadk.utils.logger import get_logger try: @@ -55,7 +55,7 @@ class OpensearchLTMBackend(BaseLongTermMemoryBackend): """Embedding model configs""" def model_post_init(self, __context: Any) -> None: - self._embed_model = OpenAILikeEmbedding( + self._embed_model = create_embedding_model( model_name=self.embedding_config.name, api_key=self.embedding_config.api_key, api_base=self.embedding_config.api_base, diff --git a/veadk/memory/long_term_memory_backends/redis_backend.py b/veadk/memory/long_term_memory_backends/redis_backend.py index c530eea2..c199f3c6 100644 --- a/veadk/memory/long_term_memory_backends/redis_backend.py +++ b/veadk/memory/long_term_memory_backends/redis_backend.py @@ -14,7 +14,6 @@ from llama_index.core import Document, VectorStoreIndex from llama_index.core.schema import BaseNode -from llama_index.embeddings.openai_like import OpenAILikeEmbedding from pydantic import Field from typing_extensions import Any, override @@ -25,6 +24,7 @@ from veadk.memory.long_term_memory_backends.base_backend import ( BaseLongTermMemoryBackend, ) +from veadk.models.ark_embedding import create_embedding_model from veadk.utils.logger import get_logger try: @@ -51,7 +51,7 @@ class RedisLTMBackend(BaseLongTermMemoryBackend): """Embedding model configs""" def model_post_init(self, __context: Any) -> None: - self._embed_model = OpenAILikeEmbedding( + self._embed_model = create_embedding_model( model_name=self.embedding_config.name, api_key=self.embedding_config.api_key, api_base=self.embedding_config.api_base, diff --git a/veadk/models/ark_embedding.py b/veadk/models/ark_embedding.py new file mode 100644 index 00000000..5bcd8f68 --- /dev/null +++ b/veadk/models/ark_embedding.py @@ -0,0 +1,302 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any, Dict, Optional, List, Union, Tuple +from enum import Enum + +import httpx +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.callbacks.base import CallbackManager +from pydantic import PrivateAttr, Field +from volcenginesdkarkruntime import Ark, AsyncArk + +from veadk.consts import DEFAULT_MODEL_EMBEDDING_NAME, DEFAULT_MODEL_EMBEDDING_API_BASE +from llama_index.embeddings.openai_like import OpenAILikeEmbedding + + +class ArkEmbeddingModel(str, Enum): + DOUBAO_EMBEDDING_VISION_251215 = "doubao-embedding-vision-251215" + DOUBAO_EMBEDDING_VISION_250615 = "doubao-embedding-vision-250615" + + +class ArkEmbedding(BaseEmbedding): + additional_kwargs: Dict[str, Any] = Field( + default_factory=dict, description="Additional kwargs for the Ark API." + ) + + api_key: str = Field(description="The Ark API key.") + api_base: Optional[str] = Field( + default=None, description="The base URL for Ark API." + ) + max_retries: int = Field(default=10, description="Maximum number of retries.", ge=0) + timeout: float = Field(default=60.0, description="Timeout for each request.", ge=0) + default_headers: Optional[Dict[str, str]] = Field( + default=None, description="The default headers for API requests." + ) + reuse_client: bool = Field( + default=True, + description=( + "Reuse the Ark client between requests. When doing anything with large " + "volumes of async API calls, setting this to false can improve stability." + ), + ) + dimensions: Optional[int] = Field( + default=None, + description=( + "The number of dimensions on the output embedding vectors. " + "Works only with v3 embedding models." + ), + ) + + _client: Optional[Ark] = PrivateAttr() + _aclient: Optional[AsyncArk] = PrivateAttr() + _http_client: Optional[httpx.Client] = PrivateAttr() + _async_http_client: Optional[httpx.AsyncClient] = PrivateAttr() + + def __init__( + self, + model_name: str = DEFAULT_MODEL_EMBEDDING_NAME, + embed_batch_size: int = 100, + dimensions: Optional[int] = None, + additional_kwargs: Optional[Dict[str, Any]] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + max_retries: int = 10, + timeout: float = 60.0, + reuse_client: bool = True, + callback_manager: Optional[CallbackManager] = None, + default_headers: Optional[Dict[str, str]] = None, + http_client: Optional[httpx.Client] = None, + async_http_client: Optional[httpx.AsyncClient] = None, + num_workers: Optional[int] = None, + **kwargs: Any, + ) -> None: + additional_kwargs = additional_kwargs or {} + if dimensions is not None: + additional_kwargs["dimensions"] = dimensions + + api_key, api_base = self._resolve_credentials( + api_key=api_key, + api_base=api_base, + ) + + super().__init__( + embed_batch_size=embed_batch_size, + dimensions=dimensions, + callback_manager=callback_manager, + model_name=model_name, + additional_kwargs=additional_kwargs, + api_key=api_key, + api_base=api_base, + max_retries=max_retries, + reuse_client=reuse_client, + timeout=timeout, + default_headers=default_headers, + num_workers=num_workers, + **kwargs, + ) + + if self.api_base is None: + self.api_base = DEFAULT_MODEL_EMBEDDING_API_BASE + + self._client = None + self._aclient = None + self._http_client = http_client + self._async_http_client = async_http_client + + def _resolve_credentials( + self, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> Tuple[Optional[str], Optional[str]]: + if api_key is None: + api_key = os.getenv("MODEL_EMBEDDING_API_KEY") + if api_key is None: + raise ValueError( + "API key must be provided or set as MODEL_EMBEDDING_API_KEY environment variable" + ) + + return api_key, api_base + + def _get_credential_kwargs(self, is_async: bool = False) -> Dict[str, Any]: + return { + "api_key": self.api_key, + "base_url": self.api_base, + "timeout": self.timeout, + "max_retries": self.max_retries, + "http_client": self._async_http_client if is_async else self._http_client, + } + + def _get_client(self) -> Ark: + if not self.reuse_client: + return Ark(**self._get_credential_kwargs()) + if self._client is None: + self._client = Ark(**self._get_credential_kwargs()) + return self._client + + def _get_aclient(self) -> AsyncArk: + if not self.reuse_client: + return AsyncArk(**self._get_credential_kwargs(is_async=True)) + if self._aclient is None: + self._aclient = AsyncArk(**self._get_credential_kwargs(is_async=True)) + return self._aclient + + def _get_query_embedding(self, query: str) -> List[float]: + """Get query embedding.""" + client = self._get_client() + + input_data = [{"type": "text", "text": query}] + + response = client.multimodal_embeddings.create( + model=self.model_name, input=input_data, **self.additional_kwargs + ) + + return response.data.embedding + + async def _aget_query_embedding(self, query: str) -> List[float]: + """The asynchronous version of _get_query_embedding.""" + aclient = self._get_aclient() + + input_data = [{"type": "text", "text": query}] + + response = await aclient.multimodal_embeddings.create( + model=self.model_name, input=input_data, **self.additional_kwargs + ) + + return response.data.embedding + + def _get_text_embedding(self, text: str) -> List[float]: + """Get text embedding.""" + client = self._get_client() + + input_data = [{"type": "text", "text": text}] + + response = client.multimodal_embeddings.create( + model=self.model_name, input=input_data, **self.additional_kwargs + ) + + return response.data.embedding + + async def _aget_text_embedding(self, text: str) -> List[float]: + """Asynchronously get text embedding.""" + aclient = self._get_aclient() + + input_data = [{"type": "text", "text": text}] + + response = await aclient.multimodal_embeddings.create( + model=self.model_name, input=input_data, **self.additional_kwargs + ) + + return response.data.embedding + + def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """ + Get text embeddings for multiple texts. + + Simple loop implementation - Ark API requires one request per text. + """ + if not texts: + return [] + + client = self._get_client() + results = [] + + for text in texts: + single_input = [{"type": "text", "text": text}] + response = client.multimodal_embeddings.create( + model=self.model_name, input=single_input, **self.additional_kwargs + ) + results.append(response.data.embedding) + + return results + + async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """ + Asynchronously get text embeddings for multiple texts. + + Simple async loop implementation - Ark API requires one request per text. + """ + if not texts: + return [] + + aclient = self._get_aclient() + results = [] + + for text in texts: + single_input = [{"type": "text", "text": text}] + response = await aclient.multimodal_embeddings.create( + model=self.model_name, input=single_input, **self.additional_kwargs + ) + results.append(response.data.embedding) + + return results + + def get_text_embedding(self, text: str) -> List[float]: + return self._get_text_embedding(text) + + async def aget_text_embedding(self, text: str) -> List[float]: + return await self._aget_text_embedding(text) + + def get_text_embeddings(self, texts: List[str]) -> List[List[float]]: + return self._get_text_embeddings(texts) + + async def aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: + return await self._aget_text_embeddings(texts) + + def get_query_embedding(self, query: str) -> List[float]: + return self._get_query_embedding(query) + + async def aget_query_embedding(self, query: str) -> List[float]: + return await self._aget_query_embedding(query) + + @classmethod + def class_name(cls) -> str: + return "ArkEmbedding" + + +# Independent factory function +def create_embedding_model( + model_name: str, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + **kwargs: Any, +) -> Union["ArkEmbedding", "OpenAILikeEmbedding"]: + """ + Factory function: smart embedding model creation by model name + + Args: + model_name: Model name + api_key: API key + api_base: API base URL + **kwargs: Other parameters + + Returns: + Suitable embedding model instance (ArkEmbedding or OpenAILikeEmbedding) + """ + # Ark supported model list + ark_models = {"doubao-embedding-vision-250615", "doubao-embedding-vision-251215"} + + # Check if it's Ark supported model + if model_name in ark_models: + return ArkEmbedding( + model_name=model_name, api_key=api_key, api_base=api_base, **kwargs + ) + else: + # Use OpenAILikeEmbedding + from llama_index.embeddings.openai_like import OpenAILikeEmbedding + + return OpenAILikeEmbedding( + model_name=model_name, api_key=api_key, api_base=api_base, **kwargs + )