From 3436457ec5c8b04f5d62e9a350861632df875489 Mon Sep 17 00:00:00 2001 From: "hanzhi.421" Date: Wed, 14 Jan 2026 11:32:48 +0800 Subject: [PATCH 1/4] fix: vision embedding --- veadk/configs/model_configs.py | 3 +- veadk/consts.py | 2 + veadk/models/ark_embedding.py | 115 +++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 veadk/models/ark_embedding.py diff --git a/veadk/configs/model_configs.py b/veadk/configs/model_configs.py index de66505f..38e66378 100644 --- a/veadk/configs/model_configs.py +++ b/veadk/configs/model_configs.py @@ -23,6 +23,7 @@ DEFAULT_MODEL_AGENT_API_BASE, DEFAULT_MODEL_AGENT_NAME, DEFAULT_MODEL_AGENT_PROVIDER, + DEFAULT_MODEL_EMBEDDING_NAME, ) @@ -46,7 +47,7 @@ def api_key(self) -> str: class EmbeddingModelConfig(BaseSettings): model_config = SettingsConfigDict(env_prefix="MODEL_EMBEDDING_") - name: str = "doubao-embedding-text-240715" + name: str = DEFAULT_MODEL_EMBEDDING_NAME """Model name for embedding.""" dim: int = 2560 diff --git a/veadk/consts.py b/veadk/consts.py index 5eaade54..dfde9b58 100644 --- a/veadk/consts.py +++ b/veadk/consts.py @@ -75,3 +75,5 @@ DEFAULT_NACOS_GROUP = "VEADK_GROUP" DEFAULT_NACOS_INSTANCE_NAME = "veadk" + +DEFAULT_MODEL_EMBEDDING_NAME = "doubao-embedding-vision-250615" diff --git a/veadk/models/ark_embedding.py b/veadk/models/ark_embedding.py new file mode 100644 index 00000000..51ffd28c --- /dev/null +++ b/veadk/models/ark_embedding.py @@ -0,0 +1,115 @@ +from typing import Any, Dict, Optional, List + +import httpx +from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding +from llama_index.core.callbacks.base import CallbackManager +# from llama_index.embeddings.openai import OpenAIEmbedding + + +class ArkEmbedding(BaseEmbedding): + """ + OpenAI-Like class for embeddings. + + Args: + model_name (str): + Model for embedding. + api_key (str): + The API key (if any) to use for the embedding API. + api_base (str): + The base URL for the embedding API. + api_version (str): + The version for the embedding API. + max_retries (int): + The maximum number of retries for the embedding API. + timeout (float): + The timeout for the embedding API. + reuse_client (bool): + Whether to reuse the client for the embedding API. + callback_manager (CallbackManager): + The callback manager for the embedding API. + default_headers (Dict[str, str]): + The default headers for the embedding API. + additional_kwargs (Dict[str, Any]): + Additional kwargs for the embedding API. + dimensions (int): + The number of dimensions for the embedding API. + + Example: + ```bash + pip install llama-index-embeddings-openai-like + ``` + + ```python + from llama_index.embeddings.openai_like import OpenAILikeEmbedding + + embedding = ArkEmbedding( + model_name="my-model-name", + api_base="http://localhost:1234/v1", + api_key="fake", + embed_batch_size=10, + ) + ``` + + """ + + def _get_query_embedding(self, query: str) -> Embedding: + # client = self._get_client() + # retry_decorator = self._create_retry_decorator() + + pass + + async def _aget_query_embedding(self, query: str) -> Embedding: + pass + + def _get_text_embedding(self, text: str) -> Embedding: + pass + + def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]: ... + + def __init__( + self, + model_name: str, + embed_batch_size: int = 10, + dimensions: Optional[int] = None, + additional_kwargs: Optional[Dict[str, Any]] = None, + api_key: str = "fake", + api_base: Optional[str] = None, + api_version: Optional[str] = None, + max_retries: int = 10, + timeout: float = 60.0, + reuse_client: bool = True, + callback_manager: Optional[CallbackManager] = None, + default_headers: Optional[Dict[str, str]] = None, + http_client: Optional[httpx.Client] = None, + async_http_client: Optional[httpx.AsyncClient] = None, + num_workers: Optional[int] = None, + **kwargs: Any, + ) -> None: + # ensure model is not passed in kwargs, will cause error in parent class + if "model" in kwargs: + raise ValueError( + "Use `model_name` instead of `model` to initialize OpenAILikeEmbedding" + ) + + super().__init__( + model_name=model_name, + embed_batch_size=embed_batch_size, + dimensions=dimensions, + callback_manager=callback_manager, + additional_kwargs=additional_kwargs, + api_key=api_key, + api_base=api_base, + api_version=api_version, + max_retries=max_retries, + reuse_client=reuse_client, + timeout=timeout, + default_headers=default_headers, + http_client=http_client, + async_http_client=async_http_client, + num_workers=num_workers, + **kwargs, + ) + + @classmethod + def class_name(cls) -> str: + return "ArkEmbedding" From 43914dadafbe525312093801e27a2bb7d63b8dbb Mon Sep 17 00:00:00 2001 From: "hanzhi.421" Date: Thu, 15 Jan 2026 15:50:19 +0800 Subject: [PATCH 2/4] feat: openai like embedding to ark embedding factory --- veadk/consts.py | 1 + .../backends/in_memory_backend.py | 4 +- .../backends/opensearch_backend.py | 4 +- veadk/knowledgebase/backends/redis_backend.py | 4 +- .../backends/tos_vector_backend.py | 4 +- .../in_memory_backend.py | 4 +- .../opensearch_backend.py | 4 +- .../redis_backend.py | 4 +- veadk/models/ark_embedding.py | 320 ++++++++++++++---- 9 files changed, 265 insertions(+), 84 deletions(-) diff --git a/veadk/consts.py b/veadk/consts.py index dfde9b58..69a98e59 100644 --- a/veadk/consts.py +++ b/veadk/consts.py @@ -77,3 +77,4 @@ DEFAULT_NACOS_INSTANCE_NAME = "veadk" DEFAULT_MODEL_EMBEDDING_NAME = "doubao-embedding-vision-250615" +DEFAULT_MODEL_EMBEDDING_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/" diff --git a/veadk/knowledgebase/backends/in_memory_backend.py b/veadk/knowledgebase/backends/in_memory_backend.py index 33b8b037..83a5f04d 100644 --- a/veadk/knowledgebase/backends/in_memory_backend.py +++ b/veadk/knowledgebase/backends/in_memory_backend.py @@ -14,13 +14,13 @@ from llama_index.core import Document, SimpleDirectoryReader, VectorStoreIndex from llama_index.core.schema import BaseNode -from llama_index.embeddings.openai_like import OpenAILikeEmbedding from pydantic import Field from typing_extensions import Any, override from veadk.configs.model_configs import EmbeddingModelConfig, NormalEmbeddingModelConfig from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend from veadk.knowledgebase.backends.utils import get_llama_index_splitter +from veadk.models.ark_embedding import create_embedding_model class InMemoryKnowledgeBackend(BaseKnowledgebaseBackend): @@ -39,7 +39,7 @@ class InMemoryKnowledgeBackend(BaseKnowledgebaseBackend): ) def model_post_init(self, __context: Any) -> None: - self._embed_model = OpenAILikeEmbedding( + self._embed_model = create_embedding_model( model_name=self.embedding_config.name, api_key=self.embedding_config.api_key, api_base=self.embedding_config.api_base, diff --git a/veadk/knowledgebase/backends/opensearch_backend.py b/veadk/knowledgebase/backends/opensearch_backend.py index 13abc0ab..598baf5b 100644 --- a/veadk/knowledgebase/backends/opensearch_backend.py +++ b/veadk/knowledgebase/backends/opensearch_backend.py @@ -21,7 +21,6 @@ VectorStoreIndex, ) from llama_index.core.schema import BaseNode -from llama_index.embeddings.openai_like import OpenAILikeEmbedding from pydantic import Field from typing_extensions import Any, override @@ -33,6 +32,7 @@ ) from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend from veadk.knowledgebase.backends.utils import get_llama_index_splitter +from veadk.models.ark_embedding import create_embedding_model from veadk.utils.logger import get_logger try: @@ -112,7 +112,7 @@ def model_post_init(self, __context: Any) -> None: vector_store=self._vector_store ) - self._embed_model = OpenAILikeEmbedding( + self._embed_model = create_embedding_model( model_name=self.embedding_config.name, api_key=self.embedding_config.api_key, api_base=self.embedding_config.api_base, diff --git a/veadk/knowledgebase/backends/redis_backend.py b/veadk/knowledgebase/backends/redis_backend.py index 7867562e..cf50aaa3 100644 --- a/veadk/knowledgebase/backends/redis_backend.py +++ b/veadk/knowledgebase/backends/redis_backend.py @@ -19,7 +19,6 @@ VectorStoreIndex, ) from llama_index.core.schema import BaseNode -from llama_index.embeddings.openai_like import OpenAILikeEmbedding from pydantic import Field from typing_extensions import Any, override @@ -28,6 +27,7 @@ from veadk.configs.model_configs import EmbeddingModelConfig, NormalEmbeddingModelConfig from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend from veadk.knowledgebase.backends.utils import get_llama_index_splitter +from veadk.models.ark_embedding import create_embedding_model try: from llama_index.vector_stores.redis import RedisVectorStore @@ -92,7 +92,7 @@ def model_post_init(self, __context: Any) -> None: password=self.redis_config.password, ) - self._embed_model = OpenAILikeEmbedding( + self._embed_model = create_embedding_model( model_name=self.embedding_config.name, api_key=self.embedding_config.api_key, api_base=self.embedding_config.api_base, diff --git a/veadk/knowledgebase/backends/tos_vector_backend.py b/veadk/knowledgebase/backends/tos_vector_backend.py index 779187c7..545e3216 100644 --- a/veadk/knowledgebase/backends/tos_vector_backend.py +++ b/veadk/knowledgebase/backends/tos_vector_backend.py @@ -20,7 +20,6 @@ SimpleDirectoryReader, ) from llama_index.core.schema import BaseNode -from llama_index.embeddings.openai_like import OpenAILikeEmbedding from pydantic import Field from typing_extensions import Any, override @@ -30,6 +29,7 @@ from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend from veadk.knowledgebase.backends.utils import get_llama_index_splitter +from veadk.models.ark_embedding import create_embedding_model from veadk.utils.logger import get_logger logger = get_logger(__name__) @@ -77,7 +77,7 @@ def model_post_init(self, __context: Any) -> None: # create_bucket and index if not exist self._create_index() - self._embed_model = OpenAILikeEmbedding( + self._embed_model = create_embedding_model( model_name=self.embedding_config.name, api_key=self.embedding_config.api_key, api_base=self.embedding_config.api_base, diff --git a/veadk/memory/long_term_memory_backends/in_memory_backend.py b/veadk/memory/long_term_memory_backends/in_memory_backend.py index 8d905cd4..cf047a11 100644 --- a/veadk/memory/long_term_memory_backends/in_memory_backend.py +++ b/veadk/memory/long_term_memory_backends/in_memory_backend.py @@ -14,7 +14,6 @@ from llama_index.core import Document, VectorStoreIndex from llama_index.core.schema import BaseNode -from llama_index.embeddings.openai_like import OpenAILikeEmbedding from pydantic import Field from typing_extensions import Any, override @@ -23,6 +22,7 @@ from veadk.memory.long_term_memory_backends.base_backend import ( BaseLongTermMemoryBackend, ) +from veadk.models.ark_embedding import create_embedding_model class InMemoryLTMBackend(BaseLongTermMemoryBackend): @@ -30,7 +30,7 @@ class InMemoryLTMBackend(BaseLongTermMemoryBackend): """Embedding model configs""" def model_post_init(self, __context: Any) -> None: - self._embed_model = OpenAILikeEmbedding( + self._embed_model = create_embedding_model( model_name=self.embedding_config.name, api_key=self.embedding_config.api_key, api_base=self.embedding_config.api_base, diff --git a/veadk/memory/long_term_memory_backends/opensearch_backend.py b/veadk/memory/long_term_memory_backends/opensearch_backend.py index 35231674..373d1dcb 100644 --- a/veadk/memory/long_term_memory_backends/opensearch_backend.py +++ b/veadk/memory/long_term_memory_backends/opensearch_backend.py @@ -16,7 +16,6 @@ from llama_index.core import Document, VectorStoreIndex from llama_index.core.schema import BaseNode -from llama_index.embeddings.openai_like import OpenAILikeEmbedding from pydantic import Field from typing_extensions import Any, override @@ -30,6 +29,7 @@ from veadk.memory.long_term_memory_backends.base_backend import ( BaseLongTermMemoryBackend, ) +from veadk.models.ark_embedding import create_embedding_model from veadk.utils.logger import get_logger try: @@ -55,7 +55,7 @@ class OpensearchLTMBackend(BaseLongTermMemoryBackend): """Embedding model configs""" def model_post_init(self, __context: Any) -> None: - self._embed_model = OpenAILikeEmbedding( + self._embed_model = create_embedding_model( model_name=self.embedding_config.name, api_key=self.embedding_config.api_key, api_base=self.embedding_config.api_base, diff --git a/veadk/memory/long_term_memory_backends/redis_backend.py b/veadk/memory/long_term_memory_backends/redis_backend.py index c530eea2..c199f3c6 100644 --- a/veadk/memory/long_term_memory_backends/redis_backend.py +++ b/veadk/memory/long_term_memory_backends/redis_backend.py @@ -14,7 +14,6 @@ from llama_index.core import Document, VectorStoreIndex from llama_index.core.schema import BaseNode -from llama_index.embeddings.openai_like import OpenAILikeEmbedding from pydantic import Field from typing_extensions import Any, override @@ -25,6 +24,7 @@ from veadk.memory.long_term_memory_backends.base_backend import ( BaseLongTermMemoryBackend, ) +from veadk.models.ark_embedding import create_embedding_model from veadk.utils.logger import get_logger try: @@ -51,7 +51,7 @@ class RedisLTMBackend(BaseLongTermMemoryBackend): """Embedding model configs""" def model_post_init(self, __context: Any) -> None: - self._embed_model = OpenAILikeEmbedding( + self._embed_model = create_embedding_model( model_name=self.embedding_config.name, api_key=self.embedding_config.api_key, api_base=self.embedding_config.api_base, diff --git a/veadk/models/ark_embedding.py b/veadk/models/ark_embedding.py index 51ffd28c..6752cd06 100644 --- a/veadk/models/ark_embedding.py +++ b/veadk/models/ark_embedding.py @@ -1,80 +1,64 @@ -from typing import Any, Dict, Optional, List +import os +from typing import Any, Dict, Optional, List, Union, Tuple +from enum import Enum import httpx -from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding +from llama_index.core.base.embeddings.base import BaseEmbedding from llama_index.core.callbacks.base import CallbackManager -# from llama_index.embeddings.openai import OpenAIEmbedding +from pydantic import PrivateAttr, Field +from volcenginesdkarkruntime import Ark, AsyncArk +from veadk.consts import DEFAULT_MODEL_EMBEDDING_NAME, DEFAULT_MODEL_EMBEDDING_API_BASE +from llama_index.embeddings.openai_like import OpenAILikeEmbedding -class ArkEmbedding(BaseEmbedding): - """ - OpenAI-Like class for embeddings. - - Args: - model_name (str): - Model for embedding. - api_key (str): - The API key (if any) to use for the embedding API. - api_base (str): - The base URL for the embedding API. - api_version (str): - The version for the embedding API. - max_retries (int): - The maximum number of retries for the embedding API. - timeout (float): - The timeout for the embedding API. - reuse_client (bool): - Whether to reuse the client for the embedding API. - callback_manager (CallbackManager): - The callback manager for the embedding API. - default_headers (Dict[str, str]): - The default headers for the embedding API. - additional_kwargs (Dict[str, Any]): - Additional kwargs for the embedding API. - dimensions (int): - The number of dimensions for the embedding API. - - Example: - ```bash - pip install llama-index-embeddings-openai-like - ``` - - ```python - from llama_index.embeddings.openai_like import OpenAILikeEmbedding - embedding = ArkEmbedding( - model_name="my-model-name", - api_base="http://localhost:1234/v1", - api_key="fake", - embed_batch_size=10, - ) - ``` - - """ +class ArkEmbeddingModel(str, Enum): + DOUBAO_EMBEDDING_VISION_251215 = "doubao-embedding-vision-251215" + DOUBAO_EMBEDDING_VISION_250615 = "doubao-embedding-vision-250615" - def _get_query_embedding(self, query: str) -> Embedding: - # client = self._get_client() - # retry_decorator = self._create_retry_decorator() - pass - - async def _aget_query_embedding(self, query: str) -> Embedding: - pass +class ArkEmbedding(BaseEmbedding): + additional_kwargs: Dict[str, Any] = Field( + default_factory=dict, description="Additional kwargs for the Ark API." + ) - def _get_text_embedding(self, text: str) -> Embedding: - pass + api_key: str = Field(description="The Ark API key.") + api_base: Optional[str] = Field( + default=None, description="The base URL for Ark API." + ) + max_retries: int = Field(default=10, description="Maximum number of retries.", ge=0) + timeout: float = Field(default=60.0, description="Timeout for each request.", ge=0) + default_headers: Optional[Dict[str, str]] = Field( + default=None, description="The default headers for API requests." + ) + reuse_client: bool = Field( + default=True, + description=( + "Reuse the Ark client between requests. When doing anything with large " + "volumes of async API calls, setting this to false can improve stability." + ), + ) + dimensions: Optional[int] = Field( + default=None, + description=( + "The number of dimensions on the output embedding vectors. " + "Works only with v3 embedding models." + ), + ) - def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]: ... + _client: Optional[Ark] = PrivateAttr() + _aclient: Optional[AsyncArk] = PrivateAttr() + _http_client: Optional[httpx.Client] = PrivateAttr() + _async_http_client: Optional[httpx.AsyncClient] = PrivateAttr() def __init__( self, - model_name: str, - embed_batch_size: int = 10, + model_name: str = DEFAULT_MODEL_EMBEDDING_NAME, + embed_batch_size: int = 100, dimensions: Optional[int] = None, additional_kwargs: Optional[Dict[str, Any]] = None, - api_key: str = "fake", + api_key: Optional[str] = None, api_base: Optional[str] = None, - api_version: Optional[str] = None, max_retries: int = 10, timeout: float = 60.0, reuse_client: bool = True, @@ -85,31 +69,227 @@ def __init__( num_workers: Optional[int] = None, **kwargs: Any, ) -> None: - # ensure model is not passed in kwargs, will cause error in parent class - if "model" in kwargs: - raise ValueError( - "Use `model_name` instead of `model` to initialize OpenAILikeEmbedding" - ) + additional_kwargs = additional_kwargs or {} + if dimensions is not None: + additional_kwargs["dimensions"] = dimensions + + api_key, api_base = self._resolve_credentials( + api_key=api_key, + api_base=api_base, + ) super().__init__( - model_name=model_name, embed_batch_size=embed_batch_size, dimensions=dimensions, callback_manager=callback_manager, + model_name=model_name, additional_kwargs=additional_kwargs, api_key=api_key, api_base=api_base, - api_version=api_version, max_retries=max_retries, reuse_client=reuse_client, timeout=timeout, default_headers=default_headers, - http_client=http_client, - async_http_client=async_http_client, num_workers=num_workers, **kwargs, ) + # 设置默认值 + if self.api_base is None: + self.api_base = DEFAULT_MODEL_EMBEDDING_API_BASE + + self._client = None + self._aclient = None + self._http_client = http_client + self._async_http_client = async_http_client + + def _resolve_credentials( + self, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> Tuple[Optional[str], Optional[str]]: + if api_key is None: + api_key = os.getenv("MODEL_EMBEDDING_API_KEY") + if api_key is None: + raise ValueError( + "API key must be provided or set as MODEL_EMBEDDING_API_KEY environment variable" + ) + + return api_key, api_base + + def _get_credential_kwargs(self, is_async: bool = False) -> Dict[str, Any]: + return { + "api_key": self.api_key, + "base_url": self.api_base, + "timeout": self.timeout, + "max_retries": self.max_retries, + "http_client": self._async_http_client if is_async else self._http_client, + } + + def _get_client(self) -> Ark: + if not self.reuse_client: + return Ark(**self._get_credential_kwargs()) + if self._client is None: + self._client = Ark(**self._get_credential_kwargs()) + return self._client + + def _get_aclient(self) -> AsyncArk: + if not self.reuse_client: + return AsyncArk(**self._get_credential_kwargs(is_async=True)) + if self._aclient is None: + self._aclient = AsyncArk(**self._get_credential_kwargs(is_async=True)) + return self._aclient + + def _get_query_embedding(self, query: str) -> List[float]: + """Get query embedding.""" + client = self._get_client() + + input_data = [{"type": "text", "text": query}] + + response = client.multimodal_embeddings.create( + model=self.model_name, input=input_data, **self.additional_kwargs + ) + + return response.data.embedding + + async def _aget_query_embedding(self, query: str) -> List[float]: + """The asynchronous version of _get_query_embedding.""" + aclient = self._get_aclient() + + input_data = [{"type": "text", "text": query}] + + response = await aclient.multimodal_embeddings.create( + model=self.model_name, input=input_data, **self.additional_kwargs + ) + + return response.data.embedding + + def _get_text_embedding(self, text: str) -> List[float]: + """Get text embedding.""" + client = self._get_client() + + input_data = [{"type": "text", "text": text}] + + response = client.multimodal_embeddings.create( + model=self.model_name, input=input_data, **self.additional_kwargs + ) + + return response.data.embedding + + async def _aget_text_embedding(self, text: str) -> List[float]: + """Asynchronously get text embedding.""" + aclient = self._get_aclient() + + input_data = [{"type": "text", "text": text}] + + response = await aclient.multimodal_embeddings.create( + model=self.model_name, input=input_data, **self.additional_kwargs + ) + + return response.data.embedding + + def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """ + Get text embeddings for multiple texts. + + Simple loop implementation - Ark API requires one request per text. + """ + if not texts: + return [] + + client = self._get_client() + results = [] + + for text in texts: + single_input = [{"type": "text", "text": text}] + response = client.multimodal_embeddings.create( + model=self.model_name, input=single_input, **self.additional_kwargs + ) + results.append(response.data.embedding) + + return results + + async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """ + Asynchronously get text embeddings for multiple texts. + + Simple async loop implementation - Ark API requires one request per text. + """ + if not texts: + return [] + + aclient = self._get_aclient() + results = [] + + for text in texts: + single_input = [{"type": "text", "text": text}] + response = await aclient.multimodal_embeddings.create( + model=self.model_name, input=single_input, **self.additional_kwargs + ) + results.append(response.data.embedding) + + return results + + def get_text_embedding(self, text: str) -> List[float]: + """公共接口:获取文本嵌入""" + return self._get_text_embedding(text) + + async def aget_text_embedding(self, text: str) -> List[float]: + """公共接口:异步获取文本嵌入""" + return await self._aget_text_embedding(text) + + def get_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Public interface: batch text embedding""" + return self._get_text_embeddings(texts) + + async def aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Public interface: async batch text embedding""" + return await self._aget_text_embeddings(texts) + + def get_query_embedding(self, query: str) -> List[float]: + """Public interface: get query embedding""" + return self._get_query_embedding(query) + + async def aget_query_embedding(self, query: str) -> List[float]: + """Public interface: async get query embedding""" + return await self._aget_query_embedding(query) + + +# Independent factory function +def create_embedding_model( + model_name: str, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + **kwargs: Any, +) -> Union["ArkEmbedding", "OpenAILikeEmbedding"]: + """ + Factory function: smart embedding model creation by model name + + Args: + model_name: Model name + api_key: API key + api_base: API base URL + **kwargs: Other parameters + + Returns: + Suitable embedding model instance (ArkEmbedding or OpenAILikeEmbedding) + """ + # Ark supported model list + ark_models = {"doubao-embedding-vision-250615", "doubao-embedding-vision-251215"} + + # Check if it's Ark supported model + if model_name in ark_models: + return ArkEmbedding( + model_name=model_name, api_key=api_key, api_base=api_base, **kwargs + ) + else: + # Use OpenAILikeEmbedding + from llama_index.embeddings.openai_like import OpenAILikeEmbedding + + return OpenAILikeEmbedding( + model_name=model_name, api_key=api_key, api_base=api_base, **kwargs + ) + @classmethod def class_name(cls) -> str: return "ArkEmbedding" From 87acc44d961de6b94dac429dc17c6230cc111c0d Mon Sep 17 00:00:00 2001 From: "hanzhi.421" Date: Thu, 15 Jan 2026 16:46:05 +0800 Subject: [PATCH 3/4] feat: header and consts --- config.yaml.full | 4 ++-- veadk/configs/model_configs.py | 5 +++-- veadk/consts.py | 1 + veadk/models/ark_embedding.py | 22 ++++++++++++++++++---- 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/config.yaml.full b/config.yaml.full index df7aac0a..e29ed159 100644 --- a/config.yaml.full +++ b/config.yaml.full @@ -15,8 +15,8 @@ model: api_key: # [optional] for knowledgebase embedding: - name: doubao-embedding-text-240715 - dim: 2560 + name: doubao-embedding-vision-250615 + dim: 2048 api_base: https://ark.cn-beijing.volces.com/api/v3/ api_key: video: diff --git a/veadk/configs/model_configs.py b/veadk/configs/model_configs.py index 38e66378..c6878a5c 100644 --- a/veadk/configs/model_configs.py +++ b/veadk/configs/model_configs.py @@ -24,6 +24,7 @@ DEFAULT_MODEL_AGENT_NAME, DEFAULT_MODEL_AGENT_PROVIDER, DEFAULT_MODEL_EMBEDDING_NAME, + DEFAULT_MODEL_EMBEDDING_DIM, ) @@ -50,10 +51,10 @@ class EmbeddingModelConfig(BaseSettings): name: str = DEFAULT_MODEL_EMBEDDING_NAME """Model name for embedding.""" - dim: int = 2560 + dim: int = DEFAULT_MODEL_EMBEDDING_DIM """Embedding dim is different from different models.""" - api_base: str = "https://ark.cn-beijing.volces.com/api/v3/" + api_base: str = DEFAULT_MODEL_AGENT_API_BASE """The api base of the model for embedding.""" @cached_property diff --git a/veadk/consts.py b/veadk/consts.py index 69a98e59..fa5c1aa6 100644 --- a/veadk/consts.py +++ b/veadk/consts.py @@ -78,3 +78,4 @@ DEFAULT_MODEL_EMBEDDING_NAME = "doubao-embedding-vision-250615" DEFAULT_MODEL_EMBEDDING_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/" +DEFAULT_MODEL_EMBEDDING_DIM = 2048 diff --git a/veadk/models/ark_embedding.py b/veadk/models/ark_embedding.py index 6752cd06..ec6d6e63 100644 --- a/veadk/models/ark_embedding.py +++ b/veadk/models/ark_embedding.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from typing import Any, Dict, Optional, List, Union, Tuple from enum import Enum @@ -254,6 +268,10 @@ async def aget_query_embedding(self, query: str) -> List[float]: """Public interface: async get query embedding""" return await self._aget_query_embedding(query) + @classmethod + def class_name(cls) -> str: + return "ArkEmbedding" + # Independent factory function def create_embedding_model( @@ -289,7 +307,3 @@ def create_embedding_model( return OpenAILikeEmbedding( model_name=model_name, api_key=api_key, api_base=api_base, **kwargs ) - - @classmethod - def class_name(cls) -> str: - return "ArkEmbedding" From bc51827b2d3540f331d9b994627db77d9d7c8db9 Mon Sep 17 00:00:00 2001 From: "hanzhi.421" Date: Thu, 15 Jan 2026 16:51:45 +0800 Subject: [PATCH 4/4] fix: comment --- veadk/models/ark_embedding.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/veadk/models/ark_embedding.py b/veadk/models/ark_embedding.py index ec6d6e63..5bcd8f68 100644 --- a/veadk/models/ark_embedding.py +++ b/veadk/models/ark_embedding.py @@ -108,7 +108,6 @@ def __init__( **kwargs, ) - # 设置默认值 if self.api_base is None: self.api_base = DEFAULT_MODEL_EMBEDDING_API_BASE @@ -245,27 +244,21 @@ async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: return results def get_text_embedding(self, text: str) -> List[float]: - """公共接口:获取文本嵌入""" return self._get_text_embedding(text) async def aget_text_embedding(self, text: str) -> List[float]: - """公共接口:异步获取文本嵌入""" return await self._aget_text_embedding(text) def get_text_embeddings(self, texts: List[str]) -> List[List[float]]: - """Public interface: batch text embedding""" return self._get_text_embeddings(texts) async def aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: - """Public interface: async batch text embedding""" return await self._aget_text_embeddings(texts) def get_query_embedding(self, query: str) -> List[float]: - """Public interface: get query embedding""" return self._get_query_embedding(query) async def aget_query_embedding(self, query: str) -> List[float]: - """Public interface: async get query embedding""" return await self._aget_query_embedding(query) @classmethod