diff --git a/api/apps/restful_apis/agent_api.py b/api/apps/restful_apis/agent_api.py index 0803a2b11b..efaf3285d9 100644 --- a/api/apps/restful_apis/agent_api.py +++ b/api/apps/restful_apis/agent_api.py @@ -616,6 +616,7 @@ async def update_agent_tags(tenant_id, canvas_id): @add_tenant_id_to_kwargs async def create_agent(tenant_id): req = {k: v for k, v in (await get_request_json()).items() if v is not None} + req["canvas_type"] = req.get("canvas_type","") req["user_id"] = tenant_id req["canvas_category"] = req.get("canvas_category") or CanvasCategory.Agent req["release"] = bool(req.get("release", "")) @@ -872,6 +873,7 @@ def delete_agent(agent_id, tenant_id): @_require_canvas_access_async async def update_agent(agent_id, tenant_id): req = {k: v for k, v in (await get_request_json()).items() if v is not None} + req["canvas_type"] = req.get("canvas_type","") req["release"] = bool(req.get("release", "")) if req.get("dsl") is not None: diff --git a/sdk/python/ragflow_sdk/ragflow.py b/sdk/python/ragflow_sdk/ragflow.py index 679f5ba5f3..5228c18c7f 100644 --- a/sdk/python/ragflow_sdk/ragflow.py +++ b/sdk/python/ragflow_sdk/ragflow.py @@ -257,19 +257,35 @@ class RAGFlow: return Agent(self, res["data"]) raise Exception(res["message"]) - def create_agent(self, title: str, dsl: dict, description: str | None = None) -> None: + def create_agent( + self, + title: str, + dsl: dict, + description: str | None = None, + canvas_type: str | None = None, + ) -> None: req = {"title": title, "dsl": dsl} if description is not None: req["description"] = description + if canvas_type is not None: + req["canvas_type"] = canvas_type + res = self.post("/agents", req) res = res.json() if res.get("code") != 0: raise Exception(res["message"]) - def update_agent(self, agent_id: str, title: str | None = None, description: str | None = None, dsl: dict | None = None) -> None: + def update_agent( + self, + agent_id: str, + title: str | None = None, + description: str | None = None, + dsl: dict | None = None, + canvas_type: str | None = None, + ) -> None: req = {} if title is not None: @@ -281,6 +297,9 @@ class RAGFlow: if dsl is not None: req["dsl"] = dsl + if canvas_type is not None: + req["canvas_type"] = canvas_type + res = self.put(f"/agents/{agent_id}", req) res = res.json() diff --git a/test/testcases/test_sdk_api/test_agent_management/test_agent_crud_unit.py b/test/testcases/test_sdk_api/test_agent_management/test_agent_crud_unit.py index 1642c14dde..b74a4dbdbd 100644 --- a/test/testcases/test_sdk_api/test_agent_management/test_agent_crud_unit.py +++ b/test/testcases/test_sdk_api/test_agent_management/test_agent_crud_unit.py @@ -80,6 +80,9 @@ def test_create_agent_payload_and_error(monkeypatch): client.create_agent("agent-title", {"graph": {}}, description="desc") assert calls[-1][1] == {"title": "agent-title", "dsl": {"graph": {}}, "description": "desc"} + client.create_agent("agent-title", {"graph": {}}, canvas_type="Marketing") + assert calls[-1][1] == {"title": "agent-title", "dsl": {"graph": {}}, "canvas_type": "Marketing"} + monkeypatch.setattr(client, "post", lambda *_args, **_kwargs: _DummyResponse({"code": 1, "message": "create boom"})) with pytest.raises(Exception) as exception_info: client.create_agent("agent-title", {"graph": {}}) @@ -104,6 +107,7 @@ def test_update_agent_payload_matrix_and_error(monkeypatch): {"title": "new-title", "description": "new-description", "dsl": {"nodes": []}}, {"title": "new-title", "description": "new-description", "dsl": {"nodes": []}}, ), + ({"canvas_type": "Agent"}, {"canvas_type": "Agent"}), ] for kwargs, expected_payload in cases: client.update_agent("agent-1", **kwargs)