diff --git a/agent/sandbox/executor_manager/models/schemas.py b/agent/sandbox/executor_manager/models/schemas.py index 9baa94b5f7..ed50c26a18 100644 --- a/agent/sandbox/executor_manager/models/schemas.py +++ b/agent/sandbox/executor_manager/models/schemas.py @@ -14,7 +14,7 @@ # limitations under the License. # import base64 -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel, Field, field_validator @@ -28,6 +28,12 @@ class ArtifactItem(BaseModel): content_b64: str +class ExecutionStructuredResult(BaseModel): + present: bool + value: Any = None + type: str = "json" + + class CodeExecutionResult(BaseModel): status: ResultStatus stdout: str @@ -47,6 +53,9 @@ class CodeExecutionResult(BaseModel): # File artifacts produced by code execution (images, PDFs, CSVs, etc.) artifacts: list[ArtifactItem] = [] + # Structured return value produced by main() + result: Optional[ExecutionStructuredResult] = None + class CodeExecutionRequest(BaseModel): code_b64: str = Field(..., description="Base64 encoded code string") diff --git a/agent/sandbox/executor_manager/services/execution.py b/agent/sandbox/executor_manager/services/execution.py index 358d122c26..cf78cdc7c2 100644 --- a/agent/sandbox/executor_manager/services/execution.py +++ b/agent/sandbox/executor_manager/services/execution.py @@ -19,14 +19,42 @@ import json import os import time import uuid - from core.config import TIMEOUT from core.container import allocate_container_blocking, release_container from core.logger import logger from models.enums import ResourceLimitType, ResultStatus, RuntimeErrorType, SupportLanguage, UnauthorizedAccessType -from models.schemas import ArtifactItem, CodeExecutionRequest, CodeExecutionResult +from models.schemas import ArtifactItem, CodeExecutionRequest, CodeExecutionResult, ExecutionStructuredResult from utils.common import async_run_command +RESULT_MARKER_PREFIX = "__RAGFLOW_RESULT__:" + + +def _extract_result_envelope(stdout: str) -> tuple[str, ExecutionStructuredResult | None]: + if not stdout: + return "", None + + cleaned_lines: list[str] = [] + envelope: ExecutionStructuredResult | None = None + + for line in str(stdout).splitlines(): + if line.startswith(RESULT_MARKER_PREFIX): + payload_b64 = line[len(RESULT_MARKER_PREFIX) :].strip() + if not payload_b64: + continue + try: + payload = base64.b64decode(payload_b64).decode("utf-8") + envelope = ExecutionStructuredResult.model_validate_json(payload) + except Exception as exc: + logger.warning(f"Failed to decode structured result marker: {exc}") + cleaned_lines.append(line) + continue + cleaned_lines.append(line) + + cleaned_stdout = "\n".join(cleaned_lines) + if stdout.endswith("\n") and cleaned_stdout and not cleaned_stdout.endswith("\n"): + cleaned_stdout += "\n" + return cleaned_stdout, envelope + async def execute_code(req: CodeExecutionRequest): """Fully asynchronous execution logic""" @@ -48,15 +76,14 @@ async def execute_code(req: CodeExecutionRequest): try: if language == SupportLanguage.PYTHON: code_name = "main.py" - # code code_path = os.path.join(workdir, code_name) with open(code_path, "wb") as f: f.write(base64.b64decode(req.code_b64)) - # runner runner_name = "runner.py" runner_path = os.path.join(workdir, runner_name) with open(runner_path, "w") as f: - f.write("""import json + f.write(f"""import base64 +import json import os import sys @@ -65,33 +92,64 @@ os.makedirs(os.path.join(os.getcwd(), "artifacts"), exist_ok=True) sys.path.insert(0, os.path.dirname(__file__)) from main import main +RESULT_MARKER_PREFIX = {RESULT_MARKER_PREFIX!r} + + +def emit_result(value): + payload = json.dumps( + {{ + "present": True, + "value": value, + "type": "json", + }}, + ensure_ascii=False, + separators=(",", ":"), + ) + print(RESULT_MARKER_PREFIX + base64.b64encode(payload.encode("utf-8")).decode("ascii")) + + if __name__ == "__main__": args = json.loads(sys.argv[1]) result = main(**args) - if result is not None: - print(result) + emit_result(result) """) elif language == SupportLanguage.NODEJS: code_name = "main.js" - code_path = os.path.join(workdir, "main.js") + code_path = os.path.join(workdir, code_name) with open(code_path, "wb") as f: f.write(base64.b64decode(req.code_b64)) runner_name = "runner.js" runner_path = os.path.join(workdir, "runner.js") with open(runner_path, "w") as f: - f.write(""" + runner_code = """ const fs = require('fs'); const path = require('path'); const args = JSON.parse(process.argv[2]); const mainPath = path.join(__dirname, 'main.js'); +const RESULT_MARKER_PREFIX = '__RESULT_MARKER_PREFIX__'; function isPromise(value) { return Boolean(value && typeof value.then === 'function'); } +function emitResult(value) { + if (typeof value === 'undefined') { + console.error('Error: main() must return a value. Use null for an empty result.'); + process.exit(1); + } + + const payload = JSON.stringify({ present: true, value, type: 'json' }); + if (typeof payload === 'undefined') { + console.error('Error: main() returned a non-JSON-serializable value.'); + process.exit(1); + } + + console.log(RESULT_MARKER_PREFIX + Buffer.from(payload, 'utf8').toString('base64')); +} + if (fs.existsSync(mainPath)) { const mod = require(mainPath); const main = typeof mod === 'function' ? mod : mod.main; @@ -103,40 +161,38 @@ if (fs.existsSync(mainPath)) { if (typeof args === 'object' && args !== null) { try { - const result = main(args); + const result = Promise.resolve(main(args)); if (isPromise(result)) { result.then(output => { - if (output !== null) { - console.log(output); - } + emitResult(output); }).catch(err => { console.error('Error in async main function:', err); + process.exit(1); }); } else { - if (result !== null) { - console.log(result); - } + emitResult(result); } } catch (err) { console.error('Error when executing main:', err); + process.exit(1); } } else { console.error('Error: args is not a valid object:', args); + process.exit(1); } } else { console.error('main.js not found in the current directory'); + process.exit(1); } -""") - # dirs +""" + f.write(runner_code.replace("__RESULT_MARKER_PREFIX__", RESULT_MARKER_PREFIX)) returncode, _, stderr = await async_run_command("docker", "exec", container, "mkdir", "-p", f"/workspace/{task_id}", timeout=5) if returncode != 0: raise RuntimeError(f"Directory creation failed: {stderr}") - # archive tar_proc = await asyncio.create_subprocess_exec("tar", "czf", "-", "-C", workdir, code_name, runner_name, stdout=asyncio.subprocess.PIPE) tar_stdout, _ = await tar_proc.communicate() - # unarchive docker_proc = await asyncio.create_subprocess_exec( "docker", "exec", "-i", container, "tar", "xzf", "-", "-C", f"/workspace/{task_id}", stdin=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) @@ -145,7 +201,6 @@ if (fs.existsSync(mainPath)) { if docker_proc.returncode != 0: raise RuntimeError(stderr.decode()) - # exec start_time = time.time() try: logger.info(f"Passed in args: {req.arguments}") @@ -160,11 +215,10 @@ if (fs.existsSync(mainPath)) { str(TIMEOUT), language, ] - # flags if language == SupportLanguage.PYTHON: run_args.extend(["-I", "-B"]) elif language == SupportLanguage.NODEJS: - run_args.extend([]) + pass # no additional flags else: assert False, "Will never reach here" run_args.extend([runner_name, args_json]) @@ -184,14 +238,16 @@ if (fs.existsSync(mainPath)) { logger.info(f"{args_json=}") if returncode == 0: + clean_stdout, structured_result = _extract_result_envelope(stdout) artifacts = await _collect_artifacts(container, task_id, workdir) return CodeExecutionResult( status=ResultStatus.SUCCESS, - stdout=str(stdout), + stdout=clean_stdout, stderr=stderr, exit_code=0, time_used_ms=time_used_ms, artifacts=artifacts, + result=structured_result, ) elif returncode == 124: return CodeExecutionResult( @@ -229,7 +285,6 @@ if (fs.existsSync(mainPath)) { return CodeExecutionResult(status=ResultStatus.PROGRAM_RUNNER_ERROR, stdout="", stderr=str(e), exit_code=-3, detail="internal_error") finally: - # cleanup cleanup_tasks = [async_run_command("docker", "exec", container, "rm", "-rf", f"/workspace/{task_id}"), async_run_command("rm", "-rf", workdir)] await asyncio.gather(*cleanup_tasks, return_exceptions=True) await release_container(container, language) diff --git a/agent/sandbox/executor_manager/services/security.py b/agent/sandbox/executor_manager/services/security.py index cbe1ca27e1..13a02ced2e 100644 --- a/agent/sandbox/executor_manager/services/security.py +++ b/agent/sandbox/executor_manager/services/security.py @@ -14,6 +14,7 @@ # limitations under the License. # import ast +import re from typing import List, Tuple from core.logger import logger @@ -151,6 +152,26 @@ class SecurePythonAnalyzer(ast.NodeVisitor): self.generic_visit(node) +class SecureJavaScriptAnalyzer: + DANGEROUS_PATTERNS = [ + (re.compile(r"""require\s*\(\s*['"]child_process['"]\s*\)"""), "Require: child_process"), + (re.compile(r"""require\s*\(\s*['"]fs['"]\s*\)"""), "Require: fs"), + (re.compile(r"""require\s*\(\s*['"]worker_threads['"]\s*\)"""), "Require: worker_threads"), + (re.compile(r"""\beval\s*\("""), "Call: eval"), + (re.compile(r"""\bFunction\s*\("""), "Call: Function"), + (re.compile(r"""\bprocess\s*\.\s*binding\s*\("""), "Call: process.binding"), + ] + + @classmethod + def analyze(cls, code: str) -> List[Tuple[str, int]]: + issues: List[Tuple[str, int]] = [] + for pattern, description in cls.DANGEROUS_PATTERNS: + for match in pattern.finditer(code): + lineno = code.count("\n", 0, match.start()) + 1 + issues.append((description, lineno)) + return issues + + def analyze_code_security(code: str, language: SupportLanguage) -> Tuple[bool, List[Tuple[str, int]]]: """ Analyze the provided code string and return whether it's safe and why. @@ -168,6 +189,9 @@ def analyze_code_security(code: str, language: SupportLanguage) -> Tuple[bool, L except Exception as e: logger.error(f"[SafeCheck] Python parsing failed: {str(e)}") return False, [(f"Parsing Error: {str(e)}", -1)] - else: - logger.warning(f"[SafeCheck] Unsupported language for security analysis: {language} — defaulting to SAFE (manual review recommended)") - return True, [(f"Unsupported language for security analysis: {language} — defaulted to SAFE, manual review recommended", -1)] + if language == SupportLanguage.NODEJS: + issues = SecureJavaScriptAnalyzer.analyze(code) + return len(issues) == 0, issues + + logger.warning(f"[SafeCheck] Unsupported language for security analysis: {language}") + return False, [(f"Unsupported language for security analysis: {language}", -1)] diff --git a/agent/sandbox/providers/aliyun_codeinterpreter.py b/agent/sandbox/providers/aliyun_codeinterpreter.py index 8d9eba691e..8ee99ed1ec 100644 --- a/agent/sandbox/providers/aliyun_codeinterpreter.py +++ b/agent/sandbox/providers/aliyun_codeinterpreter.py @@ -30,6 +30,8 @@ https://api.aliyun.com/api/AgentRun/2025-09-10/CreateSandbox?lang=PYTHON import logging import os import time +import base64 +import json from typing import Dict, Any, List, Optional from datetime import datetime, timezone @@ -40,6 +42,7 @@ from agentrun.utils.exception import ServerError from .base import SandboxProvider, SandboxInstance, ExecutionResult logger = logging.getLogger(__name__) +RESULT_MARKER_PREFIX = "__RAGFLOW_RESULT__:" class AliyunCodeInterpreterProvider(SandboxProvider): @@ -51,9 +54,9 @@ class AliyunCodeInterpreterProvider(SandboxProvider): """ def __init__(self): - self.access_key_id: Optional[str] = None - self.access_key_secret: Optional[str] = None - self.account_id: Optional[str] = None + self.access_key_id: Optional[str] = "" + self.access_key_secret: Optional[str] = "" + self.account_id: Optional[str] = "" self.region: str = "cn-hangzhou" self.template_name: str = "" self.timeout: int = 30 @@ -146,8 +149,6 @@ class AliyunCodeInterpreterProvider(SandboxProvider): try: # Get or create template - from agentrun.sandbox import Sandbox - if self.template_name: # Use existing template template_name = self.template_name @@ -226,48 +227,17 @@ class AliyunCodeInterpreterProvider(SandboxProvider): # Connect to existing sandbox instance sandbox = Sandbox.connect(sandbox_id=instance_id, config=self._config) - # Convert language string to CodeLanguage enum - code_language = CodeLanguage.PYTHON if normalized_lang == "python" else CodeLanguage.JAVASCRIPT + # agentrun-sdk 0.0.26 only exposes CodeLanguage.PYTHON; keep JS as string fallback. + code_language = CodeLanguage.PYTHON if normalized_lang == "python" else "javascript" # Wrap code to call main() function # Matches self_managed provider behavior: call main(**arguments) - if normalized_lang == "python": - # Build arguments string for main() call - if arguments: - import json as json_module - args_json = json_module.dumps(arguments) - wrapped_code = f'''{code} - -if __name__ == "__main__": - import json - result = main(**{args_json}) - print(json.dumps(result) if isinstance(result, dict) else result) -''' - else: - wrapped_code = f'''{code} - -if __name__ == "__main__": - import json - result = main() - print(json.dumps(result) if isinstance(result, dict) else result) -''' - else: # javascript - if arguments: - import json as json_module - args_json = json_module.dumps(arguments) - wrapped_code = f'''{code} - -// Call main and output result -const result = main({args_json}); -console.log(typeof result === 'object' ? JSON.stringify(result) : String(result)); -''' - else: - wrapped_code = f'''{code} - -// Call main and output result -const result = main(); -console.log(typeof result === 'object' ? JSON.stringify(result) : String(result)); -''' + args_json = json.dumps(arguments or {}) + wrapped_code = ( + self._build_python_wrapper(code, args_json) + if normalized_lang == "python" + else self._build_javascript_wrapper(code, args_json) + ) logger.debug(f"Aliyun Code Interpreter: Wrapped code (first 200 chars): {wrapped_code[:200]}") start_time = time.time() @@ -314,6 +284,7 @@ console.log(typeof result === 'object' ? JSON.stringify(result) : String(result) stdout = "\n".join(stdout_parts) stderr = "\n".join(stderr_parts) + stdout, structured_result = self._extract_structured_result(stdout) logger.info(f"Aliyun Code Interpreter: stdout length={len(stdout)}, stderr length={len(stderr)}, exit_code={exit_code}") if stdout: @@ -331,6 +302,9 @@ console.log(typeof result === 'object' ? JSON.stringify(result) : String(result) "language": normalized_lang, "context_id": result.get("contextId") if isinstance(result, dict) else None, "timeout": timeout, + "result_present": structured_result.get("present", False), + "result_value": structured_result.get("value"), + "result_type": structured_result.get("type"), }, ) @@ -390,6 +364,71 @@ console.log(typeof result === 'object' ? JSON.stringify(result) : String(result) # If we get any response (even an error), the service is reachable return "connection" not in str(e).lower() + @staticmethod + def _build_python_wrapper(code: str, args_json: str) -> str: + marker = RESULT_MARKER_PREFIX + return f'''{code} + +if __name__ == "__main__": + import base64 + import json + + result = main(**{args_json}) + payload = json.dumps({{"present": True, "value": result, "type": "json"}}, ensure_ascii=False, separators=(",", ":")) + print("{marker}" + base64.b64encode(payload.encode("utf-8")).decode("ascii")) +''' + + @staticmethod + def _build_javascript_wrapper(code: str, args_json: str) -> str: + marker = RESULT_MARKER_PREFIX + return f'''{code} + +const __ragflowArgs = {args_json}; + +(async () => {{ + try {{ + const output = await Promise.resolve(main(__ragflowArgs)); + if (typeof output === 'undefined') {{ + throw new Error('main() must return a value. Use null for an empty result.'); + }} + const payload = JSON.stringify({{ present: true, value: output, type: 'json' }}); + if (typeof payload === 'undefined') {{ + throw new Error('main() returned a non-JSON-serializable value.'); + }} + console.log('{marker}' + Buffer.from(payload, 'utf8').toString('base64')); + }} catch (err) {{ + console.error(err instanceof Error ? err.stack || err.message : String(err)); + }} +}})(); +''' + + @staticmethod + def _extract_structured_result(stdout: str) -> tuple[str, Dict[str, Any]]: + if not stdout: + return "", {} + + cleaned_lines: list[str] = [] + structured_result: Dict[str, Any] = {} + + for line in str(stdout).splitlines(): + if line.startswith(RESULT_MARKER_PREFIX): + payload_b64 = line[len(RESULT_MARKER_PREFIX) :].strip() + if not payload_b64: + continue + try: + payload = base64.b64decode(payload_b64).decode("utf-8") + structured_result = json.loads(payload) + except Exception as exc: + logger.warning(f"Aliyun Code Interpreter: failed to decode structured result marker: {exc}") + cleaned_lines.append(line) + continue + cleaned_lines.append(line) + + cleaned_stdout = "\n".join(cleaned_lines) + if stdout.endswith("\n") and cleaned_stdout and not cleaned_stdout.endswith("\n"): + cleaned_stdout += "\n" + return cleaned_stdout, structured_result + def get_supported_languages(self) -> List[str]: """ Get list of supported programming languages. diff --git a/agent/sandbox/providers/self_managed.py b/agent/sandbox/providers/self_managed.py index d4e0c6d687..29d8d80e19 100644 --- a/agent/sandbox/providers/self_managed.py +++ b/agent/sandbox/providers/self_managed.py @@ -187,6 +187,7 @@ class SelfManagedProvider(SandboxProvider): ) result = response.json() + structured_result = result.get("result") or {} return ExecutionResult( stdout=result.get("stdout", ""), @@ -200,6 +201,9 @@ class SelfManagedProvider(SandboxProvider): "detail": result.get("detail"), "instance_id": instance_id, "artifacts": result.get("artifacts", []), + "result_present": structured_result.get("present", False), + "result_value": structured_result.get("value"), + "result_type": structured_result.get("type"), } ) diff --git a/agent/sandbox/tests/test_aliyun_codeinterpreter.py b/agent/sandbox/tests/test_aliyun_codeinterpreter.py index 9b4a369b57..3d598da8ff 100644 --- a/agent/sandbox/tests/test_aliyun_codeinterpreter.py +++ b/agent/sandbox/tests/test_aliyun_codeinterpreter.py @@ -101,13 +101,15 @@ class TestAliyunCodeInterpreterProvider: assert provider.region == "cn-hangzhou" assert provider.template_name == "" - @patch("agent.sandbox.providers.aliyun_codeinterpreter.CodeInterpreterSandbox") - def test_create_instance_python(self, mock_sandbox_class): + @patch("agent.sandbox.providers.aliyun_codeinterpreter.Template") + @patch("agent.sandbox.providers.aliyun_codeinterpreter.Sandbox") + def test_create_instance_python(self, mock_sandbox_class, mock_template): """Test creating a Python instance.""" # Mock successful instance creation mock_sandbox = MagicMock() mock_sandbox.sandbox_id = "01JCED8Z9Y6XQVK8M2NRST5WXY" - mock_sandbox_class.return_value = mock_sandbox + mock_sandbox_class.create.return_value = mock_sandbox + mock_template.get_by_name.return_value = MagicMock() provider = AliyunCodeInterpreterProvider() provider._initialized = True @@ -119,12 +121,14 @@ class TestAliyunCodeInterpreterProvider: assert instance.status == "READY" assert instance.metadata["language"] == "python" - @patch("agent.sandbox.providers.aliyun_codeinterpreter.CodeInterpreterSandbox") - def test_create_instance_javascript(self, mock_sandbox_class): + @patch("agent.sandbox.providers.aliyun_codeinterpreter.Template") + @patch("agent.sandbox.providers.aliyun_codeinterpreter.Sandbox") + def test_create_instance_javascript(self, mock_sandbox_class, mock_template): """Test creating a JavaScript instance.""" mock_sandbox = MagicMock() mock_sandbox.sandbox_id = "01JCED8Z9Y6XQVK8M2NRST5WXY" - mock_sandbox_class.return_value = mock_sandbox + mock_sandbox_class.create.return_value = mock_sandbox + mock_template.get_by_name.return_value = MagicMock() provider = AliyunCodeInterpreterProvider() provider._initialized = True @@ -141,7 +145,7 @@ class TestAliyunCodeInterpreterProvider: with pytest.raises(RuntimeError, match="Provider not initialized"): provider.create_instance("python") - @patch("agent.sandbox.providers.aliyun_codeinterpreter.CodeInterpreterSandbox") + @patch("agent.sandbox.providers.aliyun_codeinterpreter.Sandbox") def test_execute_code_success(self, mock_sandbox_class): """Test successful code execution.""" # Mock sandbox instance @@ -150,7 +154,7 @@ class TestAliyunCodeInterpreterProvider: "results": [{"type": "stdout", "text": "Hello, World!"}, {"type": "result", "text": "None"}, {"type": "endOfExecution", "status": "ok"}], "contextId": "kernel-12345-67890", } - mock_sandbox_class.return_value = mock_sandbox + mock_sandbox_class.connect.return_value = mock_sandbox provider = AliyunCodeInterpreterProvider() provider._initialized = True @@ -163,14 +167,14 @@ class TestAliyunCodeInterpreterProvider: assert result.exit_code == 0 assert result.execution_time > 0 - @patch("agent.sandbox.providers.aliyun_codeinterpreter.CodeInterpreterSandbox") + @patch("agent.sandbox.providers.aliyun_codeinterpreter.Sandbox") def test_execute_code_timeout(self, mock_sandbox_class): """Test code execution timeout.""" from agentrun.utils.exception import ServerError mock_sandbox = MagicMock() mock_sandbox.context.execute.side_effect = ServerError(408, "Request timeout") - mock_sandbox_class.return_value = mock_sandbox + mock_sandbox_class.connect.return_value = mock_sandbox provider = AliyunCodeInterpreterProvider() provider._initialized = True @@ -179,14 +183,14 @@ class TestAliyunCodeInterpreterProvider: with pytest.raises(TimeoutError, match="Execution timed out"): provider.execute_code(instance_id="01JCED8Z9Y6XQVK8M2NRST5WXY", code="while True: pass", language="python", timeout=5) - @patch("agent.sandbox.providers.aliyun_codeinterpreter.CodeInterpreterSandbox") + @patch("agent.sandbox.providers.aliyun_codeinterpreter.Sandbox") def test_execute_code_with_error(self, mock_sandbox_class): """Test code execution with error.""" mock_sandbox = MagicMock() mock_sandbox.context.execute.return_value = { "results": [{"type": "stderr", "text": "Traceback..."}, {"type": "error", "text": "NameError: name 'x' is not defined"}, {"type": "endOfExecution", "status": "error"}] } - mock_sandbox_class.return_value = mock_sandbox + mock_sandbox_class.connect.return_value = mock_sandbox provider = AliyunCodeInterpreterProvider() provider._initialized = True @@ -197,6 +201,34 @@ class TestAliyunCodeInterpreterProvider: assert result.exit_code != 0 assert len(result.stderr) > 0 + @patch("agent.sandbox.providers.aliyun_codeinterpreter.Sandbox") + def test_execute_code_uses_structured_result_marker_for_async_javascript(self, mock_sandbox_class): + """Test JavaScript wrapper uses the structured result marker and awaits async main.""" + mock_sandbox = MagicMock() + mock_sandbox.context.execute.return_value = { + "results": [{"type": "stdout", "text": "__RAGFLOW_RESULT__:eyJwcmVzZW50Ijp0cnVlLCJ2YWx1ZSI6eyJhIjoiYiJ9LCJ0eXBlIjoianNvbiJ9"}], + "contextId": "kernel-12345-67890", + } + mock_sandbox_class.connect.return_value = mock_sandbox + + provider = AliyunCodeInterpreterProvider() + provider._initialized = True + provider._config = MagicMock() + + result = provider.execute_code( + instance_id="01JCED8Z9Y6XQVK8M2NRST5WXY", + code="async function main(args) { return { a: 'b' }; }", + language="javascript", + timeout=10, + ) + + wrapped_code = mock_sandbox.context.execute.call_args.kwargs["code"] + assert "__RAGFLOW_RESULT__:" in wrapped_code + assert "await Promise.resolve(main(" in wrapped_code + assert result.metadata["result_present"] is True + assert result.metadata["result_value"] == {"a": "b"} + assert result.metadata["result_type"] == "json" + def test_get_supported_languages(self): """Test getting supported languages.""" provider = AliyunCodeInterpreterProvider() diff --git a/agent/sandbox/tests/test_providers.py b/agent/sandbox/tests/test_providers.py index fa2e97ad02..cf90bb79ab 100644 --- a/agent/sandbox/tests/test_providers.py +++ b/agent/sandbox/tests/test_providers.py @@ -254,6 +254,41 @@ class TestSelfManagedProvider: assert result.metadata["status"] == "success" assert result.metadata["instance_id"] == "test-123" + @patch('requests.post') + def test_execute_code_maps_structured_result_into_metadata(self, mock_post): + """Test successful code execution with structured result envelope.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "status": "success", + "stdout": "debug line\n", + "stderr": "", + "exit_code": 0, + "time_used_ms": 100.0, + "memory_used_kb": 1024.0, + "result": { + "present": True, + "value": {"items": ["a", "b"]}, + "type": "json", + }, + } + mock_post.return_value = mock_response + + provider = SelfManagedProvider() + provider._initialized = True + + result = provider.execute_code( + instance_id="test-123", + code="def main(): return {'items': ['a', 'b']}", + language="python", + timeout=10 + ) + + assert result.stdout == "debug line\n" + assert result.metadata["result_present"] is True + assert result.metadata["result_value"] == {"items": ["a", "b"]} + assert result.metadata["result_type"] == "json" + @patch('requests.post') def test_execute_code_timeout(self, mock_post): """Test code execution timeout.""" diff --git a/agent/sandbox/tests/test_security.py b/agent/sandbox/tests/test_security.py new file mode 100644 index 0000000000..ed096894e4 --- /dev/null +++ b/agent/sandbox/tests/test_security.py @@ -0,0 +1,55 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys +from pathlib import Path + + +EXECUTOR_MANAGER_ROOT = Path(__file__).resolve().parents[1] / "executor_manager" +if str(EXECUTOR_MANAGER_ROOT) not in sys.path: + sys.path.insert(0, str(EXECUTOR_MANAGER_ROOT)) + +from models.enums import SupportLanguage # noqa: E402 +from services.security import analyze_code_security # noqa: E402 + + +def test_javascript_child_process_is_rejected(): + is_safe, issues = analyze_code_security( + "const cp = require('child_process'); async function main() { return 'ok'; }", + SupportLanguage.NODEJS, + ) + + assert is_safe is False + assert any("child_process" in issue for issue, _ in issues) + + +def test_javascript_eval_is_rejected(): + is_safe, issues = analyze_code_security( + "async function main() { return eval('1+1'); }", + SupportLanguage.NODEJS, + ) + + assert is_safe is False + assert any("eval" in issue.lower() for issue, _ in issues) + + +def test_javascript_safe_code_still_passes(): + is_safe, issues = analyze_code_security( + "async function main(args) { return { answer: args.value ?? null }; }", + SupportLanguage.NODEJS, + ) + + assert is_safe is True + assert issues == [] diff --git a/agent/tools/code_exec.py b/agent/tools/code_exec.py index c896de57c1..5d65a2e33a 100644 --- a/agent/tools/code_exec.py +++ b/agent/tools/code_exec.py @@ -20,6 +20,7 @@ import logging import os import uuid from abc import ABC +from collections.abc import Mapping from typing import Optional from pydantic import BaseModel, Field, field_validator @@ -32,6 +33,183 @@ from common.connection_utils import timeout from common.constants import SANDBOX_ARTIFACT_BUCKET, SANDBOX_ARTIFACT_EXPIRE_DAYS +SYSTEM_OUTPUT_KEYS = frozenset( + { + "content", + "actual_type", + "_ERROR", + "_ARTIFACTS", + "_ATTACHMENT_CONTENT", + "raw_result", + "_created_time", + "_elapsed_time", + } +) + + +class ContractError(ValueError): + pass + + +def _validate_business_output_name(name: str) -> None: + if not name or not name.strip(): + raise ContractError("CodeExec business output name must not be empty") + if name in SYSTEM_OUTPUT_KEYS: + raise ContractError(f"CodeExec reserved output name is not allowed: {name}") + if "." in name: + raise ContractError(f"CodeExec business output name must not contain '.': {name}") + + +def select_business_output(outputs: Mapping[str, object]) -> tuple[str, object]: + if len(outputs) == 1: + only_name, only_meta = next(iter(outputs.items())) + _validate_business_output_name(only_name) + return only_name, only_meta + + business_outputs = [(name, meta) for name, meta in outputs.items() if name not in SYSTEM_OUTPUT_KEYS] + if len(business_outputs) != 1: + raise ContractError( + f"CodeExec contract must contain exactly one business output, got {len(business_outputs)}" + ) + _validate_business_output_name(business_outputs[0][0]) + return business_outputs[0] + + +def normalize_output_value(value): + if isinstance(value, (tuple, list)): + return [normalize_output_value(item) for item in value] + if isinstance(value, dict): + return {key: normalize_output_value(item) for key, item in value.items()} + return value + + +def infer_actual_type(value) -> str: + value = normalize_output_value(value) + if value is None: + return "Null" + if isinstance(value, bool): + return "Boolean" + if _is_number(value): + return "Number" + if isinstance(value, str): + return "String" + if isinstance(value, dict): + return "Object" + if isinstance(value, list): + if not value: + return "Array" + inferred = {infer_actual_type(item) for item in value} + if len(inferred) == 1: + return f"Array<{inferred.pop()}>" + return "Array" + return "Any" + + +def render_canonical_content(value) -> str: + value = normalize_output_value(value) + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, (dict, list)): + return json.dumps(value, ensure_ascii=False, indent=2, sort_keys=True) + return str(value) + + +def _is_number(value) -> bool: + return isinstance(value, (int, float)) and not isinstance(value, bool) + + +def _validate_top_level_value_domain(value) -> None: + allowed = value is None or isinstance(value, (bool, str, dict, list)) or _is_number(value) + if not allowed: + raise ContractError( + f"CodeExec unsupported top-level result type: {type(value).__name__}. " + "Allowed top-level values are String, Number, Boolean, Object, Array, or Null." + ) + + +def _normalize_expected_type(expected_type: str) -> str: + etype = expected_type.strip() + low = etype.lower() + simple_types = { + "string": "String", + "number": "Number", + "boolean": "Boolean", + "object": "Object", + "null": "Null", + "any": "Any", + } + if low in simple_types: + return simple_types[low] + if low.startswith("array<") and low.endswith(">"): + inner = etype[etype.find("<") + 1 : -1].strip() + if not inner: + raise ContractError(f"Unsupported expected type: {expected_type}") + return f"Array<{_normalize_expected_type(inner)}>" + return etype + + +def _validate_expected_type(expected_type: str, value, path: str = "") -> None: + etype = _normalize_expected_type(expected_type) + if not etype or etype.lower() == "any": + return + + value = normalize_output_value(value) + + if etype.startswith("Array<") and etype.endswith(">"): + inner_type = etype[6:-1].strip() + if not isinstance(value, list): + raise ContractError( + f"CodeExec contract mismatch at {path or 'value'}: expected type {etype}, got {infer_actual_type(value)}" + ) + for index, item in enumerate(value): + child_path = f"{path}[{index}]" if path else f"[{index}]" + _validate_expected_type(inner_type, item, child_path) + return + + actual_type = infer_actual_type(value) + if etype == "String": + valid = isinstance(value, str) + elif etype == "Number": + valid = _is_number(value) + elif etype == "Boolean": + valid = isinstance(value, bool) + elif etype == "Object": + valid = isinstance(value, dict) + elif etype == "Null": + valid = value is None + else: + raise ContractError(f"Unsupported expected type: {expected_type}") + + if not valid: + raise ContractError( + f"CodeExec contract mismatch at {path or 'value'}: expected type {etype}, got {actual_type}" + ) + + +def build_code_exec_contract(outputs: Mapping[str, object], raw_result) -> dict[str, object]: + business_name, business_meta = select_business_output(outputs) + expected_type = "" + if isinstance(business_meta, Mapping): + expected_type = str(business_meta.get("type") or "") + + normalized_value = normalize_output_value(raw_result) + _validate_top_level_value_domain(normalized_value) + _validate_expected_type(expected_type, normalized_value) + + return { + "business_output": business_name, + "value": normalized_value, + "actual_type": infer_actual_type(normalized_value), + "content": render_canonical_content(normalized_value), + } + + +def _art_field(art, field: str, default=""): + return art.get(field, default) if isinstance(art, dict) else getattr(art, field, default) + + class Language(StrEnum): PYTHON = "python" NODEJS = "nodejs" @@ -190,7 +368,13 @@ class CodeExec(ToolBase, ABC): return artifacts = result.metadata.get("artifacts", []) if result.metadata else [] - return self._process_execution_result(result.stdout, result.stderr, "Provider system", artifacts) + return self._process_execution_result( + result.stdout, + result.stderr, + "Provider system", + artifacts, + execution_metadata=result.metadata, + ) except (ImportError, RuntimeError) as provider_error: # Provider system not available or not configured, fall back to HTTP @@ -226,6 +410,7 @@ class CodeExec(ToolBase, ABC): body.get("stderr"), f"http://{settings.SANDBOX_HOST}:9385/run", body.get("artifacts", []), + execution_metadata=self._build_http_execution_metadata(body), ) else: self.set_output("_ERROR", "There is no response from sandbox") @@ -239,8 +424,18 @@ class CodeExec(ToolBase, ABC): return self.output() - def _process_execution_result(self, stdout: str, stderr: str | None, source: str, artifacts: list | None = None): - if stderr and not stdout and not artifacts: + def _process_execution_result( + self, + stdout: str, + stderr: str | None, + source: str, + artifacts: list | None = None, + execution_metadata: dict | None = None, + ): + has_structured_result = bool((execution_metadata or {}).get("result_present") is True) + resolved_value, used_stdout_fallback = self._resolve_execution_result_value(stdout, execution_metadata) + + if stderr and not has_structured_result and not artifacts and not str(stdout or "").strip(): self.set_output("_ERROR", stderr) return self.output() @@ -250,29 +445,48 @@ class CodeExec(ToolBase, ABC): if stderr: logging.warning(f"[CodeExec]: stderr (non-fatal): {stderr[:500]}") - parsed_stdout = self._deserialize_stdout(stdout) - logging.info(f"[CodeExec]: {source} -> {parsed_stdout}") - self._populate_outputs(parsed_stdout, stdout) + if used_stdout_fallback and str(stdout or "").strip(): + logging.warning("[CodeExec]: Falling back to stdout deserialization because no structured result metadata was provided") + + logging.info(f"[CodeExec]: {source} -> {resolved_value}") content_parts = [] - base_content = self._build_content_text(parsed_stdout, raw_stdout=stdout) + base_content = self._apply_business_output(resolved_value) if base_content: content_parts.append(base_content) if artifacts: artifact_urls = self._upload_artifacts(artifacts) - if artifact_urls: - self.set_output("_ARTIFACTS", artifact_urls) + self.set_output("_ARTIFACTS", artifact_urls or None) attachment_text = self._build_attachment_content(artifacts, artifact_urls) self.set_output("_ATTACHMENT_CONTENT", attachment_text) if attachment_text: content_parts.append(attachment_text) else: + self.set_output("_ARTIFACTS", None) self.set_output("_ATTACHMENT_CONTENT", "") self.set_output("content", "\n\n".join([part for part in content_parts if part]).strip()) return self.output() + def _build_http_execution_metadata(self, body: Mapping | None) -> dict: + if not isinstance(body, Mapping): + return {} + structured_result = body.get("result") + if not isinstance(structured_result, Mapping): + return {} + return { + "result_present": structured_result.get("present", False), + "result_value": structured_result.get("value"), + "result_type": structured_result.get("type"), + } + + def _resolve_execution_result_value(self, stdout: str, execution_metadata: Mapping | None = None): + metadata = execution_metadata or {} + if metadata.get("result_present") is True: + return metadata.get("result_value"), False + return self._deserialize_stdout(stdout), True + @classmethod def _ensure_bucket_lifecycle(cls): if cls._lifecycle_configured: @@ -306,10 +520,10 @@ class CodeExec(ToolBase, ABC): uploaded = [] for art in artifacts: try: - name = art.get("name", "") if isinstance(art, dict) else getattr(art, "name", "") - content_b64 = art.get("content_b64", "") if isinstance(art, dict) else getattr(art, "content_b64", "") - mime_type = art.get("mime_type", "") if isinstance(art, dict) else getattr(art, "mime_type", "") - size = art.get("size", 0) if isinstance(art, dict) else getattr(art, "size", 0) + name = _art_field(art, "name") + content_b64 = _art_field(art, "content_b64") + mime_type = _art_field(art, "mime_type") + size = _art_field(art, "size", 0) if not content_b64 or not name: continue @@ -350,119 +564,24 @@ class CodeExec(ToolBase, ABC): continue return text - def _coerce_output_value(self, value, expected_type: Optional[str]): - if expected_type is None: - return value - - etype = expected_type.strip().lower() - inner_type = None - if etype.startswith("array<") and etype.endswith(">"): - inner_type = etype[6:-1].strip() - etype = "array" + def _apply_business_output(self, parsed_stdout) -> str: + normalized_result = normalize_output_value(parsed_stdout) + self.set_output("raw_result", normalized_result) + business_output_names = [name for name in self._param.outputs if name not in SYSTEM_OUTPUT_KEYS] try: - if etype == "string": - return "" if value is None else str(value) + contract = build_code_exec_contract(self._param.outputs, normalized_result) + except ContractError as e: + for output_name in business_output_names: + self.set_output(output_name, None) + self.set_output("actual_type", infer_actual_type(normalized_result)) + self.set_output("_ERROR", str(e)) + logging.warning(f"[CodeExec]: contract validation failed: {e}") + return render_canonical_content(normalized_result) - if etype == "number": - if value is None or value == "": - return None - if isinstance(value, (int, float)): - return value - if isinstance(value, str): - try: - return float(value) - except Exception: - return value - return float(value) - - if etype == "boolean": - if isinstance(value, bool): - return value - if isinstance(value, str): - lv = value.lower() - if lv in ("true", "1", "yes", "y", "on"): - return True - if lv in ("false", "0", "no", "n", "off"): - return False - return bool(value) - - if etype == "array": - candidate = value - if isinstance(candidate, str): - parsed = self._deserialize_stdout(candidate) - candidate = parsed - if isinstance(candidate, tuple): - candidate = list(candidate) - if not isinstance(candidate, list): - candidate = [] if candidate is None else [candidate] - - if inner_type == "string": - return ["" if v is None else str(v) for v in candidate] - if inner_type == "number": - coerced = [] - for v in candidate: - try: - if v is None or v == "": - coerced.append(None) - elif isinstance(v, (int, float)): - coerced.append(v) - else: - coerced.append(float(v)) - except Exception: - coerced.append(v) - return coerced - return candidate - - if etype == "object": - if isinstance(value, dict): - return value - if isinstance(value, str): - parsed = self._deserialize_stdout(value) - if isinstance(parsed, dict): - return parsed - return value - except Exception: - return value - - return value - - def _populate_outputs(self, parsed_stdout, raw_stdout: str): - outputs_items = list(self._param.outputs.items()) - logging.info(f"[CodeExec]: outputs schema keys: {[k for k, _ in outputs_items]}") - if not outputs_items: - return - - if isinstance(parsed_stdout, dict): - for key, meta in outputs_items: - if key.startswith("_"): - continue - val = self._get_by_path(parsed_stdout, key) - if val is None and len(outputs_items) == 1: - val = parsed_stdout - coerced = self._coerce_output_value(val, meta.get("type")) - logging.info(f"[CodeExec]: populate dict key='{key}' raw='{val}' coerced='{coerced}'") - self.set_output(key, coerced) - return - - if isinstance(parsed_stdout, (list, tuple)): - for idx, (key, meta) in enumerate(outputs_items): - if key.startswith("_"): - continue - val = parsed_stdout[idx] if idx < len(parsed_stdout) else None - coerced = self._coerce_output_value(val, meta.get("type")) - logging.info(f"[CodeExec]: populate list key='{key}' raw='{val}' coerced='{coerced}'") - self.set_output(key, coerced) - return - - default_val = parsed_stdout if parsed_stdout is not None else raw_stdout - for idx, (key, meta) in enumerate(outputs_items): - if key.startswith("_"): - continue - val = default_val if idx == 0 else None - coerced = self._coerce_output_value(val, meta.get("type")) - logging.info(f"[CodeExec]: populate scalar key='{key}' raw='{val}' coerced='{coerced}'") - self.set_output(key, coerced) + self.set_output("actual_type", contract["actual_type"]) + self.set_output(contract["business_output"], contract["value"]) + return contract["content"] def _build_attachment_content(self, artifacts: list, artifact_urls: list[dict] | None = None) -> str: sections = [] @@ -471,9 +590,9 @@ class CodeExec(ToolBase, ABC): for idx, art in enumerate(artifacts, start=1): key = f"attachment{idx}" try: - name = art.get("name", "") if isinstance(art, dict) else getattr(art, "name", "") - content_b64 = art.get("content_b64", "") if isinstance(art, dict) else getattr(art, "content_b64", "") - mime_type = art.get("mime_type", "") if isinstance(art, dict) else getattr(art, "mime_type", "") + name = _art_field(art, "name") + content_b64 = _art_field(art, "content_b64") + mime_type = _art_field(art, "mime_type") if not name or not content_b64: continue @@ -490,11 +609,8 @@ class CodeExec(ToolBase, ABC): logging.info(f"[CodeExec]: parse attachment section key='{key}' from artifact='{name}'") except Exception as e: logging.warning(f"[CodeExec]: Failed to parse artifact for content section '{key}': {e}") - fallback_type = self._normalize_attachment_type( - art.get("name", "") if isinstance(art, dict) else getattr(art, "name", ""), - art.get("mime_type", "") if isinstance(art, dict) else getattr(art, "mime_type", ""), - ) - fallback_name = art.get("name", "") if isinstance(art, dict) else getattr(art, "name", "") + fallback_type = self._normalize_attachment_type(name, mime_type) + fallback_name = name fallback_url = "" if idx - 1 < len(artifact_urls): fallback_url = artifact_urls[idx - 1].get("url", "") @@ -529,38 +645,3 @@ class CodeExec(ToolBase, ABC): title += f": {name}" body = parsed if isinstance(parsed, str) else json.dumps(parsed, ensure_ascii=False) return f"{title}\n{body}".strip() - - def _build_content_text(self, parsed_stdout, raw_stdout: str = "") -> str: - if isinstance(parsed_stdout, str): - return parsed_stdout.strip() - if isinstance(parsed_stdout, (dict, list, tuple)): - try: - return json.dumps(parsed_stdout, ensure_ascii=False, indent=2).strip() - except Exception: - return str(parsed_stdout).strip() - if parsed_stdout is None: - return str(raw_stdout or "").strip() - return str(parsed_stdout).strip() - - def _get_by_path(self, data, path: str): - if not path: - return None - cur = data - for part in path.split("."): - part = part.strip() - if not part: - return None - if isinstance(cur, dict): - cur = cur.get(part) - elif isinstance(cur, list): - try: - idx = int(part) - cur = cur[idx] - except Exception: - return None - else: - return None - if cur is None: - return None - logging.info(f"[CodeExec]: resolve path '{path}' -> {cur}") - return cur diff --git a/test/testcases/test_web_api/test_canvas_app/test_code_exec_contract_unit.py b/test/testcases/test_web_api/test_canvas_app/test_code_exec_contract_unit.py new file mode 100644 index 0000000000..ff171c3b00 --- /dev/null +++ b/test/testcases/test_web_api/test_canvas_app/test_code_exec_contract_unit.py @@ -0,0 +1,456 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib.util +import sys +import types +from pathlib import Path +from unittest.mock import patch + +import pytest + + +CODE_EXEC_MODULE_PATH = next( + parent / "agent" / "tools" / "code_exec.py" + for parent in Path(__file__).resolve().parents + if (parent / "agent" / "tools" / "code_exec.py").exists() +) + + +def _load_module(): + return _load_code_exec_runtime_module() + + +def _build_code_exec(output_type: str): + return _build_code_exec_with_outputs({"result": {"value": None, "type": output_type}}) + + +def _build_code_exec_with_outputs(outputs: dict[str, dict]): + module = _load_module() + tool = module.CodeExec.__new__(module.CodeExec) + tool._param = types.SimpleNamespace(outputs=outputs) + tool._canvas = types.SimpleNamespace(get_tenant_id=lambda: "tenant-1") + return tool + + +def _load_code_exec_runtime_module(): + agent_module = types.ModuleType("agent") + tools_module = types.ModuleType("agent.tools") + base_module = types.ModuleType("agent.tools.base") + + class _FakeToolParamBase: + def __init__(self): + self.outputs = {} + + class _FakeToolBase: + def output(self, var_nm=None): + if var_nm: + return self._param.outputs.get(var_nm, {}).get("value", "") + return {k: v.get("value") for k, v in self._param.outputs.items()} + + def set_output(self, key, value): + if key not in self._param.outputs: + self._param.outputs[key] = {"value": None, "type": str(type(value))} + self._param.outputs[key]["value"] = value + + def check_if_canceled(self, *_args, **_kwargs): + return False + + base_module.ToolBase = _FakeToolBase + base_module.ToolMeta = dict + base_module.ToolParamBase = _FakeToolParamBase + + api_module = types.ModuleType("api") + api_db_module = types.ModuleType("api.db") + api_db_services_module = types.ModuleType("api.db.services") + file_service_module = types.ModuleType("api.db.services.file_service") + + class _FakeFileService: + @staticmethod + def parse(*_args, **_kwargs): + return "" + + file_service_module.FileService = _FakeFileService + + common_module = types.ModuleType("common") + common_settings_module = types.ModuleType("common.settings") + common_settings_module.SANDBOX_HOST = "sandbox" + common_settings_module.STORAGE_IMPL = types.SimpleNamespace(put=lambda *_args, **_kwargs: None) + + connection_utils_module = types.ModuleType("common.connection_utils") + + def _timeout(_seconds): + def _decorator(func): + return func + + return _decorator + + connection_utils_module.timeout = _timeout + + constants_module = types.ModuleType("common.constants") + constants_module.SANDBOX_ARTIFACT_BUCKET = "bucket" + constants_module.SANDBOX_ARTIFACT_EXPIRE_DAYS = 7 + + agent_module.tools = tools_module + tools_module.base = base_module + api_module.db = api_db_module + api_db_module.services = api_db_services_module + api_db_services_module.file_service = file_service_module + common_module.settings = common_settings_module + + stub_modules = { + "agent": agent_module, + "agent.tools": tools_module, + "agent.tools.base": base_module, + "api": api_module, + "api.db": api_db_module, + "api.db.services": api_db_services_module, + "api.db.services.file_service": file_service_module, + "common": common_module, + "common.settings": common_settings_module, + "common.connection_utils": connection_utils_module, + "common.constants": constants_module, + } + + spec = importlib.util.spec_from_file_location("code_exec_runtime", CODE_EXEC_MODULE_PATH) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + with patch.dict(sys.modules, stub_modules): + spec.loader.exec_module(module) + return module + + +def test_select_business_output_ignores_system_outputs(): + module = _load_module() + outputs = { + "content": {"value": "", "type": "string"}, + "actual_type": {"value": "", "type": "string"}, + "_ERROR": {"value": "", "type": "string"}, + "_ARTIFACTS": {"value": [], "type": "Array"}, + "_ATTACHMENT_CONTENT": {"value": "", "type": "string"}, + "raw_result": {"value": None, "type": "Any"}, + "_created_time": {"value": 1.0, "type": "Number"}, + "_elapsed_time": {"value": 2.0, "type": "Number"}, + "result": {"value": None, "type": "Array"}, + } + + name, meta = module.select_business_output(outputs) + + assert name == "result" + assert meta["type"] == "Array" + + +def test_array_result_is_preserved_as_single_business_value(): + module = _load_module() + contract = module.build_code_exec_contract( + {"result": {"value": None, "type": "Array"}}, + (1, 2, 3), + ) + + assert contract["business_output"] == "result" + assert contract["value"] == [1, 2, 3] + assert contract["actual_type"] == "Array" + assert contract["content"] == "[\n 1,\n 2,\n 3\n]" + + +def test_object_result_is_not_wrapped_by_business_name(): + module = _load_module() + contract = module.build_code_exec_contract( + {"result": {"value": None, "type": "Object"}}, + {"foo": "bar", "n": 1}, + ) + + assert contract["business_output"] == "result" + assert contract["value"] == {"foo": "bar", "n": 1} + assert contract["content"] == '{\n "foo": "bar",\n "n": 1\n}' + + +def test_canonical_object_rendering_is_key_order_stable(): + module = _load_module() + assert module.render_canonical_content({"b": 1, "a": 2}) == '{\n "a": 2,\n "b": 1\n}' + + +def test_lowercase_object_expected_type_validates(): + module = _load_module() + contract = module.build_code_exec_contract( + {"result": {"value": None, "type": "object"}}, + {"foo": "bar"}, + ) + + assert contract["actual_type"] == "Object" + assert contract["value"] == {"foo": "bar"} + + +def test_tuple_is_normalized_to_array_semantics(): + module = _load_module() + assert module.normalize_output_value((1, 2, 3)) == [1, 2, 3] + assert module.infer_actual_type((1, 2, 3)) == "Array" + + +def test_list_is_preserved_as_list_without_normalization_changes(): + module = _load_module() + values = [1, 2, 3] + normalized = module.normalize_output_value(values) + assert normalized == [1, 2, 3] + assert isinstance(normalized, list) + + +def test_canonical_content_rendering_handles_common_shapes(): + module = _load_module() + assert module.render_canonical_content("hello") == "hello" + assert module.render_canonical_content(None) == "" + assert module.render_canonical_content(1.5) == "1.5" + assert module.render_canonical_content({"x": [1, 2]}) == '{\n "x": [\n 1,\n 2\n ]\n}' + + +def test_any_does_not_allow_unsupported_top_level_python_types(): + module = _load_module() + with pytest.raises(module.ContractError, match="unsupported top-level result type"): + module.build_code_exec_contract( + {"result": {"value": None, "type": "Any"}}, + {1, 2}, + ) + + +def test_mismatch_raises_contract_error(): + module = _load_module() + with pytest.raises(module.ContractError, match="expected type Number"): + module.build_code_exec_contract({"result": {"value": None, "type": "Number"}}, "not-a-number") + + +def test_array_number_rejects_string_elements_without_coercion(): + module = _load_module() + with pytest.raises(module.ContractError, match=r"expected type Number, got String"): + module.build_code_exec_contract({"result": {"value": None, "type": "Array"}}, ["1", 2]) + + +def test_boolean_rejects_string_form_without_coercion(): + module = _load_module() + with pytest.raises(module.ContractError, match=r"expected type Boolean, got String"): + module.build_code_exec_contract({"result": {"value": None, "type": "Boolean"}}, "true") + + +def test_lowercase_array_number_expected_type_validates(): + module = _load_module() + contract = module.build_code_exec_contract( + {"result": {"value": None, "type": "array"}}, + (1, 2, 3), + ) + + assert contract["actual_type"] == "Array" + assert contract["value"] == [1, 2, 3] + + +def test_lowercase_array_string_expected_type_validates(): + module = _load_module() + contract = module.build_code_exec_contract( + {"result": {"value": None, "type": "array"}}, + ("a", "b"), + ) + + assert contract["actual_type"] == "Array" + assert contract["value"] == ["a", "b"] + + +@pytest.mark.parametrize("schema", ["Array<>", "Array< >", "array<>", "array< >"]) +def test_malformed_array_schema_is_rejected(schema): + module = _load_module() + with pytest.raises(module.ContractError, match="Unsupported expected type"): + module.build_code_exec_contract({"result": {"value": None, "type": schema}}, [1, 2]) + + +def test_any_and_empty_expected_type_skip_validation(): + module = _load_module() + assert module.build_code_exec_contract({"result": {"value": None, "type": "Any"}}, {"foo": "bar"})["value"] == { + "foo": "bar" + } + assert module.build_code_exec_contract({"result": {"value": None, "type": ""}}, {"foo": "bar"})["value"] == { + "foo": "bar" + } + assert module.build_code_exec_contract({"result": {"value": None, "type": None}}, {"foo": "bar"})["value"] == { + "foo": "bar" + } + + +def test_legacy_multi_output_schema_is_rejected(): + module = _load_module() + with pytest.raises(module.ContractError, match="exactly one business output"): + module.select_business_output( + { + "result": {"value": None, "type": "Number"}, + "answer": {"value": None, "type": "String"}, + "_ERROR": {"value": "", "type": "string"}, + } + ) + + +@pytest.mark.parametrize("name", ["content", "actual_type", "_ERROR", "_ARTIFACTS", "_ATTACHMENT_CONTENT", "raw_result"]) +def test_reserved_business_output_names_are_rejected(name): + module = _load_module() + with pytest.raises(module.ContractError, match="reserved output name"): + module.build_code_exec_contract( + {name: {"value": None, "type": "String"}}, + "ok", + ) + + +def test_dotted_business_output_name_is_rejected(): + module = _load_module() + with pytest.raises(module.ContractError, match=r"must not contain '.'"): + module.build_code_exec_contract( + {"payload.items": {"value": None, "type": "Array"}}, + ["a"], + ) + + +def test_process_execution_result_preserves_whole_array_for_single_business_output(): + tool = _build_code_exec("Array") + + result = tool._process_execution_result('["a", "b"]', None, "unit-test") + + assert result["result"] == ["a", "b"] + assert result["content"] == '[\n "a",\n "b"\n]' + assert result["raw_result"] == ["a", "b"] + + +def test_process_execution_result_sets_actual_type_from_contract_value(): + tool = _build_code_exec("Object") + + result = tool._process_execution_result('{"foo": "bar"}', None, "unit-test") + + assert result["result"] == {"foo": "bar"} + assert result["actual_type"] == "Object" + + +def test_process_execution_result_contract_mismatch_sets_error_and_clears_business_output(): + tool = _build_code_exec("Number") + + result = tool._process_execution_result('["a", "b"]', None, "unit-test") + + assert "expected type Number" in result["_ERROR"] + assert result["result"] is None + assert result["actual_type"] == "Array" + assert result["raw_result"] == ["a", "b"] + + +def test_process_execution_result_invalid_schema_clears_stale_business_outputs(): + tool = _build_code_exec_with_outputs( + { + "result": {"value": "stale-result", "type": "String"}, + "answer": {"value": {"stale": True}, "type": "Object"}, + "_ERROR": {"value": "", "type": "string"}, + } + ) + + result = tool._process_execution_result('["a", "b"]', None, "unit-test") + + assert "exactly one business output" in result["_ERROR"] + assert result["result"] is None + assert result["answer"] is None + assert result["actual_type"] == "Array" + assert result["raw_result"] == ["a", "b"] + + +def test_process_execution_result_keeps_business_output_when_stderr_is_non_fatal(): + tool = _build_code_exec("Object") + + result = tool._process_execution_result('{"foo": "bar"}', "warning on stderr", "unit-test") + + assert result["_ERROR"] == "" + assert result["result"] == {"foo": "bar"} + assert result["content"] == '{\n "foo": "bar"\n}' + + +def test_process_execution_result_returns_early_for_stderr_only_without_artifacts(): + tool = _build_code_exec("String") + + result = tool._process_execution_result("", "hard failure", "unit-test") + + assert result["_ERROR"] == "hard failure" + assert result.get("result") is None + assert result.get("content") is None + + +def test_process_execution_result_appends_artifact_content_to_canonical_content(): + tool = _build_code_exec("Object") + tool._upload_artifacts = lambda _artifacts: [{"name": "chart.png", "url": "/artifact/chart.png", "mime_type": "image/png", "size": 12}] + tool._build_attachment_content = lambda _artifacts, _artifact_urls: "attachment_count: 1\n\nattachment1 (image): chart.png\nparsed artifact" + + result = tool._process_execution_result( + '{"foo": "bar"}', + None, + "unit-test", + artifacts=[{"name": "chart.png", "content_b64": "ZmFrZQ==", "mime_type": "image/png", "size": 12}], + ) + + assert result["result"] == {"foo": "bar"} + assert result["content"] == '{\n "foo": "bar"\n}\n\nattachment_count: 1\n\nattachment1 (image): chart.png\nparsed artifact' + assert result["_ARTIFACTS"] == [{"name": "chart.png", "url": "/artifact/chart.png", "mime_type": "image/png", "size": 12}] + assert result["_ARTIFACTS"][0]["mime_type"] == "image/png" + assert result["_ATTACHMENT_CONTENT"] == "attachment_count: 1\n\nattachment1 (image): chart.png\nparsed artifact" + assert "attachment1 (image): chart.png" in result["_ATTACHMENT_CONTENT"] + + +def test_process_execution_result_without_artifacts_clears_stale_artifacts_output(): + tool = _build_code_exec_with_outputs( + { + "result": {"value": None, "type": "String"}, + "_ARTIFACTS": {"value": [{"name": "stale"}], "type": "Array"}, + } + ) + + result = tool._process_execution_result('"ok"', None, "unit-test") + + assert result["result"] == "ok" + assert result["_ARTIFACTS"] is None + + +def test_process_execution_result_prefers_structured_result_metadata_over_stdout_guessing(): + tool = _build_code_exec("Object") + + result = tool._process_execution_result( + '{"fake":"stdout-log"}', + None, + "unit-test", + execution_metadata={ + "result_present": True, + "result_value": {"real": "value"}, + "result_type": "json", + }, + ) + + assert result["result"] == {"real": "value"} + assert result["actual_type"] == "Object" + assert result["content"] == '{\n "real": "value"\n}' + + +def test_process_execution_result_preserves_json_looking_string_when_metadata_marks_string(): + tool = _build_code_exec("String") + + result = tool._process_execution_result( + '{"a":1}', + None, + "unit-test", + execution_metadata={ + "result_present": True, + "result_value": '{"a":1}', + "result_type": "json", + }, + ) + + assert result["result"] == '{"a":1}' + assert result["actual_type"] == "String" + assert result["content"] == '{"a":1}' diff --git a/web/src/interfaces/database/agent.ts b/web/src/interfaces/database/agent.ts index 6620c7fd32..4ddabe67ad 100644 --- a/web/src/interfaces/database/agent.ts +++ b/web/src/interfaces/database/agent.ts @@ -160,7 +160,7 @@ export interface ICodeForm { arguments: Record; lang: string; script?: string; - outputs: Record; + outputs: Record; } export interface IAgentForm { @@ -192,7 +192,7 @@ export interface IAgentForm { }; } -export type BaseNodeData = { +export type BaseNodeData = { label: string; // operator type name: string; // operator name color?: string; diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 3a1f19f37d..02e31ec496 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -2078,6 +2078,9 @@ This delimiter is used to split the input text into several text pieces echo of }`, datatype: 'MINE type of the HTTP request', insertVariableTip: `Enter / Insert variables`, + mergePath: 'Merge path', + mergePathTip: + 'When enabled, a dot suffix immediately after a variable is merged into a path query, such as {node@result.name}.', historyVersion: 'Version history', version: { created: 'Created', diff --git a/web/src/locales/zh-traditional.ts b/web/src/locales/zh-traditional.ts index e019e608d8..29855e2b43 100644 --- a/web/src/locales/zh-traditional.ts +++ b/web/src/locales/zh-traditional.ts @@ -1262,6 +1262,9 @@ export default { categoryName: '分類名稱', nextStep: '下一步', insertVariableTip: `輸入 / 插入變數`, + mergePath: '合併路徑', + mergePathTip: + '開啟後,緊跟在變數後面的點號後綴會合併為路徑查詢,例如 {node@result.name}。', promptMessage: '提示詞是必填項', promptTip: '系統提示為大型模型提供任務描述、規定回覆方式,以及設定其他各種要求。系統提示通常與 key(變數)合用,透過變數設定大型模型的輸入資料。你可以透過斜線或 (x) 按鈕顯示可用的 key。', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 17878fed74..4ec14ab451 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -1818,6 +1818,9 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 categoryName: '分类名称', nextStep: '下一步', insertVariableTip: `输入 / 插入变量`, + mergePath: '合并路径', + mergePathTip: + '开启后,紧跟在变量后面的点号后缀会合并为路径查询,例如 {node@result.name}。', setting: '设置', settings: { agentSetting: 'Agent设置', diff --git a/web/src/pages/agent/constant/index.tsx b/web/src/pages/agent/constant/index.tsx index d5cbd1980e..628a5870e9 100644 --- a/web/src/pages/agent/constant/index.tsx +++ b/web/src/pages/agent/constant/index.tsx @@ -27,6 +27,10 @@ export * from './pipeline'; import { ModelVariableType } from '@/constants/knowledge'; import { t } from 'i18next'; +import { + buildDefaultCodeOutput, + serializeCodeOutputContract, +} from '../form/code-form/utils'; // DuckDuckGo's channel options export enum Channel { @@ -427,7 +431,7 @@ export const initialCodeValues = { arg1: '', arg2: '', }, - outputs: {}, + outputs: serializeCodeOutputContract(buildDefaultCodeOutput()), }; export const initialWaitingDialogueValues = {}; diff --git a/web/src/pages/agent/form-sheet/single-debug-sheet/index.tsx b/web/src/pages/agent/form-sheet/single-debug-sheet/index.tsx index 51ff8eebcb..a105d6137d 100644 --- a/web/src/pages/agent/form-sheet/single-debug-sheet/index.tsx +++ b/web/src/pages/agent/form-sheet/single-debug-sheet/index.tsx @@ -2,6 +2,7 @@ import CopyToClipboard from '@/components/copy-to-clipboard'; import { Sheet, SheetContent, SheetHeader } from '@/components/ui/sheet'; import { useDebugSingle, useFetchInputForm } from '@/hooks/use-agent-request'; import { IModalProps } from '@/interfaces/common'; +import { ICodeForm } from '@/interfaces/database/agent'; import { cn } from '@/lib/utils'; import { isEmpty } from 'lodash'; import { X } from 'lucide-react'; @@ -12,6 +13,70 @@ import 'react18-json-view/src/style.css'; import DebugContent from '../../debug-content'; import { transferInputsArrayToObject } from '../../form/begin-form/use-watch-change'; import { buildBeginInputListFromObject } from '../../form/begin-form/utils'; +import { + deserializeCodeOutputContract, + getBusinessOutputs, +} from '../../form/code-form/utils'; +import useGraphStore from '../../store'; +import { + groupCodeExecDebugOutput, + shouldUseCodeExecDebugLayout, +} from './utils'; + +function DebugRow({ label, value }: { label: string; value: string }) { + return ( +
+ {label} + + {value || '-'} + +
+ ); +} + +function DebugJsonCard({ + title, + value, + error, +}: { + title: string; + value: unknown; + error?: boolean; +}) { + return ( +
+
+ {title} + +
+ +
+ ); +} + +function DebugTextCard({ title, value }: { title: string; value: string }) { + return ( +
+
+ {title} + +
+
+        {value || '-'}
+      
+
+ ); +} interface IProps { componentId?: string; @@ -23,8 +88,13 @@ const SingleDebugSheet = ({ hideModal, }: IModalProps & IProps) => { const { t } = useTranslation(); + const getNode = useGraphStore((state) => state.getNode); const inputForm = useFetchInputForm(componentId); const { debugSingle, data, loading } = useDebugSingle(); + const node = getNode(componentId); + const shouldUseCodeExecLayout = shouldUseCodeExecDebugLayout( + node?.data.label, + ); const list = useMemo(() => { return buildBeginInputListFromObject(inputForm); @@ -42,43 +112,124 @@ const SingleDebugSheet = ({ [componentId, debugSingle], ); + const formData = shouldUseCodeExecLayout + ? (node?.data.form as ICodeForm | undefined) + : undefined; + const debugData = useMemo( + () => + data && typeof data === 'object' + ? (data as Record) + : undefined, + [data], + ); + const { contract, legacyOutputs } = useMemo( + () => deserializeCodeOutputContract(formData), + [formData], + ); + const grouped = useMemo( + () => groupCodeExecDebugOutput(debugData, contract), + [contract, debugData], + ); + const businessOutputPreview = useMemo(() => { + if (contract?.name && debugData && contract.name in debugData) { + return { [contract.name]: debugData[contract.name] }; + } + + if (legacyOutputs.length > 0 && debugData) { + return Object.fromEntries( + Object.keys(getBusinessOutputs(formData?.outputs)).map((key) => [ + key, + debugData[key], + ]), + ); + } + + return {}; + }, [contract, debugData, formData?.outputs, legacyOutputs.length]); const content = JSON.stringify(data, null, 2); + const hasError = shouldUseCodeExecLayout + ? !isEmpty(grouped.systemOutputs._ERROR) + : !isEmpty((data as Record | undefined)?._ERROR); return ( - +
{t('flow.testRun')}
-
+
{!isEmpty(data) ? ( -
-
- JSON - +
+ {shouldUseCodeExecLayout ? ( + <> +
+ + + +
+ {!isEmpty(businessOutputPreview) && ( + + )} + + + + + ) : null} +
+
+ + {shouldUseCodeExecLayout ? 'Raw Component Output' : 'JSON'} + + +
+
-
) : null}
diff --git a/web/src/pages/agent/form-sheet/single-debug-sheet/utils.test.ts b/web/src/pages/agent/form-sheet/single-debug-sheet/utils.test.ts new file mode 100644 index 0000000000..8bbb567df7 --- /dev/null +++ b/web/src/pages/agent/form-sheet/single-debug-sheet/utils.test.ts @@ -0,0 +1,10 @@ +import { Operator } from '../../constant'; +import { shouldUseCodeExecDebugLayout } from './utils'; + +describe('shouldUseCodeExecDebugLayout', () => { + it('returns true only for CodeExec nodes', () => { + expect(shouldUseCodeExecDebugLayout(Operator.Code)).toBe(true); + expect(shouldUseCodeExecDebugLayout(Operator.Http)).toBe(false); + expect(shouldUseCodeExecDebugLayout(undefined)).toBe(false); + }); +}); diff --git a/web/src/pages/agent/form-sheet/single-debug-sheet/utils.ts b/web/src/pages/agent/form-sheet/single-debug-sheet/utils.ts new file mode 100644 index 0000000000..a17a8c64ae --- /dev/null +++ b/web/src/pages/agent/form-sheet/single-debug-sheet/utils.ts @@ -0,0 +1,40 @@ +import { Operator } from '../../constant'; +import { CodeOutputContract } from '../../form/code-form/utils'; + +const SYSTEM_OUTPUT_NAMES = new Set([ + '_ERROR', + '_ARTIFACTS', + '_ATTACHMENT_CONTENT', +]); + +export type GroupedCodeExecDebugOutput = { + expectedType: string; + actualType: string; + rawResult: unknown; + content: string; + systemOutputs: Record; +}; + +export function groupCodeExecDebugOutput( + data: Record | undefined, + contract: CodeOutputContract | null, +): GroupedCodeExecDebugOutput { + const businessName = contract?.name ?? ''; + const source = data ?? {}; + const systemOutputs = Object.fromEntries( + Object.entries(source).filter(([key]) => SYSTEM_OUTPUT_NAMES.has(key)), + ); + + return { + expectedType: contract?.type ?? '', + actualType: String(source.actual_type ?? ''), + rawResult: + source.raw_result ?? (businessName ? source[businessName] : undefined), + content: String(source.content ?? ''), + systemOutputs, + }; +} + +export function shouldUseCodeExecDebugLayout(label?: string): boolean { + return label === Operator.Code; +} diff --git a/web/src/pages/agent/form/code-form/index.tsx b/web/src/pages/agent/form/code-form/index.tsx index f9797ad24e..131348a2f1 100644 --- a/web/src/pages/agent/form/code-form/index.tsx +++ b/web/src/pages/agent/form/code-form/index.tsx @@ -16,6 +16,7 @@ import { RAGFlowSelect } from '@/components/ui/select'; import { ProgrammingLanguage } from '@/constants/agent'; import { ICodeForm } from '@/interfaces/database/agent'; import { zodResolver } from '@hookform/resolvers/zod'; +import { AlertTriangle } from 'lucide-react'; import { memo } from 'react'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; @@ -33,6 +34,11 @@ import { useHandleLanguageChange, useWatchFormChange, } from './use-watch-change'; +import { + CodeExecPanelSystemOutputs, + getBusinessOutputs, + serializeCodeOutputContract, +} from './utils'; loader.config({ paths: { vs: '/vs' } }); @@ -41,18 +47,10 @@ const options = [ ProgrammingLanguage.Javascript, ].map((x) => ({ value: x, label: x })); -const DynamicFieldName = 'outputs'; -const CodeSystemOutputs = { - content: { - type: 'string', - value: '', - }, -}; - function CodeForm({ node }: INextOperatorForm) { const formData = node?.data.form as ICodeForm; const { t } = useTranslation(); - const values = useValues(node); + const { values, legacyOutputs } = useValues(node); const isDarkTheme = useIsDarkTheme(); const form = useForm({ @@ -63,6 +61,13 @@ function CodeForm({ node }: INextOperatorForm) { useWatchFormChange(node?.id, form); const handleLanguageChange = useHandleLanguageChange(node?.id, form); + const lang = form.watch('lang'); + const currentOutput = form.watch('output'); + const outputFieldDirty = !!form.formState.dirtyFields?.output; + const displayedBusinessOutputs = + legacyOutputs.length > 0 && !outputFieldDirty + ? getBusinessOutputs(formData?.outputs) + : serializeCodeOutputContract(currentOutput); return (
@@ -103,7 +108,7 @@ function CodeForm({ node }: INextOperatorForm) { - {formData.lang === ProgrammingLanguage.Python ? ( - - ) : ( -
- - - ( - - Name - - - - - - )} - /> - ( - - Type - - - - - - )} - /> - -
- )} +
+ + {legacyOutputs.length > 0 && ( +
+ +

+ This CodeExec node uses the deprecated multi-output schema:{' '} + {legacyOutputs.join(', ')}. Keep one business output here and + move field extraction to downstream nodes. +

+
+ )} + + ( + + Name + + + + + + )} + /> + ( + + Type + + + + + + )} + /> + +
-
- +
+ + Business + + + System +
); diff --git a/web/src/pages/agent/form/code-form/next-variable.tsx b/web/src/pages/agent/form/code-form/next-variable.tsx index d46c7305c7..ad3b3fea72 100644 --- a/web/src/pages/agent/form/code-form/next-variable.tsx +++ b/web/src/pages/agent/form/code-form/next-variable.tsx @@ -29,9 +29,12 @@ export const TypeOptions = [ 'String', 'Number', 'Boolean', + 'Object', 'Array', 'Array', - 'Object', + 'Array', + 'Array', + 'Any', ].map((x) => ({ label: x, value: x })); export function DynamicVariableForm({ name = 'arguments', isOutputs }: IProps) { diff --git a/web/src/pages/agent/form/code-form/schema.ts b/web/src/pages/agent/form/code-form/schema.ts index fe694444e2..4d22bea1a7 100644 --- a/web/src/pages/agent/form/code-form/schema.ts +++ b/web/src/pages/agent/form/code-form/schema.ts @@ -1,14 +1,22 @@ import { ProgrammingLanguage } from '@/constants/agent'; import { z } from 'zod'; +import { isValidCodeOutputName } from './utils'; export const FormSchema = z.object({ lang: z.enum([ProgrammingLanguage.Python, ProgrammingLanguage.Javascript]), script: z.string(), arguments: z.array(z.object({ name: z.string(), type: z.string() })), - outputs: z.union([ - z.array(z.object({ name: z.string(), type: z.string() })).optional(), - z.object({ name: z.string(), type: z.string() }), - ]), + output: z.object({ + name: z + .string() + .trim() + .min(1, 'Name is required') + .refine( + isValidCodeOutputName, + 'Name cannot use reserved outputs or path syntax', + ), + type: z.string().trim().min(1, 'Type is required'), + }), }); export type FormSchemaType = z.infer; diff --git a/web/src/pages/agent/form/code-form/use-values.ts b/web/src/pages/agent/form/code-form/use-values.ts index f920a9744e..e3d55bc935 100644 --- a/web/src/pages/agent/form/code-form/use-values.ts +++ b/web/src/pages/agent/form/code-form/use-values.ts @@ -1,8 +1,8 @@ -import { ProgrammingLanguage } from '@/constants/agent'; -import { ICodeForm, RAGFlowNodeType } from '@/interfaces/database/agent'; +import { RAGFlowNodeType } from '@/interfaces/database/agent'; import { isEmpty } from 'lodash'; import { useMemo } from 'react'; import { initialCodeValues } from '../../constant'; +import { buildDefaultCodeOutput, deserializeCodeOutputContract } from './utils'; function convertToArray(args: Record) { return Object.entries(args).map(([key, value]) => ({ @@ -11,36 +11,32 @@ function convertToArray(args: Record) { })); } -type OutputsFormType = { name: string; type: string }; - -function convertOutputsToArray({ lang, outputs = {} }: ICodeForm) { - if (lang === ProgrammingLanguage.Python) { - return Object.entries(outputs).map(([key, val]) => ({ - name: key, - type: val.type, - })); - } - return Object.entries(outputs).reduce((pre, [key, val]) => { - pre.name = key; - pre.type = val.type; - return pre; - }, {} as OutputsFormType); -} - export function useValues(node?: RAGFlowNodeType) { - const values = useMemo(() => { + const valueState = useMemo(() => { const formData = node?.data?.form; if (isEmpty(formData)) { - return initialCodeValues; + return { + values: { + ...initialCodeValues, + arguments: convertToArray(initialCodeValues.arguments), + output: buildDefaultCodeOutput(), + }, + legacyOutputs: [], + }; } + const { contract, legacyOutputs } = deserializeCodeOutputContract(formData); + return { - ...formData, - arguments: convertToArray(formData.arguments), - outputs: convertOutputsToArray(formData), + values: { + ...formData, + arguments: convertToArray(formData.arguments), + output: contract ?? buildDefaultCodeOutput(), + }, + legacyOutputs, }; }, [node?.data?.form]); - return values; + return valueState; } diff --git a/web/src/pages/agent/form/code-form/use-watch-change.ts b/web/src/pages/agent/form/code-form/use-watch-change.ts index 80e0c8b15d..c0f313e5ed 100644 --- a/web/src/pages/agent/form/code-form/use-watch-change.ts +++ b/web/src/pages/agent/form/code-form/use-watch-change.ts @@ -1,10 +1,14 @@ import { CodeTemplateStrMap, ProgrammingLanguage } from '@/constants/agent'; -import { ICodeForm } from '@/interfaces/database/agent'; import { isEmpty } from 'lodash'; import { useCallback, useEffect } from 'react'; import { UseFormReturn, useWatch } from 'react-hook-form'; import useGraphStore from '../../store'; import { FormSchemaType } from './schema'; +import { + buildDefaultCodeOutput, + hasLegacyMultiOutputs, + serializeCodeOutputContract, +} from './utils'; function convertToObject(list: FormSchemaType['arguments'] = []) { return list.reduce>((pre, cur) => { @@ -14,58 +18,52 @@ function convertToObject(list: FormSchemaType['arguments'] = []) { }, {}); } -type ArrayOutputs = Extract>; - -type ObjectOutputs = Exclude>; - -function convertOutputsToObject({ lang, outputs }: FormSchemaType) { - if (lang === ProgrammingLanguage.Python) { - return (outputs as ArrayOutputs).reduce( - (pre, cur) => { - pre[cur.name] = { - value: '', - type: cur.type, - }; - - return pre; - }, - {}, - ); - } - const outputsObject = outputs as ObjectOutputs; - if (isEmpty(outputsObject)) { - return {}; - } - return { - [outputsObject.name]: { - value: '', - type: outputsObject.type, - }, - }; -} - export function useWatchFormChange( id?: string, form?: UseFormReturn, ) { - let values = useWatch({ control: form?.control }); + const watchedValues = useWatch({ control: form?.control }); const updateNodeForm = useGraphStore((state) => state.updateNodeForm); + const getNode = useGraphStore((state) => state.getNode); useEffect(() => { // Manually triggered form updates are synchronized to the canvas if (id) { - values = form?.getValues() || {}; - let nextValues: any = { + const values = form?.getValues() || watchedValues || {}; + const currentOutputs = getNode(id)?.data?.form?.outputs; + const shouldPreserveLegacyOutputs = + hasLegacyMultiOutputs(currentOutputs) && + isEmpty(form?.formState.dirtyFields?.output); + const hasCompleteOutputContract = + !!values?.output?.name?.trim() && !!values?.output?.type?.trim(); + const nextValues: any = { ...values, arguments: convertToObject( values?.arguments as FormSchemaType['arguments'], ), - outputs: convertOutputsToObject(values as FormSchemaType), + outputs: shouldPreserveLegacyOutputs + ? currentOutputs + : hasCompleteOutputContract + ? serializeCodeOutputContract({ + name: values.output?.name?.trim() ?? '', + type: values.output?.type?.trim() ?? '', + }) + : (currentOutputs ?? + serializeCodeOutputContract(buildDefaultCodeOutput())), }; + delete nextValues.output; updateNodeForm(id, nextValues); } - }, [form?.formState.isDirty, id, updateNodeForm, values]); + }, [ + form?.formState.dirtyFields?.output, + form?.formState.isDirty, + form, + getNode, + id, + updateNodeForm, + watchedValues, + ]); } export function useHandleLanguageChange( @@ -79,12 +77,14 @@ export function useHandleLanguageChange( if (id) { const script = CodeTemplateStrMap[lang as ProgrammingLanguage]; form?.setValue('script', script); - form?.setValue( - 'outputs', - (lang === ProgrammingLanguage.Python - ? [] - : {}) as FormSchemaType['outputs'], - ); + if ( + !form?.getValues('output')?.name || + !form?.getValues('output')?.type + ) { + form?.setValue('output', buildDefaultCodeOutput(), { + shouldDirty: true, + }); + } updateNodeForm(id, script, ['script']); } }, diff --git a/web/src/pages/agent/form/code-form/utils.ts b/web/src/pages/agent/form/code-form/utils.ts new file mode 100644 index 0000000000..204f1f729b --- /dev/null +++ b/web/src/pages/agent/form/code-form/utils.ts @@ -0,0 +1,117 @@ +import { ICodeForm } from '@/interfaces/database/agent'; + +export type CodeOutputContract = { + name: string; + type: string; +}; + +type DeserializeCodeOutputResult = { + contract: CodeOutputContract | null; + legacyOutputs: string[]; +}; + +const CodeExecReservedOutputKeys = [ + 'content', + 'actual_type', + 'raw_result', + '_ERROR', + '_ARTIFACTS', + '_ATTACHMENT_CONTENT', + '_created_time', + '_elapsed_time', +] as const; + +export const CodeExecPanelSystemOutputs: ICodeForm['outputs'] = { + content: { + type: 'String', + value: '', + }, + actual_type: { + type: 'String', + value: '', + }, +}; + +const CodeExecReservedOutputKeySet = new Set( + CodeExecReservedOutputKeys, +); + +export function buildDefaultCodeOutput(): CodeOutputContract { + return { + name: 'result', + type: 'String', + }; +} + +export function isValidCodeOutputName(name: string): boolean { + const value = name.trim(); + return ( + !!value && !CodeExecReservedOutputKeySet.has(value) && !value.includes('.') + ); +} + +export function getBusinessOutputs( + outputs: ICodeForm['outputs'] = {}, +): ICodeForm['outputs'] { + return Object.entries(outputs).reduce((next, entry) => { + const [name, value] = entry; + + if (!CodeExecReservedOutputKeySet.has(name)) { + next[name] = value; + } + + return next; + }, {}); +} + +export function deserializeCodeOutputContract( + form?: Pick | null, +): DeserializeCodeOutputResult { + const outputs = form?.outputs ?? {}; + const businessOutputs = Object.entries(getBusinessOutputs(outputs)); + + if (businessOutputs.length === 0) { + return { contract: buildDefaultCodeOutput(), legacyOutputs: [] }; + } + + if (businessOutputs.length > 1) { + return { + contract: null, + legacyOutputs: businessOutputs.map(([name]) => name), + }; + } + + const [name, output] = businessOutputs[0]; + + return { + contract: { + name, + type: output.type, + }, + legacyOutputs: [], + }; +} + +export function hasLegacyMultiOutputs( + outputs: ICodeForm['outputs'] = {}, +): boolean { + return Object.keys(getBusinessOutputs(outputs)).length > 1; +} + +export function serializeCodeOutputContract( + contract: CodeOutputContract | null, +): ICodeForm['outputs'] { + const name = contract?.name?.trim(); + const type = contract?.type?.trim(); + + if (!name || !type || !isValidCodeOutputName(name)) { + return {}; + } + + return { + [name]: { + type, + value: null, + }, + }; +} diff --git a/web/src/pages/agent/form/components/prompt-editor/index.tsx b/web/src/pages/agent/form/components/prompt-editor/index.tsx index 92edcf5674..13b8d00c1b 100644 --- a/web/src/pages/agent/form/components/prompt-editor/index.tsx +++ b/web/src/pages/agent/form/components/prompt-editor/index.tsx @@ -15,6 +15,7 @@ import { LexicalNode, } from 'lexical'; +import { Switch } from '@/components/ui/switch'; import { Tooltip, TooltipContent, @@ -24,7 +25,7 @@ import { cn } from '@/lib/utils'; import { JsonSchemaDataType } from '@/pages/agent/constant'; import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; import { Variable } from 'lucide-react'; -import { forwardRef, ReactNode, useCallback, useState } from 'react'; +import { forwardRef, ReactNode, useCallback, useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { EnterKeyPlugin } from './enter-key-plugin'; import { PasteHandlerPlugin } from './paste-handler-plugin'; @@ -51,24 +52,30 @@ const Nodes: Array> = [ ]; type PromptContentProps = { + enablePathQueryAutoMerge: boolean; showToolbar?: boolean; multiLine?: boolean; onBlur?: () => void; + onEnablePathQueryAutoMergeChange: (checked: boolean) => void; }; type IProps = { + enablePathQueryAutoMerge?: boolean; + showToolbar?: boolean; + multiLine?: boolean; value?: string; onChange?: (value?: string) => void; onBlur?: () => void; placeholder?: ReactNode; types?: JsonSchemaDataType[]; -} & PromptContentProps & - Pick; +} & Pick; function PromptContent({ + enablePathQueryAutoMerge, showToolbar = true, multiLine = true, onBlur, + onEnablePathQueryAutoMergeChange, }: PromptContentProps) { const [editor] = useLexicalComposerContext(); const [isBlur, setIsBlur] = useState(false); @@ -102,7 +109,7 @@ function PromptContent({ className={cn('border rounded-sm ', { 'border-accent-primary': !isBlur })} > {showToolbar && ( -
+
@@ -113,18 +120,60 @@ function PromptContent({

{t('flow.insertVariableTip')}

+ + + + + +

{t('flow.mergePath')}

+

{t('flow.mergePathTip')}

+
+
)} - + {!showToolbar && ( +
+ + + + + +

{t('flow.mergePath')}

+

{t('flow.mergePathTip')}

+
+
+
)} - onBlur={handleBlur} - onFocus={handleFocus} - /> + +
); } @@ -137,6 +186,7 @@ export const PromptEditor = forwardRef(function PromptEditor( placeholder, showToolbar, multiLine = true, + enablePathQueryAutoMerge = true, extraOptions, baseOptions, types, @@ -144,6 +194,8 @@ export const PromptEditor = forwardRef(function PromptEditor( ref: React.Ref, ) { const { t } = useTranslation(); + const [isPathQueryAutoMergeEnabled, setIsPathQueryAutoMergeEnabled] = + useState(enablePathQueryAutoMerge); const initialConfig: InitialConfigType = { namespace: 'PromptEditor', theme, @@ -151,6 +203,10 @@ export const PromptEditor = forwardRef(function PromptEditor( nodes: Nodes, }; + useEffect(() => { + setIsPathQueryAutoMergeEnabled(enablePathQueryAutoMerge); + }, [enablePathQueryAutoMerge]); + const onValueChange = useCallback( (editorState: EditorState) => { editorState?.read(() => { @@ -171,9 +227,11 @@ export const PromptEditor = forwardRef(function PromptEditor( } placeholder={ @@ -181,7 +239,7 @@ export const PromptEditor = forwardRef(function PromptEditor( className={cn( '-z-10 absolute top-1 left-2 text-text-disabled pointer-events-none', { - 'truncate w-[90%]': !multiLine, + 'truncate max-w-[calc(100%-4rem)]': !multiLine, 'translate-y-9': multiLine, }, )} @@ -200,6 +258,7 @@ export const PromptEditor = forwardRef(function PromptEditor( diff --git a/web/src/pages/agent/form/components/prompt-editor/utils.ts b/web/src/pages/agent/form/components/prompt-editor/utils.ts new file mode 100644 index 0000000000..3da99a93ea --- /dev/null +++ b/web/src/pages/agent/form/components/prompt-editor/utils.ts @@ -0,0 +1,93 @@ +import type { ReactNode } from 'react'; + +type PromptVariableOptionLike = { + label: string; + value: string; + parentLabel?: string | ReactNode; + icon?: ReactNode; + type?: string; +}; + +type PromptVariablePathParts = { + rootValue: string; + pathSuffix: string; +}; + +type PromptVariableLeadingPathMatch = { + pathSuffix: string; + remainingText: string; +}; + +const PromptVariableLeadingPathRegex = + /^(?(?:\.(?:\d+|[A-Za-z_][A-Za-z0-9_]*))+)/; + +function splitPromptVariablePath(value: string): PromptVariablePathParts { + const [nodeId, variable = ''] = value.split('@'); + + if (!nodeId || !variable) { + return { rootValue: value, pathSuffix: '' }; + } + + const dotIndex = variable.indexOf('.'); + if (dotIndex < 0) { + return { rootValue: value, pathSuffix: '' }; + } + + return { + rootValue: `${nodeId}@${variable.slice(0, dotIndex)}`, + pathSuffix: variable.slice(dotIndex), + }; +} + +export function extractLeadingPromptVariablePath( + text: string, +): PromptVariableLeadingPathMatch | undefined { + const match = PromptVariableLeadingPathRegex.exec(text); + const pathSuffix = match?.groups?.pathSuffix; + + if (!pathSuffix) { + return undefined; + } + + return { + pathSuffix, + remainingText: text.slice(pathSuffix.length), + }; +} + +export function appendPromptVariablePath( + option: PromptVariableOptionLike, + pathSuffix: string, +): PromptVariableOptionLike { + if (!pathSuffix) { + return option; + } + + return { + ...option, + value: `${option.value}${pathSuffix}`, + label: `${option.label}${pathSuffix}`, + }; +} + +export function resolvePromptVariableOption( + value: string, + options: PromptVariableOptionLike[], +): PromptVariableOptionLike | undefined { + const exactMatch = options.find((option) => option.value === value); + if (exactMatch) { + return exactMatch; + } + + const { rootValue, pathSuffix } = splitPromptVariablePath(value); + if (!pathSuffix) { + return undefined; + } + + const rootOption = options.find((option) => option.value === rootValue); + if (!rootOption) { + return undefined; + } + + return appendPromptVariablePath(rootOption, pathSuffix); +} diff --git a/web/src/pages/agent/form/components/prompt-editor/variable-node.tsx b/web/src/pages/agent/form/components/prompt-editor/variable-node.tsx index 5b5790ef36..deb29e2e64 100644 --- a/web/src/pages/agent/form/components/prompt-editor/variable-node.tsx +++ b/web/src/pages/agent/form/components/prompt-editor/variable-node.tsx @@ -78,7 +78,7 @@ export class VariableNode extends DecoratorNode { export function $createVariableNode( value: string, label: string, - parentLabel: string | ReactNode, + parentLabel?: string | ReactNode, icon?: ReactNode, ): VariableNode { return new VariableNode(value, label, undefined, parentLabel, icon); diff --git a/web/src/pages/agent/form/components/prompt-editor/variable-on-change-plugin.tsx b/web/src/pages/agent/form/components/prompt-editor/variable-on-change-plugin.tsx index 002face8da..2cab4d9564 100644 --- a/web/src/pages/agent/form/components/prompt-editor/variable-on-change-plugin.tsx +++ b/web/src/pages/agent/form/components/prompt-editor/variable-on-change-plugin.tsx @@ -1,9 +1,11 @@ import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; -import { EditorState, LexicalEditor } from 'lexical'; +import { EditorState, LexicalEditor, TextNode } from 'lexical'; import { useEffect } from 'react'; import { ProgrammaticTag } from './constant'; +import { mergeLeadingVariablePathTextNode } from './variable-path-transform'; interface VariableOnChangePluginProps { + enablePathQueryAutoMerge?: boolean; onChange: ( editorState: EditorState, editor?: LexicalEditor, @@ -12,14 +14,17 @@ interface VariableOnChangePluginProps { } export function VariableOnChangePlugin({ + enablePathQueryAutoMerge = true, onChange, }: VariableOnChangePluginProps) { // Access the editor through the LexicalComposerContext const [editor] = useLexicalComposerContext(); // Wrap our listener in useEffect to handle the teardown and avoid stale references. useEffect(() => { - // most listeners return a teardown function that can be called to clean them up. - return editor.registerUpdateListener( + const removeTransform = enablePathQueryAutoMerge + ? editor.registerNodeTransform(TextNode, mergeLeadingVariablePathTextNode) + : () => {}; + const removeUpdateListener = editor.registerUpdateListener( ({ editorState, tags, dirtyElements }) => { // Check if there is a "programmatic" tag const isProgrammaticUpdate = tags.has(ProgrammaticTag); @@ -31,7 +36,12 @@ export function VariableOnChangePlugin({ } }, ); - }, [editor, onChange]); + + return () => { + removeTransform(); + removeUpdateListener(); + }; + }, [editor, enablePathQueryAutoMerge, onChange]); return null; } diff --git a/web/src/pages/agent/form/components/prompt-editor/variable-path-transform.ts b/web/src/pages/agent/form/components/prompt-editor/variable-path-transform.ts new file mode 100644 index 0000000000..d095afe780 --- /dev/null +++ b/web/src/pages/agent/form/components/prompt-editor/variable-path-transform.ts @@ -0,0 +1,43 @@ +import { TextNode } from 'lexical'; +import { + appendPromptVariablePath, + extractLeadingPromptVariablePath, +} from './utils'; +import { $createVariableNode, $isVariableNode } from './variable-node'; + +export function mergeLeadingVariablePathTextNode(textNode: TextNode): boolean { + const previousSibling = textNode.getPreviousSibling(); + + if (!$isVariableNode(previousSibling)) { + return false; + } + + const leadingPath = extractLeadingPromptVariablePath( + textNode.getTextContent(), + ); + if (!leadingPath) { + return false; + } + + const nextVariable = appendPromptVariablePath( + { + value: previousSibling.__value, + label: previousSibling.__label, + parentLabel: previousSibling.__parentLabel, + icon: previousSibling.__icon, + }, + leadingPath.pathSuffix, + ); + + previousSibling.replace( + $createVariableNode( + nextVariable.value, + nextVariable.label, + nextVariable.parentLabel, + nextVariable.icon, + ), + ); + textNode.setTextContent(leadingPath.remainingText); + + return true; +} diff --git a/web/src/pages/agent/form/components/prompt-editor/variable-picker-plugin.tsx b/web/src/pages/agent/form/components/prompt-editor/variable-picker-plugin.tsx index 5ea6564b1b..a54980c8d1 100644 --- a/web/src/pages/agent/form/components/prompt-editor/variable-picker-plugin.tsx +++ b/web/src/pages/agent/form/components/prompt-editor/variable-picker-plugin.tsx @@ -36,6 +36,7 @@ import React, { } from 'react'; import * as ReactDOM from 'react-dom'; +import { resolvePromptVariableOption } from './utils'; import { $createVariableNode } from './variable-node'; import { ScrollArea } from '@/components/ui/scroll-area'; @@ -530,7 +531,7 @@ export default function VariablePickerMenuPlugin({ return agentStructuredOutput; } - return children.find((x) => x.value === value); + return resolvePromptVariableOption(value, children); }, [findAgentStructuredOutputLabel, options], );