mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
## Summary Resolves all 93 open alerts at https://github.com/infiniflow/ragflow/security/code-scanning by rule: | Rule | Count | Treatment | |------|-------|-----------| | py/clear-text-logging-sensitive-data | 23 | Real fix — log scrubbing | | go/path-injection | 15 | Real fix where possible, suppression with rationale | | go/request-forgery | 8 | Suppression with rationale (operator-controlled URLs) | | go/clear-text-logging | 10 | Real fix — log scrubbing | | go/unsafe-quoting | 5 | Real fix — escape or refactor | | go/sql-injection | 3 | Real fix — orderby whitelist + CodeQL comment | | go/uncontrolled-allocation-size | 2 | Real fix — cap to 1024 | | go/incorrect-integer-conversion | 3 | Real fix — ParseInt + range check | | go/insecure-hostkeycallback | 1 | Real fix — known_hosts file | | go/disabled-certificate-check | 2 | Suppression with rationale | | go/command-injection | 1 | Suppression (sanitized via shq()) | | go/email-injection | 1 | Suppression with rationale | | go/cookie-httponly-not-set | 1 | Suppression (SPA bootstrap) | | js/stack-trace-exposure | 1 | Real fix — generic client message | | js/prototype-pollution-utility | 1 | Real fix — reject __proto__/constructor/prototype | | py/weak-sensitive-data-hashing | 1 | Real fix — MD5 → SHA-256 | | py/incomplete-url-substring-sanitization | 3 | Real fix — urlparse(hostname) | | py/paramiko-missing-host-key-validation | 1 | Real fix — load_system_host_keys + RejectPolicy | | cpp/integer-multiplication-cast-to-long | 2 | Real fix — cast to size_t | ## Real fixes (with measurable security improvement) **SSH host key verification (Go + Python)** Replace `InsecureIgnoreHostKey()` / `paramiko.AutoAddPolicy()` with proper host key verification against a known_hosts file (configurable via `SSH_KNOWN_HOSTS` env / `known_hosts` config field; fail-closed when unset). Loads `~/.ssh/known_hosts` first via `load_system_host_keys()` so existing setups keep working. **SQL injection in `user_canvas`** Add `userCanvasOrderableColumns` whitelist + `userCanvasOrderClause` helper. Both `GetList()` and `ListByTenantIDs()` now route the user-supplied `orderby` query param through the helper, defaulting to `create_time` on miss. **SQL injection in `pipeline_operation_log`** Existing whitelist documented via CodeQL comment. **Real SQL injection in `infinity/chunk.go:931`** Escape `'` → `''` on user-controlled `questionText` before splicing into `filter_fulltext(...)` SQL filter. **Real SQL injection in `elasticsearch/sql.go:75`** Defense-in-depth escape on tokenizer output before splicing into `MATCH(...)`. **Python code injection in `result_protocol.go`** Replace raw JSON literal embedding into Python/JS expressions with base64 + `json.loads` / `JSON.parse(Buffer.from(..., 'base64').toString('utf8'))`. Eliminates both the unsafe-quoting sink and the brittleness of mixing JSON true/false/null with Python syntax. **URL substring check bypass in `embedding_model.py`** Replace `if "dashscope-intl.aliyuncs.com" in u` with `urlparse(u).hostname == "dashscope-intl.aliyuncs.com"` so a base_url like `https://attacker.example/?u=dashscope-intl.aliyuncs.com` cannot bypass the routing. **Prototype pollution in `setNestedValue` (TS)** Reject `__proto__`/`constructor`/`prototype` keys before any assignment. **Integer overflow** - scrypt params via `ParseInt` + non-positive check (`internal/common/password.go`) - `topN` and `n` caps to 1024 (retrieval_service.go, dataset.go) - `nalloc*statesize` cast to `size_t` (cpp/re2/onepass.cc) **Cookie httponly** Set explicitly with rationale: this is the OAuth bootstrap cookie intentionally read by the SPA. **Stack trace exposure** Replace `error.message` in HTTP 500 response with generic `"internal error"`; full error still logged server-side via `console.error`. **Weak hashing** MD5 → SHA-256 for deterministic `conv_id` derivation (`conversation_service.py`). **Log scrubbing** Remove or redact user-controlled / sensitive content from clear-text logs across 8 ingestion parsers, `llm_service.py` ×11, `tenant_llm_service.py` ×7, `misc_utils.py` ×4, `redis_conn.py` ×10, `conftest.py` ×4, `init_data.py`, `dataset_api_service.py`, `generator.py`, `mysql_migration.py`, `cli.go`, `user_command.go`, `pdf_parser.go`. Most patterns converted to parameterized logging (`logging.info("...: %d", n)`) or static messages. ## CodeQL suppressions (each with rationale) For alerts where the data flow is genuinely safe but CodeQL can't see the context — operator-controlled URLs, sanitized inputs, etc. — I added `// codeql[go/<rule>] <rationale>` annotations rather than dismissing them, so future readers can audit the rationale inline: - `internal/agent/component/invoke.go:135` — Invoke is a generic canvas HTTP client - `internal/service/langfuse.go` ×2 — host is per-tenant operator config - `internal/service/file.go:1184` — already SSRF-guarded by `assertURLSafe` - `internal/utility/mcp_client.go` ×3 — already `AssertURLSafe` + IP-pinned - `internal/entity/models/bedrock.go` — sigv4-signed request, URL can't be tampered - `internal/service/deep_researcher.go:269` — `callback` is SSE display string, not SQL - `internal/engine/infinity/chunk.go:346` — UUIDs can't contain `'` (RFC 4122) - `internal/cli/common_command.go` ×2 — CLI trusts operator-configured URL - `internal/utility/smtp.go:194` — msg is server-built, not user form input - `internal/entity/models/*` ×14 (path-injection) — audio file paths are caller-supplied ## Test plan - ✅ All 13 modified Go packages build cleanly - ✅ 663 tests pass across `internal/agent/sandbox`, `internal/common`, `internal/agent/component`, `internal/engine/infinity`, `internal/dao` - ✅ All 11 modified Python files parse via `ast.parse` - ✅ TypeScript `tsc --noEmit` clean on the modified `use-provider-fields.tsx` - ✅ `node --check` clean on the modified JS file 🤖 Generated with [Claude Code](https://claude.com/claude-code)
561 lines
23 KiB
Python
561 lines
23 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import os
|
|
import json
|
|
import logging
|
|
from peewee import IntegrityError
|
|
from langfuse import Langfuse
|
|
from common import settings
|
|
from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, OPENDATALOADER_DEFAULT_CONFIG, OPENDATALOADER_ENV_KEYS, PADDLEOCR_DEFAULT_CONFIG, PADDLEOCR_ENV_KEYS, LLMType
|
|
from api.db.db_models import DB, LLMFactories, TenantLLM
|
|
from api.db.services.common_service import CommonService
|
|
from api.db.services.langfuse_service import TenantLangfuseService
|
|
from api.db.services.user_service import TenantService
|
|
|
|
|
|
class LLMFactoriesService(CommonService):
|
|
model = LLMFactories
|
|
|
|
|
|
class TenantLLMService(CommonService):
|
|
model = TenantLLM
|
|
|
|
@staticmethod
|
|
def _decode_api_key_config(raw_api_key: str) -> tuple[str, bool | None, str | None]:
|
|
if not raw_api_key:
|
|
return raw_api_key, None, None
|
|
|
|
try:
|
|
parsed = json.loads(raw_api_key)
|
|
except Exception:
|
|
return raw_api_key, None, None
|
|
|
|
if not isinstance(parsed, dict):
|
|
return raw_api_key, None, None
|
|
|
|
is_tools = bool(parsed["is_tools"]) if "is_tools" in parsed else None
|
|
if set(parsed.keys()) <= {"api_key", "is_tools"}:
|
|
return parsed.get("api_key", ""), is_tools, None
|
|
|
|
return parsed.get("api_key", raw_api_key), is_tools, raw_api_key
|
|
|
|
@staticmethod
|
|
def _encode_api_key_config(raw_api_key: str, is_tools: bool | None) -> str:
|
|
if is_tools is None:
|
|
return raw_api_key
|
|
|
|
try:
|
|
parsed = json.loads(raw_api_key or "{}")
|
|
except Exception:
|
|
parsed = None
|
|
|
|
if isinstance(parsed, dict):
|
|
payload = dict(parsed)
|
|
payload["is_tools"] = bool(is_tools)
|
|
return json.dumps(payload)
|
|
|
|
return json.dumps({"api_key": raw_api_key or "", "is_tools": bool(is_tools)})
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_api_key(cls, tenant_id, model_name, model_type=None):
|
|
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
|
model_type_val = model_type.value if hasattr(model_type, "value") else model_type
|
|
query_kwargs = {"tenant_id": tenant_id, "llm_name": mdlnm}
|
|
if model_type_val is not None:
|
|
query_kwargs["model_type"] = model_type_val
|
|
if not fid:
|
|
objs = cls.query(**query_kwargs)
|
|
else:
|
|
objs = cls.query(**query_kwargs, llm_factory=fid)
|
|
|
|
if (not objs) and fid:
|
|
if fid == "LocalAI":
|
|
mdlnm += "___LocalAI"
|
|
elif fid == "HuggingFace":
|
|
mdlnm += "___HuggingFace"
|
|
elif fid == "OpenAI-API-Compatible":
|
|
mdlnm += "___OpenAI-API"
|
|
elif fid == "VLLM":
|
|
mdlnm += "___VLLM"
|
|
query_kwargs["llm_name"] = mdlnm
|
|
objs = cls.query(**query_kwargs, llm_factory=fid)
|
|
if not objs:
|
|
return None
|
|
return objs[0]
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_my_llms(cls, tenant_id):
|
|
fields = [cls.model.id, cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens, cls.model.status]
|
|
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
|
|
|
return list(objs)
|
|
|
|
@staticmethod
|
|
def split_model_name_and_factory(model_name):
|
|
arr = model_name.split("@")
|
|
if len(arr) < 2:
|
|
return model_name, None
|
|
if len(arr) > 2:
|
|
return "@".join(arr[0:-1]), arr[-1]
|
|
|
|
# model name must be xxx@yyy
|
|
try:
|
|
model_factories = settings.FACTORY_LLM_INFOS
|
|
model_providers = set([f["name"] for f in model_factories])
|
|
if arr[-1] not in model_providers:
|
|
return model_name, None
|
|
return arr[0], arr[-1]
|
|
except Exception as e:
|
|
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
|
|
return model_name, None
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
|
|
from api.db.services.llm_service import LLMService
|
|
|
|
e, tenant = TenantService.get_by_id(tenant_id)
|
|
if not e:
|
|
raise LookupError("Tenant not found")
|
|
|
|
if llm_type == LLMType.EMBEDDING.value:
|
|
mdlnm = tenant.embd_id if not llm_name else llm_name
|
|
elif llm_type == LLMType.SPEECH2TEXT.value:
|
|
mdlnm = tenant.asr_id if not llm_name else llm_name
|
|
elif llm_type == LLMType.IMAGE2TEXT.value:
|
|
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
|
elif llm_type == LLMType.CHAT.value:
|
|
mdlnm = tenant.llm_id if not llm_name else llm_name
|
|
elif llm_type == LLMType.RERANK:
|
|
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
|
elif llm_type == LLMType.TTS:
|
|
mdlnm = tenant.tts_id if not llm_name else llm_name
|
|
elif llm_type == LLMType.OCR:
|
|
if not llm_name:
|
|
raise LookupError("OCR model name is required")
|
|
mdlnm = llm_name
|
|
else:
|
|
assert False, "LLM type error"
|
|
|
|
model_config = cls.get_api_key(tenant_id, mdlnm, llm_type)
|
|
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
|
if not model_config: # for some cases seems fid mismatch
|
|
model_config = cls.get_api_key(tenant_id, mdlnm, llm_type)
|
|
if model_config:
|
|
model_config = model_config.to_dict()
|
|
api_key, is_tools, api_key_payload = cls._decode_api_key_config(model_config.get("api_key", ""))
|
|
model_config["api_key"] = api_key
|
|
if api_key_payload is not None:
|
|
model_config["api_key_payload"] = api_key_payload
|
|
if is_tools is not None:
|
|
model_config["is_tools"] = is_tools
|
|
elif llm_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv("TEI_MODEL", ""):
|
|
embedding_cfg = settings.EMBEDDING_CFG
|
|
model_config = {"llm_factory": "Builtin", "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
|
|
else:
|
|
raise LookupError(f"Model({mdlnm}@{fid}) not authorized")
|
|
|
|
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
|
if not llm and fid: # for some cases seems fid mismatch
|
|
llm = LLMService.query(llm_name=mdlnm)
|
|
if "is_tools" not in model_config and llm:
|
|
model_config["is_tools"] = llm[0].is_tools
|
|
return model_config
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def model_instance(cls, model_config: dict, lang="Chinese", **kwargs):
|
|
if not model_config:
|
|
raise LookupError("Model config is required")
|
|
from rag.llm import ChatModel, CvModel, EmbeddingModel, OcrModel, RerankModel, Seq2txtModel, TTSModel
|
|
|
|
kwargs.update({"provider": model_config["llm_factory"]})
|
|
api_key = model_config.get("api_key_payload", model_config["api_key"])
|
|
if model_config["model_type"] == LLMType.EMBEDDING.value:
|
|
if model_config["llm_factory"] not in EmbeddingModel:
|
|
logging.error("Factory not in embedding model. Supported factories: %s", list(EmbeddingModel.keys()))
|
|
return None
|
|
return EmbeddingModel[model_config["llm_factory"]](api_key, model_config["llm_name"], base_url=model_config["api_base"])
|
|
|
|
elif model_config["model_type"] == LLMType.RERANK.value:
|
|
if model_config["llm_factory"] not in RerankModel:
|
|
logging.error("Factory not in rerank model. Supported factories: %s", list(RerankModel.keys()))
|
|
return None
|
|
return RerankModel[model_config["llm_factory"]](api_key, model_config["llm_name"], base_url=model_config["api_base"])
|
|
|
|
elif model_config["model_type"] == LLMType.IMAGE2TEXT.value:
|
|
if model_config["llm_factory"] not in CvModel:
|
|
logging.error("Factory not in cv model. Supported factories: %s", list(CvModel.keys()))
|
|
return None
|
|
return CvModel[model_config["llm_factory"]](api_key, model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
|
|
|
elif model_config["model_type"] == LLMType.CHAT.value:
|
|
if model_config["llm_factory"] not in ChatModel:
|
|
logging.error("Factory not in chat model. Supported factories: %s", list(ChatModel.keys()))
|
|
return None
|
|
return ChatModel[model_config["llm_factory"]](api_key, model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
|
|
|
elif model_config["model_type"] == LLMType.SPEECH2TEXT.value:
|
|
if model_config["llm_factory"] not in Seq2txtModel:
|
|
logging.error("Factory not in speech2text model. Supported factories: %s", list(Seq2txtModel.keys()))
|
|
return None
|
|
return Seq2txtModel[model_config["llm_factory"]](key=api_key, model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
|
elif model_config["model_type"] == LLMType.TTS.value:
|
|
if model_config["llm_factory"] not in TTSModel:
|
|
logging.error("Factory not in tts model. Supported factories: %s", list(TTSModel.keys()))
|
|
return None
|
|
return TTSModel[model_config["llm_factory"]](
|
|
api_key,
|
|
model_config["llm_name"],
|
|
base_url=model_config["api_base"],
|
|
)
|
|
|
|
elif model_config["model_type"] == LLMType.OCR.value:
|
|
if model_config["llm_factory"] not in OcrModel:
|
|
logging.error("Factory not in ocr model. Supported factories: %s", list(OcrModel.keys()))
|
|
return None
|
|
return OcrModel[model_config["llm_factory"]](
|
|
key=api_key,
|
|
model_name=model_config["llm_name"],
|
|
base_url=model_config.get("api_base", ""),
|
|
**kwargs,
|
|
)
|
|
|
|
return None
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
|
|
e, tenant = TenantService.get_by_id(tenant_id)
|
|
if not e:
|
|
logging.error(f"Tenant not found: {tenant_id}")
|
|
return 0
|
|
|
|
llm_map = {
|
|
LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
|
|
LLMType.SPEECH2TEXT.value: tenant.asr_id,
|
|
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
|
|
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
|
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
|
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
|
LLMType.OCR.value: llm_name,
|
|
}
|
|
|
|
mdlnm = llm_map.get(llm_type)
|
|
if mdlnm is None:
|
|
logging.error(f"LLM type error: {llm_type}")
|
|
return 0
|
|
|
|
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
|
|
|
try:
|
|
num = (
|
|
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
|
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
|
.execute()
|
|
)
|
|
except Exception:
|
|
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
|
return 0
|
|
|
|
return num
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def increase_usage_by_id(cls, tenant_model_id: int, used_tokens: int):
|
|
try:
|
|
update_cnt = cls.model.update(used_tokens=cls.model.used_tokens + used_tokens).where(cls.model.id == tenant_model_id).execute()
|
|
except Exception as e:
|
|
logging.exception(f"TenantLLMService.increase_usage got exception {e}, Failed to update used_tokens for tenant_model_id {tenant_model_id}")
|
|
return 0
|
|
return update_cnt
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_openai_models(cls):
|
|
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
|
return list(objs)
|
|
|
|
@classmethod
|
|
def _collect_mineru_env_config(cls) -> dict | None:
|
|
cfg = MINERU_DEFAULT_CONFIG
|
|
found = False
|
|
for key in MINERU_ENV_KEYS:
|
|
val = os.environ.get(key)
|
|
if val:
|
|
found = True
|
|
cfg[key] = val
|
|
return cfg if found else None
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def ensure_mineru_from_env(cls, tenant_id: str) -> str | None:
|
|
"""
|
|
Ensure a MinerU OCR model exists for the tenant if env variables are present.
|
|
Return the existing or newly created llm_name, or None if env not set.
|
|
"""
|
|
cfg = cls._collect_mineru_env_config()
|
|
if not cfg:
|
|
return None
|
|
|
|
saved_mineru_models = cls.query(tenant_id=tenant_id, llm_factory="MinerU", model_type=LLMType.OCR.value)
|
|
|
|
def _parse_api_key(raw: str) -> dict:
|
|
try:
|
|
return json.loads(raw or "{}")
|
|
except Exception:
|
|
return {}
|
|
|
|
for item in saved_mineru_models:
|
|
api_cfg = _parse_api_key(item.api_key)
|
|
normalized = {k: api_cfg.get(k, MINERU_DEFAULT_CONFIG.get(k)) for k in MINERU_ENV_KEYS}
|
|
if normalized == cfg:
|
|
return item.llm_name
|
|
|
|
used_names = {item.llm_name for item in saved_mineru_models}
|
|
idx = 1
|
|
base_name = "mineru-from-env"
|
|
while True:
|
|
candidate = f"{base_name}-{idx}"
|
|
if candidate in used_names:
|
|
idx += 1
|
|
continue
|
|
|
|
try:
|
|
cls.save(
|
|
tenant_id=tenant_id,
|
|
llm_factory="MinerU",
|
|
llm_name=candidate,
|
|
model_type=LLMType.OCR.value,
|
|
api_key=json.dumps(cfg),
|
|
api_base="",
|
|
max_tokens=0,
|
|
)
|
|
return candidate
|
|
except IntegrityError:
|
|
logging.warning("MinerU env model %s already exists for tenant %s, retry with next name", candidate, tenant_id)
|
|
used_names.add(candidate)
|
|
idx += 1
|
|
continue
|
|
|
|
@classmethod
|
|
def _collect_paddleocr_env_config(cls) -> dict | None:
|
|
cfg = PADDLEOCR_DEFAULT_CONFIG
|
|
found = False
|
|
for key in PADDLEOCR_ENV_KEYS:
|
|
val = os.environ.get(key)
|
|
if val:
|
|
found = True
|
|
cfg[key] = val
|
|
return cfg if found else None
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def ensure_paddleocr_from_env(cls, tenant_id: str) -> str | None:
|
|
"""
|
|
Ensure a PaddleOCR model exists for the tenant if env variables are present.
|
|
Return the existing or newly created llm_name, or None if env not set.
|
|
"""
|
|
cfg = cls._collect_paddleocr_env_config()
|
|
if not cfg:
|
|
return None
|
|
|
|
saved_paddleocr_models = cls.query(tenant_id=tenant_id, llm_factory="PaddleOCR", model_type=LLMType.OCR.value)
|
|
|
|
def _parse_api_key(raw: str) -> dict:
|
|
try:
|
|
return json.loads(raw or "{}")
|
|
except Exception:
|
|
return {}
|
|
|
|
for item in saved_paddleocr_models:
|
|
api_cfg = _parse_api_key(item.api_key)
|
|
normalized = {k: api_cfg.get(k, PADDLEOCR_DEFAULT_CONFIG.get(k)) for k in PADDLEOCR_ENV_KEYS}
|
|
if normalized == cfg:
|
|
return item.llm_name
|
|
|
|
used_names = {item.llm_name for item in saved_paddleocr_models}
|
|
idx = 1
|
|
base_name = "paddleocr-from-env"
|
|
while True:
|
|
candidate = f"{base_name}-{idx}"
|
|
if candidate in used_names:
|
|
idx += 1
|
|
continue
|
|
|
|
try:
|
|
cls.save(
|
|
tenant_id=tenant_id,
|
|
llm_factory="PaddleOCR",
|
|
llm_name=candidate,
|
|
model_type=LLMType.OCR.value,
|
|
api_key=json.dumps(cfg),
|
|
api_base="",
|
|
max_tokens=0,
|
|
)
|
|
return candidate
|
|
except IntegrityError:
|
|
logging.warning("PaddleOCR env model %s already exists for tenant %s, retry with next name", candidate, tenant_id)
|
|
used_names.add(candidate)
|
|
idx += 1
|
|
continue
|
|
|
|
@classmethod
|
|
def _collect_opendataloader_env_config(cls) -> dict | None:
|
|
cfg = dict(OPENDATALOADER_DEFAULT_CONFIG)
|
|
found = False
|
|
for key in OPENDATALOADER_ENV_KEYS:
|
|
val = os.environ.get(key)
|
|
if val:
|
|
found = True
|
|
cfg[key] = val
|
|
return cfg if found else None
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def ensure_opendataloader_from_env(cls, tenant_id: str) -> str | None:
|
|
"""
|
|
Ensure an OpenDataLoader OCR model exists for the tenant if env variables are present.
|
|
Return the existing or newly created llm_name, or None if env not set.
|
|
"""
|
|
cfg = cls._collect_opendataloader_env_config()
|
|
if not cfg:
|
|
return None
|
|
|
|
saved_models = cls.query(tenant_id=tenant_id, llm_factory="OpenDataLoader", model_type=LLMType.OCR.value)
|
|
|
|
def _parse_api_key(raw: str) -> dict:
|
|
try:
|
|
return json.loads(raw or "{}")
|
|
except Exception:
|
|
return {}
|
|
|
|
for item in saved_models:
|
|
api_cfg = _parse_api_key(item.api_key)
|
|
normalized = {k: api_cfg.get(k, OPENDATALOADER_DEFAULT_CONFIG.get(k)) for k in OPENDATALOADER_ENV_KEYS}
|
|
if normalized == cfg:
|
|
return item.llm_name
|
|
|
|
used_names = {item.llm_name for item in saved_models}
|
|
idx = 1
|
|
base_name = "opendataloader-from-env"
|
|
while True:
|
|
candidate = f"{base_name}-{idx}"
|
|
if candidate in used_names:
|
|
idx += 1
|
|
continue
|
|
try:
|
|
cls.save(
|
|
tenant_id=tenant_id,
|
|
llm_factory="OpenDataLoader",
|
|
llm_name=candidate,
|
|
model_type=LLMType.OCR.value,
|
|
api_key=json.dumps(cfg),
|
|
api_base="",
|
|
max_tokens=0,
|
|
)
|
|
return candidate
|
|
except IntegrityError:
|
|
logging.warning("OpenDataLoader env model %s already exists for tenant %s, retry with next name", candidate, tenant_id)
|
|
used_names.add(candidate)
|
|
idx += 1
|
|
continue
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def delete_by_tenant_id(cls, tenant_id):
|
|
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
|
|
|
@staticmethod
|
|
def llm_id2llm_type(llm_id: str) -> str | None:
|
|
from api.db.services.llm_service import LLMService
|
|
|
|
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
|
llm_factories = settings.FACTORY_LLM_INFOS
|
|
for llm_factory in llm_factories:
|
|
for llm in llm_factory["llm"]:
|
|
if llm_id == llm["llm_name"]:
|
|
return llm["model_type"].split(",")[-1]
|
|
|
|
for llm in LLMService.query(llm_name=llm_id):
|
|
return llm.model_type
|
|
|
|
llm = TenantLLMService.get_or_none(llm_name=llm_id)
|
|
if llm:
|
|
return llm.model_type
|
|
for llm in TenantLLMService.query(llm_name=llm_id):
|
|
return llm.model_type
|
|
return None
|
|
|
|
|
|
class LLM4Tenant:
|
|
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
|
|
self.trace_context = kwargs.pop("trace_context", None) or {}
|
|
self.langfuse_session_id = kwargs.pop("langfuse_session_id", None)
|
|
self.tenant_id = tenant_id
|
|
self.llm_name = model_config["llm_name"]
|
|
self.model_config = model_config
|
|
self.mdl = TenantLLMService.model_instance(model_config, lang=lang, **kwargs)
|
|
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, model_config["model_type"], model_config["llm_name"])
|
|
self.max_length = model_config.get("max_tokens", 8192)
|
|
|
|
self.is_tools = model_config.get("is_tools", False)
|
|
self.verbose_tool_use = kwargs.get("verbose_tool_use")
|
|
|
|
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
|
self.langfuse = None
|
|
if langfuse_keys:
|
|
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
|
try:
|
|
if langfuse.auth_check():
|
|
self.langfuse = langfuse
|
|
if not self.trace_context:
|
|
trace_id = self.langfuse.create_trace_id()
|
|
self.trace_context = {"trace_id": trace_id}
|
|
except Exception:
|
|
# Skip langfuse tracing if connection fails
|
|
pass
|
|
|
|
def close(self):
|
|
"""Release resources held by this LLM4Tenant instance.
|
|
|
|
This method should be called when the instance is no longer needed
|
|
to properly release resources such as:
|
|
- Langfuse tracing client (flush and shutdown)
|
|
- Underlying model instance resources (HTTP sessions, etc.)
|
|
"""
|
|
# Flush and shutdown Langfuse client if it was initialized
|
|
if self.langfuse:
|
|
try:
|
|
self.langfuse.flush()
|
|
if hasattr(self.langfuse, 'shutdown'):
|
|
self.langfuse.shutdown()
|
|
except Exception:
|
|
# Ignore errors during cleanup
|
|
pass
|
|
finally:
|
|
self.langfuse = None
|
|
|
|
# Release underlying model instance if it has a close method
|
|
if self.mdl and hasattr(self.mdl, 'close') and callable(getattr(self.mdl, 'close')):
|
|
try:
|
|
self.mdl.close()
|
|
except Exception:
|
|
# Ignore errors during cleanup
|
|
pass
|