mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Fix: model-provider bugs (#15460)
### What problem does this PR solve? Fix: - Use @ to avoid split by `_` in model_name. - Verify api_key when add instance. - Pop api_key in list intances response. - Remove useless index. - Sort providers, instances and models by name. - Get `is_tools` from llm_factories.json ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -335,9 +335,69 @@ async def create_provider_instance(tenant_id: str = None, provider_name: str = N
|
||||
api_key = data["api_key"]
|
||||
base_url = data.get("base_url", "")
|
||||
region = data.get("region", "")
|
||||
verify = data.get("verify", False)
|
||||
|
||||
try:
|
||||
success, msg = provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region)
|
||||
success, msg = await provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region, verify)
|
||||
if success:
|
||||
return get_result(message=msg)
|
||||
else:
|
||||
return get_error_data_result(message=msg)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Internal server error")
|
||||
|
||||
|
||||
@manager.route("/providers/<provider_name>/connection", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def verify_provider_api_key(provider_name: str = None):
|
||||
"""
|
||||
Verify api key.
|
||||
---
|
||||
tags:
|
||||
- Providers
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: path
|
||||
name: provider_name
|
||||
type: string
|
||||
required: true
|
||||
description: Provider name.
|
||||
- in: header
|
||||
name: Authorization
|
||||
type: string
|
||||
required: true
|
||||
description: Bearer token for authentication.
|
||||
- in: body
|
||||
name: body
|
||||
description: Instance creation parameters.
|
||||
required: true
|
||||
schema:
|
||||
type: object
|
||||
required:
|
||||
- api_key
|
||||
properties:
|
||||
api_key:
|
||||
type: string
|
||||
description: API key.
|
||||
base_url:
|
||||
type: string
|
||||
description: Base URL.
|
||||
responses:
|
||||
200:
|
||||
description: Instance created successfully.
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
data = await request.get_json()
|
||||
if not data or "api_key" not in data:
|
||||
return get_error_argument_result(message="api_key is required")
|
||||
|
||||
base_url = data.get("base_url", "")
|
||||
api_key = data["api_key"]
|
||||
try:
|
||||
success, msg = await provider_api_service.verify_api_key(provider_name, api_key, base_url)
|
||||
if success:
|
||||
return get_result(message=msg)
|
||||
else:
|
||||
|
||||
@@ -298,7 +298,7 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None):
|
||||
target_type_records = [record for record in model_records if record.model_type == model_type_filter] if model_type_filter else model_records
|
||||
model_record_map = {}
|
||||
for model in target_type_records:
|
||||
instance_model_key = f"{model.provider_id}_{model.instance_id}_{model.model_name}"
|
||||
instance_model_key = f"{model.provider_id}@{model.instance_id}@{model.model_name}"
|
||||
if model_record_map.get(instance_model_key):
|
||||
model_record_map[instance_model_key].append(model)
|
||||
else:
|
||||
@@ -318,7 +318,7 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None):
|
||||
continue
|
||||
|
||||
for factory_instance in factory_instances:
|
||||
model_record_key = f"{factory_instance.provider_id}_{factory_instance.id}_{llm['llm_name']}"
|
||||
model_record_key = f"{factory_instance.provider_id}@{factory_instance.id}@{llm['llm_name']}"
|
||||
model_key_in_factory.append(model_record_key)
|
||||
manual_modified_models = model_record_map.get(model_record_key, [])
|
||||
active_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.ACTIVE.value]
|
||||
@@ -343,7 +343,7 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None):
|
||||
model_records = model_record_map.get(model_record_key, [])
|
||||
if not model_records:
|
||||
continue
|
||||
provider_id, instance_id, model_name = model_record_key.split("_")
|
||||
provider_id, instance_id, model_name = model_record_key.split("@")
|
||||
model_types = [model.model_type for model in model_records if model.status == ActiveStatusEnum.ACTIVE.value]
|
||||
if not model_types:
|
||||
continue
|
||||
@@ -357,4 +357,6 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None):
|
||||
"instance_name": instance_info_map[instance_id].instance_name if instance_info_map.get(instance_id) else ""
|
||||
})
|
||||
|
||||
added_models.sort(key=lambda x: (x["provider_name"], x["instance_name"], x["name"]))
|
||||
|
||||
return True, added_models
|
||||
|
||||
@@ -13,8 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from common.constants import LLMType, ActiveStatusEnum
|
||||
from common.misc_utils import get_uuid
|
||||
@@ -23,6 +25,7 @@ from api.db.joint_services.tenant_model_service import get_model_config_from_pro
|
||||
from api.db.services.tenant_model_provider_service import TenantModelProviderService
|
||||
from api.db.services.tenant_model_instance_service import TenantModelInstanceService
|
||||
from api.db.services.tenant_model_service import TenantModelService
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel
|
||||
|
||||
|
||||
def list_providers(tenant_id: str, all_available: bool = False):
|
||||
@@ -54,6 +57,7 @@ def list_providers(tenant_id: str, all_available: bool = False):
|
||||
"default": factory_info.get("url", "")
|
||||
}
|
||||
})
|
||||
providers.sort(key=lambda x: x["name"])
|
||||
return True, providers
|
||||
|
||||
# List tenant-configured providers
|
||||
@@ -76,7 +80,7 @@ def list_providers(tenant_id: str, all_available: bool = False):
|
||||
"default": factory_info.get("url", "")
|
||||
}
|
||||
})
|
||||
|
||||
providers.sort(key=lambda x: x["name"])
|
||||
return True, providers
|
||||
|
||||
|
||||
@@ -117,7 +121,7 @@ def delete_provider(tenant_id: str, provider_name: str):
|
||||
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
|
||||
if not provider_obj:
|
||||
return False, f"Provider {provider_name} not found"
|
||||
instance_objs = TenantModelInstanceService.get_by_provider_id(provider_obj.id)
|
||||
instance_objs = TenantModelInstanceService.get_all_by_provider_id(provider_obj.id)
|
||||
if not instance_objs:
|
||||
return False, f"No instances found for provider {provider_name}"
|
||||
instance_ids = [instance_obj.id for instance_obj in instance_objs]
|
||||
@@ -169,6 +173,7 @@ def list_provider_models(provider_name: str):
|
||||
"model_types": [llm["model_type"]],
|
||||
"features": None
|
||||
})
|
||||
models.sort(key=lambda x: x["name"])
|
||||
return True, models
|
||||
|
||||
|
||||
@@ -202,7 +207,7 @@ def show_provider_model(provider_name: str, model_name: str):
|
||||
}
|
||||
|
||||
|
||||
def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str, base_url: str, region: str):
|
||||
async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str, base_url: str, region: str, verify: bool):
|
||||
"""
|
||||
Create a provider instance.
|
||||
|
||||
@@ -215,6 +220,7 @@ def create_provider_instance(tenant_id: str, provider_name: str, instance_name:
|
||||
:param api_key: API key
|
||||
:param base_url: base url
|
||||
:param region: region
|
||||
:param verify: verify
|
||||
:return: (success, result_or_error_message)
|
||||
"""
|
||||
if not provider_name:
|
||||
@@ -236,6 +242,10 @@ def create_provider_instance(tenant_id: str, provider_name: str, instance_name:
|
||||
same_key_instance = TenantModelInstanceService.get_by_provider_id_and_api_key(provider_obj.id, api_key)
|
||||
if same_key_instance:
|
||||
return False, f"Already exist instance: {same_key_instance.instance_name} with api_key {api_key}"
|
||||
if verify:
|
||||
success, msg = await verify_api_key(provider_name, api_key, base_url)
|
||||
if not success:
|
||||
return False, msg
|
||||
|
||||
import json
|
||||
extra_fields = {}
|
||||
@@ -268,7 +278,6 @@ def list_provider_instances(tenant_id: str, provider_name: str):
|
||||
for instance_obj in instance_objs:
|
||||
extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {}
|
||||
instances.append({
|
||||
"api_key": instance_obj.api_key,
|
||||
"id": instance_obj.id,
|
||||
"instance_name": instance_obj.instance_name,
|
||||
"provider_id": provider_id,
|
||||
@@ -276,7 +285,93 @@ def list_provider_instances(tenant_id: str, provider_name: str):
|
||||
"status": instance_obj.status,
|
||||
})
|
||||
|
||||
return True, instances
|
||||
active_instances = [instance for instance in instances if instance["status"] == ActiveStatusEnum.ACTIVE.value]
|
||||
inactive_instances = [instance for instance in instances if instance["status"] == ActiveStatusEnum.INACTIVE.value]
|
||||
active_instances.sort(key=lambda x: x["instance_name"])
|
||||
inactive_instances.sort(key=lambda x: x["instance_name"])
|
||||
|
||||
return True, active_instances + inactive_instances
|
||||
|
||||
|
||||
async def verify_api_key(provider_name: str, api_key: str, base_url: str=None):
|
||||
"""
|
||||
Verify API key for a provider.
|
||||
|
||||
:param provider_name: provider/factory name
|
||||
:param api_key: API key
|
||||
:param base_url: base url
|
||||
:return: (success, result_or_error_message)
|
||||
"""
|
||||
if not provider_name:
|
||||
return False, "Provider name is required"
|
||||
|
||||
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name]
|
||||
if not factory_info:
|
||||
return False, f"Provider '{provider_name}' not found"
|
||||
|
||||
factory_llms = factory_info[0]["llm"]
|
||||
if not factory_llms:
|
||||
return False, f"No models found for provider '{provider_name}'"
|
||||
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
|
||||
extra = {"provider": provider_name}
|
||||
msg = ""
|
||||
for llm in factory_llms:
|
||||
if not embd_passed and llm["model_type"] == LLMType.EMBEDDING.value:
|
||||
assert provider_name in EmbeddingModel, f"Embedding model from {provider_name} is not supported yet."
|
||||
mdl = EmbeddingModel[provider_name](api_key, llm["llm_name"], base_url=base_url)
|
||||
try:
|
||||
arr, tc = asyncio.wait_for(
|
||||
asyncio.to_thread(mdl.encode, ["Test if the api key is available"]),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if len(arr[0]) == 0:
|
||||
raise Exception("Fail")
|
||||
embd_passed = True
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access embedding model({llm['llm_name']}) using this api key." + str(e)
|
||||
elif not chat_passed and llm["model_type"] == LLMType.CHAT.value:
|
||||
assert provider_name in ChatModel, f"Chat model from {provider_name} is not supported yet."
|
||||
mdl = ChatModel[provider_name](api_key, llm["llm_name"], base_url=base_url, **extra)
|
||||
try:
|
||||
async def check_streamly():
|
||||
async for chunk in mdl.async_chat_streamly(
|
||||
None,
|
||||
[{"role": "user", "content": "Hi"}],
|
||||
{"temperature": 0.9},
|
||||
):
|
||||
if chunk and isinstance(chunk, str) and chunk.find("**ERROR**") < 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
result = await asyncio.wait_for(check_streamly(), timeout=timeout_seconds)
|
||||
if result:
|
||||
chat_passed = True
|
||||
else:
|
||||
raise Exception("No valid response received")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e)
|
||||
elif not rerank_passed and llm["model_type"] == LLMType.RERANK.value:
|
||||
assert provider_name in RerankModel, f"Rerank model from {provider_name} is not supported yet."
|
||||
mdl = RerankModel[provider_name](api_key, llm["llm_name"], base_url=base_url)
|
||||
try:
|
||||
arr, tc = await asyncio.wait_for(
|
||||
asyncio.to_thread(mdl.similarity, "What's the weather?", ["Is it sunny today?"]),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if len(arr) == 0 or tc == 0:
|
||||
raise Exception("Fail")
|
||||
rerank_passed = True
|
||||
logging.debug(f"passed model rerank {llm['llm_name']}")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e)
|
||||
if any([embd_passed, chat_passed, rerank_passed]):
|
||||
msg = ""
|
||||
break
|
||||
|
||||
return any([embd_passed, chat_passed, rerank_passed]), msg or "success"
|
||||
|
||||
|
||||
def show_provider_instance(tenant_id: str, provider_name: str, instance_name: str):
|
||||
@@ -362,6 +457,7 @@ def list_instance_models(tenant_id: str, provider_name: str, instance_name: str,
|
||||
return False, f"Provider '{provider_name}' not found"
|
||||
llms = factory_info[0].get("llm", [])
|
||||
models = [{"name": llm["llm_name"]} for llm in llms]
|
||||
models.sort(key=lambda x: x["name"])
|
||||
return True, models
|
||||
|
||||
# Get instance
|
||||
@@ -407,8 +503,11 @@ def list_instance_models(tenant_id: str, provider_name: str, instance_name: str,
|
||||
"max_tokens": extra_fields.get("max_tokens", 8192),
|
||||
"status": model_info_dict["status"],
|
||||
})
|
||||
|
||||
return True, models
|
||||
active_models = [model for model in models if model["status"] == ActiveStatusEnum.ACTIVE.value]
|
||||
inactive_models = [model for model in models if model["status"] == ActiveStatusEnum.INACTIVE.value]
|
||||
active_models.sort(key=lambda x: x["name"])
|
||||
inactive_models.sort(key=lambda x: x["name"])
|
||||
return True, active_models + inactive_models
|
||||
|
||||
|
||||
def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str, model_name: str, model_type: str|list[str], max_tokens: int, extra: dict):
|
||||
@@ -434,7 +533,7 @@ def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str
|
||||
extra_fields = {"max_tokens": max_tokens}
|
||||
target_model = [llm for llm in llms if llm["model_type"] == _type and llm["llm_name"] == model_name]
|
||||
if target_model:
|
||||
extra_fields.update({"is_tool": target_model[0].get("is_tool", False)})
|
||||
extra_fields.update({"is_tools": target_model[0].get("is_tools", False)})
|
||||
extra_fields.update(extra)
|
||||
TenantModelService.insert(
|
||||
model_name=model_name,
|
||||
@@ -506,6 +605,7 @@ def update_model_status(tenant_id: str, provider_name: str, instance_name: str,
|
||||
provider_id=provider_obj.id,
|
||||
instance_id=instance_obj.id,
|
||||
status=status,
|
||||
extra=json.dumps({"max_tokens": target_llm[0].get("max_tokens", 8192), "is_tools": target_llm[0].get("is_tools", False)})
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
@@ -1372,9 +1372,6 @@ class TenantModelInstance(DataBaseModel):
|
||||
|
||||
class Meta:
|
||||
db_table = "tenant_model_instance"
|
||||
indexes = (
|
||||
(("api_key", "provider_id"), True),
|
||||
)
|
||||
|
||||
|
||||
class TenantModel(DataBaseModel):
|
||||
@@ -1731,6 +1728,17 @@ def migrate_db():
|
||||
alter_db_column_type(migrator, "document", "size", BigIntegerField(default=0, index=True))
|
||||
alter_db_column_type(migrator, "file", "size", BigIntegerField(default=0, index=True))
|
||||
alter_db_add_column(migrator, "tenant", "ocr_id", CharField(max_length=128, null=True, help_text="default ocr model ID", index=True))
|
||||
for table_name, index_name in [("tenant_model_instance", "idx_api_key_provider_id"), ("tenant_model", "idx_provider_model_instance")]:
|
||||
try:
|
||||
migrate(migrator.drop_index(table_name, index_name))
|
||||
except (OperationalError, ProgrammingError) as ex:
|
||||
msg = str(ex)
|
||||
if "1091" in msg or "can't DROP" in msg.lower() or "does not exist" in msg.lower() or "already exists" in msg.lower():
|
||||
pass
|
||||
else:
|
||||
logging.critical(f"Failed to drop index {index_name} on {table_name}: {ex}")
|
||||
except Exception as ex:
|
||||
logging.critical(f"Failed to drop index {index_name} on {table_name}: {ex}")
|
||||
logging.disable(logging.NOTSET)
|
||||
# this is after re-enabling logging to allow logging changed user emails
|
||||
migrate_add_unique_email(migrator)
|
||||
|
||||
@@ -116,7 +116,7 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str|enum.Enum
|
||||
"llm_name": model_obj.model_name,
|
||||
"api_base": extra_fields.get("base_url", ""),
|
||||
"model_type": model_obj.model_type,
|
||||
"is_tool": extra_fields.get("is_tool", is_tool)
|
||||
"is_tools": extra_fields.get("is_tools", is_tool)
|
||||
}
|
||||
if api_key_payload is not None:
|
||||
model_config["api_key_payload"] = api_key_payload
|
||||
@@ -136,7 +136,7 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str|enum.Enum
|
||||
"llm_name": llm_info["llm_name"],
|
||||
"api_base": extra_fields.get("base_url", ""),
|
||||
"model_type": llm_info["model_type"],
|
||||
"is_tool": llm_info.get("is_tool", is_tool)
|
||||
"is_tools": llm_info.get("is_tools", is_tool)
|
||||
}
|
||||
if api_key_payload is not None:
|
||||
model_config["api_key_payload"] = api_key_payload
|
||||
|
||||
@@ -20,8 +20,8 @@ package entity
|
||||
type TenantModelInstance struct {
|
||||
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
|
||||
InstanceName string `gorm:"column:instance_name;size:128;not null" json:"instance_name"`
|
||||
ProviderID string `gorm:"column:provider_id;size:32;not null;uniqueIndex:idx_api_key_provider_id" json:"provider_id"`
|
||||
APIKey string `gorm:"column:api_key;size:512;not null;uniqueIndex:idx_api_key_provider_id" json:"api_key"`
|
||||
ProviderID string `gorm:"column:provider_id;size:32;not null" json:"provider_id"`
|
||||
APIKey string `gorm:"column:api_key;size:512;not null" json:"api_key"`
|
||||
Status string `gorm:"column:status;size:32;default:'active'" json:"status"`
|
||||
Extra string `gorm:"column:extra;size:512;default:'{}'" json:"extra"`
|
||||
BaseModel
|
||||
|
||||
@@ -539,7 +539,6 @@ class TenantModelInstanceStage(MigrationStage):
|
||||
create_date DATETIME,
|
||||
update_time BIGINT,
|
||||
update_date DATETIME,
|
||||
UNIQUE INDEX idx_api_key_provider_id (api_key, provider_id),
|
||||
INDEX idx_provider_id (provider_id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
"""
|
||||
@@ -551,10 +550,34 @@ class TenantModelStage(MigrationStage):
|
||||
"""Migrate tenant_llm to tenant_model"""
|
||||
|
||||
name = "tenant_model"
|
||||
description = "Migrate tenant_llm to tenant_model (only status='0' records)"
|
||||
description = "Migrate tenant_llm to tenant_model (status='0' records, plus status='1' for empty-llm factories)"
|
||||
source_tables = ["tenant_llm", "tenant_model_provider", "tenant_model_instance"]
|
||||
target_tables = ["tenant_model"]
|
||||
|
||||
@staticmethod
|
||||
def _get_empty_llm_factories() -> list[str]:
|
||||
"""Load factory names whose llm field is an empty list from conf/llm_factories.json"""
|
||||
conf_path = os.path.join(PROJECT_BASE, "conf", "llm_factories.json")
|
||||
with open(conf_path, "r") as f:
|
||||
data = json.load(f)
|
||||
factories = []
|
||||
for key, items in data.items():
|
||||
if isinstance(items, list):
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
llm = item.get("llm")
|
||||
if isinstance(llm, list) and len(llm) == 0:
|
||||
factories.append(item["name"])
|
||||
return factories
|
||||
|
||||
def _build_status_condition(self) -> str:
|
||||
"""Build SQL WHERE condition for status filtering"""
|
||||
empty_factories = self._get_empty_llm_factories()
|
||||
if empty_factories:
|
||||
placeholders = ", ".join(f"'{f}'" for f in empty_factories)
|
||||
return f"(tl.status = '0' OR (tl.status = '1' AND tl.llm_factory IN ({placeholders})))"
|
||||
return "tl.status = '0'"
|
||||
|
||||
def current_timestamp(self) -> int:
|
||||
return int(time.time())
|
||||
|
||||
@@ -598,24 +621,26 @@ class TenantModelStage(MigrationStage):
|
||||
logger.info("Target table 'tenant_model' does not exist, will create")
|
||||
return True
|
||||
|
||||
# Check if there's data to migrate (only status='0' records)
|
||||
status_condition = self._build_status_condition()
|
||||
|
||||
# Check if there's data to migrate
|
||||
cursor = self.db.execute_sql(
|
||||
"SELECT COUNT(*) FROM ("
|
||||
" SELECT tl.id "
|
||||
" FROM tenant_llm tl "
|
||||
" INNER JOIN tenant_model_provider tmp ON tmp.tenant_id = tl.tenant_id AND tmp.provider_name = tl.llm_factory "
|
||||
" INNER JOIN tenant_model_instance tmi ON tmi.provider_id = tmp.id AND tmi.api_key = tl.api_key "
|
||||
" WHERE tl.status = '0' "
|
||||
" AND NOT EXISTS ("
|
||||
" SELECT 1 FROM tenant_model tm "
|
||||
" WHERE tm.provider_id = tmp.id AND tm.model_name = tl.llm_name AND tm.instance_id = tmi.id"
|
||||
" )"
|
||||
") AS distinct_records"
|
||||
f"SELECT COUNT(*) FROM ("
|
||||
f" SELECT tl.id "
|
||||
f" FROM tenant_llm tl "
|
||||
f" INNER JOIN tenant_model_provider tmp ON tmp.tenant_id = tl.tenant_id AND tmp.provider_name = tl.llm_factory "
|
||||
f" INNER JOIN tenant_model_instance tmi ON tmi.provider_id = tmp.id AND tmi.api_key = tl.api_key "
|
||||
f" WHERE {status_condition} "
|
||||
f" AND NOT EXISTS ("
|
||||
f" SELECT 1 FROM tenant_model tm "
|
||||
f" WHERE tm.provider_id = tmp.id AND tm.model_name = tl.llm_name AND tm.instance_id = tmi.id"
|
||||
f" )"
|
||||
f") AS distinct_records"
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
|
||||
if count == 0:
|
||||
logger.info("No new data to migrate from tenant_llm to tenant_model (status='0' only)")
|
||||
logger.info("No new data to migrate from tenant_llm to tenant_model")
|
||||
return False
|
||||
|
||||
logger.info(f"Found {count} rows to migrate from tenant_llm to tenant_model")
|
||||
@@ -652,19 +677,21 @@ class TenantModelStage(MigrationStage):
|
||||
logger.info("[CREATE TABLE ONLY] Target table created/verified, skipping data migration")
|
||||
return 0, self.target_tables
|
||||
|
||||
status_condition = self._build_status_condition()
|
||||
|
||||
# Get records from tenant_llm with provider_id and instance_id lookup
|
||||
# Only migrate records where status='0'
|
||||
# Migrate status='0' records, plus status='1' for empty-llm factories
|
||||
cursor = self.db.execute_sql(
|
||||
"SELECT tl.id, tl.llm_name, tmp.id as provider_id, tmi.id as instance_id, "
|
||||
" tl.model_type, tl.status "
|
||||
"FROM tenant_llm tl "
|
||||
"INNER JOIN tenant_model_provider tmp ON tmp.tenant_id = tl.tenant_id AND tmp.provider_name = tl.llm_factory "
|
||||
"INNER JOIN tenant_model_instance tmi ON tmi.provider_id = tmp.id AND tmi.api_key = tl.api_key "
|
||||
"WHERE tl.status = '0' "
|
||||
"AND NOT EXISTS ("
|
||||
" SELECT 1 FROM tenant_model tm "
|
||||
" WHERE tm.provider_id = tmp.id AND tm.model_name = tl.llm_name AND tm.instance_id = tmi.id"
|
||||
")"
|
||||
f"SELECT tl.id, tl.llm_name, tmp.id as provider_id, tmi.id as instance_id, "
|
||||
f" tl.model_type, tl.status "
|
||||
f"FROM tenant_llm tl "
|
||||
f"INNER JOIN tenant_model_provider tmp ON tmp.tenant_id = tl.tenant_id AND tmp.provider_name = tl.llm_factory "
|
||||
f"INNER JOIN tenant_model_instance tmi ON tmi.provider_id = tmp.id AND tmi.api_key = tl.api_key "
|
||||
f"WHERE {status_condition} "
|
||||
f"AND NOT EXISTS ("
|
||||
f" SELECT 1 FROM tenant_model tm "
|
||||
f" WHERE tm.provider_id = tmp.id AND tm.model_name = tl.llm_name AND tm.instance_id = tmi.id"
|
||||
f")"
|
||||
)
|
||||
|
||||
records = cursor.fetchall()
|
||||
@@ -727,7 +754,6 @@ class TenantModelStage(MigrationStage):
|
||||
update_time BIGINT,
|
||||
update_date DATETIME,
|
||||
INDEX idx_instance_id (instance_id),
|
||||
UNIQUE INDEX idx_provider_model_instance (provider_id, model_name, instance_id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
"""
|
||||
self.db.execute_sql(create_sql)
|
||||
|
||||
Reference in New Issue
Block a user