Refa:migrate agent webhook routes to REST APIs (#14330)

### What problem does this PR solve?

migrate agent webhook routes to REST APIs

### Type of change
- [x] Refactoring
This commit is contained in:
buua436
2026-04-24 17:55:53 +08:00
committed by GitHub
parent b8d831c1c3
commit 9ad752f497
8 changed files with 2256 additions and 840 deletions

View File

@@ -14,18 +14,25 @@
# limitations under the License.
#
import inspect
import asyncio
import base64
import copy
import hashlib
import hmac
import inspect
import ipaddress
import json
import logging
import time
from functools import partial
import jwt
from quart import Response, jsonify, request
from agent.component import LLM
from agent.canvas import Canvas
from agent.component import LLM
from agent.dsl_migration import normalize_chunker_dsl
from api.apps import login_required
from api.apps import current_user, login_required
from api.apps.services.canvas_replica_service import CanvasReplicaService
from api.db import CanvasCategory
from api.db.db_models import Task
@@ -52,15 +59,14 @@ from api.utils.api_utils import (
server_error_response,
validate_request,
)
from common import settings
from common.constants import RetCode
from common.misc_utils import get_uuid, thread_pool_exec
from common import settings
from peewee import MySQLDatabase, PostgresqlDatabase
from rag.flow.pipeline import Pipeline
from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN
def _get_user_nickname(user_id: str) -> str:
exists, user = UserService.get_by_id(user_id)
if not exists:
@@ -1045,3 +1051,795 @@ async def agent_chat_completion(tenant_id):
if return_trace and final_ans:
final_ans["data"]["trace"] = trace_items
return get_result(data=final_ans)
@manager.route("/agents/<agent_id>/webhook", methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
@manager.route("/agents/<agent_id>/webhook/test",methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"],) # noqa: F821
async def webhook(agent_id: str):
is_test = request.path.startswith(f"/api/v1/agents/{agent_id}/webhook/test")
start_ts = time.time()
# 1. Fetch canvas by agent_id
exists, cvs = UserCanvasService.get_by_id(agent_id)
if not exists:
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST
# 2. Check canvas category
if cvs.canvas_category == CanvasCategory.DataFlow:
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST
# 3. Load DSL from canvas
dsl = getattr(cvs, "dsl", None)
if not isinstance(dsl, dict):
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
# 4. Check webhook configuration in DSL
webhook_cfg = {}
components = dsl.get("components", {})
for k, _ in components.items():
cpn_obj = components[k]["obj"]
if cpn_obj["component_name"].lower() == "begin" and cpn_obj["params"]["mode"] == "Webhook":
webhook_cfg = cpn_obj["params"]
if not webhook_cfg:
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST
# 5. Validate request method against webhook_cfg.methods
allowed_methods = webhook_cfg.get("methods", [])
request_method = request.method.upper()
if allowed_methods and request_method not in allowed_methods:
return get_data_error_result(
code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook."
),RetCode.BAD_REQUEST
# 6. Validate webhook security
async def validate_webhook_security(security_cfg: dict):
"""Validate webhook security rules based on security configuration."""
if not security_cfg:
return # No security config → allowed by default
# 1. Validate max body size
await _validate_max_body_size(security_cfg)
# 2. Validate IP whitelist
_validate_ip_whitelist(security_cfg)
# # 3. Validate rate limiting
_validate_rate_limit(security_cfg)
# 4. Validate authentication
auth_type = security_cfg.get("auth_type", "none")
if auth_type == "none":
return
if auth_type == "token":
_validate_token_auth(security_cfg)
elif auth_type == "basic":
_validate_basic_auth(security_cfg)
elif auth_type == "jwt":
_validate_jwt_auth(security_cfg)
else:
raise Exception(f"Unsupported auth_type: {auth_type}")
async def _validate_max_body_size(security_cfg):
"""Check request size does not exceed max_body_size."""
max_size = security_cfg.get("max_body_size")
if not max_size:
return
# Convert "10MB" → bytes
units = {"kb": 1024, "mb": 1024**2}
size_str = max_size.lower()
for suffix, factor in units.items():
if size_str.endswith(suffix):
limit = int(size_str.replace(suffix, "")) * factor
break
else:
raise Exception("Invalid max_body_size format")
MAX_LIMIT = 10 * 1024 * 1024 # 10MB
if limit > MAX_LIMIT:
raise Exception("max_body_size exceeds maximum allowed size (10MB)")
content_length = request.content_length or 0
if content_length > limit:
raise Exception(f"Request body too large: {content_length} > {limit}")
def _validate_ip_whitelist(security_cfg):
"""Allow only IPs listed in ip_whitelist."""
whitelist = security_cfg.get("ip_whitelist", [])
if not whitelist:
return
client_ip = request.remote_addr
for rule in whitelist:
if "/" in rule:
# CIDR notation
if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False):
return
else:
# Single IP
if client_ip == rule:
return
raise Exception(f"IP {client_ip} is not allowed by whitelist")
def _validate_rate_limit(security_cfg):
"""Simple in-memory rate limiting."""
rl = security_cfg.get("rate_limit")
if not rl:
return
limit = int(rl.get("limit", 60))
if limit <= 0:
raise Exception("rate_limit.limit must be > 0")
per = rl.get("per", "minute")
window = {
"second": 1,
"minute": 60,
"hour": 3600,
"day": 86400,
}.get(per)
if not window:
raise Exception(f"Invalid rate_limit.per: {per}")
capacity = limit
rate = limit / window
cost = 1
key = f"rl:tb:{agent_id}"
now = time.time()
try:
res = REDIS_CONN.lua_token_bucket(
keys=[key],
args=[capacity, rate, now, cost],
client=REDIS_CONN.REDIS,
)
allowed = int(res[0])
if allowed != 1:
raise Exception("Too many requests (rate limit exceeded)")
except Exception as e:
raise Exception(f"Rate limit error: {e}")
def _validate_token_auth(security_cfg):
"""Validate header-based token authentication."""
token_cfg = security_cfg.get("token",{})
header = token_cfg.get("token_header")
token_value = token_cfg.get("token_value")
provided = request.headers.get(header)
if provided != token_value:
raise Exception("Invalid token authentication")
def _validate_basic_auth(security_cfg):
"""Validate HTTP Basic Auth credentials."""
auth_cfg = security_cfg.get("basic_auth", {})
username = auth_cfg.get("username")
password = auth_cfg.get("password")
auth = request.authorization
if not auth or auth.username != username or auth.password != password:
raise Exception("Invalid Basic Auth credentials")
def _validate_jwt_auth(security_cfg):
"""Validate JWT token in Authorization header."""
jwt_cfg = security_cfg.get("jwt", {})
secret = jwt_cfg.get("secret")
if not secret:
raise Exception("JWT secret not configured")
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise Exception("Missing Bearer token")
token = auth_header[len("Bearer "):].strip()
if not token:
raise Exception("Empty Bearer token")
alg = (jwt_cfg.get("algorithm") or "HS256").upper()
decode_kwargs = {
"key": secret,
"algorithms": [alg],
}
options = {}
if jwt_cfg.get("audience"):
decode_kwargs["audience"] = jwt_cfg["audience"]
options["verify_aud"] = True
else:
options["verify_aud"] = False
if jwt_cfg.get("issuer"):
decode_kwargs["issuer"] = jwt_cfg["issuer"]
options["verify_iss"] = True
else:
options["verify_iss"] = False
try:
decoded = jwt.decode(
token,
options=options,
**decode_kwargs,
)
except Exception as e:
raise Exception(f"Invalid JWT: {str(e)}")
raw_required_claims = jwt_cfg.get("required_claims", [])
if isinstance(raw_required_claims, str):
required_claims = [raw_required_claims]
elif isinstance(raw_required_claims, (list, tuple, set)):
required_claims = list(raw_required_claims)
else:
required_claims = []
required_claims = [
c for c in required_claims
if isinstance(c, str) and c.strip()
]
RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"}
for claim in required_claims:
if claim in RESERVED_CLAIMS:
raise Exception(f"Reserved JWT claim cannot be required: {claim}")
for claim in required_claims:
if claim not in decoded:
raise Exception(f"Missing JWT claim: {claim}")
return decoded
try:
security_config=webhook_cfg.get("security", {})
await validate_webhook_security(security_config)
except Exception as e:
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
if not isinstance(cvs.dsl, str):
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
try:
canvas = Canvas(dsl, cvs.user_id, agent_id, canvas_id=agent_id)
except Exception as e:
resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e))
resp.status_code = RetCode.BAD_REQUEST
return resp
# 7. Parse request body
async def parse_webhook_request(content_type):
"""Parse request based on content-type and return structured data."""
# 1. Query
query_data = {k: v for k, v in request.args.items()}
# 2. Headers
header_data = {k: v for k, v in request.headers.items()}
# 3. Body
ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
if ctype and ctype != content_type:
raise ValueError(
f"Invalid Content-Type: expect '{content_type}', got '{ctype}'"
)
body_data: dict = {}
try:
if ctype == "application/json":
body_data = await request.get_json() or {}
elif ctype == "multipart/form-data":
nonlocal canvas
form = await request.form
files = await request.files
body_data = {}
for key, value in form.items():
body_data[key] = value
if len(files) > 10:
raise Exception("Too many uploaded files")
for key, file in files.items():
desc = FileService.upload_info(
cvs.user_id, # user
file, # FileStorage
None # url (None for webhook)
)
file_parsed= await canvas.get_files_async([desc])
body_data[key] = file_parsed
elif ctype == "application/x-www-form-urlencoded":
form = await request.form
body_data = dict(form)
else:
# text/plain / octet-stream / empty / unknown
raw = await request.get_data()
if raw:
try:
body_data = json.loads(raw.decode("utf-8"))
except Exception:
body_data = {}
else:
body_data = {}
except Exception:
body_data = {}
return {
"query": query_data,
"headers": header_data,
"body": body_data,
"content_type": ctype,
}
def extract_by_schema(data, schema, name="section"):
"""
Extract only fields defined in schema.
Required fields must exist.
Optional fields default to type-based default values.
Type validation included.
"""
props = schema.get("properties", {})
required = schema.get("required", [])
extracted = {}
for field, field_schema in props.items():
field_type = field_schema.get("type")
# 1. Required field missing
if field in required and field not in data:
raise Exception(f"{name} missing required field: {field}")
# 2. Optional → default value
if field not in data:
extracted[field] = default_for_type(field_type)
continue
raw_value = data[field]
# 3. Auto convert value
try:
value = auto_cast_value(raw_value, field_type)
except Exception as e:
raise Exception(f"{name}.{field} auto-cast failed: {str(e)}")
# 4. Type validation
if not validate_type(value, field_type):
raise Exception(
f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
)
extracted[field] = value
return extracted
def default_for_type(t):
"""Return default value for the given schema type."""
if t == "file":
return []
if t == "object":
return {}
if t == "boolean":
return False
if t == "number":
return 0
if t == "string":
return ""
if t and t.startswith("array"):
return []
if t == "null":
return None
return None
def auto_cast_value(value, expected_type):
"""Convert string values into schema type when possible."""
# Non-string values already good
if not isinstance(value, str):
return value
v = value.strip()
# Boolean
if expected_type == "boolean":
if v.lower() in ["true", "1"]:
return True
if v.lower() in ["false", "0"]:
return False
raise Exception(f"Cannot convert '{value}' to boolean")
# Number
if expected_type == "number":
# integer
if v.isdigit() or (v.startswith("-") and v[1:].isdigit()):
return int(v)
# float
try:
return float(v)
except Exception:
raise Exception(f"Cannot convert '{value}' to number")
# Object
if expected_type == "object":
try:
parsed = json.loads(v)
if isinstance(parsed, dict):
return parsed
else:
raise Exception("JSON is not an object")
except Exception:
raise Exception(f"Cannot convert '{value}' to object")
# Array <T>
if expected_type.startswith("array"):
try:
parsed = json.loads(v)
if isinstance(parsed, list):
return parsed
else:
raise Exception("JSON is not an array")
except Exception:
raise Exception(f"Cannot convert '{value}' to array")
# String (accept original)
if expected_type == "string":
return value
# File
if expected_type == "file":
return value
# Default: do nothing
return value
def validate_type(value, t):
"""Validate value type against schema type t."""
if t == "file":
return isinstance(value, list)
if t == "string":
return isinstance(value, str)
if t == "number":
return isinstance(value, (int, float))
if t == "boolean":
return isinstance(value, bool)
if t == "object":
return isinstance(value, dict)
# array<string> / array<number> / array<object>
if t.startswith("array"):
if not isinstance(value, list):
return False
if "<" in t and ">" in t:
inner = t[t.find("<") + 1 : t.find(">")]
# Check each element type
for item in value:
if not validate_type(item, inner):
return False
return True
return True
parsed = await parse_webhook_request(webhook_cfg.get("content_types"))
SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
# Extract strictly by schema
try:
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
except Exception as e:
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
clean_request = {
"query": query_clean,
"headers": header_clean,
"body": body_clean,
"input": parsed
}
execution_mode = webhook_cfg.get("execution_mode", "Immediately")
response_cfg = webhook_cfg.get("response", {})
def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600):
key = f"webhook-trace-{agent_id}-logs"
raw = REDIS_CONN.get(key)
obj = json.loads(raw) if raw else {"webhooks": {}}
ws = obj["webhooks"].setdefault(
str(start_ts),
{"start_ts": start_ts, "events": []}
)
ws["events"].append({
"ts": time.time(),
**event
})
REDIS_CONN.set_obj(key, obj, ttl)
if execution_mode == "Immediately":
status = response_cfg.get("status", 200)
try:
status = int(status)
except (TypeError, ValueError):
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST
if not (200 <= status <= 399):
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST
body_tpl = response_cfg.get("body_template", "")
def parse_body(body: str):
if not body:
return None, "application/json"
try:
parsed = json.loads(body)
return parsed, "application/json"
except (json.JSONDecodeError, TypeError):
return body, "text/plain"
body, content_type = parse_body(body_tpl)
resp = Response(
json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body,
status=status,
content_type=content_type,
)
async def background_run():
try:
async for ans in canvas.run(
query="",
user_id=cvs.user_id,
webhook_payload=clean_request
):
if is_test:
append_webhook_trace(agent_id, start_ts, ans)
if is_test:
append_webhook_trace(
agent_id,
start_ts,
{
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": True,
}
)
cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(cvs.user_id, cvs.to_dict())
except Exception as e:
logging.exception("Webhook background run failed")
if is_test:
try:
append_webhook_trace(
agent_id,
start_ts,
{
"event": "error",
"message": str(e),
"error_type": type(e).__name__,
}
)
append_webhook_trace(
agent_id,
start_ts,
{
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": False,
}
)
except Exception:
logging.exception("Failed to append webhook trace")
asyncio.create_task(background_run())
return resp
else:
async def sse():
nonlocal canvas
contents: list[str] = []
status = 200
try:
async for ans in canvas.run(
query="",
user_id=cvs.user_id,
webhook_payload=clean_request,
):
if ans["event"] == "message":
content = ans["data"]["content"]
if ans["data"].get("start_to_think", False):
content = "<think>"
elif ans["data"].get("end_to_think", False):
content = "</think>"
if content:
contents.append(content)
if ans["event"] == "message_end":
status = int(ans["data"].get("status", status))
if is_test:
append_webhook_trace(
agent_id,
start_ts,
ans
)
if is_test:
append_webhook_trace(
agent_id,
start_ts,
{
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": True,
}
)
final_content = "".join(contents)
return {
"message": final_content,
"success": True,
"code": status,
}
except Exception as e:
if is_test:
append_webhook_trace(
agent_id,
start_ts,
{
"event": "error",
"message": str(e),
"error_type": type(e).__name__,
}
)
append_webhook_trace(
agent_id,
start_ts,
{
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": False,
}
)
return {"code": 400, "message": str(e),"success":False}
result = await sse()
return Response(
json.dumps(result),
status=result["code"],
mimetype="application/json",
)
@manager.route("/agents/<agent_id>/webhook/logs", methods=["GET"]) # noqa: F821
@login_required
async def webhook_trace(agent_id: str):
exists, cvs = UserCanvasService.get_by_id(agent_id)
if not exists or str(cvs.user_id) != str(current_user.id):
return get_data_error_result(
message="Canvas not found.",
)
def encode_webhook_id(start_ts: str) -> str:
WEBHOOK_ID_SECRET = "webhook_id_secret"
sig = hmac.new(
WEBHOOK_ID_SECRET.encode("utf-8"),
start_ts.encode("utf-8"),
hashlib.sha256,
).digest()
return base64.urlsafe_b64encode(sig).decode("utf-8").rstrip("=")
def decode_webhook_id(enc_id: str, webhooks: dict) -> str | None:
for ts in webhooks.keys():
if encode_webhook_id(ts) == enc_id:
return ts
return None
since_ts = request.args.get("since_ts", type=float)
webhook_id = request.args.get("webhook_id")
key = f"webhook-trace-{agent_id}-logs"
raw = REDIS_CONN.get(key)
if since_ts is None:
now = time.time()
return get_json_result(
data={
"webhook_id": None,
"events": [],
"next_since_ts": now,
"finished": False,
}
)
if not raw:
return get_json_result(
data={
"webhook_id": None,
"events": [],
"next_since_ts": since_ts,
"finished": False,
}
)
obj = json.loads(raw)
webhooks = obj.get("webhooks", {})
if webhook_id is None:
candidates = [
float(k) for k in webhooks.keys() if float(k) > since_ts
]
if not candidates:
return get_json_result(
data={
"webhook_id": None,
"events": [],
"next_since_ts": since_ts,
"finished": False,
}
)
start_ts = min(candidates)
real_id = str(start_ts)
webhook_id = encode_webhook_id(real_id)
return get_json_result(
data={
"webhook_id": webhook_id,
"events": [],
"next_since_ts": start_ts,
"finished": False,
}
)
real_id = decode_webhook_id(webhook_id, webhooks)
if not real_id:
return get_json_result(
data={
"webhook_id": webhook_id,
"events": [],
"next_since_ts": since_ts,
"finished": True,
}
)
ws = webhooks.get(str(real_id))
events = ws.get("events", [])
new_events = [e for e in events if e.get("ts", 0) > since_ts]
next_ts = since_ts
for e in new_events:
next_ts = max(next_ts, e["ts"])
finished = any(e.get("event") == "finished" for e in new_events)
return get_json_result(
data={
"webhook_id": webhook_id,
"events": new_events,
"next_since_ts": next_ts,
"finished": finished,
}
)

