Files
ragflow/sdk/python/ragflow_sdk/ragflow.py

400 lines
14 KiB
Python
Raw Normal View History

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Any
import requests
from .modules.agent import Agent
from .modules.chat import Chat
from .modules.chunk import Chunk
from .modules.dataset import DataSet
from .modules.memory import Memory
class RAGFlow:
def __init__(self, api_key, base_url, version="v1"):
"""
api_url: http://<host_address>/api/v1
"""
self.user_key = api_key
self.api_url = f"{base_url}/api/{version}"
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}
def post(self, path, json=None, stream=False, files=None):
res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files)
return res
def get(self, path, params=None, json=None):
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json)
return res
def delete(self, path, json):
res = requests.delete(url=self.api_url + path, json=json, headers=self.authorization_header)
return res
def put(self, path, json):
res = requests.put(url=self.api_url + path, json=json, headers=self.authorization_header)
return res
def patch(self, path, json):
res = requests.patch(url=self.api_url + path, json=json, headers=self.authorization_header)
return res
def create_dataset(
self,
name: str,
avatar: Optional[str] = None,
description: Optional[str] = None,
embedding_model: Optional[str] = None,
permission: str = "me",
chunk_method: str = "naive",
parser_config: Optional[DataSet.ParserConfig] = None,
auto_metadata_config: Optional[dict[str, Any]] = None,
) -> DataSet:
payload = {
"name": name,
"avatar": avatar,
"description": description,
"embedding_model": embedding_model,
"permission": permission,
"chunk_method": chunk_method,
}
if parser_config is not None:
payload["parser_config"] = parser_config.to_json()
if auto_metadata_config is not None:
payload["auto_metadata_config"] = auto_metadata_config
res = self.post("/datasets", payload)
res = res.json()
if res.get("code") == 0:
return DataSet(self, res["data"])
raise Exception(res["message"])
def delete_datasets(self, ids: list[str] | None = None, delete_all: bool = False):
res = self.delete("/datasets", {"ids": ids, "delete_all": delete_all})
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
def get_dataset(self, name: str):
_list = self.list_datasets(name=name)
if len(_list) > 0:
return _list[0]
raise Exception("Dataset %s not found" % name)
def list_datasets(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None) -> list[DataSet]:
res = self.get(
"/datasets",
{
"page": page,
"page_size": page_size,
"orderby": orderby,
"desc": desc,
"id": id,
"name": name,
},
)
res = res.json()
result_list = []
if res.get("code") == 0:
for data in res["data"]:
result_list.append(DataSet(self, data))
return result_list
raise Exception(res["message"])
def create_chat(
self,
name: str,
icon: str = "",
dataset_ids: list[str] | None = None,
llm_id: str | None = None,
llm_setting: dict | None = None,
prompt_config: dict | None = None,
**kwargs,
) -> Chat:
payload = {"name": name, "icon": icon, "dataset_ids": dataset_ids or []}
if llm_id is not None:
payload["llm_id"] = llm_id
if llm_setting is not None:
payload["llm_setting"] = llm_setting
if prompt_config is not None:
payload["prompt_config"] = prompt_config
payload.update(kwargs)
res = self.post("/chats", payload)
res = res.json()
if res.get("code") == 0:
return Chat(self, res["data"])
raise Exception(res["message"])
def delete_chats(self, ids: list[str] | None = None, delete_all: bool = False):
res = self.delete("/chats", {"ids": ids, "delete_all": delete_all})
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
def get_chat(self, chat_id: str) -> Chat:
res = self.get(f"/chats/{chat_id}")
res = res.json()
if res.get("code") == 0:
return Chat(self, res["data"])
raise Exception(res["message"])
def list_chats(
self,
page: int = 1,
page_size: int = 30,
orderby: str = "create_time",
desc: bool = True,
id: str | None = None,
name: str | None = None,
keywords: str | None = None,
owner_ids: str | list[str] | None = None,
) -> list[Chat]:
res = self.get(
"/chats",
{
"page": page,
"page_size": page_size,
"orderby": orderby,
"desc": desc,
"id": id,
"name": name,
"keywords": keywords,
"owner_ids": owner_ids,
},
)
res = res.json()
result_list = []
if res.get("code") == 0:
for data in res["data"]["chats"]:
result_list.append(Chat(self, data))
return result_list
raise Exception(res["message"])
def retrieve(
self,
dataset_ids,
document_ids=None,
question="",
page=1,
page_size=30,
similarity_threshold=0.2,
vector_similarity_weight=0.3,
top_k=1024,
rerank_id: str | None = None,
keyword: bool = False,
cross_languages: list[str]|None = None,
metadata_condition: dict | None = None,
use_kg: bool = False,
toc_enhance: bool = False,
):
if document_ids is None:
document_ids = []
data_json = {
"page": page,
"page_size": page_size,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight,
"top_k": top_k,
"rerank_id": rerank_id,
"keyword": keyword,
"question": question,
"dataset_ids": dataset_ids,
"document_ids": document_ids,
"cross_languages": cross_languages,
"metadata_condition": metadata_condition,
"use_kg": use_kg,
"toc_enhance": toc_enhance
}
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
res = self.post("/retrieval", json=data_json)
res = res.json()
if res.get("code") == 0:
chunks = []
for chunk_data in res["data"].get("chunks"):
chunk = Chunk(self, chunk_data)
chunks.append(chunk)
return chunks
raise Exception(res.get("message"))
Feat: Agent api (#14157) ### What problem does this PR solve? 1. **List agents** **Prev API**: - `/v1/canvas/list GET` - `/api/v1/agents GET` **Current API**: `/api/v2/agents GET` 2. **Get canvas template** **Prev API**: `/v1/canvas/templates GET` **Current API**: `/api/v2/agents/templates GET` 3. **Delete an agent** **Prev API**: - `/v1/canvas/rm POST` - `/api/v1/agents/<agent_id> DELETE` **Current API**: `/api/v2/agents/<agent_id> DELETE` 4. **Update an agent** **Prev API**: - `/api/v1/agents/<agent_id> PUT` - `/v1/canvas/setting POST ` **Current API**: `/api/v2/agents/<agent_id> PATCH` 5. **Create an agent** **Prev API**: - `/v1/canvas/set POST` - `/api/v1/agents POST` **Current API**: `/api/v2/agents POST` 6. **Get an agent** **Prev API**: - `/v1/canvas/get/<canvas_id> GET ` **Current API**: `/api/v2/agents/<agent_id> GET` 7. **Reset an agent** **Prev API**: - `/v1/canvas/reset POST` **Current API**: `/api/v2/agents/<agent_id>/reset POST` 8. **Upload a file to an agent** **Prev API**: - `/v1/canvas/upload/<canvas_id> POST` **Current API**: `/api/v2/agents/<agent_id>/upload POST` 9. **Input form** **Prev API**: - `/v1/canvas/input_form GET` **Current API**: `/api/v2/agents/<agent_id>/components/<component_id>/input-form GET` 10. **Debug an agent** **Prev API**: - `/v1/canvas/debug POST` **Current API**: `/api/v2/agents/<agent_id>/components/<component_id>/debug POST` 11. **Trace an agent** **Prev API**: - `/v1/canvas/trace GET` **Current API**: `/api/v2/agents/<agent_id>/logs/<message_id> GET` 12. **Get an agent version list** **Prev API**: - `/v1/canvas/getlistversion/<canvas_id>` **Current API**: `/api/v2/agents/<agent_id>/versions GET` 13. **Get a version of agent** **Prev API**: - `/v1/canvas/getversion/<version_id>` **Current API**: `/api/v2/agents/<agent_id>/versions/<version_id> GET` 14. **Test db connection** **Prev API**: - `/v1/canvas/test_db_connect POST` **Current API**: `/api/v2/agents/test_db_connection` 15. **Rerun the agent** **Prev API**: - `/v1/canvas/rerun POST` **Current API**: `/api/v2/agents/rerun POST` 16. **Get prompts** **Prev API**: - `/v1/canvas/prompts GET` **Current API**: `/api/v2/agents/prompts GET` ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: chanx <1243304602@qq.com>
2026-04-24 10:02:22 +08:00
def list_agents(self, page: int = 1, page_size: int = 30, orderby: str = "update_time", desc: bool = True) -> list[Agent]:
res = self.get(
"/agents",
{
"page": page,
"page_size": page_size,
"orderby": orderby,
"desc": desc,
},
)
res = res.json()
result_list = []
if res.get("code") == 0:
Feat: Agent api (#14157) ### What problem does this PR solve? 1. **List agents** **Prev API**: - `/v1/canvas/list GET` - `/api/v1/agents GET` **Current API**: `/api/v2/agents GET` 2. **Get canvas template** **Prev API**: `/v1/canvas/templates GET` **Current API**: `/api/v2/agents/templates GET` 3. **Delete an agent** **Prev API**: - `/v1/canvas/rm POST` - `/api/v1/agents/<agent_id> DELETE` **Current API**: `/api/v2/agents/<agent_id> DELETE` 4. **Update an agent** **Prev API**: - `/api/v1/agents/<agent_id> PUT` - `/v1/canvas/setting POST ` **Current API**: `/api/v2/agents/<agent_id> PATCH` 5. **Create an agent** **Prev API**: - `/v1/canvas/set POST` - `/api/v1/agents POST` **Current API**: `/api/v2/agents POST` 6. **Get an agent** **Prev API**: - `/v1/canvas/get/<canvas_id> GET ` **Current API**: `/api/v2/agents/<agent_id> GET` 7. **Reset an agent** **Prev API**: - `/v1/canvas/reset POST` **Current API**: `/api/v2/agents/<agent_id>/reset POST` 8. **Upload a file to an agent** **Prev API**: - `/v1/canvas/upload/<canvas_id> POST` **Current API**: `/api/v2/agents/<agent_id>/upload POST` 9. **Input form** **Prev API**: - `/v1/canvas/input_form GET` **Current API**: `/api/v2/agents/<agent_id>/components/<component_id>/input-form GET` 10. **Debug an agent** **Prev API**: - `/v1/canvas/debug POST` **Current API**: `/api/v2/agents/<agent_id>/components/<component_id>/debug POST` 11. **Trace an agent** **Prev API**: - `/v1/canvas/trace GET` **Current API**: `/api/v2/agents/<agent_id>/logs/<message_id> GET` 12. **Get an agent version list** **Prev API**: - `/v1/canvas/getlistversion/<canvas_id>` **Current API**: `/api/v2/agents/<agent_id>/versions GET` 13. **Get a version of agent** **Prev API**: - `/v1/canvas/getversion/<version_id>` **Current API**: `/api/v2/agents/<agent_id>/versions/<version_id> GET` 14. **Test db connection** **Prev API**: - `/v1/canvas/test_db_connect POST` **Current API**: `/api/v2/agents/test_db_connection` 15. **Rerun the agent** **Prev API**: - `/v1/canvas/rerun POST` **Current API**: `/api/v2/agents/rerun POST` 16. **Get prompts** **Prev API**: - `/v1/canvas/prompts GET` **Current API**: `/api/v2/agents/prompts GET` ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: chanx <1243304602@qq.com>
2026-04-24 10:02:22 +08:00
data = res.get("data") or {}
data_list = data.get("canvas", [])
for data in data_list:
result_list.append(Agent(self, data))
return result_list
raise Exception(res["message"])
Feat: Agent api (#14157) ### What problem does this PR solve? 1. **List agents** **Prev API**: - `/v1/canvas/list GET` - `/api/v1/agents GET` **Current API**: `/api/v2/agents GET` 2. **Get canvas template** **Prev API**: `/v1/canvas/templates GET` **Current API**: `/api/v2/agents/templates GET` 3. **Delete an agent** **Prev API**: - `/v1/canvas/rm POST` - `/api/v1/agents/<agent_id> DELETE` **Current API**: `/api/v2/agents/<agent_id> DELETE` 4. **Update an agent** **Prev API**: - `/api/v1/agents/<agent_id> PUT` - `/v1/canvas/setting POST ` **Current API**: `/api/v2/agents/<agent_id> PATCH` 5. **Create an agent** **Prev API**: - `/v1/canvas/set POST` - `/api/v1/agents POST` **Current API**: `/api/v2/agents POST` 6. **Get an agent** **Prev API**: - `/v1/canvas/get/<canvas_id> GET ` **Current API**: `/api/v2/agents/<agent_id> GET` 7. **Reset an agent** **Prev API**: - `/v1/canvas/reset POST` **Current API**: `/api/v2/agents/<agent_id>/reset POST` 8. **Upload a file to an agent** **Prev API**: - `/v1/canvas/upload/<canvas_id> POST` **Current API**: `/api/v2/agents/<agent_id>/upload POST` 9. **Input form** **Prev API**: - `/v1/canvas/input_form GET` **Current API**: `/api/v2/agents/<agent_id>/components/<component_id>/input-form GET` 10. **Debug an agent** **Prev API**: - `/v1/canvas/debug POST` **Current API**: `/api/v2/agents/<agent_id>/components/<component_id>/debug POST` 11. **Trace an agent** **Prev API**: - `/v1/canvas/trace GET` **Current API**: `/api/v2/agents/<agent_id>/logs/<message_id> GET` 12. **Get an agent version list** **Prev API**: - `/v1/canvas/getlistversion/<canvas_id>` **Current API**: `/api/v2/agents/<agent_id>/versions GET` 13. **Get a version of agent** **Prev API**: - `/v1/canvas/getversion/<version_id>` **Current API**: `/api/v2/agents/<agent_id>/versions/<version_id> GET` 14. **Test db connection** **Prev API**: - `/v1/canvas/test_db_connect POST` **Current API**: `/api/v2/agents/test_db_connection` 15. **Rerun the agent** **Prev API**: - `/v1/canvas/rerun POST` **Current API**: `/api/v2/agents/rerun POST` 16. **Get prompts** **Prev API**: - `/v1/canvas/prompts GET` **Current API**: `/api/v2/agents/prompts GET` ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: chanx <1243304602@qq.com>
2026-04-24 10:02:22 +08:00
def get_agent(self, agent_id: str) -> Agent:
res = self.get(f"/agents/{agent_id}")
res = res.json()
if res.get("code") == 0:
return Agent(self, res["data"])
raise Exception(res["message"])
def create_agent(
self,
title: str,
dsl: dict,
description: str | None = None,
canvas_type: str | None = None,
) -> None:
req = {"title": title, "dsl": dsl}
if description is not None:
req["description"] = description
if canvas_type is not None:
req["canvas_type"] = canvas_type
res = self.post("/agents", req)
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
def update_agent(
self,
agent_id: str,
title: str | None = None,
description: str | None = None,
dsl: dict | None = None,
canvas_type: str | None = None,
) -> None:
req = {}
if title is not None:
req["title"] = title
if description is not None:
req["description"] = description
if dsl is not None:
req["dsl"] = dsl
if canvas_type is not None:
req["canvas_type"] = canvas_type
res = self.put(f"/agents/{agent_id}", req)
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
def delete_agent(self, agent_id: str) -> None:
res = self.delete(f"/agents/{agent_id}", {})
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
def create_memory(self, name: str, memory_type: list[str], embd_id: str, llm_id: str):
payload = {"name": name, "memory_type": memory_type, "embd_id": embd_id, "llm_id": llm_id}
res = self.post("/memories", payload)
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
return Memory(self, res["data"])
def list_memory(self, page: int = 1, page_size: int = 50, tenant_id: str | list[str] = None, memory_type: str | list[str] = None, storage_type: str = None, keywords: str = None) -> dict:
res = self.get(
"/memories",
{
"page": page,
"page_size": page_size,
"tenant_id": tenant_id,
"memory_type": memory_type,
"storage_type": storage_type,
"keywords": keywords,
}
)
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
result_list = []
for data in res["data"]["memory_list"]:
result_list.append(Memory(self, data))
return {
"code": res.get("code", 0),
"message": res.get("message"),
"memory_list": result_list,
"total_count": res["data"]["total_count"]
}
def delete_memory(self, memory_id: str):
res = self.delete(f"/memories/{memory_id}", {})
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
def add_message(self, memory_id: list[str], agent_id: str, session_id: str, user_input: str, agent_response: str, user_id: str = "") -> str:
Fix: bind memory message `user_id` to authenticated user for JWT auth (#14745) ### Related issues Closes #14744 ### What problem does this PR solve? The Memory REST endpoint `POST /api/v1/messages` previously persisted whatever `user_id` the client sent in the JSON body. Memory rows were therefore attributed to an arbitrary string, even when the caller authenticated as a normal workspace user via JWT (browser/session-style bearer token decoded into an access token). That broke attribution and audit semantics for shared memories (team visibility): any authorized writer could spoof another subject id. The Python SDK already sends an optional `user_id` for integrations using **API keys** (`APIToken`) to tag an external subject distinct from the tenant owner user. ### Solution - Record **`g.auth_via_api_token`** in `_load_user` (`api/apps/__init__.py`): set `True` only when authentication resolves via `APIToken`, otherwise `False` after JWT-based login succeeds. - In **`POST /messages`** (`memory_api.add_message`): if the request was authenticated with an API key, keep accepting optional `user_id` from the body (default empty string). For JWT-authenticated users, **always** set stored `user_id` to **`current_user.id`** and ignore the client field. - Guard reads of `g` with **`RuntimeError`** handling so isolated imports or tests without a Quart application context do not fail when resolving `user_id`. - Document on **`RAGFlow.add_message`** that `user_id` is only meaningful for API-key authentication. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [ ] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): ### Testing - `python -m py_compile` on modified modules (`api/apps/__init__.py`, `api/apps/restful_apis/memory_api.py`). - Recommended: run web/SDK memory message tests (`test_add_message`, `test_message_routes_unit`) against a full environment with `quart` and configured services. ### Notes for reviewers - Behavior change **only** for callers using JWT-style authorization on `POST /messages`; API-key callers keep prior optional `user_id` semantics. Co-authored-by: jony376 <jony376@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-10 22:26:05 -07:00
"""Append messages to memories; ``user_id`` is forwarded only for API-key auth (external subject)."""
payload = {
"memory_id": memory_id,
"agent_id": agent_id,
"session_id": session_id,
"user_input": user_input,
"agent_response": agent_response,
"user_id": user_id
}
res = self.post("/messages", payload)
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
return res["message"]
def search_message(self, query: str, memory_id: list[str], agent_id: str=None, session_id: str=None, user_id: str=None, similarity_threshold: float=0.2, keywords_similarity_weight: float=0.7, top_n: int=10) -> list[dict]:
params = {
"query": query,
"memory_id": memory_id,
"agent_id": agent_id,
"session_id": session_id,
"user_id": user_id,
"similarity_threshold": similarity_threshold,
"keywords_similarity_weight": keywords_similarity_weight,
"top_n": top_n
}
res = self.get("/messages/search", params)
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
return res["data"]
def get_recent_messages(self, memory_id: list[str], agent_id: str=None, session_id: str=None, limit: int=10) -> list[dict]:
params = {
"memory_id": memory_id,
"agent_id": agent_id,
"session_id": session_id,
"limit": limit
}
res = self.get("/messages", params)
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
return res["data"]