Files
ragflow/internal/dao/tenant_model.go
Haruko386 6beae949d8 feat[Go]: add modelID for delete_model and update_status (#16025)
### What problem does this PR solve?

1. add modelID for delete_model and update_status
2. fix the bug when update-status delete model

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
2026-06-18 17:56:51 +08:00

143 lines
5.3 KiB
Go

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package dao
import (
"ragflow/internal/entity"
"gorm.io/gorm"
)
// TenantModelDAO tenant model data access object
type TenantModelDAO struct{}
// NewTenantModelDAO create tenant model DAO
func NewTenantModelDAO() *TenantModelDAO {
return &TenantModelDAO{}
}
func (dao *TenantModelDAO) Create(instance *entity.TenantModel) error {
return DB.Create(instance).Error
}
func (dao *TenantModelDAO) CreateBatch(models []*entity.TenantModel) error {
if len(models) == 0 {
return nil
}
return DB.Transaction(func(tx *gorm.DB) error {
for _, model := range models {
if err := tx.Create(model).Error; err != nil {
return err
}
}
return nil
})
}
func (dao *TenantModelDAO) DeleteByModelID(modelID string) (int64, error) {
result := DB.Unscoped().Where("id = ?", modelID).Delete(&entity.TenantModel{})
return result.RowsAffected, result.Error
}
func (dao *TenantModelDAO) DeleteByModelIDAndProviderIDAndInstanceID(modelID, providerID, instanceID string) (int64, error) {
result := DB.Unscoped().Where("id = ? AND provider_id = ? AND instance_id = ?", modelID, providerID, instanceID).Delete(&entity.TenantModel{})
return result.RowsAffected, result.Error
}
func (dao *TenantModelDAO) DeleteByProviderIDAndInstanceID(provideID, instanceID string) (int64, error) {
result := DB.Unscoped().Where("provider_id = ? AND instance_id = ?", provideID, instanceID).Delete(&entity.TenantModel{})
return result.RowsAffected, result.Error
}
func (dao *TenantModelDAO) DeleteByProviderIDAndInstanceIDAndModelName(provideID, instanceID, modelName string) (int64, error) {
result := DB.Unscoped().Where("provider_id = ? AND instance_id = ? AND model_name = ?", provideID, instanceID, modelName).Delete(&entity.TenantModel{})
return result.RowsAffected, result.Error
}
func (dao *TenantModelDAO) UpdateStatusByIDAndScope(modelID, providerID, instanceID, status string) (int64, error) {
result := DB.Model(&entity.TenantModel{}).Where("id = ? AND provider_id = ? AND instance_id = ?", modelID, providerID, instanceID).Update("status", status)
return result.RowsAffected, result.Error
}
// GetByID get tenant model by primary key (id)
func (dao *TenantModelDAO) GetByID(id string) (*entity.TenantModel, error) {
var model entity.TenantModel
err := DB.Where("id = ?", id).First(&model).Error
if err != nil {
return nil, err
}
return &model, nil
}
func (dao *TenantModelDAO) GetModelByProviderIDAndInstanceIDAndModelName(providerID, instanceID, modelName string) (*entity.TenantModel, error) {
var model entity.TenantModel
err := DB.Where("provider_id = ? AND instance_id = ? AND model_name = ?", providerID, instanceID, modelName).First(&model).Error
if err != nil {
return nil, err
}
return &model, nil
}
func (dao *TenantModelDAO) GetModelsByProviderIDAndInstanceIDAndModelName(providerID, instanceID, modelName string) ([]*entity.TenantModel, error) {
var models []*entity.TenantModel
err := DB.Where("provider_id = ? AND instance_id = ? AND model_name = ?", providerID, instanceID, modelName).Find(&models).Error
if err != nil {
return nil, err
}
return models, nil
}
func (dao *TenantModelDAO) GetByProviderIDAndInstanceIDAndModelTypeAndModelName(providerID, instanceID, modelType, modelName string) (*entity.TenantModel, error) {
var model entity.TenantModel
err := DB.Where("provider_id = ? AND instance_id = ? AND model_type = ? AND model_name = ?", providerID, instanceID, modelType, modelName).First(&model).Error
if err != nil {
return nil, err
}
return &model, nil
}
// GetModelsByInstanceID get all models by instance ID
func (dao *TenantModelDAO) GetModelsByInstanceID(instanceID string) ([]*entity.TenantModel, error) {
var models []*entity.TenantModel
err := DB.Where("instance_id = ?", instanceID).Find(&models).Error
if err != nil {
return nil, err
}
return models, nil
}
// GetModelsByProviderIDsAndInstanceIDs returns TenantModel rows whose
// provider_id is in providerIDs and instance_id is in instanceIDs.
// Mirrors Python's
// TenantModelService.get_models_by_provider_ids_and_instance_ids and is
// used to fetch per-tenant enable/disable overrides in bulk during
// /api/v1/models response assembly. The Go port never WRITES to
// tenant_model, so callers must treat an empty result as "use factory
// defaults" — see ModelProviderService.ListTenantAddedModels.
func (dao *TenantModelDAO) GetModelsByProviderIDsAndInstanceIDs(providerIDs, instanceIDs []string) ([]*entity.TenantModel, error) {
models := make([]*entity.TenantModel, 0)
if len(providerIDs) == 0 || len(instanceIDs) == 0 {
return models, nil
}
err := DB.Where("provider_id IN ? AND instance_id IN ?", providerIDs, instanceIDs).Find(&models).Error
if err != nil {
return nil, err
}
return models, nil
}