Fix: model-provider bugs (#15460)

### What problem does this PR solve?

Fix:
- Use @ to avoid split  by `_` in model_name.
- Verify api_key when add instance.
- Pop api_key in list intances response.
- Remove useless index.
- Sort providers, instances and models by name.
- Get `is_tools` from llm_factories.json

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Lynn
2026-06-02 13:24:53 +08:00
committed by GitHub
parent 0e9eeb7b88
commit 3bc5ed282e
7 changed files with 242 additions and 46 deletions

View File

@@ -335,9 +335,69 @@ async def create_provider_instance(tenant_id: str = None, provider_name: str = N
api_key = data["api_key"]
base_url = data.get("base_url", "")
region = data.get("region", "")
verify = data.get("verify", False)
try:
success, msg = provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region)
success, msg = await provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region, verify)
if success:
return get_result(message=msg)
else:
return get_error_data_result(message=msg)
except Exception as e:
logging.exception(e)
return get_error_data_result(message="Internal server error")
@manager.route("/providers/<provider_name>/connection", methods=["POST"]) # noqa: F821
@login_required
async def verify_provider_api_key(provider_name: str = None):
"""
Verify api key.
---
tags:
- Providers
security:
- ApiKeyAuth: []
parameters:
- in: path
name: provider_name
type: string
required: true
description: Provider name.
- in: header
name: Authorization
type: string
required: true
description: Bearer token for authentication.
- in: body
name: body
description: Instance creation parameters.
required: true
schema:
type: object
required:
- api_key
properties:
api_key:
type: string
description: API key.
base_url:
type: string
description: Base URL.
responses:
200:
description: Instance created successfully.
schema:
type: object
"""
data = await request.get_json()
if not data or "api_key" not in data:
return get_error_argument_result(message="api_key is required")
base_url = data.get("base_url", "")
api_key = data["api_key"]
try:
success, msg = await provider_api_service.verify_api_key(provider_name, api_key, base_url)
if success:
return get_result(message=msg)
else:

View File

@@ -298,7 +298,7 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None):
target_type_records = [record for record in model_records if record.model_type == model_type_filter] if model_type_filter else model_records
model_record_map = {}
for model in target_type_records:
instance_model_key = f"{model.provider_id}_{model.instance_id}_{model.model_name}"
instance_model_key = f"{model.provider_id}@{model.instance_id}@{model.model_name}"
if model_record_map.get(instance_model_key):
model_record_map[instance_model_key].append(model)
else:
@@ -318,7 +318,7 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None):
continue
for factory_instance in factory_instances:
model_record_key = f"{factory_instance.provider_id}_{factory_instance.id}_{llm['llm_name']}"
model_record_key = f"{factory_instance.provider_id}@{factory_instance.id}@{llm['llm_name']}"
model_key_in_factory.append(model_record_key)
manual_modified_models = model_record_map.get(model_record_key, [])
active_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.ACTIVE.value]
@@ -343,7 +343,7 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None):
model_records = model_record_map.get(model_record_key, [])
if not model_records:
continue
provider_id, instance_id, model_name = model_record_key.split("_")
provider_id, instance_id, model_name = model_record_key.split("@")
model_types = [model.model_type for model in model_records if model.status == ActiveStatusEnum.ACTIVE.value]
if not model_types:
continue
@@ -357,4 +357,6 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None):
"instance_name": instance_info_map[instance_id].instance_name if instance_info_map.get(instance_id) else ""
})
added_models.sort(key=lambda x: (x["provider_name"], x["instance_name"], x["name"]))
return True, added_models

View File

