Fix: inconsistent state handling for multi-user single-canvas access (#13267)

### What problem does this PR solve?

<img width="700" alt="image"
src="https://github.com/user-attachments/assets/1db7412e-4554-44bc-84ba-16421949aacc"
/>

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
This commit is contained in:
Magicbook1108
2026-02-28 15:09:21 +08:00
committed by GitHub
parent c91e803a38
commit 1027916bfe
6 changed files with 545 additions and 69 deletions

View File

@@ -40,13 +40,13 @@ from api.utils.api_utils import (
from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase
from api.db.db_models import APIToken, Task
import time
from rag.flow.pipeline import Pipeline
from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN
from common import settings
from api.apps import login_required, current_user
from api.apps.services.canvas_replica_service import CanvasReplicaService
from api.db.services.canvas_service import completion as agent_completion
@@ -75,9 +75,10 @@ async def rm():
@login_required
async def save():
req = await get_request_json()
if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"])
try:
req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"])
except ValueError as e:
return get_data_error_result(message=str(e))
cate = req.get("canvas_category", CanvasCategory.Agent)
if "id" not in req:
req["user_id"] = current_user.id
@@ -93,8 +94,21 @@ async def save():
code=RetCode.OPERATING_ERROR)
UserCanvasService.update_by_id(req["id"], req)
# save version
UserCanvasVersionService.insert(user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
UserCanvasVersionService.delete_all_versions(req["id"])
UserCanvasVersionService.save_or_replace_latest(
user_canvas_id=req["id"],
dsl=req["dsl"],
title=UserCanvasVersionService.build_version_title(getattr(current_user, "nickname", current_user.id), req.get("title")),
)
replica_ok = CanvasReplicaService.replace_for_set(
canvas_id=req["id"],
tenant_id=str(current_user.id),
runtime_user_id=str(current_user.id),
dsl=req["dsl"],
canvas_category=req.get("canvas_category", cate),
title=req.get("title", ""),
)
if not replica_ok:
return get_data_error_result(message="canvas saved, but replica sync failed.")
return get_json_result(data=req)
@@ -104,6 +118,20 @@ def get(canvas_id):
if not UserCanvasService.accessible(canvas_id, current_user.id):
return get_data_error_result(message="canvas not found.")
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
if not e:
return get_data_error_result(message="canvas not found.")
try:
# DELETE
CanvasReplicaService.bootstrap(
canvas_id=canvas_id,
tenant_id=str(current_user.id),
runtime_user_id=str(current_user.id),
dsl=c.get("dsl"),
canvas_category=c.get("canvas_category", CanvasCategory.Agent),
title=c.get("title", ""),
)
except ValueError as e:
return get_data_error_result(message=str(e))
return get_json_result(data=c)
@@ -137,29 +165,38 @@ async def run():
query = req.get("query", "")
files = req.get("files", [])
inputs = req.get("inputs", {})
user_id = req.get("user_id", current_user.id)
if not await thread_pool_exec(UserCanvasService.accessible, req["id"], current_user.id):
tenant_id = str(current_user.id)
runtime_user_id = req.get("user_id") or tenant_id
user_id = str(runtime_user_id)
if not await thread_pool_exec(UserCanvasService.accessible, req["id"], tenant_id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
replica_payload = CanvasReplicaService.load_for_run(
canvas_id=req["id"],
tenant_id=tenant_id,
runtime_user_id=user_id,
)
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
if not replica_payload:
return get_data_error_result(message="canvas replica not found, please call /get/<canvas_id> first.")
if cvs.canvas_category == CanvasCategory.DataFlow:
replica_dsl = replica_payload.get("dsl", {})
canvas_title = replica_payload.get("title", "")
canvas_category = replica_payload.get("canvas_category", CanvasCategory.Agent)
dsl_str = json.dumps(replica_dsl, ensure_ascii=False)
if canvas_category == CanvasCategory.DataFlow:
task_id = get_uuid()
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
Pipeline(dsl_str, tenant_id=tenant_id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
ok, error_message = await thread_pool_exec(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
if not ok:
return get_data_error_result(message=error_message)
return get_json_result(data={"message_id": task_id})
try:
canvas = Canvas(cvs.dsl, current_user.id, canvas_id=cvs.id)
canvas = Canvas(dsl_str, tenant_id, canvas_id=req["id"])
except Exception as e:
return server_error_response(e)
@@ -169,8 +206,21 @@ async def run():
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
commit_ok = CanvasReplicaService.commit_after_run(
canvas_id=req["id"],
tenant_id=tenant_id,
runtime_user_id=user_id,
dsl=json.loads(str(canvas)),
canvas_category=canvas_category,
title=canvas_title,
)
if not commit_ok:
logging.error(
"Canvas runtime replica commit failed: canvas_id=%s tenant_id=%s runtime_user_id=%s",
req["id"],
tenant_id,
user_id,
)
except Exception as e:
logging.exception(e)

View File

@@ -27,9 +27,11 @@ from typing import Any, cast
import jwt
from agent.canvas import Canvas
from api.apps.services.canvas_replica_service import CanvasReplicaService
from api.db import CanvasCategory
from api.db.services.canvas_service import UserCanvasService
from api.db.services.file_service import FileService
from api.db.services.user_service import UserService
from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode
from common.misc_utils import get_uuid
@@ -39,6 +41,13 @@ from quart import request, Response
from rag.utils.redis_conn import REDIS_CONN
def _get_user_nickname(user_id: str) -> str:
exists, user = UserService.get_by_id(user_id)
if not exists:
return user_id
return str(getattr(user, "nickname", "") or user_id)
@manager.route('/agents', methods=['GET']) # noqa: F821
@token_required
def list_agents(tenant_id):
@@ -66,10 +75,10 @@ async def create_agent(tenant_id: str):
req["user_id"] = tenant_id
if req.get("dsl") is not None:
if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"])
try:
req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"])
except ValueError as e:
return get_json_result(data=False, message=str(e), code=RetCode.ARGUMENT_ERROR)
else:
return get_json_result(data=False, message="No DSL data in request.", code=RetCode.ARGUMENT_ERROR)
@@ -87,9 +96,10 @@ async def create_agent(tenant_id: str):
if not UserCanvasService.save(**req):
return get_data_error_result(message="Fail to create agent.")
UserCanvasVersionService.insert(
owner_nickname = _get_user_nickname(tenant_id)
UserCanvasVersionService.save_or_replace_latest(
user_canvas_id=agent_id,
title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")),
title=UserCanvasVersionService.build_version_title(owner_nickname, req.get("title")),
dsl=req["dsl"]
)
@@ -103,10 +113,10 @@ async def update_agent(tenant_id: str, agent_id: str):
req["user_id"] = tenant_id
if req.get("dsl") is not None:
if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"])
try:
req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"])
except ValueError as e:
return get_json_result(data=False, message=str(e), code=RetCode.ARGUMENT_ERROR)
if req.get("title") is not None:
req["title"] = req["title"].strip()
@@ -116,17 +126,19 @@ async def update_agent(tenant_id: str, agent_id: str):
data=False, message="Only owner of canvas authorized for this operation.",
code=RetCode.OPERATING_ERROR)
_, current_agent = UserCanvasService.get_by_id(agent_id)
agent_title_for_version = req.get("title") or (current_agent.title if current_agent else "")
owner_nickname = _get_user_nickname(tenant_id)
UserCanvasService.update_by_id(agent_id, req)
if req.get("dsl") is not None:
UserCanvasVersionService.insert(
UserCanvasVersionService.save_or_replace_latest(
user_canvas_id=agent_id,
title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")),
title=UserCanvasVersionService.build_version_title(owner_nickname, agent_title_for_version),
dsl=req["dsl"]
)
UserCanvasVersionService.delete_all_versions(agent_id)
return get_json_result(data=True)

View File

@@ -0,0 +1,258 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import random
import time
from api.db import CanvasCategory
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
class CanvasReplicaService:
"""
Manage per-user canvas runtime replicas stored in Redis.
Lifecycle:
- bootstrap: initialize/refresh replica from DB DSL
- load_for_run: read replica before run
- commit_after_run: atomically persist run result back to replica
"""
TTL_SECS = 3 * 60 * 60
REPLICA_KEY_PREFIX = "canvas:replica"
LOCK_KEY_PREFIX = "canvas:replica:lock"
LOCK_TIMEOUT_SECS = 10
LOCK_BLOCKING_TIMEOUT_SECS = 1
LOCK_RETRY_ATTEMPTS = 3
LOCK_RETRY_SLEEP_SECS = 0.2
@classmethod
def normalize_dsl(cls, dsl):
"""Normalize DSL to a JSON-serializable dict. Raise ValueError on invalid input."""
normalized = dsl
if isinstance(normalized, str):
try:
normalized = json.loads(normalized)
except Exception as e:
raise ValueError("Invalid DSL JSON string.") from e
if not isinstance(normalized, dict):
raise ValueError("DSL must be a JSON object.")
try:
return json.loads(json.dumps(normalized, ensure_ascii=False))
except Exception as e:
raise ValueError("DSL is not JSON-serializable.") from e
@classmethod
def _replica_key(cls, canvas_id: str, tenant_id: str, runtime_user_id: str) -> str:
return f"{cls.REPLICA_KEY_PREFIX}:{canvas_id}:{tenant_id}:{runtime_user_id}"
@classmethod
def _lock_key(cls, canvas_id: str, tenant_id: str, runtime_user_id: str) -> str:
return f"{cls.LOCK_KEY_PREFIX}:{canvas_id}:{tenant_id}:{runtime_user_id}"
@classmethod
def _read_payload(cls, replica_key: str):
"""Read replica payload from Redis; return None on missing/invalid content."""
cache_blob = REDIS_CONN.get(replica_key)
if not cache_blob:
return None
try:
payload = json.loads(cache_blob)
if not isinstance(payload, dict):
return None
payload["dsl"] = cls.normalize_dsl(payload.get("dsl", {}))
return payload
except Exception as e:
logging.warning("Failed to parse canvas replica %s: %s", replica_key, e)
return None
@classmethod
def _write_payload(cls, replica_key: str, payload: dict):
"""Write payload and refresh TTL."""
payload["updated_at"] = int(time.time())
REDIS_CONN.set_obj(replica_key, payload, cls.TTL_SECS)
@classmethod
def _build_payload(
cls,
canvas_id: str,
tenant_id: str,
runtime_user_id: str,
dsl,
canvas_category=CanvasCategory.Agent,
title="",
):
return {
"canvas_id": canvas_id,
"tenant_id": str(tenant_id),
"runtime_user_id": str(runtime_user_id),
"title": title or "",
"canvas_category": canvas_category or CanvasCategory.Agent,
"dsl": cls.normalize_dsl(dsl),
"updated_at": int(time.time()),
}
@classmethod
def create_if_absent(
cls,
canvas_id: str,
tenant_id: str,
runtime_user_id: str,
dsl,
canvas_category=CanvasCategory.Agent,
title="",
):
"""Create a runtime replica if it does not exist; otherwise keep existing state."""
replica_key = cls._replica_key(canvas_id, str(tenant_id), str(runtime_user_id))
payload = cls._read_payload(replica_key)
if payload:
return payload
payload = cls._build_payload(canvas_id, str(tenant_id), str(runtime_user_id), dsl, canvas_category, title)
cls._write_payload(replica_key, payload)
return payload
@classmethod
def bootstrap(
cls,
canvas_id: str,
tenant_id: str,
runtime_user_id: str,
dsl,
canvas_category=CanvasCategory.Agent,
title="",
):
"""Bootstrap replica by creating it when absent and keeping existing runtime state."""
return cls.create_if_absent(
canvas_id=canvas_id,
tenant_id=tenant_id,
runtime_user_id=runtime_user_id,
dsl=dsl,
canvas_category=canvas_category,
title=title,
)
@classmethod
def load_for_run(cls, canvas_id: str, tenant_id: str, runtime_user_id: str):
"""Load current runtime replica used by /completion."""
replica_key = cls._replica_key(canvas_id, str(tenant_id), str(runtime_user_id))
return cls._read_payload(replica_key)
@classmethod
def replace_for_set(
cls,
canvas_id: str,
tenant_id: str,
runtime_user_id: str,
dsl,
canvas_category=CanvasCategory.Agent,
title="",
):
"""Replace replica content for `/set` under lock."""
replica_key = cls._replica_key(canvas_id, str(tenant_id), str(runtime_user_id))
lock_key = cls._lock_key(canvas_id, str(tenant_id), str(runtime_user_id))
lock = cls._acquire_lock_with_retry(lock_key)
if not lock:
logging.error("Failed to acquire canvas replica lock after retry: %s", lock_key)
return False
try:
updated_payload = cls._build_payload(
canvas_id=canvas_id,
tenant_id=str(tenant_id),
runtime_user_id=str(runtime_user_id),
dsl=dsl,
canvas_category=canvas_category,
title=title,
)
cls._write_payload(replica_key, updated_payload)
return True
except Exception:
logging.exception("Failed to replace canvas replica from /set.")
return False
finally:
try:
lock.release()
except Exception:
logging.exception("Failed to release canvas replica lock: %s", lock_key)
@classmethod
def _acquire_lock_with_retry(cls, lock_key: str):
"""Acquire distributed lock with bounded retries; return lock object or None."""
lock = RedisDistributedLock(
lock_key,
timeout=cls.LOCK_TIMEOUT_SECS,
blocking_timeout=cls.LOCK_BLOCKING_TIMEOUT_SECS,
)
for idx in range(cls.LOCK_RETRY_ATTEMPTS):
if lock.acquire():
return lock
if idx < cls.LOCK_RETRY_ATTEMPTS - 1:
time.sleep(cls.LOCK_RETRY_SLEEP_SECS + random.uniform(0, 0.1))
return None
@classmethod
def commit_after_run(
cls,
canvas_id: str,
tenant_id: str,
runtime_user_id: str,
dsl,
canvas_category=CanvasCategory.Agent,
title="",
):
"""
Commit post-run DSL into replica.
Returns:
bool: True on committed/saved, False on commit failure.
"""
new_dsl = cls.normalize_dsl(dsl)
replica_key = cls._replica_key(canvas_id, str(tenant_id), str(runtime_user_id))
try:
latest_payload = cls._read_payload(replica_key)
# Always write latest runtime DSL back to Redis first.
updated_payload = cls._build_payload(
canvas_id=canvas_id,
tenant_id=str(tenant_id),
runtime_user_id=str(runtime_user_id),
dsl=new_dsl,
canvas_category=canvas_category if not latest_payload else (canvas_category or latest_payload.get("canvas_category", CanvasCategory.Agent)),
title=title if not latest_payload else (title or latest_payload.get("title", "")),
)
cls._write_payload(replica_key, updated_payload)
return True
except Exception:
logging.exception("Failed to commit canvas runtime replica.")
return False

View File

@@ -1,3 +1,7 @@
import json
import logging
import time
from api.db.db_models import UserCanvasVersion, DB
from api.db.services.common_service import CommonService
from peewee import DoesNotExist
@@ -6,6 +10,30 @@ from peewee import DoesNotExist
class UserCanvasVersionService(CommonService):
model = UserCanvasVersion
@staticmethod
def build_version_title(user_nickname, agent_title, ts=None):
tenant = str(user_nickname or "").strip() or "tenant"
title = str(agent_title or "").strip() or "agent"
stamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(ts)) if ts is not None else time.strftime("%Y-%m-%d %H:%M:%S")
return "{0}_{1}_{2}".format(tenant, title, stamp)
@staticmethod
def _normalize_dsl(dsl):
normalized = dsl
if isinstance(normalized, str):
try:
normalized = json.loads(normalized)
except Exception as e:
raise ValueError("Invalid DSL JSON string.") from e
if not isinstance(normalized, dict):
raise ValueError("DSL must be a JSON object.")
try:
return json.loads(json.dumps(normalized, ensure_ascii=False))
except Exception as e:
raise ValueError("DSL is not JSON-serializable.") from e
@classmethod
@DB.connection_context()
def list_by_canvas_id(cls, user_canvas_id):
@@ -59,3 +87,43 @@ class UserCanvasVersionService(CommonService):
return None
except Exception:
return None
@classmethod
@DB.connection_context()
def save_or_replace_latest(cls, user_canvas_id, dsl, title=None, description=None):
"""
Persist a canvas snapshot into version history.
If the latest version has the same DSL content, update that version in place
instead of creating a new row.
"""
try:
normalized_dsl = cls._normalize_dsl(dsl)
latest = (
cls.model.select()
.where(cls.model.user_canvas_id == user_canvas_id)
.order_by(cls.model.create_time.desc())
.first()
)
if latest and cls._normalize_dsl(latest.dsl) == normalized_dsl:
update_data = {"dsl": normalized_dsl}
if title is not None:
update_data["title"] = title
if description is not None:
update_data["description"] = description
cls.update_by_id(latest.id, update_data)
cls.delete_all_versions(user_canvas_id)
return latest.id, False
insert_data = {"user_canvas_id": user_canvas_id, "dsl": normalized_dsl}
if title is not None:
insert_data["title"] = title
if description is not None:
insert_data["description"] = description
cls.insert(**insert_data)
cls.delete_all_versions(user_canvas_id)
return None, True
except Exception as e:
logging.exception(e)
return None, None

View File

@@ -265,6 +265,14 @@ def _load_agents_app(monkeypatch):
def delete_all_versions(*_args, **_kwargs):
return True
@staticmethod
def save_or_replace_latest(*_args, **_kwargs):
return True
@staticmethod
def build_version_title(*_args, **_kwargs):
return "stub_version_title"
canvas_version_mod.UserCanvasVersionService = _StubUserCanvasVersionService
monkeypatch.setitem(sys.modules, "api.db.services.user_canvas_version", canvas_version_mod)
services_pkg.user_canvas_version = canvas_version_mod
@@ -280,6 +288,67 @@ def _load_agents_app(monkeypatch):
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
services_pkg.tenant_llm_service = tenant_llm_service_mod
user_service_mod = ModuleType("api.db.services.user_service")
class _StubUserService:
@staticmethod
def query(**_kwargs):
return []
@staticmethod
def get_by_id(_id):
return False, None
user_service_mod.UserService = _StubUserService
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
services_pkg.user_service = user_service_mod
services_pkg.UserService = _StubUserService
# Stub api.apps package to prevent api/apps/__init__.py from executing
# (it triggers heavy imports like quart, settings, DB connections).
api_apps_pkg = ModuleType("api.apps")
api_apps_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.apps", api_apps_pkg)
api_apps_services_pkg = ModuleType("api.apps.services")
api_apps_services_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.apps.services", api_apps_services_pkg)
api_apps_pkg.services = api_apps_services_pkg
canvas_replica_mod = ModuleType("api.apps.services.canvas_replica_service")
class _StubCanvasReplicaService:
@classmethod
def normalize_dsl(cls, dsl):
import json
if isinstance(dsl, str):
return json.loads(dsl)
return dsl
@classmethod
def bootstrap(cls, *_args, **_kwargs):
return {}
@classmethod
def load_for_run(cls, *_args, **_kwargs):
return None
@classmethod
def commit_after_run(cls, *_args, **_kwargs):
return True
@classmethod
def replace_for_set(cls, *_args, **_kwargs):
return True
@classmethod
def create_if_absent(cls, *_args, **_kwargs):
return {}
canvas_replica_mod.CanvasReplicaService = _StubCanvasReplicaService
monkeypatch.setitem(sys.modules, "api.apps.services.canvas_replica_service", canvas_replica_mod)
api_apps_services_pkg.canvas_replica_service = canvas_replica_mod
redis_obj = _StubRedisConn()
redis_mod = ModuleType("rag.utils.redis_conn")
redis_mod.REDIS_CONN = redis_obj
@@ -368,7 +437,7 @@ def test_agents_crud_unit_branches(monkeypatch):
res = _run(module.update_agent.__wrapped__("tenant-1", "agent-1"))
assert res["code"] == module.RetCode.OPERATING_ERROR
calls = {"update": 0, "insert": 0, "delete_versions": 0}
calls = {"update": 0, "save_or_replace_latest": 0}
monkeypatch.setattr(module.UserCanvasService, "query", lambda **_kwargs: True)
monkeypatch.setattr(
module.UserCanvasService,
@@ -377,17 +446,12 @@ def test_agents_crud_unit_branches(monkeypatch):
)
monkeypatch.setattr(
module.UserCanvasVersionService,
"insert",
lambda **_kwargs: calls.__setitem__("insert", calls["insert"] + 1),
)
monkeypatch.setattr(
module.UserCanvasVersionService,
"delete_all_versions",
lambda *_args, **_kwargs: calls.__setitem__("delete_versions", calls["delete_versions"] + 1),
"save_or_replace_latest",
lambda *_args, **_kwargs: calls.__setitem__("save_or_replace_latest", calls["save_or_replace_latest"] + 1),
)
res = _run(module.update_agent.__wrapped__("tenant-1", "agent-1"))
assert res["code"] == module.RetCode.SUCCESS
assert calls == {"update": 1, "insert": 1, "delete_versions": 1}
assert calls == {"update": 1, "save_or_replace_latest": 1}
monkeypatch.setattr(module.UserCanvasService, "query", lambda **_kwargs: False)
res = module.delete_agent.__wrapped__("tenant-1", "agent-1")

View File

@@ -184,10 +184,50 @@ def _load_canvas_module(monkeypatch):
monkeypatch.setitem(sys.modules, "api", api_pkg)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = []
apps_mod.current_user = SimpleNamespace(id="user-1")
apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
apps_services_pkg = ModuleType("api.apps.services")
apps_services_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.apps.services", apps_services_pkg)
apps_mod.services = apps_services_pkg
canvas_replica_mod = ModuleType("api.apps.services.canvas_replica_service")
class _StubCanvasReplicaService:
@classmethod
def normalize_dsl(cls, dsl):
import json
if isinstance(dsl, str):
return json.loads(dsl)
return dsl
@classmethod
def bootstrap(cls, *_args, **_kwargs):
return {}
@classmethod
def load_for_run(cls, *_args, **_kwargs):
return None
@classmethod
def commit_after_run(cls, *_args, **_kwargs):
return True
@classmethod
def replace_for_set(cls, *_args, **_kwargs):
return True
@classmethod
def create_if_absent(cls, *_args, **_kwargs):
return {}
canvas_replica_mod.CanvasReplicaService = _StubCanvasReplicaService
monkeypatch.setitem(sys.modules, "api.apps.services.canvas_replica_service", canvas_replica_mod)
apps_services_pkg.canvas_replica_service = canvas_replica_mod
db_pkg = ModuleType("api.db")
db_pkg.CanvasCategory = _DummyCanvasCategory
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
@@ -310,6 +350,8 @@ def _load_canvas_module(monkeypatch):
delete_all_versions=lambda *_args, **_kwargs: True,
list_by_canvas_id=lambda *_args, **_kwargs: [],
get_by_id=lambda *_args, **_kwargs: (True, None),
save_or_replace_latest=lambda *_args, **_kwargs: True,
build_version_title=lambda *_args, **_kwargs: "stub_version_title",
)
monkeypatch.setitem(sys.modules, "api.db.services.user_canvas_version", canvas_version_mod)
@@ -492,18 +534,12 @@ def test_templates_rm_save_get_matrix_unit(monkeypatch):
monkeypatch.setattr(module, "get_uuid", lambda: "canvas-new")
monkeypatch.setattr(module.UserCanvasService, "query", lambda **_kwargs: [])
monkeypatch.setattr(module.UserCanvasService, "save", lambda **kwargs: created["save"].append(kwargs) or True)
monkeypatch.setattr(module.UserCanvasVersionService, "insert", lambda **kwargs: created["versions"].append(("insert", kwargs)))
monkeypatch.setattr(
module.UserCanvasVersionService,
"delete_all_versions",
lambda canvas_id: created["versions"].append(("delete", canvas_id)),
)
monkeypatch.setattr(module.UserCanvasVersionService, "save_or_replace_latest", lambda *_args, **kwargs: created["versions"].append(("save_or_replace_latest", kwargs)))
res = _run(inspect.unwrap(module.save)())
assert res["code"] == module.RetCode.SUCCESS
assert res["data"]["id"] == "canvas-new"
assert created["save"]
assert any(item[0] == "insert" for item in created["versions"])
assert any(item[0] == "delete" for item in created["versions"])
assert any(item[0] == "save_or_replace_latest" for item in created["versions"])
_set_request_json(monkeypatch, module, {"id": "canvas-1", "title": "Renamed", "dsl": "{\"m\": 1}"})
monkeypatch.setattr(module.UserCanvasService, "accessible", lambda *_args, **_kwargs: False)
@@ -515,13 +551,11 @@ def test_templates_rm_save_get_matrix_unit(monkeypatch):
_set_request_json(monkeypatch, module, {"id": "canvas-1", "title": "Renamed", "dsl": "{\"m\": 1}"})
monkeypatch.setattr(module.UserCanvasService, "accessible", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.UserCanvasService, "update_by_id", lambda canvas_id, payload: updates.append((canvas_id, payload)))
monkeypatch.setattr(module.UserCanvasVersionService, "insert", lambda **kwargs: versions.append(("insert", kwargs)))
monkeypatch.setattr(module.UserCanvasVersionService, "delete_all_versions", lambda canvas_id: versions.append(("delete", canvas_id)))
monkeypatch.setattr(module.UserCanvasVersionService, "save_or_replace_latest", lambda *_args, **kwargs: versions.append(("save_or_replace_latest", kwargs)))
res = _run(inspect.unwrap(module.save)())
assert res["code"] == module.RetCode.SUCCESS
assert updates and updates[0][0] == "canvas-1"
assert any(item[0] == "insert" for item in versions)
assert any(item[0] == "delete" for item in versions)
assert any(item[0] == "save_or_replace_latest" for item in versions)
monkeypatch.setattr(module.UserCanvasService, "accessible", lambda *_args, **_kwargs: False)
res = module.get("canvas-1")
@@ -587,25 +621,16 @@ def test_run_dataflow_and_canvas_sse_matrix_unit(monkeypatch):
_set_request_json(monkeypatch, module, {"id": "c1"})
monkeypatch.setattr(module.UserCanvasService, "accessible", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.UserCanvasService, "get_by_id", lambda _canvas_id: (False, None))
monkeypatch.setattr(module.CanvasReplicaService, "load_for_run", lambda *_args, **_kwargs: None)
res = _run(inspect.unwrap(module.run)())
assert res["message"] == "canvas not found."
class _CanvasRecord:
def __init__(self, *, canvas_id, dsl, canvas_category):
self.id = canvas_id
self.dsl = dsl
self.canvas_category = canvas_category
def to_dict(self):
return {"id": self.id, "dsl": self.dsl}
assert res["message"] == "canvas replica not found, please call /get/<canvas_id> first."
pipeline_calls = []
monkeypatch.setattr(module, "Pipeline", lambda *args, **kwargs: pipeline_calls.append((args, kwargs)))
monkeypatch.setattr(module, "get_uuid", lambda: "task-1")
_set_request_json(monkeypatch, module, {"id": "df-1", "files": ["f1"], "user_id": "exp-1"})
monkeypatch.setattr(module.UserCanvasService, "get_by_id", lambda _canvas_id: (True, _CanvasRecord(canvas_id="df-1", dsl={"n": 1}, canvas_category=module.CanvasCategory.DataFlow)))
monkeypatch.setattr(module.CanvasReplicaService, "load_for_run", lambda *_args, **_kwargs: {"dsl": {"n": 1}, "title": "df", "canvas_category": module.CanvasCategory.DataFlow})
monkeypatch.setattr(module, "queue_dataflow", lambda *_args, **_kwargs: (False, "queue failed"))
res = _run(inspect.unwrap(module.run)())
assert res["code"] == module.RetCode.DATA_ERROR
@@ -619,7 +644,7 @@ def test_run_dataflow_and_canvas_sse_matrix_unit(monkeypatch):
assert res["data"]["message_id"] == "task-1"
_set_request_json(monkeypatch, module, {"id": "ag-1", "query": "q", "files": [], "inputs": {}})
monkeypatch.setattr(module.UserCanvasService, "get_by_id", lambda _canvas_id: (True, _CanvasRecord(canvas_id="ag-1", dsl={"x": 1}, canvas_category=module.CanvasCategory.Agent)))
monkeypatch.setattr(module.CanvasReplicaService, "load_for_run", lambda *_args, **_kwargs: {"dsl": {"x": 1}, "title": "ag", "canvas_category": module.CanvasCategory.Agent})
monkeypatch.setattr(module, "Canvas", lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("canvas init failed")))
res = _run(inspect.unwrap(module.run)())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
@@ -642,14 +667,13 @@ def test_run_dataflow_and_canvas_sse_matrix_unit(monkeypatch):
_set_request_json(monkeypatch, module, {"id": "ag-2", "query": "q", "files": [], "inputs": {}, "user_id": "exp-2"})
monkeypatch.setattr(module, "Canvas", _CanvasSSESuccess)
monkeypatch.setattr(module.UserCanvasService, "get_by_id", lambda _canvas_id: (True, _CanvasRecord(canvas_id="ag-2", dsl="{}", canvas_category=module.CanvasCategory.Agent)))
monkeypatch.setattr(module.CanvasReplicaService, "load_for_run", lambda *_args, **_kwargs: {"dsl": {}, "title": "ag2", "canvas_category": module.CanvasCategory.Agent})
monkeypatch.setattr(module.UserCanvasService, "update_by_id", lambda canvas_id, payload: updates.append((canvas_id, payload)))
resp = _run(inspect.unwrap(module.run)())
assert isinstance(resp, _StubResponse)
assert resp.headers.get("Content-Type") == "text/event-stream; charset=utf-8"
chunks = _run(_collect_stream(resp.response))
assert any('"answer": "stream-ok"' in chunk for chunk in chunks)
assert updates and updates[0][0] == "ag-2"
class _CanvasSSEError:
last_instance = None
@@ -670,7 +694,7 @@ def test_run_dataflow_and_canvas_sse_matrix_unit(monkeypatch):
_set_request_json(monkeypatch, module, {"id": "ag-3", "query": "q", "files": [], "inputs": {}, "user_id": "exp-3"})
monkeypatch.setattr(module, "Canvas", _CanvasSSEError)
monkeypatch.setattr(module.UserCanvasService, "get_by_id", lambda _canvas_id: (True, _CanvasRecord(canvas_id="ag-3", dsl="{}", canvas_category=module.CanvasCategory.Agent)))
monkeypatch.setattr(module.CanvasReplicaService, "load_for_run", lambda *_args, **_kwargs: {"dsl": {}, "title": "ag3", "canvas_category": module.CanvasCategory.Agent})
resp = _run(inspect.unwrap(module.run)())
chunks = _run(_collect_stream(resp.response))
assert any('"code": 500' in chunk and "stream boom" in chunk for chunk in chunks)