mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Refa: refine code_exec component (#13925)
### What problem does this PR solve? Refine code_exec component. ### Type of change - [x] Refactoring
This commit is contained in:
@@ -20,6 +20,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@@ -32,6 +33,183 @@ from common.connection_utils import timeout
|
||||
from common.constants import SANDBOX_ARTIFACT_BUCKET, SANDBOX_ARTIFACT_EXPIRE_DAYS
|
||||
|
||||
|
||||
SYSTEM_OUTPUT_KEYS = frozenset(
|
||||
{
|
||||
"content",
|
||||
"actual_type",
|
||||
"_ERROR",
|
||||
"_ARTIFACTS",
|
||||
"_ATTACHMENT_CONTENT",
|
||||
"raw_result",
|
||||
"_created_time",
|
||||
"_elapsed_time",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ContractError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def _validate_business_output_name(name: str) -> None:
|
||||
if not name or not name.strip():
|
||||
raise ContractError("CodeExec business output name must not be empty")
|
||||
if name in SYSTEM_OUTPUT_KEYS:
|
||||
raise ContractError(f"CodeExec reserved output name is not allowed: {name}")
|
||||
if "." in name:
|
||||
raise ContractError(f"CodeExec business output name must not contain '.': {name}")
|
||||
|
||||
|
||||
def select_business_output(outputs: Mapping[str, object]) -> tuple[str, object]:
|
||||
if len(outputs) == 1:
|
||||
only_name, only_meta = next(iter(outputs.items()))
|
||||
_validate_business_output_name(only_name)
|
||||
return only_name, only_meta
|
||||
|
||||
business_outputs = [(name, meta) for name, meta in outputs.items() if name not in SYSTEM_OUTPUT_KEYS]
|
||||
if len(business_outputs) != 1:
|
||||
raise ContractError(
|
||||
f"CodeExec contract must contain exactly one business output, got {len(business_outputs)}"
|
||||
)
|
||||
_validate_business_output_name(business_outputs[0][0])
|
||||
return business_outputs[0]
|
||||
|
||||
|
||||
def normalize_output_value(value):
|
||||
if isinstance(value, (tuple, list)):
|
||||
return [normalize_output_value(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {key: normalize_output_value(item) for key, item in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
def infer_actual_type(value) -> str:
|
||||
value = normalize_output_value(value)
|
||||
if value is None:
|
||||
return "Null"
|
||||
if isinstance(value, bool):
|
||||
return "Boolean"
|
||||
if _is_number(value):
|
||||
return "Number"
|
||||
if isinstance(value, str):
|
||||
return "String"
|
||||
if isinstance(value, dict):
|
||||
return "Object"
|
||||
if isinstance(value, list):
|
||||
if not value:
|
||||
return "Array<Any>"
|
||||
inferred = {infer_actual_type(item) for item in value}
|
||||
if len(inferred) == 1:
|
||||
return f"Array<{inferred.pop()}>"
|
||||
return "Array<Any>"
|
||||
return "Any"
|
||||
|
||||
|
||||
def render_canonical_content(value) -> str:
|
||||
value = normalize_output_value(value)
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value, ensure_ascii=False, indent=2, sort_keys=True)
|
||||
return str(value)
|
||||
|
||||
|
||||
def _is_number(value) -> bool:
|
||||
return isinstance(value, (int, float)) and not isinstance(value, bool)
|
||||
|
||||
|
||||
def _validate_top_level_value_domain(value) -> None:
|
||||
allowed = value is None or isinstance(value, (bool, str, dict, list)) or _is_number(value)
|
||||
if not allowed:
|
||||
raise ContractError(
|
||||
f"CodeExec unsupported top-level result type: {type(value).__name__}. "
|
||||
"Allowed top-level values are String, Number, Boolean, Object, Array, or Null."
|
||||
)
|
||||
|
||||
|
||||
def _normalize_expected_type(expected_type: str) -> str:
|
||||
etype = expected_type.strip()
|
||||
low = etype.lower()
|
||||
simple_types = {
|
||||
"string": "String",
|
||||
"number": "Number",
|
||||
"boolean": "Boolean",
|
||||
"object": "Object",
|
||||
"null": "Null",
|
||||
"any": "Any",
|
||||
}
|
||||
if low in simple_types:
|
||||
return simple_types[low]
|
||||
if low.startswith("array<") and low.endswith(">"):
|
||||
inner = etype[etype.find("<") + 1 : -1].strip()
|
||||
if not inner:
|
||||
raise ContractError(f"Unsupported expected type: {expected_type}")
|
||||
return f"Array<{_normalize_expected_type(inner)}>"
|
||||
return etype
|
||||
|
||||
|
||||
def _validate_expected_type(expected_type: str, value, path: str = "") -> None:
|
||||
etype = _normalize_expected_type(expected_type)
|
||||
if not etype or etype.lower() == "any":
|
||||
return
|
||||
|
||||
value = normalize_output_value(value)
|
||||
|
||||
if etype.startswith("Array<") and etype.endswith(">"):
|
||||
inner_type = etype[6:-1].strip()
|
||||
if not isinstance(value, list):
|
||||
raise ContractError(
|
||||
f"CodeExec contract mismatch at {path or 'value'}: expected type {etype}, got {infer_actual_type(value)}"
|
||||
)
|
||||
for index, item in enumerate(value):
|
||||
child_path = f"{path}[{index}]" if path else f"[{index}]"
|
||||
_validate_expected_type(inner_type, item, child_path)
|
||||
return
|
||||
|
||||
actual_type = infer_actual_type(value)
|
||||
if etype == "String":
|
||||
valid = isinstance(value, str)
|
||||
elif etype == "Number":
|
||||
valid = _is_number(value)
|
||||
elif etype == "Boolean":
|
||||
valid = isinstance(value, bool)
|
||||
elif etype == "Object":
|
||||
valid = isinstance(value, dict)
|
||||
elif etype == "Null":
|
||||
valid = value is None
|
||||
else:
|
||||
raise ContractError(f"Unsupported expected type: {expected_type}")
|
||||
|
||||
if not valid:
|
||||
raise ContractError(
|
||||
f"CodeExec contract mismatch at {path or 'value'}: expected type {etype}, got {actual_type}"
|
||||
)
|
||||
|
||||
|
||||
def build_code_exec_contract(outputs: Mapping[str, object], raw_result) -> dict[str, object]:
|
||||
business_name, business_meta = select_business_output(outputs)
|
||||
expected_type = ""
|
||||
if isinstance(business_meta, Mapping):
|
||||
expected_type = str(business_meta.get("type") or "")
|
||||
|
||||
normalized_value = normalize_output_value(raw_result)
|
||||
_validate_top_level_value_domain(normalized_value)
|
||||
_validate_expected_type(expected_type, normalized_value)
|
||||
|
||||
return {
|
||||
"business_output": business_name,
|
||||
"value": normalized_value,
|
||||
"actual_type": infer_actual_type(normalized_value),
|
||||
"content": render_canonical_content(normalized_value),
|
||||
}
|
||||
|
||||
|
||||
def _art_field(art, field: str, default=""):
|
||||
return art.get(field, default) if isinstance(art, dict) else getattr(art, field, default)
|
||||
|
||||
|
||||
class Language(StrEnum):
|
||||
PYTHON = "python"
|
||||
NODEJS = "nodejs"
|
||||
@@ -190,7 +368,13 @@ class CodeExec(ToolBase, ABC):
|
||||
return
|
||||
|
||||
artifacts = result.metadata.get("artifacts", []) if result.metadata else []
|
||||
return self._process_execution_result(result.stdout, result.stderr, "Provider system", artifacts)
|
||||
return self._process_execution_result(
|
||||
result.stdout,
|
||||
result.stderr,
|
||||
"Provider system",
|
||||
artifacts,
|
||||
execution_metadata=result.metadata,
|
||||
)
|
||||
|
||||
except (ImportError, RuntimeError) as provider_error:
|
||||
# Provider system not available or not configured, fall back to HTTP
|
||||
@@ -226,6 +410,7 @@ class CodeExec(ToolBase, ABC):
|
||||
body.get("stderr"),
|
||||
f"http://{settings.SANDBOX_HOST}:9385/run",
|
||||
body.get("artifacts", []),
|
||||
execution_metadata=self._build_http_execution_metadata(body),
|
||||
)
|
||||
else:
|
||||
self.set_output("_ERROR", "There is no response from sandbox")
|
||||
@@ -239,8 +424,18 @@ class CodeExec(ToolBase, ABC):
|
||||
|
||||
return self.output()
|
||||
|
||||
def _process_execution_result(self, stdout: str, stderr: str | None, source: str, artifacts: list | None = None):
|
||||
if stderr and not stdout and not artifacts:
|
||||
def _process_execution_result(
|
||||
self,
|
||||
stdout: str,
|
||||
stderr: str | None,
|
||||
source: str,
|
||||
artifacts: list | None = None,
|
||||
execution_metadata: dict | None = None,
|
||||
):
|
||||
has_structured_result = bool((execution_metadata or {}).get("result_present") is True)
|
||||
resolved_value, used_stdout_fallback = self._resolve_execution_result_value(stdout, execution_metadata)
|
||||
|
||||
if stderr and not has_structured_result and not artifacts and not str(stdout or "").strip():
|
||||
self.set_output("_ERROR", stderr)
|
||||
return self.output()
|
||||
|
||||
@@ -250,29 +445,48 @@ class CodeExec(ToolBase, ABC):
|
||||
if stderr:
|
||||
logging.warning(f"[CodeExec]: stderr (non-fatal): {stderr[:500]}")
|
||||
|
||||
parsed_stdout = self._deserialize_stdout(stdout)
|
||||
logging.info(f"[CodeExec]: {source} -> {parsed_stdout}")
|
||||
self._populate_outputs(parsed_stdout, stdout)
|
||||
if used_stdout_fallback and str(stdout or "").strip():
|
||||
logging.warning("[CodeExec]: Falling back to stdout deserialization because no structured result metadata was provided")
|
||||
|
||||
logging.info(f"[CodeExec]: {source} -> {resolved_value}")
|
||||
content_parts = []
|
||||
base_content = self._build_content_text(parsed_stdout, raw_stdout=stdout)
|
||||
base_content = self._apply_business_output(resolved_value)
|
||||
if base_content:
|
||||
content_parts.append(base_content)
|
||||
|
||||
if artifacts:
|
||||
artifact_urls = self._upload_artifacts(artifacts)
|
||||
if artifact_urls:
|
||||
self.set_output("_ARTIFACTS", artifact_urls)
|
||||
self.set_output("_ARTIFACTS", artifact_urls or None)
|
||||
attachment_text = self._build_attachment_content(artifacts, artifact_urls)
|
||||
self.set_output("_ATTACHMENT_CONTENT", attachment_text)
|
||||
if attachment_text:
|
||||
content_parts.append(attachment_text)
|
||||
else:
|
||||
self.set_output("_ARTIFACTS", None)
|
||||
self.set_output("_ATTACHMENT_CONTENT", "")
|
||||
|
||||
self.set_output("content", "\n\n".join([part for part in content_parts if part]).strip())
|
||||
|
||||
return self.output()
|
||||
|
||||
def _build_http_execution_metadata(self, body: Mapping | None) -> dict:
|
||||
if not isinstance(body, Mapping):
|
||||
return {}
|
||||
structured_result = body.get("result")
|
||||
if not isinstance(structured_result, Mapping):
|
||||
return {}
|
||||
return {
|
||||
"result_present": structured_result.get("present", False),
|
||||
"result_value": structured_result.get("value"),
|
||||
"result_type": structured_result.get("type"),
|
||||
}
|
||||
|
||||
def _resolve_execution_result_value(self, stdout: str, execution_metadata: Mapping | None = None):
|
||||
metadata = execution_metadata or {}
|
||||
if metadata.get("result_present") is True:
|
||||
return metadata.get("result_value"), False
|
||||
return self._deserialize_stdout(stdout), True
|
||||
|
||||
@classmethod
|
||||
def _ensure_bucket_lifecycle(cls):
|
||||
if cls._lifecycle_configured:
|
||||
@@ -306,10 +520,10 @@ class CodeExec(ToolBase, ABC):
|
||||
uploaded = []
|
||||
for art in artifacts:
|
||||
try:
|
||||
name = art.get("name", "") if isinstance(art, dict) else getattr(art, "name", "")
|
||||
content_b64 = art.get("content_b64", "") if isinstance(art, dict) else getattr(art, "content_b64", "")
|
||||
mime_type = art.get("mime_type", "") if isinstance(art, dict) else getattr(art, "mime_type", "")
|
||||
size = art.get("size", 0) if isinstance(art, dict) else getattr(art, "size", 0)
|
||||
name = _art_field(art, "name")
|
||||
content_b64 = _art_field(art, "content_b64")
|
||||
mime_type = _art_field(art, "mime_type")
|
||||
size = _art_field(art, "size", 0)
|
||||
if not content_b64 or not name:
|
||||
continue
|
||||
|
||||
@@ -350,119 +564,24 @@ class CodeExec(ToolBase, ABC):
|
||||
continue
|
||||
return text
|
||||
|
||||
def _coerce_output_value(self, value, expected_type: Optional[str]):
|
||||
if expected_type is None:
|
||||
return value
|
||||
|
||||
etype = expected_type.strip().lower()
|
||||
inner_type = None
|
||||
if etype.startswith("array<") and etype.endswith(">"):
|
||||
inner_type = etype[6:-1].strip()
|
||||
etype = "array"
|
||||
def _apply_business_output(self, parsed_stdout) -> str:
|
||||
normalized_result = normalize_output_value(parsed_stdout)
|
||||
self.set_output("raw_result", normalized_result)
|
||||
|
||||
business_output_names = [name for name in self._param.outputs if name not in SYSTEM_OUTPUT_KEYS]
|
||||
try:
|
||||
if etype == "string":
|
||||
return "" if value is None else str(value)
|
||||
contract = build_code_exec_contract(self._param.outputs, normalized_result)
|
||||
except ContractError as e:
|
||||
for output_name in business_output_names:
|
||||
self.set_output(output_name, None)
|
||||
self.set_output("actual_type", infer_actual_type(normalized_result))
|
||||
self.set_output("_ERROR", str(e))
|
||||
logging.warning(f"[CodeExec]: contract validation failed: {e}")
|
||||
return render_canonical_content(normalized_result)
|
||||
|
||||
if etype == "number":
|
||||
if value is None or value == "":
|
||||
return None
|
||||
if isinstance(value, (int, float)):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return value
|
||||
return float(value)
|
||||
|
||||
if etype == "boolean":
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
lv = value.lower()
|
||||
if lv in ("true", "1", "yes", "y", "on"):
|
||||
return True
|
||||
if lv in ("false", "0", "no", "n", "off"):
|
||||
return False
|
||||
return bool(value)
|
||||
|
||||
if etype == "array":
|
||||
candidate = value
|
||||
if isinstance(candidate, str):
|
||||
parsed = self._deserialize_stdout(candidate)
|
||||
candidate = parsed
|
||||
if isinstance(candidate, tuple):
|
||||
candidate = list(candidate)
|
||||
if not isinstance(candidate, list):
|
||||
candidate = [] if candidate is None else [candidate]
|
||||
|
||||
if inner_type == "string":
|
||||
return ["" if v is None else str(v) for v in candidate]
|
||||
if inner_type == "number":
|
||||
coerced = []
|
||||
for v in candidate:
|
||||
try:
|
||||
if v is None or v == "":
|
||||
coerced.append(None)
|
||||
elif isinstance(v, (int, float)):
|
||||
coerced.append(v)
|
||||
else:
|
||||
coerced.append(float(v))
|
||||
except Exception:
|
||||
coerced.append(v)
|
||||
return coerced
|
||||
return candidate
|
||||
|
||||
if etype == "object":
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
parsed = self._deserialize_stdout(value)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return value
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
return value
|
||||
|
||||
def _populate_outputs(self, parsed_stdout, raw_stdout: str):
|
||||
outputs_items = list(self._param.outputs.items())
|
||||
logging.info(f"[CodeExec]: outputs schema keys: {[k for k, _ in outputs_items]}")
|
||||
if not outputs_items:
|
||||
return
|
||||
|
||||
if isinstance(parsed_stdout, dict):
|
||||
for key, meta in outputs_items:
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
val = self._get_by_path(parsed_stdout, key)
|
||||
if val is None and len(outputs_items) == 1:
|
||||
val = parsed_stdout
|
||||
coerced = self._coerce_output_value(val, meta.get("type"))
|
||||
logging.info(f"[CodeExec]: populate dict key='{key}' raw='{val}' coerced='{coerced}'")
|
||||
self.set_output(key, coerced)
|
||||
return
|
||||
|
||||
if isinstance(parsed_stdout, (list, tuple)):
|
||||
for idx, (key, meta) in enumerate(outputs_items):
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
val = parsed_stdout[idx] if idx < len(parsed_stdout) else None
|
||||
coerced = self._coerce_output_value(val, meta.get("type"))
|
||||
logging.info(f"[CodeExec]: populate list key='{key}' raw='{val}' coerced='{coerced}'")
|
||||
self.set_output(key, coerced)
|
||||
return
|
||||
|
||||
default_val = parsed_stdout if parsed_stdout is not None else raw_stdout
|
||||
for idx, (key, meta) in enumerate(outputs_items):
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
val = default_val if idx == 0 else None
|
||||
coerced = self._coerce_output_value(val, meta.get("type"))
|
||||
logging.info(f"[CodeExec]: populate scalar key='{key}' raw='{val}' coerced='{coerced}'")
|
||||
self.set_output(key, coerced)
|
||||
self.set_output("actual_type", contract["actual_type"])
|
||||
self.set_output(contract["business_output"], contract["value"])
|
||||
return contract["content"]
|
||||
|
||||
def _build_attachment_content(self, artifacts: list, artifact_urls: list[dict] | None = None) -> str:
|
||||
sections = []
|
||||
@@ -471,9 +590,9 @@ class CodeExec(ToolBase, ABC):
|
||||
for idx, art in enumerate(artifacts, start=1):
|
||||
key = f"attachment{idx}"
|
||||
try:
|
||||
name = art.get("name", "") if isinstance(art, dict) else getattr(art, "name", "")
|
||||
content_b64 = art.get("content_b64", "") if isinstance(art, dict) else getattr(art, "content_b64", "")
|
||||
mime_type = art.get("mime_type", "") if isinstance(art, dict) else getattr(art, "mime_type", "")
|
||||
name = _art_field(art, "name")
|
||||
content_b64 = _art_field(art, "content_b64")
|
||||
mime_type = _art_field(art, "mime_type")
|
||||
if not name or not content_b64:
|
||||
continue
|
||||
|
||||
@@ -490,11 +609,8 @@ class CodeExec(ToolBase, ABC):
|
||||
logging.info(f"[CodeExec]: parse attachment section key='{key}' from artifact='{name}'")
|
||||
except Exception as e:
|
||||
logging.warning(f"[CodeExec]: Failed to parse artifact for content section '{key}': {e}")
|
||||
fallback_type = self._normalize_attachment_type(
|
||||
art.get("name", "") if isinstance(art, dict) else getattr(art, "name", ""),
|
||||
art.get("mime_type", "") if isinstance(art, dict) else getattr(art, "mime_type", ""),
|
||||
)
|
||||
fallback_name = art.get("name", "") if isinstance(art, dict) else getattr(art, "name", "")
|
||||
fallback_type = self._normalize_attachment_type(name, mime_type)
|
||||
fallback_name = name
|
||||
fallback_url = ""
|
||||
if idx - 1 < len(artifact_urls):
|
||||
fallback_url = artifact_urls[idx - 1].get("url", "")
|
||||
@@ -529,38 +645,3 @@ class CodeExec(ToolBase, ABC):
|
||||
title += f": {name}"
|
||||
body = parsed if isinstance(parsed, str) else json.dumps(parsed, ensure_ascii=False)
|
||||
return f"{title}\n{body}".strip()
|
||||
|
||||
def _build_content_text(self, parsed_stdout, raw_stdout: str = "") -> str:
|
||||
if isinstance(parsed_stdout, str):
|
||||
return parsed_stdout.strip()
|
||||
if isinstance(parsed_stdout, (dict, list, tuple)):
|
||||
try:
|
||||
return json.dumps(parsed_stdout, ensure_ascii=False, indent=2).strip()
|
||||
except Exception:
|
||||
return str(parsed_stdout).strip()
|
||||
if parsed_stdout is None:
|
||||
return str(raw_stdout or "").strip()
|
||||
return str(parsed_stdout).strip()
|
||||
|
||||
def _get_by_path(self, data, path: str):
|
||||
if not path:
|
||||
return None
|
||||
cur = data
|
||||
for part in path.split("."):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
return None
|
||||
if isinstance(cur, dict):
|
||||
cur = cur.get(part)
|
||||
elif isinstance(cur, list):
|
||||
try:
|
||||
idx = int(part)
|
||||
cur = cur[idx]
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
if cur is None:
|
||||
return None
|
||||
logging.info(f"[CodeExec]: resolve path '{path}' -> {cur}")
|
||||
return cur
|
||||
|
||||
Reference in New Issue
Block a user