mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
### What problem does this PR solve? 1. add modelID for delete_model and update_status 2. fix the bug when update-status delete model ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality)
143 lines
5.3 KiB
Go
143 lines
5.3 KiB
Go
//
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
|
|
package dao
|
|
|
|
import (
|
|
"ragflow/internal/entity"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// TenantModelDAO tenant model data access object
|
|
type TenantModelDAO struct{}
|
|
|
|
// NewTenantModelDAO create tenant model DAO
|
|
func NewTenantModelDAO() *TenantModelDAO {
|
|
return &TenantModelDAO{}
|
|
}
|
|
|
|
func (dao *TenantModelDAO) Create(instance *entity.TenantModel) error {
|
|
return DB.Create(instance).Error
|
|
}
|
|
|
|
func (dao *TenantModelDAO) CreateBatch(models []*entity.TenantModel) error {
|
|
if len(models) == 0 {
|
|
return nil
|
|
}
|
|
|
|
return DB.Transaction(func(tx *gorm.DB) error {
|
|
for _, model := range models {
|
|
if err := tx.Create(model).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (dao *TenantModelDAO) DeleteByModelID(modelID string) (int64, error) {
|
|
result := DB.Unscoped().Where("id = ?", modelID).Delete(&entity.TenantModel{})
|
|
return result.RowsAffected, result.Error
|
|
}
|
|
|
|
func (dao *TenantModelDAO) DeleteByModelIDAndProviderIDAndInstanceID(modelID, providerID, instanceID string) (int64, error) {
|
|
result := DB.Unscoped().Where("id = ? AND provider_id = ? AND instance_id = ?", modelID, providerID, instanceID).Delete(&entity.TenantModel{})
|
|
return result.RowsAffected, result.Error
|
|
}
|
|
|
|
func (dao *TenantModelDAO) DeleteByProviderIDAndInstanceID(provideID, instanceID string) (int64, error) {
|
|
result := DB.Unscoped().Where("provider_id = ? AND instance_id = ?", provideID, instanceID).Delete(&entity.TenantModel{})
|
|
return result.RowsAffected, result.Error
|
|
}
|
|
|
|
func (dao *TenantModelDAO) DeleteByProviderIDAndInstanceIDAndModelName(provideID, instanceID, modelName string) (int64, error) {
|
|
result := DB.Unscoped().Where("provider_id = ? AND instance_id = ? AND model_name = ?", provideID, instanceID, modelName).Delete(&entity.TenantModel{})
|
|
return result.RowsAffected, result.Error
|
|
}
|
|
|
|
func (dao *TenantModelDAO) UpdateStatusByIDAndScope(modelID, providerID, instanceID, status string) (int64, error) {
|
|
result := DB.Model(&entity.TenantModel{}).Where("id = ? AND provider_id = ? AND instance_id = ?", modelID, providerID, instanceID).Update("status", status)
|
|
return result.RowsAffected, result.Error
|
|
}
|
|
|
|
// GetByID get tenant model by primary key (id)
|
|
func (dao *TenantModelDAO) GetByID(id string) (*entity.TenantModel, error) {
|
|
var model entity.TenantModel
|
|
err := DB.Where("id = ?", id).First(&model).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &model, nil
|
|
}
|
|
|
|
func (dao *TenantModelDAO) GetModelByProviderIDAndInstanceIDAndModelName(providerID, instanceID, modelName string) (*entity.TenantModel, error) {
|
|
var model entity.TenantModel
|
|
err := DB.Where("provider_id = ? AND instance_id = ? AND model_name = ?", providerID, instanceID, modelName).First(&model).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &model, nil
|
|
}
|
|
|
|
func (dao *TenantModelDAO) GetModelsByProviderIDAndInstanceIDAndModelName(providerID, instanceID, modelName string) ([]*entity.TenantModel, error) {
|
|
var models []*entity.TenantModel
|
|
err := DB.Where("provider_id = ? AND instance_id = ? AND model_name = ?", providerID, instanceID, modelName).Find(&models).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return models, nil
|
|
}
|
|
|
|
func (dao *TenantModelDAO) GetByProviderIDAndInstanceIDAndModelTypeAndModelName(providerID, instanceID, modelType, modelName string) (*entity.TenantModel, error) {
|
|
var model entity.TenantModel
|
|
err := DB.Where("provider_id = ? AND instance_id = ? AND model_type = ? AND model_name = ?", providerID, instanceID, modelType, modelName).First(&model).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &model, nil
|
|
}
|
|
|
|
// GetModelsByInstanceID get all models by instance ID
|
|
func (dao *TenantModelDAO) GetModelsByInstanceID(instanceID string) ([]*entity.TenantModel, error) {
|
|
var models []*entity.TenantModel
|
|
err := DB.Where("instance_id = ?", instanceID).Find(&models).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return models, nil
|
|
}
|
|
|
|
// GetModelsByProviderIDsAndInstanceIDs returns TenantModel rows whose
|
|
// provider_id is in providerIDs and instance_id is in instanceIDs.
|
|
// Mirrors Python's
|
|
// TenantModelService.get_models_by_provider_ids_and_instance_ids and is
|
|
// used to fetch per-tenant enable/disable overrides in bulk during
|
|
// /api/v1/models response assembly. The Go port never WRITES to
|
|
// tenant_model, so callers must treat an empty result as "use factory
|
|
// defaults" — see ModelProviderService.ListTenantAddedModels.
|
|
func (dao *TenantModelDAO) GetModelsByProviderIDsAndInstanceIDs(providerIDs, instanceIDs []string) ([]*entity.TenantModel, error) {
|
|
models := make([]*entity.TenantModel, 0)
|
|
if len(providerIDs) == 0 || len(instanceIDs) == 0 {
|
|
return models, nil
|
|
}
|
|
err := DB.Where("provider_id IN ? AND instance_id IN ?", providerIDs, instanceIDs).Find(&models).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return models, nil
|
|
}
|