mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Feat: model provider (#16028)
### What problem does this PR solve? Feat: - Allow upsert model_type for instance model Fix: - Allow create instance with duplicate api_key ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -630,6 +630,66 @@ def list_instance_models(tenant_id: str = None, provider_name: str = None, insta
|
||||
return get_error_data_result(message="Internal server error")
|
||||
|
||||
|
||||
@manager.route("/providers/<provider_name>/instances/<instance_name>/models", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
async def update_instance_models(tenant_id: str, provider_name: str, instance_name: str):
|
||||
"""
|
||||
Batch update model_type for models in instance.
|
||||
---
|
||||
tags:
|
||||
- Providers
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: path
|
||||
name: provider_name
|
||||
type: string
|
||||
required: true
|
||||
description: Provider name.
|
||||
- in: path
|
||||
name: instance_name
|
||||
type: string
|
||||
required: true
|
||||
description: Instance name.
|
||||
- in: header
|
||||
name: Authorization
|
||||
type: string
|
||||
required: true
|
||||
description: Bearer token for authentication.
|
||||
- in: body
|
||||
name: body
|
||||
description: Model details.
|
||||
required: true
|
||||
schema:
|
||||
type: object
|
||||
required:
|
||||
- model_name
|
||||
- model_type
|
||||
properties:
|
||||
model_name:
|
||||
type: list of string
|
||||
description: Model name.
|
||||
model_type:
|
||||
type: list of string
|
||||
description: Model type.
|
||||
"""
|
||||
data = await request.get_json()
|
||||
if not data or "model_name" not in data or "model_type" not in data:
|
||||
return get_error_argument_result(message="model_name and model_type are required")
|
||||
model_name = data["model_name"]
|
||||
model_type = data["model_type"]
|
||||
try:
|
||||
success, msg = provider_api_service.update_instance_models(tenant_id, provider_name, instance_name, model_name, model_type)
|
||||
if success:
|
||||
return get_result(message=msg)
|
||||
else:
|
||||
return get_error_data_result(message=msg)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Internal server error")
|
||||
|
||||
|
||||
@manager.route("/providers/<provider_name>/instances/<instance_name>/models", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
|
||||
@@ -127,7 +127,7 @@ def _get_model_info(tenant_id: str, default_model: str, model_type: str):
|
||||
model_entity = TenantModelService.get_by_provider_id_and_instance_id_and_model_type_and_model_name(
|
||||
provider_obj.id, instance_obj.id, model_type, model_name
|
||||
)
|
||||
enable = model_entity is None or model_entity.status != ActiveStatusEnum.INACTIVE.value
|
||||
enable = model_entity is None or model_entity.status == ActiveStatusEnum.ACTIVE.value
|
||||
|
||||
if not enable:
|
||||
return None
|
||||
@@ -214,7 +214,7 @@ def _check_model_available(tenant_id: str, provider_name: str, instance_name: st
|
||||
provider_obj.id, instance_obj.id, model_type, model_name
|
||||
)
|
||||
if model_entity:
|
||||
if model_entity.status == "inactive":
|
||||
if model_entity.status != ActiveStatusEnum.ACTIVE.value:
|
||||
return False, f"Model '{model_name}' isn't available"
|
||||
return True, None
|
||||
|
||||
@@ -364,7 +364,8 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None):
|
||||
manual_modified_models = model_record_map.get(model_record_key, [])
|
||||
active_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.ACTIVE.value]
|
||||
inactive_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.INACTIVE.value]
|
||||
model_types = list(set(factory_model_types + active_model_types) - set(inactive_model_types))
|
||||
unsupport_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.UNSUPPORTED.value]
|
||||
model_types = list(set(factory_model_types + active_model_types) - set(inactive_model_types) - set(unsupport_model_types))
|
||||
if not model_types:
|
||||
continue
|
||||
|
||||
|
||||
@@ -313,9 +313,6 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_
|
||||
api_key_str = ""
|
||||
if api_key:
|
||||
api_key_str = api_key if isinstance(api_key, str) else json.dumps(api_key)
|
||||
same_key_instance = TenantModelInstanceService.get_by_provider_id_and_api_key(provider_obj.id, api_key_str)
|
||||
if same_key_instance:
|
||||
return False, f"Already exist instance: {same_key_instance.instance_name} with api_key {api_key}"
|
||||
success, msg = await verify_api_key(provider_name, api_key, base_url, region, model_info)
|
||||
if not success:
|
||||
return False, msg
|
||||
@@ -594,6 +591,57 @@ def drop_provider_instances(tenant_id: str, provider_name: str, instance_names:
|
||||
return True, None
|
||||
|
||||
|
||||
def _hybrid_get_instance_models(provider_name: str, instance_id: str):
|
||||
# List all models from the LLM dictionary for this provider
|
||||
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name]
|
||||
if not factory_info:
|
||||
return False, f"Provider '{provider_name}' not found"
|
||||
|
||||
# Get model records for this instance from tenant_model table
|
||||
model_records = TenantModelService.get_models_by_instance_id(instance_id)
|
||||
# Build a map of model_name -> status, type
|
||||
model_info_map: dict = {}
|
||||
model_unsupported_type_map = {}
|
||||
for model_record in model_records:
|
||||
if model_record.status == ActiveStatusEnum.UNSUPPORTED:
|
||||
if model_unsupported_type_map.get(model_record.model_name):
|
||||
model_unsupported_type_map[model_record.model_name].append(model_record.model_type)
|
||||
else:
|
||||
model_unsupported_type_map[model_record.model_name] = [model_record.model_type]
|
||||
continue
|
||||
if model_info_map.get(model_record.model_name):
|
||||
model_info_map[model_record.model_name]["model_type"].append(model_record.model_type)
|
||||
else:
|
||||
model_info_map[model_record.model_name] = {
|
||||
"status": model_record.status,
|
||||
"model_type": [model_record.model_type],
|
||||
"extra": model_record.extra
|
||||
}
|
||||
|
||||
llms = factory_info[0].get("llm", [])
|
||||
models = []
|
||||
for llm in llms:
|
||||
models.append({
|
||||
"name": llm["llm_name"],
|
||||
"model_type": list(
|
||||
set(_factory_model_types(llm) + model_info_map.get(llm["llm_name"], {}).get("model_type", [])) - set(model_unsupported_type_map.get(llm["llm_name"], []))
|
||||
),
|
||||
"max_tokens": llm.get("max_tokens"),
|
||||
"status": model_info_map.get(llm["llm_name"], {}).get("status", "active"),
|
||||
})
|
||||
factory_models = [m["name"] for m in models]
|
||||
for model_name, model_info_dict in model_info_map.items():
|
||||
if model_name not in factory_models:
|
||||
extra_fields = json.loads(model_info_dict["extra"]) if model_info_dict["extra"] else {}
|
||||
models.append({
|
||||
"name": model_name,
|
||||
"model_type": set(model_info_dict["model_type"]) - set(model_unsupported_type_map.get(model_name, [])),
|
||||
"max_tokens": extra_fields.get("max_tokens", 8192),
|
||||
"status": model_info_dict["status"],
|
||||
})
|
||||
return True, models
|
||||
|
||||
|
||||
def list_instance_models(tenant_id: str, provider_name: str, instance_name: str, supported_only: bool = False):
|
||||
"""
|
||||
List models for a provider instance.
|
||||
@@ -628,47 +676,40 @@ def list_instance_models(tenant_id: str, provider_name: str, instance_name: str,
|
||||
if not instance_obj:
|
||||
return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'"
|
||||
|
||||
# Get model records for this instance from tenant_model table
|
||||
model_records = TenantModelService.get_models_by_instance_id(instance_obj.id)
|
||||
# Build a map of model_name -> status, type
|
||||
model_info_map: dict = {}
|
||||
for model_record in model_records:
|
||||
if model_info_map.get(model_record.model_name):
|
||||
model_info_map[model_record.model_name]["model_type"].append(model_record.model_type)
|
||||
else:
|
||||
model_info_map[model_record.model_name] = {
|
||||
"status": model_record.status,
|
||||
"model_type": [model_record.model_type],
|
||||
"extra": model_record.extra
|
||||
return _hybrid_get_instance_models(provider_name, instance_obj.id)
|
||||
|
||||
|
||||
def update_instance_models(tenant_id: str, provider_name: str, instance_name: str, model_names: list, model_types: list):
|
||||
if not model_names or not model_types:
|
||||
return False, "model_name and model_type are required"
|
||||
|
||||
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
|
||||
if not provider_obj:
|
||||
return False, f"No provider found for provider '{provider_name}'"
|
||||
instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_name)
|
||||
if not instance_obj:
|
||||
return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'"
|
||||
|
||||
found, models = _hybrid_get_instance_models(provider_name, instance_obj.id)
|
||||
if not found:
|
||||
return False, models
|
||||
|
||||
model_info_map = {model["name"]: model for model in models}
|
||||
not_exist_models = set(model_names) - set(model_info_map.keys())
|
||||
if not_exist_models:
|
||||
return False, f"Models {not_exist_models} not found for provider '{provider_name}' and instance '{instance_name}'"
|
||||
for model_name in model_names:
|
||||
model_info = model_info_map.get(model_name, {})
|
||||
TenantModelService.upsert_model_type(
|
||||
provider_obj.id,
|
||||
instance_obj.id,
|
||||
model_name,
|
||||
{
|
||||
"add": list(set(model_types) - set(model_info["model_type"])),
|
||||
"delete": list(set(model_info["model_type"]) - set(model_types))
|
||||
}
|
||||
|
||||
# List all models from the LLM dictionary for this provider
|
||||
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name]
|
||||
if not factory_info:
|
||||
return False, f"Provider '{provider_name}' not found"
|
||||
|
||||
llms = factory_info[0].get("llm", [])
|
||||
models = []
|
||||
for llm in llms:
|
||||
models.append({
|
||||
"name": llm["llm_name"],
|
||||
"model_type": list(
|
||||
dict.fromkeys(_factory_model_types(llm) + model_info_map.get(llm["llm_name"], {}).get("model_type", []))
|
||||
),
|
||||
"max_tokens": llm.get("max_tokens"),
|
||||
"status": model_info_map.get(llm["llm_name"], {}).get("status", "active"),
|
||||
})
|
||||
factory_models = [m["name"] for m in models]
|
||||
for model_name, model_info_dict in model_info_map.items():
|
||||
if model_name not in factory_models:
|
||||
extra_fields = json.loads(model_info_dict["extra"]) if model_info_dict["extra"] else {}
|
||||
models.append({
|
||||
"name": model_name,
|
||||
"model_type": model_info_dict["model_type"],
|
||||
"max_tokens": extra_fields.get("max_tokens", 8192),
|
||||
"status": model_info_dict["status"],
|
||||
})
|
||||
return True, models
|
||||
)
|
||||
return True, "success"
|
||||
|
||||
|
||||
def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str, model_name: str, model_type: str|list[str], max_tokens: int=8192, extra: dict=None):
|
||||
@@ -742,7 +783,7 @@ def update_model_status(tenant_id: str, provider_name: str, instance_name: str,
|
||||
|
||||
if model_obj_list:
|
||||
# Model record exists — update its status
|
||||
TenantModelService.batch_update_model_status([m.id for m in model_obj_list], status)
|
||||
TenantModelService.batch_update_model_status([m.id for m in model_obj_list if m.status != ActiveStatusEnum.UNSUPPORTED.value], status)
|
||||
else:
|
||||
# Model record does not exist
|
||||
if status == ActiveStatusEnum.ACTIVE.value:
|
||||
|
||||
@@ -214,6 +214,8 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str|enum.Enum
|
||||
if model_obj:
|
||||
if model_obj.status == ActiveStatusEnum.INACTIVE.value:
|
||||
raise LookupError(f"Model {model_name} is disabled.")
|
||||
if model_obj.status == ActiveStatusEnum.UNSUPPORTED.value:
|
||||
raise LookupError(f"Model {model_name} cannot be used as {model_type_val} model.")
|
||||
|
||||
model_extra = json.loads(model_obj.extra) if model_obj.extra else {}
|
||||
model_config = {
|
||||
@@ -281,6 +283,7 @@ def get_model_type_by_name(tenant_id: str, model_name: str):
|
||||
if not instance_obj:
|
||||
raise LookupError(f"Instance {instance_name} not found for model {model_name}.")
|
||||
model_objs = TenantModelService.get_by_provider_id_and_instance_id_and_model_name(provider_obj.id, instance_obj.id, pure_model_name)
|
||||
types_in_json = []
|
||||
if not model_objs:
|
||||
extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {}
|
||||
region = extra_fields.get("region", "default")
|
||||
@@ -294,8 +297,8 @@ def get_model_type_by_name(tenant_id: str, model_name: str):
|
||||
llm_list = [llm for llm in fac_list[0]["llm"] if llm["llm_name"] == pure_model_name]
|
||||
if not llm_list:
|
||||
raise LookupError(f"Model {pure_model_name} not found for model {model_name}.")
|
||||
return _factory_model_types(llm_list[0])
|
||||
return [model_obj.model_type for model_obj in model_objs]
|
||||
types_in_json = _factory_model_types(llm_list[0])
|
||||
return list(set(types_in_json + [model_obj.model_type for model_obj in model_objs if model_obj.status != ActiveStatusEnum.UNSUPPORTED.value]) - {model_obj.model_type for model_obj in model_objs if model_obj.status == ActiveStatusEnum.UNSUPPORTED.value})
|
||||
|
||||
|
||||
def delete_models_by_instance_ids(instance_ids: list[str]):
|
||||
@@ -329,6 +332,7 @@ def get_models_by_tenant_and_provider_and_model_type(tenant_id: str, provider_na
|
||||
results = []
|
||||
for inst in instances:
|
||||
models = TenantModelService.get_by_provider_id_and_instance_id_and_model_type(provider_obj.id, inst.id, model_type)
|
||||
if models:
|
||||
results.extend(models)
|
||||
supported = [model for model in models if model.status != ActiveStatusEnum.UNSUPPORTED.value]
|
||||
if supported:
|
||||
results.extend(supported)
|
||||
return results
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from common.constants import ActiveStatusEnum
|
||||
from api.db.db_models import DB, TenantModel
|
||||
from api.db.services.common_service import CommonService
|
||||
|
||||
@@ -59,6 +60,63 @@ class TenantModelService(CommonService):
|
||||
def batch_update_model_status(cls, model_ids, status):
|
||||
return cls.model.update(status=status).where(cls.model.id.in_(model_ids)).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def upsert_model_type(cls, provider_id: str, instance_id: str, model_name: str, operation: dict):
|
||||
model_type_records = cls.model.select().where(cls.model.provider_id == provider_id, cls.model.instance_id == instance_id, cls.model.model_name == model_name)
|
||||
if not model_type_records:
|
||||
for _type in operation.get("add", []):
|
||||
cls.insert(
|
||||
model_name=model_name,
|
||||
provider_id=provider_id,
|
||||
instance_id=instance_id,
|
||||
model_type=_type,
|
||||
extra="{}"
|
||||
)
|
||||
for _type in operation.get("delete", []):
|
||||
cls.insert(
|
||||
model_name=model_name,
|
||||
provider_id=provider_id,
|
||||
instance_id=instance_id,
|
||||
model_type=_type,
|
||||
status=ActiveStatusEnum.UNSUPPORTED,
|
||||
extra="{}"
|
||||
)
|
||||
return len(operation.get("add", [])) + len(operation.get("delete", []))
|
||||
model_record_example = [model_record for model_record in model_type_records if model_record.status != ActiveStatusEnum.UNSUPPORTED]
|
||||
extra_fields = model_record_example[0].extra if model_record_example else "{}"
|
||||
model_status = model_record_example[0].status if model_record_example else ActiveStatusEnum.ACTIVE.value
|
||||
type_record_map = {record.model_type: record for record in model_type_records}
|
||||
operated_cnt = 0
|
||||
for _type in operation.get("add", []):
|
||||
if type_record_map.get(_type):
|
||||
cls.update_by_id(type_record_map[_type].id, {"status": model_status})
|
||||
|
||||
else:
|
||||
cls.insert(
|
||||
model_name=model_name,
|
||||
provider_id=provider_id,
|
||||
instance_id=instance_id,
|
||||
model_type=_type,
|
||||
status=model_status,
|
||||
extra=extra_fields
|
||||
)
|
||||
operated_cnt += 1
|
||||
for _type in operation.get("delete", []):
|
||||
if type_record_map.get(_type):
|
||||
cls.update_by_id(type_record_map[_type].id, {"status": ActiveStatusEnum.UNSUPPORTED.value})
|
||||
else:
|
||||
cls.insert(
|
||||
model_name=model_name,
|
||||
provider_id=provider_id,
|
||||
instance_id=instance_id,
|
||||
model_type=_type,
|
||||
status=ActiveStatusEnum.UNSUPPORTED.value,
|
||||
extra=extra_fields
|
||||
)
|
||||
operated_cnt += 1
|
||||
return operated_cnt
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_id(cls, model_id):
|
||||
|
||||
@@ -69,6 +69,7 @@ class StatusEnum(Enum):
|
||||
class ActiveStatusEnum(Enum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
UNSUPPORTED = "unsupported"
|
||||
|
||||
|
||||
class ActiveEnum(Enum):
|
||||
|
||||
Reference in New Issue
Block a user