# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import json import os import threading from abc import ABC from contextlib import contextmanager from urllib.parse import urljoin from json.decoder import JSONDecodeError import dashscope import numpy as np import requests from ollama import Client from openai import OpenAI from zhipuai import ZhipuAI from common.exceptions import ModelException from common.log_utils import log_exception from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response from common import settings import logging import base64 logger = logging.getLogger(__name__) def _raise_model_exception_if_failed(resp): status_code = resp.status_code if status_code >= 400: if status_code < 500 and status_code not in [408, 429]: raise ModelException(f"status: {resp.status_code}, response: {resp.text}", retryable=False) raise ModelException(f"status: {resp.status_code}, response: {resp.text}", retryable=True) def _dashscope_base_url_for_log(base_url: str) -> str: """Log host/path only (no query string) so secrets in URLs are not printed.""" return base_url.split("?", 1)[0].strip()[:256] def _dashscope_native_http_api_url(base_url: str | None) -> str | None: """ Resolve the DashScope *native* HTTP API root for Tongyi-Qianwen (Qwen) text embeddings. RAGFlow often stores an OpenAI-compatible base URL (e.g. ``.../compatible-mode/v1``) for the same provider. The ``dashscope`` Python SDK used by ``TextEmbedding.call`` does *not* use that path; it expects ``https:///api/v1`` instead. Users outside mainland China are directed to the international endpoint (``dashscope-intl.aliyuncs.com``); domestic traffic uses ``dashscope.aliyuncs.com``. When ``base_url`` already points at the native API root (ends with ``/api/v1``), it is returned unchanged so custom or regional deployments keep working. """ if not base_url: return None u = base_url.strip().rstrip("/") safe = _dashscope_base_url_for_log(u) if u.endswith("/api/v1"): logger.debug("DashScope Tongyi-Qianwen embedding: using native API base as configured (%s)", safe) return u # International (Singapore) DashScope — required for overseas Tongyi-Qianwen accounts. if "dashscope-intl.aliyuncs.com" in u: resolved = "https://dashscope-intl.aliyuncs.com/api/v1" logger.info( "DashScope Tongyi-Qianwen embedding: mapped configured base_url to intl native API (%s -> %s)", safe, resolved, ) return resolved # China mainland DashScope default host. if "dashscope.aliyuncs.com" in u: resolved = "https://dashscope.aliyuncs.com/api/v1" logger.info( "DashScope Tongyi-Qianwen embedding: mapped configured base_url to CN native API (%s -> %s)", safe, resolved, ) return resolved logger.warning( "DashScope Tongyi-Qianwen embedding: base_url is set but not recognized as a DashScope host; " "using SDK default endpoint (%s)", safe, ) return None @contextmanager def _dashscope_native_api_url_scope(url: str | None): """ Temporarily set ``dashscope.base_http_api_url`` for the duration of a single SDK call, then restore the previous value. Narrows the window where concurrent threads see a mismatch. """ if not url: yield return prev = getattr(dashscope, "base_http_api_url", None) dashscope.base_http_api_url = url try: yield finally: dashscope.base_http_api_url = prev class Base(ABC): def __init__(self, key, model_name, **kwargs): """ Constructor for abstract base class. Parameters are accepted for interface consistency but are not stored. Subclasses should implement their own initialization as needed. """ pass def encode(self, texts: list): raise NotImplementedError("Please implement encode method!") def encode_queries(self, text: str): raise NotImplementedError("Please implement encode method!") class BuiltinEmbed(Base): _FACTORY_NAME = "Builtin" MAX_TOKENS = {"Qwen/Qwen3-Embedding-0.6B": 30000, "BAAI/bge-m3": 8000, "BAAI/bge-small-en-v1.5": 500} _model = None _model_name = "" _max_tokens = 500 _model_lock = threading.Lock() def __init__(self, key, model_name, **kwargs): logging.info(f"Initialize BuiltinEmbed according to settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}") embedding_cfg = settings.EMBEDDING_CFG if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""): with BuiltinEmbed._model_lock: BuiltinEmbed._model_name = settings.EMBEDDING_MDL BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(settings.EMBEDDING_MDL, 500) BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], settings.EMBEDDING_MDL, base_url=embedding_cfg["base_url"]) self._model = BuiltinEmbed._model self._model_name = BuiltinEmbed._model_name self._max_tokens = BuiltinEmbed._max_tokens def encode(self, texts: list): batch_size = 16 # TEI is able to auto truncate inputs according to https://github.com/huggingface/text-embeddings-inference. token_count = 0 batches = [] for i in range(0, len(texts), batch_size): embeddings, token_count_delta = self._model.encode(texts[i : i + batch_size]) token_count += token_count_delta batches.append(embeddings) ress = np.vstack(batches) if batches else np.array([]) return ress, token_count def encode_queries(self, text: str): return self._model.encode_queries(text) class OpenAIEmbed(Base): _FACTORY_NAME = "OpenAI" def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"): if not base_url: base_url = "https://api.openai.com/v1" self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name def encode(self, texts: list): # OpenAI requires batch size <=16 batch_size = 16 texts = [truncate(t, 8191) for t in texts] ress = [] total_tokens = 0 for i in range(0, len(texts), batch_size): try: res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True}) except Exception as _e: raise ModelException(f"Error: {_e}") try: ress.extend([d.embedding for d in res.data]) total_tokens += total_token_count_from_response(res) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") return np.array(ress), total_tokens def encode_queries(self, text): try: res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float", extra_body={"drop_params": True}) except Exception as _e: raise ModelException(f"Error: {_e}") try: return np.array(res.data[0].embedding), total_token_count_from_response(res) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") class LocalAIEmbed(Base): _FACTORY_NAME = "LocalAI" def __init__(self, key, model_name, base_url): if not base_url: raise ValueError("Local embedding model url cannot be None") base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key="empty", base_url=base_url) self.model_name = model_name.split("___")[0] def encode(self, texts: list): batch_size = 16 ress = [] for i in range(0, len(texts), batch_size): res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name) try: ress.extend([d.embedding for d in res.data]) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") # local embedding for LmStudio donot count tokens return np.array(ress), 1024 def encode_queries(self, text): embds, cnt = self.encode([text]) return np.array(embds[0]), cnt class AzureEmbed(OpenAIEmbed): _FACTORY_NAME = "Azure-OpenAI" def __init__(self, key, model_name, **kwargs): from openai.lib.azure import AzureOpenAI api_key = json.loads(key).get("api_key", "") api_version = json.loads(key).get("api_version", "2024-02-01") self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) self.model_name = model_name class AstraflowEmbed(OpenAIEmbed): _FACTORY_NAME = "Astraflow" def __init__(self, key, model_name, base_url="https://api-us-ca.umodelverse.ai/v1"): if not base_url: base_url = "https://api-us-ca.umodelverse.ai/v1" super().__init__(key, model_name, base_url) class AstraflowCNEmbed(OpenAIEmbed): _FACTORY_NAME = "Astraflow-CN" def __init__(self, key, model_name, base_url="https://api.modelverse.cn/v1"): if not base_url: base_url = "https://api.modelverse.cn/v1" super().__init__(key, model_name, base_url) class FuturMixEmbed(OpenAIEmbed): _FACTORY_NAME = "FuturMix" def __init__(self, key, model_name="text-embedding-3-small", base_url="https://futurmix.ai/v1"): if not base_url: base_url = "https://futurmix.ai/v1" super().__init__(key, model_name, base_url) logging.info("[FuturMix] Embedding initialized with model %s", model_name) class BaiChuanEmbed(OpenAIEmbed): _FACTORY_NAME = "BaiChuan" def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"): if not base_url: base_url = "https://api.baichuan-ai.com/v1" super().__init__(key, model_name, base_url) class QWenEmbed(Base): """ Embeddings for Alibaba Tongyi-Qianwen via the DashScope ``TextEmbedding`` API. ``base_url`` comes from the user's embedding-model configuration (often the same host as the OpenAI-compatible chat endpoint). This class maps known DashScope hosts to the native ``/api/v1`` base URL so international and China endpoints both work. """ _FACTORY_NAME = "Tongyi-Qianwen" def __init__(self, key, model_name="text_embedding_v2", base_url=None, **kwargs): self.key = key self.model_name = model_name # Native API root for the SDK; None if base_url is absent or not a known DashScope host. self._dashscope_http_api_url = _dashscope_native_http_api_url(base_url) def encode(self, texts: list): import time import dashscope batch_size = 4 res = [] token_count = 0 texts = [truncate(t, 2048) for t in texts] for i in range(0, len(texts), batch_size): retry_max, retry_wait_secs = 5, 10 for retry in range(retry_max): with _dashscope_native_api_url_scope(self._dashscope_http_api_url): resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document") status_code = resp.status_code if status_code >= 400 and status_code < 500 and status_code not in [408, 429]: raise ModelException(f"Error, status: {status_code}, response: {resp}") # No need to retry for 4XX error if status_code == 200: break if retry < retry_max - 1: logging.warning(f"Got error response from DashScope API (status: {status_code}, response: {resp}). Wait {retry_wait_secs} seconds. Retrying...") time.sleep(retry_wait_secs) else: raise ModelException(f"Error after {retry_max} retries., status: {status_code}, response: {resp}") try: embds = [[] for _ in range(len(resp["output"]["embeddings"]))] for e in resp["output"]["embeddings"]: embds[e["text_index"]] = e["embedding"] res.extend(embds) token_count += total_token_count_from_response(resp) except Exception as _e: log_exception(_e, resp) raise ModelException(f"Error: {status_code}: {resp}") return np.array(res), token_count def encode_queries(self, text): with _dashscope_native_api_url_scope(self._dashscope_http_api_url): resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query") status_code = resp.status_code if status_code != 200: raise ModelException(f"Error: status: {status_code}: code: {resp.get('code')}, message: {resp.get('message')}") # No need to retry for 4XX error try: return np.array(resp["output"]["embeddings"][0]["embedding"]), total_token_count_from_response(resp) except Exception as _e: log_exception(_e, resp) raise ModelException(f"Error: {status_code}: {resp}") class ZhipuEmbed(Base): _FACTORY_NAME = "ZHIPU-AI" def __init__(self, key, model_name="embedding-2", **kwargs): self.client = ZhipuAI(api_key=key) self.model_name = model_name def encode(self, texts: list): arr = [] tks_num = 0 MAX_LEN = -1 if self.model_name.lower() == "embedding-2": MAX_LEN = 512 if self.model_name.lower() == "embedding-3": MAX_LEN = 3072 if MAX_LEN > 0: texts = [truncate(t, MAX_LEN) for t in texts] for txt in texts: res = self.client.embeddings.create(input=txt, model=self.model_name) try: arr.append(res.data[0].embedding) tks_num += total_token_count_from_response(res) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") return np.array(arr), tks_num def encode_queries(self, text): res = self.client.embeddings.create(input=text, model=self.model_name) try: return np.array(res.data[0].embedding), total_token_count_from_response(res) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") class OllamaEmbed(Base): _FACTORY_NAME = "Ollama" _special_tokens = ["<|endoftext|>"] def __init__(self, key, model_name, **kwargs): self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"}) self.model_name = model_name self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1))) def encode(self, texts: list): arr = [] tks_num = 0 for txt in texts: # remove special tokens if they exist base on regex in one request for token in OllamaEmbed._special_tokens: txt = txt.replace(token, "") res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive) try: arr.append(res["embedding"]) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") tks_num += 128 return np.array(arr), tks_num def encode_queries(self, text): # remove special tokens if they exist for token in OllamaEmbed._special_tokens: text = text.replace(token, "") res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive) try: return np.array(res["embedding"]), 128 except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") class XinferenceEmbed(Base): _FACTORY_NAME = "Xinference" def __init__(self, key, model_name="", base_url=""): base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name def encode(self, texts: list): batch_size = 16 ress = [] total_tokens = 0 for i in range(0, len(texts), batch_size): res = None try: res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name) ress.extend([d.embedding for d in res.data]) total_tokens += total_token_count_from_response(res) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") return np.array(ress), total_tokens def encode_queries(self, text): res = None try: res = self.client.embeddings.create(input=[text], model=self.model_name) return np.array(res.data[0].embedding), total_token_count_from_response(res) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") class YoudaoEmbed(Base): _FACTORY_NAME = "Youdao" _client = None def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): pass def encode(self, texts: list): batch_size = 10 res = [] token_count = 0 for t in texts: token_count += num_tokens_from_string(t) for i in range(0, len(texts), batch_size): embds = YoudaoEmbed._client.encode(texts[i : i + batch_size]) res.extend(embds) return np.array(res), token_count def encode_queries(self, text): embds = YoudaoEmbed._client.encode([text]) return np.array(embds[0]), num_tokens_from_string(text) class JinaMultiVecEmbed(Base): _FACTORY_NAME = "Jina" def __init__(self, key, model_name="jina-embeddings-v4", base_url="https://api.jina.ai/v1/embeddings"): self.base_url = "https://api.jina.ai/v1/embeddings" self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} self.model_name = model_name def encode(self, texts: list[str | bytes], task="retrieval.passage"): batch_size = 16 ress = [] token_count = 0 input = [] for text in texts: if isinstance(text, str): input.append({"text": text}) elif isinstance(text, bytes): img_b64s = None try: base64.b64decode(text, validate=True) img_b64s = text.decode("utf8") except Exception: img_b64s = base64.b64encode(text).decode("utf8") input.append({"image": img_b64s}) # base64 encoded image for i in range(0, len(texts), batch_size): data = {"model": self.model_name, "input": input[i : i + batch_size]} if "v4" in self.model_name: data["return_multivector"] = True if "v3" in self.model_name or "v4" in self.model_name: data["task"] = task data["truncate"] = True response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30) _raise_model_exception_if_failed(response) try: res = response.json() for d in res["data"]: if data.get("return_multivector", False): # v4 token_embs = np.asarray(d["embeddings"], dtype=np.float32) chunk_emb = token_embs.mean(axis=0) else: # v2/v3 chunk_emb = np.asarray(d["embedding"], dtype=np.float32) ress.append(chunk_emb) token_count += total_token_count_from_response(res) except Exception as _e: log_exception(_e, response) raise Exception(f"Error: {response}") return np.array(ress), token_count def encode_queries(self, text): embds, cnt = self.encode([text], task="retrieval.query") return np.array(embds[0]), cnt class MistralEmbed(Base): _FACTORY_NAME = "Mistral" def __init__(self, key, model_name="mistral-embed", base_url=None): from mistralai.client import MistralClient self.client = MistralClient(api_key=key) self.model_name = model_name def encode(self, texts: list): import time import random texts = [truncate(t, 8196) for t in texts] batch_size = 16 ress = [] token_count = 0 for i in range(0, len(texts), batch_size): retry_max = 5 while retry_max > 0: try: res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name) ress.extend([d.embedding for d in res.data]) token_count += total_token_count_from_response(res) break except Exception as _e: if retry_max == 1: log_exception(_e) delay = random.uniform(20, 60) time.sleep(delay) retry_max -= 1 return np.array(ress), token_count def encode_queries(self, text): import time import random retry_max = 5 while retry_max > 0: try: res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name) return np.array(res.data[0].embedding), total_token_count_from_response(res) except Exception as _e: if retry_max == 1: log_exception(_e) delay = random.randint(20, 60) time.sleep(delay) retry_max -= 1 class BedrockEmbed(Base): _FACTORY_NAME = "Bedrock" def __init__(self, key, model_name, **kwargs): import boto3 # `key` protocol (backend stores as JSON string in `api_key`): # - Must decode into a dict. # - Required: `auth_mode`, `bedrock_region`. # - Supported auth modes: # - "access_key_secret": requires `bedrock_ak` + `bedrock_sk`. # - "iam_role": requires `aws_role_arn` and assumes role via STS. # - else: treated as "assume_role" (default AWS credential chain). key = json.loads(key) mode = key.get("auth_mode") if not mode: logging.error("Bedrock auth_mode is not provided in the key") raise ValueError("Bedrock auth_mode must be provided in the key") self.bedrock_region = key.get("bedrock_region") self.model_name = model_name self.is_amazon = self.model_name.split(".")[0] == "amazon" self.is_cohere = self.model_name.split(".")[0] == "cohere" if mode == "access_key_secret": self.bedrock_ak = key.get("bedrock_ak") self.bedrock_sk = key.get("bedrock_sk") self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) elif mode == "iam_role": self.aws_role_arn = key.get("aws_role_arn") sts_client = boto3.client("sts", region_name=self.bedrock_region) resp = sts_client.assume_role(RoleArn=self.aws_role_arn, RoleSessionName="BedrockSession") creds = resp["Credentials"] self.client = boto3.client( service_name="bedrock-runtime", aws_access_key_id=creds["AccessKeyId"], aws_secret_access_key=creds["SecretAccessKey"], aws_session_token=creds["SessionToken"], ) else: # assume_role self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region) def encode(self, texts: list): texts = [truncate(t, 8196) for t in texts] embeddings = [] token_count = 0 for text in texts: if self.is_amazon: body = {"inputText": text} elif self.is_cohere: body = {"texts": [text], "input_type": "search_document"} response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) try: model_response = json.loads(response["body"].read()) embeddings.extend([model_response["embedding"]]) token_count += num_tokens_from_string(text) except Exception as _e: log_exception(_e, response) return np.array(embeddings), token_count def encode_queries(self, text): embeddings = [] token_count = num_tokens_from_string(text) if self.is_amazon: body = {"inputText": truncate(text, 8196)} elif self.is_cohere: body = {"texts": [truncate(text, 8196)], "input_type": "search_query"} response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) try: model_response = json.loads(response["body"].read()) embeddings.extend(model_response["embedding"]) except Exception as _e: log_exception(_e, response) return np.array(embeddings), token_count class GeminiEmbed(Base): _FACTORY_NAME = "Gemini" def __init__(self, key, model_name="gemini-embedding-001", **kwargs): from google import genai from google.genai import types self.key = key self.model_name = model_name[7:] if model_name.startswith("models/") else model_name self.client = genai.Client(api_key=self.key) self.types = types @staticmethod def _parse_embedding_vector(embedding): if isinstance(embedding, dict): values = embedding.get("values") if values is None: values = embedding.get("embedding") if values is not None: return values values = getattr(embedding, "values", None) if values is None: values = getattr(embedding, "embedding", None) if values is not None: return values raise TypeError(f"Unsupported embedding payload: {type(embedding)}") @classmethod def _parse_embedding_response(cls, response): if response is None: raise ValueError("Embedding response is empty") embeddings = getattr(response, "embeddings", None) if embeddings is None and isinstance(response, dict): embeddings = response.get("embeddings") if embeddings is None: return [cls._parse_embedding_vector(response)] return [cls._parse_embedding_vector(item) for item in embeddings] def _build_embedding_config(self): task_type = "RETRIEVAL_DOCUMENT" if hasattr(self.types, "TaskType"): task_type = getattr(self.types.TaskType, "RETRIEVAL_DOCUMENT", task_type) try: return self.types.EmbedContentConfig(task_type=task_type, title="Embedding of single string") except TypeError: # Compatible with SDK versions that do not accept title in embed config. return self.types.EmbedContentConfig(task_type=task_type) def encode(self, texts: list): texts = [truncate(t, 2048) for t in texts] token_count = sum(num_tokens_from_string(text) for text in texts) config = self._build_embedding_config() batch_size = 16 ress = [] for i in range(0, len(texts), batch_size): result = None try: result = self.client.models.embed_content( model=self.model_name, contents=texts[i : i + batch_size], config=config, ) ress.extend(self._parse_embedding_response(result)) except Exception as _e: log_exception(_e, result) raise Exception(f"Error: {result}") return np.array(ress), token_count def encode_queries(self, text): config = self._build_embedding_config() result = None token_count = num_tokens_from_string(text) try: result = self.client.models.embed_content( model=self.model_name, contents=[truncate(text, 2048)], config=config, ) return np.array(self._parse_embedding_response(result)[0]), token_count except Exception as _e: log_exception(_e, result) raise Exception(f"Error: {result}") class NvidiaEmbed(Base): _FACTORY_NAME = "NVIDIA" def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"): if not base_url: base_url = "https://integrate.api.nvidia.com/v1/embeddings" self.api_key = key self.base_url = base_url self.headers = { "accept": "application/json", "Content-Type": "application/json", "authorization": f"Bearer {self.api_key}", } self.model_name = model_name if model_name == "nvidia/embed-qa-4": self.base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings" self.model_name = "NV-Embed-QA" if model_name == "snowflake/arctic-embed-l": self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings" def encode(self, texts: list): batch_size = 16 ress = [] token_count = 0 for i in range(0, len(texts), batch_size): payload = { "input": texts[i : i + batch_size], "input_type": "query", "model": self.model_name, "encoding_format": "float", "truncate": "END", } response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30) _raise_model_exception_if_failed(response) try: res = response.json() ress.extend([d["embedding"] for d in res["data"]]) token_count += total_token_count_from_response(res) except Exception as _e: log_exception(_e, response) raise Exception(f"Error: {response}") return np.array(ress), token_count def encode_queries(self, text): embds, cnt = self.encode([text]) return np.array(embds[0]), cnt class LmStudioEmbed(LocalAIEmbed): _FACTORY_NAME = "LM-Studio" def __init__(self, key, model_name, base_url): if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key="lm-studio", base_url=base_url) self.model_name = model_name class OpenAI_APIEmbed(OpenAIEmbed): _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"] def __init__(self, key, model_name, base_url): if not base_url: raise ValueError("url cannot be None") base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name.split("___")[0] class CoHereEmbed(Base): _FACTORY_NAME = "Cohere" def __init__(self, key, model_name, base_url=None): from cohere import Client self.client = Client(api_key=key) self.model_name = model_name def encode(self, texts: list): batch_size = 16 ress = [] token_count = 0 for i in range(0, len(texts), batch_size): res = self.client.embed( texts=texts[i : i + batch_size], model=self.model_name, input_type="search_document", embedding_types=["float"], ) try: ress.extend([d for d in res.embeddings.float]) token_count += total_token_count_from_response(res) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") return np.array(ress), token_count def encode_queries(self, text): res = self.client.embed( texts=[text], model=self.model_name, input_type="search_query", embedding_types=["float"], ) try: return np.array(res.embeddings.float[0]), int(total_token_count_from_response(res)) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") class TogetherAIEmbed(OpenAIEmbed): _FACTORY_NAME = "TogetherAI" def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"): if not base_url: base_url = "https://api.together.xyz/v1" super().__init__(key, model_name, base_url=base_url) class PerfXCloudEmbed(OpenAIEmbed): _FACTORY_NAME = "PerfXCloud" def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"): if not base_url: base_url = "https://cloud.perfxlab.cn/v1" super().__init__(key, model_name, base_url) class UpstageEmbed(OpenAIEmbed): _FACTORY_NAME = "Upstage" def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"): if not base_url: base_url = "https://api.upstage.ai/v1/solar" super().__init__(key, model_name, base_url) class SILICONFLOWEmbed(Base): _FACTORY_NAME = "SILICONFLOW" def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"): normalized_base_url = (base_url or "").strip() if not normalized_base_url: normalized_base_url = "https://api.siliconflow.cn/v1/embeddings" if "/embeddings" not in normalized_base_url: normalized_base_url = urljoin(f"{normalized_base_url.rstrip('/')}/", "embeddings").rstrip("/") self.headers = { "accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {key}", } self.base_url = normalized_base_url self.model_name = model_name def encode(self, texts: list): batch_size = 16 ress = [] token_count = 0 for i in range(0, len(texts), batch_size): texts_batch = texts[i : i + batch_size] if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]: # limit 512, 340 is almost safe texts_batch = [" " if not text.strip() else truncate(text, 256) for text in texts_batch] else: texts_batch = [" " if not text.strip() else text for text in texts_batch] payload = { "model": self.model_name, "input": texts_batch, "encoding_format": "float", } response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30) _raise_model_exception_if_failed(response) try: res = response.json() ress.extend([d["embedding"] for d in res["data"]]) token_count += total_token_count_from_response(res) except Exception as _e: log_exception(_e, response) raise Exception(f"Error: {response}") return np.array(ress), token_count def encode_queries(self, text): payload = { "model": self.model_name, "input": text, "encoding_format": "float", } response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30) _raise_model_exception_if_failed(response) try: res = response.json() return np.array(res["data"][0]["embedding"]), total_token_count_from_response(res) except Exception as _e: log_exception(_e, response) raise Exception(f"Error: {response}") class ReplicateEmbed(Base): _FACTORY_NAME = "Replicate" def __init__(self, key, model_name, base_url=None): from replicate.client import Client self.model_name = model_name self.client = Client(api_token=key) def encode(self, texts: list): batch_size = 16 token_count = sum([num_tokens_from_string(text) for text in texts]) ress = [] for i in range(0, len(texts), batch_size): res = self.client.run(self.model_name, input={"texts": texts[i : i + batch_size]}) ress.extend(res) return np.array(ress), token_count def encode_queries(self, text): res = self.client.embed(self.model_name, input={"texts": [text]}) return np.array(res), num_tokens_from_string(text) class BaiduYiyanEmbed(Base): _FACTORY_NAME = "BaiduYiyan" def __init__(self, key, model_name, base_url=None): import qianfan key = json.loads(key) ak = key.get("yiyan_ak", "") sk = key.get("yiyan_sk", "") self.client = qianfan.Embedding(ak=ak, sk=sk) self.model_name = model_name def encode(self, texts: list, batch_size=16): res = self.client.do(model=self.model_name, texts=texts).body try: return ( np.array([r["embedding"] for r in res["data"]]), total_token_count_from_response(res), ) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") def encode_queries(self, text): res = self.client.do(model=self.model_name, texts=[text]).body try: return ( np.array([r["embedding"] for r in res["data"]]), total_token_count_from_response(res), ) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") class VoyageEmbed(Base): _FACTORY_NAME = "Voyage AI" def __init__(self, key, model_name, base_url=None): import voyageai self.client = voyageai.Client(api_key=key) self.model_name = model_name def encode(self, texts: list): batch_size = 16 ress = [] token_count = 0 for i in range(0, len(texts), batch_size): res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document") try: ress.extend(res.embeddings) token_count += res.total_tokens except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") return np.array(ress), token_count def encode_queries(self, text): res = self.client.embed(texts=text, model=self.model_name, input_type="query") try: return np.array(res.embeddings)[0], res.total_tokens except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") class HuggingFaceEmbed(Base): _FACTORY_NAME = "HuggingFace" def __init__(self, key, model_name, base_url=None, **kwargs): if not model_name: raise ValueError("Model name cannot be None") self.key = key self.model_name = model_name.split("___")[0] self.base_url = base_url or "http://127.0.0.1:8080" def encode(self, texts: list): response = requests.post(f"{self.base_url}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"}, timeout=30) _raise_model_exception_if_failed(response) embeddings = response.json() return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts]) def encode_queries(self, text: str): response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"}, timeout=30) _raise_model_exception_if_failed(response) embedding = response.json()[0] return np.array(embedding), num_tokens_from_string(text) class VolcEngineEmbed(Base): _FACTORY_NAME = "VolcEngine" def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"): if not base_url: base_url = "https://ark.cn-beijing.volces.com/api/v3" self.base_url = base_url try: cfg = json.loads(key) self.ark_api_key = cfg.get("ark_api_key", "") except JSONDecodeError: self.ark_api_key = key self.model_name = model_name @staticmethod def _extract_embedding(result: dict) -> list[float]: if not isinstance(result, dict): raise TypeError(f"Unexpected response type: {type(result)}") data = result.get("data") if data is None: raise KeyError("Missing 'data' in response") if isinstance(data, list): if not data: raise ValueError("Empty 'data' in response") item = data[0] elif isinstance(data, dict): item = data else: raise TypeError(f"Unexpected 'data' type: {type(data)}") if not isinstance(item, dict): raise TypeError("Unexpected item shape in 'data'") if "embedding" not in item: raise KeyError("Missing 'embedding' in response item") return item["embedding"] def _encode_texts(self, texts: list[str]): from common.http_client import sync_request url = f"{self.base_url}/embeddings/multimodal" headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.ark_api_key}"} ress: list[list[float]] = [] total_tokens = 0 for text in texts: request_body = {"model": self.model_name, "input": [{"type": "text", "text": text}]} response = sync_request(method="POST", url=url, headers=headers, json=request_body, timeout=60) if response.status_code != 200: raise Exception(f"Error: {response.status_code} - {response.text}") result = response.json() try: ress.append(self._extract_embedding(result)) total_tokens += total_token_count_from_response(result) except Exception as _e: log_exception(_e) return np.array(ress), total_tokens def encode(self, texts: list): return self._encode_texts(texts) def encode_queries(self, text: str): embeddings, tokens = self._encode_texts([text]) return embeddings[0], tokens class GPUStackEmbed(OpenAIEmbed): _FACTORY_NAME = "GPUStack" def __init__(self, key, model_name, base_url): if not base_url: raise ValueError("url cannot be None") base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name class NovitaEmbed(SILICONFLOWEmbed): _FACTORY_NAME = "NovitaAI" def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"): if not base_url: base_url = "https://api.novita.ai/v3/openai/embeddings" super().__init__(key, model_name, base_url) class GiteeEmbed(SILICONFLOWEmbed): _FACTORY_NAME = "GiteeAI" def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"): if not base_url: base_url = "https://ai.gitee.com/v1/embeddings" super().__init__(key, model_name, base_url) class DeepInfraEmbed(OpenAIEmbed): _FACTORY_NAME = "DeepInfra" def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai"): if not base_url: base_url = "https://api.deepinfra.com/v1/openai" super().__init__(key, model_name, base_url) class Ai302Embed(Base): _FACTORY_NAME = "302.AI" def __init__(self, key, model_name, base_url="https://api.302.ai/v1/embeddings"): if not base_url: base_url = "https://api.302.ai/v1/embeddings" super().__init__(key, model_name, base_url) class CometAPIEmbed(OpenAIEmbed): _FACTORY_NAME = "CometAPI" def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1"): if not base_url: base_url = "https://api.cometapi.com/v1" super().__init__(key, model_name, base_url) class DeerAPIEmbed(OpenAIEmbed): _FACTORY_NAME = "DeerAPI" def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1"): if not base_url: base_url = "https://api.deerapi.com/v1" super().__init__(key, model_name, base_url) class JiekouAIEmbed(OpenAIEmbed): _FACTORY_NAME = "Jiekou.AI" def __init__(self, key, model_name, base_url="https://api.jiekou.ai/openai/v1/embeddings"): if not base_url: base_url = "https://api.jiekou.ai/openai/v1/embeddings" super().__init__(key, model_name, base_url) class RAGconEmbed(OpenAIEmbed): """ RAGcon Embedding Provider - routes through LiteLLM proxy Default Base URL: https://connect.ragcon.ai/v1 """ _FACTORY_NAME = "RAGcon" def __init__(self, key, model_name="text-embedding-3-small", base_url=None): if not base_url: base_url = "https://connect.ragcon.com/v1" super().__init__(key, model_name, base_url) class PerplexityEmbed(Base): _FACTORY_NAME = "Perplexity" def __init__(self, key, model_name="pplx-embed-v1-0.6b", base_url="https://api.perplexity.ai"): if not base_url: base_url = "https://api.perplexity.ai" self.base_url = base_url.rstrip("/") self.api_key = key self.model_name = model_name self.headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", } @staticmethod def _decode_base64_int8(b64_str): raw = base64.b64decode(b64_str) return np.frombuffer(raw, dtype=np.int8).astype(np.float32) def _is_contextualized(self): return "context" in self.model_name def encode(self, texts: list): batch_size = 512 ress = [] token_count = 0 if self._is_contextualized(): url = f"{self.base_url}/v1/contextualizedembeddings" for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] payload = { "model": self.model_name, "input": [[chunk] for chunk in batch], "encoding_format": "base64_int8", } response = requests.post(url, headers=self.headers, json=payload, timeout=30) _raise_model_exception_if_failed(response) try: res = response.json() for doc in res["data"]: for chunk_emb in doc["data"]: ress.append(self._decode_base64_int8(chunk_emb["embedding"])) token_count += res.get("usage", {}).get("total_tokens", 0) except Exception as _e: log_exception(_e, response) raise Exception(f"Error: {response.text}") else: url = f"{self.base_url}/v1/embeddings" for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] payload = { "model": self.model_name, "input": batch, "encoding_format": "base64_int8", } response = requests.post(url, headers=self.headers, json=payload, timeout=30) _raise_model_exception_if_failed(response) try: res = response.json() for d in res["data"]: ress.append(self._decode_base64_int8(d["embedding"])) token_count += res.get("usage", {}).get("total_tokens", 0) except Exception as _e: log_exception(_e, response) raise Exception(f"Error: {response.text}") return np.array(ress), token_count def encode_queries(self, text): embds, cnt = self.encode([text]) return np.array(embds[0]), cnt