mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Implement UpdateDataset and UpdateMetadata in GO (#13928)
### What problem does this PR solve? Implement UpdateDataset and UpdateMetadata in GO Add cli: UPDATE CHUNK <chunk_id> OF DATASET <dataset_name> SET <update_fields> REMOVE TAGS 'tag1', 'tag2' from DATASET 'dataset_name'; SET METADATA OF DOCUMENT <doc_id> TO <meta> ### Type of change - [ ] Refactoring
This commit is contained in:
@@ -99,6 +99,9 @@ sql_command: login_user
|
||||
| list_chunks
|
||||
| insert_dataset_from_file
|
||||
| insert_metadata_from_file
|
||||
| update_chunk
|
||||
| set_metadata
|
||||
| remove_tags
|
||||
| create_chat_session
|
||||
| drop_chat_session
|
||||
| list_chat_sessions
|
||||
@@ -114,10 +117,12 @@ sql_command: login_user
|
||||
// meta command definition
|
||||
meta_command: "\\" meta_command_name [meta_args]
|
||||
|
||||
COMMA: ","
|
||||
|
||||
meta_command_name: /[a-zA-Z?]+/
|
||||
meta_args: (meta_arg)+
|
||||
|
||||
meta_arg: /[^\\s"']+/ | quoted_string
|
||||
meta_arg: /[^\s"',]+/ | quoted_string
|
||||
|
||||
// command definition
|
||||
|
||||
@@ -215,8 +220,11 @@ SIZE: "SIZE"i
|
||||
KEYWORDS: "KEYWORDS"i
|
||||
AVAILABLE: "AVAILABLE"i
|
||||
FILE: "FILE"i
|
||||
UPDATE: "UPDATE"i
|
||||
REMOVE: "REMOVE"i
|
||||
TAGS: "TAGS"i
|
||||
|
||||
login_user: LOGIN USER quoted_string ";"
|
||||
login_user: LOGIN USER quoted_string (PASSWORD quoted_string)? ";"
|
||||
list_services: LIST SERVICES ";"
|
||||
show_service: SHOW SERVICE NUMBER ";"
|
||||
startup_service: STARTUP SERVICE NUMBER ";"
|
||||
@@ -299,6 +307,9 @@ user_statement: ping_server
|
||||
| list_user_default_models
|
||||
| import_docs_into_dataset
|
||||
| search_on_datasets
|
||||
| update_chunk
|
||||
| set_metadata
|
||||
| remove_tags
|
||||
| create_chat_session
|
||||
| drop_chat_session
|
||||
| list_chat_sessions
|
||||
@@ -328,8 +339,8 @@ create_user_dataset_with_pipeline: CREATE DATASET quoted_string WITH EMBEDDING q
|
||||
drop_user_dataset: DROP DATASET quoted_string ";"
|
||||
list_user_dataset_files: LIST FILES OF DATASET quoted_string ";"
|
||||
list_user_dataset_documents: LIST DOCUMENTS OF DATASET quoted_string ";"
|
||||
list_user_datasets_metadata: LIST METADATA OF DATASETS quoted_string ("," quoted_string)* ";"
|
||||
list_user_documents_metadata_summary: LIST METADATA SUMMARY OF DATASET quoted_string (DOCUMENTS quoted_string ("," quoted_string)*)? ";"
|
||||
list_user_datasets_metadata: LIST METADATA OF DATASETS quoted_string (COMMA quoted_string)* ";"
|
||||
list_user_documents_metadata_summary: LIST METADATA SUMMARY OF DATASET quoted_string (DOCUMENTS quoted_string (COMMA quoted_string)*)? ";"
|
||||
list_user_agents: LIST AGENTS ";"
|
||||
list_user_chats: LIST CHATS ";"
|
||||
create_user_chat: CREATE CHAT quoted_string ";"
|
||||
@@ -353,11 +364,15 @@ parse_dataset_docs: PARSE quoted_string OF DATASET quoted_string ";"
|
||||
parse_dataset_sync: PARSE DATASET quoted_string SYNC ";"
|
||||
parse_dataset_async: PARSE DATASET quoted_string ASYNC ";"
|
||||
|
||||
update_chunk: UPDATE CHUNK quoted_string OF DATASET quoted_string SET quoted_string ";"
|
||||
set_metadata: SET METADATA OF DOCUMENT quoted_string TO quoted_string ";"
|
||||
remove_tags: REMOVE TAGS quoted_string (COMMA quoted_string)* FROM DATASET quoted_string ";"
|
||||
|
||||
// Internal CLI for GO
|
||||
insert_dataset_from_file: INSERT DATASET FROM FILE quoted_string ";"
|
||||
insert_metadata_from_file: INSERT METADATA FROM FILE quoted_string ";"
|
||||
|
||||
identifier_list: identifier ("," identifier)*
|
||||
identifier_list: identifier (COMMA identifier)*
|
||||
|
||||
identifier: WORD
|
||||
quoted_string: QUOTED_STRING
|
||||
@@ -381,7 +396,13 @@ class RAGFlowCLITransformer(Transformer):
|
||||
|
||||
def login_user(self, items):
|
||||
email = items[2].children[0].strip("'\"")
|
||||
return {"type": "login_user", "email": email}
|
||||
if len(items) == 5:
|
||||
# With password: LOGIN USER email PASSWORD password
|
||||
password = items[4].children[0].strip("'\"")
|
||||
return {"type": "login_user", "email": email, "password": password}
|
||||
else:
|
||||
# Without password: LOGIN USER email
|
||||
return {"type": "login_user", "email": email}
|
||||
|
||||
def ping_server(self, items):
|
||||
return {"type": "ping_server"}
|
||||
@@ -766,6 +787,44 @@ class RAGFlowCLITransformer(Transformer):
|
||||
file_path = items[4].children[0].strip("'\"")
|
||||
return {"type": "insert_metadata_from_file", "file_path": file_path}
|
||||
|
||||
def update_chunk(self, items):
|
||||
def get_quoted_value(item):
|
||||
if hasattr(item, 'children') and item.children:
|
||||
return item.children[0].strip("'\"")
|
||||
return str(item).strip("'\"")
|
||||
|
||||
chunk_id = get_quoted_value(items[2])
|
||||
dataset_name = get_quoted_value(items[5])
|
||||
json_body = get_quoted_value(items[7])
|
||||
return {"type": "update_chunk", "chunk_id": chunk_id, "dataset_name": dataset_name, "json_body": json_body}
|
||||
|
||||
def set_metadata(self, items):
|
||||
doc_id = items[4].children[0].strip("'\"")
|
||||
meta_json = items[6].children[0].strip("'\"")
|
||||
return {"type": "set_metadata", "doc_id": doc_id, "meta": meta_json}
|
||||
|
||||
def remove_tags(self, items):
|
||||
# items: REMOVE, TAGS, quoted_string(tag1), quoted_string(tag2), ..., FROM, DATASET, quoted_string(dataset_name), ";"
|
||||
tags = []
|
||||
# Start from index 2 (after TAGS keyword) and parse quoted strings until FROM
|
||||
for i in range(2, len(items)):
|
||||
item = items[i]
|
||||
# Check for FROM token to stop
|
||||
if hasattr(item, 'type') and item.type == 'FROM':
|
||||
break
|
||||
if hasattr(item, 'children') and item.children:
|
||||
tag = item.children[0].strip("'\"")
|
||||
tags.append(tag)
|
||||
# Find dataset_name: quoted_string after DATASET
|
||||
dataset_name = None
|
||||
for i, item in enumerate(items):
|
||||
# Check if item is a DATASET token
|
||||
if hasattr(item, 'type') and item.type == 'DATASET':
|
||||
# Next item should be quoted_string
|
||||
dataset_name = items[i + 1].children[0].strip("'\"")
|
||||
break
|
||||
return {"type": "remove_tags", "dataset_name": dataset_name, "tags": tags}
|
||||
|
||||
def list_chunks(self, items):
|
||||
doc_id = items[4].children[0].strip("'\"")
|
||||
result = {"type": "list_chunks", "doc_id": doc_id}
|
||||
|
||||
@@ -18,6 +18,9 @@ import sys
|
||||
import argparse
|
||||
import base64
|
||||
import getpass
|
||||
import os
|
||||
import atexit
|
||||
import readline
|
||||
from cmd import Cmd
|
||||
from typing import Any, Dict, List
|
||||
|
||||
@@ -61,6 +64,12 @@ class RAGFlowCLI(Cmd):
|
||||
self.port: int = 0
|
||||
self.mode: str = "admin"
|
||||
self.ragflow_client = None
|
||||
# History file for readline persistence
|
||||
self.history_file = os.path.expanduser("~/.ragflow_cli_history")
|
||||
# Load existing history
|
||||
self._load_history()
|
||||
# Register cleanup to save history on exit
|
||||
atexit.register(self._save_history)
|
||||
|
||||
intro = r"""Type "\h" for help."""
|
||||
prompt = "ragflow> "
|
||||
@@ -99,6 +108,7 @@ class RAGFlowCLI(Cmd):
|
||||
return {"type": "empty"}
|
||||
|
||||
self.command_history.append(command_str)
|
||||
readline.add_history(command_str)
|
||||
|
||||
try:
|
||||
result = self.parser.parse(command_str)
|
||||
@@ -210,6 +220,21 @@ class RAGFlowCLI(Cmd):
|
||||
|
||||
print(separator)
|
||||
|
||||
def _load_history(self):
|
||||
"""Load command history from file."""
|
||||
try:
|
||||
if os.path.exists(self.history_file):
|
||||
readline.read_history_file(self.history_file)
|
||||
except Exception:
|
||||
pass # Ignore errors loading history
|
||||
|
||||
def _save_history(self):
|
||||
"""Save command history to file."""
|
||||
try:
|
||||
readline.write_history_file(self.history_file)
|
||||
except Exception:
|
||||
pass # Ignore errors saving history
|
||||
|
||||
def run_interactive(self, args):
|
||||
if self.verify_auth(args, single_command=False, auth=args["auth"]):
|
||||
print(r"""
|
||||
|
||||
@@ -24,7 +24,6 @@ from http_client import HttpClient
|
||||
from lark import Tree
|
||||
from user import encrypt_password, login_user
|
||||
|
||||
import getpass
|
||||
import base64
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
from Cryptodome.PublicKey import RSA
|
||||
@@ -63,10 +62,16 @@ class RAGFlowClient:
|
||||
return
|
||||
|
||||
email: str = command["email"]
|
||||
user_password = getpass.getpass(f"password for {email}: ").strip()
|
||||
user_password: str = command.get("password")
|
||||
if not user_password:
|
||||
import getpass
|
||||
user_password = getpass.getpass("Password: ")
|
||||
try:
|
||||
token = login_user(self.http_client, self.server_type, email, user_password)
|
||||
self.http_client.login_token = token
|
||||
# Also store as api_key for API endpoint authentication
|
||||
if self.server_type == "user":
|
||||
self.http_client.api_key = token
|
||||
print(f"Login user {email} successfully")
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
@@ -1506,6 +1511,108 @@ class RAGFlowClient:
|
||||
else:
|
||||
print(f"Fail to insert metadata from file, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def update_chunk(self, command_dict):
|
||||
if self.server_type != "user":
|
||||
print("This command is only allowed in USER mode")
|
||||
return
|
||||
|
||||
chunk_id = command_dict["chunk_id"]
|
||||
dataset_name = command_dict["dataset_name"]
|
||||
json_body_str = command_dict["json_body"]
|
||||
|
||||
# Get dataset_id from dataset_name
|
||||
dataset_id = self._get_dataset_id(dataset_name)
|
||||
if dataset_id is None:
|
||||
return
|
||||
|
||||
# Get doc_id from chunk_id via GET /chunk/get
|
||||
response = self.http_client.request("GET", f"/chunk/get?chunk_id={chunk_id}", use_api_base=False,
|
||||
auth_kind="web")
|
||||
res_json = response.json()
|
||||
if response.status_code != 200:
|
||||
print(f"Fail to get chunk info, code: {res_json.get('code')}, message: {res_json.get('message')}")
|
||||
return
|
||||
|
||||
doc_id = None
|
||||
if res_json.get("code") == 0 and res_json.get("data"):
|
||||
doc_id = res_json["data"].get("doc_id")
|
||||
|
||||
if not doc_id:
|
||||
print(f"Could not find document_id for chunk {chunk_id}")
|
||||
return
|
||||
|
||||
# Parse json_body
|
||||
try:
|
||||
payload = json.loads(json_body_str)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Invalid JSON body: {e}")
|
||||
return
|
||||
|
||||
# Call PUT /datasets/{dataset_id}/documents/{doc_id}/chunks/{chunk_id}
|
||||
path = f"/datasets/{dataset_id}/documents/{doc_id}/chunks/{chunk_id}"
|
||||
response = self.http_client.request("PUT", path, json_body=payload, use_api_base=True, auth_kind="api")
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
if res_json.get("code") == 0:
|
||||
print(f"Success to update chunk: {chunk_id}")
|
||||
else:
|
||||
print(f"Fail to update chunk, code: {res_json.get('code')}, message: {res_json.get('message')}")
|
||||
else:
|
||||
print(f"Fail to update chunk, HTTP {response.status_code}")
|
||||
|
||||
def set_metadata(self, command_dict):
|
||||
if self.server_type != "user":
|
||||
print("This command is only allowed in USER mode")
|
||||
return
|
||||
|
||||
doc_id = command_dict["doc_id"]
|
||||
meta_json_str = command_dict["meta"]
|
||||
|
||||
# Send meta as JSON string
|
||||
payload = {
|
||||
"doc_id": doc_id,
|
||||
"meta": meta_json_str,
|
||||
}
|
||||
|
||||
response = self.http_client.request("POST", "/document/set_meta", json_body=payload,
|
||||
use_api_base=False, auth_kind="web")
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
if res_json.get("code") == 0:
|
||||
print(f"Success to set metadata for document: {doc_id}")
|
||||
else:
|
||||
print(f"Fail to set metadata, code: {res_json.get('code')}, message: {res_json.get('message')}")
|
||||
else:
|
||||
print(f"Fail to set metadata, HTTP {response.status_code}")
|
||||
|
||||
def remove_tags(self, command_dict):
|
||||
if self.server_type != "user":
|
||||
print("This command is only allowed in USER mode")
|
||||
return
|
||||
|
||||
dataset_name = command_dict["dataset_name"]
|
||||
dataset_id = self._get_dataset_id(dataset_name)
|
||||
if dataset_id is None:
|
||||
print(f"Dataset not found: {dataset_name}")
|
||||
return
|
||||
|
||||
tags = command_dict["tags"]
|
||||
|
||||
payload = {
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
response = self.http_client.request("POST", f"/kb/{dataset_id}/rm_tags", json_body=payload,
|
||||
use_api_base=False, auth_kind="web")
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
if res_json.get("code") == 0:
|
||||
print(f"Success to remove tags from dataset: {dataset_name}")
|
||||
else:
|
||||
print(f"Fail to remove tags, code: {res_json.get('code')}, message: {res_json.get('message')}")
|
||||
else:
|
||||
print(f"Fail to remove tags, HTTP {response.status_code}")
|
||||
|
||||
def list_chunks(self, command_dict):
|
||||
if self.server_type != "user":
|
||||
print("This command is only allowed in USER mode")
|
||||
@@ -1903,6 +2010,12 @@ def run_command(client: RAGFlowClient, command_dict: dict):
|
||||
return client.insert_dataset_from_file(command_dict)
|
||||
case "insert_metadata_from_file":
|
||||
return client.insert_metadata_from_file(command_dict)
|
||||
case "update_chunk":
|
||||
return client.update_chunk(command_dict)
|
||||
case "set_metadata":
|
||||
return client.set_metadata(command_dict)
|
||||
case "remove_tags":
|
||||
return client.remove_tags(command_dict)
|
||||
case "list_chunks":
|
||||
return client.list_chunks(command_dict)
|
||||
case "meta":
|
||||
|
||||
@@ -136,6 +136,24 @@ def _load_user():
|
||||
return user[0]
|
||||
except Exception as e_api_token:
|
||||
logging.warning(f"load_user got exception {e_api_token}")
|
||||
# Fallback: try raw authorization value as access_token (for login tokens sent without JWT)
|
||||
try:
|
||||
authorization = request.headers.get("Authorization")
|
||||
if authorization and len(authorization.split()) == 1:
|
||||
# Single value without "Bearer " prefix - try as raw access_token
|
||||
access_token = authorization.strip()
|
||||
if access_token and len(access_token) >= 32:
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||
return None
|
||||
g.user = user[0]
|
||||
return user[0]
|
||||
except Exception as e_raw_token:
|
||||
logging.warning(f"load_user raw token fallback got exception {e_raw_token}")
|
||||
|
||||
|
||||
current_user = LocalProxy(_load_user)
|
||||
|
||||
@@ -307,18 +307,43 @@ def token_required(func):
|
||||
raise err
|
||||
|
||||
token = authorization_list[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
err = WerkzeugUnauthorized(description="Authentication error: API key is invalid!")
|
||||
err.code = RetCode.AUTHENTICATION_ERROR
|
||||
raise err
|
||||
|
||||
# On success, inject tenant_id into the route function's kwargs
|
||||
kwargs["tenant_id"] = objs[0].tenant_id
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.iscoroutine(result):
|
||||
return await result
|
||||
return result
|
||||
# First try API token (explicit API token authentication)
|
||||
objs = APIToken.query(token=token)
|
||||
if objs:
|
||||
# On success, inject tenant_id into the route function's kwargs
|
||||
kwargs["tenant_id"] = objs[0].tenant_id
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.iscoroutine(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
# Fallback: try login token (for clients that use login token as API token)
|
||||
# Login tokens are JWT-encoded (URLSafeTimedSerializer), need to decode to get raw access_token
|
||||
from api.db.services.user_service import UserService
|
||||
from common.constants import StatusEnum
|
||||
from common import settings
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
try:
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
raw_token = str(jwt.loads(token))
|
||||
user = UserService.query(access_token=raw_token, status=StatusEnum.VALID.value)
|
||||
if user:
|
||||
# On success, inject tenant_id from user's tenant
|
||||
from api.db.services.user_service import UserTenantService
|
||||
tenants = UserTenantService.query(user_id=user[0].id)
|
||||
if tenants:
|
||||
kwargs["tenant_id"] = tenants[0].tenant_id
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.iscoroutine(result):
|
||||
return await result
|
||||
return result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
err = WerkzeugUnauthorized(description="Authentication error: API key is invalid!")
|
||||
err.code = RetCode.AUTHENTICATION_ERROR
|
||||
raise err
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -35,6 +35,17 @@ func (p *Parser) parseAdminLoginUser() (*Command, error) {
|
||||
cmd.Params["email"] = email
|
||||
|
||||
p.nextToken()
|
||||
// Optional: PASSWORD 'password'
|
||||
if p.curToken.Type == TokenPassword {
|
||||
p.nextToken()
|
||||
password, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cmd.Params["password"] = password
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
|
||||
@@ -333,6 +333,7 @@ func looksLikeSQL(s string) bool {
|
||||
"LOGIN ", "REGISTER ", "PING", "GRANT ", "REVOKE ",
|
||||
"SET ", "UNSET ", "UPDATE ", "DELETE ", "INSERT ",
|
||||
"SELECT ", "DESCRIBE ", "EXPLAIN ", "ADD ", "ENABLE ", "DISABLE ", "CHAT ", "USE", "THINK",
|
||||
"REMOVE ",
|
||||
}
|
||||
for _, prefix := range sqlPrefixes {
|
||||
if strings.HasPrefix(s, prefix) {
|
||||
@@ -988,6 +989,7 @@ Meta Commands:
|
||||
|
||||
Commands (User Mode):
|
||||
LOGIN USER 'email'; - Login as user
|
||||
LOGIN USER 'email' PASSWORD 'pwd'; - Login as user with password
|
||||
REGISTER USER 'name' AS 'nickname' PASSWORD 'pwd'; - Register new user
|
||||
SHOW VERSION; - Show version info
|
||||
PING; - Ping server
|
||||
|
||||
@@ -262,6 +262,12 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
return c.InsertDatasetFromFile(cmd)
|
||||
case "insert_metadata_from_file":
|
||||
return c.InsertMetadataFromFile(cmd)
|
||||
case "update_chunk":
|
||||
return c.UpdateChunk(cmd)
|
||||
case "set_meta":
|
||||
return c.SetMeta(cmd)
|
||||
case "rm_tags":
|
||||
return c.RmTags(cmd)
|
||||
// TODO: Implement other commands
|
||||
default:
|
||||
return nil, fmt.Errorf("command '%s' would be executed with API", cmd.Type)
|
||||
|
||||
@@ -84,10 +84,19 @@ func (c *HTTPClient) BuildURL(path string, useAPIBase bool) string {
|
||||
// Headers builds the request headers
|
||||
func (c *HTTPClient) Headers(authKind string, extra map[string]string) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
if c.APIToken != "" {
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", c.APIToken)
|
||||
} else if c.LoginToken != "" {
|
||||
headers["Authorization"] = c.LoginToken
|
||||
|
||||
switch authKind {
|
||||
case "api":
|
||||
if c.APIToken != "" {
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", c.APIToken)
|
||||
} else if c.LoginToken != "" {
|
||||
// Fallback to login token for API requests (user mode)
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", c.LoginToken)
|
||||
}
|
||||
case "web", "admin":
|
||||
if c.LoginToken != "" {
|
||||
headers["Authorization"] = c.LoginToken
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range extra {
|
||||
|
||||
@@ -327,6 +327,16 @@ func (l *Lexer) lookupIdent(ident string) Token {
|
||||
return Token{Type: TokenMetadata, Value: ident}
|
||||
case "USE":
|
||||
return Token{Type: TokenUse, Value: ident}
|
||||
case "UPDATE":
|
||||
return Token{Type: TokenUpdate, Value: ident}
|
||||
case "REMOVE":
|
||||
return Token{Type: TokenRemove, Value: ident}
|
||||
case "CHUNK":
|
||||
return Token{Type: TokenChunk, Value: ident}
|
||||
case "DOCUMENT":
|
||||
return Token{Type: TokenDocument, Value: ident}
|
||||
case "TAGS":
|
||||
return Token{Type: TokenTag, Value: ident}
|
||||
default:
|
||||
return Token{Type: TokenIdentifier, Value: ident}
|
||||
}
|
||||
|
||||
@@ -196,6 +196,10 @@ func (p *Parser) parseUserCommand() (*Command, error) {
|
||||
return p.parseThinkCommand()
|
||||
case TokenUse:
|
||||
return p.parseUseCommand()
|
||||
case TokenUpdate:
|
||||
return p.parseUpdateCommand()
|
||||
case TokenRemove:
|
||||
return p.parseRemoveCommand()
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown command: %s", p.curToken.Value)
|
||||
}
|
||||
@@ -233,7 +237,7 @@ func (p *Parser) expectSemicolon() error {
|
||||
}
|
||||
|
||||
func isKeyword(tokenType int) bool {
|
||||
return tokenType >= TokenLogin && tokenType <= TokenMetadata
|
||||
return tokenType >= TokenLogin && tokenType <= TokenTag
|
||||
}
|
||||
|
||||
// isCECommand checks if the given string is a ContextEngine command
|
||||
|
||||
@@ -115,6 +115,11 @@ const (
|
||||
TokenInsert
|
||||
TokenFile
|
||||
TokenMetadata
|
||||
TokenUpdate
|
||||
TokenRemove
|
||||
TokenChunk
|
||||
TokenDocument
|
||||
TokenTag
|
||||
|
||||
// Literals
|
||||
TokenIdentifier
|
||||
|
||||
@@ -199,13 +199,13 @@ func (c *RAGFlowClient) ListUserDatasets(cmd *Command) (ResponseIf, error) {
|
||||
|
||||
// getDatasetID gets dataset ID by name
|
||||
func (c *RAGFlowClient) getDatasetID(datasetName string) (string, error) {
|
||||
resp, err := c.HTTPClient.Request("POST", "/kb/list", false, "web", nil, nil)
|
||||
resp, err := c.HTTPClient.Request("GET", "/datasets", true, "web", nil, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to list datasets: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("failed to list datasets: HTTP %d", resp.StatusCode)
|
||||
return "", fmt.Errorf("failed to list datasets: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
resJSON, err := resp.JSON()
|
||||
@@ -219,17 +219,12 @@ func (c *RAGFlowClient) getDatasetID(datasetName string) (string, error) {
|
||||
return "", fmt.Errorf("failed to list datasets: %s", msg)
|
||||
}
|
||||
|
||||
data, ok := resJSON["data"].(map[string]interface{})
|
||||
data, ok := resJSON["data"].([]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid response format")
|
||||
}
|
||||
|
||||
kbs, ok := data["kbs"].([]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid response format: kbs not found")
|
||||
}
|
||||
|
||||
for _, kb := range kbs {
|
||||
for _, kb := range data {
|
||||
if kbMap, ok := kb.(map[string]interface{}); ok {
|
||||
if name, _ := kbMap["name"].(string); name == datasetName {
|
||||
if id, _ := kbMap["id"].(string); id != "" {
|
||||
@@ -1487,3 +1482,195 @@ func (c *RAGFlowClient) InsertMetadataFromFile(cmd *Command) (ResponseIf, error)
|
||||
result.Duration = 0
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// UpdateChunk updates a chunk in a dataset
|
||||
func (c *RAGFlowClient) UpdateChunk(cmd *Command) (ResponseIf, error) {
|
||||
if c.ServerType != "user" {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
}
|
||||
|
||||
chunkID, ok := cmd.Params["chunk_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("chunk_id not provided")
|
||||
}
|
||||
|
||||
datasetName, ok := cmd.Params["dataset_name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("dataset_name not provided")
|
||||
}
|
||||
|
||||
jsonBody, ok := cmd.Params["json_body"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("json_body not provided")
|
||||
}
|
||||
|
||||
// Look up dataset_id from dataset_name
|
||||
datasetID, err := c.getDatasetID(datasetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset ID: %w", err)
|
||||
}
|
||||
|
||||
// Try to get doc_id from the chunk retrieval endpoint
|
||||
getResp, err := c.HTTPClient.Request("GET", "/chunk/get?chunk_id="+chunkID, false, "web", nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get chunk info: %w", err)
|
||||
}
|
||||
|
||||
var docID string
|
||||
if getResp.StatusCode == 200 {
|
||||
getJSON, err := getResp.JSON()
|
||||
if err == nil {
|
||||
if data, ok := getJSON["data"].(map[string]interface{}); ok {
|
||||
if d, ok := data["doc_id"].(string); ok {
|
||||
docID = d
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if docID == "" {
|
||||
return nil, fmt.Errorf("could not find document_id for chunk %s. Please provide document_id explicitly", chunkID)
|
||||
}
|
||||
|
||||
// Parse the JSON body
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonBody), &payload); err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON body: %w", err)
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("/datasets/%s/documents/%s/chunks/%s", datasetID, docID, chunkID)
|
||||
resp, err := c.HTTPClient.Request("PUT", path, true, "api", nil, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update chunk: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to update chunk: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
resJSON, err := resp.JSON()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON response: %w", err)
|
||||
}
|
||||
|
||||
code, ok := resJSON["code"].(float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid response format: code is not a number")
|
||||
}
|
||||
|
||||
var result SimpleResponse
|
||||
result.Code = int(code)
|
||||
if result.Code == 0 {
|
||||
result.Message = fmt.Sprintf("Success to update chunk: %s", chunkID)
|
||||
} else {
|
||||
result.Message = fmt.Sprintf("Failed to update chunk: %v", resJSON)
|
||||
}
|
||||
result.Duration = 0
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// SetMeta sets metadata for a document
|
||||
func (c *RAGFlowClient) SetMeta(cmd *Command) (ResponseIf, error) {
|
||||
if c.ServerType != "user" {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
}
|
||||
|
||||
docID, ok := cmd.Params["doc_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("doc_id not provided")
|
||||
}
|
||||
|
||||
metaJSON, ok := cmd.Params["meta"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("meta not provided")
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"doc_id": docID,
|
||||
"meta": metaJSON,
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("POST", "/document/set_meta", false, "web", nil, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set metadata: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to set metadata: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
resJSON, err := resp.JSON()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON response: %w", err)
|
||||
}
|
||||
|
||||
code, ok := resJSON["code"].(float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid response format: code is not a number")
|
||||
}
|
||||
|
||||
var result SimpleResponse
|
||||
result.Code = int(code)
|
||||
if result.Code == 0 {
|
||||
result.Message = fmt.Sprintf("Success to set metadata for document: %s", docID)
|
||||
} else {
|
||||
result.Message = fmt.Sprintf("Failed to set metadata: %v", resJSON)
|
||||
}
|
||||
result.Duration = 0
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// RmTags removes tags from chunks in a dataset
|
||||
func (c *RAGFlowClient) RmTags(cmd *Command) (ResponseIf, error) {
|
||||
if c.ServerType != "user" {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
}
|
||||
|
||||
datasetName, ok := cmd.Params["dataset_name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("dataset_name not provided")
|
||||
}
|
||||
|
||||
kbID, err := c.getDatasetID(datasetName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tags, ok := cmd.Params["tags"].([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tags not provided")
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("POST", "/kb/"+kbID+"/rm_tags", false, "web", nil, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to remove tags: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to remove tags: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
resJSON, err := resp.JSON()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON response: %w", err)
|
||||
}
|
||||
|
||||
code, ok := resJSON["code"].(float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid response format: code is not a number")
|
||||
}
|
||||
|
||||
var result SimpleResponse
|
||||
result.Code = int(code)
|
||||
if result.Code == 0 {
|
||||
result.Message = fmt.Sprintf("Success to remove tags from dataset: %s", kbID)
|
||||
} else {
|
||||
result.Message = fmt.Sprintf("Failed to remove tags: %v", resJSON)
|
||||
}
|
||||
result.Duration = 0
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
@@ -33,12 +33,8 @@ func (p *Parser) parseLoginUser() (*Command, error) {
|
||||
cmd.Params["email"] = email
|
||||
|
||||
p.nextToken()
|
||||
// Optional: WITH PASSWORD 'password'
|
||||
if p.curToken.Type == TokenWith {
|
||||
p.nextToken()
|
||||
if p.curToken.Type != TokenPassword {
|
||||
return nil, fmt.Errorf("expected PASSWORD after WITH")
|
||||
}
|
||||
// Optional: PASSWORD 'password'
|
||||
if p.curToken.Type == TokenPassword {
|
||||
p.nextToken()
|
||||
password, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
@@ -853,6 +849,17 @@ func (p *Parser) parseDeleteCommand() (*Command, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseRemoveCommand() (*Command, error) {
|
||||
p.nextToken() // consume RM
|
||||
|
||||
switch p.curToken.Type {
|
||||
case TokenTag:
|
||||
return p.parseRemoveTags()
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown REMOVE target: %s", p.curToken.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseDropToken() (*Command, error) {
|
||||
p.nextToken() // consume TOKEN
|
||||
|
||||
@@ -1574,6 +1581,9 @@ func (p *Parser) parseSetCommand() (*Command, error) {
|
||||
if p.curToken.Type == TokenToken {
|
||||
return p.parseSetToken()
|
||||
}
|
||||
if p.curToken.Type == TokenMetadata {
|
||||
return p.parseSetMeta()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown SET target: %s", p.curToken.Value)
|
||||
}
|
||||
@@ -2229,7 +2239,10 @@ func (p *Parser) parseUserStatement() (*Command, error) {
|
||||
return p.parseInsertCommand()
|
||||
case TokenSearch:
|
||||
return p.parseSearchCommand()
|
||||
|
||||
case TokenUpdate:
|
||||
return p.parseUpdateCommand()
|
||||
case TokenRemove:
|
||||
return p.parseRemoveCommand()
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid user statement: %s", p.curToken.Value)
|
||||
}
|
||||
@@ -2318,3 +2331,164 @@ func (p *Parser) parseUnsetCommand() (*Command, error) {
|
||||
}
|
||||
return NewCommand("unset_token"), nil
|
||||
}
|
||||
|
||||
// parseUpdateCommand parses UPDATE CHUNK command
|
||||
// UPDATE CHUNK 'chunk_id' OF DATASET 'dataset_name' SET '{"content": "..."}'
|
||||
func (p *Parser) parseUpdateCommand() (*Command, error) {
|
||||
p.nextToken() // consume UPDATE
|
||||
|
||||
if p.curToken.Type != TokenChunk {
|
||||
return nil, fmt.Errorf("expected CHUNK after UPDATE")
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
// Parse chunk_id
|
||||
chunkID, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected chunk_id: %w", err)
|
||||
}
|
||||
|
||||
cmd := NewCommand("update_chunk")
|
||||
cmd.Params["chunk_id"] = chunkID
|
||||
|
||||
p.nextToken()
|
||||
if p.curToken.Type != TokenOf {
|
||||
return nil, fmt.Errorf("expected OF after chunk_id")
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
if p.curToken.Type != TokenDataset {
|
||||
return nil, fmt.Errorf("expected DATASET after OF")
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
// Parse dataset_name
|
||||
datasetName, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected dataset_name: %w", err)
|
||||
}
|
||||
cmd.Params["dataset_name"] = datasetName
|
||||
|
||||
p.nextToken()
|
||||
if p.curToken.Type != TokenSet {
|
||||
return nil, fmt.Errorf("expected SET after dataset_name")
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
// Parse JSON body
|
||||
jsonBody, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected JSON body: %w", err)
|
||||
}
|
||||
cmd.Params["json_body"] = jsonBody
|
||||
|
||||
p.nextToken()
|
||||
// Semicolon is optional
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// parseSetMeta parses: SET METADATA OF DOCUMENT 'doc_id' TO '{"key": "value"}'
|
||||
func (p *Parser) parseSetMeta() (*Command, error) {
|
||||
p.nextToken() // consume METADATA
|
||||
|
||||
// Expect OF
|
||||
if p.curToken.Type != TokenOf {
|
||||
return nil, fmt.Errorf("expected OF after SET METADATA")
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
// Expect DOCUMENT
|
||||
if p.curToken.Type != TokenDocument {
|
||||
return nil, fmt.Errorf("expected DOCUMENT after SET METADATA OF")
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
// Parse doc_id
|
||||
docID, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected doc_id: %w", err)
|
||||
}
|
||||
cmd := NewCommand("set_meta")
|
||||
cmd.Params["doc_id"] = docID
|
||||
|
||||
p.nextToken()
|
||||
// Expect TO
|
||||
if p.curToken.Type != TokenTo {
|
||||
return nil, fmt.Errorf("expected TO after doc_id")
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
// Parse meta JSON
|
||||
meta, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected meta JSON: %w", err)
|
||||
}
|
||||
cmd.Params["meta"] = meta
|
||||
|
||||
p.nextToken()
|
||||
// Semicolon is optional
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// parseRemoveTags parses: REMOVE TAGS 'tag1', 'tag2' from DATASET 'dataset_name';
|
||||
func (p *Parser) parseRemoveTags() (*Command, error) {
|
||||
p.nextToken() // consume TAGS
|
||||
|
||||
// Parse first tag
|
||||
tag, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected tag: %w", err)
|
||||
}
|
||||
tags := []string{tag}
|
||||
|
||||
// Parse additional tags separated by commas
|
||||
for {
|
||||
p.nextToken()
|
||||
if p.curToken.Type == TokenComma {
|
||||
p.nextToken()
|
||||
tag, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected tag after comma: %w", err)
|
||||
}
|
||||
tags = append(tags, tag)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
cmd := NewCommand("rm_tags")
|
||||
cmd.Params["tags"] = tags
|
||||
|
||||
// Expect from
|
||||
if p.curToken.Type != TokenFrom {
|
||||
return nil, fmt.Errorf("expected FROM after tags")
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
// Expect DATASET
|
||||
if p.curToken.Type != TokenDataset {
|
||||
return nil, fmt.Errorf("expected DATASET after FROM")
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
// Parse dataset_name
|
||||
datasetName, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected dataset_name: %w", err)
|
||||
}
|
||||
cmd.Params["dataset_name"] = datasetName
|
||||
|
||||
// Semicolon is optional
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
@@ -163,3 +163,16 @@ func (e *elasticsearchEngine) InsertMetadata(ctx context.Context, documents []ma
|
||||
// TODO
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
|
||||
// UpdateDataset updates a chunk by condition
|
||||
func (e *elasticsearchEngine) UpdateDataset(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, tableNamePrefix string, knowledgebaseID string) error {
|
||||
// TODO
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateMetadata updates document metadata in tenant's metadata index
|
||||
func (e *elasticsearchEngine) UpdateMetadata(ctx context.Context, docID string, kbID string, metaFields map[string]interface{}, tenantID string) error {
|
||||
// TODO
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -50,6 +50,10 @@ type DocEngine interface {
|
||||
InsertDataset(ctx context.Context, documents []map[string]interface{}, indexName string, knowledgebaseID string) ([]string, error)
|
||||
InsertMetadata(ctx context.Context, documents []map[string]interface{}, tenantID string) ([]string, error)
|
||||
|
||||
// Update operations
|
||||
UpdateDataset(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, tableNamePrefix string, knowledgebaseID string) error
|
||||
UpdateMetadata(ctx context.Context, docID string, kbID string, metaFields map[string]interface{}, tenantID string) error
|
||||
|
||||
// Document operations
|
||||
IndexDocument(ctx context.Context, indexName, docID string, doc interface{}) error
|
||||
BulkIndex(ctx context.Context, indexName string, docs []interface{}) (interface{}, error)
|
||||
|
||||
@@ -346,22 +346,254 @@ func (e *infinityEngine) CreateDocMetaIndex(ctx context.Context, indexName strin
|
||||
return nil
|
||||
}
|
||||
|
||||
// TransformChunkFields transforms chunk field name for insert/update
|
||||
// It handles field name conversions and value transformations:
|
||||
// - docnm_kwd -> docnm
|
||||
// - title_kwd/title_sm_tks -> docnm (if docnm_kwd not set)
|
||||
// - important_kwd -> important_keywords (+ important_kwd_empty_count)
|
||||
// - content_with_weight/content_ltks/content_sm_ltks -> content
|
||||
// - authors_tks/authors_sm_tks -> authors
|
||||
// - question_kwd -> questions (joined with \n), question_tks -> questions (if question_kwd not set)
|
||||
// - kb_id: list -> str (first element)
|
||||
// - position_int: list -> hex_joined string
|
||||
// - page_num_int, top_int: list -> hex string
|
||||
// - *_feas fields -> JSON string
|
||||
// - keyword fields with list values -> ### joined string
|
||||
// - chunk_data: dict -> JSON string
|
||||
// - Missing embeddings filled with zeros if embeddingCols provided
|
||||
func TransformChunkFields(chunk map[string]interface{}, embeddingCols [][2]interface{}) map[string]interface{} {
|
||||
d := make(map[string]interface{})
|
||||
|
||||
for k, v := range chunk {
|
||||
switch k {
|
||||
case "docnm_kwd":
|
||||
d["docnm"] = v
|
||||
case "title_kwd":
|
||||
if _, exists := chunk["docnm_kwd"]; !exists {
|
||||
d["docnm"] = utility.ConvertToString(v)
|
||||
}
|
||||
case "title_sm_tks":
|
||||
if _, exists := chunk["docnm_kwd"]; !exists {
|
||||
d["docnm"] = utility.ConvertToString(v)
|
||||
}
|
||||
case "important_kwd":
|
||||
if list, ok := v.([]interface{}); ok {
|
||||
emptyCount := 0
|
||||
tokens := make([]string, 0)
|
||||
for _, item := range list {
|
||||
if str, ok := item.(string); ok {
|
||||
if str == "" {
|
||||
emptyCount++
|
||||
} else {
|
||||
tokens = append(tokens, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
d["important_keywords"] = strings.Join(tokens, ",")
|
||||
d["important_kwd_empty_count"] = emptyCount
|
||||
} else {
|
||||
d["important_keywords"] = utility.ConvertToString(v)
|
||||
}
|
||||
case "important_tks":
|
||||
if _, exists := chunk["important_kwd"]; !exists {
|
||||
d["important_keywords"] = v
|
||||
}
|
||||
case "content_with_weight":
|
||||
d["content"] = v
|
||||
case "content_ltks":
|
||||
if _, exists := chunk["content_with_weight"]; !exists {
|
||||
d["content"] = v
|
||||
}
|
||||
case "content_sm_ltks":
|
||||
if _, exists := chunk["content_with_weight"]; !exists {
|
||||
d["content"] = v
|
||||
}
|
||||
case "authors_tks":
|
||||
d["authors"] = v
|
||||
case "authors_sm_tks":
|
||||
if _, exists := chunk["authors_tks"]; !exists {
|
||||
d["authors"] = v
|
||||
}
|
||||
case "question_kwd":
|
||||
d["questions"] = strings.Join(utility.ConvertToStringSlice(v), "\n")
|
||||
case "tag_kwd":
|
||||
d["tag_kwd"] = strings.Join(utility.ConvertToStringSlice(v), "###")
|
||||
case "question_tks":
|
||||
if _, exists := chunk["question_kwd"]; !exists {
|
||||
d["questions"] = utility.ConvertToString(v)
|
||||
}
|
||||
case "kb_id":
|
||||
if list, ok := v.([]interface{}); ok && len(list) > 0 {
|
||||
d["kb_id"] = list[0]
|
||||
} else {
|
||||
d["kb_id"] = v
|
||||
}
|
||||
case "position_int":
|
||||
if list, ok := v.([]interface{}); ok {
|
||||
d["position_int"] = utility.ConvertPositionIntArrayToHex(list)
|
||||
} else {
|
||||
d["position_int"] = v
|
||||
}
|
||||
case "page_num_int", "top_int":
|
||||
if list, ok := v.([]interface{}); ok {
|
||||
d[k] = utility.ConvertIntArrayToHex(list)
|
||||
} else {
|
||||
d[k] = v
|
||||
}
|
||||
case "chunk_data":
|
||||
d["chunk_data"] = utility.ConvertMapToJSONString(v)
|
||||
default:
|
||||
// Check for *_feas fields
|
||||
if strings.HasSuffix(k, "_feas") {
|
||||
jsonBytes, _ := json.Marshal(v)
|
||||
d[k] = string(jsonBytes)
|
||||
} else if fieldKeyword(k) {
|
||||
// keyword fields with list values -> ### joined
|
||||
if list, ok := v.([]interface{}); ok {
|
||||
d[k] = strings.Join(utility.ConvertToStringSlice(list), "###")
|
||||
} else {
|
||||
d[k] = v
|
||||
}
|
||||
} else {
|
||||
d[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove intermediate token fields
|
||||
for _, key := range []string{"docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks",
|
||||
"content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks",
|
||||
"question_kwd", "question_tks"} {
|
||||
delete(d, key)
|
||||
}
|
||||
|
||||
// Fill missing embedding columns with zeros if embedding info provided
|
||||
for _, ec := range embeddingCols {
|
||||
name, size := ec[0].(string), ec[1].(int)
|
||||
if _, exists := d[name]; !exists {
|
||||
zeros := make([]float64, size)
|
||||
for i := range zeros {
|
||||
zeros[i] = 0
|
||||
}
|
||||
d[name] = zeros
|
||||
}
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
// existsCondition builds a NOT EXISTS or field!='' condition
|
||||
func existsCondition(field string, tableColumns map[string]struct {
|
||||
Type string
|
||||
Default interface{}
|
||||
}) string {
|
||||
col, colOk := tableColumns[field]
|
||||
if !colOk {
|
||||
logger.Warn(fmt.Sprintf("Column '%s' not found in table columns", field))
|
||||
return fmt.Sprintf("%s!=null", field)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(col.Type), "char") {
|
||||
if col.Default != nil {
|
||||
return fmt.Sprintf(" %s!='%v' ", field, col.Default)
|
||||
}
|
||||
return fmt.Sprintf(" %s!='' ", field)
|
||||
}
|
||||
if col.Default != nil {
|
||||
return fmt.Sprintf("%s!=%v", field, col.Default)
|
||||
}
|
||||
return fmt.Sprintf("%s!=null", field)
|
||||
}
|
||||
|
||||
func buildFilterFromCondition(condition map[string]interface{}, tableColumns map[string]struct {
|
||||
Type string
|
||||
Default interface{}
|
||||
}) string {
|
||||
var conditions []string
|
||||
|
||||
for k, v := range condition {
|
||||
if v == nil || v == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle must_not conditions -> NOT (...)
|
||||
if k == "must_not" {
|
||||
if mustNotMap, ok := v.(map[string]interface{}); ok {
|
||||
for kk, vv := range mustNotMap {
|
||||
if kk == "exists" {
|
||||
if existsField, ok := vv.(string); ok {
|
||||
conditions = append(conditions, fmt.Sprintf("NOT (%s)", existsCondition(existsField, tableColumns)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle keyword fields -> filter_fulltext with converted field name
|
||||
if fieldKeyword(k) {
|
||||
if listVal, ok := v.([]interface{}); ok {
|
||||
var orConds []string
|
||||
for _, item := range listVal {
|
||||
if strItem, ok := item.(string); ok {
|
||||
strItem = strings.ReplaceAll(strItem, "'", "''")
|
||||
orConds = append(orConds, fmt.Sprintf("filter_fulltext('%s', '%s')", convertMatchingField(k), strItem))
|
||||
}
|
||||
}
|
||||
if len(orConds) > 0 {
|
||||
conditions = append(conditions, "("+strings.Join(orConds, " OR ")+")")
|
||||
}
|
||||
} else if strVal, ok := v.(string); ok {
|
||||
strVal = strings.ReplaceAll(strVal, "'", "''")
|
||||
conditions = append(conditions, fmt.Sprintf("filter_fulltext('%s', '%s')", convertMatchingField(k), strVal))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle list values (IN condition)
|
||||
if listVal, ok := v.([]interface{}); ok {
|
||||
var inVals []string
|
||||
for _, item := range listVal {
|
||||
if strItem, ok := item.(string); ok {
|
||||
strItem = strings.ReplaceAll(strItem, "'", "''")
|
||||
inVals = append(inVals, fmt.Sprintf("'%s'", strItem))
|
||||
} else {
|
||||
inVals = append(inVals, fmt.Sprintf("%v", item))
|
||||
}
|
||||
}
|
||||
if len(inVals) > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("%s IN (%s)", k, strings.Join(inVals, ", ")))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle exists condition
|
||||
if k == "exists" {
|
||||
if existsField, ok := v.(string); ok {
|
||||
conditions = append(conditions, existsCondition(existsField, tableColumns))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle string values
|
||||
if strVal, ok := v.(string); ok {
|
||||
strVal = strings.ReplaceAll(strVal, "'", "''")
|
||||
conditions = append(conditions, fmt.Sprintf("%s='%s'", k, strVal))
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle other values
|
||||
conditions = append(conditions, fmt.Sprintf("%s=%v", k, v))
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
return "1=1"
|
||||
}
|
||||
return strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
// InsertDataset inserts chunks into a dataset table
|
||||
// Table name format: {tableNamePrefix}_{knowledgebaseID}
|
||||
// Auto-create the table if it doesn't exist
|
||||
// Transform chunks before insert:
|
||||
// - docnm_kwd -> docnm
|
||||
// - title_kwd/title_sm_tks -> docnm (if docnm_kwd not set)
|
||||
// - content_with_weight/content_ltks/content_sm_ltks -> content
|
||||
// - important_kwd -> important_keywords (+ important_kwd_empty_count)
|
||||
// - question_kwd -> questions (joined with \n)
|
||||
// - kb_id: list -> str (first element)
|
||||
// - position_int: list -> hex_joined string
|
||||
// - chunk_data: dict -> JSON string
|
||||
// - meta_fields: dict -> JSON string
|
||||
// - *_feas fields -> JSON string
|
||||
// - keyword fields with list values -> ### joined string
|
||||
// - Missing embeddings filled with zeros
|
||||
// Delete existing rows with matching IDs before insert
|
||||
func (e *infinityEngine) InsertDataset(ctx context.Context, chunks []map[string]interface{}, tableNamePrefix string, knowledgebaseID string) ([]string, error) {
|
||||
tableName := fmt.Sprintf("%s_%s", tableNamePrefix, knowledgebaseID)
|
||||
@@ -443,125 +675,10 @@ func (e *infinityEngine) InsertDataset(ctx context.Context, chunks []map[string]
|
||||
}
|
||||
}
|
||||
|
||||
// Transform chunks
|
||||
// Transform chunks using helper function
|
||||
insertChunks := make([]map[string]interface{}, len(chunks))
|
||||
for i, chunk := range chunks {
|
||||
d := make(map[string]interface{})
|
||||
|
||||
for k, v := range chunk {
|
||||
switch k {
|
||||
case "docnm_kwd":
|
||||
d["docnm"] = v
|
||||
case "title_kwd":
|
||||
if _, exists := chunk["docnm_kwd"]; !exists {
|
||||
d["docnm"] = utility.ConvertToString(v)
|
||||
}
|
||||
case "title_sm_tks":
|
||||
if _, exists := chunk["docnm_kwd"]; !exists {
|
||||
d["docnm"] = utility.ConvertToString(v)
|
||||
}
|
||||
case "important_kwd":
|
||||
if list, ok := v.([]interface{}); ok {
|
||||
emptyCount := 0
|
||||
tokens := make([]string, 0)
|
||||
for _, item := range list {
|
||||
if str, ok := item.(string); ok {
|
||||
if str == "" {
|
||||
emptyCount++
|
||||
} else {
|
||||
tokens = append(tokens, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
d["important_keywords"] = strings.Join(tokens, ",")
|
||||
d["important_kwd_empty_count"] = emptyCount
|
||||
} else {
|
||||
d["important_keywords"] = utility.ConvertToString(v)
|
||||
}
|
||||
case "important_tks":
|
||||
if _, exists := chunk["important_kwd"]; !exists {
|
||||
d["important_keywords"] = v
|
||||
}
|
||||
case "content_with_weight":
|
||||
d["content"] = v
|
||||
case "content_ltks":
|
||||
if _, exists := chunk["content_with_weight"]; !exists {
|
||||
d["content"] = v
|
||||
}
|
||||
case "content_sm_ltks":
|
||||
if _, exists := chunk["content_with_weight"]; !exists {
|
||||
d["content"] = v
|
||||
}
|
||||
case "authors_tks":
|
||||
d["authors"] = v
|
||||
case "authors_sm_tks":
|
||||
if _, exists := chunk["authors_tks"]; !exists {
|
||||
d["authors"] = v
|
||||
}
|
||||
case "question_kwd":
|
||||
d["questions"] = strings.Join(utility.ConvertToStringSlice(v), "\n")
|
||||
case "question_tks":
|
||||
if _, exists := chunk["question_kwd"]; !exists {
|
||||
d["questions"] = utility.ConvertToString(v)
|
||||
}
|
||||
case "kb_id":
|
||||
if list, ok := v.([]interface{}); ok && len(list) > 0 {
|
||||
d["kb_id"] = list[0]
|
||||
} else {
|
||||
d["kb_id"] = v
|
||||
}
|
||||
case "position_int":
|
||||
if list, ok := v.([]interface{}); ok {
|
||||
d["position_int"] = utility.ConvertPositionIntArrayToHex(list)
|
||||
} else {
|
||||
d["position_int"] = v
|
||||
}
|
||||
case "page_num_int", "top_int":
|
||||
if list, ok := v.([]interface{}); ok {
|
||||
d[k] = utility.ConvertIntArrayToHex(list)
|
||||
} else {
|
||||
d[k] = v
|
||||
}
|
||||
case "chunk_data":
|
||||
d["chunk_data"] = utility.ConvertMapToJSONString(v)
|
||||
default:
|
||||
// Check for *_feas fields
|
||||
if strings.HasSuffix(k, "_feas") {
|
||||
jsonBytes, _ := json.Marshal(v)
|
||||
d[k] = string(jsonBytes)
|
||||
} else if fieldKeyword(k) {
|
||||
// keyword fields with list values -> ### joined
|
||||
if list, ok := v.([]interface{}); ok {
|
||||
d[k] = strings.Join(utility.ConvertToStringSlice(list), "###")
|
||||
} else {
|
||||
d[k] = v
|
||||
}
|
||||
} else {
|
||||
d[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove intermediate token fields
|
||||
for _, key := range []string{"docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks",
|
||||
"content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks",
|
||||
"question_kwd", "question_tks"} {
|
||||
delete(d, key)
|
||||
}
|
||||
|
||||
// Fill missing embedding columns with zeros (raw slice, matching Python SDK)
|
||||
for _, ec := range embeddingCols {
|
||||
name, size := ec[0].(string), ec[1].(int)
|
||||
if _, exists := d[name]; !exists {
|
||||
zeros := make([]float64, size)
|
||||
for i := range zeros {
|
||||
zeros[i] = 0
|
||||
}
|
||||
d[name] = zeros
|
||||
}
|
||||
}
|
||||
|
||||
insertChunks[i] = d
|
||||
insertChunks[i] = TransformChunkFields(chunk, embeddingCols)
|
||||
}
|
||||
|
||||
// Delete existing rows with matching IDs
|
||||
@@ -590,6 +707,154 @@ func (e *infinityEngine) InsertDataset(ctx context.Context, chunks []map[string]
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// UpdateDataset updates chunks in a dataset table
|
||||
// Table name format: {tableNamePrefix}_{knowledgebaseID}
|
||||
func (e *infinityEngine) UpdateDataset(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, tableNamePrefix string, knowledgebaseID string) error {
|
||||
tableName := fmt.Sprintf("%s_%s", tableNamePrefix, knowledgebaseID)
|
||||
logger.Info("InfinityConnection.UpdateDataset called", zap.String("tableName", tableName), zap.Any("condition", condition))
|
||||
|
||||
db, err := e.client.conn.GetDatabase(e.client.dbName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to get database: %w", err)
|
||||
}
|
||||
|
||||
table, err := db.GetTable(tableName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to get table %s: %w", tableName, err)
|
||||
}
|
||||
|
||||
// Get table columns
|
||||
clmns := make(map[string]struct {
|
||||
Type string
|
||||
Default interface{}
|
||||
})
|
||||
colsResp, err := table.ShowColumns()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to get columns: %w", err)
|
||||
}
|
||||
result, ok := colsResp.(*infinity.QueryResult)
|
||||
if ok {
|
||||
if nameArr, ok := result.Data["name"]; ok {
|
||||
if typeArr, ok := result.Data["type"]; ok {
|
||||
if defArr, ok := result.Data["default"]; ok {
|
||||
for i := 0; i < len(nameArr); i++ {
|
||||
colName, _ := nameArr[i].(string)
|
||||
colType, _ := typeArr[i].(string)
|
||||
var colDefault interface{}
|
||||
if i < len(defArr) {
|
||||
colDefault = defArr[i]
|
||||
}
|
||||
clmns[colName] = struct {
|
||||
Type string
|
||||
Default interface{}
|
||||
}{colType, colDefault}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build filter string from condition
|
||||
filter := buildFilterFromCondition(condition, clmns)
|
||||
|
||||
// Process remove operation first
|
||||
removeValue := make(map[string]interface{})
|
||||
if removeData, ok := newValue["remove"].(map[string]interface{}); ok {
|
||||
removeValue = removeData
|
||||
}
|
||||
delete(newValue, "remove")
|
||||
|
||||
// Transform new_value fields using helper function (no embeddings needed for update)
|
||||
transformed := TransformChunkFields(newValue, nil)
|
||||
for k, v := range transformed {
|
||||
newValue[k] = v
|
||||
}
|
||||
|
||||
// Remove original fields that were transformed (they're now in transformed with new names/types)
|
||||
// Also remove intermediate token fields that shouldn't be stored in Infinity
|
||||
// This must match Python's delete list in infinity_conn.py
|
||||
for _, key := range []string{"docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks",
|
||||
"content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks",
|
||||
"question_kwd", "question_tks"} {
|
||||
delete(newValue, key)
|
||||
}
|
||||
|
||||
// Handle remove operations if any
|
||||
if len(removeValue) > 0 {
|
||||
colToRemove := make([]string, 0, len(removeValue))
|
||||
for k := range removeValue {
|
||||
colToRemove = append(colToRemove, k)
|
||||
}
|
||||
colToRemove = append(colToRemove, "id")
|
||||
|
||||
// Query rows to be updated
|
||||
queryResult, err := table.Output(colToRemove).Filter(filter).ToResult()
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to query rows for remove operation: %v", err))
|
||||
} else {
|
||||
qr, ok := queryResult.(*infinity.QueryResult)
|
||||
if ok && len(qr.Data) > 0 {
|
||||
// Get the id column and columns to remove
|
||||
idCol := qr.Data["id"]
|
||||
removeOpt := make(map[string]map[string][]string); // column -> value -> [ids]
|
||||
|
||||
for colName, colData := range qr.Data {
|
||||
if colName == "id" {
|
||||
continue
|
||||
}
|
||||
removeVal := removeValue[colName]
|
||||
for i, id := range idCol {
|
||||
if i < len(colData) {
|
||||
existingVal := colData[i]
|
||||
if removeStr, ok := removeVal.(string); ok {
|
||||
// Split existing value by ### and remove the target value
|
||||
if existingStr, ok := existingVal.(string); ok {
|
||||
parts := strings.Split(existingStr, "###")
|
||||
var newParts []string
|
||||
for _, p := range parts {
|
||||
if p != removeStr {
|
||||
newParts = append(newParts, p)
|
||||
}
|
||||
}
|
||||
if len(newParts) != len(parts) {
|
||||
idStr := fmt.Sprintf("%v", id)
|
||||
if removeOpt[colName] == nil {
|
||||
removeOpt[colName] = make(map[string][]string)
|
||||
}
|
||||
removeOpt[colName][strings.Join(newParts, "###")] = append(removeOpt[colName][strings.Join(newParts, "###")], idStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Execute remove updates
|
||||
for colName, valueToIDs := range removeOpt {
|
||||
for newVal, ids := range valueToIDs {
|
||||
idFilter := filter + " AND id IN (" + strings.Join(ids, ", ") + ")"
|
||||
logger.Info(fmt.Sprintf("INFINITY remove update: table=%s, idFilter=%s, column=%s, newValue=%v", tableName, idFilter, colName, newVal))
|
||||
_, err := table.Update(idFilter, map[string]interface{}{colName: newVal})
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to remove value from column %s: %v", colName, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the main update
|
||||
logger.Info(fmt.Sprintf("INFINITY update: table=%s, filter=%s, newValue=%v", tableName, filter, newValue))
|
||||
_, err = table.Update(filter, newValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to update chunks: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("InfinityConnection.UpdateDataset completes", zap.String("tableName", tableName))
|
||||
return nil
|
||||
}
|
||||
|
||||
// InsertMetadata inserts document metadata into tenant's metadata table
|
||||
// Table name format: ragflow_doc_meta_{tenant_id}
|
||||
// Auto-create the table if it doesn't exist
|
||||
@@ -663,3 +928,77 @@ func (e *infinityEngine) InsertMetadata(ctx context.Context, metadata []map[stri
|
||||
logger.Info("InfinityConnection.InsertMetadata result", zap.String("tableName", tableName), zap.Int("metaCount", len(metadata)))
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// UpdateMetadata updates document metadata in tenant's metadata table
|
||||
// Table name format: ragflow_doc_meta_{tenant_id}
|
||||
func (e *infinityEngine) UpdateMetadata(ctx context.Context, docID string, kbID string, metaFields map[string]interface{}, tenantID string) error {
|
||||
tableName := fmt.Sprintf("ragflow_doc_meta_%s", tenantID)
|
||||
logger.Info("InfinityConnection.UpdateMetadata called", zap.String("tableName", tableName), zap.String("docID", docID), zap.String("kbID", kbID))
|
||||
|
||||
db, err := e.client.conn.GetDatabase(e.client.dbName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get database: %w", err)
|
||||
}
|
||||
|
||||
table, err := db.GetTable(tableName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get metadata table %s: %w", tableName, err)
|
||||
}
|
||||
|
||||
// Query existing metadata using the chainable API
|
||||
filter := fmt.Sprintf("id = '%s' AND kb_id = '%s'", docID, kbID)
|
||||
|
||||
// Use chainable API: Output().Filter().Limit().Offset()
|
||||
queryTable := table.Output([]string{"id", "kb_id", "meta_fields"}).Filter(filter).Limit(1).Offset(0)
|
||||
|
||||
// Execute query
|
||||
result, err := queryTable.ToResult()
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to query existing metadata: %v", err))
|
||||
// If query fails, just insert new metadata
|
||||
} else {
|
||||
// Get results
|
||||
rows, ok := result.([]map[string]interface{})
|
||||
if ok && len(rows) > 0 {
|
||||
existingMetaFieldsVal := rows[0]["meta_fields"]
|
||||
|
||||
// Parse existing meta_fields if it's a string
|
||||
var existingMetaFields map[string]interface{}
|
||||
if existingMetaFieldsVal != nil {
|
||||
switch v := existingMetaFieldsVal.(type) {
|
||||
case string:
|
||||
if err := json.Unmarshal([]byte(v), &existingMetaFields); err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to parse existing meta_fields: %v", err))
|
||||
existingMetaFields = make(map[string]interface{})
|
||||
}
|
||||
case map[string]interface{}:
|
||||
existingMetaFields = v
|
||||
}
|
||||
}
|
||||
|
||||
// Merge new meta_fields with existing
|
||||
if existingMetaFields == nil {
|
||||
existingMetaFields = make(map[string]interface{})
|
||||
}
|
||||
for k, v := range metaFields {
|
||||
existingMetaFields[k] = v
|
||||
}
|
||||
metaFields = existingMetaFields
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare updated metadata
|
||||
updatedFields := map[string]interface{}{
|
||||
"meta_fields": utility.ConvertMapToJSONString(metaFields),
|
||||
}
|
||||
|
||||
// Update metadata
|
||||
logger.Info(fmt.Sprintf("INFINITY metadata update: table=%s, filter=%s, newValue=%v", tableName, filter, updatedFields))
|
||||
_, err = table.Update(filter, updatedFields)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update metadata: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("InfinityConnection.UpdateMetadata completes", zap.String("tableName", tableName), zap.String("docID", docID))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -344,6 +344,7 @@ func convertMatchingField(fieldWeightStr string) string {
|
||||
"content_sm_ltks": "content@ft_content_rag_fine",
|
||||
"authors_tks": "authors@ft_authors_rag_coarse",
|
||||
"authors_sm_tks": "authors@ft_authors_rag_fine",
|
||||
"tag_kwd": "tag_kwd@ft_tag_kwd_whitespace__",
|
||||
}
|
||||
|
||||
if newField, ok := fieldMapping[field]; ok {
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
|
||||
@@ -246,3 +247,122 @@ func (h *ChunkHandler) List(c *gin.Context) {
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateChunk updates a chunk
|
||||
// @Summary Update Chunk
|
||||
// @Description Update chunk fields
|
||||
// @Tags chunks
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param dataset_id path string true "Dataset ID"
|
||||
// @Param document_id path string true "Document ID"
|
||||
// @Param chunk_id path string true "Chunk ID"
|
||||
// @Param request body service.UpdateChunkRequest true "update chunk"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/datasets/{dataset_id}/documents/{document_id}/chunks/{chunk_id} [put]
|
||||
func (h *ChunkHandler) UpdateChunk(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
// Get path parameters
|
||||
datasetID := c.Param("dataset_id")
|
||||
documentID := c.Param("document_id")
|
||||
chunkID := c.Param("chunk_id")
|
||||
|
||||
if datasetID == "" || documentID == "" || chunkID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "dataset_id, document_id, and chunk_id are required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate allowed update fields
|
||||
var rawBody map[string]interface{}
|
||||
if err := json.NewDecoder(c.Request.Body).Decode(&rawBody); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "invalid JSON body: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Allowed fields for update
|
||||
allowedFields := map[string]bool{
|
||||
"content": true,
|
||||
"important_keywords": true,
|
||||
"questions": true,
|
||||
"available": true,
|
||||
"positions": true,
|
||||
"tag_kwd": true,
|
||||
"tag_feas": true,
|
||||
}
|
||||
for field := range rawBody {
|
||||
if !allowedFields[field] {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Update field '" + field + "' is not supported. Updatable fields: content, important_keywords, questions, available, positions, tag_kwd, tag_feas",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Build UpdateChunkRequest from rawBody
|
||||
var req service.UpdateChunkRequest
|
||||
if content, ok := rawBody["content"].(string); ok {
|
||||
req.Content = &content
|
||||
}
|
||||
if importantKwd, ok := rawBody["important_keywords"].([]interface{}); ok {
|
||||
req.ImportantKwd = make([]string, len(importantKwd))
|
||||
for i, v := range importantKwd {
|
||||
if s, ok := v.(string); ok {
|
||||
req.ImportantKwd[i] = s
|
||||
}
|
||||
}
|
||||
}
|
||||
if questions, ok := rawBody["questions"].([]interface{}); ok {
|
||||
req.Questions = make([]string, len(questions))
|
||||
for i, v := range questions {
|
||||
if s, ok := v.(string); ok {
|
||||
req.Questions[i] = s
|
||||
}
|
||||
}
|
||||
}
|
||||
if available, ok := rawBody["available"].(bool); ok {
|
||||
req.Available = &available
|
||||
}
|
||||
if positions, ok := rawBody["positions"].([]interface{}); ok {
|
||||
req.Positions = positions
|
||||
}
|
||||
if tagKwd, ok := rawBody["tag_kwd"].([]interface{}); ok {
|
||||
req.TagKwd = make([]string, len(tagKwd))
|
||||
for i, v := range tagKwd {
|
||||
if s, ok := v.(string); ok {
|
||||
req.TagKwd[i] = s
|
||||
}
|
||||
}
|
||||
}
|
||||
req.TagFeas = rawBody["tag_feas"]
|
||||
|
||||
// Set path parameters
|
||||
req.DatasetID = datasetID
|
||||
req.DocumentID = documentID
|
||||
req.ChunkID = chunkID
|
||||
|
||||
err := h.chunkService.UpdateChunk(&req, user.ID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "chunk updated successfully",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
"strconv"
|
||||
@@ -372,3 +374,103 @@ func (h *DocumentHandler) MetadataSummary(c *gin.Context) {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// SetMetaRequest represents the request for setting document metadata
|
||||
type SetMetaRequest struct {
|
||||
DocID string `json:"doc_id" binding:"required"`
|
||||
Meta string `json:"meta" binding:"required"`
|
||||
}
|
||||
|
||||
// SetMeta handles the set metadata request for a document
|
||||
// @Summary Set Document Metadata
|
||||
// @Description Set metadata for a specific document
|
||||
// @Tags documents
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param request body SetMetaRequest true "metadata info"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/document/set_meta [post]
|
||||
func (h *DocumentHandler) SetMeta(c *gin.Context) {
|
||||
_, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
var req SetMetaRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 1,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.DocID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 1,
|
||||
"message": "doc_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse meta JSON string
|
||||
var meta map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(req.Meta), &meta); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 1,
|
||||
"message": "Json syntax error: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if meta == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 1,
|
||||
"message": "meta is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate meta values - must be str, int, float, or list of those
|
||||
for k, v := range meta {
|
||||
switch val := v.(type) {
|
||||
case string, int, float64:
|
||||
// Valid
|
||||
case []interface{}:
|
||||
for _, item := range val {
|
||||
if _, ok := item.(string); !ok {
|
||||
if _, ok := item.(float64); !ok {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 1,
|
||||
"message": fmt.Sprintf("Unsupported type in list for key %s: %T", k, item),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 1,
|
||||
"message": fmt.Sprintf("Unsupported type for key %s: %T", k, v),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err := h.documentService.SetDocumentMetadata(req.DocID, meta)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 1,
|
||||
"message": "Failed to set metadata: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": true,
|
||||
})
|
||||
}
|
||||
@@ -444,6 +444,34 @@ func (h *KnowledgebaseHandler) RemoveTags(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get KB to find tenant_id and build index name
|
||||
kb, err := h.kbService.GetByID(kbID)
|
||||
if err != nil {
|
||||
jsonError(c, common.CodeDataError, "knowledge base not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Build index name prefix: ragflow_<tenant_id>
|
||||
indexName := "ragflow_" + kb.TenantID
|
||||
|
||||
// For each tag, call UpdateChunk to remove it from documents
|
||||
for _, tag := range req.Tags {
|
||||
condition := map[string]interface{}{
|
||||
"tag_kwd": tag,
|
||||
"kb_id": kbID,
|
||||
}
|
||||
newValue := map[string]interface{}{
|
||||
"remove": map[string]interface{}{
|
||||
"tag_kwd": tag,
|
||||
},
|
||||
}
|
||||
err := h.kbService.RemoveTag(condition, newValue, indexName, kbID)
|
||||
if err != nil {
|
||||
jsonError(c, common.CodeServerError, "Failed to remove tag: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, true, "success")
|
||||
}
|
||||
|
||||
|
||||
@@ -162,6 +162,12 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
datasets.DELETE("", r.datasetsHandler.DeleteDatasets)
|
||||
}
|
||||
|
||||
// RESTful dataset chunk routes
|
||||
datasetChunks := v1.Group("/datasets/:dataset_id/documents/:document_id/chunks")
|
||||
{
|
||||
datasetChunks.PUT("/:chunk_id", r.chunkHandler.UpdateChunk)
|
||||
}
|
||||
|
||||
// Author routes
|
||||
authors := v1.Group("/authors")
|
||||
{
|
||||
@@ -256,6 +262,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
{
|
||||
doc.POST("/list", r.documentHandler.ListDocuments)
|
||||
doc.POST("/metadata/summary", r.documentHandler.MetadataSummary)
|
||||
doc.POST("/set_meta", r.documentHandler.SetMeta)
|
||||
}
|
||||
|
||||
// Chunk routes
|
||||
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
"ragflow/internal/logger"
|
||||
|
||||
"ragflow/internal/service/nlp"
|
||||
"ragflow/internal/tokenizer"
|
||||
"ragflow/internal/utility"
|
||||
)
|
||||
|
||||
@@ -855,3 +856,156 @@ func (s *ChunkService) List(req *ListChunksRequest, userID string) (*ListChunksR
|
||||
Doc: docInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateChunkRequest request for updating a chunk
|
||||
type UpdateChunkRequest struct {
|
||||
DatasetID string `json:"dataset_id"`
|
||||
DocumentID string `json:"document_id"`
|
||||
ChunkID string `json:"chunk_id"`
|
||||
Content *string `json:"content,omitempty"`
|
||||
ImportantKwd []string `json:"important_keywords,omitempty"`
|
||||
Questions []string `json:"questions,omitempty"`
|
||||
Available *bool `json:"available,omitempty"`
|
||||
Positions []interface{} `json:"positions,omitempty"`
|
||||
TagKwd []string `json:"tag_kwd,omitempty"`
|
||||
TagFeas interface{} `json:"tag_feas,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateChunk updates a chunk fields
|
||||
func (s *ChunkService) UpdateChunk(req *UpdateChunkRequest, userID string) error {
|
||||
if s.docEngine == nil {
|
||||
return fmt.Errorf("doc engine not initialized")
|
||||
}
|
||||
|
||||
if req.ChunkID == "" {
|
||||
return fmt.Errorf("chunk_id is required")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get user's tenants
|
||||
tenants, err := s.userTenantDAO.GetByUserID(userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user tenants: %w", err)
|
||||
}
|
||||
if len(tenants) == 0 {
|
||||
return fmt.Errorf("user has no accessible tenants")
|
||||
}
|
||||
|
||||
// Find the tenant that owns this dataset
|
||||
var targetTenantID string
|
||||
for _, tenant := range tenants {
|
||||
kb, err := s.kbDAO.GetByIDAndTenantID(req.DatasetID, tenant.TenantID)
|
||||
if err == nil && kb != nil {
|
||||
targetTenantID = tenant.TenantID
|
||||
break
|
||||
}
|
||||
}
|
||||
if targetTenantID == "" {
|
||||
return fmt.Errorf("user does not have access to this dataset")
|
||||
}
|
||||
|
||||
// Verify document belongs to dataset
|
||||
docDAO := dao.NewDocumentDAO()
|
||||
doc, err := docDAO.GetByID(req.DocumentID)
|
||||
if err != nil || doc == nil {
|
||||
return fmt.Errorf("document not found")
|
||||
}
|
||||
if doc.KbID != req.DatasetID {
|
||||
return fmt.Errorf("document does not belong to this dataset")
|
||||
}
|
||||
|
||||
// Fetch existing chunk first (like Python does)
|
||||
indexName := fmt.Sprintf("ragflow_%s", targetTenantID)
|
||||
existingChunk, err := s.docEngine.GetChunk(ctx, indexName, req.ChunkID, []string{req.DatasetID})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing chunk: %w", err)
|
||||
}
|
||||
|
||||
existing, ok := existingChunk.(map[string]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid chunk format")
|
||||
}
|
||||
|
||||
// Build update dict like Python does (doc.py:1476-1523)
|
||||
d := make(map[string]interface{})
|
||||
|
||||
// Content - use new value or existing
|
||||
if req.Content != nil {
|
||||
d["content_with_weight"] = *req.Content
|
||||
} else {
|
||||
if v, ok := existing["content_with_weight"].(string); ok {
|
||||
d["content_with_weight"] = v
|
||||
} else if v, ok := existing["content"].(string); ok {
|
||||
d["content_with_weight"] = v
|
||||
} else {
|
||||
d["content_with_weight"] = ""
|
||||
}
|
||||
}
|
||||
|
||||
// Tokenize content
|
||||
contentStr := d["content_with_weight"].(string)
|
||||
d["content_ltks"], _ = tokenizer.Tokenize(contentStr)
|
||||
d["content_sm_ltks"], _ = tokenizer.FineGrainedTokenize(d["content_ltks"].(string))
|
||||
|
||||
// Important keywords - convert []string to []interface{} for transformChunkFields
|
||||
if req.ImportantKwd != nil {
|
||||
impKwd := make([]interface{}, len(req.ImportantKwd))
|
||||
for i, v := range req.ImportantKwd {
|
||||
impKwd[i] = v
|
||||
}
|
||||
d["important_kwd"] = impKwd
|
||||
}
|
||||
|
||||
// Questions
|
||||
if req.Questions != nil {
|
||||
// Filter out empty questions and trim
|
||||
filteredQuestions := []string{}
|
||||
for _, q := range req.Questions {
|
||||
q = strings.TrimSpace(q)
|
||||
if q != "" {
|
||||
filteredQuestions = append(filteredQuestions, q)
|
||||
}
|
||||
}
|
||||
d["question_kwd"] = filteredQuestions
|
||||
}
|
||||
|
||||
// Available
|
||||
if req.Available != nil {
|
||||
if *req.Available {
|
||||
d["available_int"] = 1
|
||||
} else {
|
||||
d["available_int"] = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Positions
|
||||
if req.Positions != nil {
|
||||
d["position_int"] = req.Positions
|
||||
}
|
||||
|
||||
// Tag keywords
|
||||
if req.TagKwd != nil {
|
||||
d["tag_kwd"] = req.TagKwd
|
||||
}
|
||||
|
||||
// Tag features
|
||||
if req.TagFeas != nil {
|
||||
d["tag_feas"] = req.TagFeas
|
||||
}
|
||||
|
||||
// Always include id
|
||||
d["id"] = req.ChunkID
|
||||
|
||||
// Call update
|
||||
condition := map[string]interface{}{
|
||||
"id": req.ChunkID,
|
||||
}
|
||||
|
||||
err = s.docEngine.UpdateDataset(ctx, condition, d, indexName, req.DatasetID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update chunk: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -276,6 +276,29 @@ func (s *DocumentService) GetMetadataSummary(kbID string, docIDs []string) (map[
|
||||
return aggregateMetadata(searchResult.Chunks), nil
|
||||
}
|
||||
|
||||
// SetDocumentMetadata sets metadata for a document in the document engine
|
||||
func (s *DocumentService) SetDocumentMetadata(docID string, meta map[string]interface{}) error {
|
||||
// Get document to find kb_id
|
||||
doc, err := s.documentDAO.GetByID(docID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("document not found: %w", err)
|
||||
}
|
||||
|
||||
// Get tenant ID
|
||||
tenantID, err := s.metadataSvc.GetTenantIDByKBID(doc.KbID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get tenant ID: %w", err)
|
||||
}
|
||||
|
||||
// Update metadata using the document engine (merges with existing)
|
||||
err = s.docEngine.UpdateMetadata(nil, docID, doc.KbID, meta, tenantID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update metadata: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDocumentMetadataByID get metadata for a specific document
|
||||
func (s *DocumentService) GetDocumentMetadataByID(docID string) (map[string]interface{}, error) {
|
||||
// Get document to find kb_id
|
||||
|
||||
@@ -475,6 +475,11 @@ func (s *KnowledgebaseService) Accessible(kbID, userID string) bool {
|
||||
return s.kbDAO.Accessible(kbID, userID)
|
||||
}
|
||||
|
||||
// RemoveTag removes a tag from documents in a dataset
|
||||
func (s *KnowledgebaseService) RemoveTag(condition map[string]interface{}, newValue map[string]interface{}, indexName, kbID string) error {
|
||||
return s.docEngine.UpdateDataset(context.Background(), condition, newValue, indexName, kbID)
|
||||
}
|
||||
|
||||
// GetByID retrieves a knowledge base by ID
|
||||
func (s *KnowledgebaseService) GetByID(kbID string) (*entity.Knowledgebase, error) {
|
||||
return s.kbDAO.GetByID(kbID)
|
||||
|
||||
@@ -37,6 +37,9 @@ func ExtractAccessToken(authorization, secretKey string) (string, error) {
|
||||
return "", errors.New("empty authorization")
|
||||
}
|
||||
|
||||
// Strip "Bearer " prefix if present
|
||||
token := strings.TrimPrefix(authorization, "Bearer ")
|
||||
|
||||
// Create URLSafeTimedSerializer with correct configuration
|
||||
// Matching Python itsdangerous configuration:
|
||||
// - salt: "itsdangerous"
|
||||
@@ -53,7 +56,7 @@ func ExtractAccessToken(authorization, secretKey string) (string, error) {
|
||||
)
|
||||
|
||||
// Unsign the token (verifies signature and extracts payload)
|
||||
encodedValue, err := signer.Unsign(authorization, 0)
|
||||
encodedValue, err := signer.Unsign(token, 0)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode token: %w", err)
|
||||
}
|
||||
|
||||
@@ -82,6 +82,8 @@ class InfinityConnection(InfinityConnectionBase):
|
||||
field = "authors@ft_authors_rag_coarse"
|
||||
elif field == "authors_sm_tks":
|
||||
field = "authors@ft_authors_rag_fine"
|
||||
elif field == "tag_kwd":
|
||||
field = "tag_kwd@ft_tag_kwd_whitespace__"
|
||||
tokens[0] = field
|
||||
return "^".join(tokens)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user