Feat: model provider (#16028)

### What problem does this PR solve?

Feat:
- Allow upsert model_type for instance model

Fix:
- Allow create instance with duplicate api_key

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Lynn
2026-06-15 19:10:33 +08:00
committed by GitHub
parent ba93ac3bd7
commit 47495c1f6a
6 changed files with 216 additions and 51 deletions

View File

@@ -630,6 +630,66 @@ def list_instance_models(tenant_id: str = None, provider_name: str = None, insta
return get_error_data_result(message="Internal server error")
@manager.route("/providers/<provider_name>/instances/<instance_name>/models", methods=["PUT"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
async def update_instance_models(tenant_id: str, provider_name: str, instance_name: str):
"""
Batch update model_type for models in instance.
---
tags:
- Providers
security:
- ApiKeyAuth: []
parameters:
- in: path
name: provider_name
type: string
required: true
description: Provider name.
- in: path
name: instance_name
type: string
required: true
description: Instance name.
- in: header
name: Authorization
type: string
required: true
description: Bearer token for authentication.
- in: body
name: body
description: Model details.
required: true
schema:
type: object
required:
- model_name
- model_type
properties:
model_name:
type: list of string
description: Model name.
model_type:
type: list of string
description: Model type.
"""
data = await request.get_json()
if not data or "model_name" not in data or "model_type" not in data:
return get_error_argument_result(message="model_name and model_type are required")
model_name = data["model_name"]
model_type = data["model_type"]
try:
success, msg = provider_api_service.update_instance_models(tenant_id, provider_name, instance_name, model_name, model_type)
if success:
return get_result(message=msg)
else:
return get_error_data_result(message=msg)
except Exception as e:
logging.exception(e)
return get_error_data_result(message="Internal server error")
@manager.route("/providers/<provider_name>/instances/<instance_name>/models", methods=["POST"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs

View File

@@ -127,7 +127,7 @@ def _get_model_info(tenant_id: str, default_model: str, model_type: str):
model_entity = TenantModelService.get_by_provider_id_and_instance_id_and_model_type_and_model_name(
provider_obj.id, instance_obj.id, model_type, model_name
)
enable = model_entity is None or model_entity.status != ActiveStatusEnum.INACTIVE.value
enable = model_entity is None or model_entity.status == ActiveStatusEnum.ACTIVE.value
if not enable:
return None
@@ -214,7 +214,7 @@ def _check_model_available(tenant_id: str, provider_name: str, instance_name: st
provider_obj.id, instance_obj.id, model_type, model_name
)
if model_entity:
if model_entity.status == "inactive":
if model_entity.status != ActiveStatusEnum.ACTIVE.value:
return False, f"Model '{model_name}' isn't available"
return True, None
@@ -364,7 +364,8 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None):
manual_modified_models = model_record_map.get(model_record_key, [])
active_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.ACTIVE.value]
inactive_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.INACTIVE.value]
model_types = list(set(factory_model_types + active_model_types) - set(inactive_model_types))
unsupport_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.UNSUPPORTED.value]
model_types = list(set(factory_model_types + active_model_types) - set(inactive_model_types) - set(unsupport_model_types))
if not model_types:
continue

View File

@@ -313,9 +313,6 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_
api_key_str = ""
if api_key:
api_key_str = api_key if isinstance(api_key, str) else json.dumps(api_key)
same_key_instance = TenantModelInstanceService.get_by_provider_id_and_api_key(provider_obj.id, api_key_str)
if same_key_instance:
return False, f"Already exist instance: {same_key_instance.instance_name} with api_key {api_key}"
success, msg = await verify_api_key(provider_name, api_key, base_url, region, model_info)
if not success:
return False, msg
@@ -594,6 +591,57 @@ def drop_provider_instances(tenant_id: str, provider_name: str, instance_names:
return True, None
def _hybrid_get_instance_models(provider_name: str, instance_id: str):
# List all models from the LLM dictionary for this provider
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name]
if not factory_info:
return False, f"Provider '{provider_name}' not found"
# Get model records for this instance from tenant_model table
model_records = TenantModelService.get_models_by_instance_id(instance_id)
# Build a map of model_name -> status, type
model_info_map: dict = {}
model_unsupported_type_map = {}
for model_record in model_records:
if model_record.status == ActiveStatusEnum.UNSUPPORTED:
if model_unsupported_type_map.get(model_record.model_name):
model_unsupported_type_map[model_record.model_name].append(model_record.model_type)
else:
model_unsupported_type_map[model_record.model_name] = [model_record.model_type]
continue
if model_info_map.get(model_record.model_name):
model_info_map[model_record.model_name]["model_type"].append(model_record.model_type)
else:
model_info_map[model_record.model_name] = {
"status": model_record.status,
"model_type": [model_record.model_type],
"extra": model_record.extra
}
llms = factory_info[0].get("llm", [])
models = []
for llm in llms:
models.append({
"name": llm["llm_name"],
"model_type": list(
set(_factory_model_types(llm) + model_info_map.get(llm["llm_name"], {}).get("model_type", [])) - set(model_unsupported_type_map.get(llm["llm_name"], []))
),
"max_tokens": llm.get("max_tokens"),
"status": model_info_map.get(llm["llm_name"], {}).get("status", "active"),
})
factory_models = [m["name"] for m in models]
for model_name, model_info_dict in model_info_map.items():
if model_name not in factory_models:
extra_fields = json.loads(model_info_dict["extra"]) if model_info_dict["extra"] else {}
models.append({
"name": model_name,
"model_type": set(model_info_dict["model_type"]) - set(model_unsupported_type_map.get(model_name, [])),
"max_tokens": extra_fields.get("max_tokens", 8192),
"status": model_info_dict["status"],
})
return True, models
def list_instance_models(tenant_id: str, provider_name: str, instance_name: str, supported_only: bool = False):
"""
List models for a provider instance.
@@ -628,47 +676,40 @@ def list_instance_models(tenant_id: str, provider_name: str, instance_name: str,
if not instance_obj:
return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'"
# Get model records for this instance from tenant_model table
model_records = TenantModelService.get_models_by_instance_id(instance_obj.id)
# Build a map of model_name -> status, type
model_info_map: dict = {}
for model_record in model_records:
if model_info_map.get(model_record.model_name):
model_info_map[model_record.model_name]["model_type"].append(model_record.model_type)
else:
model_info_map[model_record.model_name] = {
"status": model_record.status,
"model_type": [model_record.model_type],
"extra": model_record.extra
return _hybrid_get_instance_models(provider_name, instance_obj.id)
def update_instance_models(tenant_id: str, provider_name: str, instance_name: str, model_names: list, model_types: list):
if not model_names or not model_types:
return False, "model_name and model_type are required"
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
if not provider_obj:
return False, f"No provider found for provider '{provider_name}'"
instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_name)
if not instance_obj:
return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'"
found, models = _hybrid_get_instance_models(provider_name, instance_obj.id)
if not found:
return False, models
model_info_map = {model["name"]: model for model in models}
not_exist_models = set(model_names) - set(model_info_map.keys())
if not_exist_models:
return False, f"Models {not_exist_models} not found for provider '{provider_name}' and instance '{instance_name}'"
for model_name in model_names:
model_info = model_info_map.get(model_name, {})
TenantModelService.upsert_model_type(
provider_obj.id,
instance_obj.id,
model_name,
{
"add": list(set(model_types) - set(model_info["model_type"])),
"delete": list(set(model_info["model_type"]) - set(model_types))
}
# List all models from the LLM dictionary for this provider
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name]
if not factory_info:
return False, f"Provider '{provider_name}' not found"
llms = factory_info[0].get("llm", [])
models = []
for llm in llms:
models.append({
"name": llm["llm_name"],
"model_type": list(
dict.fromkeys(_factory_model_types(llm) + model_info_map.get(llm["llm_name"], {}).get("model_type", []))
),
"max_tokens": llm.get("max_tokens"),
"status": model_info_map.get(llm["llm_name"], {}).get("status", "active"),
})
factory_models = [m["name"] for m in models]
for model_name, model_info_dict in model_info_map.items():
if model_name not in factory_models:
extra_fields = json.loads(model_info_dict["extra"]) if model_info_dict["extra"] else {}
models.append({
"name": model_name,
"model_type": model_info_dict["model_type"],
"max_tokens": extra_fields.get("max_tokens", 8192),
"status": model_info_dict["status"],
})
return True, models
)
return True, "success"
def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str, model_name: str, model_type: str|list[str], max_tokens: int=8192, extra: dict=None):
@@ -742,7 +783,7 @@ def update_model_status(tenant_id: str, provider_name: str, instance_name: str,
if model_obj_list:
# Model record exists — update its status
TenantModelService.batch_update_model_status([m.id for m in model_obj_list], status)
TenantModelService.batch_update_model_status([m.id for m in model_obj_list if m.status != ActiveStatusEnum.UNSUPPORTED.value], status)
else:
# Model record does not exist
if status == ActiveStatusEnum.ACTIVE.value:

View File

@@ -214,6 +214,8 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str|enum.Enum
if model_obj:
if model_obj.status == ActiveStatusEnum.INACTIVE.value:
raise LookupError(f"Model {model_name} is disabled.")
if model_obj.status == ActiveStatusEnum.UNSUPPORTED.value:
raise LookupError(f"Model {model_name} cannot be used as {model_type_val} model.")
model_extra = json.loads(model_obj.extra) if model_obj.extra else {}
model_config = {
@@ -281,6 +283,7 @@ def get_model_type_by_name(tenant_id: str, model_name: str):
if not instance_obj:
raise LookupError(f"Instance {instance_name} not found for model {model_name}.")
model_objs = TenantModelService.get_by_provider_id_and_instance_id_and_model_name(provider_obj.id, instance_obj.id, pure_model_name)
types_in_json = []
if not model_objs:
extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {}
region = extra_fields.get("region", "default")
@@ -294,8 +297,8 @@ def get_model_type_by_name(tenant_id: str, model_name: str):
llm_list = [llm for llm in fac_list[0]["llm"] if llm["llm_name"] == pure_model_name]
if not llm_list:
raise LookupError(f"Model {pure_model_name} not found for model {model_name}.")
return _factory_model_types(llm_list[0])
return [model_obj.model_type for model_obj in model_objs]
types_in_json = _factory_model_types(llm_list[0])
return list(set(types_in_json + [model_obj.model_type for model_obj in model_objs if model_obj.status != ActiveStatusEnum.UNSUPPORTED.value]) - {model_obj.model_type for model_obj in model_objs if model_obj.status == ActiveStatusEnum.UNSUPPORTED.value})
def delete_models_by_instance_ids(instance_ids: list[str]):
@@ -329,6 +332,7 @@ def get_models_by_tenant_and_provider_and_model_type(tenant_id: str, provider_na
results = []
for inst in instances:
models = TenantModelService.get_by_provider_id_and_instance_id_and_model_type(provider_obj.id, inst.id, model_type)
if models:
results.extend(models)
supported = [model for model in models if model.status != ActiveStatusEnum.UNSUPPORTED.value]
if supported:
results.extend(supported)
return results

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from common.constants import ActiveStatusEnum
from api.db.db_models import DB, TenantModel
from api.db.services.common_service import CommonService
@@ -59,6 +60,63 @@ class TenantModelService(CommonService):
def batch_update_model_status(cls, model_ids, status):
return cls.model.update(status=status).where(cls.model.id.in_(model_ids)).execute()
@classmethod
@DB.connection_context()
def upsert_model_type(cls, provider_id: str, instance_id: str, model_name: str, operation: dict):
model_type_records = cls.model.select().where(cls.model.provider_id == provider_id, cls.model.instance_id == instance_id, cls.model.model_name == model_name)
if not model_type_records:
for _type in operation.get("add", []):
cls.insert(
model_name=model_name,
provider_id=provider_id,
instance_id=instance_id,
model_type=_type,
extra="{}"
)
for _type in operation.get("delete", []):
cls.insert(
model_name=model_name,
provider_id=provider_id,
instance_id=instance_id,
model_type=_type,
status=ActiveStatusEnum.UNSUPPORTED,
extra="{}"
)
return len(operation.get("add", [])) + len(operation.get("delete", []))
model_record_example = [model_record for model_record in model_type_records if model_record.status != ActiveStatusEnum.UNSUPPORTED]
extra_fields = model_record_example[0].extra if model_record_example else "{}"
model_status = model_record_example[0].status if model_record_example else ActiveStatusEnum.ACTIVE.value
type_record_map = {record.model_type: record for record in model_type_records}
operated_cnt = 0
for _type in operation.get("add", []):
if type_record_map.get(_type):
cls.update_by_id(type_record_map[_type].id, {"status": model_status})
else:
cls.insert(
model_name=model_name,
provider_id=provider_id,
instance_id=instance_id,
model_type=_type,
status=model_status,
extra=extra_fields
)
operated_cnt += 1
for _type in operation.get("delete", []):
if type_record_map.get(_type):
cls.update_by_id(type_record_map[_type].id, {"status": ActiveStatusEnum.UNSUPPORTED.value})
else:
cls.insert(
model_name=model_name,
provider_id=provider_id,
instance_id=instance_id,
model_type=_type,
status=ActiveStatusEnum.UNSUPPORTED.value,
extra=extra_fields
)
operated_cnt += 1
return operated_cnt
@classmethod
@DB.connection_context()
def delete_by_id(cls, model_id):

View File

@@ -69,6 +69,7 @@ class StatusEnum(Enum):
class ActiveStatusEnum(Enum):
ACTIVE = "active"
INACTIVE = "inactive"
UNSUPPORTED = "unsupported"
class ActiveEnum(Enum):