mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Refactor Go server model provider reading and access (#13831)
### What problem does this PR solve? 1. Refactor model provider json file format 2. Use memory data structure to replace database 3. Add CLI command to access ``` RAGFlow(user)> list pool models from 'xai'; +-------------------------------------------------------------------------------------+------------+-------------+-----------------------+ | features | max_tokens | model_types | name | +-------------------------------------------------------------------------------------+------------+-------------+-----------------------+ | map[] | 256000 | [llm] | grok-4 | | map[] | 131072 | [llm] | grok-3 | | map[] | 131072 | [llm] | grok-3-fast | | map[] | 131072 | [llm] | grok-3-mini | | map[] | 131072 | [llm] | grok-3-mini-mini-fast | | map[multimodal:map[enabled:true input_modalities:[image] output_modalities:[text]]] | 32768 | [vlm] | grok-2-vision | +-------------------------------------------------------------------------------------+------------+-------------+-----------------------+ RAGFlow(user)> show pool model 'grok-2-vision' from 'xai'; +-------------------------------------------------------------------------------------+------------+-------------+---------------+ | features | max_tokens | model_types | name | +-------------------------------------------------------------------------------------+------------+-------------+---------------+ | map[multimodal:map[enabled:true input_modalities:[image] output_modalities:[text]]] | 32768 | [vlm] | grok-2-vision | +-------------------------------------------------------------------------------------+------------+-------------+---------------+ RAGFlow(user)> list pool providers; +--------+------------------------------------------------------------+---------------------------+ | name | tags | url | +--------+------------------------------------------------------------+---------------------------+ | OpenAI | LLM,TEXT EMBEDDING,TTS,TEXT RE-RANK,SPEECH2TEXT,MODERATION | https://api.openai.com/v1 | | xAI | LLM | https://api.x.ai/v1 | +--------+------------------------------------------------------------+---------------------------+ RAGFlow(user)> show pool provider 'openai'; +---------------------------+--------+------------------------------------------------------------+--------------+ | base_url | name | tags | total_models | +---------------------------+--------+------------------------------------------------------------+--------------+ | https://api.openai.com/v1 | OpenAI | LLM,TEXT EMBEDDING,TTS,TEXT RE-RANK,SPEECH2TEXT,MODERATION | 27 | +---------------------------+--------+------------------------------------------------------------+--------------+ ``` ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
239
conf/models/openai.json
Normal file
239
conf/models/openai.json
Normal file
@@ -0,0 +1,239 @@
|
||||
{
|
||||
"name": "OpenAI",
|
||||
"tags": "LLM,TEXT EMBEDDING,TTS,TEXT RE-RANK,SPEECH2TEXT,MODERATION",
|
||||
"url": "https://api.openai.com/v1",
|
||||
"models": [
|
||||
{
|
||||
"name": "gpt-5.2-pro",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-5.2",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-5.1",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-5.1-chat-latest",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-5",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-5-mini",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-5-nano",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-5-chat-latest",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-4.1",
|
||||
"max_tokens": 1047576,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-4.1-mini",
|
||||
"max_tokens": 1047576,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-4.1-nano",
|
||||
"max_tokens": 1047576,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-4.5-preview",
|
||||
"max_tokens": 128000,
|
||||
"model_types": [
|
||||
"llm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "o3",
|
||||
"max_tokens": 200000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "o4-mini",
|
||||
"max_tokens": 200000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "o4-mini-high",
|
||||
"max_tokens": 200000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o-mini",
|
||||
"max_tokens": 128000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o",
|
||||
"max_tokens": 128000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-3.5-turbo",
|
||||
"max_tokens": 4096,
|
||||
"model_types": [
|
||||
"llm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-3.5-turbo-16k-0613",
|
||||
"max_tokens": 16385,
|
||||
"model_types": [
|
||||
"llm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "text-embedding-ada-002",
|
||||
"max_tokens": 8191,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "text-embedding-3-small",
|
||||
"max_tokens": 8191,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "text-embedding-3-large",
|
||||
"max_tokens": 8191,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "whisper-1",
|
||||
"max_tokens": 26214400,
|
||||
"model_types": [
|
||||
"speech2text"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"max_tokens": 8191,
|
||||
"model_types": [
|
||||
"llm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-4-turbo",
|
||||
"max_tokens": 8191,
|
||||
"model_types": [
|
||||
"llm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "gpt-4-32k",
|
||||
"max_tokens": 32768,
|
||||
"model_types": [
|
||||
"llm"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "tts-1",
|
||||
"max_tokens": 2048,
|
||||
"model_types": [
|
||||
"tts"
|
||||
],
|
||||
"features": {}
|
||||
}
|
||||
]
|
||||
}
|
||||
49
conf/models/xai.json
Normal file
49
conf/models/xai.json
Normal file
@@ -0,0 +1,49 @@
|
||||
{
|
||||
"name": "xAI",
|
||||
"tags": "LLM",
|
||||
"url": "https://api.x.ai/v1",
|
||||
"models": [
|
||||
{
|
||||
"name": "grok-4",
|
||||
"max_tokens": 256000,
|
||||
"model_types": ["llm"],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "grok-3",
|
||||
"max_tokens": 131072,
|
||||
"model_types": ["llm"],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "grok-3-fast",
|
||||
"max_tokens": 131072,
|
||||
"model_types": ["llm"],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "grok-3-mini",
|
||||
"max_tokens": 131072,
|
||||
"model_types": ["llm"],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "grok-3-mini-mini-fast",
|
||||
"max_tokens": 131072,
|
||||
"model_types": ["llm"],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "grok-2-vision",
|
||||
"max_tokens": 32768,
|
||||
"model_types": ["vlm"],
|
||||
"features": {
|
||||
"multimodal": {
|
||||
"enabled": true,
|
||||
"input_modalities": ["image"],
|
||||
"output_modalities": ["text"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -17,6 +17,8 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"ragflow/internal/handler"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -46,6 +48,15 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
|
||||
admin.POST("/reports", r.handler.Reports)
|
||||
|
||||
// provider pool route group
|
||||
provider := admin.Group("providers")
|
||||
{
|
||||
provider.GET("/", handler.ListPoolProviders)
|
||||
provider.GET("/:provider_name", handler.ShowPoolProvider)
|
||||
provider.GET("/:provider_name/models", handler.ListPoolModels)
|
||||
provider.GET("/:provider_name/models/:model_name", handler.ShowPoolModel)
|
||||
}
|
||||
|
||||
// Protected routes
|
||||
protected := admin.Group("")
|
||||
protected.Use(r.handler.AuthMiddleware())
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package cli
|
||||
|
||||
import "fmt"
|
||||
@@ -150,6 +166,8 @@ func (p *Parser) parseAdminListCommand() (*Command, error) {
|
||||
return p.parseAdminListModelProviders()
|
||||
case TokenDefault:
|
||||
return p.parseAdminListDefaultModels()
|
||||
case TokenPool:
|
||||
return p.parseCommonListPoolModels()
|
||||
case TokenChats:
|
||||
p.nextToken()
|
||||
// Semicolon is optional for SHOW TOKEN
|
||||
@@ -255,6 +273,78 @@ func (p *Parser) parseAdminListDefaultModels() (*Command, error) {
|
||||
return NewCommand("list_user_default_models"), nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseCommonListPoolModels() (*Command, error) {
|
||||
p.nextToken() // consume POOL
|
||||
if p.curToken.Type == TokenProviders {
|
||||
return NewCommand("list_pool_providers"), nil
|
||||
} else if p.curToken.Type == TokenModels {
|
||||
p.nextToken()
|
||||
if p.curToken.Type != TokenFrom {
|
||||
return nil, fmt.Errorf("expected FROM")
|
||||
}
|
||||
p.nextToken()
|
||||
providerName, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cmd := NewCommand("list_pool_models")
|
||||
cmd.Params["provider_name"] = providerName
|
||||
p.nextToken()
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("expected PROVIDERS or MODELS")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseCommonShowPoolModel() (*Command, error) {
|
||||
p.nextToken() // consume POOL
|
||||
if p.curToken.Type == TokenProvider {
|
||||
p.nextToken()
|
||||
providerName, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cmd := NewCommand("show_pool_provider")
|
||||
cmd.Params["provider_name"] = providerName
|
||||
p.nextToken()
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
} else if p.curToken.Type == TokenModel {
|
||||
p.nextToken() // skip model
|
||||
modelName, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.nextToken() // skip model name
|
||||
if p.curToken.Type != TokenFrom {
|
||||
return nil, fmt.Errorf("expected FROM")
|
||||
}
|
||||
p.nextToken() // skip from
|
||||
providerName, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.nextToken() // skip provider name
|
||||
cmd := NewCommand("show_pool_model")
|
||||
cmd.Params["provider_name"] = providerName
|
||||
cmd.Params["model_name"] = modelName
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("expected PROVIDERS or MODELS")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseAdminListFiles() (*Command, error) {
|
||||
p.nextToken() // consume FILES
|
||||
if p.curToken.Type != TokenOf {
|
||||
@@ -319,6 +409,8 @@ func (p *Parser) parseAdminShowCommand() (*Command, error) {
|
||||
return p.parseShowVariable()
|
||||
case TokenService:
|
||||
return p.parseShowService()
|
||||
case TokenPool:
|
||||
return p.parseCommonShowPoolModel()
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown SHOW target: %s", p.curToken.Value)
|
||||
}
|
||||
|
||||
@@ -17,17 +17,7 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
ce "ragflow/internal/cli/contextengine"
|
||||
)
|
||||
|
||||
@@ -103,264 +93,6 @@ func (a *httpClientAdapter) Request(method, path string, useAPIBase bool, authKi
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoginUserInteractive performs interactive login with username and password
|
||||
func (c *RAGFlowClient) LoginUserInteractive(username, password string) error {
|
||||
// First, ping the server to check if it's available
|
||||
// For admin mode, use /admin/ping with useAPIBase=true
|
||||
// For user mode, use /system/ping with useAPIBase=false
|
||||
var pingPath string
|
||||
var useAPIBase bool
|
||||
if c.ServerType == "admin" {
|
||||
pingPath = "/admin/ping"
|
||||
useAPIBase = true
|
||||
} else {
|
||||
pingPath = "/system/ping"
|
||||
useAPIBase = false
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("GET", pingPath, useAPIBase, "web", nil, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
fmt.Println("Can't access server for login (connection failed)")
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
fmt.Println("Server is down")
|
||||
return fmt.Errorf("server is down")
|
||||
}
|
||||
|
||||
// Check response - admin returns JSON with message "PONG", user returns plain "pong"
|
||||
resJSON, err := resp.JSON()
|
||||
if err == nil {
|
||||
// Admin mode returns {"code":0,"message":"PONG"}
|
||||
if msg, ok := resJSON["message"].(string); !ok || msg != "pong" {
|
||||
fmt.Println("Server is down")
|
||||
return fmt.Errorf("server is down")
|
||||
}
|
||||
} else {
|
||||
// User mode returns plain "pong"
|
||||
if string(resp.Body) != "pong" {
|
||||
fmt.Println("Server is down")
|
||||
return fmt.Errorf("server is down")
|
||||
}
|
||||
}
|
||||
|
||||
// If password is not provided, prompt for it
|
||||
if password == "" {
|
||||
fmt.Printf("password for %s: ", username)
|
||||
var err error
|
||||
password, err = readPassword()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read password: %w", err)
|
||||
}
|
||||
password = strings.TrimSpace(password)
|
||||
}
|
||||
|
||||
// Login
|
||||
token, err := c.loginUser(username, password)
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
fmt.Println("Can't access server for login (connection failed)")
|
||||
return err
|
||||
}
|
||||
|
||||
c.HTTPClient.LoginToken = token
|
||||
fmt.Printf("Login user %s successfully\n", username)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoginUser performs user login
|
||||
func (c *RAGFlowClient) LoginUser(cmd *Command) error {
|
||||
// First, ping the server to check if it's available
|
||||
// For admin mode, use /admin/ping with useAPIBase=true
|
||||
// For user mode, use /system/ping with useAPIBase=false
|
||||
var pingPath string
|
||||
var useAPIBase bool
|
||||
if c.ServerType == "admin" {
|
||||
pingPath = "/admin/ping"
|
||||
useAPIBase = true
|
||||
} else {
|
||||
pingPath = "/system/ping"
|
||||
useAPIBase = false
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("GET", pingPath, useAPIBase, "web", nil, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
fmt.Println("Can't access server for login (connection failed)")
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
fmt.Println("Server is down")
|
||||
return fmt.Errorf("server is down")
|
||||
}
|
||||
|
||||
// Check response - admin returns JSON with message "PONG", user returns plain "pong"
|
||||
resJSON, err := resp.JSON()
|
||||
if err == nil {
|
||||
// Admin mode returns {"code":0,"message":"PONG"}
|
||||
if msg, ok := resJSON["message"].(string); !ok || msg != "pong" {
|
||||
fmt.Println("Server is down")
|
||||
return fmt.Errorf("server is down")
|
||||
}
|
||||
} else {
|
||||
// User mode returns plain "pong"
|
||||
if string(resp.Body) != "pong" {
|
||||
fmt.Println("Server is down")
|
||||
return fmt.Errorf("server is down")
|
||||
}
|
||||
}
|
||||
|
||||
email, ok := cmd.Params["email"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("email not provided")
|
||||
}
|
||||
|
||||
password, ok := cmd.Params["password"].(string)
|
||||
if !ok {
|
||||
// Get password from user input (hidden)
|
||||
fmt.Printf("password for %s: ", email)
|
||||
password, err = readPassword()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read password: %w", err)
|
||||
}
|
||||
password = strings.TrimSpace(password)
|
||||
}
|
||||
|
||||
// Login
|
||||
token, err := c.loginUser(email, password)
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
fmt.Println("Can't access server for login (connection failed)")
|
||||
return err
|
||||
}
|
||||
|
||||
c.HTTPClient.LoginToken = token
|
||||
fmt.Printf("Login user %s successfully\n", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
// loginUser performs the actual login request
|
||||
func (c *RAGFlowClient) loginUser(email, password string) (string, error) {
|
||||
// Encrypt password using scrypt (same as Python implementation)
|
||||
encryptedPassword, err := EncryptPassword(password)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to encrypt password: %w", err)
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"email": email,
|
||||
"password": encryptedPassword,
|
||||
}
|
||||
|
||||
var path string
|
||||
if c.ServerType == "admin" {
|
||||
path = "/admin/login"
|
||||
} else {
|
||||
path = "/user/login"
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("POST", path, c.ServerType == "admin", "", nil, payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var result SimpleResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return "", fmt.Errorf("login failed: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return "", fmt.Errorf("login failed: %s", result.Message)
|
||||
}
|
||||
|
||||
token := resp.Headers.Get("Authorization")
|
||||
if token == "" {
|
||||
return "", fmt.Errorf("login failed: missing Authorization header")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (c *RAGFlowClient) Logout() (ResponseIf, error) {
|
||||
if c.HTTPClient.LoginToken == "" {
|
||||
return nil, fmt.Errorf("not logged in")
|
||||
}
|
||||
|
||||
var path string
|
||||
if c.ServerType == "admin" {
|
||||
path = "/admin/logout"
|
||||
} else {
|
||||
path = "/user/logout"
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("GET", path, c.ServerType == "admin", "web", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result SimpleResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("login failed: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("login failed: %s", result.Message)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// readPassword reads password from terminal without echoing
|
||||
func readPassword() (string, error) {
|
||||
// Check if stdin is a terminal by trying to get terminal size
|
||||
if isTerminal() {
|
||||
// Use stty to disable echo
|
||||
cmd := exec.Command("stty", "-echo")
|
||||
cmd.Stdin = os.Stdin
|
||||
if err := cmd.Run(); err != nil {
|
||||
// Fallback: read normally
|
||||
return readPasswordFallback()
|
||||
}
|
||||
defer func() {
|
||||
// Re-enable echo
|
||||
cmd := exec.Command("stty", "echo")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Run()
|
||||
}()
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
password, err := reader.ReadString('\n')
|
||||
fmt.Println() // New line after password input
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(password), nil
|
||||
}
|
||||
|
||||
// Fallback for non-terminal input (e.g., piped input)
|
||||
return readPasswordFallback()
|
||||
}
|
||||
|
||||
// isTerminal checks if stdin is a terminal
|
||||
func isTerminal() bool {
|
||||
var termios syscall.Termios
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, os.Stdin.Fd(), syscall.TCGETS, uintptr(unsafe.Pointer(&termios)), 0, 0, 0)
|
||||
return err == 0
|
||||
}
|
||||
|
||||
// readPasswordFallback reads password as plain text (fallback mode)
|
||||
func readPasswordFallback() (string, error) {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
password, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(password), nil
|
||||
}
|
||||
|
||||
// ExecuteCommand executes a parsed command
|
||||
// Returns benchmark result map for commands that support it (e.g., ping_server with iterations > 1)
|
||||
func (c *RAGFlowClient) ExecuteCommand(cmd *Command) (ResponseIf, error) {
|
||||
@@ -420,6 +152,14 @@ func (c *RAGFlowClient) ExecuteAdminCommand(cmd *Command) (ResponseIf, error) {
|
||||
return c.ListAdminTokens(cmd)
|
||||
case "drop_token":
|
||||
return c.DropAdminToken(cmd)
|
||||
case "list_pool_providers":
|
||||
return c.ListPoolProviders(cmd)
|
||||
case "show_pool_provider":
|
||||
return c.ShowPoolProvider(cmd)
|
||||
case "list_pool_models":
|
||||
return c.ListPoolModels(cmd)
|
||||
case "show_pool_model":
|
||||
return c.ShowPoolModel(cmd)
|
||||
// TODO: Implement other commands
|
||||
default:
|
||||
return nil, fmt.Errorf("command '%s' would be executed with API", cmd.Type)
|
||||
@@ -463,6 +203,14 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
return c.CreateDocMetaIndex(cmd)
|
||||
case "drop_doc_meta_index":
|
||||
return c.DropDocMetaIndex(cmd)
|
||||
case "list_pool_providers":
|
||||
return c.ListPoolProviders(cmd)
|
||||
case "show_pool_provider":
|
||||
return c.ShowPoolProvider(cmd)
|
||||
case "list_pool_models":
|
||||
return c.ListPoolModels(cmd)
|
||||
case "show_pool_model":
|
||||
return c.ShowPoolModel(cmd)
|
||||
// ContextEngine commands
|
||||
case "ce_ls":
|
||||
return c.CEList(cmd)
|
||||
@@ -482,366 +230,3 @@ func (c *RAGFlowClient) ShowCurrentUser(cmd *Command) (map[string]interface{}, e
|
||||
// The /admin/auth API only verifies authorization, does not return user info
|
||||
return nil, fmt.Errorf("command 'SHOW CURRENT USER' is not yet implemented")
|
||||
}
|
||||
|
||||
type ResponseIf interface {
|
||||
Type() string
|
||||
PrintOut()
|
||||
TimeCost() float64
|
||||
SetOutputFormat(format OutputFormat)
|
||||
}
|
||||
|
||||
type CommonResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *CommonResponse) Type() string {
|
||||
return "common"
|
||||
}
|
||||
|
||||
func (r *CommonResponse) TimeCost() float64 {
|
||||
return r.Duration
|
||||
}
|
||||
|
||||
func (r *CommonResponse) SetOutputFormat(format OutputFormat) {
|
||||
r.outputFormat = format
|
||||
}
|
||||
|
||||
func (r *CommonResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
PrintTableSimpleByFormat(r.Data, r.outputFormat)
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
}
|
||||
}
|
||||
|
||||
type CommonDataResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *CommonDataResponse) Type() string {
|
||||
return "show"
|
||||
}
|
||||
|
||||
func (r *CommonDataResponse) TimeCost() float64 {
|
||||
return r.Duration
|
||||
}
|
||||
|
||||
func (r *CommonDataResponse) SetOutputFormat(format OutputFormat) {
|
||||
r.outputFormat = format
|
||||
}
|
||||
|
||||
func (r *CommonDataResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
table := make([]map[string]interface{}, 0)
|
||||
table = append(table, r.Data)
|
||||
PrintTableSimpleByFormat(table, r.outputFormat)
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
}
|
||||
}
|
||||
|
||||
type SimpleResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *SimpleResponse) Type() string {
|
||||
return "simple"
|
||||
}
|
||||
|
||||
func (r *SimpleResponse) TimeCost() float64 {
|
||||
return r.Duration
|
||||
}
|
||||
|
||||
func (r *SimpleResponse) SetOutputFormat(format OutputFormat) {
|
||||
r.outputFormat = format
|
||||
}
|
||||
|
||||
func (r *SimpleResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
fmt.Println("SUCCESS")
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
}
|
||||
}
|
||||
|
||||
type RegisterResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *RegisterResponse) Type() string {
|
||||
return "register"
|
||||
}
|
||||
|
||||
func (r *RegisterResponse) TimeCost() float64 {
|
||||
return r.Duration
|
||||
}
|
||||
|
||||
func (r *RegisterResponse) SetOutputFormat(format OutputFormat) {
|
||||
r.outputFormat = format
|
||||
}
|
||||
|
||||
func (r *RegisterResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
fmt.Println("Register successfully")
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
}
|
||||
}
|
||||
|
||||
type BenchmarkResponse struct {
|
||||
Code int `json:"code"`
|
||||
Duration float64 `json:"duration"`
|
||||
SuccessCount int `json:"success_count"`
|
||||
FailureCount int `json:"failure_count"`
|
||||
Concurrency int
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *BenchmarkResponse) Type() string {
|
||||
return "benchmark"
|
||||
}
|
||||
|
||||
func (r *BenchmarkResponse) SetOutputFormat(format OutputFormat) {
|
||||
r.outputFormat = format
|
||||
}
|
||||
|
||||
func (r *BenchmarkResponse) PrintOut() {
|
||||
if r.Code != 0 {
|
||||
fmt.Printf("ERROR, Code: %d\n", r.Code)
|
||||
return
|
||||
}
|
||||
|
||||
iterations := r.SuccessCount + r.FailureCount
|
||||
if r.Concurrency == 1 {
|
||||
if iterations == 1 {
|
||||
fmt.Printf("Latency: %fs\n", r.Duration)
|
||||
} else {
|
||||
fmt.Printf("Latency: %fs, QPS: %.1f, SUCCESS: %d, FAILURE: %d\n", r.Duration, float64(iterations)/r.Duration, r.SuccessCount, r.FailureCount)
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("Concurrency: %d, Latency: %fs, QPS: %.1f, SUCCESS: %d, FAILURE: %d\n", r.Concurrency, r.Duration, float64(iterations)/r.Duration, r.SuccessCount, r.FailureCount)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BenchmarkResponse) TimeCost() float64 {
|
||||
return r.Duration
|
||||
}
|
||||
|
||||
type KeyValueResponse struct {
|
||||
Code int `json:"code"`
|
||||
Key string `json:"key"`
|
||||
Value string `json:"data"`
|
||||
Duration float64
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *KeyValueResponse) Type() string {
|
||||
return "data"
|
||||
}
|
||||
|
||||
func (r *KeyValueResponse) TimeCost() float64 {
|
||||
return r.Duration
|
||||
}
|
||||
|
||||
func (r *KeyValueResponse) SetOutputFormat(format OutputFormat) {
|
||||
r.outputFormat = format
|
||||
}
|
||||
|
||||
func (r *KeyValueResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
table := make([]map[string]interface{}, 0)
|
||||
// insert r.key and r.value into table
|
||||
table = append(table, map[string]interface{}{
|
||||
"key": r.Key,
|
||||
"value": r.Value,
|
||||
})
|
||||
PrintTableSimpleByFormat(table, r.outputFormat)
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d\n", r.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== ContextEngine Commands ====================
|
||||
|
||||
// CEListResponse represents the response for ls command
|
||||
type CEListResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *CEListResponse) Type() string { return "ce_ls" }
|
||||
func (r *CEListResponse) TimeCost() float64 { return r.Duration }
|
||||
func (r *CEListResponse) SetOutputFormat(format OutputFormat) { r.outputFormat = format }
|
||||
func (r *CEListResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
PrintTableSimpleByFormat(r.Data, r.outputFormat)
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
}
|
||||
}
|
||||
|
||||
// CEList handles the ls command - lists nodes using Context Engine
|
||||
func (c *RAGFlowClient) CEList(cmd *Command) (ResponseIf, error) {
|
||||
// Get path from command params, default to "datasets"
|
||||
path, _ := cmd.Params["path"].(string)
|
||||
if path == "" {
|
||||
path = "datasets"
|
||||
}
|
||||
|
||||
// Parse options
|
||||
opts := &ce.ListOptions{}
|
||||
if recursive, ok := cmd.Params["recursive"].(bool); ok {
|
||||
opts.Recursive = recursive
|
||||
}
|
||||
if limit, ok := cmd.Params["limit"].(int); ok {
|
||||
opts.Limit = limit
|
||||
}
|
||||
if offset, ok := cmd.Params["offset"].(int); ok {
|
||||
opts.Offset = offset
|
||||
}
|
||||
|
||||
// Execute list command through Context Engine
|
||||
ctx := context.Background()
|
||||
result, err := c.ContextEngine.List(ctx, path, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert to response
|
||||
var response CEListResponse
|
||||
response.outputFormat = c.OutputFormat
|
||||
response.Code = 0
|
||||
response.Data = ce.FormatNodes(result.Nodes, string(c.OutputFormat))
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// getStringValue safely converts interface{} to string
|
||||
func getStringValue(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
// formatTimeValue converts a timestamp (milliseconds or string) to readable format
|
||||
func formatTimeValue(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var ts int64
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
ts = int64(val)
|
||||
case int64:
|
||||
ts = val
|
||||
case int:
|
||||
ts = int64(val)
|
||||
case string:
|
||||
// Try to parse as number
|
||||
if _, err := fmt.Sscanf(val, "%d", &ts); err != nil {
|
||||
// If it's already a formatted date string, return as is
|
||||
return val
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
// Convert milliseconds to seconds if timestamp is in milliseconds (13 digits)
|
||||
if ts > 1e12 {
|
||||
ts = ts / 1000
|
||||
}
|
||||
|
||||
t := time.Unix(ts, 0)
|
||||
return t.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
// CESearchResponse represents the response for search command
|
||||
type CESearchResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
Total int `json:"total"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *CESearchResponse) Type() string { return "ce_search" }
|
||||
func (r *CESearchResponse) TimeCost() float64 { return r.Duration }
|
||||
func (r *CESearchResponse) SetOutputFormat(format OutputFormat) { r.outputFormat = format }
|
||||
func (r *CESearchResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
fmt.Printf("Found %d results:\n", r.Total)
|
||||
PrintTableSimpleByFormat(r.Data, r.outputFormat)
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
}
|
||||
}
|
||||
|
||||
// CESearch handles the search command using Context Engine
|
||||
func (c *RAGFlowClient) CESearch(cmd *Command) (ResponseIf, error) {
|
||||
// Get path and query from command params
|
||||
path, _ := cmd.Params["path"].(string)
|
||||
if path == "" {
|
||||
path = "datasets"
|
||||
}
|
||||
query, _ := cmd.Params["query"].(string)
|
||||
|
||||
// Parse options
|
||||
opts := &ce.SearchOptions{
|
||||
Query: query,
|
||||
}
|
||||
if limit, ok := cmd.Params["limit"].(int); ok {
|
||||
opts.Limit = limit
|
||||
}
|
||||
if offset, ok := cmd.Params["offset"].(int); ok {
|
||||
opts.Offset = offset
|
||||
}
|
||||
if recursive, ok := cmd.Params["recursive"].(bool); ok {
|
||||
opts.Recursive = recursive
|
||||
}
|
||||
|
||||
// Execute search command through Context Engine
|
||||
ctx := context.Background()
|
||||
result, err := c.ContextEngine.Search(ctx, path, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert to response
|
||||
var response CESearchResponse
|
||||
response.outputFormat = c.OutputFormat
|
||||
response.Code = 0
|
||||
response.Total = result.Total
|
||||
response.Data = ce.FormatNodes(result.Nodes, string(c.OutputFormat))
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
423
internal/cli/common_command.go
Normal file
423
internal/cli/common_command.go
Normal file
@@ -0,0 +1,423 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// LoginUserInteractive performs interactive login with username and password
|
||||
func (c *RAGFlowClient) LoginUserInteractive(username, password string) error {
|
||||
// First, ping the server to check if it's available
|
||||
// For admin mode, use /admin/ping with useAPIBase=true
|
||||
// For user mode, use /system/ping with useAPIBase=false
|
||||
var pingPath string
|
||||
var useAPIBase bool
|
||||
if c.ServerType == "admin" {
|
||||
pingPath = "/admin/ping"
|
||||
useAPIBase = true
|
||||
} else {
|
||||
pingPath = "/system/ping"
|
||||
useAPIBase = false
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("GET", pingPath, useAPIBase, "web", nil, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
fmt.Println("Can't access server for login (connection failed)")
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
fmt.Println("Server is down")
|
||||
return fmt.Errorf("server is down")
|
||||
}
|
||||
|
||||
// Check response - admin returns JSON with message "PONG", user returns plain "pong"
|
||||
resJSON, err := resp.JSON()
|
||||
if err == nil {
|
||||
// Admin mode returns {"code":0,"message":"PONG"}
|
||||
if msg, ok := resJSON["message"].(string); !ok || msg != "pong" {
|
||||
fmt.Println("Server is down")
|
||||
return fmt.Errorf("server is down")
|
||||
}
|
||||
} else {
|
||||
// User mode returns plain "pong"
|
||||
if string(resp.Body) != "pong" {
|
||||
fmt.Println("Server is down")
|
||||
return fmt.Errorf("server is down")
|
||||
}
|
||||
}
|
||||
|
||||
// If password is not provided, prompt for it
|
||||
if password == "" {
|
||||
fmt.Printf("password for %s: ", username)
|
||||
var err error
|
||||
password, err = readPassword()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read password: %w", err)
|
||||
}
|
||||
password = strings.TrimSpace(password)
|
||||
}
|
||||
|
||||
// Login
|
||||
token, err := c.loginUser(username, password)
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
fmt.Println("Can't access server for login (connection failed)")
|
||||
return err
|
||||
}
|
||||
|
||||
c.HTTPClient.LoginToken = token
|
||||
fmt.Printf("Login user %s successfully\n", username)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoginUser performs user login
|
||||
func (c *RAGFlowClient) LoginUser(cmd *Command) error {
|
||||
// First, ping the server to check if it's available
|
||||
// For admin mode, use /admin/ping with useAPIBase=true
|
||||
// For user mode, use /system/ping with useAPIBase=false
|
||||
var pingPath string
|
||||
var useAPIBase bool
|
||||
if c.ServerType == "admin" {
|
||||
pingPath = "/admin/ping"
|
||||
useAPIBase = true
|
||||
} else {
|
||||
pingPath = "/system/ping"
|
||||
useAPIBase = false
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("GET", pingPath, useAPIBase, "web", nil, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
fmt.Println("Can't access server for login (connection failed)")
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
fmt.Println("Server is down")
|
||||
return fmt.Errorf("server is down")
|
||||
}
|
||||
|
||||
// Check response - admin returns JSON with message "PONG", user returns plain "pong"
|
||||
resJSON, err := resp.JSON()
|
||||
if err == nil {
|
||||
// Admin mode returns {"code":0,"message":"PONG"}
|
||||
if msg, ok := resJSON["message"].(string); !ok || msg != "pong" {
|
||||
fmt.Println("Server is down")
|
||||
return fmt.Errorf("server is down")
|
||||
}
|
||||
} else {
|
||||
// User mode returns plain "pong"
|
||||
if string(resp.Body) != "pong" {
|
||||
fmt.Println("Server is down")
|
||||
return fmt.Errorf("server is down")
|
||||
}
|
||||
}
|
||||
|
||||
email, ok := cmd.Params["email"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("email not provided")
|
||||
}
|
||||
|
||||
password, ok := cmd.Params["password"].(string)
|
||||
if !ok {
|
||||
// Get password from user input (hidden)
|
||||
fmt.Printf("password for %s: ", email)
|
||||
password, err = readPassword()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read password: %w", err)
|
||||
}
|
||||
password = strings.TrimSpace(password)
|
||||
}
|
||||
|
||||
// Login
|
||||
token, err := c.loginUser(email, password)
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
fmt.Println("Can't access server for login (connection failed)")
|
||||
return err
|
||||
}
|
||||
|
||||
c.HTTPClient.LoginToken = token
|
||||
fmt.Printf("Login user %s successfully\n", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
// loginUser performs the actual login request
|
||||
func (c *RAGFlowClient) loginUser(email, password string) (string, error) {
|
||||
// Encrypt password using scrypt (same as Python implementation)
|
||||
encryptedPassword, err := EncryptPassword(password)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to encrypt password: %w", err)
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"email": email,
|
||||
"password": encryptedPassword,
|
||||
}
|
||||
|
||||
var path string
|
||||
if c.ServerType == "admin" {
|
||||
path = "/admin/login"
|
||||
} else {
|
||||
path = "/user/login"
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("POST", path, c.ServerType == "admin", "", nil, payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var result SimpleResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return "", fmt.Errorf("login failed: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return "", fmt.Errorf("login failed: %s", result.Message)
|
||||
}
|
||||
|
||||
token := resp.Headers.Get("Authorization")
|
||||
if token == "" {
|
||||
return "", fmt.Errorf("login failed: missing Authorization header")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (c *RAGFlowClient) Logout() (ResponseIf, error) {
|
||||
if c.HTTPClient.LoginToken == "" {
|
||||
return nil, fmt.Errorf("not logged in")
|
||||
}
|
||||
|
||||
var path string
|
||||
if c.ServerType == "admin" {
|
||||
path = "/admin/logout"
|
||||
} else {
|
||||
path = "/user/logout"
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("GET", path, c.ServerType == "admin", "web", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result SimpleResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("login failed: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("login failed: %s", result.Message)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *RAGFlowClient) ListPoolProviders(cmd *Command) (ResponseIf, error) {
|
||||
|
||||
var endPoint string
|
||||
if c.ServerType == "admin" {
|
||||
endPoint = fmt.Sprintf("/admin/providers")
|
||||
} else {
|
||||
endPoint = fmt.Sprintf("/providers")
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("GET", endPoint, true, "web", nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list providers: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to list providers: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result CommonResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to list providers: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", result.Message)
|
||||
}
|
||||
result.Duration = resp.Duration
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *RAGFlowClient) ShowPoolProvider(cmd *Command) (ResponseIf, error) {
|
||||
providerName, ok := cmd.Params["provider_name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("provider_name not provided")
|
||||
}
|
||||
|
||||
var endPoint string
|
||||
if c.ServerType == "admin" {
|
||||
endPoint = fmt.Sprintf("/admin/providers/%s", providerName)
|
||||
} else {
|
||||
endPoint = fmt.Sprintf("/providers/%s", providerName)
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("GET", endPoint, true, "web", nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to show provider: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to show provider: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result CommonDataResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to show provider: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", result.Message)
|
||||
}
|
||||
result.Duration = resp.Duration
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *RAGFlowClient) ListPoolModels(cmd *Command) (ResponseIf, error) {
|
||||
|
||||
providerName, ok := cmd.Params["provider_name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("provider_name not provided")
|
||||
}
|
||||
|
||||
var endPoint string
|
||||
if c.ServerType == "admin" {
|
||||
endPoint = fmt.Sprintf("/admin/providers/%s/models", providerName)
|
||||
} else {
|
||||
endPoint = fmt.Sprintf("/providers/%s/models", providerName)
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("GET", endPoint, true, "web", nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list models: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to list models: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result CommonResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to list models: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", result.Message)
|
||||
}
|
||||
result.Duration = resp.Duration
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *RAGFlowClient) ShowPoolModel(cmd *Command) (ResponseIf, error) {
|
||||
providerName, ok := cmd.Params["provider_name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("provider_name not provided")
|
||||
}
|
||||
modelName, ok := cmd.Params["model_name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model_name not provided")
|
||||
}
|
||||
|
||||
var endPoint string
|
||||
if c.ServerType == "admin" {
|
||||
endPoint = fmt.Sprintf("/admin/providers/%s/models/%s", providerName, modelName)
|
||||
} else {
|
||||
endPoint = fmt.Sprintf("/providers/%s/models/%s", providerName, modelName)
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("GET", endPoint, true, "web", nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to show model: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to show model: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result CommonDataResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to show model: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", result.Message)
|
||||
}
|
||||
result.Duration = resp.Duration
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// readPassword reads password from terminal without echoing
|
||||
func readPassword() (string, error) {
|
||||
// Check if stdin is a terminal by trying to get terminal size
|
||||
if isTerminal() {
|
||||
// Use stty to disable echo
|
||||
cmd := exec.Command("stty", "-echo")
|
||||
cmd.Stdin = os.Stdin
|
||||
if err := cmd.Run(); err != nil {
|
||||
// Fallback: read normally
|
||||
return readPasswordFallback()
|
||||
}
|
||||
defer func() {
|
||||
// Re-enable echo
|
||||
cmd := exec.Command("stty", "echo")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Run()
|
||||
}()
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
password, err := reader.ReadString('\n')
|
||||
fmt.Println() // New line after password input
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(password), nil
|
||||
}
|
||||
|
||||
// Fallback for non-terminal input (e.g., piped input)
|
||||
return readPasswordFallback()
|
||||
}
|
||||
|
||||
// isTerminal checks if stdin is a terminal
|
||||
func isTerminal() bool {
|
||||
var termios syscall.Termios
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, os.Stdin.Fd(), syscall.TCGETS, uintptr(unsafe.Pointer(&termios)), 0, 0, 0)
|
||||
return err == 0
|
||||
}
|
||||
|
||||
// readPasswordFallback reads password as plain text (fallback mode)
|
||||
func readPasswordFallback() (string, error) {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
password, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(password), nil
|
||||
}
|
||||
@@ -301,6 +301,8 @@ func (l *Lexer) lookupIdent(ident string) Token {
|
||||
return Token{Type: TokenVectorSize, Value: ident}
|
||||
case "DOC_META":
|
||||
return Token{Type: TokenDocMeta, Value: ident}
|
||||
case "POOL":
|
||||
return Token{Type: TokenPool, Value: ident}
|
||||
default:
|
||||
return Token{Type: TokenIdentifier, Value: ident}
|
||||
}
|
||||
|
||||
262
internal/cli/response.go
Normal file
262
internal/cli/response.go
Normal file
@@ -0,0 +1,262 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package cli
|
||||
|
||||
import "fmt"
|
||||
|
||||
type ResponseIf interface {
|
||||
Type() string
|
||||
PrintOut()
|
||||
TimeCost() float64
|
||||
SetOutputFormat(format OutputFormat)
|
||||
}
|
||||
|
||||
type CommonResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *CommonResponse) Type() string {
|
||||
return "common"
|
||||
}
|
||||
|
||||
func (r *CommonResponse) TimeCost() float64 {
|
||||
return r.Duration
|
||||
}
|
||||
|
||||
func (r *CommonResponse) SetOutputFormat(format OutputFormat) {
|
||||
r.outputFormat = format
|
||||
}
|
||||
|
||||
func (r *CommonResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
PrintTableSimpleByFormat(r.Data, r.outputFormat)
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
}
|
||||
}
|
||||
|
||||
type CommonDataResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *CommonDataResponse) Type() string {
|
||||
return "show"
|
||||
}
|
||||
|
||||
func (r *CommonDataResponse) TimeCost() float64 {
|
||||
return r.Duration
|
||||
}
|
||||
|
||||
func (r *CommonDataResponse) SetOutputFormat(format OutputFormat) {
|
||||
r.outputFormat = format
|
||||
}
|
||||
|
||||
func (r *CommonDataResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
table := make([]map[string]interface{}, 0)
|
||||
table = append(table, r.Data)
|
||||
PrintTableSimpleByFormat(table, r.outputFormat)
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
}
|
||||
}
|
||||
|
||||
type SimpleResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *SimpleResponse) Type() string {
|
||||
return "simple"
|
||||
}
|
||||
|
||||
func (r *SimpleResponse) TimeCost() float64 {
|
||||
return r.Duration
|
||||
}
|
||||
|
||||
func (r *SimpleResponse) SetOutputFormat(format OutputFormat) {
|
||||
r.outputFormat = format
|
||||
}
|
||||
|
||||
func (r *SimpleResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
fmt.Println("SUCCESS")
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
}
|
||||
}
|
||||
|
||||
type RegisterResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *RegisterResponse) Type() string {
|
||||
return "register"
|
||||
}
|
||||
|
||||
func (r *RegisterResponse) TimeCost() float64 {
|
||||
return r.Duration
|
||||
}
|
||||
|
||||
func (r *RegisterResponse) SetOutputFormat(format OutputFormat) {
|
||||
r.outputFormat = format
|
||||
}
|
||||
|
||||
func (r *RegisterResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
fmt.Println("Register successfully")
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
}
|
||||
}
|
||||
|
||||
type BenchmarkResponse struct {
|
||||
Code int `json:"code"`
|
||||
Duration float64 `json:"duration"`
|
||||
SuccessCount int `json:"success_count"`
|
||||
FailureCount int `json:"failure_count"`
|
||||
Concurrency int
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *BenchmarkResponse) Type() string {
|
||||
return "benchmark"
|
||||
}
|
||||
|
||||
func (r *BenchmarkResponse) SetOutputFormat(format OutputFormat) {
|
||||
r.outputFormat = format
|
||||
}
|
||||
|
||||
func (r *BenchmarkResponse) PrintOut() {
|
||||
if r.Code != 0 {
|
||||
fmt.Printf("ERROR, Code: %d\n", r.Code)
|
||||
return
|
||||
}
|
||||
|
||||
iterations := r.SuccessCount + r.FailureCount
|
||||
if r.Concurrency == 1 {
|
||||
if iterations == 1 {
|
||||
fmt.Printf("Latency: %fs\n", r.Duration)
|
||||
} else {
|
||||
fmt.Printf("Latency: %fs, QPS: %.1f, SUCCESS: %d, FAILURE: %d\n", r.Duration, float64(iterations)/r.Duration, r.SuccessCount, r.FailureCount)
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("Concurrency: %d, Latency: %fs, QPS: %.1f, SUCCESS: %d, FAILURE: %d\n", r.Concurrency, r.Duration, float64(iterations)/r.Duration, r.SuccessCount, r.FailureCount)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BenchmarkResponse) TimeCost() float64 {
|
||||
return r.Duration
|
||||
}
|
||||
|
||||
type KeyValueResponse struct {
|
||||
Code int `json:"code"`
|
||||
Key string `json:"key"`
|
||||
Value string `json:"data"`
|
||||
Duration float64
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *KeyValueResponse) Type() string {
|
||||
return "data"
|
||||
}
|
||||
|
||||
func (r *KeyValueResponse) TimeCost() float64 {
|
||||
return r.Duration
|
||||
}
|
||||
|
||||
func (r *KeyValueResponse) SetOutputFormat(format OutputFormat) {
|
||||
r.outputFormat = format
|
||||
}
|
||||
|
||||
func (r *KeyValueResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
table := make([]map[string]interface{}, 0)
|
||||
// insert r.key and r.value into table
|
||||
table = append(table, map[string]interface{}{
|
||||
"key": r.Key,
|
||||
"value": r.Value,
|
||||
})
|
||||
PrintTableSimpleByFormat(table, r.outputFormat)
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d\n", r.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== ContextEngine Commands ====================
|
||||
|
||||
// CEListResponse represents the response for ls command
|
||||
type CEListResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *CEListResponse) Type() string { return "ce_ls" }
|
||||
func (r *CEListResponse) TimeCost() float64 { return r.Duration }
|
||||
func (r *CEListResponse) SetOutputFormat(format OutputFormat) { r.outputFormat = format }
|
||||
func (r *CEListResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
PrintTableSimpleByFormat(r.Data, r.outputFormat)
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
}
|
||||
}
|
||||
|
||||
// CESearchResponse represents the response for search command
|
||||
type CESearchResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
Total int `json:"total"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *CESearchResponse) Type() string { return "ce_search" }
|
||||
func (r *CESearchResponse) TimeCost() float64 { return r.Duration }
|
||||
func (r *CESearchResponse) SetOutputFormat(format OutputFormat) { r.outputFormat = format }
|
||||
func (r *CESearchResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
fmt.Printf("Found %d results:\n", r.Total)
|
||||
PrintTableSimpleByFormat(r.Data, r.outputFormat)
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
}
|
||||
}
|
||||
@@ -69,6 +69,7 @@ const (
|
||||
TokenKey
|
||||
TokenKeys
|
||||
TokenGenerate
|
||||
TokenPool
|
||||
TokenModel
|
||||
TokenModels
|
||||
TokenProvider
|
||||
|
||||
@@ -1,8 +1,26 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
ce "ragflow/internal/cli/contextengine"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -732,3 +750,81 @@ func (c *RAGFlowClient) DropDocMetaIndex(cmd *Command) (ResponseIf, error) {
|
||||
result.Duration = 0
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Context related commands
|
||||
|
||||
// CEList handles the ls command - lists nodes using Context Engine
|
||||
func (c *RAGFlowClient) CEList(cmd *Command) (ResponseIf, error) {
|
||||
// Get path from command params, default to "datasets"
|
||||
path, _ := cmd.Params["path"].(string)
|
||||
if path == "" {
|
||||
path = "datasets"
|
||||
}
|
||||
|
||||
// Parse options
|
||||
opts := &ce.ListOptions{}
|
||||
if recursive, ok := cmd.Params["recursive"].(bool); ok {
|
||||
opts.Recursive = recursive
|
||||
}
|
||||
if limit, ok := cmd.Params["limit"].(int); ok {
|
||||
opts.Limit = limit
|
||||
}
|
||||
if offset, ok := cmd.Params["offset"].(int); ok {
|
||||
opts.Offset = offset
|
||||
}
|
||||
|
||||
// Execute list command through Context Engine
|
||||
ctx := context.Background()
|
||||
result, err := c.ContextEngine.List(ctx, path, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert to response
|
||||
var response CEListResponse
|
||||
response.outputFormat = c.OutputFormat
|
||||
response.Code = 0
|
||||
response.Data = ce.FormatNodes(result.Nodes, string(c.OutputFormat))
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// CESearch handles the search command using Context Engine
|
||||
func (c *RAGFlowClient) CESearch(cmd *Command) (ResponseIf, error) {
|
||||
// Get path and query from command params
|
||||
path, _ := cmd.Params["path"].(string)
|
||||
if path == "" {
|
||||
path = "datasets"
|
||||
}
|
||||
query, _ := cmd.Params["query"].(string)
|
||||
|
||||
// Parse options
|
||||
opts := &ce.SearchOptions{
|
||||
Query: query,
|
||||
}
|
||||
if limit, ok := cmd.Params["limit"].(int); ok {
|
||||
opts.Limit = limit
|
||||
}
|
||||
if offset, ok := cmd.Params["offset"].(int); ok {
|
||||
opts.Offset = offset
|
||||
}
|
||||
if recursive, ok := cmd.Params["recursive"].(bool); ok {
|
||||
opts.Recursive = recursive
|
||||
}
|
||||
|
||||
// Execute search command through Context Engine
|
||||
ctx := context.Background()
|
||||
result, err := c.ContextEngine.Search(ctx, path, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert to response
|
||||
var response CESearchResponse
|
||||
response.outputFormat = c.OutputFormat
|
||||
response.Code = 0
|
||||
response.Total = result.Total
|
||||
response.Data = ce.FormatNodes(result.Nodes, string(c.OutputFormat))
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
@@ -153,6 +153,8 @@ func (p *Parser) parseListCommand() (*Command, error) {
|
||||
return p.parseListModelProviders()
|
||||
case TokenDefault:
|
||||
return p.parseListDefaultModels()
|
||||
case TokenPool:
|
||||
return p.parseCommonListPoolModels()
|
||||
case TokenChats:
|
||||
p.nextToken()
|
||||
// Semicolon is optional for SHOW TOKEN
|
||||
@@ -334,6 +336,8 @@ func (p *Parser) parseShowCommand() (*Command, error) {
|
||||
return p.parseShowVariable()
|
||||
case TokenService:
|
||||
return p.parseShowService()
|
||||
case TokenPool:
|
||||
return p.parseCommonShowPoolModel()
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown SHOW target: %s", p.curToken.Value)
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ import (
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
var modelProviderManager *entity.ProviderManager
|
||||
|
||||
// LLMFactoryConfig represents a single LLM factory configuration
|
||||
type LLMFactoryConfig struct {
|
||||
@@ -155,11 +156,17 @@ func InitDB() error {
|
||||
}
|
||||
|
||||
// Run manual migrations for complex schema changes
|
||||
if err := RunMigrations(DB); err != nil {
|
||||
if err = RunMigrations(DB); err != nil {
|
||||
return fmt.Errorf("failed to run manual migrations: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Database connected and migrated successfully")
|
||||
|
||||
modelProviderManager, err = entity.NewProviderManager("conf/models")
|
||||
if err != nil {
|
||||
log.Fatal("Failed to load model providers:", err)
|
||||
}
|
||||
logger.Info("Model providers loaded successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -168,6 +175,11 @@ func GetDB() *gorm.DB {
|
||||
return DB
|
||||
}
|
||||
|
||||
// GetModelProviderManager get database instance
|
||||
func GetModelProviderManager() *entity.ProviderManager {
|
||||
return modelProviderManager
|
||||
}
|
||||
|
||||
// autoMigrateSafely runs AutoMigrate and ignores duplicate index errors
|
||||
// This handles cases where indexes already exist (e.g., created by Python backend)
|
||||
func autoMigrateSafely(db *gorm.DB, model interface{}) error {
|
||||
|
||||
511
internal/entity/model.go
Normal file
511
internal/entity/model.go
Normal file
@@ -0,0 +1,511 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ReasoningSimple represents simple reasoning capability
|
||||
type ReasoningSimple struct {
|
||||
Type string `json:"type"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Default bool `json:"default"`
|
||||
}
|
||||
|
||||
// ReasoningBudget represents budget-based reasoning capability
|
||||
type ReasoningBudget struct {
|
||||
Type string `json:"type"`
|
||||
Enabled bool `json:"enabled"`
|
||||
DefaultTokens int `json:"default_tokens"`
|
||||
TokenRange struct {
|
||||
Min int `json:"min"`
|
||||
Max int `json:"max"`
|
||||
} `json:"token_range"`
|
||||
}
|
||||
|
||||
// ReasoningEffort represents effort-based reasoning capability
|
||||
type ReasoningEffort struct {
|
||||
Type string `json:"type"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Default string `json:"default"`
|
||||
Options []string `json:"options"`
|
||||
}
|
||||
|
||||
// Reasoning represents the reasoning capability (can be one of three types)
|
||||
type Reasoning struct {
|
||||
Simple *ReasoningSimple `json:"-"`
|
||||
Budget *ReasoningBudget `json:"-"`
|
||||
Effort *ReasoningEffort `json:"-"`
|
||||
RawType string `json:"type"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON custom unmarshal for Reasoning
|
||||
func (r *Reasoning) UnmarshalJSON(data []byte) error {
|
||||
var temp map[string]interface{}
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
typeVal, ok := temp["type"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("reasoning type is required")
|
||||
}
|
||||
|
||||
r.RawType = typeVal
|
||||
|
||||
switch typeVal {
|
||||
case "simple":
|
||||
var simple ReasoningSimple
|
||||
dataBytes, _ := json.Marshal(temp)
|
||||
if err := json.Unmarshal(dataBytes, &simple); err != nil {
|
||||
return err
|
||||
}
|
||||
r.Simple = &simple
|
||||
case "budget":
|
||||
var budget ReasoningBudget
|
||||
dataBytes, _ := json.Marshal(temp)
|
||||
if err := json.Unmarshal(dataBytes, &budget); err != nil {
|
||||
return err
|
||||
}
|
||||
r.Budget = &budget
|
||||
case "effort":
|
||||
var effort ReasoningEffort
|
||||
dataBytes, _ := json.Marshal(temp)
|
||||
if err := json.Unmarshal(dataBytes, &effort); err != nil {
|
||||
return err
|
||||
}
|
||||
r.Effort = &effort
|
||||
default:
|
||||
return fmt.Errorf("unknown reasoning type: %s", typeVal)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON custom marshal for Reasoning
|
||||
func (r *Reasoning) MarshalJSON() ([]byte, error) {
|
||||
switch r.RawType {
|
||||
case "simple":
|
||||
if r.Simple != nil {
|
||||
return json.Marshal(r.Simple)
|
||||
}
|
||||
case "budget":
|
||||
if r.Budget != nil {
|
||||
return json.Marshal(r.Budget)
|
||||
}
|
||||
case "effort":
|
||||
if r.Effort != nil {
|
||||
return json.Marshal(r.Effort)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("invalid reasoning state")
|
||||
}
|
||||
|
||||
// Multimodal represents multimodal capability
|
||||
type Multimodal struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
InputModalities []string `json:"input_modalities,omitempty"`
|
||||
OutputModalities []string `json:"output_modalities,omitempty"`
|
||||
}
|
||||
|
||||
// Features represents all features of a model
|
||||
type Features struct {
|
||||
Multimodal *Multimodal `json:"multimodal,omitempty"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
}
|
||||
|
||||
// Model represents a single LLM model
|
||||
type Model struct {
|
||||
Name string `json:"name"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
ModelTypes []string `json:"model_types"`
|
||||
Features Features `json:"features"`
|
||||
}
|
||||
|
||||
// Provider represents an LLM provider
|
||||
type Provider struct {
|
||||
Name string `json:"name"`
|
||||
Tags string `json:"tags"`
|
||||
URL string `json:"url"`
|
||||
Models []Model `json:"models"`
|
||||
}
|
||||
|
||||
// ProviderManager manages provider and model operations
|
||||
type ProviderManager struct {
|
||||
Providers []Provider `json:"model_providers"`
|
||||
}
|
||||
|
||||
// ModelResponse represents the standard response structure
|
||||
type ModelResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// NewProviderManager creates a new ProviderManager by reading all JSON files from a directory
|
||||
func NewProviderManager(dirPath string) (*ProviderManager, error) {
|
||||
providers := []Provider{}
|
||||
|
||||
// Read all files in the directory
|
||||
files, err := os.ReadDir(dirPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading directory %s: %w", dirPath, err)
|
||||
}
|
||||
|
||||
// Iterate through all files
|
||||
for _, file := range files {
|
||||
// Skip directories
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Only process JSON files
|
||||
if !strings.HasSuffix(file.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build full file path
|
||||
filePath := filepath.Join(dirPath, file.Name())
|
||||
|
||||
// Read the file
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Parse JSON
|
||||
var provider Provider
|
||||
if err = json.Unmarshal(data, &provider); err != nil {
|
||||
return nil, fmt.Errorf("error parsing JSON from file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Add to providers list
|
||||
providers = append(providers, provider)
|
||||
}
|
||||
|
||||
if len(providers) == 0 {
|
||||
return nil, fmt.Errorf("no JSON files found in directory %s", dirPath)
|
||||
}
|
||||
|
||||
return &ProviderManager{
|
||||
Providers: providers,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 1. List all providers
|
||||
func (pm *ProviderManager) ListProviders() ([]map[string]interface{}, error) {
|
||||
|
||||
var providers []map[string]interface{}
|
||||
|
||||
for _, provider := range pm.Providers {
|
||||
providerData := map[string]interface{}{
|
||||
"name": provider.Name,
|
||||
"tags": provider.Tags,
|
||||
"url": provider.URL,
|
||||
}
|
||||
providers = append(providers, providerData)
|
||||
}
|
||||
|
||||
if len(providers) == 0 {
|
||||
return nil, fmt.Errorf("no providers found")
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
// 2. Show specific provider information (including base_url)
|
||||
func (pm *ProviderManager) GetProviderByName(providerName string) (map[string]interface{}, error) {
|
||||
|
||||
provider := pm.findProvider(providerName)
|
||||
if provider == nil {
|
||||
return nil, fmt.Errorf("provider '%s' not found", providerName)
|
||||
}
|
||||
|
||||
providerInfo := map[string]interface{}{
|
||||
"name": provider.Name,
|
||||
"tags": provider.Tags,
|
||||
"base_url": provider.URL,
|
||||
"total_models": len(provider.Models),
|
||||
}
|
||||
|
||||
return providerInfo, nil
|
||||
}
|
||||
|
||||
// 3. List models under a specific provider
|
||||
func (pm *ProviderManager) ListModels(providerName string) ([]map[string]interface{}, error) {
|
||||
provider := pm.findProvider(providerName)
|
||||
if provider == nil {
|
||||
return nil, fmt.Errorf("provider '%s' not found", providerName)
|
||||
}
|
||||
|
||||
models := []map[string]interface{}{}
|
||||
for _, model := range provider.Models {
|
||||
modelData := map[string]interface{}{
|
||||
"name": model.Name,
|
||||
"max_tokens": model.MaxTokens,
|
||||
"model_types": model.ModelTypes,
|
||||
"features": getFeaturesMap(model.Features),
|
||||
}
|
||||
models = append(models, modelData)
|
||||
}
|
||||
|
||||
if len(models) == 0 {
|
||||
return nil, fmt.Errorf("no models found")
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (pm *ProviderManager) GetModelByName(providerName, modelName string) (*Model, error) {
|
||||
provider := pm.findProvider(providerName)
|
||||
if provider == nil {
|
||||
return nil, fmt.Errorf("provider '%s' not found", providerName)
|
||||
}
|
||||
model := pm.findModel(provider, modelName)
|
||||
if model == nil {
|
||||
return nil, fmt.Errorf("model '%s' not found", modelName)
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// 4. Search specific model information with filtering by max_tokens or type
|
||||
func (pm *ProviderManager) SearchModelInfo(providerName, modelName string, filterBy string, filterValue interface{}) ModelResponse {
|
||||
resp := ModelResponse{
|
||||
Code: 0,
|
||||
Data: []map[string]interface{}{},
|
||||
Message: "success",
|
||||
}
|
||||
|
||||
provider := pm.findProvider(providerName)
|
||||
if provider == nil {
|
||||
resp.Code = 404
|
||||
resp.Message = fmt.Sprintf("Provider '%s' not found", providerName)
|
||||
return resp
|
||||
}
|
||||
|
||||
model := pm.findModel(provider, modelName)
|
||||
if model == nil {
|
||||
resp.Code = 404
|
||||
resp.Message = fmt.Sprintf("Model '%s' not found in provider '%s'", modelName, providerName)
|
||||
return resp
|
||||
}
|
||||
|
||||
// Apply filters
|
||||
matchFilter := true
|
||||
if filterBy != "" && filterValue != nil {
|
||||
switch filterBy {
|
||||
case "max_tokens":
|
||||
if maxVal, ok := filterValue.(int); ok {
|
||||
if model.MaxTokens < maxVal {
|
||||
matchFilter = false
|
||||
resp.Code = 400
|
||||
resp.Message = fmt.Sprintf("Model does not meet filter criteria: max_tokens (%d) < %d",
|
||||
model.MaxTokens, maxVal)
|
||||
}
|
||||
}
|
||||
case "type":
|
||||
if typeVal, ok := filterValue.(string); ok {
|
||||
if !containsModelType(model.ModelTypes, typeVal) {
|
||||
matchFilter = false
|
||||
resp.Code = 400
|
||||
resp.Message = fmt.Sprintf("Model does not meet filter criteria: type '%s' not found", typeVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if matchFilter {
|
||||
modelData := map[string]interface{}{
|
||||
"name": model.Name,
|
||||
"max_tokens": model.MaxTokens,
|
||||
"model_types": model.ModelTypes,
|
||||
"features": getFeaturesMap(model.Features),
|
||||
}
|
||||
|
||||
if filterBy != "" && filterValue != nil {
|
||||
modelData["filter_applied"] = map[string]interface{}{
|
||||
"field": filterBy,
|
||||
"value": filterValue,
|
||||
}
|
||||
}
|
||||
|
||||
resp.Data = append(resp.Data, modelData)
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// 5. Display models with specific features
|
||||
func (pm *ProviderManager) SearchByFeature(featureType string) ModelResponse {
|
||||
resp := ModelResponse{
|
||||
Code: 0,
|
||||
Data: []map[string]interface{}{},
|
||||
Message: "success",
|
||||
}
|
||||
|
||||
for _, provider := range pm.Providers {
|
||||
for _, model := range provider.Models {
|
||||
if modelHasFeature(model.Features, featureType) {
|
||||
modelData := map[string]interface{}{
|
||||
"provider": provider.Name,
|
||||
"name": model.Name,
|
||||
"max_tokens": model.MaxTokens,
|
||||
"model_types": model.ModelTypes,
|
||||
"features": getFeaturesMap(model.Features),
|
||||
}
|
||||
resp.Data = append(resp.Data, modelData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(resp.Data) == 0 {
|
||||
resp.Code = 404
|
||||
resp.Message = fmt.Sprintf("No models found with feature '%s'", featureType)
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// 6. Display models with specific type
|
||||
func (pm *ProviderManager) SearchByType(modelType string) ModelResponse {
|
||||
resp := ModelResponse{
|
||||
Code: 0,
|
||||
Data: []map[string]interface{}{},
|
||||
Message: "success",
|
||||
}
|
||||
|
||||
for _, provider := range pm.Providers {
|
||||
for _, model := range provider.Models {
|
||||
if containsModelType(model.ModelTypes, modelType) {
|
||||
modelData := map[string]interface{}{
|
||||
"provider": provider.Name,
|
||||
"name": model.Name,
|
||||
"max_tokens": model.MaxTokens,
|
||||
"model_types": model.ModelTypes,
|
||||
"features": getFeaturesMap(model.Features),
|
||||
}
|
||||
resp.Data = append(resp.Data, modelData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(resp.Data) == 0 {
|
||||
resp.Code = 404
|
||||
resp.Message = fmt.Sprintf("No models found with type '%s'", modelType)
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// Helper: Get features map for response
|
||||
func getFeaturesMap(features Features) map[string]interface{} {
|
||||
featuresMap := make(map[string]interface{})
|
||||
|
||||
if features.Multimodal != nil && features.Multimodal.Enabled {
|
||||
multimodalMap := map[string]interface{}{
|
||||
"enabled": features.Multimodal.Enabled,
|
||||
"input_modalities": features.Multimodal.InputModalities,
|
||||
"output_modalities": features.Multimodal.OutputModalities,
|
||||
}
|
||||
featuresMap["multimodal"] = multimodalMap
|
||||
}
|
||||
|
||||
if features.Reasoning != nil {
|
||||
reasoningMap := make(map[string]interface{})
|
||||
switch features.Reasoning.RawType {
|
||||
case "simple":
|
||||
if features.Reasoning.Simple != nil {
|
||||
reasoningMap["type"] = "simple"
|
||||
reasoningMap["enabled"] = features.Reasoning.Simple.Enabled
|
||||
reasoningMap["default"] = features.Reasoning.Simple.Default
|
||||
}
|
||||
case "budget":
|
||||
if features.Reasoning.Budget != nil {
|
||||
reasoningMap["type"] = "budget"
|
||||
reasoningMap["enabled"] = features.Reasoning.Budget.Enabled
|
||||
reasoningMap["default_tokens"] = features.Reasoning.Budget.DefaultTokens
|
||||
reasoningMap["token_range"] = map[string]int{
|
||||
"min": features.Reasoning.Budget.TokenRange.Min,
|
||||
"max": features.Reasoning.Budget.TokenRange.Max,
|
||||
}
|
||||
}
|
||||
case "effort":
|
||||
if features.Reasoning.Effort != nil {
|
||||
reasoningMap["type"] = "effort"
|
||||
reasoningMap["enabled"] = features.Reasoning.Effort.Enabled
|
||||
reasoningMap["default"] = features.Reasoning.Effort.Default
|
||||
reasoningMap["options"] = features.Reasoning.Effort.Options
|
||||
}
|
||||
}
|
||||
featuresMap["reasoning"] = reasoningMap
|
||||
}
|
||||
|
||||
return featuresMap
|
||||
}
|
||||
|
||||
// Helper: Check if model has a specific feature
|
||||
func modelHasFeature(features Features, featureType string) bool {
|
||||
switch strings.ToLower(featureType) {
|
||||
case "multimodal":
|
||||
return features.Multimodal != nil && features.Multimodal.Enabled
|
||||
case "reasoning":
|
||||
return features.Reasoning != nil
|
||||
case "reasoning_simple":
|
||||
return features.Reasoning != nil && features.Reasoning.RawType == "simple"
|
||||
case "reasoning_budget":
|
||||
return features.Reasoning != nil && features.Reasoning.RawType == "budget"
|
||||
case "reasoning_effort":
|
||||
return features.Reasoning != nil && features.Reasoning.RawType == "effort"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Helper: Find provider by name
|
||||
func (pm *ProviderManager) findProvider(name string) *Provider {
|
||||
for i := range pm.Providers {
|
||||
if strings.EqualFold(pm.Providers[i].Name, name) {
|
||||
return &pm.Providers[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper: Find model by name
|
||||
func (pm *ProviderManager) findModel(provider *Provider, modelName string) *Model {
|
||||
for i := range provider.Models {
|
||||
if strings.EqualFold(provider.Models[i].Name, modelName) {
|
||||
return &provider.Models[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper: Check if model types contains target
|
||||
func containsModelType(types []string, target string) bool {
|
||||
for _, t := range types {
|
||||
if strings.EqualFold(t, target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
123
internal/handler/providers.go
Normal file
123
internal/handler/providers.go
Normal file
@@ -0,0 +1,123 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ListPoolProviders(c *gin.Context) {
|
||||
providers, err := dao.GetModelProviderManager().ListProviders()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeNotFound,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": providers,
|
||||
})
|
||||
}
|
||||
|
||||
func ShowPoolProvider(c *gin.Context) {
|
||||
providerName := c.Param("provider_name")
|
||||
if providerName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := dao.GetModelProviderManager().GetProviderByName(providerName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeNotFound,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": provider,
|
||||
})
|
||||
}
|
||||
|
||||
func ListPoolModels(c *gin.Context) {
|
||||
providerName := c.Param("provider_name")
|
||||
if providerName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
models, err := dao.GetModelProviderManager().ListModels(providerName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeNotFound,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": models,
|
||||
})
|
||||
}
|
||||
|
||||
func ShowPoolModel(c *gin.Context) {
|
||||
providerName := c.Param("provider_name")
|
||||
if providerName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
modelName := c.Param("model_name")
|
||||
if modelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
model, err := dao.GetModelProviderManager().GetModelByName(providerName, modelName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeNotFound,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": model,
|
||||
})
|
||||
}
|
||||
@@ -102,6 +102,15 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
// User logout endpoint
|
||||
engine.GET("/v1/user/logout", r.userHandler.Logout)
|
||||
|
||||
// provider pool route group
|
||||
provider := engine.Group("/api/v1/providers")
|
||||
{
|
||||
provider.GET("/", handler.ListPoolProviders)
|
||||
provider.GET("/:provider_name", handler.ShowPoolProvider)
|
||||
provider.GET("/:provider_name/models", handler.ListPoolModels)
|
||||
provider.GET("/:provider_name/models/:model_name", handler.ShowPoolModel)
|
||||
}
|
||||
|
||||
// Protected routes
|
||||
authorized := engine.Group("")
|
||||
authorized.Use(r.authHandler.AuthMiddleware())
|
||||
|
||||
Reference in New Issue
Block a user