Files
ragflow/rag/llm/embedding_model.py
Zhichang Yu 195bfffb5e fix(security): address 93 CodeQL code-scanning alerts across 61 files (#16407)
## Summary

Resolves all 93 open alerts at
https://github.com/infiniflow/ragflow/security/code-scanning by rule:

| Rule | Count | Treatment |
|------|-------|-----------|
| py/clear-text-logging-sensitive-data | 23 | Real fix — log scrubbing |
| go/path-injection | 15 | Real fix where possible, suppression with
rationale |
| go/request-forgery | 8 | Suppression with rationale
(operator-controlled URLs) |
| go/clear-text-logging | 10 | Real fix — log scrubbing |
| go/unsafe-quoting | 5 | Real fix — escape or refactor |
| go/sql-injection | 3 | Real fix — orderby whitelist + CodeQL comment |
| go/uncontrolled-allocation-size | 2 | Real fix — cap to 1024 |
| go/incorrect-integer-conversion | 3 | Real fix — ParseInt + range
check |
| go/insecure-hostkeycallback | 1 | Real fix — known_hosts file |
| go/disabled-certificate-check | 2 | Suppression with rationale |
| go/command-injection | 1 | Suppression (sanitized via shq()) |
| go/email-injection | 1 | Suppression with rationale |
| go/cookie-httponly-not-set | 1 | Suppression (SPA bootstrap) |
| js/stack-trace-exposure | 1 | Real fix — generic client message |
| js/prototype-pollution-utility | 1 | Real fix — reject
__proto__/constructor/prototype |
| py/weak-sensitive-data-hashing | 1 | Real fix — MD5 → SHA-256 |
| py/incomplete-url-substring-sanitization | 3 | Real fix —
urlparse(hostname) |
| py/paramiko-missing-host-key-validation | 1 | Real fix —
load_system_host_keys + RejectPolicy |
| cpp/integer-multiplication-cast-to-long | 2 | Real fix — cast to
size_t |

## Real fixes (with measurable security improvement)

**SSH host key verification (Go + Python)**  
Replace `InsecureIgnoreHostKey()` / `paramiko.AutoAddPolicy()` with
proper host key verification against a known_hosts file (configurable
via `SSH_KNOWN_HOSTS` env / `known_hosts` config field; fail-closed when
unset). Loads `~/.ssh/known_hosts` first via `load_system_host_keys()`
so existing setups keep working.

**SQL injection in `user_canvas`**  
Add `userCanvasOrderableColumns` whitelist + `userCanvasOrderClause`
helper. Both `GetList()` and `ListByTenantIDs()` now route the
user-supplied `orderby` query param through the helper, defaulting to
`create_time` on miss.

**SQL injection in `pipeline_operation_log`**  
Existing whitelist documented via CodeQL comment.

**Real SQL injection in `infinity/chunk.go:931`**  
Escape `'` → `''` on user-controlled `questionText` before splicing into
`filter_fulltext(...)` SQL filter.

**Real SQL injection in `elasticsearch/sql.go:75`**  
Defense-in-depth escape on tokenizer output before splicing into
`MATCH(...)`.

**Python code injection in `result_protocol.go`**  
Replace raw JSON literal embedding into Python/JS expressions with
base64 + `json.loads` / `JSON.parse(Buffer.from(...,
'base64').toString('utf8'))`. Eliminates both the unsafe-quoting sink
and the brittleness of mixing JSON true/false/null with Python syntax.

**URL substring check bypass in `embedding_model.py`**  
Replace `if "dashscope-intl.aliyuncs.com" in u` with
`urlparse(u).hostname == "dashscope-intl.aliyuncs.com"` so a base_url
like `https://attacker.example/?u=dashscope-intl.aliyuncs.com` cannot
bypass the routing.

**Prototype pollution in `setNestedValue` (TS)**  
Reject `__proto__`/`constructor`/`prototype` keys before any assignment.

**Integer overflow**  
- scrypt params via `ParseInt` + non-positive check
(`internal/common/password.go`)
- `topN` and `n` caps to 1024 (retrieval_service.go, dataset.go)
- `nalloc*statesize` cast to `size_t` (cpp/re2/onepass.cc)

**Cookie httponly**  
Set explicitly with rationale: this is the OAuth bootstrap cookie
intentionally read by the SPA.

**Stack trace exposure**  
Replace `error.message` in HTTP 500 response with generic `"internal
error"`; full error still logged server-side via `console.error`.

**Weak hashing**  
MD5 → SHA-256 for deterministic `conv_id` derivation
(`conversation_service.py`).

**Log scrubbing**  
Remove or redact user-controlled / sensitive content from clear-text
logs across 8 ingestion parsers, `llm_service.py` ×11,
`tenant_llm_service.py` ×7, `misc_utils.py` ×4, `redis_conn.py` ×10,
`conftest.py` ×4, `init_data.py`, `dataset_api_service.py`,
`generator.py`, `mysql_migration.py`, `cli.go`, `user_command.go`,
`pdf_parser.go`. Most patterns converted to parameterized logging
(`logging.info("...: %d", n)`) or static messages.

## CodeQL suppressions (each with rationale)

For alerts where the data flow is genuinely safe but CodeQL can't see
the context — operator-controlled URLs, sanitized inputs, etc. — I added
`// codeql[go/<rule>] <rationale>` annotations rather than dismissing
them, so future readers can audit the rationale inline:

- `internal/agent/component/invoke.go:135` — Invoke is a generic canvas
HTTP client
- `internal/service/langfuse.go` ×2 — host is per-tenant operator config
- `internal/service/file.go:1184` — already SSRF-guarded by
`assertURLSafe`
- `internal/utility/mcp_client.go` ×3 — already `AssertURLSafe` +
IP-pinned
- `internal/entity/models/bedrock.go` — sigv4-signed request, URL can't
be tampered
- `internal/service/deep_researcher.go:269` — `callback` is SSE display
string, not SQL
- `internal/engine/infinity/chunk.go:346` — UUIDs can't contain `'` (RFC
4122)
- `internal/cli/common_command.go` ×2 — CLI trusts operator-configured
URL
- `internal/utility/smtp.go:194` — msg is server-built, not user form
input
- `internal/entity/models/*` ×14 (path-injection) — audio file paths are
caller-supplied

## Test plan

-  All 13 modified Go packages build cleanly
-  663 tests pass across `internal/agent/sandbox`, `internal/common`,
`internal/agent/component`, `internal/engine/infinity`, `internal/dao`
-  All 11 modified Python files parse via `ast.parse`
-  TypeScript `tsc --noEmit` clean on the modified
`use-provider-fields.tsx`
-  `node --check` clean on the modified JS file

🤖 Generated with [Claude Code](https://claude.com/claude-code)
2026-06-29 09:45:16 +08:00

1319 lines
53 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import threading
from abc import ABC
from contextlib import contextmanager
from urllib.parse import urljoin, urlparse
from json.decoder import JSONDecodeError
import dashscope
import numpy as np
import requests
from ollama import Client
from openai import OpenAI
from zhipuai import ZhipuAI
from common import settings
from common.exceptions import ModelException
from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response
from rag.llm.key_utils import _normalize_replicate_key
import logging
import base64
logger = logging.getLogger(__name__)
# Standard token ceiling for the common 8K-context embedding models (OpenAI
# text-embedding-*, Mistral, Bedrock Titan, ...). Inputs are truncated to this
# many tokens so boundary-sized chunks are not rejected by the provider.
DEFAULT_MAX_TOKENS = 8192
class EmbeddingError(ModelException):
"""Raised when an embedding provider fails to return usable embeddings.
A single, deterministic exception type for every provider failure path so
callers see consistent behaviour regardless of which SDK raised underneath.
Subclasses ``ModelException`` so the API error handler (and its retry
semantics) treats embedding failures like any other model failure.
"""
def _sorted_by_index(items):
"""Order OpenAI-style SDK embedding items by their ``.index`` so batched
results stay aligned with input order even if the provider returns them out
of order. Stable no-op when items carry no ``index`` attribute."""
return sorted(items, key=lambda d: getattr(d, "index", 0))
def _raise_model_exception_if_failed(resp):
status_code = resp.status_code
if status_code >= 400:
if status_code < 500 and status_code not in [408, 429]:
raise ModelException(f"status: {resp.status_code}, response: {resp.text}", retryable=False)
raise ModelException(f"status: {resp.status_code}, response: {resp.text}", retryable=True)
def _dashscope_base_url_for_log(base_url: str) -> str:
"""Log host/path only (no query string) so secrets in URLs are not printed."""
return base_url.split("?", 1)[0].strip()[:256]
def _dashscope_native_http_api_url(base_url: str | None) -> str | None:
"""
Resolve the DashScope *native* HTTP API root for Tongyi-Qianwen (Qwen) text embeddings.
RAGFlow often stores an OpenAI-compatible base URL (e.g. ``.../compatible-mode/v1``) for
the same provider. The ``dashscope`` Python SDK used by ``TextEmbedding.call`` does *not*
use that path; it expects ``https://<host>/api/v1`` instead.
Users outside mainland China are directed to the international endpoint
(``dashscope-intl.aliyuncs.com``); domestic traffic uses ``dashscope.aliyuncs.com``.
When ``base_url`` already points at the native API root (ends with ``/api/v1``), it is
returned unchanged so custom or regional deployments keep working.
"""
if not base_url:
return None
u = base_url.strip().rstrip("/")
safe = _dashscope_base_url_for_log(u)
if u.endswith("/api/v1"):
logger.debug("DashScope Tongyi-Qianwen embedding: using native API base as configured (%s)", safe)
return u
# Compare against the URL's hostname (not a substring of the full URL),
# so a base_url like https://attacker.example/?u=dashscope-intl.aliyuncs.com
# doesn't accidentally match. urlparse() requires a scheme; if the
# configured base_url is bare, treat the whole string as a hostname.
parsed = urlparse(u if "://" in u else "http://" + u)
host = (parsed.hostname or "").lower()
# International (Singapore) DashScope — required for overseas Tongyi-Qianwen accounts.
if host == "dashscope-intl.aliyuncs.com" or host.endswith(".dashscope-intl.aliyuncs.com"):
resolved = "https://dashscope-intl.aliyuncs.com/api/v1"
logger.info(
"DashScope Tongyi-Qianwen embedding: mapped configured base_url to intl native API (%s -> %s)",
safe,
resolved,
)
return resolved
# China mainland DashScope default host.
if host == "dashscope.aliyuncs.com" or host.endswith(".dashscope.aliyuncs.com"):
resolved = "https://dashscope.aliyuncs.com/api/v1"
logger.info(
"DashScope Tongyi-Qianwen embedding: mapped configured base_url to CN native API (%s -> %s)",
safe,
resolved,
)
return resolved
logger.warning(
"DashScope Tongyi-Qianwen embedding: base_url is set but not recognized as a DashScope host; using SDK default endpoint (%s)",
safe,
)
return None
@contextmanager
def _dashscope_native_api_url_scope(url: str | None):
"""
Temporarily set ``dashscope.base_http_api_url`` for the duration of a single SDK call,
then restore the previous value. Narrows the window where concurrent threads see a mismatch.
"""
if not url:
yield
return
prev = getattr(dashscope, "base_http_api_url", None)
dashscope.base_http_api_url = url
try:
yield
finally:
dashscope.base_http_api_url = prev
class Base(ABC):
def __init__(self, key, model_name, **kwargs):
"""
Constructor for abstract base class.
Parameters are accepted for interface consistency but are not stored.
Subclasses should implement their own initialization as needed.
"""
pass
def encode(self, texts: list):
raise NotImplementedError("Please implement encode method!")
def encode_queries(self, text: str):
raise NotImplementedError("Please implement encode method!")
def _batched_encode(self, texts: list, call_fn, *, batch_size: int, truncate_to: int | None = None):
"""Drive an embedding provider over ``texts`` in batches.
This is the shared template behind the OpenAI-style providers. It owns:
* optional per-text truncation to ``truncate_to`` tokens (skipped when
``None``) so oversized inputs do not get rejected by the provider;
* the batch loop, issuing ``ceil(len(texts) / batch_size)`` calls;
* accumulation of the per-text vectors into a single ``np.ndarray``;
* summation of the per-batch token counts;
* one deterministic, informative error path.
``call_fn`` is a provider-supplied closure ``call_fn(batch) ->
(embeddings, token_count)``. It performs the SDK/HTTP request *and*
parses the response (so a malformed/error response surfaces here), and
must not assume any particular response shape — the helper never touches
the raw response object. ``embeddings`` is a sequence of per-text
vectors; ``token_count`` is the real token usage for that batch.
Any exception raised by ``call_fn`` is wrapped in a single
:class:`EmbeddingError` that includes the underlying detail. We log and
raise here directly instead of relying on ``log_exception``'s implicit
raise (whose surfaced exception varies by SDK response shape).
"""
if truncate_to is not None:
texts = [truncate(t, truncate_to) for t in texts]
vectors = []
token_count = 0
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
try:
embeddings, tokens = call_fn(batch)
except ModelException:
# Already a structured (and possibly retryable) model error; keep it.
raise
except Exception as e:
logger.exception("%s embedding request failed", type(self).__name__)
raise EmbeddingError(f"Embedding request failed for {type(self).__name__}. Error: {e}") from e
vectors.extend(embeddings)
token_count += tokens
return np.array(vectors), token_count
@staticmethod
def _openai_http_embeddings(response):
"""Parse an OpenAI-compatible HTTP embeddings ``requests`` response.
Returns ``(embeddings, token_count)``. Raises a retryable-aware
:class:`ModelException` on a bad HTTP status, or surfaces the response
body (via :class:`EmbeddingError`) when the payload is not a successful
``{"data": [...]}`` response.
"""
_raise_model_exception_if_failed(response)
res = response.json()
if not isinstance(res, dict) or "data" not in res:
raise ValueError(f"unexpected embeddings response (status {getattr(response, 'status_code', '?')}): {res}")
# Keep results aligned with input order: OpenAI-compatible responses carry
# a per-item `index`; sorting by it is a no-op (stable) when it is absent.
data = sorted(res["data"], key=lambda d: d.get("index", 0))
return [d["embedding"] for d in data], total_token_count_from_response(res)
class BuiltinEmbed(Base):
_FACTORY_NAME = "Builtin"
MAX_TOKENS = {"Qwen/Qwen3-Embedding-0.6B": 30000, "BAAI/bge-m3": 8000, "BAAI/bge-small-en-v1.5": 500}
_model = None
_model_name = ""
_max_tokens = 500
_model_lock = threading.Lock()
def __init__(self, key, model_name, **kwargs):
logging.info(f"Initialize BuiltinEmbed according to settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}")
embedding_cfg = settings.EMBEDDING_CFG
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
with BuiltinEmbed._model_lock:
BuiltinEmbed._model_name = settings.EMBEDDING_MDL
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(settings.EMBEDDING_MDL, 500)
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], settings.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
self._model = BuiltinEmbed._model
self._model_name = BuiltinEmbed._model_name
self._max_tokens = BuiltinEmbed._max_tokens
def encode(self, texts: list):
batch_size = 16
# TEI is able to auto truncate inputs according to https://github.com/huggingface/text-embeddings-inference.
token_count = 0
batches = []
for i in range(0, len(texts), batch_size):
embeddings, token_count_delta = self._model.encode(texts[i : i + batch_size])
token_count += token_count_delta
batches.append(embeddings)
ress = np.vstack(batches) if batches else np.array([])
return ress, token_count
def encode_queries(self, text: str):
return self._model.encode_queries(text)
class OpenAIEmbed(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
def _call(self, batch):
res = self.client.embeddings.create(input=batch, model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
return [d.embedding for d in _sorted_by_index(res.data)], total_token_count_from_response(res)
def encode(self, texts: list):
# OpenAI requires batch size <=16; 8191 is the documented per-input token ceiling.
return self._batched_encode(texts, self._call, batch_size=16, truncate_to=8191)
def encode_queries(self, text):
vectors, token_count = self._batched_encode([text], self._call, batch_size=16, truncate_to=8191)
return vectors[0], token_count
class LocalAIEmbed(Base):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("Local embedding model url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key="empty", base_url=base_url)
self.model_name = model_name.split("___")[0]
def _call(self, batch):
res = self.client.embeddings.create(input=batch, model=self.model_name)
# Local servers (LocalAI / LM Studio) usually omit usage data; fall back
# to a local tiktoken count rather than fabricating a fixed number.
tokens = total_token_count_from_response(res)
if not tokens:
tokens = sum(num_tokens_from_string(t) for t in batch)
return [d.embedding for d in _sorted_by_index(res.data)], tokens
def encode(self, texts: list):
return self._batched_encode(texts, self._call, batch_size=16)
def encode_queries(self, text):
vectors, token_count = self._batched_encode([text], self._call, batch_size=16)
return vectors[0], token_count
def _resolve_azure_credentials(key):
try:
key_obj = json.loads(key)
if isinstance(key_obj, dict):
return key_obj.get("api_key", ""), key_obj.get("api_version", "2024-02-01")
logging.warning(
"Azure credential payload parsed as JSON but is not an object; using raw api_key string"
)
except (json.JSONDecodeError, TypeError):
logging.warning("Azure credential payload is not valid JSON; using raw api_key string")
return key, "2024-02-01"
class AzureEmbed(OpenAIEmbed):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, **kwargs):
from openai.lib.azure import AzureOpenAI
api_key, api_version = _resolve_azure_credentials(key)
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
class AstraflowEmbed(OpenAIEmbed):
_FACTORY_NAME = "Astraflow"
def __init__(self, key, model_name, base_url="https://api-us-ca.umodelverse.ai/v1"):
if not base_url:
base_url = "https://api-us-ca.umodelverse.ai/v1"
super().__init__(key, model_name, base_url)
class AstraflowCNEmbed(OpenAIEmbed):
_FACTORY_NAME = "Astraflow-CN"
def __init__(self, key, model_name, base_url="https://api.modelverse.cn/v1"):
if not base_url:
base_url = "https://api.modelverse.cn/v1"
super().__init__(key, model_name, base_url)
class FuturMixEmbed(OpenAIEmbed):
_FACTORY_NAME = "FuturMix"
def __init__(self, key, model_name="text-embedding-3-small", base_url="https://futurmix.ai/v1"):
if not base_url:
base_url = "https://futurmix.ai/v1"
super().__init__(key, model_name, base_url)
logging.info("[FuturMix] Embedding initialized with model %s", model_name)
class BaiChuanEmbed(OpenAIEmbed):
_FACTORY_NAME = "BaiChuan"
def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
if not base_url:
base_url = "https://api.baichuan-ai.com/v1"
super().__init__(key, model_name, base_url)
class QWenEmbed(Base):
"""
Embeddings for Alibaba Tongyi-Qianwen via the DashScope ``TextEmbedding`` API.
``base_url`` comes from the user's embedding-model configuration (often the same host
as the OpenAI-compatible chat endpoint). This class maps known DashScope hosts to the
native ``/api/v1`` base URL so international and China endpoints both work.
"""
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="text_embedding_v2", base_url=None, **kwargs):
self.key = key
self.model_name = model_name
# Native API root for the SDK; None if base_url is absent or not a known DashScope host.
self._dashscope_http_api_url = _dashscope_native_http_api_url(base_url)
def encode(self, texts: list):
import time
import dashscope
batch_size = 4
res = []
token_count = 0
texts = [truncate(t, 2048) for t in texts]
for i in range(0, len(texts), batch_size):
retry_max, retry_wait_secs = 5, 10
for retry in range(retry_max):
with _dashscope_native_api_url_scope(self._dashscope_http_api_url):
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
status_code = resp.status_code
if status_code >= 400 and status_code < 500 and status_code not in [408, 429]:
# No need to retry for 4XX error
raise ModelException(f"Error, status: {status_code}, response: {resp}")
if status_code == 200:
break
if retry < retry_max - 1:
logging.warning(f"Got error response from DashScope API (status: {status_code}, response: {resp}). Wait {retry_wait_secs} seconds. Retrying...")
time.sleep(retry_wait_secs)
else:
raise ModelException(f"Error after {retry_max} retries, status: {status_code}, response: {resp}")
try:
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
for e in resp["output"]["embeddings"]:
embds[e["text_index"]] = e["embedding"]
res.extend(embds)
token_count += total_token_count_from_response(resp)
except Exception as _e:
logger.exception("QWenEmbed: failed to parse embedding response")
raise EmbeddingError(f"Embedding request failed for QWenEmbed. Error: {_e}; response={resp}") from _e
return np.array(res), token_count
def encode_queries(self, text):
with _dashscope_native_api_url_scope(self._dashscope_http_api_url):
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
status_code = resp.status_code
if status_code != 200:
raise ModelException(f"Error: status: {status_code}: code: {resp.get('code')}, message: {resp.get('message')}")
# No need to retry for 4XX error
try:
return np.array(resp["output"]["embeddings"][0]["embedding"]), total_token_count_from_response(resp)
except Exception as _e:
logger.exception("QWenEmbed: failed to parse query embedding response")
raise EmbeddingError(f"Embedding request failed for QWenEmbed. Error: {_e}; response={resp}") from _e
class ZhipuEmbed(Base):
_FACTORY_NAME = "ZHIPU-AI"
def __init__(self, key, model_name="embedding-2", **kwargs):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
def _max_len(self):
# Per-model input ceilings; fall back to the standard 8K limit for any
# other model rather than leaving oversized inputs untruncated.
if self.model_name.lower() == "embedding-2":
return 512
if self.model_name.lower() == "embedding-3":
return 3072
return DEFAULT_MAX_TOKENS
def _call(self, batch):
# Batch like the other OpenAI-style providers: one request per batch
# instead of one request per text. Sort by index so the batched results
# stay aligned with input order.
res = self.client.embeddings.create(input=batch, model=self.model_name)
return [d.embedding for d in _sorted_by_index(res.data)], total_token_count_from_response(res)
def encode(self, texts: list):
return self._batched_encode(texts, self._call, batch_size=16, truncate_to=self._max_len())
def encode_queries(self, text):
vectors, token_count = self._batched_encode([text], self._call, batch_size=16, truncate_to=self._max_len())
return vectors[0], token_count
class OllamaEmbed(Base):
_FACTORY_NAME = "Ollama"
_special_tokens = ["<|endoftext|>"]
def __init__(self, key, model_name, **kwargs):
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"})
self.model_name = model_name
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
@classmethod
def _strip_special(cls, text: str) -> str:
for token in cls._special_tokens:
text = text.replace(token, "")
return text
def _call(self, batch):
# Batch via client.embed (accepts a list `input`) instead of one
# client.embeddings request per text. `truncate=True` lets Ollama clip
# oversized inputs to the model's real context length server-side, which
# is more accurate than a client-side cl100k estimate.
cleaned = [self._strip_special(t) for t in batch]
res = self.client.embed(model=self.model_name, input=cleaned, truncate=True, options={"use_mmap": True}, keep_alive=self.keep_alive)
# Ollama reports real prompt token usage in `prompt_eval_count`; fall
# back to a local count only if the server omits it (never a fixed 128).
tokens = res.get("prompt_eval_count") or 0
if not tokens:
tokens = sum(num_tokens_from_string(t) for t in cleaned)
return res["embeddings"], tokens
def encode(self, texts: list):
# No client-side truncation: Ollama truncates to the model context above.
return self._batched_encode(texts, self._call, batch_size=16)
def encode_queries(self, text):
vectors, token_count = self._batched_encode([text], self._call, batch_size=16)
return vectors[0], token_count
class XinferenceEmbed(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="", base_url=""):
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
def _call(self, batch):
res = self.client.embeddings.create(input=batch, model=self.model_name)
return [d.embedding for d in _sorted_by_index(res.data)], total_token_count_from_response(res)
def encode(self, texts: list):
return self._batched_encode(texts, self._call, batch_size=16)
def encode_queries(self, text):
vectors, token_count = self._batched_encode([text], self._call, batch_size=16)
return vectors[0], token_count
class YoudaoEmbed(Base):
_FACTORY_NAME = "Youdao"
_client = None
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
pass
def encode(self, texts: list):
batch_size = 10
res = []
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
for i in range(0, len(texts), batch_size):
embds = YoudaoEmbed._client.encode(texts[i : i + batch_size])
res.extend(embds)
return np.array(res), token_count
def encode_queries(self, text):
embds = YoudaoEmbed._client.encode([text])
return np.array(embds[0]), num_tokens_from_string(text)
class JinaMultiVecEmbed(Base):
_FACTORY_NAME = "Jina"
def __init__(self, key, model_name="jina-embeddings-v4", base_url="https://api.jina.ai/v1/embeddings"):
self.base_url = "https://api.jina.ai/v1/embeddings"
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name
@staticmethod
def _as_input_item(text):
if isinstance(text, str):
return {"text": text}
# bytes -> base64 encoded image
try:
base64.b64decode(text, validate=True)
return {"image": text.decode("utf8")}
except Exception:
return {"image": base64.b64encode(text).decode("utf8")}
def encode(self, texts: list[str | bytes], task="retrieval.passage"):
def _call(batch):
data = {"model": self.model_name, "input": [self._as_input_item(t) for t in batch]}
if "v4" in self.model_name:
data["return_multivector"] = True
if "v3" in self.model_name or "v4" in self.model_name:
data["task"] = task
data["truncate"] = True # let Jina truncate oversized inputs server-side
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
_raise_model_exception_if_failed(response)
res = response.json()
embs = []
for d in res["data"]:
if data.get("return_multivector", False): # v4
embs.append(np.asarray(d["embeddings"], dtype=np.float32).mean(axis=0))
else: # v2/v3
embs.append(np.asarray(d["embedding"], dtype=np.float32))
return embs, total_token_count_from_response(res)
# Inputs may be image bytes, so token truncation is left to the server.
return self._batched_encode(texts, _call, batch_size=16)
def encode_queries(self, text):
vectors, token_count = self.encode([text], task="retrieval.query")
return vectors[0], token_count
class MistralEmbed(Base):
_FACTORY_NAME = "Mistral"
def __init__(self, key, model_name="mistral-embed", base_url=None):
from mistralai.client import MistralClient
self.client = MistralClient(api_key=key)
self.model_name = model_name
def encode(self, texts: list):
import time
import random
texts = [truncate(t, DEFAULT_MAX_TOKENS) for t in texts]
batch_size = 16
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
retry_max = 5
while retry_max > 0:
try:
res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
ress.extend([d.embedding for d in res.data])
token_count += total_token_count_from_response(res)
break
except Exception as _e:
if retry_max == 1:
logger.exception("MistralEmbed: embedding request failed after retries")
raise EmbeddingError(f"Embedding request failed for MistralEmbed. Error: {_e}") from _e
delay = random.uniform(20, 60)
time.sleep(delay)
retry_max -= 1
return np.array(ress), token_count
def encode_queries(self, text):
import time
import random
retry_max = 5
while retry_max > 0:
try:
res = self.client.embeddings(input=[truncate(text, DEFAULT_MAX_TOKENS)], model=self.model_name)
return np.array(res.data[0].embedding), total_token_count_from_response(res)
except Exception as _e:
if retry_max == 1:
logger.exception("MistralEmbed: query embedding request failed after retries")
raise EmbeddingError(f"Embedding request failed for MistralEmbed. Error: {_e}") from _e
delay = random.randint(20, 60)
time.sleep(delay)
retry_max -= 1
class BedrockEmbed(Base):
_FACTORY_NAME = "Bedrock"
def __init__(self, key, model_name, **kwargs):
import boto3
# `key` protocol (backend stores as JSON string in `api_key`):
# - Must decode into a dict.
# - Required: `auth_mode`, `bedrock_region`.
# - Supported auth modes:
# - "access_key_secret": requires `bedrock_ak` + `bedrock_sk`.
# - "iam_role": requires `aws_role_arn` and assumes role via STS.
# - else: treated as "assume_role" (default AWS credential chain).
key = json.loads(key)
mode = key.get("auth_mode")
if not mode:
logging.error("Bedrock auth_mode is not provided in the key")
raise ValueError("Bedrock auth_mode must be provided in the key")
self.bedrock_region = key.get("bedrock_region")
self.model_name = model_name
self.is_amazon = self.model_name.split(".")[0] == "amazon"
self.is_cohere = self.model_name.split(".")[0] == "cohere"
if mode == "access_key_secret":
self.bedrock_ak = key.get("bedrock_ak")
self.bedrock_sk = key.get("bedrock_sk")
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
elif mode == "iam_role":
self.aws_role_arn = key.get("aws_role_arn")
sts_client = boto3.client("sts", region_name=self.bedrock_region)
resp = sts_client.assume_role(RoleArn=self.aws_role_arn, RoleSessionName="BedrockSession")
creds = resp["Credentials"]
self.client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=creds["AccessKeyId"],
aws_secret_access_key=creds["SecretAccessKey"],
aws_session_token=creds["SessionToken"],
)
else: # assume_role
self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region)
def _extract_vector(self, model_response):
# Titan returns {"embedding": [...]}; Cohere returns {"embeddings": [[...]]}.
if self.is_cohere:
return model_response["embeddings"][0]
return model_response["embedding"]
def encode(self, texts: list):
def _call(batch):
# Titan accepts a single input per call, so batch_size is 1.
text = batch[0]
if self.is_amazon:
body = {"inputText": text}
elif self.is_cohere:
body = {"texts": [text], "input_type": "search_document"}
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
model_response = json.loads(response["body"].read())
# Bedrock does not report token usage; count locally.
return [self._extract_vector(model_response)], num_tokens_from_string(text)
return self._batched_encode(texts, _call, batch_size=1, truncate_to=DEFAULT_MAX_TOKENS)
def encode_queries(self, text):
text = truncate(text, DEFAULT_MAX_TOKENS)
token_count = num_tokens_from_string(text)
if self.is_amazon:
body = {"inputText": text}
elif self.is_cohere:
body = {"texts": [text], "input_type": "search_query"}
try:
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
model_response = json.loads(response["body"].read())
return np.array(self._extract_vector(model_response)), token_count
except Exception as _e:
logger.exception("BedrockEmbed: query embedding request failed")
raise EmbeddingError(f"Embedding request failed for BedrockEmbed. Error: {_e}") from _e
class GeminiEmbed(Base):
_FACTORY_NAME = "Gemini"
def __init__(self, key, model_name="gemini-embedding-001", **kwargs):
from google import genai
from google.genai import types
self.key = key
self.model_name = model_name[7:] if model_name.startswith("models/") else model_name
self.client = genai.Client(api_key=self.key)
self.types = types
@staticmethod
def _parse_embedding_vector(embedding):
if isinstance(embedding, dict):
values = embedding.get("values")
if values is None:
values = embedding.get("embedding")
if values is not None:
return values
values = getattr(embedding, "values", None)
if values is None:
values = getattr(embedding, "embedding", None)
if values is not None:
return values
raise TypeError(f"Unsupported embedding payload: {type(embedding)}")
@classmethod
def _parse_embedding_response(cls, response):
if response is None:
raise ValueError("Embedding response is empty")
embeddings = getattr(response, "embeddings", None)
if embeddings is None and isinstance(response, dict):
embeddings = response.get("embeddings")
if embeddings is None:
return [cls._parse_embedding_vector(response)]
return [cls._parse_embedding_vector(item) for item in embeddings]
def _build_embedding_config(self):
task_type = "RETRIEVAL_DOCUMENT"
if hasattr(self.types, "TaskType"):
task_type = getattr(self.types.TaskType, "RETRIEVAL_DOCUMENT", task_type)
try:
return self.types.EmbedContentConfig(task_type=task_type, title="Embedding of single string")
except TypeError:
# Compatible with SDK versions that do not accept title in embed config.
return self.types.EmbedContentConfig(task_type=task_type)
def encode(self, texts: list):
config = self._build_embedding_config()
def _call(batch):
result = self.client.models.embed_content(model=self.model_name, contents=batch, config=config)
# Gemini embeddings do not report token usage; count locally.
return self._parse_embedding_response(result), sum(num_tokens_from_string(t) for t in batch)
return self._batched_encode(texts, _call, batch_size=16, truncate_to=2048)
def encode_queries(self, text):
config = self._build_embedding_config()
token_count = num_tokens_from_string(text)
try:
result = self.client.models.embed_content(
model=self.model_name,
contents=[truncate(text, 2048)],
config=config,
)
return np.array(self._parse_embedding_response(result)[0]), token_count
except Exception as _e:
logger.exception("GeminiEmbed: query embedding request failed")
raise EmbeddingError(f"Embedding request failed for GeminiEmbed. Error: {_e}") from _e
class NvidiaEmbed(Base):
_FACTORY_NAME = "NVIDIA"
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
if not base_url:
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
self.api_key = key
self.base_url = base_url
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"authorization": f"Bearer {self.api_key}",
}
self.model_name = model_name
if model_name == "nvidia/embed-qa-4":
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
self.model_name = "NV-Embed-QA"
if model_name == "snowflake/arctic-embed-l":
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
def _call(self, batch, input_type="query"):
payload = {
"input": batch,
"input_type": input_type,
"model": self.model_name,
"encoding_format": "float",
"truncate": "END", # NVIDIA truncates oversized inputs server-side.
}
response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30)
return self._openai_http_embeddings(response)
def encode(self, texts: list):
# NVIDIA NIM expects "passage" for documents (indexing) and "query" for retrieval.
return self._batched_encode(texts, lambda b: self._call(b, "passage"), batch_size=16)
def encode_queries(self, text):
vectors, token_count = self._batched_encode([text], lambda b: self._call(b, "query"), batch_size=16)
return vectors[0], token_count
class LmStudioEmbed(LocalAIEmbed):
_FACTORY_NAME = "LM-Studio"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
self.model_name = model_name
class OpenAI_APIEmbed(OpenAIEmbed):
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name.split("___")[0]
class CoHereEmbed(Base):
_FACTORY_NAME = "Cohere"
def __init__(self, key, model_name, base_url=None):
from cohere import Client
self.client = Client(api_key=key)
self.model_name = model_name
def _call(self, batch):
res = self.client.embed(
texts=batch,
model=self.model_name,
input_type="search_document",
embedding_types=["float"],
truncate="END", # let Cohere clip oversized inputs server-side instead of hard-failing
)
return list(res.embeddings.float), total_token_count_from_response(res)
def encode(self, texts: list):
return self._batched_encode(texts, self._call, batch_size=16)
def encode_queries(self, text):
try:
res = self.client.embed(
texts=[text],
model=self.model_name,
input_type="search_query",
embedding_types=["float"],
truncate="END",
)
return np.array(res.embeddings.float[0]), int(total_token_count_from_response(res))
except Exception as _e:
logger.exception("CoHereEmbed: query embedding request failed")
raise EmbeddingError(f"Embedding request failed for CoHereEmbed. Error: {_e}") from _e
class TogetherAIEmbed(OpenAIEmbed):
_FACTORY_NAME = "TogetherAI"
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
if not base_url:
base_url = "https://api.together.xyz/v1"
super().__init__(key, model_name, base_url=base_url)
class PerfXCloudEmbed(OpenAIEmbed):
_FACTORY_NAME = "PerfXCloud"
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
if not base_url:
base_url = "https://cloud.perfxlab.cn/v1"
super().__init__(key, model_name, base_url)
class UpstageEmbed(OpenAIEmbed):
_FACTORY_NAME = "Upstage"
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
if not base_url:
base_url = "https://api.upstage.ai/v1/solar"
super().__init__(key, model_name, base_url)
class SILICONFLOWEmbed(Base):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"):
normalized_base_url = (base_url or "").strip()
if not normalized_base_url:
normalized_base_url = "https://api.siliconflow.cn/v1/embeddings"
if "/embeddings" not in normalized_base_url:
normalized_base_url = urljoin(f"{normalized_base_url.rstrip('/')}/", "embeddings").rstrip("/")
self.headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {key}",
}
self.base_url = normalized_base_url
self.model_name = model_name
def _clean_batch(self, batch):
if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]:
# limit 512, 340 is almost safe
return [" " if not text.strip() else truncate(text, 256) for text in batch]
return [" " if not text.strip() else text for text in batch]
def _call(self, batch):
payload = {
"model": self.model_name,
"input": self._clean_batch(batch),
"encoding_format": "float",
}
response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30)
return self._openai_http_embeddings(response)
def encode(self, texts: list):
return self._batched_encode(texts, self._call, batch_size=16)
def encode_queries(self, text):
vectors, token_count = self._batched_encode([text], self._call, batch_size=16)
return vectors[0], token_count
class ReplicateEmbed(Base):
_FACTORY_NAME = "Replicate"
def __init__(self, key, model_name, base_url=None):
from replicate.client import Client
self.model_name = model_name
self.client = Client(api_token=_normalize_replicate_key(key))
def encode(self, texts: list):
batch_size = 16
token_count = sum([num_tokens_from_string(text) for text in texts])
ress = []
for i in range(0, len(texts), batch_size):
res = self.client.run(self.model_name, input={"texts": texts[i : i + batch_size]})
ress.extend(res)
return np.array(ress), token_count
def encode_queries(self, text):
res = self.client.embed(self.model_name, input={"texts": [text]})
return np.array(res), num_tokens_from_string(text)
class BaiduYiyanEmbed(Base):
_FACTORY_NAME = "BaiduYiyan"
def __init__(self, key, model_name, base_url=None):
import qianfan
key = json.loads(key)
ak = key.get("yiyan_ak", "")
sk = key.get("yiyan_sk", "")
self.client = qianfan.Embedding(ak=ak, sk=sk)
self.model_name = model_name
def encode(self, texts: list, batch_size=16):
try:
res = self.client.do(model=self.model_name, texts=texts).body
return (
np.array([r["embedding"] for r in res["data"]]),
total_token_count_from_response(res),
)
except Exception as _e:
logger.exception("BaiduYiyanEmbed: embedding request failed")
raise EmbeddingError(f"Embedding request failed for BaiduYiyanEmbed. Error: {_e}") from _e
def encode_queries(self, text):
try:
res = self.client.do(model=self.model_name, texts=[text]).body
return (
np.array([r["embedding"] for r in res["data"]]),
total_token_count_from_response(res),
)
except Exception as _e:
logger.exception("BaiduYiyanEmbed: query embedding request failed")
raise EmbeddingError(f"Embedding request failed for BaiduYiyanEmbed. Error: {_e}") from _e
class VoyageEmbed(Base):
_FACTORY_NAME = "Voyage AI"
def __init__(self, key, model_name, base_url=None):
import voyageai
self.client = voyageai.Client(api_key=key)
self.model_name = model_name
def _call(self, batch):
res = self.client.embed(texts=batch, model=self.model_name, input_type="document")
# `_batched_encode` accumulates these per-batch vectors and returns a
# single np.ndarray, so encode() keeps the np.ndarray contract.
return res.embeddings, res.total_tokens
def encode(self, texts: list):
return self._batched_encode(texts, self._call, batch_size=16)
def encode_queries(self, text):
try:
res = self.client.embed(texts=text, model=self.model_name, input_type="query")
return np.array(res.embeddings)[0], res.total_tokens
except Exception as _e:
logger.exception("VoyageEmbed: query embedding request failed")
raise EmbeddingError(f"Embedding request failed for VoyageEmbed. Error: {_e}") from _e
class HuggingFaceEmbed(Base):
_FACTORY_NAME = "HuggingFace"
def __init__(self, key, model_name, base_url=None, **kwargs):
if not model_name:
raise ValueError("Model name cannot be None")
self.key = key
self.model_name = model_name.split("___")[0]
self.base_url = base_url or "http://127.0.0.1:8080"
def encode(self, texts: list):
response = requests.post(f"{self.base_url}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"}, timeout=30)
_raise_model_exception_if_failed(response)
# TEI auto-truncates oversized inputs, so no client-side truncation is needed.
return np.array(response.json()), sum([num_tokens_from_string(text) for text in texts])
def encode_queries(self, text: str):
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"}, timeout=30)
_raise_model_exception_if_failed(response)
return np.array(response.json()[0]), num_tokens_from_string(text)
class VolcEngineEmbed(Base):
_FACTORY_NAME = "VolcEngine"
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
if not base_url:
base_url = "https://ark.cn-beijing.volces.com/api/v3"
self.base_url = base_url
try:
cfg = json.loads(key)
self.ark_api_key = cfg.get("ark_api_key", "")
except JSONDecodeError:
self.ark_api_key = key
self.model_name = model_name
@staticmethod
def _extract_embedding(result: dict) -> list[float]:
if not isinstance(result, dict):
raise TypeError(f"Unexpected response type: {type(result)}")
data = result.get("data")
if data is None:
raise KeyError("Missing 'data' in response")
if isinstance(data, list):
if not data:
raise ValueError("Empty 'data' in response")
item = data[0]
elif isinstance(data, dict):
item = data
else:
raise TypeError(f"Unexpected 'data' type: {type(data)}")
if not isinstance(item, dict):
raise TypeError("Unexpected item shape in 'data'")
if "embedding" not in item:
raise KeyError("Missing 'embedding' in response item")
return item["embedding"]
def _encode_texts(self, texts: list[str]):
from common.http_client import sync_request
url = f"{self.base_url}/embeddings/multimodal"
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.ark_api_key}"}
ress: list[list[float]] = []
total_tokens = 0
for text in texts:
request_body = {"model": self.model_name, "input": [{"type": "text", "text": text}]}
response = sync_request(method="POST", url=url, headers=headers, json=request_body, timeout=60)
if response.status_code != 200:
raise EmbeddingError(f"Embedding request failed for VolcEngineEmbed. Error: {response.status_code} - {response.text}")
result = response.json()
try:
ress.append(self._extract_embedding(result))
total_tokens += total_token_count_from_response(result)
except Exception as _e:
logger.exception("VolcEngineEmbed: failed to parse embedding response")
raise EmbeddingError(f"Embedding request failed for VolcEngineEmbed. Error: {_e}; response={result}") from _e
return np.array(ress), total_tokens
def encode(self, texts: list):
return self._encode_texts(texts)
def encode_queries(self, text: str):
embeddings, tokens = self._encode_texts([text])
return embeddings[0], tokens
class GPUStackEmbed(OpenAIEmbed):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class NovitaEmbed(SILICONFLOWEmbed):
_FACTORY_NAME = "NovitaAI"
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"):
if not base_url:
base_url = "https://api.novita.ai/v3/openai/embeddings"
super().__init__(key, model_name, base_url)
class GiteeEmbed(SILICONFLOWEmbed):
_FACTORY_NAME = "GiteeAI"
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"):
if not base_url:
base_url = "https://ai.gitee.com/v1/embeddings"
super().__init__(key, model_name, base_url)
class DeepInfraEmbed(OpenAIEmbed):
_FACTORY_NAME = "DeepInfra"
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai"):
if not base_url:
base_url = "https://api.deepinfra.com/v1/openai"
super().__init__(key, model_name, base_url)
class Ai302Embed(Base):
_FACTORY_NAME = "302.AI"
def __init__(self, key, model_name, base_url="https://api.302.ai/v1/embeddings"):
if not base_url:
base_url = "https://api.302.ai/v1/embeddings"
super().__init__(key, model_name, base_url)
class CometAPIEmbed(OpenAIEmbed):
_FACTORY_NAME = "CometAPI"
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1"):
if not base_url:
base_url = "https://api.cometapi.com/v1"
super().__init__(key, model_name, base_url)
class DeerAPIEmbed(OpenAIEmbed):
_FACTORY_NAME = "DeerAPI"
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1"):
if not base_url:
base_url = "https://api.deerapi.com/v1"
super().__init__(key, model_name, base_url)
class JiekouAIEmbed(OpenAIEmbed):
_FACTORY_NAME = "Jiekou.AI"
def __init__(self, key, model_name, base_url="https://api.jiekou.ai/openai/v1/embeddings"):
if not base_url:
base_url = "https://api.jiekou.ai/openai/v1/embeddings"
super().__init__(key, model_name, base_url)
class RAGconEmbed(OpenAIEmbed):
"""
RAGcon Embedding Provider - routes through LiteLLM proxy
Default Base URL: https://connect.ragcon.ai/v1
"""
_FACTORY_NAME = "RAGcon"
def __init__(self, key, model_name="text-embedding-3-small", base_url=None):
if not base_url:
base_url = "https://connect.ragcon.com/v1"
super().__init__(key, model_name, base_url)
class PerplexityEmbed(Base):
_FACTORY_NAME = "Perplexity"
def __init__(self, key, model_name="pplx-embed-v1-0.6b", base_url="https://api.perplexity.ai"):
if not base_url:
base_url = "https://api.perplexity.ai"
self.base_url = base_url.rstrip("/")
self.api_key = key
self.model_name = model_name
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
@staticmethod
def _decode_base64_int8(b64_str):
raw = base64.b64decode(b64_str)
return np.frombuffer(raw, dtype=np.int8).astype(np.float32)
def _is_contextualized(self):
return "context" in self.model_name
def encode(self, texts: list):
batch_size = 512
ress = []
token_count = 0
if self._is_contextualized():
url = f"{self.base_url}/v1/contextualizedembeddings"
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
payload = {
"model": self.model_name,
"input": [[chunk] for chunk in batch],
"encoding_format": "base64_int8",
}
response = requests.post(url, headers=self.headers, json=payload, timeout=30)
_raise_model_exception_if_failed(response)
try:
res = response.json()
for doc in res["data"]:
for chunk_emb in doc["data"]:
ress.append(self._decode_base64_int8(chunk_emb["embedding"]))
token_count += res.get("usage", {}).get("total_tokens", 0)
except Exception as _e:
logger.exception("PerplexityEmbed: failed to parse contextualized embedding response")
raise EmbeddingError(f"Embedding request failed for PerplexityEmbed. Error: {response.text}") from _e
else:
url = f"{self.base_url}/v1/embeddings"
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
payload = {
"model": self.model_name,
"input": batch,
"encoding_format": "base64_int8",
}
response = requests.post(url, headers=self.headers, json=payload, timeout=30)
_raise_model_exception_if_failed(response)
try:
res = response.json()
for d in res["data"]:
ress.append(self._decode_base64_int8(d["embedding"]))
token_count += res.get("usage", {}).get("total_tokens", 0)
except Exception as _e:
logger.exception("PerplexityEmbed: failed to parse embedding response")
raise EmbeddingError(f"Embedding request failed for PerplexityEmbed. Error: {response.text}") from _e
return np.array(ress), token_count
def encode_queries(self, text):
embds, cnt = self.encode([text])
return np.array(embds[0]), cnt
class NewAPIEmbed(OpenAIEmbed):
_FACTORY_NAME = "New API"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name.split("___")[0]