mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 16:25:44 +08:00
fix(api): fall back to factory max_tokens for tenant models (#16364)
This commit is contained in:
@@ -45,6 +45,19 @@ def _factory_model_types(llm: dict) -> list[str]:
|
||||
return [model_type] if model_type else []
|
||||
|
||||
|
||||
def _lookup_factory_llm_info(provider_name: str, pure_model_name: str, extra_fields: dict) -> dict | None:
|
||||
region = extra_fields.get("region", "default")
|
||||
if region == "intl" and provider_name.lower() == "siliconflow":
|
||||
target_factory_name = "siliconflow_intl"
|
||||
else:
|
||||
target_factory_name = provider_name
|
||||
fac_list = [f for f in settings.FACTORY_LLM_INFOS if f["name"] == target_factory_name]
|
||||
if not fac_list:
|
||||
return None
|
||||
llm_list = [llm for llm in fac_list[0]["llm"] if llm["llm_name"] == pure_model_name]
|
||||
return llm_list[0] if llm_list else None
|
||||
|
||||
|
||||
def _decode_api_key_config(raw_api_key: str) -> tuple[str, bool | None, str | None]:
|
||||
if not raw_api_key:
|
||||
return raw_api_key, None, None
|
||||
@@ -249,6 +262,11 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str | enum.En
|
||||
raise LookupError(f"Model {model_name} cannot be used as {model_type_val} model.")
|
||||
|
||||
model_extra = json.loads(model_obj.extra) if model_obj.extra else {}
|
||||
llm_info = _lookup_factory_llm_info(provider_obj.provider_name, pure_model_name, extra_fields)
|
||||
if "max_tokens" in model_extra:
|
||||
max_tokens = model_extra["max_tokens"]
|
||||
else:
|
||||
max_tokens = (llm_info or {}).get("max_tokens", 8192)
|
||||
model_config = {
|
||||
"llm_factory": provider_obj.provider_name,
|
||||
"api_key": api_key,
|
||||
@@ -256,7 +274,7 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str | enum.En
|
||||
"api_base": extra_fields.get("base_url", ""),
|
||||
"model_type": model_obj.model_type,
|
||||
"is_tools": model_extra.get("is_tools", is_tool),
|
||||
"max_tokens": model_extra.get("max_tokens") or 8192,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
if provider_name.lower() == "somark":
|
||||
# SoMark/OCR factories read parser config (somark_*, parse_method, ...)
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from common.constants import ActiveStatusEnum
|
||||
from api.db.joint_services import tenant_model_service as tms
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_max_tokens_falls_back_to_factory_when_model_extra_empty(monkeypatch):
|
||||
provider = SimpleNamespace(id="provider-1", provider_name="OpenAI")
|
||||
instance = SimpleNamespace(id="instance-1", api_key="sk-test", extra='{"base_url": "https://api.example"}')
|
||||
model = SimpleNamespace(
|
||||
model_name="gpt-test",
|
||||
model_type="chat",
|
||||
status=ActiveStatusEnum.ACTIVE.value,
|
||||
extra="{}",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
tms.TenantModelProviderService,
|
||||
"get_by_tenant_id_and_provider_name",
|
||||
lambda tenant_id, provider_name: provider,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tms.TenantModelInstanceService,
|
||||
"get_by_provider_id_and_instance_name",
|
||||
lambda provider_id, instance_name: instance,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tms.TenantModelService,
|
||||
"get_by_provider_id_and_instance_id_and_model_type_and_model_name",
|
||||
lambda provider_id, instance_id, model_type, model_name: model,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tms.settings,
|
||||
"FACTORY_LLM_INFOS",
|
||||
[
|
||||
{
|
||||
"name": "OpenAI",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "gpt-test",
|
||||
"model_type": "chat",
|
||||
"max_tokens": 128000,
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
config = tms.get_model_config_from_provider_instance(
|
||||
"tenant-1", "chat", "gpt-test@default@OpenAI"
|
||||
)
|
||||
|
||||
assert config["max_tokens"] == 128000
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_max_tokens_prefers_model_extra_over_factory(monkeypatch):
|
||||
provider = SimpleNamespace(id="provider-1", provider_name="OpenAI")
|
||||
instance = SimpleNamespace(id="instance-1", api_key="sk-test", extra="{}")
|
||||
model = SimpleNamespace(
|
||||
model_name="gpt-test",
|
||||
model_type="chat",
|
||||
status=ActiveStatusEnum.ACTIVE.value,
|
||||
extra='{"max_tokens": 32000}',
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
tms.TenantModelProviderService,
|
||||
"get_by_tenant_id_and_provider_name",
|
||||
lambda tenant_id, provider_name: provider,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tms.TenantModelInstanceService,
|
||||
"get_by_provider_id_and_instance_name",
|
||||
lambda provider_id, instance_name: instance,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tms.TenantModelService,
|
||||
"get_by_provider_id_and_instance_id_and_model_type_and_model_name",
|
||||
lambda provider_id, instance_id, model_type, model_name: model,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tms.settings,
|
||||
"FACTORY_LLM_INFOS",
|
||||
[
|
||||
{
|
||||
"name": "OpenAI",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "gpt-test",
|
||||
"model_type": "chat",
|
||||
"max_tokens": 128000,
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
config = tms.get_model_config_from_provider_instance(
|
||||
"tenant-1", "chat", "gpt-test@default@OpenAI"
|
||||
)
|
||||
|
||||
assert config["max_tokens"] == 32000
|
||||
Reference in New Issue
Block a user