diff --git a/api/apps/restful_apis/provider_api.py b/api/apps/restful_apis/provider_api.py index ffc06a458d..85554073d2 100644 --- a/api/apps/restful_apis/provider_api.py +++ b/api/apps/restful_apis/provider_api.py @@ -630,6 +630,66 @@ def list_instance_models(tenant_id: str = None, provider_name: str = None, insta return get_error_data_result(message="Internal server error") +@manager.route("/providers//instances//models", methods=["PUT"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def update_instance_models(tenant_id: str, provider_name: str, instance_name: str): + """ + Batch update model_type for models in instance. + --- + tags: + - Providers + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: provider_name + type: string + required: true + description: Provider name. + - in: path + name: instance_name + type: string + required: true + description: Instance name. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Model details. + required: true + schema: + type: object + required: + - model_name + - model_type + properties: + model_name: + type: list of string + description: Model name. + model_type: + type: list of string + description: Model type. + """ + data = await request.get_json() + if not data or "model_name" not in data or "model_type" not in data: + return get_error_argument_result(message="model_name and model_type are required") + model_name = data["model_name"] + model_type = data["model_type"] + try: + success, msg = provider_api_service.update_instance_models(tenant_id, provider_name, instance_name, model_name, model_type) + if success: + return get_result(message=msg) + else: + return get_error_data_result(message=msg) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + @manager.route("/providers//instances//models", methods=["POST"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs diff --git a/api/apps/services/models_api_service.py b/api/apps/services/models_api_service.py index 8ce14874ec..aaef7f90ca 100644 --- a/api/apps/services/models_api_service.py +++ b/api/apps/services/models_api_service.py @@ -127,7 +127,7 @@ def _get_model_info(tenant_id: str, default_model: str, model_type: str): model_entity = TenantModelService.get_by_provider_id_and_instance_id_and_model_type_and_model_name( provider_obj.id, instance_obj.id, model_type, model_name ) - enable = model_entity is None or model_entity.status != ActiveStatusEnum.INACTIVE.value + enable = model_entity is None or model_entity.status == ActiveStatusEnum.ACTIVE.value if not enable: return None @@ -214,7 +214,7 @@ def _check_model_available(tenant_id: str, provider_name: str, instance_name: st provider_obj.id, instance_obj.id, model_type, model_name ) if model_entity: - if model_entity.status == "inactive": + if model_entity.status != ActiveStatusEnum.ACTIVE.value: return False, f"Model '{model_name}' isn't available" return True, None @@ -364,7 +364,8 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None): manual_modified_models = model_record_map.get(model_record_key, []) active_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.ACTIVE.value] inactive_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.INACTIVE.value] - model_types = list(set(factory_model_types + active_model_types) - set(inactive_model_types)) + unsupport_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.UNSUPPORTED.value] + model_types = list(set(factory_model_types + active_model_types) - set(inactive_model_types) - set(unsupport_model_types)) if not model_types: continue diff --git a/api/apps/services/provider_api_service.py b/api/apps/services/provider_api_service.py index 23497da605..b011ed6b4b 100644 --- a/api/apps/services/provider_api_service.py +++ b/api/apps/services/provider_api_service.py @@ -313,9 +313,6 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_ api_key_str = "" if api_key: api_key_str = api_key if isinstance(api_key, str) else json.dumps(api_key) - same_key_instance = TenantModelInstanceService.get_by_provider_id_and_api_key(provider_obj.id, api_key_str) - if same_key_instance: - return False, f"Already exist instance: {same_key_instance.instance_name} with api_key {api_key}" success, msg = await verify_api_key(provider_name, api_key, base_url, region, model_info) if not success: return False, msg @@ -594,6 +591,57 @@ def drop_provider_instances(tenant_id: str, provider_name: str, instance_names: return True, None +def _hybrid_get_instance_models(provider_name: str, instance_id: str): + # List all models from the LLM dictionary for this provider + factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name] + if not factory_info: + return False, f"Provider '{provider_name}' not found" + + # Get model records for this instance from tenant_model table + model_records = TenantModelService.get_models_by_instance_id(instance_id) + # Build a map of model_name -> status, type + model_info_map: dict = {} + model_unsupported_type_map = {} + for model_record in model_records: + if model_record.status == ActiveStatusEnum.UNSUPPORTED: + if model_unsupported_type_map.get(model_record.model_name): + model_unsupported_type_map[model_record.model_name].append(model_record.model_type) + else: + model_unsupported_type_map[model_record.model_name] = [model_record.model_type] + continue + if model_info_map.get(model_record.model_name): + model_info_map[model_record.model_name]["model_type"].append(model_record.model_type) + else: + model_info_map[model_record.model_name] = { + "status": model_record.status, + "model_type": [model_record.model_type], + "extra": model_record.extra + } + + llms = factory_info[0].get("llm", []) + models = [] + for llm in llms: + models.append({ + "name": llm["llm_name"], + "model_type": list( + set(_factory_model_types(llm) + model_info_map.get(llm["llm_name"], {}).get("model_type", [])) - set(model_unsupported_type_map.get(llm["llm_name"], [])) + ), + "max_tokens": llm.get("max_tokens"), + "status": model_info_map.get(llm["llm_name"], {}).get("status", "active"), + }) + factory_models = [m["name"] for m in models] + for model_name, model_info_dict in model_info_map.items(): + if model_name not in factory_models: + extra_fields = json.loads(model_info_dict["extra"]) if model_info_dict["extra"] else {} + models.append({ + "name": model_name, + "model_type": set(model_info_dict["model_type"]) - set(model_unsupported_type_map.get(model_name, [])), + "max_tokens": extra_fields.get("max_tokens", 8192), + "status": model_info_dict["status"], + }) + return True, models + + def list_instance_models(tenant_id: str, provider_name: str, instance_name: str, supported_only: bool = False): """ List models for a provider instance. @@ -628,47 +676,40 @@ def list_instance_models(tenant_id: str, provider_name: str, instance_name: str, if not instance_obj: return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'" - # Get model records for this instance from tenant_model table - model_records = TenantModelService.get_models_by_instance_id(instance_obj.id) - # Build a map of model_name -> status, type - model_info_map: dict = {} - for model_record in model_records: - if model_info_map.get(model_record.model_name): - model_info_map[model_record.model_name]["model_type"].append(model_record.model_type) - else: - model_info_map[model_record.model_name] = { - "status": model_record.status, - "model_type": [model_record.model_type], - "extra": model_record.extra + return _hybrid_get_instance_models(provider_name, instance_obj.id) + + +def update_instance_models(tenant_id: str, provider_name: str, instance_name: str, model_names: list, model_types: list): + if not model_names or not model_types: + return False, "model_name and model_type are required" + + provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) + if not provider_obj: + return False, f"No provider found for provider '{provider_name}'" + instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_name) + if not instance_obj: + return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'" + + found, models = _hybrid_get_instance_models(provider_name, instance_obj.id) + if not found: + return False, models + + model_info_map = {model["name"]: model for model in models} + not_exist_models = set(model_names) - set(model_info_map.keys()) + if not_exist_models: + return False, f"Models {not_exist_models} not found for provider '{provider_name}' and instance '{instance_name}'" + for model_name in model_names: + model_info = model_info_map.get(model_name, {}) + TenantModelService.upsert_model_type( + provider_obj.id, + instance_obj.id, + model_name, + { + "add": list(set(model_types) - set(model_info["model_type"])), + "delete": list(set(model_info["model_type"]) - set(model_types)) } - - # List all models from the LLM dictionary for this provider - factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name] - if not factory_info: - return False, f"Provider '{provider_name}' not found" - - llms = factory_info[0].get("llm", []) - models = [] - for llm in llms: - models.append({ - "name": llm["llm_name"], - "model_type": list( - dict.fromkeys(_factory_model_types(llm) + model_info_map.get(llm["llm_name"], {}).get("model_type", [])) - ), - "max_tokens": llm.get("max_tokens"), - "status": model_info_map.get(llm["llm_name"], {}).get("status", "active"), - }) - factory_models = [m["name"] for m in models] - for model_name, model_info_dict in model_info_map.items(): - if model_name not in factory_models: - extra_fields = json.loads(model_info_dict["extra"]) if model_info_dict["extra"] else {} - models.append({ - "name": model_name, - "model_type": model_info_dict["model_type"], - "max_tokens": extra_fields.get("max_tokens", 8192), - "status": model_info_dict["status"], - }) - return True, models + ) + return True, "success" def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str, model_name: str, model_type: str|list[str], max_tokens: int=8192, extra: dict=None): @@ -742,7 +783,7 @@ def update_model_status(tenant_id: str, provider_name: str, instance_name: str, if model_obj_list: # Model record exists — update its status - TenantModelService.batch_update_model_status([m.id for m in model_obj_list], status) + TenantModelService.batch_update_model_status([m.id for m in model_obj_list if m.status != ActiveStatusEnum.UNSUPPORTED.value], status) else: # Model record does not exist if status == ActiveStatusEnum.ACTIVE.value: diff --git a/api/db/joint_services/tenant_model_service.py b/api/db/joint_services/tenant_model_service.py index c12f3b764a..2bc9a1c103 100644 --- a/api/db/joint_services/tenant_model_service.py +++ b/api/db/joint_services/tenant_model_service.py @@ -214,6 +214,8 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str|enum.Enum if model_obj: if model_obj.status == ActiveStatusEnum.INACTIVE.value: raise LookupError(f"Model {model_name} is disabled.") + if model_obj.status == ActiveStatusEnum.UNSUPPORTED.value: + raise LookupError(f"Model {model_name} cannot be used as {model_type_val} model.") model_extra = json.loads(model_obj.extra) if model_obj.extra else {} model_config = { @@ -281,6 +283,7 @@ def get_model_type_by_name(tenant_id: str, model_name: str): if not instance_obj: raise LookupError(f"Instance {instance_name} not found for model {model_name}.") model_objs = TenantModelService.get_by_provider_id_and_instance_id_and_model_name(provider_obj.id, instance_obj.id, pure_model_name) + types_in_json = [] if not model_objs: extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {} region = extra_fields.get("region", "default") @@ -294,8 +297,8 @@ def get_model_type_by_name(tenant_id: str, model_name: str): llm_list = [llm for llm in fac_list[0]["llm"] if llm["llm_name"] == pure_model_name] if not llm_list: raise LookupError(f"Model {pure_model_name} not found for model {model_name}.") - return _factory_model_types(llm_list[0]) - return [model_obj.model_type for model_obj in model_objs] + types_in_json = _factory_model_types(llm_list[0]) + return list(set(types_in_json + [model_obj.model_type for model_obj in model_objs if model_obj.status != ActiveStatusEnum.UNSUPPORTED.value]) - {model_obj.model_type for model_obj in model_objs if model_obj.status == ActiveStatusEnum.UNSUPPORTED.value}) def delete_models_by_instance_ids(instance_ids: list[str]): @@ -329,6 +332,7 @@ def get_models_by_tenant_and_provider_and_model_type(tenant_id: str, provider_na results = [] for inst in instances: models = TenantModelService.get_by_provider_id_and_instance_id_and_model_type(provider_obj.id, inst.id, model_type) - if models: - results.extend(models) + supported = [model for model in models if model.status != ActiveStatusEnum.UNSUPPORTED.value] + if supported: + results.extend(supported) return results diff --git a/api/db/services/tenant_model_service.py b/api/db/services/tenant_model_service.py index e75390956d..ebef70137c 100644 --- a/api/db/services/tenant_model_service.py +++ b/api/db/services/tenant_model_service.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from common.constants import ActiveStatusEnum from api.db.db_models import DB, TenantModel from api.db.services.common_service import CommonService @@ -59,6 +60,63 @@ class TenantModelService(CommonService): def batch_update_model_status(cls, model_ids, status): return cls.model.update(status=status).where(cls.model.id.in_(model_ids)).execute() + @classmethod + @DB.connection_context() + def upsert_model_type(cls, provider_id: str, instance_id: str, model_name: str, operation: dict): + model_type_records = cls.model.select().where(cls.model.provider_id == provider_id, cls.model.instance_id == instance_id, cls.model.model_name == model_name) + if not model_type_records: + for _type in operation.get("add", []): + cls.insert( + model_name=model_name, + provider_id=provider_id, + instance_id=instance_id, + model_type=_type, + extra="{}" + ) + for _type in operation.get("delete", []): + cls.insert( + model_name=model_name, + provider_id=provider_id, + instance_id=instance_id, + model_type=_type, + status=ActiveStatusEnum.UNSUPPORTED, + extra="{}" + ) + return len(operation.get("add", [])) + len(operation.get("delete", [])) + model_record_example = [model_record for model_record in model_type_records if model_record.status != ActiveStatusEnum.UNSUPPORTED] + extra_fields = model_record_example[0].extra if model_record_example else "{}" + model_status = model_record_example[0].status if model_record_example else ActiveStatusEnum.ACTIVE.value + type_record_map = {record.model_type: record for record in model_type_records} + operated_cnt = 0 + for _type in operation.get("add", []): + if type_record_map.get(_type): + cls.update_by_id(type_record_map[_type].id, {"status": model_status}) + + else: + cls.insert( + model_name=model_name, + provider_id=provider_id, + instance_id=instance_id, + model_type=_type, + status=model_status, + extra=extra_fields + ) + operated_cnt += 1 + for _type in operation.get("delete", []): + if type_record_map.get(_type): + cls.update_by_id(type_record_map[_type].id, {"status": ActiveStatusEnum.UNSUPPORTED.value}) + else: + cls.insert( + model_name=model_name, + provider_id=provider_id, + instance_id=instance_id, + model_type=_type, + status=ActiveStatusEnum.UNSUPPORTED.value, + extra=extra_fields + ) + operated_cnt += 1 + return operated_cnt + @classmethod @DB.connection_context() def delete_by_id(cls, model_id): diff --git a/common/constants.py b/common/constants.py index 00cd2be0e6..9f6d85db2e 100644 --- a/common/constants.py +++ b/common/constants.py @@ -69,6 +69,7 @@ class StatusEnum(Enum): class ActiveStatusEnum(Enum): ACTIVE = "active" INACTIVE = "inactive" + UNSUPPORTED = "unsupported" class ActiveEnum(Enum):