From d770217b254fca92bcbc32fb5542e0ca88df100a Mon Sep 17 00:00:00 2001 From: Harsh Kashyap Date: Wed, 1 Jul 2026 13:30:13 +0530 Subject: [PATCH] fix(api): fall back to factory max_tokens for tenant models (#16364) --- api/db/joint_services/tenant_model_service.py | 20 ++- .../test_tenant_model_service_max_tokens.py | 122 ++++++++++++++++++ 2 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 test/unit_test/api/db/joint_services/test_tenant_model_service_max_tokens.py diff --git a/api/db/joint_services/tenant_model_service.py b/api/db/joint_services/tenant_model_service.py index 4d1c29d80d..44b6a4aa9f 100644 --- a/api/db/joint_services/tenant_model_service.py +++ b/api/db/joint_services/tenant_model_service.py @@ -45,6 +45,19 @@ def _factory_model_types(llm: dict) -> list[str]: return [model_type] if model_type else [] +def _lookup_factory_llm_info(provider_name: str, pure_model_name: str, extra_fields: dict) -> dict | None: + region = extra_fields.get("region", "default") + if region == "intl" and provider_name.lower() == "siliconflow": + target_factory_name = "siliconflow_intl" + else: + target_factory_name = provider_name + fac_list = [f for f in settings.FACTORY_LLM_INFOS if f["name"] == target_factory_name] + if not fac_list: + return None + llm_list = [llm for llm in fac_list[0]["llm"] if llm["llm_name"] == pure_model_name] + return llm_list[0] if llm_list else None + + def _decode_api_key_config(raw_api_key: str) -> tuple[str, bool | None, str | None]: if not raw_api_key: return raw_api_key, None, None @@ -249,6 +262,11 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str | enum.En raise LookupError(f"Model {model_name} cannot be used as {model_type_val} model.") model_extra = json.loads(model_obj.extra) if model_obj.extra else {} + llm_info = _lookup_factory_llm_info(provider_obj.provider_name, pure_model_name, extra_fields) + if "max_tokens" in model_extra: + max_tokens = model_extra["max_tokens"] + else: + max_tokens = (llm_info or {}).get("max_tokens", 8192) model_config = { "llm_factory": provider_obj.provider_name, "api_key": api_key, @@ -256,7 +274,7 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str | enum.En "api_base": extra_fields.get("base_url", ""), "model_type": model_obj.model_type, "is_tools": model_extra.get("is_tools", is_tool), - "max_tokens": model_extra.get("max_tokens") or 8192, + "max_tokens": max_tokens, } if provider_name.lower() == "somark": # SoMark/OCR factories read parser config (somark_*, parse_method, ...) diff --git a/test/unit_test/api/db/joint_services/test_tenant_model_service_max_tokens.py b/test/unit_test/api/db/joint_services/test_tenant_model_service_max_tokens.py new file mode 100644 index 0000000000..b52f625b71 --- /dev/null +++ b/test/unit_test/api/db/joint_services/test_tenant_model_service_max_tokens.py @@ -0,0 +1,122 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from types import SimpleNamespace + +import pytest + +from common.constants import ActiveStatusEnum +from api.db.joint_services import tenant_model_service as tms + + +@pytest.mark.p1 +def test_max_tokens_falls_back_to_factory_when_model_extra_empty(monkeypatch): + provider = SimpleNamespace(id="provider-1", provider_name="OpenAI") + instance = SimpleNamespace(id="instance-1", api_key="sk-test", extra='{"base_url": "https://api.example"}') + model = SimpleNamespace( + model_name="gpt-test", + model_type="chat", + status=ActiveStatusEnum.ACTIVE.value, + extra="{}", + ) + + monkeypatch.setattr( + tms.TenantModelProviderService, + "get_by_tenant_id_and_provider_name", + lambda tenant_id, provider_name: provider, + ) + monkeypatch.setattr( + tms.TenantModelInstanceService, + "get_by_provider_id_and_instance_name", + lambda provider_id, instance_name: instance, + ) + monkeypatch.setattr( + tms.TenantModelService, + "get_by_provider_id_and_instance_id_and_model_type_and_model_name", + lambda provider_id, instance_id, model_type, model_name: model, + ) + monkeypatch.setattr( + tms.settings, + "FACTORY_LLM_INFOS", + [ + { + "name": "OpenAI", + "llm": [ + { + "llm_name": "gpt-test", + "model_type": "chat", + "max_tokens": 128000, + } + ], + } + ], + ) + + config = tms.get_model_config_from_provider_instance( + "tenant-1", "chat", "gpt-test@default@OpenAI" + ) + + assert config["max_tokens"] == 128000 + + +@pytest.mark.p1 +def test_max_tokens_prefers_model_extra_over_factory(monkeypatch): + provider = SimpleNamespace(id="provider-1", provider_name="OpenAI") + instance = SimpleNamespace(id="instance-1", api_key="sk-test", extra="{}") + model = SimpleNamespace( + model_name="gpt-test", + model_type="chat", + status=ActiveStatusEnum.ACTIVE.value, + extra='{"max_tokens": 32000}', + ) + + monkeypatch.setattr( + tms.TenantModelProviderService, + "get_by_tenant_id_and_provider_name", + lambda tenant_id, provider_name: provider, + ) + monkeypatch.setattr( + tms.TenantModelInstanceService, + "get_by_provider_id_and_instance_name", + lambda provider_id, instance_name: instance, + ) + monkeypatch.setattr( + tms.TenantModelService, + "get_by_provider_id_and_instance_id_and_model_type_and_model_name", + lambda provider_id, instance_id, model_type, model_name: model, + ) + monkeypatch.setattr( + tms.settings, + "FACTORY_LLM_INFOS", + [ + { + "name": "OpenAI", + "llm": [ + { + "llm_name": "gpt-test", + "model_type": "chat", + "max_tokens": 128000, + } + ], + } + ], + ) + + config = tms.get_model_config_from_provider_instance( + "tenant-1", "chat", "gpt-test@default@OpenAI" + ) + + assert config["max_tokens"] == 32000