@@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import json
import logging
import asyncio
from common.constants import LLMType, ActiveStatusEnum
from common.misc_utils import get_uuid
@@ -23,6 +25,7 @@ from api.db.joint_services.tenant_model_service import get_model_config_from_pro
from api.db.services.tenant_model_provider_service import TenantModelProviderService
from api.db.services.tenant_model_instance_service import TenantModelInstanceService
from api.db.services.tenant_model_service import TenantModelService
from rag.llm import EmbeddingModel, ChatModel, RerankModel
def list_providers(tenant_id: str, all_available: bool = False):
@@ -54,6 +57,7 @@ def list_providers(tenant_id: str, all_available: bool = False):
"default": factory_info.get("url", "")
}
})
providers.sort(key=lambda x: x["name"])
return True, providers
# List tenant-configured providers
@@ -76,7 +80,7 @@ def list_providers(tenant_id: str, all_available: bool = False):
"default": factory_info.get("url", "")
}
})
providers.sort(key=lambda x: x["name"])
return True, providers
@@ -117,7 +121,7 @@ def delete_provider(tenant_id: str, provider_name: str):
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
if not provider_obj:
return False, f"Provider {provider_name} not found"
instance_objs = TenantModelInstanceService.get_by_provider_id(provider_obj.id)
instance_objs = TenantModelInstanceService.get_all_by_provider_id(provider_obj.id)
if not instance_objs:
return False, f"No instances found for provider {provider_name}"
instance_ids = [instance_obj.id for instance_obj in instance_objs]
@@ -169,6 +173,7 @@ def list_provider_models(provider_name: str):
"model_types": [llm["model_type"]],
"features": None
})
models.sort(key=lambda x: x["name"])
return True, models
@@ -202,7 +207,7 @@ def show_provider_model(provider_name: str, model_name: str):
}
def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str, base_url: str, region: str):
async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str, base_url: str, region: str, verify: bool):
"""
Create a provider instance.
@@ -215,6 +220,7 @@ def create_provider_instance(tenant_id: str, provider_name: str, instance_name:
:param api_key: API key
:param base_url: base url
:param region: region
:param verify: verify
:return: (success, result_or_error_message)
"""
if not provider_name:
@@ -236,6 +242,10 @@ def create_provider_instance(tenant_id: str, provider_name: str, instance_name:
same_key_instance = TenantModelInstanceService.get_by_provider_id_and_api_key(provider_obj.id, api_key)
if same_key_instance:
return False, f"Already exist instance: {same_key_instance.instance_name} with api_key {api_key}"
if verify:
success, msg = await verify_api_key(provider_name, api_key, base_url)
if not success:
return False, msg
import json
extra_fields = {}
@@ -268,7 +278,6 @@ def list_provider_instances(tenant_id: str, provider_name: str):
for instance_obj in instance_objs:
extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {}
instances.append({
"api_key": instance_obj.api_key,
"id": instance_obj.id,
"instance_name": instance_obj.instance_name,
"provider_id": provider_id,
@@ -276,7 +285,93 @@ def list_provider_instances(tenant_id: str, provider_name: str):
"status": instance_obj.status,
})
return True, instances
active_instances = [instance for instance in instances if instance["status"] == ActiveStatusEnum.ACTIVE.value]
inactive_instances = [instance for instance in instances if instance["status"] == ActiveStatusEnum.INACTIVE.value]
active_instances.sort(key=lambda x: x["instance_name"])
inactive_instances.sort(key=lambda x: x["instance_name"])
return True, active_instances + inactive_instances
async def verify_api_key(provider_name: str, api_key: str, base_url: str=None):
"""
Verify API key for a provider.
:param provider_name: provider/factory name
:param api_key: API key
:param base_url: base url
:return: (success, result_or_error_message)
"""
if not provider_name:
return False, "Provider name is required"
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name]
if not factory_info:
return False, f"Provider '{provider_name}' not found"
factory_llms = factory_info[0]["llm"]
if not factory_llms:
return False, f"No models found for provider '{provider_name}'"
# test if api key works
chat_passed, embd_passed, rerank_passed = False, False, False
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
extra = {"provider": provider_name}
msg = ""
for llm in factory_llms:
if not embd_passed and llm["model_type"] == LLMType.EMBEDDING.value:
assert provider_name in EmbeddingModel, f"Embedding model from {provider_name} is not supported yet."
mdl = EmbeddingModel[provider_name](api_key, llm["llm_name"], base_url=base_url)
try:
arr, tc = asyncio.wait_for(
asyncio.to_thread(mdl.encode, ["Test if the api key is available"]),
timeout=timeout_seconds,
)
if len(arr[0]) == 0:
raise Exception("Fail")
embd_passed = True
except Exception as e:
msg += f"\nFail to access embedding model({llm['llm_name']}) using this api key." + str(e)
elif not chat_passed and llm["model_type"] == LLMType.CHAT.value:
assert provider_name in ChatModel, f"Chat model from {provider_name} is not supported yet."
mdl = ChatModel[provider_name](api_key, llm["llm_name"], base_url=base_url, **extra)
try:
async def check_streamly():
async for chunk in mdl.async_chat_streamly(
None,
[{"role": "user", "content": "Hi"}],
{"temperature": 0.9},
):
if chunk and isinstance(chunk, str) and chunk.find("**ERROR**") < 0:
return True
return False
result = await asyncio.wait_for(check_streamly(), timeout=timeout_seconds)
if result:
chat_passed = True
else:
raise Exception("No valid response received")
except Exception as e:
msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e)
elif not rerank_passed and llm["model_type"] == LLMType.RERANK.value:
assert provider_name in RerankModel, f"Rerank model from {provider_name} is not supported yet."
mdl = RerankModel[provider_name](api_key, llm["llm_name"], base_url=base_url)
try:
arr, tc = await asyncio.wait_for(
asyncio.to_thread(mdl.similarity, "What's the weather?", ["Is it sunny today?"]),
timeout=timeout_seconds,
)
if len(arr) == 0 or tc == 0:
raise Exception("Fail")
rerank_passed = True
logging.debug(f"passed model rerank {llm['llm_name']}")
except Exception as e:
msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e)
if any([embd_passed, chat_passed, rerank_passed]):
msg = ""
break
return any([embd_passed, chat_passed, rerank_passed]), msg or "success"
def show_provider_instance(tenant_id: str, provider_name: str, instance_name: str):
@@ -362,6 +457,7 @@ def list_instance_models(tenant_id: str, provider_name: str, instance_name: str,
return False, f"Provider '{provider_name}' not found"
llms = factory_info[0].get("llm", [])
models = [{"name": llm["llm_name"]} for llm in llms]
models.sort(key=lambda x: x["name"])
return True, models
# Get instance
@@ -407,8 +503,11 @@ def list_instance_models(tenant_id: str, provider_name: str, instance_name: str,
"max_tokens": extra_fields.get("max_tokens", 8192),
"status": model_info_dict["status"],
})
return True, models
active_models = [model for model in models if model["status"] == ActiveStatusEnum.ACTIVE.value]
inactive_models = [model for model in models if model["status"] == ActiveStatusEnum.INACTIVE.value]
active_models.sort(key=lambda x: x["name"])
inactive_models.sort(key=lambda x: x["name"])
return True, active_models + inactive_models
def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str, model_name: str, model_type: str|list[str], max_tokens: int, extra: dict):
@@ -434,7 +533,7 @@ def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str
extra_fields = {"max_tokens": max_tokens}
target_model = [llm for llm in llms if llm["model_type"] == _type and llm["llm_name"] == model_name]
if target_model:
extra_fields.update({"is_tool": target_model[0].get("is_tool", False)})
extra_fields.update({"is_tools": target_model[0].get("is_tools", False)})
extra_fields.update(extra)
TenantModelService.insert(
model_name=model_name,
@@ -506,6 +605,7 @@ def update_model_status(tenant_id: str, provider_name: str, instance_name: str,
provider_id=provider_obj.id,
instance_id=instance_obj.id,
status=status,
extra=json.dumps({"max_tokens": target_llm[0].get("max_tokens", 8192), "is_tools": target_llm[0].get("is_tools", False)})
)
return True, None

View File

@@ -1372,9 +1372,6 @@ class TenantModelInstance(DataBaseModel):
class Meta:
db_table = "tenant_model_instance"
indexes = (
(("api_key", "provider_id"), True),
)
class TenantModel(DataBaseModel):
@@ -1731,6 +1728,17 @@ def migrate_db():
alter_db_column_type(migrator, "document", "size", BigIntegerField(default=0, index=True))
alter_db_column_type(migrator, "file", "size", BigIntegerField(default=0, index=True))
alter_db_add_column(migrator, "tenant", "ocr_id", CharField(max_length=128, null=True, help_text="default ocr model ID", index=True))
for table_name, index_name in [("tenant_model_instance", "idx_api_key_provider_id"), ("tenant_model", "idx_provider_model_instance")]:
try:
migrate(migrator.drop_index(table_name, index_name))
except (OperationalError, ProgrammingError) as ex:
msg = str(ex)
if "1091" in msg or "can't DROP" in msg.lower() or "does not exist" in msg.lower() or "already exists" in msg.lower():
pass
else:
logging.critical(f"Failed to drop index {index_name} on {table_name}: {ex}")
except Exception as ex:
logging.critical(f"Failed to drop index {index_name} on {table_name}: {ex}")
logging.disable(logging.NOTSET)
# this is after re-enabling logging to allow logging changed user emails
migrate_add_unique_email(migrator)

View File

@@ -116,7 +116,7 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str|enum.Enum
"llm_name": model_obj.model_name,
"api_base": extra_fields.get("base_url", ""),
"model_type": model_obj.model_type,
"is_tool": extra_fields.get("is_tool", is_tool)
"is_tools": extra_fields.get("is_tools", is_tool)
}
if api_key_payload is not None:
model_config["api_key_payload"] = api_key_payload
@@ -136,7 +136,7 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str|enum.Enum
"llm_name": llm_info["llm_name"],
"api_base": extra_fields.get("base_url", ""),
"model_type": llm_info["model_type"],
"is_tool": llm_info.get("is_tool", is_tool)
"is_tools": llm_info.get("is_tools", is_tool)
}
if api_key_payload is not None:
model_config["api_key_payload"] = api_key_payload