View File

@@ -1,819 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import base64
import hashlib
import hmac
import ipaddress
import json
import logging
import time
import jwt
from agent.canvas import Canvas
from api.db import CanvasCategory
from api.db.services.canvas_service import UserCanvasService
from api.db.services.file_service import FileService
from common.constants import RetCode
from api.utils.api_utils import get_data_error_result, get_json_result
from quart import request, Response
from rag.utils.redis_conn import REDIS_CONN
@manager.route("/webhook/<agent_id>", methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
@manager.route("/webhook_test/<agent_id>",methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"],) # noqa: F821
async def webhook(agent_id: str):
is_test = request.path.startswith("/api/v1/webhook_test")
start_ts = time.time()
# 1. Fetch canvas by agent_id
exists, cvs = UserCanvasService.get_by_id(agent_id)
if not exists:
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST
# 2. Check canvas category
if cvs.canvas_category == CanvasCategory.DataFlow:
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST
# 3. Load DSL from canvas
dsl = getattr(cvs, "dsl", None)
if not isinstance(dsl, dict):
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
# 4. Check webhook configuration in DSL
webhook_cfg = {}
components = dsl.get("components", {})
for k, _ in components.items():
cpn_obj = components[k]["obj"]
if cpn_obj["component_name"].lower() == "begin" and cpn_obj["params"]["mode"] == "Webhook":
webhook_cfg = cpn_obj["params"]
if not webhook_cfg:
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST
# 5. Validate request method against webhook_cfg.methods
allowed_methods = webhook_cfg.get("methods", [])
request_method = request.method.upper()
if allowed_methods and request_method not in allowed_methods:
return get_data_error_result(
code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook."
),RetCode.BAD_REQUEST
# 6. Validate webhook security
async def validate_webhook_security(security_cfg: dict):
"""Validate webhook security rules based on security configuration."""
if not security_cfg:
return # No security config → allowed by default
# 1. Validate max body size
await _validate_max_body_size(security_cfg)
# 2. Validate IP whitelist
_validate_ip_whitelist(security_cfg)
# # 3. Validate rate limiting
_validate_rate_limit(security_cfg)
# 4. Validate authentication
auth_type = security_cfg.get("auth_type", "none")
if auth_type == "none":
return
if auth_type == "token":
_validate_token_auth(security_cfg)
elif auth_type == "basic":
_validate_basic_auth(security_cfg)
elif auth_type == "jwt":
_validate_jwt_auth(security_cfg)
else:
raise Exception(f"Unsupported auth_type: {auth_type}")
async def _validate_max_body_size(security_cfg):
"""Check request size does not exceed max_body_size."""
max_size = security_cfg.get("max_body_size")
if not max_size:
return
# Convert "10MB" → bytes
units = {"kb": 1024, "mb": 1024**2}
size_str = max_size.lower()
for suffix, factor in units.items():
if size_str.endswith(suffix):
limit = int(size_str.replace(suffix, "")) * factor
break
else:
raise Exception("Invalid max_body_size format")
MAX_LIMIT = 10 * 1024 * 1024 # 10MB
if limit > MAX_LIMIT:
raise Exception("max_body_size exceeds maximum allowed size (10MB)")
content_length = request.content_length or 0
if content_length > limit:
raise Exception(f"Request body too large: {content_length} > {limit}")
def _validate_ip_whitelist(security_cfg):
"""Allow only IPs listed in ip_whitelist."""
whitelist = security_cfg.get("ip_whitelist", [])
if not whitelist:
return
client_ip = request.remote_addr
for rule in whitelist:
if "/" in rule:
# CIDR notation
if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False):
return
else:
# Single IP
if client_ip == rule:
return
raise Exception(f"IP {client_ip} is not allowed by whitelist")
def _validate_rate_limit(security_cfg):
"""Simple in-memory rate limiting."""
rl = security_cfg.get("rate_limit")
if not rl:
return
limit = int(rl.get("limit", 60))
if limit <= 0:
raise Exception("rate_limit.limit must be > 0")
per = rl.get("per", "minute")
window = {
"second": 1,
"minute": 60,
"hour": 3600,
"day": 86400,
}.get(per)
if not window:
raise Exception(f"Invalid rate_limit.per: {per}")
capacity = limit
rate = limit / window
cost = 1
key = f"rl:tb:{agent_id}"
now = time.time()
try:
res = REDIS_CONN.lua_token_bucket(
keys=[key],
args=[capacity, rate, now, cost],
client=REDIS_CONN.REDIS,
)
allowed = int(res[0])
if allowed != 1:
raise Exception("Too many requests (rate limit exceeded)")
except Exception as e:
raise Exception(f"Rate limit error: {e}")
def _validate_token_auth(security_cfg):
"""Validate header-based token authentication."""
token_cfg = security_cfg.get("token",{})
header = token_cfg.get("token_header")
token_value = token_cfg.get("token_value")
provided = request.headers.get(header)
if provided != token_value:
raise Exception("Invalid token authentication")
def _validate_basic_auth(security_cfg):
"""Validate HTTP Basic Auth credentials."""
auth_cfg = security_cfg.get("basic_auth", {})
username = auth_cfg.get("username")
password = auth_cfg.get("password")
auth = request.authorization
if not auth or auth.username != username or auth.password != password:
raise Exception("Invalid Basic Auth credentials")
def _validate_jwt_auth(security_cfg):
"""Validate JWT token in Authorization header."""
jwt_cfg = security_cfg.get("jwt", {})
secret = jwt_cfg.get("secret")
if not secret:
raise Exception("JWT secret not configured")
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise Exception("Missing Bearer token")
token = auth_header[len("Bearer "):].strip()
if not token:
raise Exception("Empty Bearer token")
alg = (jwt_cfg.get("algorithm") or "HS256").upper()
decode_kwargs = {
"key": secret,
"algorithms": [alg],
}
options = {}
if jwt_cfg.get("audience"):
decode_kwargs["audience"] = jwt_cfg["audience"]
options["verify_aud"] = True
else:
options["verify_aud"] = False
if jwt_cfg.get("issuer"):
decode_kwargs["issuer"] = jwt_cfg["issuer"]
options["verify_iss"] = True
else:
options["verify_iss"] = False
try:
decoded = jwt.decode(
token,
options=options,
**decode_kwargs,
)
except Exception as e:
raise Exception(f"Invalid JWT: {str(e)}")
raw_required_claims = jwt_cfg.get("required_claims", [])
if isinstance(raw_required_claims, str):
required_claims = [raw_required_claims]
elif isinstance(raw_required_claims, (list, tuple, set)):
required_claims = list(raw_required_claims)
else:
required_claims = []
required_claims = [
c for c in required_claims
if isinstance(c, str) and c.strip()
]
RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"}
for claim in required_claims:
if claim in RESERVED_CLAIMS:
raise Exception(f"Reserved JWT claim cannot be required: {claim}")
for claim in required_claims:
if claim not in decoded:
raise Exception(f"Missing JWT claim: {claim}")
return decoded
try:
security_config=webhook_cfg.get("security", {})
await validate_webhook_security(security_config)
except Exception as e:
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
if not isinstance(cvs.dsl, str):
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
try:
canvas = Canvas(dsl, cvs.user_id, agent_id, canvas_id=agent_id)
except Exception as e:
resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e))
resp.status_code = RetCode.BAD_REQUEST
return resp
# 7. Parse request body
async def parse_webhook_request(content_type):
"""Parse request based on content-type and return structured data."""
# 1. Query
query_data = {k: v for k, v in request.args.items()}
# 2. Headers
header_data = {k: v for k, v in request.headers.items()}
# 3. Body
ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
if ctype and ctype != content_type:
raise ValueError(
f"Invalid Content-Type: expect '{content_type}', got '{ctype}'"
)
body_data: dict = {}
try:
if ctype == "application/json":
body_data = await request.get_json() or {}
elif ctype == "multipart/form-data":
nonlocal canvas
form = await request.form
files = await request.files
body_data = {}
for key, value in form.items():
body_data[key] = value
if len(files) > 10:
raise Exception("Too many uploaded files")
for key, file in files.items():
desc = FileService.upload_info(
cvs.user_id, # user
file, # FileStorage
None # url (None for webhook)
)
file_parsed= await canvas.get_files_async([desc])
body_data[key] = file_parsed
elif ctype == "application/x-www-form-urlencoded":
form = await request.form
body_data = dict(form)
else:
# text/plain / octet-stream / empty / unknown
raw = await request.get_data()
if raw:
try:
body_data = json.loads(raw.decode("utf-8"))
except Exception:
body_data = {}
else:
body_data = {}
except Exception:
body_data = {}
return {
"query": query_data,
"headers": header_data,
"body": body_data,
"content_type": ctype,
}
def extract_by_schema(data, schema, name="section"):
"""
Extract only fields defined in schema.
Required fields must exist.
Optional fields default to type-based default values.
Type validation included.
"""
props = schema.get("properties", {})
required = schema.get("required", [])
extracted = {}
for field, field_schema in props.items():
field_type = field_schema.get("type")
# 1. Required field missing
if field in required and field not in data:
raise Exception(f"{name} missing required field: {field}")
# 2. Optional → default value
if field not in data:
extracted[field] = default_for_type(field_type)
continue
raw_value = data[field]
# 3. Auto convert value
try:
value = auto_cast_value(raw_value, field_type)
except Exception as e:
raise Exception(f"{name}.{field} auto-cast failed: {str(e)}")
# 4. Type validation
if not validate_type(value, field_type):
raise Exception(
f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
)
extracted[field] = value
return extracted
def default_for_type(t):
"""Return default value for the given schema type."""
if t == "file":
return []
if t == "object":
return {}
if t == "boolean":
return False
if t == "number":
return 0
if t == "string":
return ""
if t and t.startswith("array"):
return []
if t == "null":
return None
return None
def auto_cast_value(value, expected_type):
"""Convert string values into schema type when possible."""
# Non-string values already good
if not isinstance(value, str):
return value
v = value.strip()
# Boolean
if expected_type == "boolean":
if v.lower() in ["true", "1"]:
return True
if v.lower() in ["false", "0"]:
return False
raise Exception(f"Cannot convert '{value}' to boolean")
# Number
if expected_type == "number":
# integer
if v.isdigit() or (v.startswith("-") and v[1:].isdigit()):
return int(v)
# float
try:
return float(v)
except Exception:
raise Exception(f"Cannot convert '{value}' to number")
# Object
if expected_type == "object":
try:
parsed = json.loads(v)
if isinstance(parsed, dict):
return parsed
else:
raise Exception("JSON is not an object")
except Exception:
raise Exception(f"Cannot convert '{value}' to object")
# Array <T>
if expected_type.startswith("array"):
try:
parsed = json.loads(v)
if isinstance(parsed, list):
return parsed
else:
raise Exception("JSON is not an array")
except Exception:
raise Exception(f"Cannot convert '{value}' to array")
# String (accept original)
if expected_type == "string":
return value
# File
if expected_type == "file":
return value
# Default: do nothing
return value
def validate_type(value, t):
"""Validate value type against schema type t."""
if t == "file":
return isinstance(value, list)
if t == "string":
return isinstance(value, str)
if t == "number":
return isinstance(value, (int, float))
if t == "boolean":
return isinstance(value, bool)
if t == "object":
return isinstance(value, dict)
# array<string> / array<number> / array<object>
if t.startswith("array"):
if not isinstance(value, list):
return False
if "<" in t and ">" in t:
inner = t[t.find("<") + 1 : t.find(">")]
# Check each element type
for item in value:
if not validate_type(item, inner):
return False
return True
return True
parsed = await parse_webhook_request(webhook_cfg.get("content_types"))
SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
# Extract strictly by schema
try:
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
except Exception as e:
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
clean_request = {
"query": query_clean,
"headers": header_clean,
"body": body_clean,
"input": parsed
}
execution_mode = webhook_cfg.get("execution_mode", "Immediately")
response_cfg = webhook_cfg.get("response", {})
def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600):
key = f"webhook-trace-{agent_id}-logs"
raw = REDIS_CONN.get(key)
obj = json.loads(raw) if raw else {"webhooks": {}}
ws = obj["webhooks"].setdefault(
str(start_ts),
{"start_ts": start_ts, "events": []}
)
ws["events"].append({
"ts": time.time(),
**event
})
REDIS_CONN.set_obj(key, obj, ttl)
if execution_mode == "Immediately":
status = response_cfg.get("status", 200)
try:
status = int(status)
except (TypeError, ValueError):
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST
if not (200 <= status <= 399):
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST
body_tpl = response_cfg.get("body_template", "")
def parse_body(body: str):
if not body:
return None, "application/json"
try:
parsed = json.loads(body)
return parsed, "application/json"
except (json.JSONDecodeError, TypeError):
return body, "text/plain"
body, content_type = parse_body(body_tpl)
resp = Response(
json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body,
status=status,
content_type=content_type,
)
async def background_run():
try:
async for ans in canvas.run(
query="",
user_id=cvs.user_id,
webhook_payload=clean_request
):
if is_test:
append_webhook_trace(agent_id, start_ts, ans)
if is_test:
append_webhook_trace(
agent_id,
start_ts,
{
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": True,
}
)
cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(cvs.user_id, cvs.to_dict())
except Exception as e:
logging.exception("Webhook background run failed")
if is_test:
try:
append_webhook_trace(
agent_id,
start_ts,
{
"event": "error",
"message": str(e),
"error_type": type(e).__name__,
}
)
append_webhook_trace(
agent_id,
start_ts,
{
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": False,
}
)
except Exception:
logging.exception("Failed to append webhook trace")
asyncio.create_task(background_run())
return resp
else:
async def sse():
nonlocal canvas
contents: list[str] = []
status = 200
try:
async for ans in canvas.run(
query="",
user_id=cvs.user_id,
webhook_payload=clean_request,
):
if ans["event"] == "message":
content = ans["data"]["content"]
if ans["data"].get("start_to_think", False):
content = "<think>"
elif ans["data"].get("end_to_think", False):
content = "</think>"
if content:
contents.append(content)
if ans["event"] == "message_end":
status = int(ans["data"].get("status", status))
if is_test:
append_webhook_trace(
agent_id,
start_ts,
ans
)
if is_test:
append_webhook_trace(
agent_id,
start_ts,
{
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": True,
}
)
final_content = "".join(contents)
return {
"message": final_content,
"success": True,
"code": status,
}
except Exception as e:
if is_test:
append_webhook_trace(
agent_id,
start_ts,
{
"event": "error",
"message": str(e),
"error_type": type(e).__name__,
}
)
append_webhook_trace(
agent_id,
start_ts,
{
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": False,
}
)
return {"code": 400, "message": str(e),"success":False}
result = await sse()
return Response(
json.dumps(result),
status=result["code"],
mimetype="application/json",
)
@manager.route("/webhook_trace/<agent_id>", methods=["GET"]) # noqa: F821
async def webhook_trace(agent_id: str):
def encode_webhook_id(start_ts: str) -> str:
WEBHOOK_ID_SECRET = "webhook_id_secret"
sig = hmac.new(
WEBHOOK_ID_SECRET.encode("utf-8"),
start_ts.encode("utf-8"),
hashlib.sha256,
).digest()
return base64.urlsafe_b64encode(sig).decode("utf-8").rstrip("=")
def decode_webhook_id(enc_id: str, webhooks: dict) -> str | None:
for ts in webhooks.keys():
if encode_webhook_id(ts) == enc_id:
return ts
return None
since_ts = request.args.get("since_ts", type=float)
webhook_id = request.args.get("webhook_id")
key = f"webhook-trace-{agent_id}-logs"
raw = REDIS_CONN.get(key)
if since_ts is None:
now = time.time()
return get_json_result(
data={
"webhook_id": None,
"events": [],
"next_since_ts": now,
"finished": False,
}
)
if not raw:
return get_json_result(
data={
"webhook_id": None,
"events": [],
"next_since_ts": since_ts,
"finished": False,
}
)
obj = json.loads(raw)
webhooks = obj.get("webhooks", {})
if webhook_id is None:
candidates = [
float(k) for k in webhooks.keys() if float(k) > since_ts
]
if not candidates:
return get_json_result(
data={
"webhook_id": None,
"events": [],
"next_since_ts": since_ts,
"finished": False,
}
)
start_ts = min(candidates)
real_id = str(start_ts)
webhook_id = encode_webhook_id(real_id)
return get_json_result(
data={
"webhook_id": webhook_id,
"events": [],
"next_since_ts": start_ts,
"finished": False,
}
)
real_id = decode_webhook_id(webhook_id, webhooks)
if not real_id:
return get_json_result(
data={
"webhook_id": webhook_id,
"events": [],
"next_since_ts": since_ts,
"finished": True,
}
)
ws = webhooks.get(str(real_id))
events = ws.get("events", [])
new_events = [e for e in events if e.get("ts", 0) > since_ts]
next_ts = since_ts
for e in new_events:
next_ts = max(next_ts, e["ts"])
finished = any(e.get("event") == "finished" for e in new_events)
return get_json_result(
data={
"webhook_id": webhook_id,
"events": new_events,
"next_since_ts": next_ts,
"finished": finished,
}
)

