mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat: support multi-type factory models (#15893)
### What problem does this PR solve? Support factory models with multiple model types, so visual chat models can be exposed as both image2text and chat while preserving the database model-type-per-record design. This also updates the SILICONFLOW model list and adds a helper script to refresh SiliconFlow models from the provider API. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -53,6 +53,13 @@ def _to_int(v, default=500):
|
||||
return default
|
||||
|
||||
|
||||
def _factory_model_types(llm: dict) -> list[str]:
|
||||
model_type = llm.get("model_type")
|
||||
if isinstance(model_type, list):
|
||||
return model_type
|
||||
return [model_type] if model_type else []
|
||||
|
||||
|
||||
def _get_model_info(tenant_id: str, default_model: str, model_type: str):
|
||||
"""
|
||||
Parse a composite model string (modelName@instanceName@providerName or modelName@providerName)
|
||||
@@ -147,7 +154,7 @@ def _get_model_info(tenant_id: str, default_model: str, model_type: str):
|
||||
return None
|
||||
|
||||
# Check if the model_type matches
|
||||
if target_llm[0].get("model_type") != model_type:
|
||||
if model_type not in _factory_model_types(target_llm[0]):
|
||||
logging.warning(f"Model '{model_name}' isn't a {model_type} model")
|
||||
return None
|
||||
|
||||
@@ -217,7 +224,7 @@ def _check_model_available(tenant_id: str, provider_name: str, instance_name: st
|
||||
return False, f"Model '{model_name}' not found for provider '{provider_name}'"
|
||||
|
||||
if target_llm:
|
||||
if target_llm[0].get("model_type") != model_type:
|
||||
if model_type not in _factory_model_types(target_llm[0]):
|
||||
return False, f"Model '{model_name}' isn't a {model_type} model"
|
||||
|
||||
return True, None
|
||||
@@ -347,7 +354,8 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None):
|
||||
if not factory_instances:
|
||||
continue
|
||||
for llm in factory["llm"]:
|
||||
if model_type_filter and llm["model_type"] != model_type_filter:
|
||||
factory_model_types = _factory_model_types(llm)
|
||||
if model_type_filter and model_type_filter not in factory_model_types:
|
||||
continue
|
||||
|
||||
for factory_instance in factory_instances:
|
||||
@@ -356,7 +364,7 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None):
|
||||
manual_modified_models = model_record_map.get(model_record_key, [])
|
||||
active_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.ACTIVE.value]
|
||||
inactive_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.INACTIVE.value]
|
||||
model_types = list(set([llm["model_type"]] + active_model_types) - set(inactive_model_types))
|
||||
model_types = list(set(factory_model_types + active_model_types) - set(inactive_model_types))
|
||||
if not model_types:
|
||||
continue
|
||||
|
||||
|
||||
@@ -35,6 +35,13 @@ def _to_int(v, default=500):
|
||||
return default
|
||||
|
||||
|
||||
def _factory_model_types(llm: dict) -> list[str]:
|
||||
model_type = llm.get("model_type")
|
||||
if isinstance(model_type, list):
|
||||
return model_type
|
||||
return [model_type] if model_type else []
|
||||
|
||||
|
||||
def _normalize_provider_base_url(provider_name: str, base_url: str | None):
|
||||
if provider_name != "VLLM" or not base_url:
|
||||
return base_url
|
||||
@@ -66,9 +73,9 @@ def list_providers(tenant_id: str, all_available: bool = False):
|
||||
if factory_info["name"] in ["Youdao", "FastEmbed", "BAAI", "Builtin", "siliconflow_intl"]:
|
||||
continue
|
||||
model_types = sorted(set(
|
||||
llm["model_type"]
|
||||
model_type
|
||||
for llm in factory_info.get("llm", [])
|
||||
if llm.get("model_type")
|
||||
for model_type in _factory_model_types(llm)
|
||||
))
|
||||
provider = {
|
||||
"model_types": model_types,
|
||||
@@ -94,9 +101,9 @@ def list_providers(tenant_id: str, all_available: bool = False):
|
||||
if name not in ["Youdao", "FastEmbed", "BAAI", "Builtin", "siliconflow_intl"] and factory_info_mapping.get(name):
|
||||
factory_info = factory_info_mapping[name]
|
||||
model_types = sorted(set(
|
||||
llm["model_type"]
|
||||
model_type
|
||||
for llm in factory_info.get("llm", [])
|
||||
if llm.get("model_type")
|
||||
for model_type in _factory_model_types(llm)
|
||||
))
|
||||
provider = {
|
||||
"model_types": model_types,
|
||||
@@ -196,7 +203,7 @@ async def list_provider_models(provider_name: str, api_key: str = None, base_url
|
||||
static_llms = [{
|
||||
"name": llm["name"],
|
||||
"max_tokens": llm["max_tokens"],
|
||||
"model_types": [llm["model_type"]],
|
||||
"model_types": _factory_model_types(llm),
|
||||
"features": (
|
||||
llm.get("features")
|
||||
if llm.get("features") is not None
|
||||
@@ -246,11 +253,9 @@ def show_provider_model(provider_name: str, model_name: str):
|
||||
return True, {
|
||||
"name": llm_info["name"],
|
||||
"max_tokens": llm_info["max_tokens"],
|
||||
"model_types": [llm_info["model_type"]],
|
||||
"model_types": _factory_model_types(llm_info),
|
||||
"thinking": None,
|
||||
"model_type_map": {
|
||||
llm_info["model_type"]: True
|
||||
}
|
||||
"model_type_map": {model_type: True for model_type in _factory_model_types(llm_info)}
|
||||
}
|
||||
|
||||
|
||||
@@ -414,7 +419,8 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
|
||||
api_key = {"yiyan_ak": api_key, "yiyan_sk": ""}
|
||||
api_key_str = api_key if isinstance(api_key, str) else json.dumps(api_key)
|
||||
for llm in factory_llms:
|
||||
if not embd_passed and llm["model_type"] == LLMType.EMBEDDING.value:
|
||||
model_types = _factory_model_types(llm)
|
||||
if not embd_passed and LLMType.EMBEDDING.value in model_types:
|
||||
assert provider_name in EmbeddingModel, f"Embedding model from {provider_name} is not supported yet."
|
||||
mdl = EmbeddingModel[provider_name](api_key_str, llm["llm_name"], base_url=base_url)
|
||||
try:
|
||||
@@ -432,7 +438,7 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
|
||||
llm["llm_name"],
|
||||
)
|
||||
msg += f"\nFail to access embedding model({llm['llm_name']}) using this api key." + str(e)
|
||||
elif not chat_passed and llm["model_type"] == LLMType.CHAT.value:
|
||||
elif not chat_passed and LLMType.CHAT.value in model_types:
|
||||
assert provider_name in ChatModel, f"Chat model from {provider_name} is not supported yet."
|
||||
mdl = ChatModel[provider_name](api_key_str, llm["llm_name"], base_url=base_url, **extra)
|
||||
try:
|
||||
@@ -458,7 +464,7 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
|
||||
llm["llm_name"],
|
||||
)
|
||||
msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e)
|
||||
elif not rerank_passed and llm["model_type"] == LLMType.RERANK.value:
|
||||
elif not rerank_passed and LLMType.RERANK.value in model_types:
|
||||
assert provider_name in RerankModel, f"Rerank model from {provider_name} is not supported yet."
|
||||
mdl = RerankModel[provider_name](api_key_str, llm["llm_name"], base_url=base_url)
|
||||
try:
|
||||
@@ -477,7 +483,7 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
|
||||
llm["llm_name"],
|
||||
)
|
||||
msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e)
|
||||
elif not ocr_passed and llm["model_type"] == LLMType.OCR.value:
|
||||
elif not ocr_passed and LLMType.OCR.value in model_types:
|
||||
assert provider_name in OcrModel, f"OCR model from {provider_name} is not supported yet."
|
||||
mdl = OcrModel[provider_name](key=api_key_str, model_name=llm["llm_name"], base_url=base_url)
|
||||
try:
|
||||
@@ -617,7 +623,7 @@ def list_instance_models(tenant_id: str, provider_name: str, instance_name: str,
|
||||
for llm in llms:
|
||||
models.append({
|
||||
"name": llm["llm_name"],
|
||||
"model_type": [llm["model_type"]] + model_info_map.get(llm["llm_name"], {}).get("model_type", []),
|
||||
"model_type": _factory_model_types(llm) + model_info_map.get(llm["llm_name"], {}).get("model_type", []),
|
||||
"max_tokens": llm.get("max_tokens"),
|
||||
"status": model_info_map.get(llm["llm_name"], {}).get("status", "active"),
|
||||
})
|
||||
@@ -657,7 +663,7 @@ def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str
|
||||
|
||||
for _type in model_type:
|
||||
extra_fields = {"max_tokens": max_tokens}
|
||||
target_model = [llm for llm in llms if llm["model_type"] == _type and llm["llm_name"] == model_name]
|
||||
target_model = [llm for llm in llms if _type in _factory_model_types(llm) and llm["llm_name"] == model_name]
|
||||
if target_model:
|
||||
extra_fields.update({"is_tools": target_model[0].get("is_tools", False)})
|
||||
if extra:
|
||||
@@ -725,15 +731,16 @@ def update_model_status(tenant_id: str, provider_name: str, instance_name: str,
|
||||
if not target_llm:
|
||||
return False, f"provider {provider_name} model {model_name} not found"
|
||||
|
||||
TenantModelService.insert(
|
||||
id=get_uuid(),
|
||||
model_name=model_name,
|
||||
model_type=target_llm[0]["model_type"],
|
||||
provider_id=provider_obj.id,
|
||||
instance_id=instance_obj.id,
|
||||
status=status,
|
||||
extra=json.dumps({"max_tokens": target_llm[0].get("max_tokens", 8192), "is_tools": target_llm[0].get("is_tools", False)})
|
||||
)
|
||||
for model_type in _factory_model_types(target_llm[0]):
|
||||
TenantModelService.insert(
|
||||
id=get_uuid(),
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
provider_id=provider_obj.id,
|
||||
instance_id=instance_obj.id,
|
||||
status=status,
|
||||
extra=json.dumps({"max_tokens": target_llm[0].get("max_tokens", 8192), "is_tools": target_llm[0].get("is_tools", False)})
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
@@ -27,6 +27,11 @@ from api.db.services.tenant_model_service import TenantModelService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _factory_model_types(llm: dict) -> list[str]:
|
||||
model_type = llm.get("model_type")
|
||||
if isinstance(model_type, list):
|
||||
return model_type
|
||||
return [model_type] if model_type else []
|
||||
def _decode_api_key_config(raw_api_key: str) -> tuple[str, bool | None, str | None]:
|
||||
if not raw_api_key:
|
||||
return raw_api_key, None, None
|
||||
@@ -235,12 +240,14 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str|enum.Enum
|
||||
if not llm_list:
|
||||
raise LookupError(f"Model config not found: {model_name}")
|
||||
llm_info = llm_list[0]
|
||||
if model_type_val not in _factory_model_types(llm_info):
|
||||
raise LookupError(f"Model {model_name} is not a {model_type_val} model.")
|
||||
model_config = {
|
||||
"llm_factory": provider_obj.provider_name,
|
||||
"api_key": api_key,
|
||||
"llm_name": llm_info["llm_name"],
|
||||
"api_base": extra_fields.get("base_url", ""),
|
||||
"model_type": llm_info["model_type"],
|
||||
"model_type": model_type_val,
|
||||
"is_tools": llm_info.get("is_tools", is_tool)
|
||||
}
|
||||
if api_key_payload is not None:
|
||||
@@ -284,7 +291,7 @@ def get_model_type_by_name(tenant_id: str, model_name: str):
|
||||
llm_list = [llm for llm in fac_list[0]["llm"] if llm["llm_name"] == pure_model_name]
|
||||
if not llm_list:
|
||||
raise LookupError(f"Model {pure_model_name} not found for model {model_name}.")
|
||||
return [llm_list[0]["model_type"]]
|
||||
return _factory_model_types(llm_list[0])
|
||||
return [model_obj.model_type for model_obj in model_objs]
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user