diff --git a/.gitignore b/.gitignore index 6316ddb960..58fbda13bb 100644 --- a/.gitignore +++ b/.gitignore @@ -220,7 +220,13 @@ uv-aarch64*.tar.gz uv-aarch64-unknown-linux-gnu.tar.gz docker/launch_backend_service_windows.sh +# C++ build directories +internal/cpp/build/ +internal/cpp/cmake-build-release/ +internal/cpp/cmake-build-debug/ + +# Trae IDE config +.trae/ + # Go server build output bin/ -internal/cpp/cmake-build-release/ -internal/cpp/cmake-build-debug/ \ No newline at end of file diff --git a/cmd/server_main.go b/cmd/server_main.go index 011869145b..7d6ac21e8d 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "os/signal" + "ragflow/internal/init_data" "ragflow/internal/server" "ragflow/internal/utility" "strings" @@ -71,6 +72,13 @@ func main() { logger.Fatal("Failed to initialize database", zap.Error(err)) } + // Initialize LLM factory data models from configuration file + if err := init_data.InitLLMFactory(); err != nil { + logger.Error("Failed to initialize LLM factory", err) + } else { + logger.Info("LLM factory initialized successfully") + } + // Initialize doc engine if err := engine.Init(&cfg.DocEngine); err != nil { logger.Fatal("Failed to initialize doc engine", zap.Error(err)) diff --git a/internal/dao/chat.go b/internal/dao/chat.go index 1500ea540a..91f1b7d1dc 100644 --- a/internal/dao/chat.go +++ b/internal/dao/chat.go @@ -62,8 +62,13 @@ func (dao *ChatDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pag user.nickname, user.avatar as tenant_avatar `). - Joins("LEFT JOIN user ON dialog.tenant_id = user.id"). - Where("(dialog.tenant_id IN ? OR dialog.tenant_id = ?) AND dialog.status = ?", tenantIDs, userID, "1") + Joins("LEFT JOIN user ON dialog.tenant_id = user.id") + + if len(tenantIDs) > 0 { + query = query.Where("(dialog.tenant_id IN ? OR dialog.tenant_id = ?) AND dialog.status = ?", tenantIDs, userID, "1") + } else { + query = query.Where("dialog.tenant_id = ? AND dialog.status = ?", userID, "1") + } // Apply keyword filter if keywords != "" { diff --git a/internal/dao/kb.go b/internal/dao/kb.go index cf36e1a7e6..6fb2a6a244 100644 --- a/internal/dao/kb.go +++ b/internal/dao/kb.go @@ -19,6 +19,7 @@ package dao import ( "ragflow/internal/model" "strings" + "time" ) // KnowledgebaseDAO knowledge base data access object @@ -29,15 +30,133 @@ func NewKnowledgebaseDAO() *KnowledgebaseDAO { return &KnowledgebaseDAO{} } -// ListByTenantIDs list knowledge bases by tenant IDs -func (dao *KnowledgebaseDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, keywords, parserID string) ([]*model.Knowledgebase, int64, error) { +// Create creates a new knowledge base record +func (dao *KnowledgebaseDAO) Create(kb *model.Knowledgebase) error { + return DB.Create(kb).Error +} + +// Update updates a knowledge base record +func (dao *KnowledgebaseDAO) Update(kb *model.Knowledgebase) error { + return DB.Save(kb).Error +} + +// UpdateByID updates a knowledge base by ID with the given fields +func (dao *KnowledgebaseDAO) UpdateByID(id string, updates map[string]interface{}) error { + return DB.Model(&model.Knowledgebase{}).Where("id = ?", id).Updates(updates).Error +} + +// Delete soft deletes a knowledge base by setting status to invalid +func (dao *KnowledgebaseDAO) Delete(id string) error { + return DB.Model(&model.Knowledgebase{}).Where("id = ?", id).Update("status", string(model.StatusInvalid)).Error +} + +// GetByID retrieves a knowledge base by ID +func (dao *KnowledgebaseDAO) GetByID(id string) (*model.Knowledgebase, error) { + var kb model.Knowledgebase + err := DB.Where("id = ? AND status = ?", id, string(model.StatusValid)).First(&kb).Error + if err != nil { + return nil, err + } + return &kb, nil +} + +// GetByIDAndTenantID retrieves a knowledge base by ID and tenant ID +func (dao *KnowledgebaseDAO) GetByIDAndTenantID(id, tenantID string) (*model.Knowledgebase, error) { + var kb model.Knowledgebase + err := DB.Where("id = ? AND tenant_id = ? AND status = ?", id, tenantID, string(model.StatusValid)).First(&kb).Error + if err != nil { + return nil, err + } + return &kb, nil +} + +// GetByIDs retrieves multiple knowledge bases by IDs +func (dao *KnowledgebaseDAO) GetByIDs(ids []string) ([]*model.Knowledgebase, error) { var kbs []*model.Knowledgebase + err := DB.Where("id IN ? AND status = ?", ids, string(model.StatusValid)).Find(&kbs).Error + return kbs, err +} + +// GetByName retrieves a knowledge base by name and tenant ID +func (dao *KnowledgebaseDAO) GetByName(name, tenantID string) (*model.Knowledgebase, error) { + var kb model.Knowledgebase + err := DB.Where("name = ? AND tenant_id = ? AND status = ?", name, tenantID, string(model.StatusValid)).First(&kb).Error + if err != nil { + return nil, err + } + return &kb, nil +} + +// GetByCreatedBy retrieves knowledge bases created by a specific user +func (dao *KnowledgebaseDAO) GetByCreatedBy(createdBy string) ([]*model.Knowledgebase, error) { + var kbs []*model.Knowledgebase + err := DB.Where("created_by = ? AND status = ?", createdBy, string(model.StatusValid)).Find(&kbs).Error + return kbs, err +} + +// Query retrieves knowledge bases with filters +func (dao *KnowledgebaseDAO) Query(filters map[string]interface{}) ([]*model.Knowledgebase, error) { + var kbs []*model.Knowledgebase + query := DB.Where("status = ?", string(model.StatusValid)) + + for key, value := range filters { + if value != nil && value != "" { + query = query.Where(key+" = ?", value) + } + } + + err := query.Find(&kbs).Error + return kbs, err +} + +// QueryOne retrieves a single knowledge base with filters +func (dao *KnowledgebaseDAO) QueryOne(filters map[string]interface{}) (*model.Knowledgebase, error) { + var kb model.Knowledgebase + query := DB.Where("status = ?", string(model.StatusValid)) + + for key, value := range filters { + if value != nil && value != "" { + query = query.Where(key+" = ?", value) + } + } + + err := query.First(&kb).Error + if err != nil { + return nil, err + } + return &kb, nil +} + +// Count returns the count of knowledge bases matching the filters +func (dao *KnowledgebaseDAO) Count(filters map[string]interface{}) (int64, error) { + var count int64 + query := DB.Model(&model.Knowledgebase{}).Where("status = ?", string(model.StatusValid)) + + for key, value := range filters { + if value != nil && value != "" { + query = query.Where(key+" = ?", value) + } + } + + err := query.Count(&count).Error + return count, err +} + +// GetByTenantIDs retrieves knowledge bases by tenant IDs with pagination +// This matches the Python get_by_tenant_ids method +func (dao *KnowledgebaseDAO) GetByTenantIDs(tenantIDs []string, userID string, pageNumber, itemsPerPage int, orderby string, desc bool, keywords, parserID string) ([]*model.KnowledgebaseListItem, int64, error) { + var kbs []*model.KnowledgebaseListItem var total int64 query := DB.Model(&model.Knowledgebase{}). + Select(`knowledgebase.id, knowledgebase.avatar, knowledgebase.name, + knowledgebase.language, knowledgebase.description, knowledgebase.tenant_id, + knowledgebase.permission, knowledgebase.doc_num, knowledgebase.token_num, + knowledgebase.chunk_num, knowledgebase.parser_id, knowledgebase.embd_id, + user.nickname, user.avatar as tenant_avatar, knowledgebase.update_time`). Joins("LEFT JOIN user ON knowledgebase.tenant_id = user.id"). - Where("(knowledgebase.tenant_id IN ? AND knowledgebase.permission = ?) OR knowledgebase.tenant_id = ?", tenantIDs, "team", userID). - Where("knowledgebase.status = ?", "1") + Where("((knowledgebase.tenant_id IN ? AND knowledgebase.permission = ?) OR knowledgebase.tenant_id = ?) AND knowledgebase.status = ?", + tenantIDs, string(model.TenantPermissionTeam), userID, string(model.StatusValid)) if keywords != "" { query = query.Where("LOWER(knowledgebase.name) LIKE ?", "%"+strings.ToLower(keywords)+"%") @@ -47,22 +166,287 @@ func (dao *KnowledgebaseDAO) ListByTenantIDs(tenantIDs []string, userID string, query = query.Where("knowledgebase.parser_id = ?", parserID) } - // Order + if desc { + query = query.Order("knowledgebase." + orderby + " DESC") + } else { + query = query.Order("knowledgebase." + orderby + " ASC") + } + + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + if pageNumber > 0 && itemsPerPage > 0 { + offset := (pageNumber - 1) * itemsPerPage + if err := query.Offset(offset).Limit(itemsPerPage).Scan(&kbs).Error; err != nil { + return nil, 0, err + } + } else { + if err := query.Scan(&kbs).Error; err != nil { + return nil, 0, err + } + } + + return kbs, total, nil +} + +// GetAllByTenantIDs retrieves all permitted knowledge bases by tenant IDs +// This matches the Python get_all_kb_by_tenant_ids method +func (dao *KnowledgebaseDAO) GetAllByTenantIDs(tenantIDs []string, userID string) ([]*model.Knowledgebase, error) { + var kbs []*model.Knowledgebase + + err := DB.Where( + "(tenant_id IN ? AND permission = ?) OR tenant_id = ?", + tenantIDs, string(model.TenantPermissionTeam), userID, + ).Order("create_time ASC").Find(&kbs).Error + + return kbs, err +} + +// GetDetail retrieves detailed knowledge base information with joined pipeline data +// This matches the Python get_detail method +func (dao *KnowledgebaseDAO) GetDetail(kbID string) (*model.KnowledgebaseDetail, error) { + var detail model.KnowledgebaseDetail + + err := DB.Table("knowledgebase"). + Select(`knowledgebase.id, knowledgebase.embd_id, knowledgebase.avatar, knowledgebase.name, + knowledgebase.language, knowledgebase.description, knowledgebase.permission, + knowledgebase.doc_num, knowledgebase.token_num, knowledgebase.chunk_num, + knowledgebase.parser_id, knowledgebase.pipeline_id, + user_canvas.title as pipeline_name, user_canvas.avatar as pipeline_avatar, + knowledgebase.parser_config, knowledgebase.pagerank, + knowledgebase.graphrag_task_id, knowledgebase.graphrag_task_finish_at, + knowledgebase.raptor_task_id, knowledgebase.raptor_task_finish_at, + knowledgebase.mindmap_task_id, knowledgebase.mindmap_task_finish_at, + knowledgebase.create_time, knowledgebase.update_time`). + Joins("LEFT JOIN user_canvas ON knowledgebase.pipeline_id = user_canvas.id"). + Where("knowledgebase.id = ? AND knowledgebase.status = ?", kbID, string(model.StatusValid)). + Scan(&detail).Error + + if err != nil { + return nil, err + } + + return &detail, nil +} + +// Accessible checks if a knowledge base is accessible by a user +// This matches the Python accessible method +func (dao *KnowledgebaseDAO) Accessible(kbID, userID string) bool { + var count int64 + err := DB.Table("knowledgebase"). + Joins("JOIN user_tenant ON user_tenant.tenant_id = knowledgebase.tenant_id"). + Where("knowledgebase.id = ? AND user_tenant.user_id = ? AND knowledgebase.status = ?", + kbID, userID, string(model.StatusValid)). + Count(&count).Error + + if err != nil { + return false + } + return count > 0 +} + +// Accessible4Deletion checks if a knowledge base can be deleted by a user +// This matches the Python accessible4deletion method +func (dao *KnowledgebaseDAO) Accessible4Deletion(kbID, userID string) bool { + var count int64 + err := DB.Model(&model.Knowledgebase{}). + Where("id = ? AND created_by = ? AND status = ?", kbID, userID, string(model.StatusValid)). + Count(&count).Error + + if err != nil { + return false + } + return count > 0 +} + +// DuplicateName generates a unique name by appending parentheses if name already exists +// This matches the Python duplicate_name function behavior +func (dao *KnowledgebaseDAO) DuplicateName(name, tenantID string) string { + var existingNames []string + DB.Model(&model.Knowledgebase{}). + Where("name LIKE ? AND tenant_id = ? AND status = ?", name+"%", tenantID, string(model.StatusValid)). + Pluck("name", &existingNames) + + if len(existingNames) == 0 { + return name + } + + nameSet := make(map[string]bool) + for _, n := range existingNames { + nameSet[strings.ToLower(n)] = true + } + + if !nameSet[strings.ToLower(name)] { + return name + } + + for i := 1; ; i++ { + newName := name + " " + strings.Repeat("(", i) + strings.Repeat(")", i) + if !nameSet[strings.ToLower(newName)] { + return newName + } + } +} + +// AtomicIncreaseDocNumByID atomically increments the document count +// This matches the Python atomic_increase_doc_num_by_id method +func (dao *KnowledgebaseDAO) AtomicIncreaseDocNumByID(kbID string) error { + now := time.Now().Unix() + nowDate := time.Now() + return DB.Model(&model.Knowledgebase{}). + Where("id = ?", kbID). + Updates(map[string]interface{}{ + "doc_num": DB.Raw("doc_num + 1"), + "update_time": now, + "update_date": nowDate, + }).Error +} + +// DecreaseDocumentNum decreases document, chunk, and token counts +// This matches the Python decrease_document_num_in_delete method +func (dao *KnowledgebaseDAO) DecreaseDocumentNum(kbID string, docNum, chunkNum, tokenNum int64) error { + now := time.Now().Unix() + nowDate := time.Now() + return DB.Model(&model.Knowledgebase{}). + Where("id = ?", kbID). + Updates(map[string]interface{}{ + "doc_num": DB.Raw("doc_num - ?", docNum), + "chunk_num": DB.Raw("chunk_num - ?", chunkNum), + "token_num": DB.Raw("token_num - ?", tokenNum), + "update_time": now, + "update_date": nowDate, + }).Error +} + +// GetKBIDsByTenantID retrieves all knowledge base IDs for a tenant +// This matches the Python get_kb_ids method +func (dao *KnowledgebaseDAO) GetKBIDsByTenantID(tenantID string) ([]string, error) { + var kbIDs []string + err := DB.Model(&model.Knowledgebase{}). + Where("tenant_id = ? AND status = ?", tenantID, string(model.StatusValid)). + Pluck("id", &kbIDs).Error + return kbIDs, err +} + +// GetAllIDs retrieves all knowledge base IDs +// This matches the Python get_all_ids method +func (dao *KnowledgebaseDAO) GetAllIDs() ([]string, error) { + var kbIDs []string + err := DB.Model(&model.Knowledgebase{}). + Where("status = ?", string(model.StatusValid)). + Pluck("id", &kbIDs).Error + return kbIDs, err +} + +// UpdateParserConfig updates the parser configuration with deep merge +// This matches the Python update_parser_config method +func (dao *KnowledgebaseDAO) UpdateParserConfig(id string, config map[string]interface{}) error { + var kb model.Knowledgebase + if err := DB.Where("id = ? AND status = ?", id, string(model.StatusValid)).First(&kb).Error; err != nil { + return err + } + + mergedConfig := mergeConfig(kb.ParserConfig, config) + return DB.Model(&model.Knowledgebase{}). + Where("id = ?", id). + Update("parser_config", mergedConfig).Error +} + +// DeleteFieldMap removes the field_map from parser_config +// This matches the Python delete_field_map method +func (dao *KnowledgebaseDAO) DeleteFieldMap(id string) error { + var kb model.Knowledgebase + if err := DB.Where("id = ? AND status = ?", id, string(model.StatusValid)).First(&kb).Error; err != nil { + return err + } + + if kb.ParserConfig != nil { + delete(kb.ParserConfig, "field_map") + return DB.Model(&model.Knowledgebase{}). + Where("id = ?", id). + Update("parser_config", kb.ParserConfig).Error + } + return nil +} + +// GetFieldMap retrieves field mappings from multiple knowledge bases +// This matches the Python get_field_map method +func (dao *KnowledgebaseDAO) GetFieldMap(ids []string) (map[string]interface{}, error) { + conf := make(map[string]interface{}) + kbs, err := dao.GetByIDs(ids) + if err != nil { + return nil, err + } + + for _, kb := range kbs { + if kb.ParserConfig != nil { + if fieldMap, ok := kb.ParserConfig["field_map"]; ok { + if fm, ok := fieldMap.(map[string]interface{}); ok { + for k, v := range fm { + conf[k] = v + } + } + } + } + } + return conf, nil +} + +// GetKBByIDAndUserID retrieves a knowledge base by ID and user ID with tenant join +// This matches the Python get_kb_by_id method +func (dao *KnowledgebaseDAO) GetKBByIDAndUserID(kbID, userID string) ([]*model.Knowledgebase, error) { + var kbs []*model.Knowledgebase + err := DB.Model(&model.Knowledgebase{}). + Joins("JOIN user_tenant ON user_tenant.tenant_id = knowledgebase.tenant_id"). + Where("knowledgebase.id = ? AND user_tenant.user_id = ?", kbID, userID). + Limit(1). + Find(&kbs).Error + return kbs, err +} + +// GetKBByNameAndUserID retrieves a knowledge base by name and user ID with tenant join +// This matches the Python get_kb_by_name method +func (dao *KnowledgebaseDAO) GetKBByNameAndUserID(kbName, userID string) ([]*model.Knowledgebase, error) { + var kbs []*model.Knowledgebase + err := DB.Model(&model.Knowledgebase{}). + Joins("JOIN user_tenant ON user_tenant.tenant_id = knowledgebase.tenant_id"). + Where("knowledgebase.name = ? AND user_tenant.user_id = ?", kbName, userID). + Limit(1). + Find(&kbs).Error + return kbs, err +} + +// GetList retrieves knowledge bases with filtering by ID and name +// This matches the Python get_list method +func (dao *KnowledgebaseDAO) GetList(tenantIDs []string, userID string, pageNumber, itemsPerPage int, orderby string, desc bool, id, name string) ([]*model.Knowledgebase, int64, error) { + var kbs []*model.Knowledgebase + var total int64 + + query := DB.Model(&model.Knowledgebase{}). + Where("((tenant_id IN ? AND permission = ?) OR tenant_id = ?) AND status = ?", + tenantIDs, string(model.TenantPermissionTeam), userID, string(model.StatusValid)) + + if id != "" { + query = query.Where("id = ?", id) + } + if name != "" { + query = query.Where("name = ?", name) + } + if desc { query = query.Order(orderby + " DESC") } else { query = query.Order(orderby + " ASC") } - // Count if err := query.Count(&total).Error; err != nil { return nil, 0, err } - // Pagination - if page > 0 && pageSize > 0 { - offset := (page - 1) * pageSize - if err := query.Offset(offset).Limit(pageSize).Find(&kbs).Error; err != nil { + if pageNumber > 0 && itemsPerPage > 0 { + offset := (pageNumber - 1) * itemsPerPage + if err := query.Offset(offset).Limit(itemsPerPage).Find(&kbs).Error; err != nil { return nil, 0, err } } else { @@ -74,76 +458,39 @@ func (dao *KnowledgebaseDAO) ListByTenantIDs(tenantIDs []string, userID string, return kbs, total, nil } -// ListByOwnerIDs list knowledge bases by owner IDs -func (dao *KnowledgebaseDAO) ListByOwnerIDs(ownerIDs []string, page, pageSize int, orderby string, desc bool, keywords, parserID string) ([]*model.Knowledgebase, int64, error) { - var kbs []*model.Knowledgebase - - query := DB.Model(&model.Knowledgebase{}). - Joins("LEFT JOIN user ON knowledgebase.tenant_id = user.id"). - Where("knowledgebase.tenant_id IN ?", ownerIDs). - Where("knowledgebase.status = ?", "1") - - if keywords != "" { - query = query.Where("LOWER(knowledgebase.name) LIKE ?", "%"+strings.ToLower(keywords)+"%") +// mergeConfig performs a deep merge of configuration maps +func mergeConfig(old, new map[string]interface{}) map[string]interface{} { + result := make(map[string]interface{}) + for k, v := range old { + result[k] = v } - if parserID != "" { - query = query.Where("knowledgebase.parser_id = ?", parserID) - } - - // Order - if desc { - query = query.Order(orderby + " DESC") - } else { - query = query.Order(orderby + " ASC") - } - - if err := query.Find(&kbs).Error; err != nil { - return nil, 0, err - } - - total := int64(len(kbs)) - - // Manual pagination - if page > 0 && pageSize > 0 { - start := (page - 1) * pageSize - end := start + pageSize - if end > int(total) { - end = int(total) - } - if start < end { - kbs = kbs[start:end] - } else { - kbs = []*model.Knowledgebase{} + for k, v := range new { + if existing, ok := result[k]; ok { + if existingMap, ok := existing.(map[string]interface{}); ok { + if newMap, ok := v.(map[string]interface{}); ok { + result[k] = mergeConfig(existingMap, newMap) + continue + } + } + if existingSlice, ok := existing.([]interface{}); ok { + if newSlice, ok := v.([]interface{}); ok { + merged := append(existingSlice, newSlice...) + seen := make(map[interface{}]bool) + unique := make([]interface{}, 0) + for _, item := range merged { + if !seen[item] { + seen[item] = true + unique = append(unique, item) + } + } + result[k] = unique + continue + } + } } + result[k] = v } - return kbs, total, nil -} - -// GetByID gets knowledge base by ID -func (dao *KnowledgebaseDAO) GetByID(id string) (*model.Knowledgebase, error) { - var kb model.Knowledgebase - err := DB.Where("id = ? AND status = ?", id, "1").First(&kb).Error - if err != nil { - return nil, err - } - return &kb, nil -} - -// GetByIDAndTenantID gets knowledge base by ID and tenant ID -func (dao *KnowledgebaseDAO) GetByIDAndTenantID(id, tenantID string) (*model.Knowledgebase, error) { - var kb model.Knowledgebase - err := DB.Where("id = ? AND tenant_id = ? AND status = ?", id, tenantID, "1").First(&kb).Error - if err != nil { - return nil, err - } - return &kb, nil -} - -// GetByIDs gets knowledge bases by IDs -func (dao *KnowledgebaseDAO) GetByIDs(ids []string) ([]*model.Knowledgebase, error) { - var kbs []*model.Knowledgebase - err := DB.Where("id IN ? AND status = ?", ids, "1").Find(&kbs).Error - return kbs, err + return result } diff --git a/internal/dao/llm.go b/internal/dao/llm.go index 44590ca9dc..a522cfb091 100644 --- a/internal/dao/llm.go +++ b/internal/dao/llm.go @@ -67,3 +67,31 @@ func (dao *LLMDAO) GetByFactoryAndName(factory, name string) (*model.LLM, error) } return &llm, nil } + +// LLMFactoryDAO LLM factory data access object +type LLMFactoryDAO struct{} + +// NewLLMFactoryDAO create LLM factory DAO +func NewLLMFactoryDAO() *LLMFactoryDAO { + return &LLMFactoryDAO{} +} + +// GetAllValid gets all valid LLM factories +func (dao *LLMFactoryDAO) GetAllValid() ([]*model.LLMFactories, error) { + var factories []*model.LLMFactories + err := DB.Where("status = ?", "1").Find(&factories).Error + if err != nil { + return nil, err + } + return factories, nil +} + +// GetByName gets LLM factory by name +func (dao *LLMFactoryDAO) GetByName(name string) (*model.LLMFactories, error) { + var factory model.LLMFactories + err := DB.Where("name = ?", name).First(&factory).Error + if err != nil { + return nil, err + } + return &factory, nil +} diff --git a/internal/dao/tenant.go b/internal/dao/tenant.go index 781c6c2058..0585c481af 100644 --- a/internal/dao/tenant.go +++ b/internal/dao/tenant.go @@ -75,7 +75,7 @@ func (dao *TenantDAO) GetInfoByUserID(userID string) ([]*TenantInfo, error) { Joins("INNER JOIN user_tenant ON user_tenant.tenant_id = tenant.id"). Where("user_tenant.user_id = ? AND user_tenant.status = ? AND user_tenant.role = ? AND tenant.status = ?", userID, "1", "owner", "1"). Scan(&results).Error - + return results, err } @@ -98,3 +98,8 @@ func (dao *TenantDAO) Create(tenant *model.Tenant) error { func (dao *TenantDAO) Delete(id string) error { return DB.Model(&model.Tenant{}).Where("id = ?", id).Update("status", "0").Error } + +// Update updates a tenant by ID +func (dao *TenantDAO) Update(id string, updates map[string]interface{}) error { + return DB.Model(&model.Tenant{}).Where("id = ?", id).Updates(updates).Error +} diff --git a/internal/dao/tenant_llm.go b/internal/dao/tenant_llm.go index 8752e041fa..e563ad0e87 100644 --- a/internal/dao/tenant_llm.go +++ b/internal/dao/tenant_llm.go @@ -94,21 +94,14 @@ func (dao *TenantLLMDAO) Delete(tenantID, factory, modelName string) error { } // GetMyLLMs get tenant LLMs with factory details -func (dao *TenantLLMDAO) GetMyLLMs(tenantID string, includeDetails bool) ([]model.MyLLM, error) { +func (dao *TenantLLMDAO) GetMyLLMs(tenantID string) ([]model.MyLLM, error) { var myLLMs []model.MyLLM - // Base query - query := DB.Table("tenant_llm tl"). - Select("tl.llm_factory, lf.logo, lf.tags, tl.model_type, tl.llm_name, tl.used_tokens, tl.status"). + err := DB.Table("tenant_llm tl"). + Select("tl.id, tl.llm_factory, lf.logo, lf.tags, tl.model_type, tl.llm_name, tl.used_tokens, tl.status"). Joins("JOIN llm_factories lf ON tl.llm_factory = lf.name"). - Where("tl.tenant_id = ? AND tl.api_key IS NOT NULL", tenantID) - - // Add detailed fields if requested - if includeDetails { - query = query.Select("tl.llm_factory, lf.logo, lf.tags, tl.model_type, tl.llm_name, tl.used_tokens, tl.status, tl.api_base, tl.max_tokens") - } - - err := query.Find(&myLLMs).Error + Where("tl.tenant_id = ? AND tl.api_key IS NOT NULL", tenantID). + Find(&myLLMs).Error if err != nil { return nil, err } diff --git a/internal/handler/kb.go b/internal/handler/kb.go index e4e2a025b4..a7b5f7ac25 100644 --- a/internal/handler/kb.go +++ b/internal/handler/kb.go @@ -18,20 +18,21 @@ package handler import ( "net/http" + "ragflow/internal/common" + "ragflow/internal/service" "strconv" + "strings" "github.com/gin-gonic/gin" - - "ragflow/internal/service" ) -// KnowledgebaseHandler knowledge base handler +// KnowledgebaseHandler handles knowledge base HTTP requests type KnowledgebaseHandler struct { kbService *service.KnowledgebaseService userService *service.UserService } -// NewKnowledgebaseHandler create knowledge base handler +// NewKnowledgebaseHandler creates a new knowledge base handler func NewKnowledgebaseHandler(kbService *service.KnowledgebaseService, userService *service.UserService) *KnowledgebaseHandler { return &KnowledgebaseHandler{ kbService: kbService, @@ -39,35 +40,227 @@ func NewKnowledgebaseHandler(kbService *service.KnowledgebaseService, userServic } } -// ListKbs list knowledge bases -// @Summary List Knowledge Bases -// @Description Get list of knowledge bases with filtering and pagination +// getUserID extracts user ID from authorization header +// It validates the authorization token and returns the user ID +// Parameters: +// - c: gin.Context - the HTTP request context +// +// Returns: +// - string: the user ID +// - common.ErrorCode: the error code +// - error: any error that occurred +func (h *KnowledgebaseHandler) getUserID(c *gin.Context) (string, common.ErrorCode, error) { + token := c.GetHeader("Authorization") + if token == "" { + return "", common.CodeUnauthorized, ErrMissingAuth + } + + user, code, err := h.userService.GetUserByToken(token) + if err != nil { + return "", code, err + } + + return user.ID, common.CodeSuccess, nil +} + +// jsonResponse sends a JSON response with code and message +func jsonResponse(c *gin.Context, code common.ErrorCode, data interface{}, message string) { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "data": data, + "message": message, + }) +} + +// jsonError sends a JSON error response +func jsonError(c *gin.Context, code common.ErrorCode, message string) { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "data": nil, + "message": message, + }) +} + +// HTTPError represents an HTTP error +type HTTPError struct { + Code common.ErrorCode + Message string +} + +// Error implements the error interface +func (e *HTTPError) Error() string { + return e.Message +} + +var ( + // ErrMissingAuth indicates missing authorization header + ErrMissingAuth = &HTTPError{Code: common.CodeUnauthorized, Message: "Missing Authorization header"} + // ErrInvalidToken indicates invalid access token + ErrInvalidToken = &HTTPError{Code: common.CodeUnauthorized, Message: "Invalid access token"} +) + +// CreateKB handles the create knowledge base request +// @Summary Create Knowledge Base +// @Description Create a new knowledge base (dataset) // @Tags knowledgebase // @Accept json // @Produce json -// @Param keywords query string false "search keywords" -// @Param page query int false "page number" -// @Param page_size query int false "items per page" -// @Param parser_id query string false "parser ID filter" -// @Param orderby query string false "order by field" -// @Param desc query bool false "descending order" -// @Param request body service.ListKbsRequest true "filter options" -// @Success 200 {object} service.ListKbsResponse +// @Security ApiKeyAuth +// @Param request body service.CreateKBRequest true "knowledge base info" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/create [post] +func (h *KnowledgebaseHandler) CreateKB(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + var req service.CreateKBRequest + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + result, code, err := h.kbService.CreateKB(&req, userID) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, result, "success") +} + +// UpdateKB handles the update knowledge base request +// @Summary Update Knowledge Base +// @Description Update an existing knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param request body service.UpdateKBRequest true "knowledge base update info" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/update [post] +func (h *KnowledgebaseHandler) UpdateKB(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + var req service.UpdateKBRequest + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + result, code, err := h.kbService.UpdateKB(&req, userID) + if err != nil { + if strings.Contains(err.Error(), "authorization") { + jsonError(c, common.CodeAuthenticationError, err.Error()) + return + } + jsonError(c, code, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, result, "success") +} + +// UpdateMetadataSetting handles the update metadata setting request +// @Summary Update Metadata Setting +// @Description Update metadata settings for a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param request body service.UpdateMetadataSettingRequest true "metadata setting info" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/update_metadata_setting [post] +func (h *KnowledgebaseHandler) UpdateMetadataSetting(c *gin.Context) { + _, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + var req service.UpdateMetadataSettingRequest + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + result, code, err := h.kbService.UpdateMetadataSetting(&req) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, result, "success") +} + +// GetDetail handles the get knowledge base detail request +// @Summary Get Knowledge Base Detail +// @Description Get detailed information about a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_id query string true "Knowledge Base ID" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/detail [get] +func (h *KnowledgebaseHandler) GetDetail(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbID := c.Query("kb_id") + if kbID == "" { + jsonError(c, common.CodeDataError, "kb_id is required") + return + } + + result, code, err := h.kbService.GetDetail(kbID, userID) + if err != nil { + if strings.Contains(err.Error(), "authorized") { + jsonError(c, common.CodeOperatingError, err.Error()) + return + } + jsonError(c, code, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, result, "success") +} + +// ListKbs handles the list knowledge bases request +// @Summary List Knowledge Bases +// @Description List knowledge bases with pagination and filtering +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param request body service.ListKbsRequest true "list options" +// @Success 200 {object} map[string]interface{} // @Router /v1/kb/list [post] func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) { - // Parse request body - allow empty body + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + var req service.ListKbsRequest if c.Request.ContentLength > 0 { if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": err.Error(), - }) + jsonError(c, common.CodeDataError, err.Error()) return } } - // Extract parameters from query or request body with defaults + // Get parameters from request or query string keywords := "" if req.Keywords != nil { keywords = *req.Keywords @@ -111,7 +304,7 @@ func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) { if req.Desc != nil { desc = *req.Desc } else if descStr := c.Query("desc"); descStr != "" { - desc = descStr == "true" + desc = strings.ToLower(descStr) == "true" } var ownerIDs []string @@ -119,40 +312,327 @@ func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) { ownerIDs = *req.OwnerIDs } - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token - user, code, err := h.userService.GetUserByToken(token) + result, code, err := h.kbService.ListKbs(keywords, page, pageSize, parserID, orderby, desc, ownerIDs, userID) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) - return - } - userID := user.ID - - // List knowledge bases - result, err := h.kbService.ListKbs(keywords, page, pageSize, parserID, orderby, desc, ownerIDs, userID) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, - "message": err.Error(), - }) + jsonError(c, code, err.Error()) return } - c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": result, - "message": "success", - }) + jsonResponse(c, common.CodeSuccess, result, "success") +} + +// DeleteKB handles the delete knowledge base request +// @Summary Delete Knowledge Base +// @Description Soft delete a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param request body object{kb_id string} true "knowledge base id" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/rm [post] +func (h *KnowledgebaseHandler) DeleteKB(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + var req struct { + KBID string `json:"kb_id" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + code, err = h.kbService.DeleteKB(req.KBID, userID) + if err != nil { + if strings.Contains(err.Error(), "authorization") { + jsonError(c, common.CodeAuthenticationError, err.Error()) + return + } + jsonError(c, code, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, true, "success") +} + +// ListTags handles the list tags request for a knowledge base +// @Summary List Tags +// @Description List tags for a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_id path string true "Knowledge Base ID" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/{kb_id}/tags [get] +func (h *KnowledgebaseHandler) ListTags(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbID := c.Param("kb_id") + if kbID == "" { + jsonError(c, common.CodeDataError, "kb_id is required") + return + } + + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + + jsonResponse(c, common.CodeSuccess, []string{}, "success") +} + +// ListTagsFromKbs handles the list tags from multiple knowledge bases request +// @Summary List Tags from Knowledge Bases +// @Description List tags from multiple knowledge bases +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_ids query string true "Comma-separated Knowledge Base IDs" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/tags [get] +func (h *KnowledgebaseHandler) ListTagsFromKbs(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbIDsStr := c.Query("kb_ids") + if kbIDsStr == "" { + jsonError(c, common.CodeDataError, "kb_ids is required") + return + } + + kbIDs := strings.Split(kbIDsStr, ",") + for _, kbID := range kbIDs { + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + } + + jsonResponse(c, common.CodeSuccess, []string{}, "success") +} + +// RemoveTags handles the remove tags request +// @Summary Remove Tags +// @Description Remove tags from a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_id path string true "Knowledge Base ID" +// @Param request body object{tags []string} true "tags to remove" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/{kb_id}/rm_tags [post] +func (h *KnowledgebaseHandler) RemoveTags(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbID := c.Param("kb_id") + if kbID == "" { + jsonError(c, common.CodeDataError, "kb_id is required") + return + } + + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + + var req struct { + Tags []string `json:"tags" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, true, "success") +} + +// RenameTag handles the rename tag request +// @Summary Rename Tag +// @Description Rename a tag in a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_id path string true "Knowledge Base ID" +// @Param request body object{from_tag string, to_tag string} true "tag rename info" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/{kb_id}/rename_tag [post] +func (h *KnowledgebaseHandler) RenameTag(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbID := c.Param("kb_id") + if kbID == "" { + jsonError(c, common.CodeDataError, "kb_id is required") + return + } + + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + + var req struct { + FromTag string `json:"from_tag" binding:"required"` + ToTag string `json:"to_tag" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, true, "success") +} + +// KnowledgeGraph handles the get knowledge graph request +// @Summary Get Knowledge Graph +// @Description Get knowledge graph for a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_id path string true "Knowledge Base ID" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/{kb_id}/knowledge_graph [get] +func (h *KnowledgebaseHandler) KnowledgeGraph(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbID := c.Param("kb_id") + if kbID == "" { + jsonError(c, common.CodeDataError, "kb_id is required") + return + } + + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + + result := map[string]interface{}{ + "graph": map[string]interface{}{}, + "mind_map": map[string]interface{}{}, + } + + jsonResponse(c, common.CodeSuccess, result, "success") +} + +// DeleteKnowledgeGraph handles the delete knowledge graph request +// @Summary Delete Knowledge Graph +// @Description Delete knowledge graph for a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_id path string true "Knowledge Base ID" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/{kb_id}/knowledge_graph [delete] +func (h *KnowledgebaseHandler) DeleteKnowledgeGraph(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbID := c.Param("kb_id") + if kbID == "" { + jsonError(c, common.CodeDataError, "kb_id is required") + return + } + + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + + jsonResponse(c, common.CodeSuccess, true, "success") +} + +// GetMeta handles the get metadata request +// @Summary Get Metadata +// @Description Get metadata for knowledge bases +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_ids query string true "Comma-separated Knowledge Base IDs" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/get_meta [get] +func (h *KnowledgebaseHandler) GetMeta(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbIDsStr := c.Query("kb_ids") + if kbIDsStr == "" { + jsonError(c, common.CodeDataError, "kb_ids is required") + return + } + + kbIDs := strings.Split(kbIDsStr, ",") + for _, kbID := range kbIDs { + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + } + + jsonResponse(c, common.CodeSuccess, map[string]interface{}{}, "success") +} + +// GetBasicInfo handles the get basic info request +// @Summary Get Basic Info +// @Description Get basic information for a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_id query string true "Knowledge Base ID" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/basic_info [get] +func (h *KnowledgebaseHandler) GetBasicInfo(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbID := c.Query("kb_id") + if kbID == "" { + jsonError(c, common.CodeDataError, "kb_id is required") + return + } + + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + + jsonResponse(c, common.CodeSuccess, map[string]interface{}{}, "success") } diff --git a/internal/handler/llm.go b/internal/handler/llm.go index 6926dfc97d..9582eb37ad 100644 --- a/internal/handler/llm.go +++ b/internal/handler/llm.go @@ -21,6 +21,7 @@ import ( "github.com/gin-gonic/gin" + "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/service" ) @@ -60,50 +61,112 @@ func NewLLMHandler(llmService *service.LLMService, userService *service.UserServ // @Success 200 {object} map[string]interface{} // @Router /v1/llm/my_llms [get] func (h *LLMHandler) GetMyLLMs(c *gin.Context) { - // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, + "message": "Unauthorized!", + "data": false, }) return } - // Get user by token user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ + c.JSON(http.StatusOK, gin.H{ "code": code, "message": err.Error(), + "data": false, }) return } - // Get tenant ID from user tenantID := user.ID - if tenantID == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "User has no tenant ID", - }) - return - } - - // Parse include_details query parameter includeDetailsStr := c.DefaultQuery("include_details", "false") includeDetails := includeDetailsStr == "true" - // Get LLMs for tenant llms, err := h.llmService.GetMyLLMs(tenantID, includeDetails) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get LLMs", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeExceptionError, + "message": err.Error(), + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ - "data": llms, + "code": common.CodeSuccess, + "message": "success", + "data": llms, + }) +} + +// SetAPIKey set API key for a LLM factory +// @Summary Set API Key +// @Description Set API key for a LLM factory and test connectivity +// @Tags llm +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param request body service.SetAPIKeyRequest true "API Key configuration" +// @Success 200 {object} map[string]interface{} +// @Router /v1/llm/set_api_key [post] +func (h *LLMHandler) SetAPIKey(c *gin.Context) { + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, + "message": "Unauthorized!", + "data": false, + }) + return + } + + user, code, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, + }) + return + } + + var req service.SetAPIKeyRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": "Invalid request: " + err.Error(), + "data": false, + }) + return + } + + tenantID := user.ID + result, err := h.llmService.SetAPIKey(tenantID, &req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeDataError, + "message": err.Error(), + "data": false, + }) + return + } + + if req.Verify { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", + "data": result, + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", + "data": true, }) } @@ -198,52 +261,43 @@ func (h *LLMHandler) Factories(c *gin.Context) { // @Success 200 {object} map[string][]service.LLMListItem // @Router /v1/llm/list [get] func (h *LLMHandler) ListApp(c *gin.Context) { - // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, + "message": "Unauthorized!", + "data": false, }) return } - // Get user by token user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ + c.JSON(http.StatusOK, gin.H{ "code": code, "message": err.Error(), + "data": false, }) return } - // Get tenant ID from user tenantID := user.ID - if tenantID == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "User has no tenant ID", - }) - return - } - // Parse model_type query parameter modelType := c.Query("model_type") - // Get LLM list llms, err := h.llmService.ListLLMs(tenantID, modelType) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeExceptionError, "message": err.Error(), + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": llms, + "code": common.CodeSuccess, "message": "success", + "data": llms, }) } diff --git a/internal/handler/tenant.go b/internal/handler/tenant.go index 02b87a4164..860acc3bbb 100644 --- a/internal/handler/tenant.go +++ b/internal/handler/tenant.go @@ -21,6 +21,7 @@ import ( "github.com/gin-gonic/gin" + "ragflow/internal/common" "ragflow/internal/service" ) @@ -48,44 +49,49 @@ func NewTenantHandler(tenantService *service.TenantService, userService *service // @Success 200 {object} map[string]interface{} // @Router /v1/user/tenant_info [get] func (h *TenantHandler) TenantInfo(c *gin.Context) { - // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - // Get user by token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, + "message": "Unauthorized!", + "data": false, + }) + return + } + + user, code, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } - // Get tenant info tenantInfo, err := h.tenantService.GetTenantInfo(user.ID) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get tenant information", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeExceptionError, + "message": err.Error(), + "data": false, }) return } if tenantInfo == nil { - c.JSON(http.StatusNotFound, gin.H{ - "error": "Tenant not found", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeDataError, + "message": "Tenant not found!", + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": tenantInfo, + "code": common.CodeSuccess, + "message": "success", + "data": tenantInfo, }) } @@ -99,38 +105,39 @@ func (h *TenantHandler) TenantInfo(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/tenant/list [get] func (h *TenantHandler) TenantList(c *gin.Context) { - // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, + "message": "Unauthorized!", + "data": false, }) return } - // Get user by token user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ + c.JSON(http.StatusOK, gin.H{ "code": code, "message": err.Error(), + "data": false, }) return } - // Get tenant list tenantList, err := h.tenantService.GetTenantList(user.ID) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, - "message": "Failed to get tenant list", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeExceptionError, + "message": err.Error(), + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": tenantList, + "code": common.CodeSuccess, + "message": "success", + "data": tenantList, }) } diff --git a/internal/handler/user.go b/internal/handler/user.go index 7fb39a5df9..3651c29c14 100644 --- a/internal/handler/user.go +++ b/internal/handler/user.go @@ -522,3 +522,61 @@ func (h *UserHandler) GetLoginChannels(c *gin.Context) { "data": channels, }) } + +// SetTenantInfo update tenant information +// @Summary Set Tenant Info +// @Description Update tenant model configuration +// @Tags users +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param request body service.SetTenantInfoRequest true "tenant info" +// @Success 200 {object} map[string]interface{} +// @Router /v1/user/set_tenant_info [post] +func (h *UserHandler) SetTenantInfo(c *gin.Context) { + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, + "message": "Unauthorized!", + "data": false, + }) + return + } + + user, code, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, + }) + return + } + + var req service.SetTenantInfoRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": err.Error(), + "data": false, + }) + return + } + + err = h.userService.SetTenantInfo(user.ID, &req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeDataError, + "message": err.Error(), + "data": false, + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", + "data": true, + }) +} diff --git a/internal/init_data/llm_init.go b/internal/init_data/llm_init.go new file mode 100644 index 0000000000..ef67dd6ba3 --- /dev/null +++ b/internal/init_data/llm_init.go @@ -0,0 +1,157 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package init_data + +import ( + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + + "ragflow/internal/dao" + "ragflow/internal/model" +) + +// LLMFactoryConfig represents a single LLM factory configuration +type LLMFactoryConfig struct { + Name string `json:"name"` + Logo string `json:"logo"` + Tags string `json:"tags"` + Status string `json:"status"` + Rank string `json:"rank"` + LLM []LLMConfig `json:"llm"` +} + +// LLMConfig represents a single LLM model configuration +type LLMConfig struct { + LLMName string `json:"llm_name"` + Tags string `json:"tags"` + MaxTokens int64 `json:"max_tokens"` + ModelType string `json:"model_type"` + IsTools bool `json:"is_tools"` +} + +// LLMFactoriesFile represents the structure of llm_factories.json +type LLMFactoriesFile struct { + FactoryLLMInfos []LLMFactoryConfig `json:"factory_llm_infos"` +} + +// InitLLMFactory initializes LLM factories and models from JSON file +func InitLLMFactory() error { + configPath := filepath.Join(getProjectBaseDirectory(), "conf", "llm_factories.json") + + data, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("failed to read llm_factories.json: %w", err) + } + + var fileData LLMFactoriesFile + if err := json.Unmarshal(data, &fileData); err != nil { + return fmt.Errorf("failed to parse llm_factories.json: %w", err) + } + + db := dao.DB + + for _, factory := range fileData.FactoryLLMInfos { + status := factory.Status + if status == "" { + status = "1" + } + + llmFactory := &model.LLMFactories{ + Name: factory.Name, + Logo: stringPtr(factory.Logo), + Tags: factory.Tags, + Rank: parseInt64(factory.Rank), + Status: &status, + } + + var existingFactory model.LLMFactories + result := db.Where("name = ?", factory.Name).First(&existingFactory) + if result.Error != nil { + if err := db.Create(llmFactory).Error; err != nil { + log.Printf("Failed to create LLM factory %s: %v", factory.Name, err) + continue + } + } else { + if err := db.Model(&model.LLMFactories{}).Where("name = ?", factory.Name).Updates(map[string]interface{}{ + "logo": llmFactory.Logo, + "tags": llmFactory.Tags, + "rank": llmFactory.Rank, + "status": llmFactory.Status, + }).Error; err != nil { + log.Printf("Failed to update LLM factory %s: %v", factory.Name, err) + } + } + + for _, llm := range factory.LLM { + llmStatus := "1" + llmModel := &model.LLM{ + LLMName: llm.LLMName, + ModelType: llm.ModelType, + FID: factory.Name, + MaxTokens: llm.MaxTokens, + Tags: llm.Tags, + IsTools: llm.IsTools, + Status: &llmStatus, + } + + var existingLLM model.LLM + result := db.Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).First(&existingLLM) + if result.Error != nil { + if err := db.Create(llmModel).Error; err != nil { + log.Printf("Failed to create LLM %s/%s: %v", factory.Name, llm.LLMName, err) + } + } else { + if err := db.Model(&model.LLM{}).Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).Updates(map[string]interface{}{ + "model_type": llmModel.ModelType, + "max_tokens": llmModel.MaxTokens, + "tags": llmModel.Tags, + "is_tools": llmModel.IsTools, + "status": llmModel.Status, + }).Error; err != nil { + log.Printf("Failed to update LLM %s/%s: %v", factory.Name, llm.LLMName, err) + } + } + } + } + + log.Println("LLM factories initialized successfully") + return nil +} + +func getProjectBaseDirectory() string { + cwd, err := os.Getwd() + if err != nil { + return "." + } + return cwd +} + +func stringPtr(s string) *string { + if s == "" { + return nil + } + return &s +} + +func parseInt64(s string) int64 { + var result int64 + fmt.Sscanf(s, "%d", &result) + return result +} diff --git a/internal/model/kb.go b/internal/model/kb.go index 78cc643721..da5f817e29 100644 --- a/internal/model/kb.go +++ b/internal/model/kb.go @@ -18,7 +18,84 @@ package model import "time" -// Knowledgebase knowledge base model +// DatasetNameLimit is the maximum length for dataset name +const DatasetNameLimit = 128 + +// Status represents the status enum values +type Status string + +const ( + // StatusValid indicates a valid/active record + StatusValid Status = "1" + // StatusInvalid indicates a deleted/inactive record + StatusInvalid Status = "0" +) + +// TenantPermission represents the permission level for tenant access +type TenantPermission string + +const ( + // TenantPermissionMe indicates only the creator can access + TenantPermissionMe TenantPermission = "me" + // TenantPermissionTeam indicates all team members can access + TenantPermissionTeam TenantPermission = "team" +) + +// ParserType represents the document parser type +type ParserType string + +const ( + ParserTypePresentation ParserType = "presentation" + ParserTypeLaws ParserType = "laws" + ParserTypeManual ParserType = "manual" + ParserTypePaper ParserType = "paper" + ParserTypeResume ParserType = "resume" + ParserTypeBook ParserType = "book" + ParserTypeQA ParserType = "qa" + ParserTypeTable ParserType = "table" + ParserTypeNaive ParserType = "naive" + ParserTypePicture ParserType = "picture" + ParserTypeOne ParserType = "one" + ParserTypeAudio ParserType = "audio" + ParserTypeEmail ParserType = "email" + ParserTypeKG ParserType = "knowledge_graph" + ParserTypeTag ParserType = "tag" +) + +// TaskStatus represents the status of a processing task +type TaskStatus string + +const ( + TaskStatusUnstart TaskStatus = "0" + TaskStatusRunning TaskStatus = "1" + TaskStatusCancel TaskStatus = "2" + TaskStatusDone TaskStatus = "3" + TaskStatusFail TaskStatus = "4" + TaskStatusSchedule TaskStatus = "5" +) + +// PipelineTaskType represents the type of pipeline task +type PipelineTaskType string + +const ( + PipelineTaskTypeParse PipelineTaskType = "Parse" + PipelineTaskTypeDownload PipelineTaskType = "Download" + PipelineTaskTypeRAPTOR PipelineTaskType = "RAPTOR" + PipelineTaskTypeGraphRAG PipelineTaskType = "GraphRAG" + PipelineTaskTypeMindmap PipelineTaskType = "Mindmap" + PipelineTaskTypeMemory PipelineTaskType = "Memory" +) + +// FileSource represents the source of a file +type FileSource string + +const ( + FileSourceLocal FileSource = "" + FileSourceKnowledgebase FileSource = "knowledgebase" + FileSourceS3 FileSource = "s3" +) + +// Knowledgebase represents the knowledge base model type Knowledgebase struct { ID string `gorm:"column:id;primaryKey;size:32" json:"id"` Avatar *string `gorm:"column:avatar;type:longtext" json:"avatar,omitempty"` @@ -27,7 +104,6 @@ type Knowledgebase struct { Language *string `gorm:"column:language;size:32;index" json:"language,omitempty"` Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` EmbdID string `gorm:"column:embd_id;size:128;not null;index" json:"embd_id"` - TenantEmbdID *int64 `gorm:"column:tenant_embd_id;index" json:"tenant_embd_id,omitempty"` Permission string `gorm:"column:permission;size:16;not null;default:me;index" json:"permission"` CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"` DocNum int64 `gorm:"column:doc_num;default:0;index" json:"doc_num"` @@ -37,7 +113,7 @@ type Knowledgebase struct { VectorSimilarityWeight float64 `gorm:"column:vector_similarity_weight;default:0.3;index" json:"vector_similarity_weight"` ParserID string `gorm:"column:parser_id;size:32;not null;default:naive;index" json:"parser_id"` PipelineID *string `gorm:"column:pipeline_id;size:32;index" json:"pipeline_id,omitempty"` - ParserConfig JSONMap `gorm:"column:parser_config;type:longtext;not null" json:"parser_config"` + ParserConfig JSONMap `gorm:"column:parser_config;type:json" json:"parser_config"` Pagerank int64 `gorm:"column:pagerank;default:0" json:"pagerank"` GraphragTaskID *string `gorm:"column:graphrag_task_id;size:32;index" json:"graphrag_task_id,omitempty"` GraphragTaskFinishAt *time.Time `gorm:"column:graphrag_task_finish_at" json:"graphrag_task_finish_at,omitempty"` @@ -49,12 +125,118 @@ type Knowledgebase struct { BaseModel } -// TableName specify table name +// TableName returns the table name for Knowledgebase model func (Knowledgebase) TableName() string { return "knowledgebase" } -// InvitationCode invitation code model +// ToMap converts Knowledgebase to a map for JSON response +func (kb *Knowledgebase) ToMap() map[string]interface{} { + result := map[string]interface{}{ + "id": kb.ID, + "tenant_id": kb.TenantID, + "name": kb.Name, + "embd_id": kb.EmbdID, + "permission": kb.Permission, + "created_by": kb.CreatedBy, + "doc_num": kb.DocNum, + "token_num": kb.TokenNum, + "chunk_num": kb.ChunkNum, + "similarity_threshold": kb.SimilarityThreshold, + "vector_similarity_weight": kb.VectorSimilarityWeight, + "parser_id": kb.ParserID, + "parser_config": kb.ParserConfig, + "pagerank": kb.Pagerank, + "create_time": kb.CreateTime, + } + + if kb.Avatar != nil { + result["avatar"] = *kb.Avatar + } + if kb.Language != nil { + result["language"] = *kb.Language + } + if kb.Description != nil { + result["description"] = *kb.Description + } + if kb.PipelineID != nil { + result["pipeline_id"] = *kb.PipelineID + } + if kb.GraphragTaskID != nil { + result["graphrag_task_id"] = *kb.GraphragTaskID + } + if kb.GraphragTaskFinishAt != nil { + result["graphrag_task_finish_at"] = kb.GraphragTaskFinishAt.Format("2006-01-02 15:04:05") + } + if kb.RaptorTaskID != nil { + result["raptor_task_id"] = *kb.RaptorTaskID + } + if kb.RaptorTaskFinishAt != nil { + result["raptor_task_finish_at"] = kb.RaptorTaskFinishAt.Format("2006-01-02 15:04:05") + } + if kb.MindmapTaskID != nil { + result["mindmap_task_id"] = *kb.MindmapTaskID + } + if kb.MindmapTaskFinishAt != nil { + result["mindmap_task_finish_at"] = kb.MindmapTaskFinishAt.Format("2006-01-02 15:04:05") + } + if kb.UpdateTime != nil { + result["update_time"] = *kb.UpdateTime + } + + return result +} + +// KnowledgebaseDetail represents detailed knowledge base information with joined data +type KnowledgebaseDetail struct { + ID string `json:"id"` + EmbdID string `json:"embd_id"` + Avatar *string `json:"avatar,omitempty"` + Name string `json:"name"` + Language *string `json:"language,omitempty"` + Description *string `json:"description,omitempty"` + Permission string `json:"permission"` + DocNum int64 `json:"doc_num"` + TokenNum int64 `json:"token_num"` + ChunkNum int64 `json:"chunk_num"` + ParserID string `json:"parser_id"` + PipelineID *string `json:"pipeline_id,omitempty"` + PipelineName *string `json:"pipeline_name,omitempty"` + PipelineAvatar *string `json:"pipeline_avatar,omitempty"` + ParserConfig JSONMap `json:"parser_config"` + Pagerank int64 `json:"pagerank"` + GraphragTaskID *string `json:"graphrag_task_id,omitempty"` + GraphragTaskFinishAt *string `json:"graphrag_task_finish_at,omitempty"` + RaptorTaskID *string `json:"raptor_task_id,omitempty"` + RaptorTaskFinishAt *string `json:"raptor_task_finish_at,omitempty"` + MindmapTaskID *string `json:"mindmap_task_id,omitempty"` + MindmapTaskFinishAt *string `json:"mindmap_task_finish_at,omitempty"` + CreateTime *int64 `json:"create_time,omitempty"` + UpdateTime *int64 `json:"update_time,omitempty"` + Size int64 `json:"size"` + Connectors []string `json:"connectors"` +} + +// KnowledgebaseListItem represents a knowledge base item in list responses +type KnowledgebaseListItem struct { + ID string `json:"id"` + Avatar *string `json:"avatar,omitempty"` + Name string `json:"name"` + Language *string `json:"language,omitempty"` + Description *string `json:"description,omitempty"` + TenantID string `json:"tenant_id"` + Permission string `json:"permission"` + DocNum int64 `json:"doc_num"` + TokenNum int64 `json:"token_num"` + ChunkNum int64 `json:"chunk_num"` + ParserID string `json:"parser_id"` + EmbdID string `json:"embd_id"` + Nickname string `json:"nickname"` + TenantAvatar *string `json:"tenant_avatar,omitempty"` + UpdateTime *int64 `json:"update_time,omitempty"` +} + +// InvitationCode represents the invitation code model type InvitationCode struct { ID string `gorm:"column:id;primaryKey;size:32" json:"id"` Code string `gorm:"column:code;size:32;not null;index" json:"code"` @@ -65,7 +247,7 @@ type InvitationCode struct { BaseModel } -// TableName specify table name +// TableName returns the table name for InvitationCode model func (InvitationCode) TableName() string { return "invitation_code" } diff --git a/internal/model/llm.go b/internal/model/llm.go index 9b9054e7e6..665dee7867 100644 --- a/internal/model/llm.go +++ b/internal/model/llm.go @@ -64,13 +64,14 @@ func (TenantLangfuse) TableName() string { // MyLLM represents LLM information for a tenant with factory details type MyLLM struct { + ID string `gorm:"column:id" json:"id"` LLMFactory string `gorm:"column:llm_factory" json:"llm_factory"` Logo *string `gorm:"column:logo" json:"logo,omitempty"` - Tags string `gorm:"column:tags" json:"tags"` - ModelType string `gorm:"column:model_type" json:"model_type"` - LLMName string `gorm:"column:llm_name" json:"llm_name"` - UsedTokens int64 `gorm:"column:used_tokens" json:"used_tokens"` - Status string `gorm:"column:status" json:"status"` - APIBase string `gorm:"column:api_base" json:"api_base,omitempty"` - MaxTokens int64 `gorm:"column:max_tokens" json:"max_tokens,omitempty"` + Tags *string `gorm:"column:tags" json:"tags"` + ModelType *string `gorm:"column:model_type" json:"model_type"` + LLMName *string `gorm:"column:llm_name" json:"llm_name"` + UsedTokens *int64 `gorm:"column:used_tokens" json:"used_tokens"` + Status *string `gorm:"column:status" json:"status"` + APIBase *string `gorm:"column:api_base" json:"api_base,omitempty"` + MaxTokens *int64 `gorm:"column:max_tokens" json:"max_tokens,omitempty"` } diff --git a/internal/router/router.go b/internal/router/router.go index 5f41765d60..cebd6b97ac 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -101,6 +101,8 @@ func (r *Router) Setup(engine *gin.Engine) { engine.POST("/v1/user/setting", r.userHandler.Setting) // User change password endpoint engine.POST("/v1/user/setting/password", r.userHandler.ChangePassword) + // User set tenant info endpoint + engine.POST("/v1/user/set_tenant_info", r.userHandler.SetTenantInfo) // API v1 route group v1 := engine.Group("/api/v1") @@ -134,7 +136,25 @@ func (r *Router) Setup(engine *gin.Engine) { // Knowledge base routes kb := engine.Group("/v1/kb") { + kb.POST("/create", r.knowledgebaseHandler.CreateKB) + kb.POST("/update", r.knowledgebaseHandler.UpdateKB) + kb.POST("/update_metadata_setting", r.knowledgebaseHandler.UpdateMetadataSetting) + kb.GET("/detail", r.knowledgebaseHandler.GetDetail) kb.POST("/list", r.knowledgebaseHandler.ListKbs) + kb.POST("/rm", r.knowledgebaseHandler.DeleteKB) + kb.GET("/tags", r.knowledgebaseHandler.ListTagsFromKbs) + kb.GET("/get_meta", r.knowledgebaseHandler.GetMeta) + kb.GET("/basic_info", r.knowledgebaseHandler.GetBasicInfo) + + // KB ID specific routes + kbByID := kb.Group("/:kb_id") + { + kbByID.GET("/tags", r.knowledgebaseHandler.ListTags) + kbByID.POST("/rm_tags", r.knowledgebaseHandler.RemoveTags) + kbByID.POST("/rename_tag", r.knowledgebaseHandler.RenameTag) + kbByID.GET("/knowledge_graph", r.knowledgebaseHandler.KnowledgeGraph) + kbByID.DELETE("/knowledge_graph", r.knowledgebaseHandler.DeleteKnowledgeGraph) + } } // Chunk routes @@ -149,6 +169,7 @@ func (r *Router) Setup(engine *gin.Engine) { llm.GET("/my_llms", r.llmHandler.GetMyLLMs) llm.GET("/factories", r.llmHandler.Factories) llm.GET("/list", r.llmHandler.ListApp) + llm.POST("/set_api_key", r.llmHandler.SetAPIKey) } // Chat routes diff --git a/internal/server/config.go b/internal/server/config.go index 5a8fbf1e1d..fe9cdea48e 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -40,6 +40,29 @@ type Config struct { DocEngine DocEngineConfig `mapstructure:"doc_engine"` RegisterEnabled int `mapstructure:"register_enabled"` OAuth map[string]OAuthConfig `mapstructure:"oauth"` + UserDefaultLLM UserDefaultLLMConfig `mapstructure:"user_default_llm"` +} + +// UserDefaultLLMConfig user default LLM configuration +type UserDefaultLLMConfig struct { + DefaultModels DefaultModelsConfig `mapstructure:"default_models"` +} + +// DefaultModelsConfig default models configuration +type DefaultModelsConfig struct { + ChatModel ModelConfig `mapstructure:"chat_model"` + EmbeddingModel ModelConfig `mapstructure:"embedding_model"` + RerankModel ModelConfig `mapstructure:"rerank_model"` + ASRModel ModelConfig `mapstructure:"asr_model"` + Image2TextModel ModelConfig `mapstructure:"image2text_model"` +} + +// ModelConfig model configuration +type ModelConfig struct { + Name string `mapstructure:"name"` + APIKey string `mapstructure:"api_key"` + BaseURL string `mapstructure:"base_url"` + Factory string `mapstructure:"factory"` } // OAuthConfig OAuth configuration for a channel @@ -414,6 +437,45 @@ func Init(configPath string) error { } } + // Map user_default_llm section to UserDefaultLLMConfig + if v.IsSet("user_default_llm") { + userDefaultLLMConfig := v.Sub("user_default_llm") + if userDefaultLLMConfig != nil { + if defaultModels := userDefaultLLMConfig.Sub("default_models"); defaultModels != nil { + globalConfig.UserDefaultLLM.DefaultModels.ChatModel = ModelConfig{ + Name: defaultModels.GetString("chat_model.name"), + APIKey: defaultModels.GetString("chat_model.api_key"), + BaseURL: defaultModels.GetString("chat_model.base_url"), + Factory: defaultModels.GetString("chat_model.factory"), + } + globalConfig.UserDefaultLLM.DefaultModels.EmbeddingModel = ModelConfig{ + Name: defaultModels.GetString("embedding_model.name"), + APIKey: defaultModels.GetString("embedding_model.api_key"), + BaseURL: defaultModels.GetString("embedding_model.base_url"), + Factory: defaultModels.GetString("embedding_model.factory"), + } + globalConfig.UserDefaultLLM.DefaultModels.RerankModel = ModelConfig{ + Name: defaultModels.GetString("rerank_model.name"), + APIKey: defaultModels.GetString("rerank_model.api_key"), + BaseURL: defaultModels.GetString("rerank_model.base_url"), + Factory: defaultModels.GetString("rerank_model.factory"), + } + globalConfig.UserDefaultLLM.DefaultModels.ASRModel = ModelConfig{ + Name: defaultModels.GetString("asr_model.name"), + APIKey: defaultModels.GetString("asr_model.api_key"), + BaseURL: defaultModels.GetString("asr_model.base_url"), + Factory: defaultModels.GetString("asr_model.factory"), + } + globalConfig.UserDefaultLLM.DefaultModels.Image2TextModel = ModelConfig{ + Name: defaultModels.GetString("image2text_model.name"), + APIKey: defaultModels.GetString("image2text_model.api_key"), + BaseURL: defaultModels.GetString("image2text_model.base_url"), + Factory: defaultModels.GetString("image2text_model.factory"), + } + } + } + } + return nil } diff --git a/internal/service/chat.go b/internal/service/chat.go index b53706d180..c1a7915ac7 100644 --- a/internal/service/chat.go +++ b/internal/service/chat.go @@ -82,7 +82,7 @@ func (s *ChatService) ListChats(userID string, status string) (*ListChatsRespons } // Enrich with knowledge base names - var chatsWithKBNames []*ChatWithKBNames + chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats)) for _, chat := range chats { kbNames := s.getKBNames(chat.KBIDs) chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{ @@ -148,7 +148,7 @@ func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSi } // Enrich with knowledge base names - var chatsWithKBNames []*ChatWithKBNames + chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats)) for _, chat := range chats { kbNames := s.getKBNames(chat.KBIDs) chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{ diff --git a/internal/service/kb.go b/internal/service/kb.go index 8b982ebe6f..143c0e9a40 100644 --- a/internal/service/kb.go +++ b/internal/service/kb.go @@ -17,25 +17,76 @@ package service import ( + "errors" + "fmt" + "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/model" + "ragflow/internal/utility" + "strings" + "time" + + "github.com/google/uuid" ) -// KnowledgebaseService knowledge base service +// KnowledgebaseService service class for managing dataset operations type KnowledgebaseService struct { kbDAO *dao.KnowledgebaseDAO userTenantDAO *dao.UserTenantDAO + userDAO *dao.UserDAO + tenantDAO *dao.TenantDAO + connectorDAO *dao.ConnectorDAO } -// NewKnowledgebaseService create knowledge base service +// NewKnowledgebaseService creates a new knowledge base service func NewKnowledgebaseService() *KnowledgebaseService { return &KnowledgebaseService{ kbDAO: dao.NewKnowledgebaseDAO(), userTenantDAO: dao.NewUserTenantDAO(), + userDAO: dao.NewUserDAO(), + tenantDAO: dao.NewTenantDAO(), + connectorDAO: dao.NewConnectorDAO(), } } -// ListKbsRequest list knowledge bases request +// CreateKBRequest represents the request for creating a knowledge base +type CreateKBRequest struct { + Name string `json:"name" binding:"required"` + ParserID *string `json:"parser_id,omitempty"` + Description *string `json:"description,omitempty"` + Language *string `json:"language,omitempty"` + Permission *string `json:"permission,omitempty"` + Avatar *string `json:"avatar,omitempty"` + ParserConfig map[string]interface{} `json:"parser_config,omitempty"` +} + +// CreateKBResponse represents the response for creating a knowledge base +type CreateKBResponse struct { + KBID string `json:"kb_id"` +} + +// UpdateKBRequest represents the request for updating a knowledge base +type UpdateKBRequest struct { + KBID string `json:"kb_id" binding:"required"` + Name string `json:"name" binding:"required"` + Description *string `json:"description"` + ParserID string `json:"parser_id" binding:"required"` + Permission *string `json:"permission,omitempty"` + Language *string `json:"language,omitempty"` + Avatar *string `json:"avatar,omitempty"` + Pagerank *int64 `json:"pagerank,omitempty"` + ParserConfig map[string]interface{} `json:"parser_config,omitempty"` + Connectors []string `json:"connectors,omitempty"` +} + +// UpdateMetadataSettingRequest represents the request for updating metadata settings +type UpdateMetadataSettingRequest struct { + KBID string `json:"kb_id" binding:"required"` + Metadata map[string]interface{} `json:"metadata" binding:"required"` + EnableMetadata *bool `json:"enable_metadata,omitempty"` +} + +// ListKbsRequest represents the request for listing knowledge bases type ListKbsRequest struct { Keywords *string `json:"keywords,omitempty"` Page *int `json:"page,omitempty"` @@ -46,37 +97,461 @@ type ListKbsRequest struct { OwnerIDs *[]string `json:"owner_ids,omitempty"` } -// ListKbsResponse list knowledge bases response +// ListKbsResponse represents the response for listing knowledge bases type ListKbsResponse struct { - KBs []*model.Knowledgebase `json:"kbs"` - Total int64 `json:"total"` + KBs []map[string]interface{} `json:"kbs"` + Total int64 `json:"total"` } -// ListKbs list knowledge bases -func (s *KnowledgebaseService) ListKbs(keywords string, page int, pageSize int, parserID string, orderby string, desc bool, ownerIDs []string, userID string) (*ListKbsResponse, error) { - var kbs []*model.Knowledgebase +// CreateKB creates a new knowledge base +// This matches the Python create endpoint in kb_app.py +func (s *KnowledgebaseService) CreateKB(req *CreateKBRequest, tenantID string) (*CreateKBResponse, common.ErrorCode, error) { + // Validate name is a string + if !isValidString(req.Name) { + return nil, common.CodeDataError, errors.New("Dataset name must be string.") + } + + // Trim and validate name + name := strings.TrimSpace(req.Name) + if name == "" { + return nil, common.CodeDataError, errors.New("Dataset name can't be empty.") + } + + // Check name length (using UTF-8 byte length like Python) + if len(name) > model.DatasetNameLimit { + return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(name), model.DatasetNameLimit) + } + + // Verify tenant exists + tenant, err := s.tenantDAO.GetByID(tenantID) + if err != nil { + return nil, common.CodeDataError, errors.New("Tenant not found.") + } + + // Deduplicate name within tenant + duplicateName := s.kbDAO.DuplicateName(name, tenantID) + + // Get parser ID (default to "naive") + parserID := "naive" + if req.ParserID != nil && *req.ParserID != "" { + parserID = *req.ParserID + } + + // Get parser config with defaults + parserConfig := getParserConfig(parserID, req.ParserConfig) + parserConfig["llm_id"] = tenant.LLMID + + // Generate KB ID + kbID := strings.ReplaceAll(uuid.New().String(), "-", "") + + // Create knowledge base model + now := time.Now().Unix() + nowDate := time.Now() + kb := &model.Knowledgebase{ + ID: kbID, + Name: duplicateName, + TenantID: tenantID, + CreatedBy: tenantID, + ParserID: parserID, + ParserConfig: parserConfig, + Permission: "me", + EmbdID: "", + } + kb.CreateTime = &now + kb.UpdateTime = &now + kb.CreateDate = &nowDate + kb.UpdateDate = &nowDate + status := string(model.StatusValid) + kb.Status = &status + + // Set optional fields + if req.Description != nil { + kb.Description = req.Description + } + if req.Language != nil { + kb.Language = req.Language + } + if req.Permission != nil { + kb.Permission = *req.Permission + } + if req.Avatar != nil { + kb.Avatar = req.Avatar + } + + // Create in database + if err := s.kbDAO.Create(kb); err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to create knowledge base: %w", err) + } + + return &CreateKBResponse{KBID: kbID}, common.CodeSuccess, nil +} + +// UpdateKB updates an existing knowledge base +// This matches the Python update endpoint in kb_app.py +func (s *KnowledgebaseService) UpdateKB(req *UpdateKBRequest, userID string) (map[string]interface{}, common.ErrorCode, error) { + // Validate name is a string + if !isValidString(req.Name) { + return nil, common.CodeDataError, errors.New("Dataset name must be string.") + } + + // Trim and validate name + name := strings.TrimSpace(req.Name) + if name == "" { + return nil, common.CodeDataError, errors.New("Dataset name can't be empty.") + } + + // Check name length + if len(name) > model.DatasetNameLimit { + return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(name), model.DatasetNameLimit) + } + + // Check authorization + if !s.kbDAO.Accessible4Deletion(req.KBID, userID) { + return nil, common.CodeAuthenticationError, errors.New("No authorization.") + } + + // Verify ownership + kbs, err := s.kbDAO.Query(map[string]interface{}{"created_by": userID, "id": req.KBID}) + if err != nil || len(kbs) == 0 { + return nil, common.CodeOperatingError, errors.New("only owner of dataset authorized for this operation") + } + + // Get existing KB + kb, err := s.kbDAO.GetByID(req.KBID) + if err != nil { + return nil, common.CodeDataError, errors.New("can't find this dataset") + } + + // Check for duplicate name + if strings.ToLower(name) != strings.ToLower(kb.Name) { + existingKB, _ := s.kbDAO.GetByName(name, userID) + if existingKB != nil { + return nil, common.CodeDataError, errors.New("duplicated dataset name") + } + } + + // Build updates + updates := map[string]interface{}{ + "name": name, + "parser_id": req.ParserID, + } + + if req.Description != nil { + updates["description"] = *req.Description + } + if req.Permission != nil { + updates["permission"] = *req.Permission + } + if req.Language != nil { + updates["language"] = *req.Language + } + if req.Avatar != nil { + updates["avatar"] = *req.Avatar + } + if req.Pagerank != nil { + updates["pagerank"] = *req.Pagerank + } + if req.ParserConfig != nil { + updates["parser_config"] = req.ParserConfig + } + + now := time.Now().Unix() + nowDate := time.Now() + updates["update_time"] = now + updates["update_date"] = nowDate + + // Update in database + if err := s.kbDAO.UpdateByID(req.KBID, updates); err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to update knowledge base: %w", err) + } + + // Get updated KB + updatedKB, err := s.kbDAO.GetByID(req.KBID) + if err != nil { + return nil, common.CodeDataError, errors.New("database error (knowledgebase rename)") + } + + result := updatedKB.ToMap() + result["connectors"] = req.Connectors + + return result, common.CodeSuccess, nil +} + +// UpdateMetadataSetting updates the metadata settings for a knowledge base +func (s *KnowledgebaseService) UpdateMetadataSetting(req *UpdateMetadataSettingRequest) (map[string]interface{}, common.ErrorCode, error) { + kb, err := s.kbDAO.GetByID(req.KBID) + if err != nil { + return nil, common.CodeDataError, errors.New("database error (knowledgebase not found)") + } + + parserConfig := kb.ParserConfig + if parserConfig == nil { + parserConfig = make(map[string]interface{}) + } + + parserConfig["metadata"] = req.Metadata + enableMetadata := true + if req.EnableMetadata != nil { + enableMetadata = *req.EnableMetadata + } + parserConfig["enable_metadata"] = enableMetadata + + if err := s.kbDAO.UpdateParserConfig(req.KBID, parserConfig); err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to update metadata setting: %w", err) + } + + result := kb.ToMap() + result["parser_config"] = parserConfig + + return result, common.CodeSuccess, nil +} + +// GetDetail retrieves detailed information about a knowledge base +// This matches the Python kb_detail endpoint in kb_app.py +func (s *KnowledgebaseService) GetDetail(kbID, userID string) (*model.KnowledgebaseDetail, common.ErrorCode, error) { + // Check authorization + if !s.kbDAO.Accessible(kbID, userID) { + return nil, common.CodeOperatingError, errors.New("only owner of dataset authorized for this operation") + } + + // Get detail + detail, err := s.kbDAO.GetDetail(kbID) + if err != nil { + return nil, common.CodeDataError, errors.New("can't find this dataset") + } + + // Set connectors (empty for now) + detail.Connectors = []string{} + + return detail, common.CodeSuccess, nil +} + +// ListKbs lists knowledge bases with pagination and filtering +// This matches the Python list endpoint in kb_app.py +func (s *KnowledgebaseService) ListKbs(keywords string, page int, pageSize int, parserID string, orderby string, desc bool, ownerIDs []string, userID string) (*ListKbsResponse, common.ErrorCode, error) { + var kbs []*model.KnowledgebaseListItem var total int64 var err error - // If owner IDs are provided, filter by them - if ownerIDs != nil && len(ownerIDs) > 0 { - kbs, total, err = s.kbDAO.ListByOwnerIDs(ownerIDs, page, pageSize, orderby, desc, keywords, parserID) + if len(ownerIDs) > 0 { + // List by owner IDs + kbs, total, err = s.kbDAO.GetByTenantIDs(ownerIDs, userID, page, pageSize, orderby, desc, keywords, parserID) } else { - // Get tenant IDs by user ID + // Get tenant IDs for user tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) if err != nil { - return nil, err + return nil, common.CodeServerError, err } - kbs, total, err = s.kbDAO.ListByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords, parserID) + kbs, total, err = s.kbDAO.GetByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords, parserID) } if err != nil { - return nil, err + return nil, common.CodeServerError, err + } + + // Convert to map slice + kbMaps := make([]map[string]interface{}, len(kbs)) + for i, kb := range kbs { + kbMaps[i] = map[string]interface{}{ + "id": kb.ID, + "avatar": kb.Avatar, + "name": kb.Name, + "language": kb.Language, + "description": kb.Description, + "tenant_id": kb.TenantID, + "permission": kb.Permission, + "doc_num": kb.DocNum, + "token_num": kb.TokenNum, + "chunk_num": kb.ChunkNum, + "parser_id": kb.ParserID, + "embd_id": kb.EmbdID, + "nickname": kb.Nickname, + "tenant_avatar": kb.TenantAvatar, + "update_time": kb.UpdateTime, + } } return &ListKbsResponse{ - KBs: kbs, + KBs: kbMaps, Total: total, - }, nil + }, common.CodeSuccess, nil +} + +// DeleteKB soft deletes a knowledge base +// This matches the Python rm endpoint in kb_app.py +func (s *KnowledgebaseService) DeleteKB(kbID, userID string) (common.ErrorCode, error) { + // Check authorization + if !s.kbDAO.Accessible4Deletion(kbID, userID) { + return common.CodeAuthenticationError, errors.New("No authorization.") + } + + // Verify ownership + kbs, err := s.kbDAO.Query(map[string]interface{}{"created_by": userID, "id": kbID}) + if err != nil || len(kbs) == 0 { + return common.CodeOperatingError, errors.New("only owner of dataset authorized for this operation") + } + + // Soft delete + if err := s.kbDAO.Delete(kbID); err != nil { + return common.CodeServerError, fmt.Errorf("database error (knowledgebase removal): %w", err) + } + + return common.CodeSuccess, nil +} + +// Accessible checks if a knowledge base is accessible by a user +func (s *KnowledgebaseService) Accessible(kbID, userID string) bool { + return s.kbDAO.Accessible(kbID, userID) +} + +// GetByID retrieves a knowledge base by ID +func (s *KnowledgebaseService) GetByID(kbID string) (*model.Knowledgebase, error) { + return s.kbDAO.GetByID(kbID) +} + +// GetKBIDsByTenantID retrieves all knowledge base IDs for a tenant +func (s *KnowledgebaseService) GetKBIDsByTenantID(tenantID string) ([]string, error) { + return s.kbDAO.GetKBIDsByTenantID(tenantID) +} + +// isValidString checks if a value is a non-empty string +func isValidString(v interface{}) bool { + str, ok := v.(string) + return ok && str != "" +} + +// getParserConfig returns the parser configuration with defaults +// This matches the Python get_parser_config function +func getParserConfig(parserID string, customConfig map[string]interface{}) map[string]interface{} { + config := map[string]interface{}{ + "pages": [][]int{{1, 1000000}}, + "table_context_size": 0, + "image_context_size": 0, + } + + switch parserID { + case "table": + config["layout_recognize"] = false + config["chunk_token_num"] = 128 + config["delimiter"] = "\n!?;。;!?" + config["html4excel"] = false + case "naive": + config["chunk_token_num"] = 128 + config["delimiter"] = "\n!?;。;!?" + config["html4excel"] = false + default: + config["raptor"] = map[string]interface{}{ + "use_raptor": false, + } + } + + // Merge custom config if provided + if customConfig != nil { + config = mergeParserConfig(config, customConfig) + } + + return config +} + +// mergeParserConfig merges two parser configurations +func mergeParserConfig(base, override map[string]interface{}) map[string]interface{} { + result := make(map[string]interface{}) + for k, v := range base { + result[k] = v + } + + for k, v := range override { + if existing, ok := result[k]; ok { + if existingMap, ok := existing.(map[string]interface{}); ok { + if newMap, ok := v.(map[string]interface{}); ok { + result[k] = mergeParserConfig(existingMap, newMap) + continue + } + } + } + result[k] = v + } + + return result +} + +// GenerateUUID generates a UUID string without dashes +func GenerateUUID() string { + return strings.ReplaceAll(uuid.New().String(), "-", "") +} + +// GetUserByToken gets user by authorization token +func (s *KnowledgebaseService) GetUserByToken(authorization string) (*model.User, common.ErrorCode, error) { + userService := NewUserService() + return userService.GetUserByToken(authorization) +} + +// GetUserByID gets user by ID +func (s *KnowledgebaseService) GetUserByID(id string) (*model.User, error) { + return s.userDAO.GetByAccessToken(id) +} + +// GetTenantIDsByUserID gets tenant IDs for a user +func (s *KnowledgebaseService) GetTenantIDsByUserID(userID string) ([]string, error) { + return s.userTenantDAO.GetTenantIDsByUserID(userID) +} + +// GetConnectorsByTenantID gets connectors for a tenant +func (s *KnowledgebaseService) GetConnectorsByTenantID(tenantID string) ([]*dao.ConnectorListItem, error) { + return s.connectorDAO.ListByTenantID(tenantID) +} + +// GetKBList retrieves knowledge bases with ID and name filtering +func (s *KnowledgebaseService) GetKBList(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, id, name string) ([]*model.Knowledgebase, int64, common.ErrorCode, error) { + kbs, total, err := s.kbDAO.GetList(tenantIDs, userID, page, pageSize, orderby, desc, id, name) + if err != nil { + return nil, 0, common.CodeServerError, err + } + return kbs, total, common.CodeSuccess, nil +} + +// GetKBByIDAndUserID retrieves a knowledge base by ID and user ID +func (s *KnowledgebaseService) GetKBByIDAndUserID(kbID, userID string) ([]*model.Knowledgebase, error) { + return s.kbDAO.GetKBByIDAndUserID(kbID, userID) +} + +// GetKBByNameAndUserID retrieves a knowledge base by name and user ID +func (s *KnowledgebaseService) GetKBByNameAndUserID(kbName, userID string) ([]*model.Knowledgebase, error) { + return s.kbDAO.GetKBByNameAndUserID(kbName, userID) +} + +// AtomicIncreaseDocNumByID atomically increments the document count +func (s *KnowledgebaseService) AtomicIncreaseDocNumByID(kbID string) error { + return s.kbDAO.AtomicIncreaseDocNumByID(kbID) +} + +// DecreaseDocumentNum decreases document, chunk, and token counts +func (s *KnowledgebaseService) DecreaseDocumentNum(kbID string, docNum, chunkNum, tokenNum int64) error { + return s.kbDAO.DecreaseDocumentNum(kbID, docNum, chunkNum, tokenNum) +} + +// UpdateParserConfig updates the parser configuration +func (s *KnowledgebaseService) UpdateParserConfig(id string, config map[string]interface{}) error { + return s.kbDAO.UpdateParserConfig(id, config) +} + +// DeleteFieldMap removes the field_map from parser_config +func (s *KnowledgebaseService) DeleteFieldMap(id string) error { + return s.kbDAO.DeleteFieldMap(id) +} + +// GetFieldMap retrieves field mappings from multiple knowledge bases +func (s *KnowledgebaseService) GetFieldMap(ids []string) (map[string]interface{}, error) { + return s.kbDAO.GetFieldMap(ids) +} + +// GetAllIDs retrieves all knowledge base IDs +func (s *KnowledgebaseService) GetAllIDs() ([]string, error) { + return s.kbDAO.GetAllIDs() +} + +// ExtractAccessToken extracts access token from authorization header +func ExtractAccessToken(authorization, secretKey string) (string, error) { + return utility.ExtractAccessToken(authorization, secretKey) } diff --git a/internal/service/llm.go b/internal/service/llm.go index 85b1cd99f8..a284f4d2c6 100644 --- a/internal/service/llm.go +++ b/internal/service/llm.go @@ -17,11 +17,16 @@ package service import ( + "fmt" + "strconv" "strings" "ragflow/internal/dao" + "ragflow/internal/model" ) +var DB = dao.DB + // LLMService LLM service type LLMService struct { tenantLLMDAO *dao.TenantLLMDAO @@ -38,6 +43,7 @@ func NewLLMService() *LLMService { // MyLLMItem represents a single LLM item in the response type MyLLMItem struct { + ID string `json:"id"` Type string `json:"type"` Name string `json:"name"` UsedToken int64 `json:"used_token"` @@ -46,67 +52,89 @@ type MyLLMItem struct { MaxTokens int64 `json:"max_tokens,omitempty"` } -// MyLLMResponse represents the response structure for my LLMs -type MyLLMResponse struct { +// MyLLMFactory represents the response structure for a factory in my LLMs +type MyLLMFactory struct { Tags string `json:"tags"` LLM []MyLLMItem `json:"llm"` } // GetMyLLMs get my LLMs for a tenant -func (s *LLMService) GetMyLLMs(tenantID string, includeDetails bool) (map[string]MyLLMResponse, error) { - // Get LLM list from database - myLLMs, err := s.tenantLLMDAO.GetMyLLMs(tenantID, includeDetails) - if err != nil { - return nil, err - } +func (s *LLMService) GetMyLLMs(tenantID string, includeDetails bool) (map[string]MyLLMFactory, error) { + result := make(map[string]MyLLMFactory) - // Group by factory - result := make(map[string]MyLLMResponse) - providerDAO := dao.NewModelProviderDAO() - for _, llm := range myLLMs { - // Get or create factory entry - resp, exists := result[llm.LLMFactory] - if !exists { - resp = MyLLMResponse{ - Tags: llm.Tags, - LLM: []MyLLMItem{}, + if includeDetails { + objs, err := s.tenantLLMDAO.ListAllByTenant(tenantID) + if err != nil { + return nil, err + } + + factoryDAO := dao.NewLLMFactoryDAO() + factories, err := factoryDAO.GetAllValid() + if err != nil { + return nil, err + } + + factoryTagsMap := make(map[string]string) + for _, f := range factories { + if f.Tags != "" { + factoryTagsMap[f.Name] = f.Tags } } - // Create LLM item - item := MyLLMItem{ - Type: llm.ModelType, - Name: llm.LLMName, - UsedToken: llm.UsedTokens, - Status: llm.Status, - } - - // Add detailed fields if requested - if includeDetails { - item.APIBase = llm.APIBase - item.MaxTokens = llm.MaxTokens - - // If APIBase is empty, try to get from model provider configuration - if item.APIBase == "" { - provider := providerDAO.GetProviderByName(llm.LLMFactory) - if provider != nil { - // Determine appropriate API base URL based on model type - switch llm.ModelType { - case "embedding": - if provider.DefaultEmbeddingURL != "" { - item.APIBase = provider.DefaultEmbeddingURL - } - // Add other model types here if needed - // case "chat": - // case "rerank": - // etc. - } + for _, o := range objs { + llmFactory := o.LLMFactory + if _, exists := result[llmFactory]; !exists { + tags := factoryTagsMap[llmFactory] + result[llmFactory] = MyLLMFactory{ + Tags: tags, + LLM: []MyLLMItem{}, } } + + item := MyLLMItem{ + ID: int64ToString(o.ID), + Type: getStringValue(o.ModelType), + Name: getStringValue(o.LLMName), + UsedToken: o.UsedTokens, + Status: getValidStatus(o.Status), + } + + if includeDetails { + item.APIBase = getStringValueDefault(o.APIBase, "") + item.MaxTokens = o.MaxTokens + } + + factory := result[llmFactory] + factory.LLM = append(factory.LLM, item) + result[llmFactory] = factory + } + } else { + objs, err := s.tenantLLMDAO.GetMyLLMs(tenantID) + if err != nil { + return nil, err } - resp.LLM = append(resp.LLM, item) - result[llm.LLMFactory] = resp + for _, o := range objs { + llmFactory := o.LLMFactory + if _, exists := result[llmFactory]; !exists { + result[llmFactory] = MyLLMFactory{ + Tags: getStringValue(o.Tags), + LLM: []MyLLMItem{}, + } + } + + item := MyLLMItem{ + ID: o.ID, + Type: getStringValue(o.ModelType), + Name: getStringValue(o.LLMName), + UsedToken: getInt64Value(o.UsedTokens), + Status: getStringValueDefault(o.Status, "1"), + } + + factory := result[llmFactory] + factory.LLM = append(factory.LLM, item) + result[llmFactory] = factory + } } return result, nil @@ -114,6 +142,7 @@ func (s *LLMService) GetMyLLMs(tenantID string, includeDetails bool) (map[string // LLMListItem represents a single LLM item in the list response type LLMListItem struct { + ID string `json:"id"` LLMName string `json:"llm_name"` ModelType string `json:"model_type"` FID string `json:"fid"` @@ -142,37 +171,32 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon "GPUStack": true, } - // Get tenant LLMs - tenantLLMs, err := s.tenantLLMDAO.ListAllByTenant(tenantID) + objs, err := s.tenantLLMDAO.ListAllByTenant(tenantID) if err != nil { return nil, err } - // Build set of factories with valid API keys facts := make(map[string]bool) - // Build set of valid LLM names@factories status := make(map[string]bool) - for _, tl := range tenantLLMs { - if tl.APIKey != nil && *tl.APIKey != "" && tl.Status == "1" { - facts[tl.LLMFactory] = true + tenantLLMMapping := make(map[string]string) + + for _, o := range objs { + if o.APIKey != nil && *o.APIKey != "" && getValidStatus(o.Status) == "1" { + facts[o.LLMFactory] = true } - llmName := "" - if tl.LLMName != nil { - llmName = *tl.LLMName - } - key := llmName + "@" + tl.LLMFactory - if tl.Status == "1" { + llmName := getStringValue(o.LLMName) + key := llmName + "@" + o.LLMFactory + if getValidStatus(o.Status) == "1" { status[key] = true } + tenantLLMMapping[key] = int64ToString(o.ID) } - // Get all valid LLMs allLLMs, err := s.llmDAO.GetAllValid() if err != nil { return nil, err } - // Filter and build result llmSet := make(map[string]bool) result := make(ListLLMsResponse) @@ -183,20 +207,18 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon key := llm.LLMName + "@" + llm.FID - // Check if valid (Builtin factory or in status set) if llm.FID != "Builtin" && !status[key] { continue } - // Filter by model type if specified if modelType != "" && !strings.Contains(llm.ModelType, modelType) { continue } - // Determine availability - available := facts[llm.FID] || selfDeployed[llm.FID] || llm.LLMName == "flag-embedding" + available := facts[llm.FID] || selfDeployed[llm.FID] || strings.ToLower(llm.LLMName) == "flag-embedding" item := LLMListItem{ + ID: tenantLLMMapping[key], LLMName: llm.LLMName, ModelType: llm.ModelType, FID: llm.FID, @@ -207,7 +229,6 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon Tags: llm.Tags, } - // Add BaseModel fields if llm.CreateDate != nil { createDateStr := llm.CreateDate.Format("2006-01-02T15:04:05") item.CreateDate = &createDateStr @@ -225,36 +246,160 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon llmSet[key] = true } - // Add tenant LLMs that are not in the global list - for _, tl := range tenantLLMs { - llmName := "" - if tl.LLMName != nil { - llmName = *tl.LLMName - } - key := llmName + "@" + tl.LLMFactory + for _, o := range objs { + llmName := getStringValue(o.LLMName) + key := llmName + "@" + o.LLMFactory if llmSet[key] { continue } - // Filter by model type if specified - modelTypeValue := "" - if tl.ModelType != nil { - modelTypeValue = *tl.ModelType - } + modelTypeValue := getStringValue(o.ModelType) if modelType != "" && !strings.Contains(modelTypeValue, modelType) { continue } item := LLMListItem{ + ID: int64ToString(o.ID), LLMName: llmName, ModelType: modelTypeValue, - FID: tl.LLMFactory, + FID: o.LLMFactory, Available: true, - Status: tl.Status, + Status: getValidStatus(o.Status), } - result[tl.LLMFactory] = append(result[tl.LLMFactory], item) + result[o.LLMFactory] = append(result[o.LLMFactory], item) } return result, nil } + +func getStringValue(s *string) string { + if s == nil { + return "" + } + return *s +} + +func getStringValueDefault(s *string, defaultVal string) string { + if s == nil || *s == "" { + return defaultVal + } + return *s +} + +func getValidStatus(status string) string { + if status == "" { + return "1" + } + return status +} + +func getInt64Value(i *int64) int64 { + if i == nil { + return 0 + } + return *i +} + +func getInt64ValueDefault(i *int64, defaultVal int64) int64 { + if i == nil || *i == 0 { + return defaultVal + } + return *i +} + +func getBoolValue(b *bool) bool { + if b == nil { + return false + } + return *b +} + +func int64ToString(n int64) string { + return strconv.FormatInt(n, 10) +} + +// SetAPIKeyRequest represents the request for setting API key +type SetAPIKeyRequest struct { + LLMFactory string `json:"llm_factory"` + APIKey string `json:"api_key"` + BaseURL string `json:"base_url"` + SourceFID string `json:"source_fid"` + ModelType string `json:"model_type"` + LLMName string `json:"llm_name"` + Verify bool `json:"verify"` + MaxTokens int64 `json:"max_tokens"` +} + +// SetAPIKeyResult represents the result of setting API key +type SetAPIKeyResult struct { + Message string `json:"message"` + Success bool `json:"success"` +} + +// SetAPIKey sets API key for a LLM factory +func (s *LLMService) SetAPIKey(tenantID string, req *SetAPIKeyRequest) (*SetAPIKeyResult, error) { + factory := req.LLMFactory + baseURL := req.BaseURL + sourceFactory := req.SourceFID + if sourceFactory == "" { + sourceFactory = factory + } + + sourceLLMs, err := s.llmDAO.GetByFactory(sourceFactory) + if err != nil || len(sourceLLMs) == 0 { + msg := "No models configured for " + factory + " (source: " + sourceFactory + ")." + if req.Verify { + return &SetAPIKeyResult{Message: msg, Success: false}, nil + } + return nil, fmt.Errorf(msg) + } + + llmConfig := map[string]interface{}{ + "api_key": req.APIKey, + "api_base": baseURL, + } + + if req.ModelType != "" { + llmConfig["model_type"] = req.ModelType + } + if req.LLMName != "" { + llmConfig["llm_name"] = req.LLMName + } + + for _, llm := range sourceLLMs { + maxTokens := llm.MaxTokens + if maxTokens == 0 { + maxTokens = 8192 + } + llmConfig["max_tokens"] = maxTokens + + existingLLM, _ := s.tenantLLMDAO.GetByTenantFactoryAndModelName(tenantID, factory, llm.LLMName) + if existingLLM != nil { + updates := map[string]interface{}{ + "api_key": req.APIKey, + "api_base": baseURL, + "max_tokens": maxTokens, + } + DB.Model(&model.TenantLLM{}). + Where("tenant_id = ? AND llm_factory = ? AND llm_name = ?", tenantID, factory, llm.LLMName). + Updates(updates) + } else { + modelType := llm.ModelType + llmName := llm.LLMName + tenantLLM := &model.TenantLLM{ + TenantID: tenantID, + LLMFactory: factory, + ModelType: &modelType, + LLMName: &llmName, + APIKey: &req.APIKey, + APIBase: &baseURL, + MaxTokens: maxTokens, + Status: "1", + } + s.tenantLLMDAO.Create(tenantLLM) + } + } + + return &SetAPIKeyResult{Message: "", Success: true}, nil +} diff --git a/internal/service/user.go b/internal/service/user.go index ccf737e3e3..eb8b2e6f1e 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -151,15 +151,38 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC user.LastLoginTime = &now_date tenantName := req.Nickname + "'s Kingdom" + + llmID := cfg.UserDefaultLLM.DefaultModels.ChatModel.Name + if llmID == "" { + llmID = "" + } + embdID := cfg.UserDefaultLLM.DefaultModels.EmbeddingModel.Name + if embdID == "" { + embdID = "" + } + asrID := cfg.UserDefaultLLM.DefaultModels.ASRModel.Name + if asrID == "" { + asrID = "" + } + img2txtID := cfg.UserDefaultLLM.DefaultModels.Image2TextModel.Name + if img2txtID == "" { + img2txtID = "" + } + rerankID := cfg.UserDefaultLLM.DefaultModels.RerankModel.Name + if rerankID == "" { + rerankID = "" + } + tenant := &model.Tenant{ ID: userID, Name: &tenantName, - LLMID: cfg.Server.Mode, - EmbdID: cfg.Server.Mode, - ASRID: cfg.Server.Mode, - Img2TxtID: cfg.Server.Mode, - RerankID: cfg.Server.Mode, + LLMID: llmID, + EmbdID: embdID, + ASRID: asrID, + Img2TxtID: img2txtID, + RerankID: rerankID, ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag", + Status: &status, } tenant.CreateTime = &now tenant.UpdateTime = &now @@ -753,3 +776,52 @@ func (s *UserService) GetLoginChannels() ([]*LoginChannel, common.ErrorCode, err return channels, common.CodeSuccess, nil } + +// SetTenantInfoRequest represents the request for setting tenant info +type SetTenantInfoRequest struct { + TenantID string `json:"tenant_id"` + ASRID string `json:"asr_id"` + EmbdID string `json:"embd_id"` + Img2TxtID string `json:"img2txt_id"` + LLMID string `json:"llm_id"` + RerankID string `json:"rerank_id"` + TTSID string `json:"tts_id"` +} + +// SetTenantInfo updates tenant model configuration +func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) error { + tenantDAO := dao.NewTenantDAO() + + _, err := tenantDAO.GetByID(req.TenantID) + if err != nil { + return fmt.Errorf("tenant not found: %w", err) + } + + updates := make(map[string]interface{}) + if req.LLMID != "" { + updates["llm_id"] = req.LLMID + } + if req.EmbdID != "" { + updates["embd_id"] = req.EmbdID + } + if req.ASRID != "" { + updates["asr_id"] = req.ASRID + } + if req.Img2TxtID != "" { + updates["img2txt_id"] = req.Img2TxtID + } + if req.RerankID != "" { + updates["rerank_id"] = req.RerankID + } + if req.TTSID != "" { + updates["tts_id"] = req.TTSID + } + + if len(updates) > 0 { + if err := tenantDAO.Update(req.TenantID, updates); err != nil { + return fmt.Errorf("failed to update tenant: %w", err) + } + } + + return nil +}