View File

@@ -552,6 +552,7 @@ def _load_agent_api_module(monkeypatch):
api_apps_mod = ModuleType("api.apps")
api_apps_mod.__path__ = [str(repo_root / "api" / "apps")]
api_apps_mod.current_user = SimpleNamespace(id="tenant-1")
api_apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", api_apps_mod)

File diff suppressed because it is too large Load Diff

View File

@@ -90,11 +90,17 @@ class TestChunksList:
res = update_chunk(WebApiAuth, dataset_id, document_id, chunk_id, {"content": "unchanged content", "available": False})
assert res["code"] == 0, res
from time import sleep
from time import sleep, time
sleep(1)
res = list_chunks(WebApiAuth, dataset_id, document_id, params={"available": "false"})
assert res["code"] == 0, res
deadline = time() + 5
res = None
while time() < deadline:
res = list_chunks(WebApiAuth, dataset_id, document_id, params={"available": "false"})
assert res["code"] == 0, res
if res["data"]["chunks"]:
break
sleep(0.5)
assert res is not None
assert len(res["data"]["chunks"]) >= 1, res
assert all(chunk["available"] is False for chunk in res["data"]["chunks"]), res
@@ -104,20 +110,23 @@ class TestChunksList:
@pytest.mark.p2
@pytest.mark.parametrize(
"params, expected_page_size",
"params, expected_page_size, minimum_page_size",
[
({"keywords": None}, 5),
({"keywords": ""}, 5),
({"keywords": "1"}, 1),
pytest.param({"keywords": "chunk"}, 4, marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6509")),
({"keywords": "unknown"}, 0),
({"keywords": None}, 5, None),
({"keywords": ""}, 5, None),
({"keywords": "1"}, 1, None),
pytest.param({"keywords": "chunk"}, None, 3, marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6509")),
({"keywords": "unknown"}, 0, None),
],
)
def test_keywords(self, WebApiAuth, add_chunks, params, expected_page_size):
def test_keywords(self, WebApiAuth, add_chunks, params, expected_page_size, minimum_page_size):
dataset_id, document_id, _ = add_chunks
res = list_chunks(WebApiAuth, dataset_id, document_id, params=params)
assert res["code"] == 0, res
assert len(res["data"]["chunks"]) == expected_page_size, res
if minimum_page_size is not None:
assert len(res["data"]["chunks"]) >= minimum_page_size, res
else:
assert len(res["data"]["chunks"]) == expected_page_size, res
@pytest.mark.p3
def test_invalid_params(self, WebApiAuth, add_chunks):

View File

@@ -3,6 +3,6 @@ import { useParams } from 'react-router';
export function useBuildWebhookUrl() {
const { id } = useParams();
const text = `${location.protocol}//${location.host}/api/v1/webhook/${id}`;
const text = `${location.protocol}//${location.host}/api/v1/agents/${id}/webhook`;
return text;
}

View File

@@ -28,7 +28,7 @@ enum WebhookTraceTabType {
const WebhookSheet = ({ hideModal }: RunSheetProps) => {
const { t } = useTranslation();
const { id } = useParams();
const text = `${location.protocol}//${location.host}/api/v1/webhook_test/${id}`;
const text = `${location.protocol}//${location.host}/api/v1/agents/${id}/webhook/test`;
const { data } = useFetchWebhookTrace(true);

View File

@@ -211,8 +211,8 @@ export default {
prompt: `${restAPIv1}/agents/prompts`,
cancelDataflow: (id: string) => `${webAPI}/canvas/cancel/${id}`,
downloadFile: `${restAPIv1}/agents/download`,
testWebhook: (id: string) => `${restAPIv1}/webhook_test/${id}`,
fetchWebhookTrace: (id: string) => `${restAPIv1}/webhook_trace/${id}`,
testWebhook: (id: string) => `${restAPIv1}/agents/${id}/webhook/test`,
fetchWebhookTrace: (id: string) => `${restAPIv1}/agents/${id}/webhook/logs`,
// explore