From 93d3deb5e425966197575449a88ccb077312251a Mon Sep 17 00:00:00 2001 From: Jake Armstrong <65635253+jakearmstrong59@users.noreply.github.com> Date: Mon, 18 May 2026 01:08:45 -1000 Subject: [PATCH] Fix admin CLI system variable commands (#14956) ## What Fixes #12409. Implements admin CLI support for: - `list vars;` - `show var ;` - `set var ;` ## Changes - Wire Go CLI variable commands to the admin API. - Support integer and quoted string values in `SET VAR`. - Return variable rows as `data_type`, `name`, `setting_type`, and `value`. - Add exact-name lookup with prefix fallback for `SHOW VAR`. - Validate values by stored data type: `string`, `integer`, `bool`, and `json`. - Keep the legacy Python admin CLI/server behavior aligned. - Update admin CLI docs and add focused tests. ## Verification - `go test -count=1 ./internal/cli` - `python3.12 -m py_compile admin/server/services.py admin/server/routes.py api/db/services/system_settings_service.py admin/client/parser.py admin/client/ragflow_client.py` - Python admin CLI parser smoke test for `SET VAR`, quoted values, `SHOW VAR`, and `LIST VARS`. - Attempted `./run_go_tests.sh`; local environment is missing native tokenizer/linker artifacts: - `internal/cpp/cmake-build-release/librag_tokenizer_c_api.a` - `-lstdc++` Co-authored-by: Jin Hai --- admin/client/parser.py | 3 +- admin/client/ragflow_client.py | 15 +- admin/server/routes.py | 4 +- admin/server/services.py | 74 ++++--- api/db/services/system_settings_service.py | 8 +- docs/administrator/admin/ragflow_cli.md | 26 +-- internal/admin/service.go | 102 ++++++--- internal/admin/service_variables_test.go | 65 ++++++ internal/cli/admin_command.go | 136 ++++++++++++ internal/cli/admin_parser.go | 2 +- internal/cli/admin_variables_test.go | 235 +++++++++++++++++++++ internal/cli/client.go | 6 + internal/cli/parser.go | 9 + internal/cli/response.go | 28 +++ internal/cli/user_parser.go | 2 +- internal/dao/system_settings.go | 15 +- 16 files changed, 646 insertions(+), 84 deletions(-) create mode 100644 internal/admin/service_variables_test.go create mode 100644 internal/cli/admin_variables_test.go diff --git a/admin/client/parser.py b/admin/client/parser.py index cdb20b491d..7e668c4e29 100644 --- a/admin/client/parser.py +++ b/admin/client/parser.py @@ -264,7 +264,7 @@ generate_key: GENERATE KEY FOR USER quoted_string ";" list_keys: LIST KEYS OF quoted_string ";" drop_key: DROP KEY quoted_string OF quoted_string ";" -set_variable: SET VAR identifier identifier ";" +set_variable: SET VAR identifier variable_value ";" show_variable: SHOW VAR identifier ";" list_variables: LIST VARS ";" list_configs: LIST CONFIGS ";" @@ -378,6 +378,7 @@ update_chunk: UPDATE CHUNK quoted_string OF DATASET quoted_string SET quoted_str identifier_list: identifier (COMMA identifier)* identifier: WORD +variable_value: WORD | NUMBER | QUOTED_STRING quoted_string: QUOTED_STRING status: ON | WORD diff --git a/admin/client/ragflow_client.py b/admin/client/ragflow_client.py index 148af4b45f..71a5541bba 100644 --- a/admin/client/ragflow_client.py +++ b/admin/client/ragflow_client.py @@ -43,6 +43,12 @@ def encrypt(input_string): return base64.b64encode(cipher_text).decode("utf-8") +def _strip_tree_value(value): + if isinstance(value, Tree): + value = value.children[0] + return str(value).strip("'\"") + + class RAGFlowClient: def __init__(self, http_client: HttpClient, server_type: str): self.http_client = http_client @@ -526,10 +532,8 @@ class RAGFlowClient: if self.server_type != "admin": print("This command is only allowed in ADMIN mode") - var_name_tree: Tree = command["var_name"] - var_name = var_name_tree.children[0].strip("'\"") - var_value_tree: Tree = command["var_value"] - var_value = var_value_tree.children[0].strip("'\"") + var_name = _strip_tree_value(command["var_name"]) + var_value = _strip_tree_value(command["var_value"]) response = self.http_client.request("PUT", "/admin/variables", json_body={"var_name": var_name, "var_value": var_value}, use_api_base=True, auth_kind="admin") @@ -544,8 +548,7 @@ class RAGFlowClient: if self.server_type != "admin": print("This command is only allowed in ADMIN mode") - var_name_tree: Tree = command["var_name"] - var_name = var_name_tree.children[0].strip("'\"") + var_name = _strip_tree_value(command["var_name"]) response = self.http_client.request(method="GET", path="/admin/variables", json_body={"var_name": var_name}, use_api_base=True, auth_kind="admin") res_json = response.json() diff --git a/admin/server/routes.py b/admin/server/routes.py index 658cec48c0..0313d8230b 100644 --- a/admin/server/routes.py +++ b/admin/server/routes.py @@ -421,7 +421,7 @@ def get_user_permission(user_name: str): def set_variable(): try: data = request.get_json() - if not data and "var_name" not in data: + if not data or "var_name" not in data: return error_response("Var name is required", 400) if "var_value" not in data: @@ -449,7 +449,7 @@ def get_variable(): # get var data = request.get_json() - if not data and "var_name" not in data: + if not data or "var_name" not in data: return error_response("Var name is required", 400) var_name: str = data["var_name"] res = SettingsMgr.get_by_name(var_name) diff --git a/admin/server/services.py b/admin/server/services.py index 43646d7918..cc68f98dc3 100644 --- a/admin/server/services.py +++ b/admin/server/services.py @@ -330,36 +330,65 @@ class ServiceMgr: class SettingsMgr: + @staticmethod + def _format_setting(setting): + return { + "data_type": setting.data_type, + "name": setting.name, + "setting_type": "config", + "value": setting.value, + } + + @staticmethod + def _validate_value(name: str, data_type: str, value: str): + data_type = data_type.lower() + value = str(value) + if data_type == "string": + return + if data_type == "integer": + try: + int(value) + except ValueError: + raise AdminException(f"Invalid integer value for {name}: {value}") + return + if data_type in {"bool", "boolean"}: + if value not in {"true", "false"}: + raise AdminException(f"Invalid bool value for {name}: expected true or false") + return + if data_type == "json": + try: + json.loads(value) + except json.JSONDecodeError: + raise AdminException(f"Invalid JSON value for {name}") + return + raise AdminException(f"Unsupported data type for {name}: {data_type}") + + @staticmethod + def _infer_data_type(name: str): + if name.startswith("sandbox."): + return "json" + if name.endswith(".enabled"): + return "bool" + return "string" + @staticmethod def get_all(): - settings = SystemSettingsService.get_all() + settings = SystemSettingsService.get_all(reverse=False, order_by="name") result = [] for setting in settings: - result.append( - { - "name": setting.name, - "source": setting.source, - "data_type": setting.data_type, - "value": setting.value, - } - ) + result.append(SettingsMgr._format_setting(setting)) return result @staticmethod def get_by_name(name: str): settings = SystemSettingsService.get_by_name(name) if len(settings) == 0: - raise AdminException(f"Can't get setting: {name}") + settings = SystemSettingsService.get_by_name_prefix(name) + if len(settings) == 0: + raise AdminException(f"Can't get setting: {name}") result = [] for setting in settings: - result.append( - { - "name": setting.name, - "source": setting.source, - "data_type": setting.data_type, - "value": setting.value, - } - ) + result.append(SettingsMgr._format_setting(setting)) return result @staticmethod @@ -367,6 +396,7 @@ class SettingsMgr: settings = SystemSettingsService.get_by_name(name) if len(settings) == 1: setting = settings[0] + SettingsMgr._validate_value(name, setting.data_type, value) setting.value = value setting_dict = setting.to_dict() SystemSettingsService.update_by_name(name, setting_dict) @@ -376,12 +406,8 @@ class SettingsMgr: # Create new setting if it doesn't exist # Determine data_type based on name and value - if name.startswith("sandbox."): - data_type = "json" - elif name.endswith(".enabled"): - data_type = "boolean" - else: - data_type = "string" + data_type = SettingsMgr._infer_data_type(name) + SettingsMgr._validate_value(name, data_type, value) new_setting = { "name": name, diff --git a/api/db/services/system_settings_service.py b/api/db/services/system_settings_service.py index eac7019e6a..0b0bde8024 100644 --- a/api/db/services/system_settings_service.py +++ b/api/db/services/system_settings_service.py @@ -26,7 +26,13 @@ class SystemSettingsService(CommonService): @classmethod @DB.connection_context() def get_by_name(cls, name): - objs = cls.model.select().where(cls.model.name == name) + objs = cls.model.select().where(cls.model.name == name).order_by(cls.model.name.asc()) + return objs + + @classmethod + @DB.connection_context() + def get_by_name_prefix(cls, name_prefix): + objs = cls.model.select().where(cls.model.name.startswith(name_prefix)).order_by(cls.model.name.asc()) return objs @classmethod diff --git a/docs/administrator/admin/ragflow_cli.md b/docs/administrator/admin/ragflow_cli.md index 6c7bc5943c..682105116c 100644 --- a/docs/administrator/admin/ragflow_cli.md +++ b/docs/administrator/admin/ragflow_cli.md @@ -468,18 +468,18 @@ Revoke successfully! ``` ragflow> list vars; +-----------+---------------------+--------------+-----------+ -| data_type | name | source | value | +| data_type | name | setting_type | value | +-----------+---------------------+--------------+-----------+ -| string | default_role | variable | user | -| bool | enable_whitelist | variable | true | -| string | mail.default_sender | variable | | -| string | mail.password | variable | | -| integer | mail.port | variable | 15 | -| string | mail.server | variable | localhost | -| integer | mail.timeout | variable | 10 | -| bool | mail.use_ssl | variable | true | -| bool | mail.use_tls | variable | false | -| string | mail.username | variable | | +| string | default_role | config | user | +| bool | enable_whitelist | config | true | +| string | mail.default_sender | config | | +| string | mail.password | config | | +| integer | mail.port | config | 15 | +| string | mail.server | config | localhost | +| integer | mail.timeout | config | 10 | +| bool | mail.use_ssl | config | true | +| bool | mail.use_tls | config | false | +| string | mail.username | config | | +-----------+---------------------+--------------+-----------+ ``` @@ -490,9 +490,9 @@ ragflow> list vars; ``` ragflow> show var mail.server; +-----------+-------------+--------------+-----------+ -| data_type | name | source | value | +| data_type | name | setting_type | value | +-----------+-------------+--------------+-----------+ -| string | mail.server | variable | localhost | +| string | mail.server | config | localhost | +-----------+-------------+--------------+-----------+ ``` diff --git a/internal/admin/service.go b/internal/admin/service.go index b857ac59aa..4b30b4f26c 100644 --- a/internal/admin/service.go +++ b/internal/admin/service.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "encoding/base64" "encoding/hex" + "encoding/json" "errors" "fmt" "net/http" @@ -1451,9 +1452,59 @@ func NewAdminException(message string) *AdminException { } } +func formatSystemSetting(setting entity.SystemSettings) map[string]interface{} { + return map[string]interface{}{ + "data_type": setting.DataType, + "name": setting.Name, + "setting_type": "config", + "value": setting.Value, + } +} + +func formatSystemSettings(settings []entity.SystemSettings) []map[string]interface{} { + result := make([]map[string]interface{}, 0, len(settings)) + for _, setting := range settings { + result = append(result, formatSystemSetting(setting)) + } + return result +} + +func validateSystemSettingValue(setting entity.SystemSettings, value string) error { + dataType := strings.ToLower(setting.DataType) + switch dataType { + case "string": + return nil + case "integer", "int": + if _, err := strconv.Atoi(value); err != nil { + return NewAdminException(fmt.Sprintf("Invalid integer value for %s: %s", setting.Name, value)) + } + case "bool", "boolean": + if value != "true" && value != "false" { + return NewAdminException(fmt.Sprintf("Invalid bool value for %s: expected true or false", setting.Name)) + } + case "json": + if !json.Valid([]byte(value)) { + return NewAdminException(fmt.Sprintf("Invalid JSON value for %s", setting.Name)) + } + default: + return NewAdminException(fmt.Sprintf("Unsupported data type for %s: %s", setting.Name, setting.DataType)) + } + return nil +} + +func inferSystemSettingDataType(name string) string { + if strings.HasPrefix(name, "sandbox.") { + return "json" + } + if strings.HasSuffix(name, ".enabled") { + return "bool" + } + return "string" +} + // GetVariable get variable by name -// Returns the system setting with the given name -// Returns AdminException if the setting is not found +// Returns the exact system setting with the given name, or settings matching the +// given name prefix when an exact setting does not exist. func (s *Service) GetVariable(varName string) ([]map[string]interface{}, error) { settings, err := s.systemSettingsDAO.GetByName(varName) if err != nil { @@ -1461,19 +1512,15 @@ func (s *Service) GetVariable(varName string) ([]map[string]interface{}, error) } if len(settings) == 0 { - return nil, NewAdminException("Can't get setting: " + varName) + settings, err = s.systemSettingsDAO.GetByNamePrefix(varName) + if err != nil { + return nil, err + } + if len(settings) == 0 { + return nil, NewAdminException("Can't get setting: " + varName) + } } - - result := make([]map[string]interface{}, 0, len(settings)) - for _, setting := range settings { - result = append(result, map[string]interface{}{ - "name": setting.Name, - "source": setting.Source, - "data_type": setting.DataType, - "value": setting.Value, - }) - } - return result, nil + return formatSystemSettings(settings), nil } // GetAllVariables get all variables @@ -1484,16 +1531,7 @@ func (s *Service) GetAllVariables() ([]map[string]interface{}, error) { return nil, err } - result := make([]map[string]interface{}, 0, len(settings)) - for _, setting := range settings { - result = append(result, map[string]interface{}{ - "name": setting.Name, - "source": setting.Source, - "data_type": setting.DataType, - "value": setting.Value, - }) - } - return result, nil + return formatSystemSettings(settings), nil } // SetVariable set variable @@ -1507,27 +1545,25 @@ func (s *Service) SetVariable(varName, varValue string) error { if len(settings) == 1 { setting := &settings[0] + if err := validateSystemSettingValue(*setting, varValue); err != nil { + return err + } setting.Value = varValue return s.systemSettingsDAO.UpdateByName(varName, setting) } else if len(settings) > 1 { return NewAdminException("Can't update more than 1 setting: " + varName) } - // Create new setting if it doesn't exist - // Determine data_type based on name and value - dataType := "string" - if len(varName) >= 7 && varName[:7] == "sandbox" { - dataType = "json" - } else if len(varName) >= 9 && varName[len(varName)-9:] == ".enabled" { - dataType = "boolean" - } - + dataType := inferSystemSettingDataType(varName) newSetting := &entity.SystemSettings{ Name: varName, Value: varValue, Source: "admin", DataType: dataType, } + if err := validateSystemSettingValue(*newSetting, varValue); err != nil { + return err + } return s.systemSettingsDAO.Create(newSetting) } diff --git a/internal/admin/service_variables_test.go b/internal/admin/service_variables_test.go new file mode 100644 index 0000000000..2b94a09088 --- /dev/null +++ b/internal/admin/service_variables_test.go @@ -0,0 +1,65 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package admin + +import ( + "ragflow/internal/entity" + "testing" +) + +func TestValidateSystemSettingValue(t *testing.T) { + tests := []struct { + name string + dataType string + value string + wantError bool + }{ + {name: "string accepts arbitrary text", dataType: "string", value: "local host"}, + {name: "integer accepts digits", dataType: "integer", value: "15"}, + {name: "integer rejects text", dataType: "integer", value: "localhost", wantError: true}, + {name: "bool accepts true", dataType: "bool", value: "true"}, + {name: "bool accepts false", dataType: "bool", value: "false"}, + {name: "bool rejects non bool", dataType: "bool", value: "yes", wantError: true}, + {name: "json accepts object", dataType: "json", value: `{"endpoint":"http://localhost:9385"}`}, + {name: "json rejects invalid", dataType: "json", value: "{", wantError: true}, + {name: "unknown type rejects", dataType: "float", value: "1.2", wantError: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setting := entity.SystemSettings{Name: "test.setting", DataType: tt.dataType} + err := validateSystemSettingValue(setting, tt.value) + if (err != nil) != tt.wantError { + t.Fatalf("validateSystemSettingValue() error = %v, wantError %v", err, tt.wantError) + } + }) + } +} + +func TestInferSystemSettingDataType(t *testing.T) { + tests := map[string]string{ + "sandbox.self_managed": "json", + "mail.enabled": "bool", + "mail.server": "string", + } + + for name, want := range tests { + if got := inferSystemSettingDataType(name); got != want { + t.Fatalf("inferSystemSettingDataType(%q) = %q, want %q", name, got, want) + } + } +} diff --git a/internal/cli/admin_command.go b/internal/cli/admin_command.go index f6ab603af5..d1c3763601 100644 --- a/internal/cli/admin_command.go +++ b/internal/cli/admin_command.go @@ -596,6 +596,142 @@ func (c *RAGFlowClient) ShowService(cmd *Command) (ResponseIf, error) { return &result, nil } +func normalizeVariableRows(rows []map[string]interface{}) { + for _, row := range rows { + if _, ok := row["setting_type"]; ok { + delete(row, "source") + continue + } + if _, ok := row["source"]; ok { + row["setting_type"] = "config" + delete(row, "source") + } + } +} + +// ListVariables lists all system variables (admin mode only). +func (c *RAGFlowClient) ListVariables(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + if iterations > 1 { + return c.HTTPClient.RequestWithIterations("GET", "/admin/variables", "admin", nil, nil, iterations) + } + + resp, err := c.HTTPClient.Request("GET", "/admin/variables", "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list variables: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list variables: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list variables failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + normalizeVariableRows(result.Data) + result.Duration = resp.Duration + return &result, nil +} + +// ShowVariable shows system variables by exact name or name prefix (admin mode only). +func (c *RAGFlowClient) ShowVariable(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + varName, ok := cmd.Params["var_name"].(string) + if !ok { + return nil, fmt.Errorf("var_name not provided") + } + + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + payload := map[string]interface{}{"var_name": varName} + if iterations > 1 { + return c.HTTPClient.RequestWithIterations("GET", "/admin/variables", "admin", nil, payload, iterations) + } + + resp, err := c.HTTPClient.Request("GET", "/admin/variables", "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to show variable: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to show variable: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("show variable failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + normalizeVariableRows(result.Data) + result.Duration = resp.Duration + return &result, nil +} + +// SetVariable updates a system variable (admin mode only). +func (c *RAGFlowClient) SetVariable(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + varName, ok := cmd.Params["var_name"].(string) + if !ok { + return nil, fmt.Errorf("var_name not provided") + } + varValue, ok := cmd.Params["var_value"].(string) + if !ok { + return nil, fmt.Errorf("var_value not provided") + } + + payload := map[string]interface{}{ + "var_name": varName, + "var_value": varValue, + } + resp, err := c.HTTPClient.Request("PUT", "/admin/variables", "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to set variable: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to set variable: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result MessageResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("set variable failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + // ListUsers lists all users (admin mode only) // Returns (result_map, error) - result_map is non-nil for benchmark mode func (c *RAGFlowClient) ListUsers(cmd *Command) (ResponseIf, error) { diff --git a/internal/cli/admin_parser.go b/internal/cli/admin_parser.go index c1b2edab5a..9f4c6228e8 100644 --- a/internal/cli/admin_parser.go +++ b/internal/cli/admin_parser.go @@ -1173,7 +1173,7 @@ func (p *Parser) parseAdminSetVariable() (*Command, error) { } p.nextToken() - varValue, err := p.parseIdentifier() + varValue, err := p.parseVariableValue() if err != nil { return nil, err } diff --git a/internal/cli/admin_variables_test.go b/internal/cli/admin_variables_test.go new file mode 100644 index 0000000000..c871bca18e --- /dev/null +++ b/internal/cli/admin_variables_test.go @@ -0,0 +1,235 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package cli + +import ( + "encoding/json" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "testing" +) + +func TestParseAdminVariableCommands(t *testing.T) { + tests := []struct { + name string + input string + command string + varName string + varValue string + hasValue bool + adminMode bool + }{ + { + name: "list variables", + input: "list vars;", + command: "list_variables", + adminMode: true, + }, + { + name: "show variables by prefix", + input: "show var mail;", + command: "show_variable", + varName: "mail", + adminMode: true, + }, + { + name: "set integer variable", + input: "set var mail.port 15;", + command: "set_variable", + varName: "mail.port", + varValue: "15", + hasValue: true, + adminMode: true, + }, + { + name: "set quoted string variable", + input: `set var mail.server "local host";`, + command: "set_variable", + varName: "mail.server", + varValue: "local host", + hasValue: true, + adminMode: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd, err := NewParser(tt.input).Parse(tt.adminMode) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + if cmd.Type != tt.command { + t.Fatalf("command type = %q, want %q", cmd.Type, tt.command) + } + if tt.varName != "" && cmd.Params["var_name"] != tt.varName { + t.Fatalf("var_name = %v, want %q", cmd.Params["var_name"], tt.varName) + } + if tt.hasValue && cmd.Params["var_value"] != tt.varValue { + t.Fatalf("var_value = %v, want %q", cmd.Params["var_value"], tt.varValue) + } + }) + } +} + +func newAdminTestClient(t *testing.T, handler http.HandlerFunc) (*RAGFlowClient, func()) { + t.Helper() + + server := httptest.NewServer(handler) + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("parse test server URL: %v", err) + } + host, portText, err := net.SplitHostPort(serverURL.Host) + if err != nil { + t.Fatalf("split host port: %v", err) + } + port, err := strconv.Atoi(portText) + if err != nil { + t.Fatalf("parse port: %v", err) + } + + client := NewRAGFlowClient("admin") + client.HTTPClient.Host = host + client.HTTPClient.Port = port + client.HTTPClient.client = server.Client() + client.HTTPClient.LoginToken = "test-token" + + return client, server.Close +} + +func TestListVariablesUsesAdminVariablesEndpoint(t *testing.T) { + client, closeServer := newAdminTestClient(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("method = %s, want GET", r.Method) + return + } + if r.URL.Path != "/api/v1/admin/variables" { + t.Errorf("path = %s, want /api/v1/admin/variables", r.URL.Path) + return + } + if r.Header.Get("Authorization") != "test-token" { + t.Errorf("Authorization header = %q", r.Header.Get("Authorization")) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "code": 0, + "message": "", + "data": []map[string]interface{}{ + {"data_type": "string", "name": "mail.server", "source": "variable", "value": "localhost"}, + }, + }) + }) + defer closeServer() + + resp, err := client.ListVariables(NewCommand("list_variables")) + if err != nil { + t.Fatalf("ListVariables() error = %v", err) + } + result := resp.(*CommonResponse) + if got := result.Data[0]["setting_type"]; got != "config" { + t.Fatalf("setting_type = %v, want config", got) + } + if _, ok := result.Data[0]["source"]; ok { + t.Fatalf("source column should be normalized away: %#v", result.Data[0]) + } +} + +func TestShowVariableSendsRequestedName(t *testing.T) { + client, closeServer := newAdminTestClient(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("method = %s, want GET", r.Method) + return + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + var request map[string]string + if err := json.Unmarshal(body, &request); err != nil { + t.Errorf("request body is not JSON: %v", err) + return + } + if request["var_name"] != "mail" { + t.Errorf("var_name = %q, want mail", request["var_name"]) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "code": 0, + "message": "", + "data": []map[string]interface{}{ + {"data_type": "string", "name": "mail.server", "setting_type": "config", "value": "localhost"}, + }, + }) + }) + defer closeServer() + + cmd := NewCommand("show_variable") + cmd.Params["var_name"] = "mail" + resp, err := client.ShowVariable(cmd) + if err != nil { + t.Fatalf("ShowVariable() error = %v", err) + } + result := resp.(*CommonResponse) + if got := result.Data[0]["name"]; got != "mail.server" { + t.Fatalf("name = %v, want mail.server", got) + } +} + +func TestSetVariableReturnsServerConfirmation(t *testing.T) { + client, closeServer := newAdminTestClient(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + t.Errorf("method = %s, want PUT", r.Method) + return + } + var request map[string]string + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + t.Errorf("request body is not JSON: %v", err) + return + } + if request["var_name"] != "mail.server" { + t.Errorf("var_name = %q, want mail.server", request["var_name"]) + return + } + if request["var_value"] != "localhost" { + t.Errorf("var_value = %q, want localhost", request["var_value"]) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "code": 0, + "message": "Set variable successfully", + "data": nil, + }) + }) + defer closeServer() + + cmd := NewCommand("set_variable") + cmd.Params["var_name"] = "mail.server" + cmd.Params["var_value"] = "localhost" + resp, err := client.SetVariable(cmd) + if err != nil { + t.Fatalf("SetVariable() error = %v", err) + } + result := resp.(*MessageResponse) + if result.Message != "Set variable successfully" { + t.Fatalf("message = %q, want Set variable successfully", result.Message) + } +} diff --git a/internal/cli/client.go b/internal/cli/client.go index 679e1d19b5..eebfb33d0d 100644 --- a/internal/cli/client.go +++ b/internal/cli/client.go @@ -155,6 +155,12 @@ func (c *RAGFlowClient) ExecuteAdminCommand(cmd *Command) (ResponseIf, error) { return c.ShowAdminVersion(cmd) case "show_user": return c.ShowUser(cmd) + case "list_variables": + return c.ListVariables(cmd) + case "show_variable": + return c.ShowVariable(cmd) + case "set_variable": + return c.SetVariable(cmd) case "list_user_datasets": return c.ListUserDatasets(cmd) case "list_agents": diff --git a/internal/cli/parser.go b/internal/cli/parser.go index 035d6b12e5..02723f5467 100644 --- a/internal/cli/parser.go +++ b/internal/cli/parser.go @@ -285,6 +285,15 @@ func (p *Parser) parseIdentifier() (string, error) { return p.curToken.Value, nil } +func (p *Parser) parseVariableValue() (string, error) { + switch p.curToken.Type { + case TokenIdentifier, TokenQuotedString, TokenInteger, TokenFloat: + return p.curToken.Value, nil + default: + return "", fmt.Errorf("expected variable value, got %s", p.curToken.Value) + } +} + func (p *Parser) parseNumber() (int, error) { if p.curToken.Type != TokenInteger { return 0, fmt.Errorf("expected number, got %s", p.curToken.Value) diff --git a/internal/cli/response.go b/internal/cli/response.go index 440345be9d..19daffc9e6 100644 --- a/internal/cli/response.go +++ b/internal/cli/response.go @@ -149,6 +149,34 @@ func (r *SimpleResponse) PrintOut() { } } +type MessageResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Duration float64 + OutputFormat OutputFormat +} + +func (r *MessageResponse) Type() string { + return "message" +} + +func (r *MessageResponse) TimeCost() float64 { + return r.Duration +} + +func (r *MessageResponse) SetOutputFormat(format OutputFormat) { + r.OutputFormat = format +} + +func (r *MessageResponse) PrintOut() { + if r.Code == 0 { + fmt.Println(r.Message) + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + type NonStreamResponse struct { Code int `json:"code"` ReasoningContent string `json:"reasoning_content"` diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index 38b2d2d29d..36b7a35297 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -1875,7 +1875,7 @@ func (p *Parser) parseSetVariable() (*Command, error) { } p.nextToken() - varValue, err := p.parseIdentifier() + varValue, err := p.parseVariableValue() if err != nil { return nil, err } diff --git a/internal/dao/system_settings.go b/internal/dao/system_settings.go index c224ad0fcd..dd7762e42e 100644 --- a/internal/dao/system_settings.go +++ b/internal/dao/system_settings.go @@ -35,7 +35,7 @@ func NewSystemSettingsDAO() *SystemSettingsDAO { // Returns all system settings records from database func (d *SystemSettingsDAO) GetAll() ([]entity.SystemSettings, error) { var settings []entity.SystemSettings - err := DB.Find(&settings).Error + err := DB.Order("name ASC").Find(&settings).Error if err != nil { return nil, err } @@ -46,7 +46,18 @@ func (d *SystemSettingsDAO) GetAll() ([]entity.SystemSettings, error) { // Returns settings records that match the given name func (d *SystemSettingsDAO) GetByName(name string) ([]entity.SystemSettings, error) { var settings []entity.SystemSettings - err := DB.Where("name = ?", name).Find(&settings).Error + err := DB.Where("name = ?", name).Order("name ASC").Find(&settings).Error + if err != nil { + return nil, err + } + return settings, nil +} + +// GetByNamePrefix get system settings by name prefix +// Returns settings records whose names start with the given prefix. +func (d *SystemSettingsDAO) GetByNamePrefix(namePrefix string) ([]entity.SystemSettings, error) { + var settings []entity.SystemSettings + err := DB.Where("name LIKE ?", namePrefix+"%").Order("name ASC").Find(&settings).Error if err != nil { return nil, err }