View File

@@ -20,8 +20,8 @@ package entity
type TenantModelInstance struct {
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
InstanceName string `gorm:"column:instance_name;size:128;not null" json:"instance_name"`
ProviderID string `gorm:"column:provider_id;size:32;not null;uniqueIndex:idx_api_key_provider_id" json:"provider_id"`
APIKey string `gorm:"column:api_key;size:512;not null;uniqueIndex:idx_api_key_provider_id" json:"api_key"`
ProviderID string `gorm:"column:provider_id;size:32;not null" json:"provider_id"`
APIKey string `gorm:"column:api_key;size:512;not null" json:"api_key"`
Status string `gorm:"column:status;size:32;default:'active'" json:"status"`
Extra string `gorm:"column:extra;size:512;default:'{}'" json:"extra"`
BaseModel

View File

@@ -539,7 +539,6 @@ class TenantModelInstanceStage(MigrationStage):
create_date DATETIME,
update_time BIGINT,
update_date DATETIME,
UNIQUE INDEX idx_api_key_provider_id (api_key, provider_id),
INDEX idx_provider_id (provider_id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
"""
@@ -551,10 +550,34 @@ class TenantModelStage(MigrationStage):
"""Migrate tenant_llm to tenant_model"""
name = "tenant_model"
description = "Migrate tenant_llm to tenant_model (only status='0' records)"
description = "Migrate tenant_llm to tenant_model (status='0' records, plus status='1' for empty-llm factories)"
source_tables = ["tenant_llm", "tenant_model_provider", "tenant_model_instance"]
target_tables = ["tenant_model"]
@staticmethod
def _get_empty_llm_factories() -> list[str]:
"""Load factory names whose llm field is an empty list from conf/llm_factories.json"""
conf_path = os.path.join(PROJECT_BASE, "conf", "llm_factories.json")
with open(conf_path, "r") as f:
data = json.load(f)
factories = []
for key, items in data.items():
if isinstance(items, list):
for item in items:
if isinstance(item, dict):
llm = item.get("llm")
if isinstance(llm, list) and len(llm) == 0:
factories.append(item["name"])
return factories
def _build_status_condition(self) -> str:
"""Build SQL WHERE condition for status filtering"""
empty_factories = self._get_empty_llm_factories()
if empty_factories:
placeholders = ", ".join(f"'{f}'" for f in empty_factories)
return f"(tl.status = '0' OR (tl.status = '1' AND tl.llm_factory IN ({placeholders})))"
return "tl.status = '0'"
def current_timestamp(self) -> int:
return int(time.time())
@@ -598,24 +621,26 @@ class TenantModelStage(MigrationStage):
logger.info("Target table 'tenant_model' does not exist, will create")
return True
# Check if there's data to migrate (only status='0' records)
status_condition = self._build_status_condition()
# Check if there's data to migrate
cursor = self.db.execute_sql(
"SELECT COUNT(*) FROM ("
" SELECT tl.id "
" FROM tenant_llm tl "
" INNER JOIN tenant_model_provider tmp ON tmp.tenant_id = tl.tenant_id AND tmp.provider_name = tl.llm_factory "
" INNER JOIN tenant_model_instance tmi ON tmi.provider_id = tmp.id AND tmi.api_key = tl.api_key "
" WHERE tl.status = '0' "
" AND NOT EXISTS ("
" SELECT 1 FROM tenant_model tm "
" WHERE tm.provider_id = tmp.id AND tm.model_name = tl.llm_name AND tm.instance_id = tmi.id"
" )"
") AS distinct_records"
f"SELECT COUNT(*) FROM ("
f" SELECT tl.id "
f" FROM tenant_llm tl "
f" INNER JOIN tenant_model_provider tmp ON tmp.tenant_id = tl.tenant_id AND tmp.provider_name = tl.llm_factory "
f" INNER JOIN tenant_model_instance tmi ON tmi.provider_id = tmp.id AND tmi.api_key = tl.api_key "
f" WHERE {status_condition} "
f" AND NOT EXISTS ("
f" SELECT 1 FROM tenant_model tm "
f" WHERE tm.provider_id = tmp.id AND tm.model_name = tl.llm_name AND tm.instance_id = tmi.id"
f" )"
f") AS distinct_records"
)
count = cursor.fetchone()[0]
if count == 0:
logger.info("No new data to migrate from tenant_llm to tenant_model (status='0' only)")
logger.info("No new data to migrate from tenant_llm to tenant_model")
return False
logger.info(f"Found {count} rows to migrate from tenant_llm to tenant_model")
@@ -652,19 +677,21 @@ class TenantModelStage(MigrationStage):
logger.info("[CREATE TABLE ONLY] Target table created/verified, skipping data migration")
return 0, self.target_tables
status_condition = self._build_status_condition()
# Get records from tenant_llm with provider_id and instance_id lookup
# Only migrate records where status='0'
# Migrate status='0' records, plus status='1' for empty-llm factories
cursor = self.db.execute_sql(
"SELECT tl.id, tl.llm_name, tmp.id as provider_id, tmi.id as instance_id, "
" tl.model_type, tl.status "
"FROM tenant_llm tl "
"INNER JOIN tenant_model_provider tmp ON tmp.tenant_id = tl.tenant_id AND tmp.provider_name = tl.llm_factory "
"INNER JOIN tenant_model_instance tmi ON tmi.provider_id = tmp.id AND tmi.api_key = tl.api_key "
"WHERE tl.status = '0' "
"AND NOT EXISTS ("
" SELECT 1 FROM tenant_model tm "
" WHERE tm.provider_id = tmp.id AND tm.model_name = tl.llm_name AND tm.instance_id = tmi.id"
")"
f"SELECT tl.id, tl.llm_name, tmp.id as provider_id, tmi.id as instance_id, "
f" tl.model_type, tl.status "
f"FROM tenant_llm tl "
f"INNER JOIN tenant_model_provider tmp ON tmp.tenant_id = tl.tenant_id AND tmp.provider_name = tl.llm_factory "
f"INNER JOIN tenant_model_instance tmi ON tmi.provider_id = tmp.id AND tmi.api_key = tl.api_key "
f"WHERE {status_condition} "
f"AND NOT EXISTS ("
f" SELECT 1 FROM tenant_model tm "
f" WHERE tm.provider_id = tmp.id AND tm.model_name = tl.llm_name AND tm.instance_id = tmi.id"
f")"
)
records = cursor.fetchall()
@@ -727,7 +754,6 @@ class TenantModelStage(MigrationStage):
update_time BIGINT,
update_date DATETIME,
INDEX idx_instance_id (instance_id),
UNIQUE INDEX idx_provider_model_instance (provider_id, model_name, instance_id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
"""
self.db.execute_sql(create_sql)