feat: support multi-type factory models (#15893)

### What problem does this PR solve?
Support factory models with multiple model types, so visual chat models
can be exposed as both image2text and chat while preserving the database
model-type-per-record design.

This also updates the SILICONFLOW model list and adds a helper script to
refresh SiliconFlow models from the provider API.

### Type of change
- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
buua436
2026-06-10 15:35:21 +08:00
committed by GitHub
parent 478c9846a1
commit dcf623d60d
4 changed files with 897 additions and 514 deletions

View File

@@ -53,6 +53,13 @@ def _to_int(v, default=500):
return default
def _factory_model_types(llm: dict) -> list[str]:
model_type = llm.get("model_type")
if isinstance(model_type, list):
return model_type
return [model_type] if model_type else []
def _get_model_info(tenant_id: str, default_model: str, model_type: str):
"""
Parse a composite model string (modelName@instanceName@providerName or modelName@providerName)
@@ -147,7 +154,7 @@ def _get_model_info(tenant_id: str, default_model: str, model_type: str):
return None
# Check if the model_type matches
if target_llm[0].get("model_type") != model_type:
if model_type not in _factory_model_types(target_llm[0]):
logging.warning(f"Model '{model_name}' isn't a {model_type} model")
return None
@@ -217,7 +224,7 @@ def _check_model_available(tenant_id: str, provider_name: str, instance_name: st
return False, f"Model '{model_name}' not found for provider '{provider_name}'"
if target_llm:
if target_llm[0].get("model_type") != model_type:
if model_type not in _factory_model_types(target_llm[0]):
return False, f"Model '{model_name}' isn't a {model_type} model"
return True, None
@@ -347,7 +354,8 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None):
if not factory_instances:
continue
for llm in factory["llm"]:
if model_type_filter and llm["model_type"] != model_type_filter:
factory_model_types = _factory_model_types(llm)
if model_type_filter and model_type_filter not in factory_model_types:
continue
for factory_instance in factory_instances:
@@ -356,7 +364,7 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None):
manual_modified_models = model_record_map.get(model_record_key, [])
active_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.ACTIVE.value]
inactive_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.INACTIVE.value]
model_types = list(set([llm["model_type"]] + active_model_types) - set(inactive_model_types))
model_types = list(set(factory_model_types + active_model_types) - set(inactive_model_types))
if not model_types:
continue

View File

@@ -35,6 +35,13 @@ def _to_int(v, default=500):
return default
def _factory_model_types(llm: dict) -> list[str]:
model_type = llm.get("model_type")
if isinstance(model_type, list):
return model_type
return [model_type] if model_type else []
def _normalize_provider_base_url(provider_name: str, base_url: str | None):
if provider_name != "VLLM" or not base_url:
return base_url
@@ -66,9 +73,9 @@ def list_providers(tenant_id: str, all_available: bool = False):
if factory_info["name"] in ["Youdao", "FastEmbed", "BAAI", "Builtin", "siliconflow_intl"]:
continue
model_types = sorted(set(
llm["model_type"]
model_type
for llm in factory_info.get("llm", [])
if llm.get("model_type")
for model_type in _factory_model_types(llm)
))
provider = {
"model_types": model_types,
@@ -94,9 +101,9 @@ def list_providers(tenant_id: str, all_available: bool = False):
if name not in ["Youdao", "FastEmbed", "BAAI", "Builtin", "siliconflow_intl"] and factory_info_mapping.get(name):
factory_info = factory_info_mapping[name]
model_types = sorted(set(
llm["model_type"]
model_type
for llm in factory_info.get("llm", [])
if llm.get("model_type")
for model_type in _factory_model_types(llm)
))
provider = {
"model_types": model_types,
@@ -196,7 +203,7 @@ async def list_provider_models(provider_name: str, api_key: str = None, base_url
static_llms = [{
"name": llm["name"],
"max_tokens": llm["max_tokens"],
"model_types": [llm["model_type"]],
"model_types": _factory_model_types(llm),
"features": (
llm.get("features")
if llm.get("features") is not None
@@ -246,11 +253,9 @@ def show_provider_model(provider_name: str, model_name: str):
return True, {
"name": llm_info["name"],
"max_tokens": llm_info["max_tokens"],
"model_types": [llm_info["model_type"]],
"model_types": _factory_model_types(llm_info),
"thinking": None,
"model_type_map": {
llm_info["model_type"]: True
}
"model_type_map": {model_type: True for model_type in _factory_model_types(llm_info)}
}
@@ -414,7 +419,8 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
api_key = {"yiyan_ak": api_key, "yiyan_sk": ""}
api_key_str = api_key if isinstance(api_key, str) else json.dumps(api_key)
for llm in factory_llms:
if not embd_passed and llm["model_type"] == LLMType.EMBEDDING.value:
model_types = _factory_model_types(llm)
if not embd_passed and LLMType.EMBEDDING.value in model_types:
assert provider_name in EmbeddingModel, f"Embedding model from {provider_name} is not supported yet."
mdl = EmbeddingModel[provider_name](api_key_str, llm["llm_name"], base_url=base_url)
try:
@@ -432,7 +438,7 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
llm["llm_name"],
)
msg += f"\nFail to access embedding model({llm['llm_name']}) using this api key." + str(e)
elif not chat_passed and llm["model_type"] == LLMType.CHAT.value:
elif not chat_passed and LLMType.CHAT.value in model_types:
assert provider_name in ChatModel, f"Chat model from {provider_name} is not supported yet."
mdl = ChatModel[provider_name](api_key_str, llm["llm_name"], base_url=base_url, **extra)
try:
@@ -458,7 +464,7 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
llm["llm_name"],
)
msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e)
elif not rerank_passed and llm["model_type"] == LLMType.RERANK.value:
elif not rerank_passed and LLMType.RERANK.value in model_types:
assert provider_name in RerankModel, f"Rerank model from {provider_name} is not supported yet."
mdl = RerankModel[provider_name](api_key_str, llm["llm_name"], base_url=base_url)
try:
@@ -477,7 +483,7 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
llm["llm_name"],
)
msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e)
elif not ocr_passed and llm["model_type"] == LLMType.OCR.value:
elif not ocr_passed and LLMType.OCR.value in model_types:
assert provider_name in OcrModel, f"OCR model from {provider_name} is not supported yet."
mdl = OcrModel[provider_name](key=api_key_str, model_name=llm["llm_name"], base_url=base_url)
try:
@@ -617,7 +623,7 @@ def list_instance_models(tenant_id: str, provider_name: str, instance_name: str,
for llm in llms:
models.append({
"name": llm["llm_name"],
"model_type": [llm["model_type"]] + model_info_map.get(llm["llm_name"], {}).get("model_type", []),
"model_type": _factory_model_types(llm) + model_info_map.get(llm["llm_name"], {}).get("model_type", []),
"max_tokens": llm.get("max_tokens"),
"status": model_info_map.get(llm["llm_name"], {}).get("status", "active"),
})
@@ -657,7 +663,7 @@ def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str
for _type in model_type:
extra_fields = {"max_tokens": max_tokens}
target_model = [llm for llm in llms if llm["model_type"] == _type and llm["llm_name"] == model_name]
target_model = [llm for llm in llms if _type in _factory_model_types(llm) and llm["llm_name"] == model_name]
if target_model:
extra_fields.update({"is_tools": target_model[0].get("is_tools", False)})
if extra:
@@ -725,15 +731,16 @@ def update_model_status(tenant_id: str, provider_name: str, instance_name: str,
if not target_llm:
return False, f"provider {provider_name} model {model_name} not found"
TenantModelService.insert(
id=get_uuid(),
model_name=model_name,
model_type=target_llm[0]["model_type"],
provider_id=provider_obj.id,
instance_id=instance_obj.id,
status=status,
extra=json.dumps({"max_tokens": target_llm[0].get("max_tokens", 8192), "is_tools": target_llm[0].get("is_tools", False)})
)
for model_type in _factory_model_types(target_llm[0]):
TenantModelService.insert(
id=get_uuid(),
model_name=model_name,
model_type=model_type,
provider_id=provider_obj.id,
instance_id=instance_obj.id,
status=status,
extra=json.dumps({"max_tokens": target_llm[0].get("max_tokens", 8192), "is_tools": target_llm[0].get("is_tools", False)})
)
return True, None

View File

@@ -27,6 +27,11 @@ from api.db.services.tenant_model_service import TenantModelService
logger = logging.getLogger(__name__)
def _factory_model_types(llm: dict) -> list[str]:
model_type = llm.get("model_type")
if isinstance(model_type, list):
return model_type
return [model_type] if model_type else []
def _decode_api_key_config(raw_api_key: str) -> tuple[str, bool | None, str | None]:
if not raw_api_key:
return raw_api_key, None, None
@@ -235,12 +240,14 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str|enum.Enum
if not llm_list:
raise LookupError(f"Model config not found: {model_name}")
llm_info = llm_list[0]
if model_type_val not in _factory_model_types(llm_info):
raise LookupError(f"Model {model_name} is not a {model_type_val} model.")
model_config = {
"llm_factory": provider_obj.provider_name,
"api_key": api_key,
"llm_name": llm_info["llm_name"],
"api_base": extra_fields.get("base_url", ""),
"model_type": llm_info["model_type"],
"model_type": model_type_val,
"is_tools": llm_info.get("is_tools", is_tool)
}
if api_key_payload is not None:
@@ -284,7 +291,7 @@ def get_model_type_by_name(tenant_id: str, model_name: str):
llm_list = [llm for llm in fac_list[0]["llm"] if llm["llm_name"] == pure_model_name]
if not llm_list:
raise LookupError(f"Model {pure_model_name} not found for model {model_name}.")
return [llm_list[0]["model_type"]]
return _factory_model_types(llm_list[0])
return [model_obj.model_type for model_obj in model_objs]

File diff suppressed because it is too large Load Diff