2024-01-15 08:46:22 +08:00
|
|
|
#
|
2024-01-19 19:51:57 +08:00
|
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
2024-01-15 08:46:22 +08:00
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
#
|
2025-07-03 19:05:31 +08:00
|
|
|
import json
|
|
|
|
|
import os
|
2024-09-12 17:51:20 +08:00
|
|
|
import threading
|
2025-07-03 19:05:31 +08:00
|
|
|
from abc import ABC
|
2026-05-15 07:07:48 +05:00
|
|
|
from contextlib import contextmanager
|
2025-06-03 14:18:40 +08:00
|
|
|
from urllib.parse import urljoin
|
2026-06-05 09:45:44 +08:00
|
|
|
from json.decoder import JSONDecodeError
|
2025-06-03 14:18:40 +08:00
|
|
|
|
2025-07-03 19:05:31 +08:00
|
|
|
import dashscope
|
|
|
|
|
import numpy as np
|
2024-05-29 16:50:02 +08:00
|
|
|
import requests
|
2024-04-08 19:20:57 +08:00
|
|
|
from ollama import Client
|
2024-01-15 08:46:22 +08:00
|
|
|
from openai import OpenAI
|
2025-07-03 19:05:31 +08:00
|
|
|
from zhipuai import ZhipuAI
|
2024-09-24 19:22:01 +08:00
|
|
|
|
2026-06-01 19:18:16 +08:00
|
|
|
from common.exceptions import ModelException
|
2025-11-03 20:25:02 +08:00
|
|
|
from common.log_utils import log_exception
|
2025-12-15 11:33:57 +08:00
|
|
|
from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response
|
2025-11-06 09:36:38 +08:00
|
|
|
from common import settings
|
2025-10-23 23:02:27 +08:00
|
|
|
import logging
|
2025-11-10 18:01:40 +08:00
|
|
|
import base64
|
2024-03-27 11:33:46 +08:00
|
|
|
|
2026-05-15 07:07:48 +05:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
2026-06-01 19:18:16 +08:00
|
|
|
def _raise_model_exception_if_failed(resp):
|
|
|
|
|
status_code = resp.status_code
|
|
|
|
|
if status_code >= 400:
|
|
|
|
|
if status_code < 500 and status_code not in [408, 429]:
|
|
|
|
|
raise ModelException(f"status: {resp.status_code}, response: {resp.text}", retryable=False)
|
|
|
|
|
raise ModelException(f"status: {resp.status_code}, response: {resp.text}", retryable=True)
|
|
|
|
|
|
|
|
|
|
|
2026-05-15 07:07:48 +05:00
|
|
|
def _dashscope_base_url_for_log(base_url: str) -> str:
|
|
|
|
|
"""Log host/path only (no query string) so secrets in URLs are not printed."""
|
|
|
|
|
return base_url.split("?", 1)[0].strip()[:256]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _dashscope_native_http_api_url(base_url: str | None) -> str | None:
|
|
|
|
|
"""
|
|
|
|
|
Resolve the DashScope *native* HTTP API root for Tongyi-Qianwen (Qwen) text embeddings.
|
|
|
|
|
|
|
|
|
|
RAGFlow often stores an OpenAI-compatible base URL (e.g. ``.../compatible-mode/v1``) for
|
|
|
|
|
the same provider. The ``dashscope`` Python SDK used by ``TextEmbedding.call`` does *not*
|
|
|
|
|
use that path; it expects ``https://<host>/api/v1`` instead.
|
|
|
|
|
|
|
|
|
|
Users outside mainland China are directed to the international endpoint
|
|
|
|
|
(``dashscope-intl.aliyuncs.com``); domestic traffic uses ``dashscope.aliyuncs.com``.
|
|
|
|
|
When ``base_url`` already points at the native API root (ends with ``/api/v1``), it is
|
|
|
|
|
returned unchanged so custom or regional deployments keep working.
|
|
|
|
|
"""
|
|
|
|
|
if not base_url:
|
|
|
|
|
return None
|
|
|
|
|
u = base_url.strip().rstrip("/")
|
|
|
|
|
safe = _dashscope_base_url_for_log(u)
|
|
|
|
|
if u.endswith("/api/v1"):
|
|
|
|
|
logger.debug("DashScope Tongyi-Qianwen embedding: using native API base as configured (%s)", safe)
|
|
|
|
|
return u
|
|
|
|
|
# International (Singapore) DashScope — required for overseas Tongyi-Qianwen accounts.
|
|
|
|
|
if "dashscope-intl.aliyuncs.com" in u:
|
|
|
|
|
resolved = "https://dashscope-intl.aliyuncs.com/api/v1"
|
|
|
|
|
logger.info(
|
|
|
|
|
"DashScope Tongyi-Qianwen embedding: mapped configured base_url to intl native API (%s -> %s)",
|
|
|
|
|
safe,
|
|
|
|
|
resolved,
|
|
|
|
|
)
|
|
|
|
|
return resolved
|
|
|
|
|
# China mainland DashScope default host.
|
|
|
|
|
if "dashscope.aliyuncs.com" in u:
|
|
|
|
|
resolved = "https://dashscope.aliyuncs.com/api/v1"
|
|
|
|
|
logger.info(
|
|
|
|
|
"DashScope Tongyi-Qianwen embedding: mapped configured base_url to CN native API (%s -> %s)",
|
|
|
|
|
safe,
|
|
|
|
|
resolved,
|
|
|
|
|
)
|
|
|
|
|
return resolved
|
|
|
|
|
logger.warning(
|
|
|
|
|
"DashScope Tongyi-Qianwen embedding: base_url is set but not recognized as a DashScope host; "
|
|
|
|
|
"using SDK default endpoint (%s)",
|
|
|
|
|
safe,
|
|
|
|
|
)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def _dashscope_native_api_url_scope(url: str | None):
|
|
|
|
|
"""
|
|
|
|
|
Temporarily set ``dashscope.base_http_api_url`` for the duration of a single SDK call,
|
|
|
|
|
then restore the previous value. Narrows the window where concurrent threads see a mismatch.
|
|
|
|
|
"""
|
|
|
|
|
if not url:
|
|
|
|
|
yield
|
|
|
|
|
return
|
|
|
|
|
prev = getattr(dashscope, "base_http_api_url", None)
|
|
|
|
|
dashscope.base_http_api_url = url
|
|
|
|
|
try:
|
|
|
|
|
yield
|
|
|
|
|
finally:
|
|
|
|
|
dashscope.base_http_api_url = prev
|
|
|
|
|
|
2024-11-25 11:37:56 +08:00
|
|
|
|
2024-01-15 08:46:22 +08:00
|
|
|
class Base(ABC):
|
2025-08-07 08:45:37 +07:00
|
|
|
def __init__(self, key, model_name, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Constructor for abstract base class.
|
|
|
|
|
Parameters are accepted for interface consistency but are not stored.
|
|
|
|
|
Subclasses should implement their own initialization as needed.
|
|
|
|
|
"""
|
2024-01-15 08:46:22 +08:00
|
|
|
pass
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
2024-01-15 08:46:22 +08:00
|
|
|
raise NotImplementedError("Please implement encode method!")
|
|
|
|
|
|
2024-01-23 19:45:36 +08:00
|
|
|
def encode_queries(self, text: str):
|
|
|
|
|
raise NotImplementedError("Please implement encode method!")
|
|
|
|
|
|
2024-01-15 08:46:22 +08:00
|
|
|
|
2025-10-23 23:02:27 +08:00
|
|
|
class BuiltinEmbed(Base):
|
|
|
|
|
_FACTORY_NAME = "Builtin"
|
|
|
|
|
MAX_TOKENS = {"Qwen/Qwen3-Embedding-0.6B": 30000, "BAAI/bge-m3": 8000, "BAAI/bge-small-en-v1.5": 500}
|
2024-05-29 16:50:02 +08:00
|
|
|
_model = None
|
2024-12-19 16:18:18 +08:00
|
|
|
_model_name = ""
|
2025-10-23 23:02:27 +08:00
|
|
|
_max_tokens = 500
|
2024-06-27 14:48:49 +08:00
|
|
|
_model_lock = threading.Lock()
|
2025-01-26 13:54:26 +08:00
|
|
|
|
2024-05-29 16:50:02 +08:00
|
|
|
def __init__(self, key, model_name, **kwargs):
|
2025-11-06 09:36:38 +08:00
|
|
|
logging.info(f"Initialize BuiltinEmbed according to settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}")
|
|
|
|
|
embedding_cfg = settings.EMBEDDING_CFG
|
2025-10-23 23:02:27 +08:00
|
|
|
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
|
|
|
|
|
with BuiltinEmbed._model_lock:
|
2025-11-06 09:36:38 +08:00
|
|
|
BuiltinEmbed._model_name = settings.EMBEDDING_MDL
|
|
|
|
|
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(settings.EMBEDDING_MDL, 500)
|
|
|
|
|
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], settings.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
|
2025-10-23 23:02:27 +08:00
|
|
|
self._model = BuiltinEmbed._model
|
|
|
|
|
self._model_name = BuiltinEmbed._model_name
|
|
|
|
|
self._max_tokens = BuiltinEmbed._max_tokens
|
2024-01-15 08:46:22 +08:00
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
2025-10-31 16:46:20 +08:00
|
|
|
# TEI is able to auto truncate inputs according to https://github.com/huggingface/text-embeddings-inference.
|
2024-01-15 08:46:22 +08:00
|
|
|
token_count = 0
|
2026-04-30 05:00:10 +02:00
|
|
|
batches = []
|
2024-01-15 08:46:22 +08:00
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-10-23 23:02:27 +08:00
|
|
|
embeddings, token_count_delta = self._model.encode(texts[i : i + batch_size])
|
|
|
|
|
token_count += token_count_delta
|
2026-04-30 05:00:10 +02:00
|
|
|
batches.append(embeddings)
|
|
|
|
|
ress = np.vstack(batches) if batches else np.array([])
|
2025-07-14 14:02:48 +08:00
|
|
|
return ress, token_count
|
2024-01-15 08:46:22 +08:00
|
|
|
|
2024-01-17 20:20:42 +08:00
|
|
|
def encode_queries(self, text: str):
|
2025-10-23 23:02:27 +08:00
|
|
|
return self._model.encode_queries(text)
|
2024-01-17 20:20:42 +08:00
|
|
|
|
2024-01-15 08:46:22 +08:00
|
|
|
|
|
|
|
|
class OpenAIEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "OpenAI"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
|
2024-04-16 16:42:19 +08:00
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.openai.com/v1"
|
2024-03-28 19:15:16 +08:00
|
|
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
2024-01-15 08:46:22 +08:00
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
# OpenAI requires batch size <=16
|
|
|
|
|
batch_size = 16
|
2024-08-16 09:49:27 +08:00
|
|
|
texts = [truncate(t, 8191) for t in texts]
|
2024-12-03 16:22:39 +08:00
|
|
|
ress = []
|
|
|
|
|
total_tokens = 0
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2026-06-01 19:18:16 +08:00
|
|
|
try:
|
|
|
|
|
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
raise ModelException(f"Error: {_e}")
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
ress.extend([d.embedding for d in res.data])
|
2025-12-15 11:33:57 +08:00
|
|
|
total_tokens += total_token_count_from_response(res)
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {res}")
|
2024-12-03 16:22:39 +08:00
|
|
|
return np.array(ress), total_tokens
|
2024-01-23 19:45:36 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2026-06-01 19:18:16 +08:00
|
|
|
try:
|
|
|
|
|
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
raise ModelException(f"Error: {_e}")
|
2025-12-13 11:37:42 +08:00
|
|
|
try:
|
2025-12-16 11:29:07 +08:00
|
|
|
return np.array(res.data[0].embedding), total_token_count_from_response(res)
|
2025-12-13 11:37:42 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
|
|
|
|
raise Exception(f"Error: {res}")
|
2024-01-15 08:46:22 +08:00
|
|
|
|
|
|
|
|
|
2024-07-19 15:50:28 +08:00
|
|
|
class LocalAIEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "LocalAI"
|
|
|
|
|
|
2024-07-19 15:50:28 +08:00
|
|
|
def __init__(self, key, model_name, base_url):
|
2024-07-25 10:23:35 +08:00
|
|
|
if not base_url:
|
|
|
|
|
raise ValueError("Local embedding model url cannot be None")
|
2025-06-03 14:18:40 +08:00
|
|
|
base_url = urljoin(base_url, "v1")
|
2024-07-25 10:23:35 +08:00
|
|
|
self.client = OpenAI(api_key="empty", base_url=base_url)
|
2024-07-19 15:50:28 +08:00
|
|
|
self.model_name = model_name.split("___")[0]
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-07-03 19:05:31 +08:00
|
|
|
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
ress.extend([d.embedding for d in res.data])
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {res}")
|
2024-12-03 16:22:39 +08:00
|
|
|
# local embedding for LmStudio donot count tokens
|
|
|
|
|
return np.array(ress), 1024
|
2024-07-19 15:50:28 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2024-07-25 11:43:43 +08:00
|
|
|
embds, cnt = self.encode([text])
|
|
|
|
|
return np.array(embds[0]), cnt
|
2024-07-25 10:23:35 +08:00
|
|
|
|
2024-07-19 15:50:28 +08:00
|
|
|
|
2024-07-19 09:22:59 +08:00
|
|
|
class AzureEmbed(OpenAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Azure-OpenAI"
|
|
|
|
|
|
2024-07-04 09:57:16 +08:00
|
|
|
def __init__(self, key, model_name, **kwargs):
|
2024-10-09 10:34:58 +08:00
|
|
|
from openai.lib.azure import AzureOpenAI
|
2025-07-03 19:05:31 +08:00
|
|
|
|
|
|
|
|
api_key = json.loads(key).get("api_key", "")
|
|
|
|
|
api_version = json.loads(key).get("api_version", "2024-02-01")
|
2024-10-11 11:26:42 +08:00
|
|
|
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
2024-07-04 09:57:16 +08:00
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-07-19 09:22:59 +08:00
|
|
|
|
feat: Add Astraflow provider support (global + China endpoints) (#14270)
## Add Astraflow Provider Support
This PR integrates [Astraflow](https://astraflow.ucloud.cn/) (by UCloud
/ 优刻得) as a new AI model provider in RAGFlow, with support for both
global and China endpoints.
### About Astraflow
Astraflow is an OpenAI-compatible AI model aggregation platform
supporting 200+ models from major providers including DeepSeek, Qwen,
GPT, Claude, Gemini, Llama, Mistral, and more.
| Variant | Factory Name | Endpoint | Env Var |
|---------|-------------|----------|---------|
| Global | `Astraflow` | `https://api-us-ca.umodelverse.ai/v1` |
`ASTRAFLOW_API_KEY` |
| China | `Astraflow-CN` | `https://api.modelverse.cn/v1` |
`ASTRAFLOW_CN_API_KEY` |
- **API key signup**: https://astraflow.ucloud.cn/
---
### Files Changed
| File | Change |
|------|--------|
| `rag/llm/__init__.py` | Register `Astraflow` and `Astraflow-CN` in
`SupportedLiteLLMProvider` enum, `FACTORY_DEFAULT_BASE_URL`, and
`LITELLM_PROVIDER_PREFIX` |
| `rag/llm/chat_model.py` | Add `AstraflowChat` and `AstraflowCNChat`
(OpenAI-compatible `Base` subclass) |
| `rag/llm/embedding_model.py` | Add `AstraflowEmbed` and
`AstraflowCNEmbed` (subclasses of `OpenAIEmbed`) |
| `rag/llm/rerank_model.py` | Add `AstraflowRerank` and
`AstraflowCNRerank` (subclasses of `OpenAI_APIRerank`) |
| `rag/llm/cv_model.py` | Add `AstraflowCV` and `AstraflowCNCV`
(subclasses of `GptV4`) |
| `rag/llm/tts_model.py` | Add `AstraflowTTS` and `AstraflowCNTTS`
(subclasses of `OpenAITTS`) |
| `rag/llm/sequence2txt_model.py` | Add `AstraflowSeq2txt` and
`AstraflowCNSeq2txt` (subclasses of `GPTSeq2txt`) |
| `conf/llm_factories.json` | Register `Astraflow` and `Astraflow-CN`
factories with a curated list of popular models |
---
### Supported Model Types
- ✅ **Chat / LLM** — DeepSeek-V3/R1, Qwen3, GPT-4o/4.1, Claude 3.5/3.7,
Gemini 2.0/2.5 Flash, Llama 3.3/4, Mistral, and 200+ more
- ✅ **Text Embedding** — text-embedding-3-small/large
- ✅ **Image / Vision (IMAGE2TEXT)** — GPT-4o, GPT-4.1, Claude, Gemini,
Llama-4, etc.
- ✅ **Text Re-Rank**
- ✅ **TTS** — tts-1
- ✅ **Speech-to-Text (SPEECH2TEXT)** — whisper-1
### Implementation Notes
- Uses the `openai/` LiteLLM prefix — consistent with other
OpenAI-compatible aggregation platforms (SILICONFLOW, DeerAPI, CometAPI,
OpenRouter, n1n, Avian, etc.)
- `Astraflow` (global, rank 250) and `Astraflow-CN` (China, rank 249)
are separate factory entries, allowing users to choose the optimal
endpoint based on their region.
- All model classes cleanly subclass existing base classes (`Base`,
`OpenAIEmbed`, `OpenAI_APIRerank`, `GptV4`, `OpenAITTS`, `GPTSeq2txt`)
with no custom logic needed — the provider is fully OpenAI-compatible.
---------
Co-authored-by: user <user@xzaaaMacBook-Air.local>
2026-04-22 15:38:34 +08:00
|
|
|
class AstraflowEmbed(OpenAIEmbed):
|
|
|
|
|
_FACTORY_NAME = "Astraflow"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name, base_url="https://api-us-ca.umodelverse.ai/v1"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api-us-ca.umodelverse.ai/v1"
|
|
|
|
|
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AstraflowCNEmbed(OpenAIEmbed):
|
|
|
|
|
_FACTORY_NAME = "Astraflow-CN"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.modelverse.cn/v1"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.modelverse.cn/v1"
|
|
|
|
|
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
|
|
|
|
feat: add FuturMix as model provider (#14419)
## Summary
Add [FuturMix](https://futurmix.ai) as a new model provider. FuturMix is
an OpenAI-compatible unified AI gateway that provides access to 22+
models (GPT, Claude, Gemini, DeepSeek, and more) through a single API
endpoint and key.
- **API Base**: `https://futurmix.ai/v1` (OpenAI-compatible)
- **Supported capabilities**: Chat, Embedding, Image2Text, TTS,
Speech2Text, Rerank
### Changes
| File | Change |
|------|--------|
| `rag/llm/__init__.py` | Add `FuturMix` to `SupportedLiteLLMProvider`
enum, `FACTORY_DEFAULT_BASE_URL`, and `LITELLM_PROVIDER_PREFIX` |
| `rag/llm/chat_model.py` | Add `FuturMixChat(Base)` — follows
Astraflow/Avian pattern |
| `rag/llm/embedding_model.py` | Add `FuturMixEmbed(OpenAIEmbed)` —
follows Astraflow pattern |
| `rag/llm/cv_model.py` | Add `FuturMixCV(GptV4)` — follows
SILICONFLOW/OpenRouter pattern |
| `rag/llm/tts_model.py` | Add `FuturMixTTS(OpenAITTS)` — follows
CometAPI/DeerAPI pattern |
| `rag/llm/sequence2txt_model.py` | Add `FuturMixSeq2txt(GPTSeq2txt)` —
follows StepFun pattern |
| `rag/llm/rerank_model.py` | Add `FuturMixRerank(OpenAI_APIRerank)` |
| `conf/llm_factories.json` | Add factory config with 8 chat, 2
embedding, 1 image2text, 2 TTS, 1 speech2text models |
| `docs/guides/models/supported_models.mdx` | Add FuturMix to supported
models table |
### Models included
- **Chat**: claude-sonnet-4-20250514, claude-3.5-haiku, gpt-4o,
gpt-4o-mini, gemini-2.5-flash, gemini-2.0-flash, deepseek-chat,
deepseek-reasoner
- **Embedding**: text-embedding-3-small, text-embedding-3-large
- **Image2Text**: gpt-4o
- **TTS**: tts-1, tts-1-hd
- **Speech2Text**: whisper-1
## Test plan
- [ ] Verify FuturMix appears in the model provider list in RAGFlow UI
- [ ] Configure FuturMix with API key and test chat completion
- [ ] Test embedding model with document indexing
- [ ] Test image2text with a sample image
🤖 Generated with [Claude Code](https://claude.com/claude-code)
---------
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-30 10:59:37 +08:00
|
|
|
class FuturMixEmbed(OpenAIEmbed):
|
|
|
|
|
_FACTORY_NAME = "FuturMix"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name="text-embedding-3-small", base_url="https://futurmix.ai/v1"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://futurmix.ai/v1"
|
|
|
|
|
super().__init__(key, model_name, base_url)
|
|
|
|
|
logging.info("[FuturMix] Embedding initialized with model %s", model_name)
|
|
|
|
|
|
|
|
|
|
|
2024-05-28 09:09:37 +08:00
|
|
|
class BaiChuanEmbed(OpenAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "BaiChuan"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
|
2024-05-28 09:09:37 +08:00
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.baichuan-ai.com/v1"
|
|
|
|
|
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
|
|
|
|
2024-01-15 08:46:22 +08:00
|
|
|
class QWenEmbed(Base):
|
2026-05-15 07:07:48 +05:00
|
|
|
"""
|
|
|
|
|
Embeddings for Alibaba Tongyi-Qianwen via the DashScope ``TextEmbedding`` API.
|
|
|
|
|
|
|
|
|
|
``base_url`` comes from the user's embedding-model configuration (often the same host
|
|
|
|
|
as the OpenAI-compatible chat endpoint). This class maps known DashScope hosts to the
|
|
|
|
|
native ``/api/v1`` base URL so international and China endpoints both work.
|
|
|
|
|
"""
|
|
|
|
|
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Tongyi-Qianwen"
|
|
|
|
|
|
2026-05-15 07:07:48 +05:00
|
|
|
def __init__(self, key, model_name="text_embedding_v2", base_url=None, **kwargs):
|
2024-12-03 16:22:39 +08:00
|
|
|
self.key = key
|
2024-01-15 08:46:22 +08:00
|
|
|
self.model_name = model_name
|
2026-05-15 07:07:48 +05:00
|
|
|
# Native API root for the SDK; None if base_url is absent or not a known DashScope host.
|
|
|
|
|
self._dashscope_http_api_url = _dashscope_native_http_api_url(base_url)
|
2024-01-15 08:46:22 +08:00
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
fix: retry embedding with Qwen family models when limits temporarily reached. (#8690)
fix: retry embedding with Qwen family models when limits temporarily
reached.
APIs of Qwen family models are limited by calling rates. When reached,
the "output" attribute of the "resp" will be None, and in turn cause
TypeError when trying to retrieve "embeddings". Since these limits are
almost temporary, I have added a simple retry mechanism to avoid it.
Besides, if retry_max reached, the error can be early raised, instead of
hidden behind "TypeError".
### What problem does this PR solve?
Sometimes Qwen blocks calling due to rate limits, but it will cause the
whole parsing procedure stops when creating knowledge base. In this
situation, resp["output"] will be None, and resp["output"]["embeddings"]
will cause TypeError. Since the limits are temporary, I apply a simple
retry mechanism to solve it.
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
---------
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-07-07 12:15:52 +08:00
|
|
|
import time
|
2025-07-03 19:05:31 +08:00
|
|
|
|
2025-07-23 18:10:35 +08:00
|
|
|
import dashscope
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
batch_size = 4
|
2025-06-12 17:53:59 +08:00
|
|
|
res = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
texts = [truncate(t, 2048) for t in texts]
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2026-06-01 19:18:16 +08:00
|
|
|
|
|
|
|
|
retry_max, retry_wait_secs = 5, 10
|
|
|
|
|
for retry in range(retry_max):
|
2026-05-15 07:07:48 +05:00
|
|
|
with _dashscope_native_api_url_scope(self._dashscope_http_api_url):
|
|
|
|
|
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
|
2026-06-01 19:18:16 +08:00
|
|
|
status_code = resp.status_code
|
|
|
|
|
if status_code >= 400 and status_code < 500 and status_code not in [408, 429]:
|
|
|
|
|
raise ModelException(f"Error, status: {status_code}, response: {resp}")
|
|
|
|
|
# No need to retry for 4XX error
|
|
|
|
|
if status_code == 200:
|
|
|
|
|
break
|
|
|
|
|
if retry < retry_max - 1:
|
|
|
|
|
logging.warning(f"Got error response from DashScope API (status: {status_code}, response: {resp}). Wait {retry_wait_secs} seconds. Retrying...")
|
|
|
|
|
time.sleep(retry_wait_secs)
|
2025-07-10 10:30:18 +08:00
|
|
|
else:
|
2026-06-01 19:18:16 +08:00
|
|
|
raise ModelException(f"Error after {retry_max} retries., status: {status_code}, response: {resp}")
|
|
|
|
|
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
2024-05-31 09:46:22 +08:00
|
|
|
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
|
|
|
|
for e in resp["output"]["embeddings"]:
|
|
|
|
|
embds[e["text_index"]] = e["embedding"]
|
|
|
|
|
res.extend(embds)
|
2025-12-16 11:29:07 +08:00
|
|
|
token_count += total_token_count_from_response(resp)
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, resp)
|
2026-06-01 19:18:16 +08:00
|
|
|
raise ModelException(f"Error: {status_code}: {resp}")
|
2025-06-12 17:53:59 +08:00
|
|
|
return np.array(res), token_count
|
2024-05-31 09:46:22 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2026-05-15 07:07:48 +05:00
|
|
|
with _dashscope_native_api_url_scope(self._dashscope_http_api_url):
|
|
|
|
|
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
|
2026-06-01 19:18:16 +08:00
|
|
|
status_code = resp.status_code
|
|
|
|
|
if status_code != 200:
|
|
|
|
|
raise ModelException(f"Error: status: {status_code}: code: {resp.get('code')}, message: {resp.get('message')}")
|
|
|
|
|
# No need to retry for 4XX error
|
2024-05-31 09:46:22 +08:00
|
|
|
try:
|
2025-12-16 11:29:07 +08:00
|
|
|
return np.array(resp["output"]["embeddings"][0]["embedding"]), total_token_count_from_response(resp)
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, resp)
|
2026-06-01 19:18:16 +08:00
|
|
|
raise ModelException(f"Error: {status_code}: {resp}")
|
2024-02-08 17:01:01 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class ZhipuEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "ZHIPU-AI"
|
|
|
|
|
|
2024-03-28 19:15:16 +08:00
|
|
|
def __init__(self, key, model_name="embedding-2", **kwargs):
|
2024-02-08 17:01:01 +08:00
|
|
|
self.client = ZhipuAI(api_key=key)
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
2024-03-27 17:55:45 +08:00
|
|
|
arr = []
|
|
|
|
|
tks_num = 0
|
2025-01-15 14:36:27 +08:00
|
|
|
MAX_LEN = -1
|
|
|
|
|
if self.model_name.lower() == "embedding-2":
|
|
|
|
|
MAX_LEN = 512
|
|
|
|
|
if self.model_name.lower() == "embedding-3":
|
|
|
|
|
MAX_LEN = 3072
|
|
|
|
|
if MAX_LEN > 0:
|
|
|
|
|
texts = [truncate(t, MAX_LEN) for t in texts]
|
|
|
|
|
|
2024-03-27 17:55:45 +08:00
|
|
|
for txt in texts:
|
2025-07-03 19:05:31 +08:00
|
|
|
res = self.client.embeddings.create(input=txt, model=self.model_name)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
arr.append(res.data[0].embedding)
|
2025-12-16 11:29:07 +08:00
|
|
|
tks_num += total_token_count_from_response(res)
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {res}")
|
2024-03-27 17:55:45 +08:00
|
|
|
return np.array(arr), tks_num
|
2024-02-08 17:01:01 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2025-07-03 19:05:31 +08:00
|
|
|
res = self.client.embeddings.create(input=text, model=self.model_name)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
2025-12-16 11:29:07 +08:00
|
|
|
return np.array(res.data[0].embedding), total_token_count_from_response(res)
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {res}")
|
2024-04-08 19:20:57 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class OllamaEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Ollama"
|
|
|
|
|
|
2025-07-07 14:13:37 +08:00
|
|
|
_special_tokens = ["<|endoftext|>"]
|
|
|
|
|
|
2024-04-08 19:20:57 +08:00
|
|
|
def __init__(self, key, model_name, **kwargs):
|
2025-07-23 18:10:51 +08:00
|
|
|
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"})
|
2024-04-08 19:20:57 +08:00
|
|
|
self.model_name = model_name
|
2025-07-25 12:16:33 +08:00
|
|
|
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
|
2024-04-08 19:20:57 +08:00
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
2024-04-08 19:20:57 +08:00
|
|
|
arr = []
|
|
|
|
|
tks_num = 0
|
|
|
|
|
for txt in texts:
|
2025-07-28 10:16:38 +08:00
|
|
|
# remove special tokens if they exist base on regex in one request
|
2025-07-07 14:13:37 +08:00
|
|
|
for token in OllamaEmbed._special_tokens:
|
|
|
|
|
txt = txt.replace(token, "")
|
2025-07-25 12:16:33 +08:00
|
|
|
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
arr.append(res["embedding"])
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {res}")
|
2024-04-08 19:20:57 +08:00
|
|
|
tks_num += 128
|
|
|
|
|
return np.array(arr), tks_num
|
|
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2025-07-07 14:13:37 +08:00
|
|
|
# remove special tokens if they exist
|
|
|
|
|
for token in OllamaEmbed._special_tokens:
|
|
|
|
|
text = text.replace(token, "")
|
2025-07-25 12:16:33 +08:00
|
|
|
res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
return np.array(res["embedding"]), 128
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {res}")
|
2024-04-11 18:22:25 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class XinferenceEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Xinference"
|
|
|
|
|
|
2024-04-11 18:22:25 +08:00
|
|
|
def __init__(self, key, model_name="", base_url=""):
|
2025-06-03 14:18:40 +08:00
|
|
|
base_url = urljoin(base_url, "v1")
|
2024-10-16 10:21:08 +08:00
|
|
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
2024-04-11 18:22:25 +08:00
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
total_tokens = 0
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-07-31 12:14:50 +08:00
|
|
|
res = None
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
2025-07-31 12:14:50 +08:00
|
|
|
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
2025-06-12 17:53:59 +08:00
|
|
|
ress.extend([d.embedding for d in res.data])
|
2025-12-16 11:29:07 +08:00
|
|
|
total_tokens += total_token_count_from_response(res)
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {res}")
|
2024-12-03 16:22:39 +08:00
|
|
|
return np.array(ress), total_tokens
|
2024-04-11 18:22:25 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2025-07-31 12:14:50 +08:00
|
|
|
res = None
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
2025-07-31 12:14:50 +08:00
|
|
|
res = self.client.embeddings.create(input=[text], model=self.model_name)
|
2025-12-16 11:29:07 +08:00
|
|
|
return np.array(res.data[0].embedding), total_token_count_from_response(res)
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {res}")
|
2024-04-15 13:28:06 +05:30
|
|
|
|
2024-04-16 16:42:19 +08:00
|
|
|
|
2024-04-25 14:14:28 +08:00
|
|
|
class YoudaoEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Youdao"
|
2024-04-16 16:42:19 +08:00
|
|
|
_client = None
|
|
|
|
|
|
|
|
|
|
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
2025-10-23 23:02:27 +08:00
|
|
|
pass
|
2024-04-16 16:42:19 +08:00
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 10
|
2024-04-16 16:42:19 +08:00
|
|
|
res = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
for t in texts:
|
|
|
|
|
token_count += num_tokens_from_string(t)
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-07-03 19:05:31 +08:00
|
|
|
embds = YoudaoEmbed._client.encode(texts[i : i + batch_size])
|
2024-04-16 16:42:19 +08:00
|
|
|
res.extend(embds)
|
|
|
|
|
return np.array(res), token_count
|
|
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2024-04-25 14:14:28 +08:00
|
|
|
embds = YoudaoEmbed._client.encode([text])
|
2024-04-16 16:42:19 +08:00
|
|
|
return np.array(embds[0]), num_tokens_from_string(text)
|
2024-05-29 16:50:02 +08:00
|
|
|
|
|
|
|
|
|
2025-11-10 18:01:40 +08:00
|
|
|
class JinaMultiVecEmbed(Base):
|
|
|
|
|
_FACTORY_NAME = "Jina"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name="jina-embeddings-v4", base_url="https://api.jina.ai/v1/embeddings"):
|
|
|
|
|
self.base_url = "https://api.jina.ai/v1/embeddings"
|
|
|
|
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2026-03-20 02:47:48 +00:00
|
|
|
def encode(self, texts: list[str | bytes], task="retrieval.passage"):
|
2025-11-10 18:01:40 +08:00
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
input = []
|
|
|
|
|
for text in texts:
|
|
|
|
|
if isinstance(text, str):
|
|
|
|
|
input.append({"text": text})
|
|
|
|
|
elif isinstance(text, bytes):
|
|
|
|
|
img_b64s = None
|
|
|
|
|
try:
|
|
|
|
|
base64.b64decode(text, validate=True)
|
2026-03-20 02:47:48 +00:00
|
|
|
img_b64s = text.decode("utf8")
|
2025-11-10 18:01:40 +08:00
|
|
|
except Exception:
|
2026-03-20 02:47:48 +00:00
|
|
|
img_b64s = base64.b64encode(text).decode("utf8")
|
2025-11-10 18:01:40 +08:00
|
|
|
input.append({"image": img_b64s}) # base64 encoded image
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-12-01 14:24:35 +08:00
|
|
|
data = {"model": self.model_name, "input": input[i : i + batch_size]}
|
|
|
|
|
if "v4" in self.model_name:
|
|
|
|
|
data["return_multivector"] = True
|
2026-02-05 09:49:46 +08:00
|
|
|
|
2025-12-01 14:24:35 +08:00
|
|
|
if "v3" in self.model_name or "v4" in self.model_name:
|
2026-03-20 02:47:48 +00:00
|
|
|
data["task"] = task
|
|
|
|
|
data["truncate"] = True
|
2025-12-01 14:24:35 +08:00
|
|
|
|
2026-05-11 11:19:07 +08:00
|
|
|
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
|
2026-06-01 19:18:16 +08:00
|
|
|
_raise_model_exception_if_failed(response)
|
2025-11-10 18:01:40 +08:00
|
|
|
try:
|
|
|
|
|
res = response.json()
|
2026-03-20 02:47:48 +00:00
|
|
|
for d in res["data"]:
|
|
|
|
|
if data.get("return_multivector", False): # v4
|
|
|
|
|
token_embs = np.asarray(d["embeddings"], dtype=np.float32)
|
2025-12-01 14:24:35 +08:00
|
|
|
chunk_emb = token_embs.mean(axis=0)
|
2026-02-05 09:49:46 +08:00
|
|
|
|
2025-12-01 14:24:35 +08:00
|
|
|
else:
|
|
|
|
|
# v2/v3
|
2026-03-20 02:47:48 +00:00
|
|
|
chunk_emb = np.asarray(d["embedding"], dtype=np.float32)
|
2025-12-01 14:24:35 +08:00
|
|
|
|
|
|
|
|
ress.append(chunk_emb)
|
|
|
|
|
|
2025-12-17 12:58:48 +08:00
|
|
|
token_count += total_token_count_from_response(res)
|
2025-11-10 18:01:40 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, response)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {response}")
|
2025-11-10 18:01:40 +08:00
|
|
|
return np.array(ress), token_count
|
|
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
|
|
|
|
embds, cnt = self.encode([text], task="retrieval.query")
|
|
|
|
|
return np.array(embds[0]), cnt
|
|
|
|
|
|
|
|
|
|
|
2024-06-14 11:32:58 +08:00
|
|
|
class MistralEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Mistral"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name="mistral-embed", base_url=None):
|
2024-06-14 11:32:58 +08:00
|
|
|
from mistralai.client import MistralClient
|
2025-07-03 19:05:31 +08:00
|
|
|
|
2024-06-14 11:32:58 +08:00
|
|
|
self.client = MistralClient(api_key=key)
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
2025-07-30 11:37:49 +08:00
|
|
|
import time
|
|
|
|
|
import random
|
2025-09-18 14:49:47 +08:00
|
|
|
|
2024-06-14 11:32:58 +08:00
|
|
|
texts = [truncate(t, 8196) for t in texts]
|
2024-12-03 16:22:39 +08:00
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-07-30 11:37:49 +08:00
|
|
|
retry_max = 5
|
|
|
|
|
while retry_max > 0:
|
|
|
|
|
try:
|
|
|
|
|
res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
|
|
|
|
|
ress.extend([d.embedding for d in res.data])
|
2025-12-16 11:29:07 +08:00
|
|
|
token_count += total_token_count_from_response(res)
|
2025-07-30 11:37:49 +08:00
|
|
|
break
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
if retry_max == 1:
|
|
|
|
|
log_exception(_e)
|
|
|
|
|
delay = random.uniform(20, 60)
|
|
|
|
|
time.sleep(delay)
|
|
|
|
|
retry_max -= 1
|
2024-12-03 16:22:39 +08:00
|
|
|
return np.array(ress), token_count
|
2024-06-14 11:32:58 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2025-07-30 11:37:49 +08:00
|
|
|
import time
|
|
|
|
|
import random
|
2026-03-20 02:47:48 +00:00
|
|
|
|
2025-07-30 11:37:49 +08:00
|
|
|
retry_max = 5
|
|
|
|
|
while retry_max > 0:
|
|
|
|
|
try:
|
|
|
|
|
res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
|
2025-12-16 11:29:07 +08:00
|
|
|
return np.array(res.data[0].embedding), total_token_count_from_response(res)
|
2025-07-30 11:37:49 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
if retry_max == 1:
|
|
|
|
|
log_exception(_e)
|
|
|
|
|
delay = random.randint(20, 60)
|
|
|
|
|
time.sleep(delay)
|
|
|
|
|
retry_max -= 1
|
2024-07-08 09:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class BedrockEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Bedrock"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name, **kwargs):
|
2024-07-08 09:37:34 +08:00
|
|
|
import boto3
|
2026-03-20 02:47:48 +00:00
|
|
|
|
2025-12-19 11:32:20 +08:00
|
|
|
# `key` protocol (backend stores as JSON string in `api_key`):
|
|
|
|
|
# - Must decode into a dict.
|
|
|
|
|
# - Required: `auth_mode`, `bedrock_region`.
|
|
|
|
|
# - Supported auth modes:
|
|
|
|
|
# - "access_key_secret": requires `bedrock_ak` + `bedrock_sk`.
|
|
|
|
|
# - "iam_role": requires `aws_role_arn` and assumes role via STS.
|
|
|
|
|
# - else: treated as "assume_role" (default AWS credential chain).
|
|
|
|
|
key = json.loads(key)
|
|
|
|
|
mode = key.get("auth_mode")
|
|
|
|
|
if not mode:
|
|
|
|
|
logging.error("Bedrock auth_mode is not provided in the key")
|
|
|
|
|
raise ValueError("Bedrock auth_mode must be provided in the key")
|
|
|
|
|
|
|
|
|
|
self.bedrock_region = key.get("bedrock_region")
|
2025-07-03 19:05:31 +08:00
|
|
|
|
2024-07-08 09:37:34 +08:00
|
|
|
self.model_name = model_name
|
2025-07-28 10:16:38 +08:00
|
|
|
self.is_amazon = self.model_name.split(".")[0] == "amazon"
|
|
|
|
|
self.is_cohere = self.model_name.split(".")[0] == "cohere"
|
2026-02-05 09:49:46 +08:00
|
|
|
|
2025-12-19 11:32:20 +08:00
|
|
|
if mode == "access_key_secret":
|
|
|
|
|
self.bedrock_ak = key.get("bedrock_ak")
|
|
|
|
|
self.bedrock_sk = key.get("bedrock_sk")
|
2025-07-03 19:05:31 +08:00
|
|
|
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
2025-12-19 11:32:20 +08:00
|
|
|
elif mode == "iam_role":
|
|
|
|
|
self.aws_role_arn = key.get("aws_role_arn")
|
|
|
|
|
sts_client = boto3.client("sts", region_name=self.bedrock_region)
|
|
|
|
|
resp = sts_client.assume_role(RoleArn=self.aws_role_arn, RoleSessionName="BedrockSession")
|
|
|
|
|
creds = resp["Credentials"]
|
|
|
|
|
|
|
|
|
|
self.client = boto3.client(
|
|
|
|
|
service_name="bedrock-runtime",
|
|
|
|
|
aws_access_key_id=creds["AccessKeyId"],
|
|
|
|
|
aws_secret_access_key=creds["SecretAccessKey"],
|
|
|
|
|
aws_session_token=creds["SessionToken"],
|
|
|
|
|
)
|
2026-03-20 02:47:48 +00:00
|
|
|
else: # assume_role
|
2025-12-19 11:32:20 +08:00
|
|
|
self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region)
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
2024-07-08 09:37:34 +08:00
|
|
|
texts = [truncate(t, 8196) for t in texts]
|
|
|
|
|
embeddings = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
for text in texts:
|
2025-07-28 10:16:38 +08:00
|
|
|
if self.is_amazon:
|
2024-07-08 09:37:34 +08:00
|
|
|
body = {"inputText": text}
|
2025-07-28 10:16:38 +08:00
|
|
|
elif self.is_cohere:
|
2025-07-03 19:05:31 +08:00
|
|
|
body = {"texts": [text], "input_type": "search_document"}
|
2024-07-08 09:37:34 +08:00
|
|
|
|
|
|
|
|
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
model_response = json.loads(response["body"].read())
|
|
|
|
|
embeddings.extend([model_response["embedding"]])
|
|
|
|
|
token_count += num_tokens_from_string(text)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, response)
|
2024-07-08 09:37:34 +08:00
|
|
|
|
|
|
|
|
return np.array(embeddings), token_count
|
|
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
|
|
|
|
embeddings = []
|
|
|
|
|
token_count = num_tokens_from_string(text)
|
2025-07-28 10:16:38 +08:00
|
|
|
if self.is_amazon:
|
2024-07-08 09:37:34 +08:00
|
|
|
body = {"inputText": truncate(text, 8196)}
|
2025-07-28 10:16:38 +08:00
|
|
|
elif self.is_cohere:
|
2025-07-03 19:05:31 +08:00
|
|
|
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
|
2024-07-08 09:37:34 +08:00
|
|
|
|
|
|
|
|
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
model_response = json.loads(response["body"].read())
|
|
|
|
|
embeddings.extend(model_response["embedding"])
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, response)
|
2024-07-08 09:37:34 +08:00
|
|
|
|
|
|
|
|
return np.array(embeddings), token_count
|
|
|
|
|
|
2025-01-06 14:41:29 +08:00
|
|
|
|
2024-07-11 15:41:00 +08:00
|
|
|
class GeminiEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Gemini"
|
|
|
|
|
|
2026-02-24 10:28:33 +08:00
|
|
|
def __init__(self, key, model_name="gemini-embedding-001", **kwargs):
|
|
|
|
|
from google import genai
|
|
|
|
|
from google.genai import types
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
self.key = key
|
2026-02-24 10:28:33 +08:00
|
|
|
self.model_name = model_name[7:] if model_name.startswith("models/") else model_name
|
|
|
|
|
self.client = genai.Client(api_key=self.key)
|
|
|
|
|
self.types = types
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _parse_embedding_vector(embedding):
|
|
|
|
|
if isinstance(embedding, dict):
|
|
|
|
|
values = embedding.get("values")
|
|
|
|
|
if values is None:
|
|
|
|
|
values = embedding.get("embedding")
|
|
|
|
|
if values is not None:
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
values = getattr(embedding, "values", None)
|
|
|
|
|
if values is None:
|
|
|
|
|
values = getattr(embedding, "embedding", None)
|
|
|
|
|
if values is not None:
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
raise TypeError(f"Unsupported embedding payload: {type(embedding)}")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _parse_embedding_response(cls, response):
|
|
|
|
|
if response is None:
|
|
|
|
|
raise ValueError("Embedding response is empty")
|
|
|
|
|
|
|
|
|
|
embeddings = getattr(response, "embeddings", None)
|
|
|
|
|
if embeddings is None and isinstance(response, dict):
|
|
|
|
|
embeddings = response.get("embeddings")
|
|
|
|
|
|
|
|
|
|
if embeddings is None:
|
|
|
|
|
return [cls._parse_embedding_vector(response)]
|
|
|
|
|
|
|
|
|
|
return [cls._parse_embedding_vector(item) for item in embeddings]
|
|
|
|
|
|
|
|
|
|
def _build_embedding_config(self):
|
|
|
|
|
task_type = "RETRIEVAL_DOCUMENT"
|
|
|
|
|
if hasattr(self.types, "TaskType"):
|
|
|
|
|
task_type = getattr(self.types.TaskType, "RETRIEVAL_DOCUMENT", task_type)
|
|
|
|
|
try:
|
|
|
|
|
return self.types.EmbedContentConfig(task_type=task_type, title="Embedding of single string")
|
|
|
|
|
except TypeError:
|
|
|
|
|
# Compatible with SDK versions that do not accept title in embed config.
|
|
|
|
|
return self.types.EmbedContentConfig(task_type=task_type)
|
2025-07-03 19:05:31 +08:00
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
2024-07-11 15:41:00 +08:00
|
|
|
texts = [truncate(t, 2048) for t in texts]
|
|
|
|
|
token_count = sum(num_tokens_from_string(text) for text in texts)
|
2026-02-24 10:28:33 +08:00
|
|
|
config = self._build_embedding_config()
|
2024-12-03 16:22:39 +08:00
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2026-02-24 10:28:33 +08:00
|
|
|
result = None
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
2026-02-24 10:28:33 +08:00
|
|
|
result = self.client.models.embed_content(
|
|
|
|
|
model=self.model_name,
|
|
|
|
|
contents=texts[i : i + batch_size],
|
|
|
|
|
config=config,
|
|
|
|
|
)
|
|
|
|
|
ress.extend(self._parse_embedding_response(result))
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, result)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {result}")
|
2025-07-03 19:05:31 +08:00
|
|
|
return np.array(ress), token_count
|
|
|
|
|
|
2024-07-11 15:41:00 +08:00
|
|
|
def encode_queries(self, text):
|
2026-02-24 10:28:33 +08:00
|
|
|
config = self._build_embedding_config()
|
|
|
|
|
result = None
|
2024-07-11 15:41:00 +08:00
|
|
|
token_count = num_tokens_from_string(text)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
2026-02-24 10:28:33 +08:00
|
|
|
result = self.client.models.embed_content(
|
|
|
|
|
model=self.model_name,
|
|
|
|
|
contents=[truncate(text, 2048)],
|
|
|
|
|
config=config,
|
|
|
|
|
)
|
|
|
|
|
return np.array(self._parse_embedding_response(result)[0]), token_count
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, result)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {result}")
|
2025-01-06 14:41:29 +08:00
|
|
|
|
2024-07-23 10:43:09 +08:00
|
|
|
|
|
|
|
|
class NvidiaEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "NVIDIA"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
|
2024-07-23 10:43:09 +08:00
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
|
|
|
|
|
self.api_key = key
|
|
|
|
|
self.base_url = base_url
|
|
|
|
|
self.headers = {
|
|
|
|
|
"accept": "application/json",
|
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
|
"authorization": f"Bearer {self.api_key}",
|
|
|
|
|
}
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
if model_name == "nvidia/embed-qa-4":
|
|
|
|
|
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
|
|
|
|
|
self.model_name = "NV-Embed-QA"
|
|
|
|
|
if model_name == "snowflake/arctic-embed-l":
|
|
|
|
|
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
|
|
|
payload = {
|
|
|
|
|
"input": texts[i : i + batch_size],
|
|
|
|
|
"input_type": "query",
|
|
|
|
|
"model": self.model_name,
|
|
|
|
|
"encoding_format": "float",
|
|
|
|
|
"truncate": "END",
|
|
|
|
|
}
|
2026-05-11 11:19:07 +08:00
|
|
|
response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30)
|
2026-06-01 19:18:16 +08:00
|
|
|
_raise_model_exception_if_failed(response)
|
2025-06-10 19:04:17 +08:00
|
|
|
try:
|
|
|
|
|
res = response.json()
|
2025-12-13 11:37:42 +08:00
|
|
|
ress.extend([d["embedding"] for d in res["data"]])
|
2025-12-16 11:29:07 +08:00
|
|
|
token_count += total_token_count_from_response(res)
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, response)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {response}")
|
2024-12-03 16:22:39 +08:00
|
|
|
return np.array(ress), token_count
|
2024-07-23 10:43:09 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
|
|
|
|
embds, cnt = self.encode([text])
|
|
|
|
|
return np.array(embds[0]), cnt
|
2024-07-24 12:46:43 +08:00
|
|
|
|
|
|
|
|
|
2024-07-25 10:23:35 +08:00
|
|
|
class LmStudioEmbed(LocalAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "LM-Studio"
|
|
|
|
|
|
2024-07-24 12:46:43 +08:00
|
|
|
def __init__(self, key, model_name, base_url):
|
|
|
|
|
if not base_url:
|
|
|
|
|
raise ValueError("Local llm url cannot be None")
|
2025-06-03 14:18:40 +08:00
|
|
|
base_url = urljoin(base_url, "v1")
|
2024-08-07 18:10:42 +08:00
|
|
|
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
|
2024-07-24 12:46:43 +08:00
|
|
|
self.model_name = model_name
|
2024-08-06 16:20:21 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAI_APIEmbed(OpenAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
|
|
|
|
|
2024-08-06 16:20:21 +08:00
|
|
|
def __init__(self, key, model_name, base_url):
|
|
|
|
|
if not base_url:
|
|
|
|
|
raise ValueError("url cannot be None")
|
2025-06-03 14:18:40 +08:00
|
|
|
base_url = urljoin(base_url, "v1")
|
2024-08-06 16:20:21 +08:00
|
|
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
2024-08-07 18:40:51 +08:00
|
|
|
self.model_name = model_name.split("___")[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CoHereEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Cohere"
|
|
|
|
|
|
2024-08-07 18:40:51 +08:00
|
|
|
def __init__(self, key, model_name, base_url=None):
|
|
|
|
|
from cohere import Client
|
|
|
|
|
|
|
|
|
|
self.client = Client(api_key=key)
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
|
|
|
res = self.client.embed(
|
|
|
|
|
texts=texts[i : i + batch_size],
|
|
|
|
|
model=self.model_name,
|
|
|
|
|
input_type="search_document",
|
|
|
|
|
embedding_types=["float"],
|
|
|
|
|
)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
ress.extend([d for d in res.embeddings.float])
|
2025-12-18 10:04:28 +08:00
|
|
|
token_count += total_token_count_from_response(res)
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {res}")
|
2024-12-03 16:22:39 +08:00
|
|
|
return np.array(ress), token_count
|
2024-08-07 18:40:51 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
|
|
|
|
res = self.client.embed(
|
|
|
|
|
texts=[text],
|
|
|
|
|
model=self.model_name,
|
|
|
|
|
input_type="search_query",
|
|
|
|
|
embedding_types=["float"],
|
|
|
|
|
)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
2025-12-18 10:04:28 +08:00
|
|
|
return np.array(res.embeddings.float[0]), int(total_token_count_from_response(res))
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {res}")
|
2024-08-12 10:11:50 +08:00
|
|
|
|
|
|
|
|
|
2025-01-24 10:29:30 +08:00
|
|
|
class TogetherAIEmbed(OpenAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "TogetherAI"
|
|
|
|
|
|
2024-08-12 10:15:21 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.together.xyz/v1"
|
2024-11-19 14:51:33 +08:00
|
|
|
super().__init__(key, model_name, base_url=base_url)
|
2024-08-12 10:15:21 +08:00
|
|
|
|
2024-08-19 10:36:57 +08:00
|
|
|
|
2024-08-12 10:11:50 +08:00
|
|
|
class PerfXCloudEmbed(OpenAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "PerfXCloud"
|
|
|
|
|
|
2024-08-12 10:11:50 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://cloud.perfxlab.cn/v1"
|
|
|
|
|
super().__init__(key, model_name, base_url)
|
2024-08-12 11:06:25 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class UpstageEmbed(OpenAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Upstage"
|
|
|
|
|
|
2024-08-12 11:06:25 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.upstage.ai/v1/solar"
|
|
|
|
|
super().__init__(key, model_name, base_url)
|
2024-08-13 16:09:10 +08:00
|
|
|
|
|
|
|
|
|
2024-09-11 12:17:44 +08:00
|
|
|
class SILICONFLOWEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "SILICONFLOW"
|
|
|
|
|
|
2025-06-30 11:22:11 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"):
|
2026-03-02 15:37:42 +08:00
|
|
|
normalized_base_url = (base_url or "").strip()
|
|
|
|
|
if not normalized_base_url:
|
|
|
|
|
normalized_base_url = "https://api.siliconflow.cn/v1/embeddings"
|
|
|
|
|
if "/embeddings" not in normalized_base_url:
|
|
|
|
|
normalized_base_url = urljoin(f"{normalized_base_url.rstrip('/')}/", "embeddings").rstrip("/")
|
2024-09-11 12:17:44 +08:00
|
|
|
self.headers = {
|
|
|
|
|
"accept": "application/json",
|
|
|
|
|
"content-type": "application/json",
|
|
|
|
|
"authorization": f"Bearer {key}",
|
|
|
|
|
}
|
2026-03-02 15:37:42 +08:00
|
|
|
self.base_url = normalized_base_url
|
2024-09-11 12:17:44 +08:00
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
|
|
|
texts_batch = texts[i : i + batch_size]
|
2025-09-11 12:02:12 +08:00
|
|
|
if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]:
|
|
|
|
|
# limit 512, 340 is almost safe
|
2025-10-09 12:36:19 +08:00
|
|
|
texts_batch = [" " if not text.strip() else truncate(text, 256) for text in texts_batch]
|
2025-09-11 12:02:12 +08:00
|
|
|
else:
|
|
|
|
|
texts_batch = [" " if not text.strip() else text for text in texts_batch]
|
2025-09-10 13:02:53 +08:00
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
payload = {
|
|
|
|
|
"model": self.model_name,
|
|
|
|
|
"input": texts_batch,
|
|
|
|
|
"encoding_format": "float",
|
|
|
|
|
}
|
2026-05-11 11:19:07 +08:00
|
|
|
response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30)
|
2026-06-01 19:18:16 +08:00
|
|
|
_raise_model_exception_if_failed(response)
|
2025-06-10 19:04:17 +08:00
|
|
|
try:
|
|
|
|
|
res = response.json()
|
2025-06-12 17:53:59 +08:00
|
|
|
ress.extend([d["embedding"] for d in res["data"]])
|
2025-12-16 11:29:07 +08:00
|
|
|
token_count += total_token_count_from_response(res)
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, response)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {response}")
|
2025-06-12 17:53:59 +08:00
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
return np.array(ress), token_count
|
2024-09-11 12:17:44 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
|
|
|
|
payload = {
|
|
|
|
|
"model": self.model_name,
|
|
|
|
|
"input": text,
|
|
|
|
|
"encoding_format": "float",
|
|
|
|
|
}
|
2026-05-11 11:19:07 +08:00
|
|
|
response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30)
|
2026-06-01 19:18:16 +08:00
|
|
|
_raise_model_exception_if_failed(response)
|
2025-06-10 19:04:17 +08:00
|
|
|
try:
|
|
|
|
|
res = response.json()
|
2025-12-16 11:29:07 +08:00
|
|
|
return np.array(res["data"][0]["embedding"]), total_token_count_from_response(res)
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, response)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {response}")
|
2024-08-19 10:36:57 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReplicateEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Replicate"
|
|
|
|
|
|
2024-08-19 10:36:57 +08:00
|
|
|
def __init__(self, key, model_name, base_url=None):
|
|
|
|
|
from replicate.client import Client
|
|
|
|
|
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
self.client = Client(api_token=key)
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
|
|
|
|
token_count = sum([num_tokens_from_string(text) for text in texts])
|
|
|
|
|
ress = []
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
|
|
|
res = self.client.run(self.model_name, input={"texts": texts[i : i + batch_size]})
|
|
|
|
|
ress.extend(res)
|
|
|
|
|
return np.array(ress), token_count
|
2024-08-19 10:36:57 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
|
|
|
|
res = self.client.embed(self.model_name, input={"texts": [text]})
|
|
|
|
|
return np.array(res), num_tokens_from_string(text)
|
2024-08-22 16:45:15 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaiduYiyanEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "BaiduYiyan"
|
|
|
|
|
|
2024-08-22 16:45:15 +08:00
|
|
|
def __init__(self, key, model_name, base_url=None):
|
|
|
|
|
import qianfan
|
|
|
|
|
|
|
|
|
|
key = json.loads(key)
|
|
|
|
|
ak = key.get("yiyan_ak", "")
|
|
|
|
|
sk = key.get("yiyan_sk", "")
|
|
|
|
|
self.client = qianfan.Embedding(ak=ak, sk=sk)
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-11-27 18:06:43 +08:00
|
|
|
def encode(self, texts: list, batch_size=16):
|
2024-08-22 16:45:15 +08:00
|
|
|
res = self.client.do(model=self.model_name, texts=texts).body
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
return (
|
|
|
|
|
np.array([r["embedding"] for r in res["data"]]),
|
2025-12-16 11:29:07 +08:00
|
|
|
total_token_count_from_response(res),
|
2025-06-12 17:53:59 +08:00
|
|
|
)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {res}")
|
2024-08-22 16:45:15 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
|
|
|
|
res = self.client.do(model=self.model_name, texts=[text]).body
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
return (
|
|
|
|
|
np.array([r["embedding"] for r in res["data"]]),
|
2025-12-16 11:29:07 +08:00
|
|
|
total_token_count_from_response(res),
|
2025-06-12 17:53:59 +08:00
|
|
|
)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {res}")
|
2024-08-29 16:14:49 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class VoyageEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Voyage AI"
|
|
|
|
|
|
2024-08-29 16:14:49 +08:00
|
|
|
def __init__(self, key, model_name, base_url=None):
|
|
|
|
|
import voyageai
|
|
|
|
|
|
|
|
|
|
self.client = voyageai.Client(api_key=key)
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-07-03 19:05:31 +08:00
|
|
|
res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document")
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
ress.extend(res.embeddings)
|
|
|
|
|
token_count += res.total_tokens
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {res}")
|
2024-12-03 16:22:39 +08:00
|
|
|
return np.array(ress), token_count
|
2024-08-29 16:14:49 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2025-07-03 19:05:31 +08:00
|
|
|
res = self.client.embed(texts=text, model=self.model_name, input_type="query")
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
return np.array(res.embeddings)[0], res.total_tokens
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2025-12-13 11:37:42 +08:00
|
|
|
raise Exception(f"Error: {res}")
|
2024-09-27 19:15:38 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class HuggingFaceEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "HuggingFace"
|
|
|
|
|
|
2025-08-07 08:45:37 +07:00
|
|
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
2024-09-27 19:15:38 +08:00
|
|
|
if not model_name:
|
|
|
|
|
raise ValueError("Model name cannot be None")
|
|
|
|
|
self.key = key
|
2024-12-05 13:28:42 +08:00
|
|
|
self.model_name = model_name.split("___")[0]
|
2024-09-27 19:15:38 +08:00
|
|
|
self.base_url = base_url or "http://127.0.0.1:8080"
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
2026-05-11 11:19:07 +08:00
|
|
|
response = requests.post(f"{self.base_url}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"}, timeout=30)
|
2026-06-01 19:18:16 +08:00
|
|
|
_raise_model_exception_if_failed(response)
|
|
|
|
|
embeddings = response.json()
|
fix a bug when using huggingface embedding api (#8432)
### What problem does this PR solve?
image_version: v0.19.1
This PR fixes a bug in the HuggingFaceEmBedding API method that was
causing AssertionError: assert len(vects) == len(docs) during the
document embedding process.
#### Problem
The HuggingFaceEmbed.encode() method had an early return statement
inside the for loop, causing it to return after processing only the
first text input instead of processing all texts in the input list.
**Error Messenge**
```python
AssertionError: assert len(vects) == len(docs) # input chunks != embedded vectors from embedding api
File "/ragflow/rag/svr/task_executor.py", line 442, in embedding
```
**Buggy code(/ragflow/rag/llm/embedding_model.py)**
```python
class HuggingFaceEmbed(Base):
def __init__(self, key, model_name, base_url=None):
if not model_name:
raise ValueError("Model name cannot be None")
self.key = key
self.model_name = model_name.split("___")[0]
self.base_url = base_url or "http://127.0.0.1:8080"
def encode(self, texts: list):
embeddings = []
for text in texts:
response = requests.post(...)
if response.status_code == 200:
try:
embedding = response.json()
embeddings.append(embedding[0])
# ❌ Early return
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
except Exception as _e:
log_exception(_e, response)
else:
raise Exception(...)
```
**Fixed Code(I just Rollback this function to the v0.19.0 version)**
```python
Class HuggingFaceEmbed(Base):
def __init__(self, key, model_name, base_url=None):
if not model_name:
raise ValueError("Model name cannot be None")
self.key = key
self.model_name = model_name.split("___")[0]
self.base_url = base_url or "http://127.0.0.1:8080"
def encode(self, texts: list):
embeddings = []
for text in texts:
response = requests.post(...)
if response.status_code == 200:
embedding = response.json()
embeddings.append(embedding[0]) # ✅ Only append, no return
else:
raise Exception(...)
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts]) # ✅ Return after processing all
```
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
2025-06-24 09:35:02 +08:00
|
|
|
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
|
2024-09-27 19:15:38 +08:00
|
|
|
|
2025-10-23 23:02:27 +08:00
|
|
|
def encode_queries(self, text: str):
|
2026-05-11 11:19:07 +08:00
|
|
|
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"}, timeout=30)
|
2026-06-01 19:18:16 +08:00
|
|
|
_raise_model_exception_if_failed(response)
|
|
|
|
|
embedding = response.json()[0]
|
|
|
|
|
return np.array(embedding), num_tokens_from_string(text)
|
2024-09-27 19:15:38 +08:00
|
|
|
|
2024-12-05 13:28:42 +08:00
|
|
|
|
2026-02-05 09:49:46 +08:00
|
|
|
class VolcEngineEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "VolcEngine"
|
|
|
|
|
|
2024-11-27 09:30:49 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
2026-02-05 09:49:46 +08:00
|
|
|
self.base_url = base_url
|
|
|
|
|
|
2026-06-05 09:45:44 +08:00
|
|
|
try:
|
|
|
|
|
cfg = json.loads(key)
|
|
|
|
|
self.ark_api_key = cfg.get("ark_api_key", "")
|
|
|
|
|
except JSONDecodeError:
|
|
|
|
|
self.ark_api_key = key
|
2026-02-05 09:49:46 +08:00
|
|
|
self.model_name = model_name
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _extract_embedding(result: dict) -> list[float]:
|
|
|
|
|
if not isinstance(result, dict):
|
|
|
|
|
raise TypeError(f"Unexpected response type: {type(result)}")
|
|
|
|
|
|
|
|
|
|
data = result.get("data")
|
|
|
|
|
if data is None:
|
|
|
|
|
raise KeyError("Missing 'data' in response")
|
|
|
|
|
|
|
|
|
|
if isinstance(data, list):
|
|
|
|
|
if not data:
|
|
|
|
|
raise ValueError("Empty 'data' in response")
|
|
|
|
|
item = data[0]
|
|
|
|
|
elif isinstance(data, dict):
|
|
|
|
|
item = data
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(f"Unexpected 'data' type: {type(data)}")
|
|
|
|
|
|
|
|
|
|
if not isinstance(item, dict):
|
|
|
|
|
raise TypeError("Unexpected item shape in 'data'")
|
|
|
|
|
if "embedding" not in item:
|
|
|
|
|
raise KeyError("Missing 'embedding' in response item")
|
|
|
|
|
return item["embedding"]
|
|
|
|
|
|
|
|
|
|
def _encode_texts(self, texts: list[str]):
|
|
|
|
|
from common.http_client import sync_request
|
|
|
|
|
|
|
|
|
|
url = f"{self.base_url}/embeddings/multimodal"
|
|
|
|
|
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.ark_api_key}"}
|
|
|
|
|
|
|
|
|
|
ress: list[list[float]] = []
|
|
|
|
|
total_tokens = 0
|
|
|
|
|
for text in texts:
|
|
|
|
|
request_body = {"model": self.model_name, "input": [{"type": "text", "text": text}]}
|
|
|
|
|
response = sync_request(method="POST", url=url, headers=headers, json=request_body, timeout=60)
|
|
|
|
|
if response.status_code != 200:
|
|
|
|
|
raise Exception(f"Error: {response.status_code} - {response.text}")
|
|
|
|
|
result = response.json()
|
|
|
|
|
try:
|
|
|
|
|
ress.append(self._extract_embedding(result))
|
|
|
|
|
total_tokens += total_token_count_from_response(result)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e)
|
|
|
|
|
|
|
|
|
|
return np.array(ress), total_tokens
|
|
|
|
|
|
|
|
|
|
def encode(self, texts: list):
|
|
|
|
|
return self._encode_texts(texts)
|
|
|
|
|
|
|
|
|
|
def encode_queries(self, text: str):
|
|
|
|
|
embeddings, tokens = self._encode_texts([text])
|
|
|
|
|
return embeddings[0], tokens
|
2025-01-15 14:15:58 +08:00
|
|
|
|
2025-06-12 17:53:59 +08:00
|
|
|
|
2025-01-15 14:15:58 +08:00
|
|
|
class GPUStackEmbed(OpenAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "GPUStack"
|
|
|
|
|
|
2025-01-15 14:15:58 +08:00
|
|
|
def __init__(self, key, model_name, base_url):
|
|
|
|
|
if not base_url:
|
|
|
|
|
raise ValueError("url cannot be None")
|
2025-06-03 14:18:40 +08:00
|
|
|
base_url = urljoin(base_url, "v1")
|
2025-01-15 14:15:58 +08:00
|
|
|
|
|
|
|
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
2025-03-21 15:25:48 +08:00
|
|
|
self.model_name = model_name
|
2025-06-13 15:42:17 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class NovitaEmbed(SILICONFLOWEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "NovitaAI"
|
|
|
|
|
|
2025-06-13 15:42:17 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"):
|
2025-06-30 11:22:11 +08:00
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.novita.ai/v3/openai/embeddings"
|
2025-06-30 09:22:31 +08:00
|
|
|
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GiteeEmbed(SILICONFLOWEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "GiteeAI"
|
|
|
|
|
|
2025-06-30 09:22:31 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"):
|
2025-06-30 11:22:11 +08:00
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://ai.gitee.com/v1/embeddings"
|
2025-07-03 19:05:31 +08:00
|
|
|
super().__init__(key, model_name, base_url)
|
2025-09-10 13:02:53 +08:00
|
|
|
|
2026-03-20 02:47:48 +00:00
|
|
|
|
2025-07-23 18:10:35 +08:00
|
|
|
class DeepInfraEmbed(OpenAIEmbed):
|
|
|
|
|
_FACTORY_NAME = "DeepInfra"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.deepinfra.com/v1/openai"
|
|
|
|
|
super().__init__(key, model_name, base_url)
|
2025-07-31 14:48:30 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class Ai302Embed(Base):
|
|
|
|
|
_FACTORY_NAME = "302.AI"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.302.ai/v1/embeddings"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.302.ai/v1/embeddings"
|
2025-09-10 13:02:53 +08:00
|
|
|
super().__init__(key, model_name, base_url)
|
2025-09-18 09:51:29 +08:00
|
|
|
|
|
|
|
|
|
2025-09-26 10:50:56 +08:00
|
|
|
class CometAPIEmbed(OpenAIEmbed):
|
2025-09-18 09:51:29 +08:00
|
|
|
_FACTORY_NAME = "CometAPI"
|
|
|
|
|
|
2025-09-18 14:49:47 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1"):
|
2025-09-18 09:51:29 +08:00
|
|
|
if not base_url:
|
2025-09-18 14:49:47 +08:00
|
|
|
base_url = "https://api.cometapi.com/v1"
|
2025-09-18 09:51:29 +08:00
|
|
|
super().__init__(key, model_name, base_url)
|
2025-10-09 11:14:49 +08:00
|
|
|
|
2026-03-20 02:47:48 +00:00
|
|
|
|
2025-10-09 11:14:49 +08:00
|
|
|
class DeerAPIEmbed(OpenAIEmbed):
|
|
|
|
|
_FACTORY_NAME = "DeerAPI"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.deerapi.com/v1"
|
|
|
|
|
super().__init__(key, model_name, base_url)
|
2025-11-17 19:47:46 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class JiekouAIEmbed(OpenAIEmbed):
|
|
|
|
|
_FACTORY_NAME = "Jiekou.AI"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.jiekou.ai/openai/v1/embeddings"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.jiekou.ai/openai/v1/embeddings"
|
|
|
|
|
super().__init__(key, model_name, base_url)
|
2026-03-06 02:37:27 +01:00
|
|
|
|
2026-03-20 02:47:48 +00:00
|
|
|
|
2026-03-06 02:37:27 +01:00
|
|
|
class RAGconEmbed(OpenAIEmbed):
|
|
|
|
|
"""
|
|
|
|
|
RAGcon Embedding Provider - routes through LiteLLM proxy
|
2026-03-20 02:47:48 +00:00
|
|
|
|
2026-03-06 02:37:27 +01:00
|
|
|
Default Base URL: https://connect.ragcon.ai/v1
|
|
|
|
|
"""
|
2026-03-20 02:47:48 +00:00
|
|
|
|
2026-03-06 02:37:27 +01:00
|
|
|
_FACTORY_NAME = "RAGcon"
|
2026-03-20 02:47:48 +00:00
|
|
|
|
2026-03-06 02:37:27 +01:00
|
|
|
def __init__(self, key, model_name="text-embedding-3-small", base_url=None):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://connect.ragcon.com/v1"
|
2026-03-20 02:47:48 +00:00
|
|
|
|
|
|
|
|
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PerplexityEmbed(Base):
|
|
|
|
|
_FACTORY_NAME = "Perplexity"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name="pplx-embed-v1-0.6b", base_url="https://api.perplexity.ai"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.perplexity.ai"
|
|
|
|
|
self.base_url = base_url.rstrip("/")
|
|
|
|
|
self.api_key = key
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
self.headers = {
|
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _decode_base64_int8(b64_str):
|
|
|
|
|
raw = base64.b64decode(b64_str)
|
|
|
|
|
return np.frombuffer(raw, dtype=np.int8).astype(np.float32)
|
|
|
|
|
|
|
|
|
|
def _is_contextualized(self):
|
|
|
|
|
return "context" in self.model_name
|
|
|
|
|
|
|
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 512
|
|
|
|
|
ress = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
|
|
|
|
|
if self._is_contextualized():
|
|
|
|
|
url = f"{self.base_url}/v1/contextualizedembeddings"
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
|
|
|
batch = texts[i : i + batch_size]
|
|
|
|
|
payload = {
|
|
|
|
|
"model": self.model_name,
|
|
|
|
|
"input": [[chunk] for chunk in batch],
|
|
|
|
|
"encoding_format": "base64_int8",
|
|
|
|
|
}
|
2026-05-11 11:19:07 +08:00
|
|
|
response = requests.post(url, headers=self.headers, json=payload, timeout=30)
|
2026-06-01 19:18:16 +08:00
|
|
|
_raise_model_exception_if_failed(response)
|
2026-03-20 02:47:48 +00:00
|
|
|
try:
|
|
|
|
|
res = response.json()
|
|
|
|
|
for doc in res["data"]:
|
|
|
|
|
for chunk_emb in doc["data"]:
|
|
|
|
|
ress.append(self._decode_base64_int8(chunk_emb["embedding"]))
|
|
|
|
|
token_count += res.get("usage", {}).get("total_tokens", 0)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, response)
|
|
|
|
|
raise Exception(f"Error: {response.text}")
|
|
|
|
|
else:
|
|
|
|
|
url = f"{self.base_url}/v1/embeddings"
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
|
|
|
batch = texts[i : i + batch_size]
|
|
|
|
|
payload = {
|
|
|
|
|
"model": self.model_name,
|
|
|
|
|
"input": batch,
|
|
|
|
|
"encoding_format": "base64_int8",
|
|
|
|
|
}
|
2026-05-11 11:19:07 +08:00
|
|
|
response = requests.post(url, headers=self.headers, json=payload, timeout=30)
|
2026-06-01 19:18:16 +08:00
|
|
|
_raise_model_exception_if_failed(response)
|
2026-03-20 02:47:48 +00:00
|
|
|
try:
|
|
|
|
|
res = response.json()
|
|
|
|
|
for d in res["data"]:
|
|
|
|
|
ress.append(self._decode_base64_int8(d["embedding"]))
|
|
|
|
|
token_count += res.get("usage", {}).get("total_tokens", 0)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, response)
|
|
|
|
|
raise Exception(f"Error: {response.text}")
|
|
|
|
|
|
|
|
|
|
return np.array(ress), token_count
|
|
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
|
|
|
|
embds, cnt = self.encode([text])
|
|
|
|
|
return np.array(embds[0]), cnt
|