diff --git a/cmd/server_main.go b/cmd/server_main.go index dbae3efbea..c4919abeca 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -6,8 +6,7 @@ import ( "net/http" "os" "os/signal" - "ragflow/internal/common" - "ragflow/internal/init_data" + "ragflow/internal/common" "ragflow/internal/server" "ragflow/internal/utility" "strings" @@ -67,7 +66,7 @@ func main() { } // Initialize LLM factory data models from configuration file - if err := init_data.InitLLMFactory(); err != nil { + if err := dao.InitLLMFactory(); err != nil { logger.Error("Failed to initialize LLM factory", err) } else { logger.Info("LLM factory initialized successfully") diff --git a/internal/dao/database.go b/internal/dao/database.go index 391759431e..1529088fb4 100644 --- a/internal/dao/database.go +++ b/internal/dao/database.go @@ -17,21 +17,50 @@ package dao import ( + "encoding/json" "fmt" + "log" + "os" + "path/filepath" + "time" + + "ragflow/internal/logger" "ragflow/internal/model" "ragflow/internal/server" - "time" + "ragflow/internal/utility" gormLogger "gorm.io/gorm/logger" "gorm.io/driver/mysql" "gorm.io/gorm" - - "ragflow/internal/logger" ) var DB *gorm.DB +// LLMFactoryConfig represents a single LLM factory configuration +type LLMFactoryConfig struct { + Name string `json:"name"` + Logo string `json:"logo"` + Tags string `json:"tags"` + Status string `json:"status"` + Rank string `json:"rank"` + LLM []LLMConfig `json:"llm"` +} + +// LLMConfig represents a single LLM model configuration +type LLMConfig struct { + LLMName string `json:"llm_name"` + Tags string `json:"tags"` + MaxTokens int64 `json:"max_tokens"` + ModelType string `json:"model_type"` + IsTools bool `json:"is_tools"` +} + +// LLMFactoriesFile represents the structure of llm_factories.json +type LLMFactoriesFile struct { + FactoryLLMInfos []LLMFactoryConfig `json:"factory_llm_infos"` +} + // InitDB initialize database connection func InitDB() error { cfg := server.GetConfig() @@ -132,3 +161,93 @@ func InitDB() error { func GetDB() *gorm.DB { return DB } + +// InitLLMFactory initializes LLM factories and models from JSON file. +// It reads the llm_factories.json configuration file and populates the database +// with LLM factory and model information. If a factory or model already exists, +// it will be updated with the new configuration. +// +// Returns: +// - error: An error if the initialization fails, nil otherwise. +func InitLLMFactory() error { + configPath := filepath.Join(utility.GetProjectBaseDirectory(), "conf", "llm_factories.json") + + data, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("failed to read llm_factories.json: %w", err) + } + + var fileData LLMFactoriesFile + if err := json.Unmarshal(data, &fileData); err != nil { + return fmt.Errorf("failed to parse llm_factories.json: %w", err) + } + + db := DB + + for _, factory := range fileData.FactoryLLMInfos { + status := factory.Status + if status == "" { + status = "1" + } + + llmFactory := &model.LLMFactories{ + Name: factory.Name, + Logo: utility.StringPtr(factory.Logo), + Tags: factory.Tags, + Rank: utility.ParseInt64(factory.Rank), + Status: &status, + } + + var existingFactory model.LLMFactories + result := db.Where("name = ?", factory.Name).First(&existingFactory) + if result.Error != nil { + if err := db.Create(llmFactory).Error; err != nil { + log.Printf("Failed to create LLM factory %s: %v", factory.Name, err) + continue + } + } else { + if err := db.Model(&model.LLMFactories{}).Where("name = ?", factory.Name).Updates(map[string]interface{}{ + "logo": llmFactory.Logo, + "tags": llmFactory.Tags, + "rank": llmFactory.Rank, + "status": llmFactory.Status, + }).Error; err != nil { + log.Printf("Failed to update LLM factory %s: %v", factory.Name, err) + } + } + + for _, llm := range factory.LLM { + llmStatus := "1" + llmModel := &model.LLM{ + LLMName: llm.LLMName, + ModelType: llm.ModelType, + FID: factory.Name, + MaxTokens: llm.MaxTokens, + Tags: llm.Tags, + IsTools: llm.IsTools, + Status: &llmStatus, + } + + var existingLLM model.LLM + result := db.Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).First(&existingLLM) + if result.Error != nil { + if err := db.Create(llmModel).Error; err != nil { + log.Printf("Failed to create LLM %s/%s: %v", factory.Name, llm.LLMName, err) + } + } else { + if err := db.Model(&model.LLM{}).Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).Updates(map[string]interface{}{ + "model_type": llmModel.ModelType, + "max_tokens": llmModel.MaxTokens, + "tags": llmModel.Tags, + "is_tools": llmModel.IsTools, + "status": llmModel.Status, + }).Error; err != nil { + log.Printf("Failed to update LLM %s/%s: %v", factory.Name, llm.LLMName, err) + } + } + } + } + + log.Println("LLM factories initialized successfully") + return nil +} diff --git a/internal/init_data/llm_init.go b/internal/init_data/llm_init.go deleted file mode 100644 index ef67dd6ba3..0000000000 --- a/internal/init_data/llm_init.go +++ /dev/null @@ -1,157 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package init_data - -import ( - "encoding/json" - "fmt" - "log" - "os" - "path/filepath" - - "ragflow/internal/dao" - "ragflow/internal/model" -) - -// LLMFactoryConfig represents a single LLM factory configuration -type LLMFactoryConfig struct { - Name string `json:"name"` - Logo string `json:"logo"` - Tags string `json:"tags"` - Status string `json:"status"` - Rank string `json:"rank"` - LLM []LLMConfig `json:"llm"` -} - -// LLMConfig represents a single LLM model configuration -type LLMConfig struct { - LLMName string `json:"llm_name"` - Tags string `json:"tags"` - MaxTokens int64 `json:"max_tokens"` - ModelType string `json:"model_type"` - IsTools bool `json:"is_tools"` -} - -// LLMFactoriesFile represents the structure of llm_factories.json -type LLMFactoriesFile struct { - FactoryLLMInfos []LLMFactoryConfig `json:"factory_llm_infos"` -} - -// InitLLMFactory initializes LLM factories and models from JSON file -func InitLLMFactory() error { - configPath := filepath.Join(getProjectBaseDirectory(), "conf", "llm_factories.json") - - data, err := os.ReadFile(configPath) - if err != nil { - return fmt.Errorf("failed to read llm_factories.json: %w", err) - } - - var fileData LLMFactoriesFile - if err := json.Unmarshal(data, &fileData); err != nil { - return fmt.Errorf("failed to parse llm_factories.json: %w", err) - } - - db := dao.DB - - for _, factory := range fileData.FactoryLLMInfos { - status := factory.Status - if status == "" { - status = "1" - } - - llmFactory := &model.LLMFactories{ - Name: factory.Name, - Logo: stringPtr(factory.Logo), - Tags: factory.Tags, - Rank: parseInt64(factory.Rank), - Status: &status, - } - - var existingFactory model.LLMFactories - result := db.Where("name = ?", factory.Name).First(&existingFactory) - if result.Error != nil { - if err := db.Create(llmFactory).Error; err != nil { - log.Printf("Failed to create LLM factory %s: %v", factory.Name, err) - continue - } - } else { - if err := db.Model(&model.LLMFactories{}).Where("name = ?", factory.Name).Updates(map[string]interface{}{ - "logo": llmFactory.Logo, - "tags": llmFactory.Tags, - "rank": llmFactory.Rank, - "status": llmFactory.Status, - }).Error; err != nil { - log.Printf("Failed to update LLM factory %s: %v", factory.Name, err) - } - } - - for _, llm := range factory.LLM { - llmStatus := "1" - llmModel := &model.LLM{ - LLMName: llm.LLMName, - ModelType: llm.ModelType, - FID: factory.Name, - MaxTokens: llm.MaxTokens, - Tags: llm.Tags, - IsTools: llm.IsTools, - Status: &llmStatus, - } - - var existingLLM model.LLM - result := db.Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).First(&existingLLM) - if result.Error != nil { - if err := db.Create(llmModel).Error; err != nil { - log.Printf("Failed to create LLM %s/%s: %v", factory.Name, llm.LLMName, err) - } - } else { - if err := db.Model(&model.LLM{}).Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).Updates(map[string]interface{}{ - "model_type": llmModel.ModelType, - "max_tokens": llmModel.MaxTokens, - "tags": llmModel.Tags, - "is_tools": llmModel.IsTools, - "status": llmModel.Status, - }).Error; err != nil { - log.Printf("Failed to update LLM %s/%s: %v", factory.Name, llm.LLMName, err) - } - } - } - } - - log.Println("LLM factories initialized successfully") - return nil -} - -func getProjectBaseDirectory() string { - cwd, err := os.Getwd() - if err != nil { - return "." - } - return cwd -} - -func stringPtr(s string) *string { - if s == "" { - return nil - } - return &s -} - -func parseInt64(s string) int64 { - var result int64 - fmt.Sscanf(s, "%d", &result) - return result -} diff --git a/internal/utility/convert.go b/internal/utility/convert.go new file mode 100644 index 0000000000..281b361486 --- /dev/null +++ b/internal/utility/convert.go @@ -0,0 +1,80 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package utility + +import ( + "fmt" + "os" +) + +// GetProjectBaseDirectory returns the current working directory. +// If an error occurs while getting the current directory, it returns ".". +// +// Returns: +// - string: The current working directory path, or "." if an error occurs. +// +// Example: +// +// baseDir := utility.GetProjectBaseDirectory() +// configPath := filepath.Join(baseDir, "conf", "config.json") +func GetProjectBaseDirectory() string { + cwd, err := os.Getwd() + if err != nil { + return "." + } + return cwd +} + +// StringPtr converts a string to a pointer of string. +// If the input string is empty, it returns nil. +// +// Parameters: +// - s: The string to convert to a pointer. +// +// Returns: +// - *string: A pointer to the input string, or nil if the input is empty. +// +// Example: +// +// name := utility.StringPtr("example") // returns &"example" +// empty := utility.StringPtr("") // returns nil +func StringPtr(s string) *string { + if s == "" { + return nil + } + return &s +} + +// ParseInt64 parses a string to int64. +// If parsing fails, it returns 0. +// +// Parameters: +// - s: The string to parse. +// +// Returns: +// - int64: The parsed integer value, or 0 if parsing fails. +// +// Example: +// +// val := utility.ParseInt64("123") // returns 123 +// val := utility.ParseInt64("abc") // returns 0 +// val := utility.ParseInt64("") // returns 0 +func ParseInt64(s string) int64 { + var result int64 + fmt.Sscanf(s, "%d", &result) + return result +}