diff --git a/api/apps/restful_apis/provider_api.py b/api/apps/restful_apis/provider_api.py index 6a9806b47b..d27e6042ec 100644 --- a/api/apps/restful_apis/provider_api.py +++ b/api/apps/restful_apis/provider_api.py @@ -335,9 +335,69 @@ async def create_provider_instance(tenant_id: str = None, provider_name: str = N api_key = data["api_key"] base_url = data.get("base_url", "") region = data.get("region", "") + verify = data.get("verify", False) try: - success, msg = provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region) + success, msg = await provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region, verify) + if success: + return get_result(message=msg) + else: + return get_error_data_result(message=msg) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/providers//connection", methods=["POST"]) # noqa: F821 +@login_required +async def verify_provider_api_key(provider_name: str = None): + """ + Verify api key. + --- + tags: + - Providers + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: provider_name + type: string + required: true + description: Provider name. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Instance creation parameters. + required: true + schema: + type: object + required: + - api_key + properties: + api_key: + type: string + description: API key. + base_url: + type: string + description: Base URL. + responses: + 200: + description: Instance created successfully. + schema: + type: object + """ + data = await request.get_json() + if not data or "api_key" not in data: + return get_error_argument_result(message="api_key is required") + + base_url = data.get("base_url", "") + api_key = data["api_key"] + try: + success, msg = await provider_api_service.verify_api_key(provider_name, api_key, base_url) if success: return get_result(message=msg) else: diff --git a/api/apps/services/models_api_service.py b/api/apps/services/models_api_service.py index 043957bd8d..bb853825c0 100644 --- a/api/apps/services/models_api_service.py +++ b/api/apps/services/models_api_service.py @@ -298,7 +298,7 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None): target_type_records = [record for record in model_records if record.model_type == model_type_filter] if model_type_filter else model_records model_record_map = {} for model in target_type_records: - instance_model_key = f"{model.provider_id}_{model.instance_id}_{model.model_name}" + instance_model_key = f"{model.provider_id}@{model.instance_id}@{model.model_name}" if model_record_map.get(instance_model_key): model_record_map[instance_model_key].append(model) else: @@ -318,7 +318,7 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None): continue for factory_instance in factory_instances: - model_record_key = f"{factory_instance.provider_id}_{factory_instance.id}_{llm['llm_name']}" + model_record_key = f"{factory_instance.provider_id}@{factory_instance.id}@{llm['llm_name']}" model_key_in_factory.append(model_record_key) manual_modified_models = model_record_map.get(model_record_key, []) active_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.ACTIVE.value] @@ -343,7 +343,7 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None): model_records = model_record_map.get(model_record_key, []) if not model_records: continue - provider_id, instance_id, model_name = model_record_key.split("_") + provider_id, instance_id, model_name = model_record_key.split("@") model_types = [model.model_type for model in model_records if model.status == ActiveStatusEnum.ACTIVE.value] if not model_types: continue @@ -357,4 +357,6 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None): "instance_name": instance_info_map[instance_id].instance_name if instance_info_map.get(instance_id) else "" }) + added_models.sort(key=lambda x: (x["provider_name"], x["instance_name"], x["name"])) + return True, added_models diff --git a/api/apps/services/provider_api_service.py b/api/apps/services/provider_api_service.py index c094674d4d..214757d142 100644 --- a/api/apps/services/provider_api_service.py +++ b/api/apps/services/provider_api_service.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os import json import logging +import asyncio from common.constants import LLMType, ActiveStatusEnum from common.misc_utils import get_uuid @@ -23,6 +25,7 @@ from api.db.joint_services.tenant_model_service import get_model_config_from_pro from api.db.services.tenant_model_provider_service import TenantModelProviderService from api.db.services.tenant_model_instance_service import TenantModelInstanceService from api.db.services.tenant_model_service import TenantModelService +from rag.llm import EmbeddingModel, ChatModel, RerankModel def list_providers(tenant_id: str, all_available: bool = False): @@ -54,6 +57,7 @@ def list_providers(tenant_id: str, all_available: bool = False): "default": factory_info.get("url", "") } }) + providers.sort(key=lambda x: x["name"]) return True, providers # List tenant-configured providers @@ -76,7 +80,7 @@ def list_providers(tenant_id: str, all_available: bool = False): "default": factory_info.get("url", "") } }) - + providers.sort(key=lambda x: x["name"]) return True, providers @@ -117,7 +121,7 @@ def delete_provider(tenant_id: str, provider_name: str): provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) if not provider_obj: return False, f"Provider {provider_name} not found" - instance_objs = TenantModelInstanceService.get_by_provider_id(provider_obj.id) + instance_objs = TenantModelInstanceService.get_all_by_provider_id(provider_obj.id) if not instance_objs: return False, f"No instances found for provider {provider_name}" instance_ids = [instance_obj.id for instance_obj in instance_objs] @@ -169,6 +173,7 @@ def list_provider_models(provider_name: str): "model_types": [llm["model_type"]], "features": None }) + models.sort(key=lambda x: x["name"]) return True, models @@ -202,7 +207,7 @@ def show_provider_model(provider_name: str, model_name: str): } -def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str, base_url: str, region: str): +async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str, base_url: str, region: str, verify: bool): """ Create a provider instance. @@ -215,6 +220,7 @@ def create_provider_instance(tenant_id: str, provider_name: str, instance_name: :param api_key: API key :param base_url: base url :param region: region + :param verify: verify :return: (success, result_or_error_message) """ if not provider_name: @@ -236,6 +242,10 @@ def create_provider_instance(tenant_id: str, provider_name: str, instance_name: same_key_instance = TenantModelInstanceService.get_by_provider_id_and_api_key(provider_obj.id, api_key) if same_key_instance: return False, f"Already exist instance: {same_key_instance.instance_name} with api_key {api_key}" + if verify: + success, msg = await verify_api_key(provider_name, api_key, base_url) + if not success: + return False, msg import json extra_fields = {} @@ -268,7 +278,6 @@ def list_provider_instances(tenant_id: str, provider_name: str): for instance_obj in instance_objs: extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {} instances.append({ - "api_key": instance_obj.api_key, "id": instance_obj.id, "instance_name": instance_obj.instance_name, "provider_id": provider_id, @@ -276,7 +285,93 @@ def list_provider_instances(tenant_id: str, provider_name: str): "status": instance_obj.status, }) - return True, instances + active_instances = [instance for instance in instances if instance["status"] == ActiveStatusEnum.ACTIVE.value] + inactive_instances = [instance for instance in instances if instance["status"] == ActiveStatusEnum.INACTIVE.value] + active_instances.sort(key=lambda x: x["instance_name"]) + inactive_instances.sort(key=lambda x: x["instance_name"]) + + return True, active_instances + inactive_instances + + +async def verify_api_key(provider_name: str, api_key: str, base_url: str=None): + """ + Verify API key for a provider. + + :param provider_name: provider/factory name + :param api_key: API key + :param base_url: base url + :return: (success, result_or_error_message) + """ + if not provider_name: + return False, "Provider name is required" + + factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name] + if not factory_info: + return False, f"Provider '{provider_name}' not found" + + factory_llms = factory_info[0]["llm"] + if not factory_llms: + return False, f"No models found for provider '{provider_name}'" + + # test if api key works + chat_passed, embd_passed, rerank_passed = False, False, False + timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10)) + extra = {"provider": provider_name} + msg = "" + for llm in factory_llms: + if not embd_passed and llm["model_type"] == LLMType.EMBEDDING.value: + assert provider_name in EmbeddingModel, f"Embedding model from {provider_name} is not supported yet." + mdl = EmbeddingModel[provider_name](api_key, llm["llm_name"], base_url=base_url) + try: + arr, tc = asyncio.wait_for( + asyncio.to_thread(mdl.encode, ["Test if the api key is available"]), + timeout=timeout_seconds, + ) + if len(arr[0]) == 0: + raise Exception("Fail") + embd_passed = True + except Exception as e: + msg += f"\nFail to access embedding model({llm['llm_name']}) using this api key." + str(e) + elif not chat_passed and llm["model_type"] == LLMType.CHAT.value: + assert provider_name in ChatModel, f"Chat model from {provider_name} is not supported yet." + mdl = ChatModel[provider_name](api_key, llm["llm_name"], base_url=base_url, **extra) + try: + async def check_streamly(): + async for chunk in mdl.async_chat_streamly( + None, + [{"role": "user", "content": "Hi"}], + {"temperature": 0.9}, + ): + if chunk and isinstance(chunk, str) and chunk.find("**ERROR**") < 0: + return True + return False + + result = await asyncio.wait_for(check_streamly(), timeout=timeout_seconds) + if result: + chat_passed = True + else: + raise Exception("No valid response received") + except Exception as e: + msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e) + elif not rerank_passed and llm["model_type"] == LLMType.RERANK.value: + assert provider_name in RerankModel, f"Rerank model from {provider_name} is not supported yet." + mdl = RerankModel[provider_name](api_key, llm["llm_name"], base_url=base_url) + try: + arr, tc = await asyncio.wait_for( + asyncio.to_thread(mdl.similarity, "What's the weather?", ["Is it sunny today?"]), + timeout=timeout_seconds, + ) + if len(arr) == 0 or tc == 0: + raise Exception("Fail") + rerank_passed = True + logging.debug(f"passed model rerank {llm['llm_name']}") + except Exception as e: + msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e) + if any([embd_passed, chat_passed, rerank_passed]): + msg = "" + break + + return any([embd_passed, chat_passed, rerank_passed]), msg or "success" def show_provider_instance(tenant_id: str, provider_name: str, instance_name: str): @@ -362,6 +457,7 @@ def list_instance_models(tenant_id: str, provider_name: str, instance_name: str, return False, f"Provider '{provider_name}' not found" llms = factory_info[0].get("llm", []) models = [{"name": llm["llm_name"]} for llm in llms] + models.sort(key=lambda x: x["name"]) return True, models # Get instance @@ -407,8 +503,11 @@ def list_instance_models(tenant_id: str, provider_name: str, instance_name: str, "max_tokens": extra_fields.get("max_tokens", 8192), "status": model_info_dict["status"], }) - - return True, models + active_models = [model for model in models if model["status"] == ActiveStatusEnum.ACTIVE.value] + inactive_models = [model for model in models if model["status"] == ActiveStatusEnum.INACTIVE.value] + active_models.sort(key=lambda x: x["name"]) + inactive_models.sort(key=lambda x: x["name"]) + return True, active_models + inactive_models def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str, model_name: str, model_type: str|list[str], max_tokens: int, extra: dict): @@ -434,7 +533,7 @@ def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str extra_fields = {"max_tokens": max_tokens} target_model = [llm for llm in llms if llm["model_type"] == _type and llm["llm_name"] == model_name] if target_model: - extra_fields.update({"is_tool": target_model[0].get("is_tool", False)}) + extra_fields.update({"is_tools": target_model[0].get("is_tools", False)}) extra_fields.update(extra) TenantModelService.insert( model_name=model_name, @@ -506,6 +605,7 @@ def update_model_status(tenant_id: str, provider_name: str, instance_name: str, provider_id=provider_obj.id, instance_id=instance_obj.id, status=status, + extra=json.dumps({"max_tokens": target_llm[0].get("max_tokens", 8192), "is_tools": target_llm[0].get("is_tools", False)}) ) return True, None diff --git a/api/db/db_models.py b/api/db/db_models.py index e676ae97cd..a81b32c22b 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -1372,9 +1372,6 @@ class TenantModelInstance(DataBaseModel): class Meta: db_table = "tenant_model_instance" - indexes = ( - (("api_key", "provider_id"), True), - ) class TenantModel(DataBaseModel): @@ -1731,6 +1728,17 @@ def migrate_db(): alter_db_column_type(migrator, "document", "size", BigIntegerField(default=0, index=True)) alter_db_column_type(migrator, "file", "size", BigIntegerField(default=0, index=True)) alter_db_add_column(migrator, "tenant", "ocr_id", CharField(max_length=128, null=True, help_text="default ocr model ID", index=True)) + for table_name, index_name in [("tenant_model_instance", "idx_api_key_provider_id"), ("tenant_model", "idx_provider_model_instance")]: + try: + migrate(migrator.drop_index(table_name, index_name)) + except (OperationalError, ProgrammingError) as ex: + msg = str(ex) + if "1091" in msg or "can't DROP" in msg.lower() or "does not exist" in msg.lower() or "already exists" in msg.lower(): + pass + else: + logging.critical(f"Failed to drop index {index_name} on {table_name}: {ex}") + except Exception as ex: + logging.critical(f"Failed to drop index {index_name} on {table_name}: {ex}") logging.disable(logging.NOTSET) # this is after re-enabling logging to allow logging changed user emails migrate_add_unique_email(migrator) diff --git a/api/db/joint_services/tenant_model_service.py b/api/db/joint_services/tenant_model_service.py index b52a1377bf..28fc959da6 100644 --- a/api/db/joint_services/tenant_model_service.py +++ b/api/db/joint_services/tenant_model_service.py @@ -116,7 +116,7 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str|enum.Enum "llm_name": model_obj.model_name, "api_base": extra_fields.get("base_url", ""), "model_type": model_obj.model_type, - "is_tool": extra_fields.get("is_tool", is_tool) + "is_tools": extra_fields.get("is_tools", is_tool) } if api_key_payload is not None: model_config["api_key_payload"] = api_key_payload @@ -136,7 +136,7 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str|enum.Enum "llm_name": llm_info["llm_name"], "api_base": extra_fields.get("base_url", ""), "model_type": llm_info["model_type"], - "is_tool": llm_info.get("is_tool", is_tool) + "is_tools": llm_info.get("is_tools", is_tool) } if api_key_payload is not None: model_config["api_key_payload"] = api_key_payload diff --git a/internal/entity/tenant_model_instance.go b/internal/entity/tenant_model_instance.go index 7563f7bd8d..03e7c78b23 100644 --- a/internal/entity/tenant_model_instance.go +++ b/internal/entity/tenant_model_instance.go @@ -20,8 +20,8 @@ package entity type TenantModelInstance struct { ID string `gorm:"column:id;primaryKey;size:32" json:"id"` InstanceName string `gorm:"column:instance_name;size:128;not null" json:"instance_name"` - ProviderID string `gorm:"column:provider_id;size:32;not null;uniqueIndex:idx_api_key_provider_id" json:"provider_id"` - APIKey string `gorm:"column:api_key;size:512;not null;uniqueIndex:idx_api_key_provider_id" json:"api_key"` + ProviderID string `gorm:"column:provider_id;size:32;not null" json:"provider_id"` + APIKey string `gorm:"column:api_key;size:512;not null" json:"api_key"` Status string `gorm:"column:status;size:32;default:'active'" json:"status"` Extra string `gorm:"column:extra;size:512;default:'{}'" json:"extra"` BaseModel diff --git a/tools/scripts/mysql_migration.py b/tools/scripts/mysql_migration.py index 5079592523..8fe4390daa 100644 --- a/tools/scripts/mysql_migration.py +++ b/tools/scripts/mysql_migration.py @@ -539,7 +539,6 @@ class TenantModelInstanceStage(MigrationStage): create_date DATETIME, update_time BIGINT, update_date DATETIME, - UNIQUE INDEX idx_api_key_provider_id (api_key, provider_id), INDEX idx_provider_id (provider_id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 """ @@ -551,10 +550,34 @@ class TenantModelStage(MigrationStage): """Migrate tenant_llm to tenant_model""" name = "tenant_model" - description = "Migrate tenant_llm to tenant_model (only status='0' records)" + description = "Migrate tenant_llm to tenant_model (status='0' records, plus status='1' for empty-llm factories)" source_tables = ["tenant_llm", "tenant_model_provider", "tenant_model_instance"] target_tables = ["tenant_model"] + @staticmethod + def _get_empty_llm_factories() -> list[str]: + """Load factory names whose llm field is an empty list from conf/llm_factories.json""" + conf_path = os.path.join(PROJECT_BASE, "conf", "llm_factories.json") + with open(conf_path, "r") as f: + data = json.load(f) + factories = [] + for key, items in data.items(): + if isinstance(items, list): + for item in items: + if isinstance(item, dict): + llm = item.get("llm") + if isinstance(llm, list) and len(llm) == 0: + factories.append(item["name"]) + return factories + + def _build_status_condition(self) -> str: + """Build SQL WHERE condition for status filtering""" + empty_factories = self._get_empty_llm_factories() + if empty_factories: + placeholders = ", ".join(f"'{f}'" for f in empty_factories) + return f"(tl.status = '0' OR (tl.status = '1' AND tl.llm_factory IN ({placeholders})))" + return "tl.status = '0'" + def current_timestamp(self) -> int: return int(time.time()) @@ -598,24 +621,26 @@ class TenantModelStage(MigrationStage): logger.info("Target table 'tenant_model' does not exist, will create") return True - # Check if there's data to migrate (only status='0' records) + status_condition = self._build_status_condition() + + # Check if there's data to migrate cursor = self.db.execute_sql( - "SELECT COUNT(*) FROM (" - " SELECT tl.id " - " FROM tenant_llm tl " - " INNER JOIN tenant_model_provider tmp ON tmp.tenant_id = tl.tenant_id AND tmp.provider_name = tl.llm_factory " - " INNER JOIN tenant_model_instance tmi ON tmi.provider_id = tmp.id AND tmi.api_key = tl.api_key " - " WHERE tl.status = '0' " - " AND NOT EXISTS (" - " SELECT 1 FROM tenant_model tm " - " WHERE tm.provider_id = tmp.id AND tm.model_name = tl.llm_name AND tm.instance_id = tmi.id" - " )" - ") AS distinct_records" + f"SELECT COUNT(*) FROM (" + f" SELECT tl.id " + f" FROM tenant_llm tl " + f" INNER JOIN tenant_model_provider tmp ON tmp.tenant_id = tl.tenant_id AND tmp.provider_name = tl.llm_factory " + f" INNER JOIN tenant_model_instance tmi ON tmi.provider_id = tmp.id AND tmi.api_key = tl.api_key " + f" WHERE {status_condition} " + f" AND NOT EXISTS (" + f" SELECT 1 FROM tenant_model tm " + f" WHERE tm.provider_id = tmp.id AND tm.model_name = tl.llm_name AND tm.instance_id = tmi.id" + f" )" + f") AS distinct_records" ) count = cursor.fetchone()[0] if count == 0: - logger.info("No new data to migrate from tenant_llm to tenant_model (status='0' only)") + logger.info("No new data to migrate from tenant_llm to tenant_model") return False logger.info(f"Found {count} rows to migrate from tenant_llm to tenant_model") @@ -652,19 +677,21 @@ class TenantModelStage(MigrationStage): logger.info("[CREATE TABLE ONLY] Target table created/verified, skipping data migration") return 0, self.target_tables + status_condition = self._build_status_condition() + # Get records from tenant_llm with provider_id and instance_id lookup - # Only migrate records where status='0' + # Migrate status='0' records, plus status='1' for empty-llm factories cursor = self.db.execute_sql( - "SELECT tl.id, tl.llm_name, tmp.id as provider_id, tmi.id as instance_id, " - " tl.model_type, tl.status " - "FROM tenant_llm tl " - "INNER JOIN tenant_model_provider tmp ON tmp.tenant_id = tl.tenant_id AND tmp.provider_name = tl.llm_factory " - "INNER JOIN tenant_model_instance tmi ON tmi.provider_id = tmp.id AND tmi.api_key = tl.api_key " - "WHERE tl.status = '0' " - "AND NOT EXISTS (" - " SELECT 1 FROM tenant_model tm " - " WHERE tm.provider_id = tmp.id AND tm.model_name = tl.llm_name AND tm.instance_id = tmi.id" - ")" + f"SELECT tl.id, tl.llm_name, tmp.id as provider_id, tmi.id as instance_id, " + f" tl.model_type, tl.status " + f"FROM tenant_llm tl " + f"INNER JOIN tenant_model_provider tmp ON tmp.tenant_id = tl.tenant_id AND tmp.provider_name = tl.llm_factory " + f"INNER JOIN tenant_model_instance tmi ON tmi.provider_id = tmp.id AND tmi.api_key = tl.api_key " + f"WHERE {status_condition} " + f"AND NOT EXISTS (" + f" SELECT 1 FROM tenant_model tm " + f" WHERE tm.provider_id = tmp.id AND tm.model_name = tl.llm_name AND tm.instance_id = tmi.id" + f")" ) records = cursor.fetchall() @@ -727,7 +754,6 @@ class TenantModelStage(MigrationStage): update_time BIGINT, update_date DATETIME, INDEX idx_instance_id (instance_id), - UNIQUE INDEX idx_provider_model_instance (provider_id, model_name, instance_id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 """ self.db.execute_sql(create_sql)