mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 08:15:44 +08:00
Fix: inconsistent state handling for multi-user single-canvas access (#13267)
### What problem does this PR solve? <img width="700" alt="image" src="https://github.com/user-attachments/assets/1db7412e-4554-44bc-84ba-16421949aacc" /> ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
This commit is contained in:
@@ -40,13 +40,13 @@ from api.utils.api_utils import (
|
||||
from agent.canvas import Canvas
|
||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||
from api.db.db_models import APIToken, Task
|
||||
import time
|
||||
|
||||
from rag.flow.pipeline import Pipeline
|
||||
from rag.nlp import search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
from api.apps.services.canvas_replica_service import CanvasReplicaService
|
||||
from api.db.services.canvas_service import completion as agent_completion
|
||||
|
||||
|
||||
@@ -75,9 +75,10 @@ async def rm():
|
||||
@login_required
|
||||
async def save():
|
||||
req = await get_request_json()
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
try:
|
||||
req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"])
|
||||
except ValueError as e:
|
||||
return get_data_error_result(message=str(e))
|
||||
cate = req.get("canvas_category", CanvasCategory.Agent)
|
||||
if "id" not in req:
|
||||
req["user_id"] = current_user.id
|
||||
@@ -93,8 +94,21 @@ async def save():
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
UserCanvasService.update_by_id(req["id"], req)
|
||||
# save version
|
||||
UserCanvasVersionService.insert(user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
|
||||
UserCanvasVersionService.delete_all_versions(req["id"])
|
||||
UserCanvasVersionService.save_or_replace_latest(
|
||||
user_canvas_id=req["id"],
|
||||
dsl=req["dsl"],
|
||||
title=UserCanvasVersionService.build_version_title(getattr(current_user, "nickname", current_user.id), req.get("title")),
|
||||
)
|
||||
replica_ok = CanvasReplicaService.replace_for_set(
|
||||
canvas_id=req["id"],
|
||||
tenant_id=str(current_user.id),
|
||||
runtime_user_id=str(current_user.id),
|
||||
dsl=req["dsl"],
|
||||
canvas_category=req.get("canvas_category", cate),
|
||||
title=req.get("title", ""),
|
||||
)
|
||||
if not replica_ok:
|
||||
return get_data_error_result(message="canvas saved, but replica sync failed.")
|
||||
return get_json_result(data=req)
|
||||
|
||||
|
||||
@@ -104,6 +118,20 @@ def get(canvas_id):
|
||||
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
try:
|
||||
# DELETE
|
||||
CanvasReplicaService.bootstrap(
|
||||
canvas_id=canvas_id,
|
||||
tenant_id=str(current_user.id),
|
||||
runtime_user_id=str(current_user.id),
|
||||
dsl=c.get("dsl"),
|
||||
canvas_category=c.get("canvas_category", CanvasCategory.Agent),
|
||||
title=c.get("title", ""),
|
||||
)
|
||||
except ValueError as e:
|
||||
return get_data_error_result(message=str(e))
|
||||
return get_json_result(data=c)
|
||||
|
||||
|
||||
@@ -137,29 +165,38 @@ async def run():
|
||||
query = req.get("query", "")
|
||||
files = req.get("files", [])
|
||||
inputs = req.get("inputs", {})
|
||||
user_id = req.get("user_id", current_user.id)
|
||||
if not await thread_pool_exec(UserCanvasService.accessible, req["id"], current_user.id):
|
||||
tenant_id = str(current_user.id)
|
||||
runtime_user_id = req.get("user_id") or tenant_id
|
||||
user_id = str(runtime_user_id)
|
||||
if not await thread_pool_exec(UserCanvasService.accessible, req["id"], tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
replica_payload = CanvasReplicaService.load_for_run(
|
||||
canvas_id=req["id"],
|
||||
tenant_id=tenant_id,
|
||||
runtime_user_id=user_id,
|
||||
)
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
if not replica_payload:
|
||||
return get_data_error_result(message="canvas replica not found, please call /get/<canvas_id> first.")
|
||||
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
replica_dsl = replica_payload.get("dsl", {})
|
||||
canvas_title = replica_payload.get("title", "")
|
||||
canvas_category = replica_payload.get("canvas_category", CanvasCategory.Agent)
|
||||
dsl_str = json.dumps(replica_dsl, ensure_ascii=False)
|
||||
|
||||
if canvas_category == CanvasCategory.DataFlow:
|
||||
task_id = get_uuid()
|
||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||
Pipeline(dsl_str, tenant_id=tenant_id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||
ok, error_message = await thread_pool_exec(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
|
||||
if not ok:
|
||||
return get_data_error_result(message=error_message)
|
||||
return get_json_result(data={"message_id": task_id})
|
||||
|
||||
try:
|
||||
canvas = Canvas(cvs.dsl, current_user.id, canvas_id=cvs.id)
|
||||
canvas = Canvas(dsl_str, tenant_id, canvas_id=req["id"])
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@@ -169,8 +206,21 @@ async def run():
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
|
||||
commit_ok = CanvasReplicaService.commit_after_run(
|
||||
canvas_id=req["id"],
|
||||
tenant_id=tenant_id,
|
||||
runtime_user_id=user_id,
|
||||
dsl=json.loads(str(canvas)),
|
||||
canvas_category=canvas_category,
|
||||
title=canvas_title,
|
||||
)
|
||||
if not commit_ok:
|
||||
logging.error(
|
||||
"Canvas runtime replica commit failed: canvas_id=%s tenant_id=%s runtime_user_id=%s",
|
||||
req["id"],
|
||||
tenant_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
@@ -27,9 +27,11 @@ from typing import Any, cast
|
||||
import jwt
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api.apps.services.canvas_replica_service import CanvasReplicaService
|
||||
from api.db import CanvasCategory
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.user_service import UserService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
@@ -39,6 +41,13 @@ from quart import request, Response
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
def _get_user_nickname(user_id: str) -> str:
|
||||
exists, user = UserService.get_by_id(user_id)
|
||||
if not exists:
|
||||
return user_id
|
||||
return str(getattr(user, "nickname", "") or user_id)
|
||||
|
||||
|
||||
@manager.route('/agents', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def list_agents(tenant_id):
|
||||
@@ -66,10 +75,10 @@ async def create_agent(tenant_id: str):
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
try:
|
||||
req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"])
|
||||
except ValueError as e:
|
||||
return get_json_result(data=False, message=str(e), code=RetCode.ARGUMENT_ERROR)
|
||||
else:
|
||||
return get_json_result(data=False, message="No DSL data in request.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
@@ -87,9 +96,10 @@ async def create_agent(tenant_id: str):
|
||||
if not UserCanvasService.save(**req):
|
||||
return get_data_error_result(message="Fail to create agent.")
|
||||
|
||||
UserCanvasVersionService.insert(
|
||||
owner_nickname = _get_user_nickname(tenant_id)
|
||||
UserCanvasVersionService.save_or_replace_latest(
|
||||
user_canvas_id=agent_id,
|
||||
title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")),
|
||||
title=UserCanvasVersionService.build_version_title(owner_nickname, req.get("title")),
|
||||
dsl=req["dsl"]
|
||||
)
|
||||
|
||||
@@ -103,10 +113,10 @@ async def update_agent(tenant_id: str, agent_id: str):
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
try:
|
||||
req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"])
|
||||
except ValueError as e:
|
||||
return get_json_result(data=False, message=str(e), code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
if req.get("title") is not None:
|
||||
req["title"] = req["title"].strip()
|
||||
@@ -116,17 +126,19 @@ async def update_agent(tenant_id: str, agent_id: str):
|
||||
data=False, message="Only owner of canvas authorized for this operation.",
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
_, current_agent = UserCanvasService.get_by_id(agent_id)
|
||||
agent_title_for_version = req.get("title") or (current_agent.title if current_agent else "")
|
||||
owner_nickname = _get_user_nickname(tenant_id)
|
||||
|
||||
UserCanvasService.update_by_id(agent_id, req)
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
UserCanvasVersionService.insert(
|
||||
UserCanvasVersionService.save_or_replace_latest(
|
||||
user_canvas_id=agent_id,
|
||||
title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")),
|
||||
title=UserCanvasVersionService.build_version_title(owner_nickname, agent_title_for_version),
|
||||
dsl=req["dsl"]
|
||||
)
|
||||
|
||||
UserCanvasVersionService.delete_all_versions(agent_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
|
||||
258
api/apps/services/canvas_replica_service.py
Normal file
258
api/apps/services/canvas_replica_service.py
Normal file
@@ -0,0 +1,258 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
|
||||
from api.db import CanvasCategory
|
||||
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
|
||||
|
||||
|
||||
class CanvasReplicaService:
|
||||
"""
|
||||
Manage per-user canvas runtime replicas stored in Redis.
|
||||
|
||||
Lifecycle:
|
||||
- bootstrap: initialize/refresh replica from DB DSL
|
||||
- load_for_run: read replica before run
|
||||
- commit_after_run: atomically persist run result back to replica
|
||||
"""
|
||||
|
||||
TTL_SECS = 3 * 60 * 60
|
||||
REPLICA_KEY_PREFIX = "canvas:replica"
|
||||
LOCK_KEY_PREFIX = "canvas:replica:lock"
|
||||
LOCK_TIMEOUT_SECS = 10
|
||||
LOCK_BLOCKING_TIMEOUT_SECS = 1
|
||||
LOCK_RETRY_ATTEMPTS = 3
|
||||
LOCK_RETRY_SLEEP_SECS = 0.2
|
||||
|
||||
|
||||
@classmethod
|
||||
def normalize_dsl(cls, dsl):
|
||||
"""Normalize DSL to a JSON-serializable dict. Raise ValueError on invalid input."""
|
||||
normalized = dsl
|
||||
if isinstance(normalized, str):
|
||||
try:
|
||||
normalized = json.loads(normalized)
|
||||
except Exception as e:
|
||||
raise ValueError("Invalid DSL JSON string.") from e
|
||||
|
||||
if not isinstance(normalized, dict):
|
||||
raise ValueError("DSL must be a JSON object.")
|
||||
|
||||
try:
|
||||
return json.loads(json.dumps(normalized, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
raise ValueError("DSL is not JSON-serializable.") from e
|
||||
|
||||
|
||||
@classmethod
|
||||
def _replica_key(cls, canvas_id: str, tenant_id: str, runtime_user_id: str) -> str:
|
||||
return f"{cls.REPLICA_KEY_PREFIX}:{canvas_id}:{tenant_id}:{runtime_user_id}"
|
||||
|
||||
|
||||
@classmethod
|
||||
def _lock_key(cls, canvas_id: str, tenant_id: str, runtime_user_id: str) -> str:
|
||||
return f"{cls.LOCK_KEY_PREFIX}:{canvas_id}:{tenant_id}:{runtime_user_id}"
|
||||
|
||||
|
||||
@classmethod
|
||||
def _read_payload(cls, replica_key: str):
|
||||
"""Read replica payload from Redis; return None on missing/invalid content."""
|
||||
cache_blob = REDIS_CONN.get(replica_key)
|
||||
if not cache_blob:
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(cache_blob)
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
payload["dsl"] = cls.normalize_dsl(payload.get("dsl", {}))
|
||||
return payload
|
||||
except Exception as e:
|
||||
logging.warning("Failed to parse canvas replica %s: %s", replica_key, e)
|
||||
return None
|
||||
|
||||
|
||||
@classmethod
|
||||
def _write_payload(cls, replica_key: str, payload: dict):
|
||||
"""Write payload and refresh TTL."""
|
||||
payload["updated_at"] = int(time.time())
|
||||
REDIS_CONN.set_obj(replica_key, payload, cls.TTL_SECS)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _build_payload(
|
||||
cls,
|
||||
canvas_id: str,
|
||||
tenant_id: str,
|
||||
runtime_user_id: str,
|
||||
dsl,
|
||||
canvas_category=CanvasCategory.Agent,
|
||||
title="",
|
||||
):
|
||||
return {
|
||||
"canvas_id": canvas_id,
|
||||
"tenant_id": str(tenant_id),
|
||||
"runtime_user_id": str(runtime_user_id),
|
||||
"title": title or "",
|
||||
"canvas_category": canvas_category or CanvasCategory.Agent,
|
||||
"dsl": cls.normalize_dsl(dsl),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_if_absent(
|
||||
cls,
|
||||
canvas_id: str,
|
||||
tenant_id: str,
|
||||
runtime_user_id: str,
|
||||
dsl,
|
||||
canvas_category=CanvasCategory.Agent,
|
||||
title="",
|
||||
):
|
||||
"""Create a runtime replica if it does not exist; otherwise keep existing state."""
|
||||
replica_key = cls._replica_key(canvas_id, str(tenant_id), str(runtime_user_id))
|
||||
payload = cls._read_payload(replica_key)
|
||||
if payload:
|
||||
return payload
|
||||
payload = cls._build_payload(canvas_id, str(tenant_id), str(runtime_user_id), dsl, canvas_category, title)
|
||||
cls._write_payload(replica_key, payload)
|
||||
return payload
|
||||
|
||||
|
||||
@classmethod
|
||||
def bootstrap(
|
||||
cls,
|
||||
canvas_id: str,
|
||||
tenant_id: str,
|
||||
runtime_user_id: str,
|
||||
dsl,
|
||||
canvas_category=CanvasCategory.Agent,
|
||||
title="",
|
||||
):
|
||||
"""Bootstrap replica by creating it when absent and keeping existing runtime state."""
|
||||
return cls.create_if_absent(
|
||||
canvas_id=canvas_id,
|
||||
tenant_id=tenant_id,
|
||||
runtime_user_id=runtime_user_id,
|
||||
dsl=dsl,
|
||||
canvas_category=canvas_category,
|
||||
title=title,
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def load_for_run(cls, canvas_id: str, tenant_id: str, runtime_user_id: str):
|
||||
"""Load current runtime replica used by /completion."""
|
||||
replica_key = cls._replica_key(canvas_id, str(tenant_id), str(runtime_user_id))
|
||||
return cls._read_payload(replica_key)
|
||||
|
||||
|
||||
@classmethod
|
||||
def replace_for_set(
|
||||
cls,
|
||||
canvas_id: str,
|
||||
tenant_id: str,
|
||||
runtime_user_id: str,
|
||||
dsl,
|
||||
canvas_category=CanvasCategory.Agent,
|
||||
title="",
|
||||
):
|
||||
"""Replace replica content for `/set` under lock."""
|
||||
replica_key = cls._replica_key(canvas_id, str(tenant_id), str(runtime_user_id))
|
||||
lock_key = cls._lock_key(canvas_id, str(tenant_id), str(runtime_user_id))
|
||||
lock = cls._acquire_lock_with_retry(lock_key)
|
||||
if not lock:
|
||||
logging.error("Failed to acquire canvas replica lock after retry: %s", lock_key)
|
||||
return False
|
||||
|
||||
try:
|
||||
updated_payload = cls._build_payload(
|
||||
canvas_id=canvas_id,
|
||||
tenant_id=str(tenant_id),
|
||||
runtime_user_id=str(runtime_user_id),
|
||||
dsl=dsl,
|
||||
canvas_category=canvas_category,
|
||||
title=title,
|
||||
)
|
||||
cls._write_payload(replica_key, updated_payload)
|
||||
return True
|
||||
except Exception:
|
||||
logging.exception("Failed to replace canvas replica from /set.")
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
lock.release()
|
||||
except Exception:
|
||||
logging.exception("Failed to release canvas replica lock: %s", lock_key)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _acquire_lock_with_retry(cls, lock_key: str):
|
||||
"""Acquire distributed lock with bounded retries; return lock object or None."""
|
||||
lock = RedisDistributedLock(
|
||||
lock_key,
|
||||
timeout=cls.LOCK_TIMEOUT_SECS,
|
||||
blocking_timeout=cls.LOCK_BLOCKING_TIMEOUT_SECS,
|
||||
)
|
||||
for idx in range(cls.LOCK_RETRY_ATTEMPTS):
|
||||
if lock.acquire():
|
||||
return lock
|
||||
if idx < cls.LOCK_RETRY_ATTEMPTS - 1:
|
||||
time.sleep(cls.LOCK_RETRY_SLEEP_SECS + random.uniform(0, 0.1))
|
||||
return None
|
||||
|
||||
|
||||
@classmethod
|
||||
def commit_after_run(
|
||||
cls,
|
||||
canvas_id: str,
|
||||
tenant_id: str,
|
||||
runtime_user_id: str,
|
||||
dsl,
|
||||
canvas_category=CanvasCategory.Agent,
|
||||
title="",
|
||||
):
|
||||
"""
|
||||
Commit post-run DSL into replica.
|
||||
|
||||
Returns:
|
||||
bool: True on committed/saved, False on commit failure.
|
||||
"""
|
||||
new_dsl = cls.normalize_dsl(dsl)
|
||||
replica_key = cls._replica_key(canvas_id, str(tenant_id), str(runtime_user_id))
|
||||
|
||||
try:
|
||||
latest_payload = cls._read_payload(replica_key)
|
||||
|
||||
# Always write latest runtime DSL back to Redis first.
|
||||
updated_payload = cls._build_payload(
|
||||
canvas_id=canvas_id,
|
||||
tenant_id=str(tenant_id),
|
||||
runtime_user_id=str(runtime_user_id),
|
||||
dsl=new_dsl,
|
||||
canvas_category=canvas_category if not latest_payload else (canvas_category or latest_payload.get("canvas_category", CanvasCategory.Agent)),
|
||||
title=title if not latest_payload else (title or latest_payload.get("title", "")),
|
||||
)
|
||||
cls._write_payload(replica_key, updated_payload)
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
logging.exception("Failed to commit canvas runtime replica.")
|
||||
return False
|
||||
@@ -1,3 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from api.db.db_models import UserCanvasVersion, DB
|
||||
from api.db.services.common_service import CommonService
|
||||
from peewee import DoesNotExist
|
||||
@@ -6,6 +10,30 @@ from peewee import DoesNotExist
|
||||
class UserCanvasVersionService(CommonService):
|
||||
model = UserCanvasVersion
|
||||
|
||||
@staticmethod
|
||||
def build_version_title(user_nickname, agent_title, ts=None):
|
||||
tenant = str(user_nickname or "").strip() or "tenant"
|
||||
title = str(agent_title or "").strip() or "agent"
|
||||
stamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(ts)) if ts is not None else time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
return "{0}_{1}_{2}".format(tenant, title, stamp)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_dsl(dsl):
|
||||
normalized = dsl
|
||||
if isinstance(normalized, str):
|
||||
try:
|
||||
normalized = json.loads(normalized)
|
||||
except Exception as e:
|
||||
raise ValueError("Invalid DSL JSON string.") from e
|
||||
|
||||
if not isinstance(normalized, dict):
|
||||
raise ValueError("DSL must be a JSON object.")
|
||||
|
||||
try:
|
||||
return json.loads(json.dumps(normalized, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
raise ValueError("DSL is not JSON-serializable.") from e
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def list_by_canvas_id(cls, user_canvas_id):
|
||||
@@ -59,3 +87,43 @@ class UserCanvasVersionService(CommonService):
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save_or_replace_latest(cls, user_canvas_id, dsl, title=None, description=None):
|
||||
"""
|
||||
Persist a canvas snapshot into version history.
|
||||
|
||||
If the latest version has the same DSL content, update that version in place
|
||||
instead of creating a new row.
|
||||
"""
|
||||
try:
|
||||
normalized_dsl = cls._normalize_dsl(dsl)
|
||||
latest = (
|
||||
cls.model.select()
|
||||
.where(cls.model.user_canvas_id == user_canvas_id)
|
||||
.order_by(cls.model.create_time.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if latest and cls._normalize_dsl(latest.dsl) == normalized_dsl:
|
||||
update_data = {"dsl": normalized_dsl}
|
||||
if title is not None:
|
||||
update_data["title"] = title
|
||||
if description is not None:
|
||||
update_data["description"] = description
|
||||
cls.update_by_id(latest.id, update_data)
|
||||
cls.delete_all_versions(user_canvas_id)
|
||||
return latest.id, False
|
||||
|
||||
insert_data = {"user_canvas_id": user_canvas_id, "dsl": normalized_dsl}
|
||||
if title is not None:
|
||||
insert_data["title"] = title
|
||||
if description is not None:
|
||||
insert_data["description"] = description
|
||||
cls.insert(**insert_data)
|
||||
cls.delete_all_versions(user_canvas_id)
|
||||
return None, True
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return None, None
|
||||
|
||||
@@ -265,6 +265,14 @@ def _load_agents_app(monkeypatch):
|
||||
def delete_all_versions(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def save_or_replace_latest(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def build_version_title(*_args, **_kwargs):
|
||||
return "stub_version_title"
|
||||
|
||||
canvas_version_mod.UserCanvasVersionService = _StubUserCanvasVersionService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.user_canvas_version", canvas_version_mod)
|
||||
services_pkg.user_canvas_version = canvas_version_mod
|
||||
@@ -280,6 +288,67 @@ def _load_agents_app(monkeypatch):
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
|
||||
services_pkg.tenant_llm_service = tenant_llm_service_mod
|
||||
|
||||
user_service_mod = ModuleType("api.db.services.user_service")
|
||||
|
||||
class _StubUserService:
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(_id):
|
||||
return False, None
|
||||
|
||||
user_service_mod.UserService = _StubUserService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
|
||||
services_pkg.user_service = user_service_mod
|
||||
services_pkg.UserService = _StubUserService
|
||||
|
||||
# Stub api.apps package to prevent api/apps/__init__.py from executing
|
||||
# (it triggers heavy imports like quart, settings, DB connections).
|
||||
api_apps_pkg = ModuleType("api.apps")
|
||||
api_apps_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.apps", api_apps_pkg)
|
||||
|
||||
api_apps_services_pkg = ModuleType("api.apps.services")
|
||||
api_apps_services_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.apps.services", api_apps_services_pkg)
|
||||
api_apps_pkg.services = api_apps_services_pkg
|
||||
|
||||
canvas_replica_mod = ModuleType("api.apps.services.canvas_replica_service")
|
||||
|
||||
class _StubCanvasReplicaService:
|
||||
@classmethod
|
||||
def normalize_dsl(cls, dsl):
|
||||
import json
|
||||
if isinstance(dsl, str):
|
||||
return json.loads(dsl)
|
||||
return dsl
|
||||
|
||||
@classmethod
|
||||
def bootstrap(cls, *_args, **_kwargs):
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def load_for_run(cls, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def commit_after_run(cls, *_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def replace_for_set(cls, *_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def create_if_absent(cls, *_args, **_kwargs):
|
||||
return {}
|
||||
|
||||
canvas_replica_mod.CanvasReplicaService = _StubCanvasReplicaService
|
||||
monkeypatch.setitem(sys.modules, "api.apps.services.canvas_replica_service", canvas_replica_mod)
|
||||
api_apps_services_pkg.canvas_replica_service = canvas_replica_mod
|
||||
|
||||
redis_obj = _StubRedisConn()
|
||||
redis_mod = ModuleType("rag.utils.redis_conn")
|
||||
redis_mod.REDIS_CONN = redis_obj
|
||||
@@ -368,7 +437,7 @@ def test_agents_crud_unit_branches(monkeypatch):
|
||||
res = _run(module.update_agent.__wrapped__("tenant-1", "agent-1"))
|
||||
assert res["code"] == module.RetCode.OPERATING_ERROR
|
||||
|
||||
calls = {"update": 0, "insert": 0, "delete_versions": 0}
|
||||
calls = {"update": 0, "save_or_replace_latest": 0}
|
||||
monkeypatch.setattr(module.UserCanvasService, "query", lambda **_kwargs: True)
|
||||
monkeypatch.setattr(
|
||||
module.UserCanvasService,
|
||||
@@ -377,17 +446,12 @@ def test_agents_crud_unit_branches(monkeypatch):
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module.UserCanvasVersionService,
|
||||
"insert",
|
||||
lambda **_kwargs: calls.__setitem__("insert", calls["insert"] + 1),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module.UserCanvasVersionService,
|
||||
"delete_all_versions",
|
||||
lambda *_args, **_kwargs: calls.__setitem__("delete_versions", calls["delete_versions"] + 1),
|
||||
"save_or_replace_latest",
|
||||
lambda *_args, **_kwargs: calls.__setitem__("save_or_replace_latest", calls["save_or_replace_latest"] + 1),
|
||||
)
|
||||
res = _run(module.update_agent.__wrapped__("tenant-1", "agent-1"))
|
||||
assert res["code"] == module.RetCode.SUCCESS
|
||||
assert calls == {"update": 1, "insert": 1, "delete_versions": 1}
|
||||
assert calls == {"update": 1, "save_or_replace_latest": 1}
|
||||
|
||||
monkeypatch.setattr(module.UserCanvasService, "query", lambda **_kwargs: False)
|
||||
res = module.delete_agent.__wrapped__("tenant-1", "agent-1")
|
||||
|
||||
@@ -184,10 +184,50 @@ def _load_canvas_module(monkeypatch):
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = []
|
||||
apps_mod.current_user = SimpleNamespace(id="user-1")
|
||||
apps_mod.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
|
||||
apps_services_pkg = ModuleType("api.apps.services")
|
||||
apps_services_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.apps.services", apps_services_pkg)
|
||||
apps_mod.services = apps_services_pkg
|
||||
|
||||
canvas_replica_mod = ModuleType("api.apps.services.canvas_replica_service")
|
||||
|
||||
class _StubCanvasReplicaService:
|
||||
@classmethod
|
||||
def normalize_dsl(cls, dsl):
|
||||
import json
|
||||
if isinstance(dsl, str):
|
||||
return json.loads(dsl)
|
||||
return dsl
|
||||
|
||||
@classmethod
|
||||
def bootstrap(cls, *_args, **_kwargs):
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def load_for_run(cls, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def commit_after_run(cls, *_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def replace_for_set(cls, *_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def create_if_absent(cls, *_args, **_kwargs):
|
||||
return {}
|
||||
|
||||
canvas_replica_mod.CanvasReplicaService = _StubCanvasReplicaService
|
||||
monkeypatch.setitem(sys.modules, "api.apps.services.canvas_replica_service", canvas_replica_mod)
|
||||
apps_services_pkg.canvas_replica_service = canvas_replica_mod
|
||||
|
||||
db_pkg = ModuleType("api.db")
|
||||
db_pkg.CanvasCategory = _DummyCanvasCategory
|
||||
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
|
||||
@@ -310,6 +350,8 @@ def _load_canvas_module(monkeypatch):
|
||||
delete_all_versions=lambda *_args, **_kwargs: True,
|
||||
list_by_canvas_id=lambda *_args, **_kwargs: [],
|
||||
get_by_id=lambda *_args, **_kwargs: (True, None),
|
||||
save_or_replace_latest=lambda *_args, **_kwargs: True,
|
||||
build_version_title=lambda *_args, **_kwargs: "stub_version_title",
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.user_canvas_version", canvas_version_mod)
|
||||
|
||||
@@ -492,18 +534,12 @@ def test_templates_rm_save_get_matrix_unit(monkeypatch):
|
||||
monkeypatch.setattr(module, "get_uuid", lambda: "canvas-new")
|
||||
monkeypatch.setattr(module.UserCanvasService, "query", lambda **_kwargs: [])
|
||||
monkeypatch.setattr(module.UserCanvasService, "save", lambda **kwargs: created["save"].append(kwargs) or True)
|
||||
monkeypatch.setattr(module.UserCanvasVersionService, "insert", lambda **kwargs: created["versions"].append(("insert", kwargs)))
|
||||
monkeypatch.setattr(
|
||||
module.UserCanvasVersionService,
|
||||
"delete_all_versions",
|
||||
lambda canvas_id: created["versions"].append(("delete", canvas_id)),
|
||||
)
|
||||
monkeypatch.setattr(module.UserCanvasVersionService, "save_or_replace_latest", lambda *_args, **kwargs: created["versions"].append(("save_or_replace_latest", kwargs)))
|
||||
res = _run(inspect.unwrap(module.save)())
|
||||
assert res["code"] == module.RetCode.SUCCESS
|
||||
assert res["data"]["id"] == "canvas-new"
|
||||
assert created["save"]
|
||||
assert any(item[0] == "insert" for item in created["versions"])
|
||||
assert any(item[0] == "delete" for item in created["versions"])
|
||||
assert any(item[0] == "save_or_replace_latest" for item in created["versions"])
|
||||
|
||||
_set_request_json(monkeypatch, module, {"id": "canvas-1", "title": "Renamed", "dsl": "{\"m\": 1}"})
|
||||
monkeypatch.setattr(module.UserCanvasService, "accessible", lambda *_args, **_kwargs: False)
|
||||
@@ -515,13 +551,11 @@ def test_templates_rm_save_get_matrix_unit(monkeypatch):
|
||||
_set_request_json(monkeypatch, module, {"id": "canvas-1", "title": "Renamed", "dsl": "{\"m\": 1}"})
|
||||
monkeypatch.setattr(module.UserCanvasService, "accessible", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.UserCanvasService, "update_by_id", lambda canvas_id, payload: updates.append((canvas_id, payload)))
|
||||
monkeypatch.setattr(module.UserCanvasVersionService, "insert", lambda **kwargs: versions.append(("insert", kwargs)))
|
||||
monkeypatch.setattr(module.UserCanvasVersionService, "delete_all_versions", lambda canvas_id: versions.append(("delete", canvas_id)))
|
||||
monkeypatch.setattr(module.UserCanvasVersionService, "save_or_replace_latest", lambda *_args, **kwargs: versions.append(("save_or_replace_latest", kwargs)))
|
||||
res = _run(inspect.unwrap(module.save)())
|
||||
assert res["code"] == module.RetCode.SUCCESS
|
||||
assert updates and updates[0][0] == "canvas-1"
|
||||
assert any(item[0] == "insert" for item in versions)
|
||||
assert any(item[0] == "delete" for item in versions)
|
||||
assert any(item[0] == "save_or_replace_latest" for item in versions)
|
||||
|
||||
monkeypatch.setattr(module.UserCanvasService, "accessible", lambda *_args, **_kwargs: False)
|
||||
res = module.get("canvas-1")
|
||||
@@ -587,25 +621,16 @@ def test_run_dataflow_and_canvas_sse_matrix_unit(monkeypatch):
|
||||
|
||||
_set_request_json(monkeypatch, module, {"id": "c1"})
|
||||
monkeypatch.setattr(module.UserCanvasService, "accessible", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.UserCanvasService, "get_by_id", lambda _canvas_id: (False, None))
|
||||
monkeypatch.setattr(module.CanvasReplicaService, "load_for_run", lambda *_args, **_kwargs: None)
|
||||
res = _run(inspect.unwrap(module.run)())
|
||||
assert res["message"] == "canvas not found."
|
||||
|
||||
class _CanvasRecord:
|
||||
def __init__(self, *, canvas_id, dsl, canvas_category):
|
||||
self.id = canvas_id
|
||||
self.dsl = dsl
|
||||
self.canvas_category = canvas_category
|
||||
|
||||
def to_dict(self):
|
||||
return {"id": self.id, "dsl": self.dsl}
|
||||
assert res["message"] == "canvas replica not found, please call /get/<canvas_id> first."
|
||||
|
||||
pipeline_calls = []
|
||||
monkeypatch.setattr(module, "Pipeline", lambda *args, **kwargs: pipeline_calls.append((args, kwargs)))
|
||||
monkeypatch.setattr(module, "get_uuid", lambda: "task-1")
|
||||
|
||||
_set_request_json(monkeypatch, module, {"id": "df-1", "files": ["f1"], "user_id": "exp-1"})
|
||||
monkeypatch.setattr(module.UserCanvasService, "get_by_id", lambda _canvas_id: (True, _CanvasRecord(canvas_id="df-1", dsl={"n": 1}, canvas_category=module.CanvasCategory.DataFlow)))
|
||||
monkeypatch.setattr(module.CanvasReplicaService, "load_for_run", lambda *_args, **_kwargs: {"dsl": {"n": 1}, "title": "df", "canvas_category": module.CanvasCategory.DataFlow})
|
||||
monkeypatch.setattr(module, "queue_dataflow", lambda *_args, **_kwargs: (False, "queue failed"))
|
||||
res = _run(inspect.unwrap(module.run)())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
@@ -619,7 +644,7 @@ def test_run_dataflow_and_canvas_sse_matrix_unit(monkeypatch):
|
||||
assert res["data"]["message_id"] == "task-1"
|
||||
|
||||
_set_request_json(monkeypatch, module, {"id": "ag-1", "query": "q", "files": [], "inputs": {}})
|
||||
monkeypatch.setattr(module.UserCanvasService, "get_by_id", lambda _canvas_id: (True, _CanvasRecord(canvas_id="ag-1", dsl={"x": 1}, canvas_category=module.CanvasCategory.Agent)))
|
||||
monkeypatch.setattr(module.CanvasReplicaService, "load_for_run", lambda *_args, **_kwargs: {"dsl": {"x": 1}, "title": "ag", "canvas_category": module.CanvasCategory.Agent})
|
||||
monkeypatch.setattr(module, "Canvas", lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("canvas init failed")))
|
||||
res = _run(inspect.unwrap(module.run)())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
@@ -642,14 +667,13 @@ def test_run_dataflow_and_canvas_sse_matrix_unit(monkeypatch):
|
||||
|
||||
_set_request_json(monkeypatch, module, {"id": "ag-2", "query": "q", "files": [], "inputs": {}, "user_id": "exp-2"})
|
||||
monkeypatch.setattr(module, "Canvas", _CanvasSSESuccess)
|
||||
monkeypatch.setattr(module.UserCanvasService, "get_by_id", lambda _canvas_id: (True, _CanvasRecord(canvas_id="ag-2", dsl="{}", canvas_category=module.CanvasCategory.Agent)))
|
||||
monkeypatch.setattr(module.CanvasReplicaService, "load_for_run", lambda *_args, **_kwargs: {"dsl": {}, "title": "ag2", "canvas_category": module.CanvasCategory.Agent})
|
||||
monkeypatch.setattr(module.UserCanvasService, "update_by_id", lambda canvas_id, payload: updates.append((canvas_id, payload)))
|
||||
resp = _run(inspect.unwrap(module.run)())
|
||||
assert isinstance(resp, _StubResponse)
|
||||
assert resp.headers.get("Content-Type") == "text/event-stream; charset=utf-8"
|
||||
chunks = _run(_collect_stream(resp.response))
|
||||
assert any('"answer": "stream-ok"' in chunk for chunk in chunks)
|
||||
assert updates and updates[0][0] == "ag-2"
|
||||
|
||||
class _CanvasSSEError:
|
||||
last_instance = None
|
||||
@@ -670,7 +694,7 @@ def test_run_dataflow_and_canvas_sse_matrix_unit(monkeypatch):
|
||||
|
||||
_set_request_json(monkeypatch, module, {"id": "ag-3", "query": "q", "files": [], "inputs": {}, "user_id": "exp-3"})
|
||||
monkeypatch.setattr(module, "Canvas", _CanvasSSEError)
|
||||
monkeypatch.setattr(module.UserCanvasService, "get_by_id", lambda _canvas_id: (True, _CanvasRecord(canvas_id="ag-3", dsl="{}", canvas_category=module.CanvasCategory.Agent)))
|
||||
monkeypatch.setattr(module.CanvasReplicaService, "load_for_run", lambda *_args, **_kwargs: {"dsl": {}, "title": "ag3", "canvas_category": module.CanvasCategory.Agent})
|
||||
resp = _run(inspect.unwrap(module.run)())
|
||||
chunks = _run(_collect_stream(resp.response))
|
||||
assert any('"code": 500' in chunk and "stream boom" in chunk for chunk in chunks)
|
||||
|
||||
Reference in New Issue
Block a user