diff --git a/conf/models/openai.json b/conf/models/openai.json new file mode 100644 index 0000000000..57bbccb41f --- /dev/null +++ b/conf/models/openai.json @@ -0,0 +1,239 @@ +{ + "name": "OpenAI", + "tags": "LLM,TEXT EMBEDDING,TTS,TEXT RE-RANK,SPEECH2TEXT,MODERATION", + "url": "https://api.openai.com/v1", + "models": [ + { + "name": "gpt-5.2-pro", + "max_tokens": 400000, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "gpt-5.2", + "max_tokens": 400000, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "gpt-5.1", + "max_tokens": 400000, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "gpt-5.1-chat-latest", + "max_tokens": 400000, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "gpt-5", + "max_tokens": 400000, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "gpt-5-mini", + "max_tokens": 400000, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "gpt-5-nano", + "max_tokens": 400000, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "gpt-5-chat-latest", + "max_tokens": 400000, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "gpt-4.1", + "max_tokens": 1047576, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "gpt-4.1-mini", + "max_tokens": 1047576, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "gpt-4.1-nano", + "max_tokens": 1047576, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "gpt-4.5-preview", + "max_tokens": 128000, + "model_types": [ + "llm" + ], + "features": {} + }, + { + "name": "o3", + "max_tokens": 200000, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "o4-mini", + "max_tokens": 200000, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "o4-mini-high", + "max_tokens": 200000, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "gpt-4o-mini", + "max_tokens": 128000, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "gpt-4o", + "max_tokens": 128000, + "model_types": [ + "llm", + "vlm" + ], + "features": {} + }, + { + "name": "gpt-3.5-turbo", + "max_tokens": 4096, + "model_types": [ + "llm" + ], + "features": {} + }, + { + "name": "gpt-3.5-turbo-16k-0613", + "max_tokens": 16385, + "model_types": [ + "llm" + ], + "features": {} + }, + { + "name": "text-embedding-ada-002", + "max_tokens": 8191, + "model_types": [ + "embedding" + ], + "features": {} + }, + { + "name": "text-embedding-3-small", + "max_tokens": 8191, + "model_types": [ + "embedding" + ], + "features": {} + }, + { + "name": "text-embedding-3-large", + "max_tokens": 8191, + "model_types": [ + "embedding" + ], + "features": {} + }, + { + "name": "whisper-1", + "max_tokens": 26214400, + "model_types": [ + "speech2text" + ], + "features": {} + }, + { + "name": "gpt-4", + "max_tokens": 8191, + "model_types": [ + "llm" + ], + "features": {} + }, + { + "name": "gpt-4-turbo", + "max_tokens": 8191, + "model_types": [ + "llm" + ], + "features": {} + }, + { + "name": "gpt-4-32k", + "max_tokens": 32768, + "model_types": [ + "llm" + ], + "features": {} + }, + { + "name": "tts-1", + "max_tokens": 2048, + "model_types": [ + "tts" + ], + "features": {} + } + ] +} \ No newline at end of file diff --git a/conf/models/xai.json b/conf/models/xai.json new file mode 100644 index 0000000000..af6905ed9f --- /dev/null +++ b/conf/models/xai.json @@ -0,0 +1,49 @@ +{ + "name": "xAI", + "tags": "LLM", + "url": "https://api.x.ai/v1", + "models": [ + { + "name": "grok-4", + "max_tokens": 256000, + "model_types": ["llm"], + "features": {} + }, + { + "name": "grok-3", + "max_tokens": 131072, + "model_types": ["llm"], + "features": {} + }, + { + "name": "grok-3-fast", + "max_tokens": 131072, + "model_types": ["llm"], + "features": {} + }, + { + "name": "grok-3-mini", + "max_tokens": 131072, + "model_types": ["llm"], + "features": {} + }, + { + "name": "grok-3-mini-mini-fast", + "max_tokens": 131072, + "model_types": ["llm"], + "features": {} + }, + { + "name": "grok-2-vision", + "max_tokens": 32768, + "model_types": ["vlm"], + "features": { + "multimodal": { + "enabled": true, + "input_modalities": ["image"], + "output_modalities": ["text"] + } + } + } + ] +} \ No newline at end of file diff --git a/internal/admin/router.go b/internal/admin/router.go index b1b7a71c2d..377999a8b5 100644 --- a/internal/admin/router.go +++ b/internal/admin/router.go @@ -17,6 +17,8 @@ package admin import ( + "ragflow/internal/handler" + "github.com/gin-gonic/gin" ) @@ -46,6 +48,15 @@ func (r *Router) Setup(engine *gin.Engine) { admin.POST("/reports", r.handler.Reports) + // provider pool route group + provider := admin.Group("providers") + { + provider.GET("/", handler.ListPoolProviders) + provider.GET("/:provider_name", handler.ShowPoolProvider) + provider.GET("/:provider_name/models", handler.ListPoolModels) + provider.GET("/:provider_name/models/:model_name", handler.ShowPoolModel) + } + // Protected routes protected := admin.Group("") protected.Use(r.handler.AuthMiddleware()) diff --git a/internal/cli/admin_command.go b/internal/cli/admin_command.go index 001ef6758e..c40014555b 100644 --- a/internal/cli/admin_command.go +++ b/internal/cli/admin_command.go @@ -1,3 +1,19 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + package cli import ( diff --git a/internal/cli/admin_parser.go b/internal/cli/admin_parser.go index 503931dd20..05b8569d83 100644 --- a/internal/cli/admin_parser.go +++ b/internal/cli/admin_parser.go @@ -1,3 +1,19 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + package cli import "fmt" @@ -150,6 +166,8 @@ func (p *Parser) parseAdminListCommand() (*Command, error) { return p.parseAdminListModelProviders() case TokenDefault: return p.parseAdminListDefaultModels() + case TokenPool: + return p.parseCommonListPoolModels() case TokenChats: p.nextToken() // Semicolon is optional for SHOW TOKEN @@ -255,6 +273,78 @@ func (p *Parser) parseAdminListDefaultModels() (*Command, error) { return NewCommand("list_user_default_models"), nil } +func (p *Parser) parseCommonListPoolModels() (*Command, error) { + p.nextToken() // consume POOL + if p.curToken.Type == TokenProviders { + return NewCommand("list_pool_providers"), nil + } else if p.curToken.Type == TokenModels { + p.nextToken() + if p.curToken.Type != TokenFrom { + return nil, fmt.Errorf("expected FROM") + } + p.nextToken() + providerName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + cmd := NewCommand("list_pool_models") + cmd.Params["provider_name"] = providerName + p.nextToken() + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil + } else { + return nil, fmt.Errorf("expected PROVIDERS or MODELS") + } +} + +func (p *Parser) parseCommonShowPoolModel() (*Command, error) { + p.nextToken() // consume POOL + if p.curToken.Type == TokenProvider { + p.nextToken() + providerName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + cmd := NewCommand("show_pool_provider") + cmd.Params["provider_name"] = providerName + p.nextToken() + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil + } else if p.curToken.Type == TokenModel { + p.nextToken() // skip model + modelName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() // skip model name + if p.curToken.Type != TokenFrom { + return nil, fmt.Errorf("expected FROM") + } + p.nextToken() // skip from + providerName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() // skip provider name + cmd := NewCommand("show_pool_model") + cmd.Params["provider_name"] = providerName + cmd.Params["model_name"] = modelName + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil + } else { + return nil, fmt.Errorf("expected PROVIDERS or MODELS") + } +} + func (p *Parser) parseAdminListFiles() (*Command, error) { p.nextToken() // consume FILES if p.curToken.Type != TokenOf { @@ -319,6 +409,8 @@ func (p *Parser) parseAdminShowCommand() (*Command, error) { return p.parseShowVariable() case TokenService: return p.parseShowService() + case TokenPool: + return p.parseCommonShowPoolModel() default: return nil, fmt.Errorf("unknown SHOW target: %s", p.curToken.Value) } diff --git a/internal/cli/client.go b/internal/cli/client.go index 11145211f6..18d7175603 100644 --- a/internal/cli/client.go +++ b/internal/cli/client.go @@ -17,17 +17,7 @@ package cli import ( - "bufio" - "context" - "encoding/json" "fmt" - "os" - "os/exec" - "strings" - "syscall" - "time" - "unsafe" - ce "ragflow/internal/cli/contextengine" ) @@ -103,264 +93,6 @@ func (a *httpClientAdapter) Request(method, path string, useAPIBase bool, authKi }, nil } -// LoginUserInteractive performs interactive login with username and password -func (c *RAGFlowClient) LoginUserInteractive(username, password string) error { - // First, ping the server to check if it's available - // For admin mode, use /admin/ping with useAPIBase=true - // For user mode, use /system/ping with useAPIBase=false - var pingPath string - var useAPIBase bool - if c.ServerType == "admin" { - pingPath = "/admin/ping" - useAPIBase = true - } else { - pingPath = "/system/ping" - useAPIBase = false - } - - resp, err := c.HTTPClient.Request("GET", pingPath, useAPIBase, "web", nil, nil) - if err != nil { - fmt.Printf("Error: %v\n", err) - fmt.Println("Can't access server for login (connection failed)") - return err - } - - if resp.StatusCode != 200 { - fmt.Println("Server is down") - return fmt.Errorf("server is down") - } - - // Check response - admin returns JSON with message "PONG", user returns plain "pong" - resJSON, err := resp.JSON() - if err == nil { - // Admin mode returns {"code":0,"message":"PONG"} - if msg, ok := resJSON["message"].(string); !ok || msg != "pong" { - fmt.Println("Server is down") - return fmt.Errorf("server is down") - } - } else { - // User mode returns plain "pong" - if string(resp.Body) != "pong" { - fmt.Println("Server is down") - return fmt.Errorf("server is down") - } - } - - // If password is not provided, prompt for it - if password == "" { - fmt.Printf("password for %s: ", username) - var err error - password, err = readPassword() - if err != nil { - return fmt.Errorf("failed to read password: %w", err) - } - password = strings.TrimSpace(password) - } - - // Login - token, err := c.loginUser(username, password) - if err != nil { - fmt.Printf("Error: %v\n", err) - fmt.Println("Can't access server for login (connection failed)") - return err - } - - c.HTTPClient.LoginToken = token - fmt.Printf("Login user %s successfully\n", username) - return nil -} - -// LoginUser performs user login -func (c *RAGFlowClient) LoginUser(cmd *Command) error { - // First, ping the server to check if it's available - // For admin mode, use /admin/ping with useAPIBase=true - // For user mode, use /system/ping with useAPIBase=false - var pingPath string - var useAPIBase bool - if c.ServerType == "admin" { - pingPath = "/admin/ping" - useAPIBase = true - } else { - pingPath = "/system/ping" - useAPIBase = false - } - - resp, err := c.HTTPClient.Request("GET", pingPath, useAPIBase, "web", nil, nil) - if err != nil { - fmt.Printf("Error: %v\n", err) - fmt.Println("Can't access server for login (connection failed)") - return err - } - - if resp.StatusCode != 200 { - fmt.Println("Server is down") - return fmt.Errorf("server is down") - } - - // Check response - admin returns JSON with message "PONG", user returns plain "pong" - resJSON, err := resp.JSON() - if err == nil { - // Admin mode returns {"code":0,"message":"PONG"} - if msg, ok := resJSON["message"].(string); !ok || msg != "pong" { - fmt.Println("Server is down") - return fmt.Errorf("server is down") - } - } else { - // User mode returns plain "pong" - if string(resp.Body) != "pong" { - fmt.Println("Server is down") - return fmt.Errorf("server is down") - } - } - - email, ok := cmd.Params["email"].(string) - if !ok { - return fmt.Errorf("email not provided") - } - - password, ok := cmd.Params["password"].(string) - if !ok { - // Get password from user input (hidden) - fmt.Printf("password for %s: ", email) - password, err = readPassword() - if err != nil { - return fmt.Errorf("failed to read password: %w", err) - } - password = strings.TrimSpace(password) - } - - // Login - token, err := c.loginUser(email, password) - if err != nil { - fmt.Printf("Error: %v\n", err) - fmt.Println("Can't access server for login (connection failed)") - return err - } - - c.HTTPClient.LoginToken = token - fmt.Printf("Login user %s successfully\n", email) - return nil -} - -// loginUser performs the actual login request -func (c *RAGFlowClient) loginUser(email, password string) (string, error) { - // Encrypt password using scrypt (same as Python implementation) - encryptedPassword, err := EncryptPassword(password) - if err != nil { - return "", fmt.Errorf("failed to encrypt password: %w", err) - } - - payload := map[string]interface{}{ - "email": email, - "password": encryptedPassword, - } - - var path string - if c.ServerType == "admin" { - path = "/admin/login" - } else { - path = "/user/login" - } - - resp, err := c.HTTPClient.Request("POST", path, c.ServerType == "admin", "", nil, payload) - if err != nil { - return "", err - } - - var result SimpleResponse - if err = json.Unmarshal(resp.Body, &result); err != nil { - return "", fmt.Errorf("login failed: invalid JSON (%w)", err) - } - - if result.Code != 0 { - return "", fmt.Errorf("login failed: %s", result.Message) - } - - token := resp.Headers.Get("Authorization") - if token == "" { - return "", fmt.Errorf("login failed: missing Authorization header") - } - - return token, nil -} - -func (c *RAGFlowClient) Logout() (ResponseIf, error) { - if c.HTTPClient.LoginToken == "" { - return nil, fmt.Errorf("not logged in") - } - - var path string - if c.ServerType == "admin" { - path = "/admin/logout" - } else { - path = "/user/logout" - } - - resp, err := c.HTTPClient.Request("GET", path, c.ServerType == "admin", "web", nil, nil) - if err != nil { - return nil, err - } - - var result SimpleResponse - if err = json.Unmarshal(resp.Body, &result); err != nil { - return nil, fmt.Errorf("login failed: invalid JSON (%w)", err) - } - - if result.Code != 0 { - return nil, fmt.Errorf("login failed: %s", result.Message) - } - - return &result, nil -} - -// readPassword reads password from terminal without echoing -func readPassword() (string, error) { - // Check if stdin is a terminal by trying to get terminal size - if isTerminal() { - // Use stty to disable echo - cmd := exec.Command("stty", "-echo") - cmd.Stdin = os.Stdin - if err := cmd.Run(); err != nil { - // Fallback: read normally - return readPasswordFallback() - } - defer func() { - // Re-enable echo - cmd := exec.Command("stty", "echo") - cmd.Stdin = os.Stdin - cmd.Run() - }() - - reader := bufio.NewReader(os.Stdin) - password, err := reader.ReadString('\n') - fmt.Println() // New line after password input - if err != nil { - return "", err - } - return strings.TrimSpace(password), nil - } - - // Fallback for non-terminal input (e.g., piped input) - return readPasswordFallback() -} - -// isTerminal checks if stdin is a terminal -func isTerminal() bool { - var termios syscall.Termios - _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, os.Stdin.Fd(), syscall.TCGETS, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) - return err == 0 -} - -// readPasswordFallback reads password as plain text (fallback mode) -func readPasswordFallback() (string, error) { - reader := bufio.NewReader(os.Stdin) - password, err := reader.ReadString('\n') - if err != nil { - return "", err - } - return strings.TrimSpace(password), nil -} - // ExecuteCommand executes a parsed command // Returns benchmark result map for commands that support it (e.g., ping_server with iterations > 1) func (c *RAGFlowClient) ExecuteCommand(cmd *Command) (ResponseIf, error) { @@ -420,6 +152,14 @@ func (c *RAGFlowClient) ExecuteAdminCommand(cmd *Command) (ResponseIf, error) { return c.ListAdminTokens(cmd) case "drop_token": return c.DropAdminToken(cmd) + case "list_pool_providers": + return c.ListPoolProviders(cmd) + case "show_pool_provider": + return c.ShowPoolProvider(cmd) + case "list_pool_models": + return c.ListPoolModels(cmd) + case "show_pool_model": + return c.ShowPoolModel(cmd) // TODO: Implement other commands default: return nil, fmt.Errorf("command '%s' would be executed with API", cmd.Type) @@ -463,6 +203,14 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { return c.CreateDocMetaIndex(cmd) case "drop_doc_meta_index": return c.DropDocMetaIndex(cmd) + case "list_pool_providers": + return c.ListPoolProviders(cmd) + case "show_pool_provider": + return c.ShowPoolProvider(cmd) + case "list_pool_models": + return c.ListPoolModels(cmd) + case "show_pool_model": + return c.ShowPoolModel(cmd) // ContextEngine commands case "ce_ls": return c.CEList(cmd) @@ -482,366 +230,3 @@ func (c *RAGFlowClient) ShowCurrentUser(cmd *Command) (map[string]interface{}, e // The /admin/auth API only verifies authorization, does not return user info return nil, fmt.Errorf("command 'SHOW CURRENT USER' is not yet implemented") } - -type ResponseIf interface { - Type() string - PrintOut() - TimeCost() float64 - SetOutputFormat(format OutputFormat) -} - -type CommonResponse struct { - Code int `json:"code"` - Data []map[string]interface{} `json:"data"` - Message string `json:"message"` - Duration float64 - outputFormat OutputFormat -} - -func (r *CommonResponse) Type() string { - return "common" -} - -func (r *CommonResponse) TimeCost() float64 { - return r.Duration -} - -func (r *CommonResponse) SetOutputFormat(format OutputFormat) { - r.outputFormat = format -} - -func (r *CommonResponse) PrintOut() { - if r.Code == 0 { - PrintTableSimpleByFormat(r.Data, r.outputFormat) - } else { - fmt.Println("ERROR") - fmt.Printf("%d, %s\n", r.Code, r.Message) - } -} - -type CommonDataResponse struct { - Code int `json:"code"` - Data map[string]interface{} `json:"data"` - Message string `json:"message"` - Duration float64 - outputFormat OutputFormat -} - -func (r *CommonDataResponse) Type() string { - return "show" -} - -func (r *CommonDataResponse) TimeCost() float64 { - return r.Duration -} - -func (r *CommonDataResponse) SetOutputFormat(format OutputFormat) { - r.outputFormat = format -} - -func (r *CommonDataResponse) PrintOut() { - if r.Code == 0 { - table := make([]map[string]interface{}, 0) - table = append(table, r.Data) - PrintTableSimpleByFormat(table, r.outputFormat) - } else { - fmt.Println("ERROR") - fmt.Printf("%d, %s\n", r.Code, r.Message) - } -} - -type SimpleResponse struct { - Code int `json:"code"` - Message string `json:"message"` - Duration float64 - outputFormat OutputFormat -} - -func (r *SimpleResponse) Type() string { - return "simple" -} - -func (r *SimpleResponse) TimeCost() float64 { - return r.Duration -} - -func (r *SimpleResponse) SetOutputFormat(format OutputFormat) { - r.outputFormat = format -} - -func (r *SimpleResponse) PrintOut() { - if r.Code == 0 { - fmt.Println("SUCCESS") - } else { - fmt.Println("ERROR") - fmt.Printf("%d, %s\n", r.Code, r.Message) - } -} - -type RegisterResponse struct { - Code int `json:"code"` - Message string `json:"message"` - Duration float64 - outputFormat OutputFormat -} - -func (r *RegisterResponse) Type() string { - return "register" -} - -func (r *RegisterResponse) TimeCost() float64 { - return r.Duration -} - -func (r *RegisterResponse) SetOutputFormat(format OutputFormat) { - r.outputFormat = format -} - -func (r *RegisterResponse) PrintOut() { - if r.Code == 0 { - fmt.Println("Register successfully") - } else { - fmt.Println("ERROR") - fmt.Printf("%d, %s\n", r.Code, r.Message) - } -} - -type BenchmarkResponse struct { - Code int `json:"code"` - Duration float64 `json:"duration"` - SuccessCount int `json:"success_count"` - FailureCount int `json:"failure_count"` - Concurrency int - outputFormat OutputFormat -} - -func (r *BenchmarkResponse) Type() string { - return "benchmark" -} - -func (r *BenchmarkResponse) SetOutputFormat(format OutputFormat) { - r.outputFormat = format -} - -func (r *BenchmarkResponse) PrintOut() { - if r.Code != 0 { - fmt.Printf("ERROR, Code: %d\n", r.Code) - return - } - - iterations := r.SuccessCount + r.FailureCount - if r.Concurrency == 1 { - if iterations == 1 { - fmt.Printf("Latency: %fs\n", r.Duration) - } else { - fmt.Printf("Latency: %fs, QPS: %.1f, SUCCESS: %d, FAILURE: %d\n", r.Duration, float64(iterations)/r.Duration, r.SuccessCount, r.FailureCount) - } - } else { - fmt.Printf("Concurrency: %d, Latency: %fs, QPS: %.1f, SUCCESS: %d, FAILURE: %d\n", r.Concurrency, r.Duration, float64(iterations)/r.Duration, r.SuccessCount, r.FailureCount) - } -} - -func (r *BenchmarkResponse) TimeCost() float64 { - return r.Duration -} - -type KeyValueResponse struct { - Code int `json:"code"` - Key string `json:"key"` - Value string `json:"data"` - Duration float64 - outputFormat OutputFormat -} - -func (r *KeyValueResponse) Type() string { - return "data" -} - -func (r *KeyValueResponse) TimeCost() float64 { - return r.Duration -} - -func (r *KeyValueResponse) SetOutputFormat(format OutputFormat) { - r.outputFormat = format -} - -func (r *KeyValueResponse) PrintOut() { - if r.Code == 0 { - table := make([]map[string]interface{}, 0) - // insert r.key and r.value into table - table = append(table, map[string]interface{}{ - "key": r.Key, - "value": r.Value, - }) - PrintTableSimpleByFormat(table, r.outputFormat) - } else { - fmt.Println("ERROR") - fmt.Printf("%d\n", r.Code) - } -} - -// ==================== ContextEngine Commands ==================== - -// CEListResponse represents the response for ls command -type CEListResponse struct { - Code int `json:"code"` - Data []map[string]interface{} `json:"data"` - Message string `json:"message"` - Duration float64 - outputFormat OutputFormat -} - -func (r *CEListResponse) Type() string { return "ce_ls" } -func (r *CEListResponse) TimeCost() float64 { return r.Duration } -func (r *CEListResponse) SetOutputFormat(format OutputFormat) { r.outputFormat = format } -func (r *CEListResponse) PrintOut() { - if r.Code == 0 { - PrintTableSimpleByFormat(r.Data, r.outputFormat) - } else { - fmt.Println("ERROR") - fmt.Printf("%d, %s\n", r.Code, r.Message) - } -} - -// CEList handles the ls command - lists nodes using Context Engine -func (c *RAGFlowClient) CEList(cmd *Command) (ResponseIf, error) { - // Get path from command params, default to "datasets" - path, _ := cmd.Params["path"].(string) - if path == "" { - path = "datasets" - } - - // Parse options - opts := &ce.ListOptions{} - if recursive, ok := cmd.Params["recursive"].(bool); ok { - opts.Recursive = recursive - } - if limit, ok := cmd.Params["limit"].(int); ok { - opts.Limit = limit - } - if offset, ok := cmd.Params["offset"].(int); ok { - opts.Offset = offset - } - - // Execute list command through Context Engine - ctx := context.Background() - result, err := c.ContextEngine.List(ctx, path, opts) - if err != nil { - return nil, err - } - - // Convert to response - var response CEListResponse - response.outputFormat = c.OutputFormat - response.Code = 0 - response.Data = ce.FormatNodes(result.Nodes, string(c.OutputFormat)) - - return &response, nil -} - -// getStringValue safely converts interface{} to string -func getStringValue(v interface{}) string { - if v == nil { - return "" - } - if s, ok := v.(string); ok { - return s - } - return fmt.Sprintf("%v", v) -} - -// formatTimeValue converts a timestamp (milliseconds or string) to readable format -func formatTimeValue(v interface{}) string { - if v == nil { - return "" - } - - var ts int64 - switch val := v.(type) { - case float64: - ts = int64(val) - case int64: - ts = val - case int: - ts = int64(val) - case string: - // Try to parse as number - if _, err := fmt.Sscanf(val, "%d", &ts); err != nil { - // If it's already a formatted date string, return as is - return val - } - default: - return fmt.Sprintf("%v", v) - } - - // Convert milliseconds to seconds if timestamp is in milliseconds (13 digits) - if ts > 1e12 { - ts = ts / 1000 - } - - t := time.Unix(ts, 0) - return t.Format("2006-01-02 15:04:05") -} - -// CESearchResponse represents the response for search command -type CESearchResponse struct { - Code int `json:"code"` - Data []map[string]interface{} `json:"data"` - Total int `json:"total"` - Message string `json:"message"` - Duration float64 - outputFormat OutputFormat -} - -func (r *CESearchResponse) Type() string { return "ce_search" } -func (r *CESearchResponse) TimeCost() float64 { return r.Duration } -func (r *CESearchResponse) SetOutputFormat(format OutputFormat) { r.outputFormat = format } -func (r *CESearchResponse) PrintOut() { - if r.Code == 0 { - fmt.Printf("Found %d results:\n", r.Total) - PrintTableSimpleByFormat(r.Data, r.outputFormat) - } else { - fmt.Println("ERROR") - fmt.Printf("%d, %s\n", r.Code, r.Message) - } -} - -// CESearch handles the search command using Context Engine -func (c *RAGFlowClient) CESearch(cmd *Command) (ResponseIf, error) { - // Get path and query from command params - path, _ := cmd.Params["path"].(string) - if path == "" { - path = "datasets" - } - query, _ := cmd.Params["query"].(string) - - // Parse options - opts := &ce.SearchOptions{ - Query: query, - } - if limit, ok := cmd.Params["limit"].(int); ok { - opts.Limit = limit - } - if offset, ok := cmd.Params["offset"].(int); ok { - opts.Offset = offset - } - if recursive, ok := cmd.Params["recursive"].(bool); ok { - opts.Recursive = recursive - } - - // Execute search command through Context Engine - ctx := context.Background() - result, err := c.ContextEngine.Search(ctx, path, opts) - if err != nil { - return nil, err - } - - // Convert to response - var response CESearchResponse - response.outputFormat = c.OutputFormat - response.Code = 0 - response.Total = result.Total - response.Data = ce.FormatNodes(result.Nodes, string(c.OutputFormat)) - - return &response, nil -} diff --git a/internal/cli/common_command.go b/internal/cli/common_command.go new file mode 100644 index 0000000000..6bb4eb304f --- /dev/null +++ b/internal/cli/common_command.go @@ -0,0 +1,423 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package cli + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "os/exec" + "strings" + "syscall" + "unsafe" +) + +// LoginUserInteractive performs interactive login with username and password +func (c *RAGFlowClient) LoginUserInteractive(username, password string) error { + // First, ping the server to check if it's available + // For admin mode, use /admin/ping with useAPIBase=true + // For user mode, use /system/ping with useAPIBase=false + var pingPath string + var useAPIBase bool + if c.ServerType == "admin" { + pingPath = "/admin/ping" + useAPIBase = true + } else { + pingPath = "/system/ping" + useAPIBase = false + } + + resp, err := c.HTTPClient.Request("GET", pingPath, useAPIBase, "web", nil, nil) + if err != nil { + fmt.Printf("Error: %v\n", err) + fmt.Println("Can't access server for login (connection failed)") + return err + } + + if resp.StatusCode != 200 { + fmt.Println("Server is down") + return fmt.Errorf("server is down") + } + + // Check response - admin returns JSON with message "PONG", user returns plain "pong" + resJSON, err := resp.JSON() + if err == nil { + // Admin mode returns {"code":0,"message":"PONG"} + if msg, ok := resJSON["message"].(string); !ok || msg != "pong" { + fmt.Println("Server is down") + return fmt.Errorf("server is down") + } + } else { + // User mode returns plain "pong" + if string(resp.Body) != "pong" { + fmt.Println("Server is down") + return fmt.Errorf("server is down") + } + } + + // If password is not provided, prompt for it + if password == "" { + fmt.Printf("password for %s: ", username) + var err error + password, err = readPassword() + if err != nil { + return fmt.Errorf("failed to read password: %w", err) + } + password = strings.TrimSpace(password) + } + + // Login + token, err := c.loginUser(username, password) + if err != nil { + fmt.Printf("Error: %v\n", err) + fmt.Println("Can't access server for login (connection failed)") + return err + } + + c.HTTPClient.LoginToken = token + fmt.Printf("Login user %s successfully\n", username) + return nil +} + +// LoginUser performs user login +func (c *RAGFlowClient) LoginUser(cmd *Command) error { + // First, ping the server to check if it's available + // For admin mode, use /admin/ping with useAPIBase=true + // For user mode, use /system/ping with useAPIBase=false + var pingPath string + var useAPIBase bool + if c.ServerType == "admin" { + pingPath = "/admin/ping" + useAPIBase = true + } else { + pingPath = "/system/ping" + useAPIBase = false + } + + resp, err := c.HTTPClient.Request("GET", pingPath, useAPIBase, "web", nil, nil) + if err != nil { + fmt.Printf("Error: %v\n", err) + fmt.Println("Can't access server for login (connection failed)") + return err + } + + if resp.StatusCode != 200 { + fmt.Println("Server is down") + return fmt.Errorf("server is down") + } + + // Check response - admin returns JSON with message "PONG", user returns plain "pong" + resJSON, err := resp.JSON() + if err == nil { + // Admin mode returns {"code":0,"message":"PONG"} + if msg, ok := resJSON["message"].(string); !ok || msg != "pong" { + fmt.Println("Server is down") + return fmt.Errorf("server is down") + } + } else { + // User mode returns plain "pong" + if string(resp.Body) != "pong" { + fmt.Println("Server is down") + return fmt.Errorf("server is down") + } + } + + email, ok := cmd.Params["email"].(string) + if !ok { + return fmt.Errorf("email not provided") + } + + password, ok := cmd.Params["password"].(string) + if !ok { + // Get password from user input (hidden) + fmt.Printf("password for %s: ", email) + password, err = readPassword() + if err != nil { + return fmt.Errorf("failed to read password: %w", err) + } + password = strings.TrimSpace(password) + } + + // Login + token, err := c.loginUser(email, password) + if err != nil { + fmt.Printf("Error: %v\n", err) + fmt.Println("Can't access server for login (connection failed)") + return err + } + + c.HTTPClient.LoginToken = token + fmt.Printf("Login user %s successfully\n", email) + return nil +} + +// loginUser performs the actual login request +func (c *RAGFlowClient) loginUser(email, password string) (string, error) { + // Encrypt password using scrypt (same as Python implementation) + encryptedPassword, err := EncryptPassword(password) + if err != nil { + return "", fmt.Errorf("failed to encrypt password: %w", err) + } + + payload := map[string]interface{}{ + "email": email, + "password": encryptedPassword, + } + + var path string + if c.ServerType == "admin" { + path = "/admin/login" + } else { + path = "/user/login" + } + + resp, err := c.HTTPClient.Request("POST", path, c.ServerType == "admin", "", nil, payload) + if err != nil { + return "", err + } + + var result SimpleResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return "", fmt.Errorf("login failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return "", fmt.Errorf("login failed: %s", result.Message) + } + + token := resp.Headers.Get("Authorization") + if token == "" { + return "", fmt.Errorf("login failed: missing Authorization header") + } + + return token, nil +} + +func (c *RAGFlowClient) Logout() (ResponseIf, error) { + if c.HTTPClient.LoginToken == "" { + return nil, fmt.Errorf("not logged in") + } + + var path string + if c.ServerType == "admin" { + path = "/admin/logout" + } else { + path = "/user/logout" + } + + resp, err := c.HTTPClient.Request("GET", path, c.ServerType == "admin", "web", nil, nil) + if err != nil { + return nil, err + } + + var result SimpleResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("login failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("login failed: %s", result.Message) + } + + return &result, nil +} + +func (c *RAGFlowClient) ListPoolProviders(cmd *Command) (ResponseIf, error) { + + var endPoint string + if c.ServerType == "admin" { + endPoint = fmt.Sprintf("/admin/providers") + } else { + endPoint = fmt.Sprintf("/providers") + } + + resp, err := c.HTTPClient.Request("GET", endPoint, true, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list providers: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list providers: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("failed to list providers: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +func (c *RAGFlowClient) ShowPoolProvider(cmd *Command) (ResponseIf, error) { + providerName, ok := cmd.Params["provider_name"].(string) + if !ok { + return nil, fmt.Errorf("provider_name not provided") + } + + var endPoint string + if c.ServerType == "admin" { + endPoint = fmt.Sprintf("/admin/providers/%s", providerName) + } else { + endPoint = fmt.Sprintf("/providers/%s", providerName) + } + + resp, err := c.HTTPClient.Request("GET", endPoint, true, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to show provider: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to show provider: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonDataResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("failed to show provider: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +func (c *RAGFlowClient) ListPoolModels(cmd *Command) (ResponseIf, error) { + + providerName, ok := cmd.Params["provider_name"].(string) + if !ok { + return nil, fmt.Errorf("provider_name not provided") + } + + var endPoint string + if c.ServerType == "admin" { + endPoint = fmt.Sprintf("/admin/providers/%s/models", providerName) + } else { + endPoint = fmt.Sprintf("/providers/%s/models", providerName) + } + + resp, err := c.HTTPClient.Request("GET", endPoint, true, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list models: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list models: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("failed to list models: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +func (c *RAGFlowClient) ShowPoolModel(cmd *Command) (ResponseIf, error) { + providerName, ok := cmd.Params["provider_name"].(string) + if !ok { + return nil, fmt.Errorf("provider_name not provided") + } + modelName, ok := cmd.Params["model_name"].(string) + if !ok { + return nil, fmt.Errorf("model_name not provided") + } + + var endPoint string + if c.ServerType == "admin" { + endPoint = fmt.Sprintf("/admin/providers/%s/models/%s", providerName, modelName) + } else { + endPoint = fmt.Sprintf("/providers/%s/models/%s", providerName, modelName) + } + + resp, err := c.HTTPClient.Request("GET", endPoint, true, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to show model: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to show model: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonDataResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("failed to show model: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +// readPassword reads password from terminal without echoing +func readPassword() (string, error) { + // Check if stdin is a terminal by trying to get terminal size + if isTerminal() { + // Use stty to disable echo + cmd := exec.Command("stty", "-echo") + cmd.Stdin = os.Stdin + if err := cmd.Run(); err != nil { + // Fallback: read normally + return readPasswordFallback() + } + defer func() { + // Re-enable echo + cmd := exec.Command("stty", "echo") + cmd.Stdin = os.Stdin + cmd.Run() + }() + + reader := bufio.NewReader(os.Stdin) + password, err := reader.ReadString('\n') + fmt.Println() // New line after password input + if err != nil { + return "", err + } + return strings.TrimSpace(password), nil + } + + // Fallback for non-terminal input (e.g., piped input) + return readPasswordFallback() +} + +// isTerminal checks if stdin is a terminal +func isTerminal() bool { + var termios syscall.Termios + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, os.Stdin.Fd(), syscall.TCGETS, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) + return err == 0 +} + +// readPasswordFallback reads password as plain text (fallback mode) +func readPasswordFallback() (string, error) { + reader := bufio.NewReader(os.Stdin) + password, err := reader.ReadString('\n') + if err != nil { + return "", err + } + return strings.TrimSpace(password), nil +} diff --git a/internal/cli/lexer.go b/internal/cli/lexer.go index 2cbcec633a..a6ef735362 100644 --- a/internal/cli/lexer.go +++ b/internal/cli/lexer.go @@ -301,6 +301,8 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenVectorSize, Value: ident} case "DOC_META": return Token{Type: TokenDocMeta, Value: ident} + case "POOL": + return Token{Type: TokenPool, Value: ident} default: return Token{Type: TokenIdentifier, Value: ident} } diff --git a/internal/cli/response.go b/internal/cli/response.go new file mode 100644 index 0000000000..54c565feed --- /dev/null +++ b/internal/cli/response.go @@ -0,0 +1,262 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package cli + +import "fmt" + +type ResponseIf interface { + Type() string + PrintOut() + TimeCost() float64 + SetOutputFormat(format OutputFormat) +} + +type CommonResponse struct { + Code int `json:"code"` + Data []map[string]interface{} `json:"data"` + Message string `json:"message"` + Duration float64 + outputFormat OutputFormat +} + +func (r *CommonResponse) Type() string { + return "common" +} + +func (r *CommonResponse) TimeCost() float64 { + return r.Duration +} + +func (r *CommonResponse) SetOutputFormat(format OutputFormat) { + r.outputFormat = format +} + +func (r *CommonResponse) PrintOut() { + if r.Code == 0 { + PrintTableSimpleByFormat(r.Data, r.outputFormat) + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + +type CommonDataResponse struct { + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + Message string `json:"message"` + Duration float64 + outputFormat OutputFormat +} + +func (r *CommonDataResponse) Type() string { + return "show" +} + +func (r *CommonDataResponse) TimeCost() float64 { + return r.Duration +} + +func (r *CommonDataResponse) SetOutputFormat(format OutputFormat) { + r.outputFormat = format +} + +func (r *CommonDataResponse) PrintOut() { + if r.Code == 0 { + table := make([]map[string]interface{}, 0) + table = append(table, r.Data) + PrintTableSimpleByFormat(table, r.outputFormat) + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + +type SimpleResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Duration float64 + outputFormat OutputFormat +} + +func (r *SimpleResponse) Type() string { + return "simple" +} + +func (r *SimpleResponse) TimeCost() float64 { + return r.Duration +} + +func (r *SimpleResponse) SetOutputFormat(format OutputFormat) { + r.outputFormat = format +} + +func (r *SimpleResponse) PrintOut() { + if r.Code == 0 { + fmt.Println("SUCCESS") + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + +type RegisterResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Duration float64 + outputFormat OutputFormat +} + +func (r *RegisterResponse) Type() string { + return "register" +} + +func (r *RegisterResponse) TimeCost() float64 { + return r.Duration +} + +func (r *RegisterResponse) SetOutputFormat(format OutputFormat) { + r.outputFormat = format +} + +func (r *RegisterResponse) PrintOut() { + if r.Code == 0 { + fmt.Println("Register successfully") + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + +type BenchmarkResponse struct { + Code int `json:"code"` + Duration float64 `json:"duration"` + SuccessCount int `json:"success_count"` + FailureCount int `json:"failure_count"` + Concurrency int + outputFormat OutputFormat +} + +func (r *BenchmarkResponse) Type() string { + return "benchmark" +} + +func (r *BenchmarkResponse) SetOutputFormat(format OutputFormat) { + r.outputFormat = format +} + +func (r *BenchmarkResponse) PrintOut() { + if r.Code != 0 { + fmt.Printf("ERROR, Code: %d\n", r.Code) + return + } + + iterations := r.SuccessCount + r.FailureCount + if r.Concurrency == 1 { + if iterations == 1 { + fmt.Printf("Latency: %fs\n", r.Duration) + } else { + fmt.Printf("Latency: %fs, QPS: %.1f, SUCCESS: %d, FAILURE: %d\n", r.Duration, float64(iterations)/r.Duration, r.SuccessCount, r.FailureCount) + } + } else { + fmt.Printf("Concurrency: %d, Latency: %fs, QPS: %.1f, SUCCESS: %d, FAILURE: %d\n", r.Concurrency, r.Duration, float64(iterations)/r.Duration, r.SuccessCount, r.FailureCount) + } +} + +func (r *BenchmarkResponse) TimeCost() float64 { + return r.Duration +} + +type KeyValueResponse struct { + Code int `json:"code"` + Key string `json:"key"` + Value string `json:"data"` + Duration float64 + outputFormat OutputFormat +} + +func (r *KeyValueResponse) Type() string { + return "data" +} + +func (r *KeyValueResponse) TimeCost() float64 { + return r.Duration +} + +func (r *KeyValueResponse) SetOutputFormat(format OutputFormat) { + r.outputFormat = format +} + +func (r *KeyValueResponse) PrintOut() { + if r.Code == 0 { + table := make([]map[string]interface{}, 0) + // insert r.key and r.value into table + table = append(table, map[string]interface{}{ + "key": r.Key, + "value": r.Value, + }) + PrintTableSimpleByFormat(table, r.outputFormat) + } else { + fmt.Println("ERROR") + fmt.Printf("%d\n", r.Code) + } +} + +// ==================== ContextEngine Commands ==================== + +// CEListResponse represents the response for ls command +type CEListResponse struct { + Code int `json:"code"` + Data []map[string]interface{} `json:"data"` + Message string `json:"message"` + Duration float64 + outputFormat OutputFormat +} + +func (r *CEListResponse) Type() string { return "ce_ls" } +func (r *CEListResponse) TimeCost() float64 { return r.Duration } +func (r *CEListResponse) SetOutputFormat(format OutputFormat) { r.outputFormat = format } +func (r *CEListResponse) PrintOut() { + if r.Code == 0 { + PrintTableSimpleByFormat(r.Data, r.outputFormat) + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + +// CESearchResponse represents the response for search command +type CESearchResponse struct { + Code int `json:"code"` + Data []map[string]interface{} `json:"data"` + Total int `json:"total"` + Message string `json:"message"` + Duration float64 + outputFormat OutputFormat +} + +func (r *CESearchResponse) Type() string { return "ce_search" } +func (r *CESearchResponse) TimeCost() float64 { return r.Duration } +func (r *CESearchResponse) SetOutputFormat(format OutputFormat) { r.outputFormat = format } +func (r *CESearchResponse) PrintOut() { + if r.Code == 0 { + fmt.Printf("Found %d results:\n", r.Total) + PrintTableSimpleByFormat(r.Data, r.outputFormat) + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} diff --git a/internal/cli/types.go b/internal/cli/types.go index 1da0ed2eea..4b2844a3ca 100644 --- a/internal/cli/types.go +++ b/internal/cli/types.go @@ -69,6 +69,7 @@ const ( TokenKey TokenKeys TokenGenerate + TokenPool TokenModel TokenModels TokenProvider diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index dda5744094..b53b7bfcd4 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -1,8 +1,26 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + package cli import ( + "context" "encoding/json" "fmt" + ce "ragflow/internal/cli/contextengine" "strings" ) @@ -732,3 +750,81 @@ func (c *RAGFlowClient) DropDocMetaIndex(cmd *Command) (ResponseIf, error) { result.Duration = 0 return &result, nil } + +// Context related commands + +// CEList handles the ls command - lists nodes using Context Engine +func (c *RAGFlowClient) CEList(cmd *Command) (ResponseIf, error) { + // Get path from command params, default to "datasets" + path, _ := cmd.Params["path"].(string) + if path == "" { + path = "datasets" + } + + // Parse options + opts := &ce.ListOptions{} + if recursive, ok := cmd.Params["recursive"].(bool); ok { + opts.Recursive = recursive + } + if limit, ok := cmd.Params["limit"].(int); ok { + opts.Limit = limit + } + if offset, ok := cmd.Params["offset"].(int); ok { + opts.Offset = offset + } + + // Execute list command through Context Engine + ctx := context.Background() + result, err := c.ContextEngine.List(ctx, path, opts) + if err != nil { + return nil, err + } + + // Convert to response + var response CEListResponse + response.outputFormat = c.OutputFormat + response.Code = 0 + response.Data = ce.FormatNodes(result.Nodes, string(c.OutputFormat)) + + return &response, nil +} + +// CESearch handles the search command using Context Engine +func (c *RAGFlowClient) CESearch(cmd *Command) (ResponseIf, error) { + // Get path and query from command params + path, _ := cmd.Params["path"].(string) + if path == "" { + path = "datasets" + } + query, _ := cmd.Params["query"].(string) + + // Parse options + opts := &ce.SearchOptions{ + Query: query, + } + if limit, ok := cmd.Params["limit"].(int); ok { + opts.Limit = limit + } + if offset, ok := cmd.Params["offset"].(int); ok { + opts.Offset = offset + } + if recursive, ok := cmd.Params["recursive"].(bool); ok { + opts.Recursive = recursive + } + + // Execute search command through Context Engine + ctx := context.Background() + result, err := c.ContextEngine.Search(ctx, path, opts) + if err != nil { + return nil, err + } + + // Convert to response + var response CESearchResponse + response.outputFormat = c.OutputFormat + response.Code = 0 + response.Total = result.Total + response.Data = ce.FormatNodes(result.Nodes, string(c.OutputFormat)) + + return &response, nil +} diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index c4ba7da358..2b939d23ba 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -153,6 +153,8 @@ func (p *Parser) parseListCommand() (*Command, error) { return p.parseListModelProviders() case TokenDefault: return p.parseListDefaultModels() + case TokenPool: + return p.parseCommonListPoolModels() case TokenChats: p.nextToken() // Semicolon is optional for SHOW TOKEN @@ -334,6 +336,8 @@ func (p *Parser) parseShowCommand() (*Command, error) { return p.parseShowVariable() case TokenService: return p.parseShowService() + case TokenPool: + return p.parseCommonShowPoolModel() default: return nil, fmt.Errorf("unknown SHOW target: %s", p.curToken.Value) } diff --git a/internal/dao/database.go b/internal/dao/database.go index 7c13f83fdd..e6418d37e3 100644 --- a/internal/dao/database.go +++ b/internal/dao/database.go @@ -39,6 +39,7 @@ import ( ) var DB *gorm.DB +var modelProviderManager *entity.ProviderManager // LLMFactoryConfig represents a single LLM factory configuration type LLMFactoryConfig struct { @@ -155,11 +156,17 @@ func InitDB() error { } // Run manual migrations for complex schema changes - if err := RunMigrations(DB); err != nil { + if err = RunMigrations(DB); err != nil { return fmt.Errorf("failed to run manual migrations: %w", err) } logger.Info("Database connected and migrated successfully") + + modelProviderManager, err = entity.NewProviderManager("conf/models") + if err != nil { + log.Fatal("Failed to load model providers:", err) + } + logger.Info("Model providers loaded successfully") return nil } @@ -168,6 +175,11 @@ func GetDB() *gorm.DB { return DB } +// GetModelProviderManager get database instance +func GetModelProviderManager() *entity.ProviderManager { + return modelProviderManager +} + // autoMigrateSafely runs AutoMigrate and ignores duplicate index errors // This handles cases where indexes already exist (e.g., created by Python backend) func autoMigrateSafely(db *gorm.DB, model interface{}) error { diff --git a/internal/entity/model.go b/internal/entity/model.go new file mode 100644 index 0000000000..a25ba026d0 --- /dev/null +++ b/internal/entity/model.go @@ -0,0 +1,511 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package entity + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" +) + +// ReasoningSimple represents simple reasoning capability +type ReasoningSimple struct { + Type string `json:"type"` + Enabled bool `json:"enabled"` + Default bool `json:"default"` +} + +// ReasoningBudget represents budget-based reasoning capability +type ReasoningBudget struct { + Type string `json:"type"` + Enabled bool `json:"enabled"` + DefaultTokens int `json:"default_tokens"` + TokenRange struct { + Min int `json:"min"` + Max int `json:"max"` + } `json:"token_range"` +} + +// ReasoningEffort represents effort-based reasoning capability +type ReasoningEffort struct { + Type string `json:"type"` + Enabled bool `json:"enabled"` + Default string `json:"default"` + Options []string `json:"options"` +} + +// Reasoning represents the reasoning capability (can be one of three types) +type Reasoning struct { + Simple *ReasoningSimple `json:"-"` + Budget *ReasoningBudget `json:"-"` + Effort *ReasoningEffort `json:"-"` + RawType string `json:"type"` +} + +// UnmarshalJSON custom unmarshal for Reasoning +func (r *Reasoning) UnmarshalJSON(data []byte) error { + var temp map[string]interface{} + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + typeVal, ok := temp["type"].(string) + if !ok { + return fmt.Errorf("reasoning type is required") + } + + r.RawType = typeVal + + switch typeVal { + case "simple": + var simple ReasoningSimple + dataBytes, _ := json.Marshal(temp) + if err := json.Unmarshal(dataBytes, &simple); err != nil { + return err + } + r.Simple = &simple + case "budget": + var budget ReasoningBudget + dataBytes, _ := json.Marshal(temp) + if err := json.Unmarshal(dataBytes, &budget); err != nil { + return err + } + r.Budget = &budget + case "effort": + var effort ReasoningEffort + dataBytes, _ := json.Marshal(temp) + if err := json.Unmarshal(dataBytes, &effort); err != nil { + return err + } + r.Effort = &effort + default: + return fmt.Errorf("unknown reasoning type: %s", typeVal) + } + + return nil +} + +// MarshalJSON custom marshal for Reasoning +func (r *Reasoning) MarshalJSON() ([]byte, error) { + switch r.RawType { + case "simple": + if r.Simple != nil { + return json.Marshal(r.Simple) + } + case "budget": + if r.Budget != nil { + return json.Marshal(r.Budget) + } + case "effort": + if r.Effort != nil { + return json.Marshal(r.Effort) + } + } + return nil, fmt.Errorf("invalid reasoning state") +} + +// Multimodal represents multimodal capability +type Multimodal struct { + Enabled bool `json:"enabled"` + InputModalities []string `json:"input_modalities,omitempty"` + OutputModalities []string `json:"output_modalities,omitempty"` +} + +// Features represents all features of a model +type Features struct { + Multimodal *Multimodal `json:"multimodal,omitempty"` + Reasoning *Reasoning `json:"reasoning,omitempty"` +} + +// Model represents a single LLM model +type Model struct { + Name string `json:"name"` + MaxTokens int `json:"max_tokens"` + ModelTypes []string `json:"model_types"` + Features Features `json:"features"` +} + +// Provider represents an LLM provider +type Provider struct { + Name string `json:"name"` + Tags string `json:"tags"` + URL string `json:"url"` + Models []Model `json:"models"` +} + +// ProviderManager manages provider and model operations +type ProviderManager struct { + Providers []Provider `json:"model_providers"` +} + +// ModelResponse represents the standard response structure +type ModelResponse struct { + Code int `json:"code"` + Data []map[string]interface{} `json:"data"` + Message string `json:"message"` +} + +// NewProviderManager creates a new ProviderManager by reading all JSON files from a directory +func NewProviderManager(dirPath string) (*ProviderManager, error) { + providers := []Provider{} + + // Read all files in the directory + files, err := os.ReadDir(dirPath) + if err != nil { + return nil, fmt.Errorf("error reading directory %s: %w", dirPath, err) + } + + // Iterate through all files + for _, file := range files { + // Skip directories + if file.IsDir() { + continue + } + + // Only process JSON files + if !strings.HasSuffix(file.Name(), ".json") { + continue + } + + // Build full file path + filePath := filepath.Join(dirPath, file.Name()) + + // Read the file + data, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("error reading file %s: %w", filePath, err) + } + + // Parse JSON + var provider Provider + if err = json.Unmarshal(data, &provider); err != nil { + return nil, fmt.Errorf("error parsing JSON from file %s: %w", filePath, err) + } + + // Add to providers list + providers = append(providers, provider) + } + + if len(providers) == 0 { + return nil, fmt.Errorf("no JSON files found in directory %s", dirPath) + } + + return &ProviderManager{ + Providers: providers, + }, nil +} + +// 1. List all providers +func (pm *ProviderManager) ListProviders() ([]map[string]interface{}, error) { + + var providers []map[string]interface{} + + for _, provider := range pm.Providers { + providerData := map[string]interface{}{ + "name": provider.Name, + "tags": provider.Tags, + "url": provider.URL, + } + providers = append(providers, providerData) + } + + if len(providers) == 0 { + return nil, fmt.Errorf("no providers found") + } + + return providers, nil +} + +// 2. Show specific provider information (including base_url) +func (pm *ProviderManager) GetProviderByName(providerName string) (map[string]interface{}, error) { + + provider := pm.findProvider(providerName) + if provider == nil { + return nil, fmt.Errorf("provider '%s' not found", providerName) + } + + providerInfo := map[string]interface{}{ + "name": provider.Name, + "tags": provider.Tags, + "base_url": provider.URL, + "total_models": len(provider.Models), + } + + return providerInfo, nil +} + +// 3. List models under a specific provider +func (pm *ProviderManager) ListModels(providerName string) ([]map[string]interface{}, error) { + provider := pm.findProvider(providerName) + if provider == nil { + return nil, fmt.Errorf("provider '%s' not found", providerName) + } + + models := []map[string]interface{}{} + for _, model := range provider.Models { + modelData := map[string]interface{}{ + "name": model.Name, + "max_tokens": model.MaxTokens, + "model_types": model.ModelTypes, + "features": getFeaturesMap(model.Features), + } + models = append(models, modelData) + } + + if len(models) == 0 { + return nil, fmt.Errorf("no models found") + } + + return models, nil +} + +func (pm *ProviderManager) GetModelByName(providerName, modelName string) (*Model, error) { + provider := pm.findProvider(providerName) + if provider == nil { + return nil, fmt.Errorf("provider '%s' not found", providerName) + } + model := pm.findModel(provider, modelName) + if model == nil { + return nil, fmt.Errorf("model '%s' not found", modelName) + } + return model, nil +} + +// 4. Search specific model information with filtering by max_tokens or type +func (pm *ProviderManager) SearchModelInfo(providerName, modelName string, filterBy string, filterValue interface{}) ModelResponse { + resp := ModelResponse{ + Code: 0, + Data: []map[string]interface{}{}, + Message: "success", + } + + provider := pm.findProvider(providerName) + if provider == nil { + resp.Code = 404 + resp.Message = fmt.Sprintf("Provider '%s' not found", providerName) + return resp + } + + model := pm.findModel(provider, modelName) + if model == nil { + resp.Code = 404 + resp.Message = fmt.Sprintf("Model '%s' not found in provider '%s'", modelName, providerName) + return resp + } + + // Apply filters + matchFilter := true + if filterBy != "" && filterValue != nil { + switch filterBy { + case "max_tokens": + if maxVal, ok := filterValue.(int); ok { + if model.MaxTokens < maxVal { + matchFilter = false + resp.Code = 400 + resp.Message = fmt.Sprintf("Model does not meet filter criteria: max_tokens (%d) < %d", + model.MaxTokens, maxVal) + } + } + case "type": + if typeVal, ok := filterValue.(string); ok { + if !containsModelType(model.ModelTypes, typeVal) { + matchFilter = false + resp.Code = 400 + resp.Message = fmt.Sprintf("Model does not meet filter criteria: type '%s' not found", typeVal) + } + } + } + } + + if matchFilter { + modelData := map[string]interface{}{ + "name": model.Name, + "max_tokens": model.MaxTokens, + "model_types": model.ModelTypes, + "features": getFeaturesMap(model.Features), + } + + if filterBy != "" && filterValue != nil { + modelData["filter_applied"] = map[string]interface{}{ + "field": filterBy, + "value": filterValue, + } + } + + resp.Data = append(resp.Data, modelData) + } + + return resp +} + +// 5. Display models with specific features +func (pm *ProviderManager) SearchByFeature(featureType string) ModelResponse { + resp := ModelResponse{ + Code: 0, + Data: []map[string]interface{}{}, + Message: "success", + } + + for _, provider := range pm.Providers { + for _, model := range provider.Models { + if modelHasFeature(model.Features, featureType) { + modelData := map[string]interface{}{ + "provider": provider.Name, + "name": model.Name, + "max_tokens": model.MaxTokens, + "model_types": model.ModelTypes, + "features": getFeaturesMap(model.Features), + } + resp.Data = append(resp.Data, modelData) + } + } + } + + if len(resp.Data) == 0 { + resp.Code = 404 + resp.Message = fmt.Sprintf("No models found with feature '%s'", featureType) + } + + return resp +} + +// 6. Display models with specific type +func (pm *ProviderManager) SearchByType(modelType string) ModelResponse { + resp := ModelResponse{ + Code: 0, + Data: []map[string]interface{}{}, + Message: "success", + } + + for _, provider := range pm.Providers { + for _, model := range provider.Models { + if containsModelType(model.ModelTypes, modelType) { + modelData := map[string]interface{}{ + "provider": provider.Name, + "name": model.Name, + "max_tokens": model.MaxTokens, + "model_types": model.ModelTypes, + "features": getFeaturesMap(model.Features), + } + resp.Data = append(resp.Data, modelData) + } + } + } + + if len(resp.Data) == 0 { + resp.Code = 404 + resp.Message = fmt.Sprintf("No models found with type '%s'", modelType) + } + + return resp +} + +// Helper: Get features map for response +func getFeaturesMap(features Features) map[string]interface{} { + featuresMap := make(map[string]interface{}) + + if features.Multimodal != nil && features.Multimodal.Enabled { + multimodalMap := map[string]interface{}{ + "enabled": features.Multimodal.Enabled, + "input_modalities": features.Multimodal.InputModalities, + "output_modalities": features.Multimodal.OutputModalities, + } + featuresMap["multimodal"] = multimodalMap + } + + if features.Reasoning != nil { + reasoningMap := make(map[string]interface{}) + switch features.Reasoning.RawType { + case "simple": + if features.Reasoning.Simple != nil { + reasoningMap["type"] = "simple" + reasoningMap["enabled"] = features.Reasoning.Simple.Enabled + reasoningMap["default"] = features.Reasoning.Simple.Default + } + case "budget": + if features.Reasoning.Budget != nil { + reasoningMap["type"] = "budget" + reasoningMap["enabled"] = features.Reasoning.Budget.Enabled + reasoningMap["default_tokens"] = features.Reasoning.Budget.DefaultTokens + reasoningMap["token_range"] = map[string]int{ + "min": features.Reasoning.Budget.TokenRange.Min, + "max": features.Reasoning.Budget.TokenRange.Max, + } + } + case "effort": + if features.Reasoning.Effort != nil { + reasoningMap["type"] = "effort" + reasoningMap["enabled"] = features.Reasoning.Effort.Enabled + reasoningMap["default"] = features.Reasoning.Effort.Default + reasoningMap["options"] = features.Reasoning.Effort.Options + } + } + featuresMap["reasoning"] = reasoningMap + } + + return featuresMap +} + +// Helper: Check if model has a specific feature +func modelHasFeature(features Features, featureType string) bool { + switch strings.ToLower(featureType) { + case "multimodal": + return features.Multimodal != nil && features.Multimodal.Enabled + case "reasoning": + return features.Reasoning != nil + case "reasoning_simple": + return features.Reasoning != nil && features.Reasoning.RawType == "simple" + case "reasoning_budget": + return features.Reasoning != nil && features.Reasoning.RawType == "budget" + case "reasoning_effort": + return features.Reasoning != nil && features.Reasoning.RawType == "effort" + default: + return false + } +} + +// Helper: Find provider by name +func (pm *ProviderManager) findProvider(name string) *Provider { + for i := range pm.Providers { + if strings.EqualFold(pm.Providers[i].Name, name) { + return &pm.Providers[i] + } + } + return nil +} + +// Helper: Find model by name +func (pm *ProviderManager) findModel(provider *Provider, modelName string) *Model { + for i := range provider.Models { + if strings.EqualFold(provider.Models[i].Name, modelName) { + return &provider.Models[i] + } + } + return nil +} + +// Helper: Check if model types contains target +func containsModelType(types []string, target string) bool { + for _, t := range types { + if strings.EqualFold(t, target) { + return true + } + } + return false +} diff --git a/internal/handler/providers.go b/internal/handler/providers.go new file mode 100644 index 0000000000..267aa70dd6 --- /dev/null +++ b/internal/handler/providers.go @@ -0,0 +1,123 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + "ragflow/internal/common" + "ragflow/internal/dao" + + "github.com/gin-gonic/gin" +) + +func ListPoolProviders(c *gin.Context) { + providers, err := dao.GetModelProviderManager().ListProviders() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeNotFound, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": providers, + }) +} + +func ShowPoolProvider(c *gin.Context) { + providerName := c.Param("provider_name") + if providerName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } + + provider, err := dao.GetModelProviderManager().GetProviderByName(providerName) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeNotFound, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": provider, + }) +} + +func ListPoolModels(c *gin.Context) { + providerName := c.Param("provider_name") + if providerName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } + models, err := dao.GetModelProviderManager().ListModels(providerName) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeNotFound, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": models, + }) +} + +func ShowPoolModel(c *gin.Context) { + providerName := c.Param("provider_name") + if providerName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } + modelName := c.Param("model_name") + if modelName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model name is required", + }) + return + } + model, err := dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeNotFound, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": model, + }) +} diff --git a/internal/router/router.go b/internal/router/router.go index 44f783e06e..95ef875818 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -102,6 +102,15 @@ func (r *Router) Setup(engine *gin.Engine) { // User logout endpoint engine.GET("/v1/user/logout", r.userHandler.Logout) + // provider pool route group + provider := engine.Group("/api/v1/providers") + { + provider.GET("/", handler.ListPoolProviders) + provider.GET("/:provider_name", handler.ShowPoolProvider) + provider.GET("/:provider_name/models", handler.ListPoolModels) + provider.GET("/:provider_name/models/:model_name", handler.ShowPoolModel) + } + // Protected routes authorized := engine.Group("") authorized.Use(r.authHandler.AuthMiddleware())