mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat(go-agent): Ported retrieval node, added Keenable web search tool (#16396)
Ported retrieval node, added Keenable web search tool - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -415,6 +415,7 @@ class Canvas(Graph):
|
||||
if not self.globals["sys.conversation_turns"] :
|
||||
self.globals["sys.conversation_turns"] = 0
|
||||
self.globals["sys.conversation_turns"] += 1
|
||||
is_resume = bool(self.path) and self.path[0].lower().find("userfillup") >= 0
|
||||
|
||||
def decorate(event, dt):
|
||||
nonlocal created_at
|
||||
@@ -427,16 +428,16 @@ class Canvas(Graph):
|
||||
"data": dt
|
||||
}
|
||||
|
||||
if not self.path or self.path[-1].lower().find("userfillup") < 0:
|
||||
if not is_resume:
|
||||
self.path.append("begin")
|
||||
self.retrieval.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
if self.is_canceled():
|
||||
msg = f"Task {self.task_id} has been canceled before starting."
|
||||
logging.info(msg)
|
||||
raise TaskCanceledException(msg)
|
||||
|
||||
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
|
||||
if not is_resume:
|
||||
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
|
||||
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
|
||||
|
||||
async def _run_batch(f, t):
|
||||
@@ -501,7 +502,7 @@ class Canvas(Graph):
|
||||
})
|
||||
|
||||
self.error = ""
|
||||
idx = len(self.path) - 1
|
||||
idx = 0 if is_resume else len(self.path) - 1
|
||||
partials = []
|
||||
tts_mdl = None
|
||||
while idx < len(self.path):
|
||||
@@ -647,9 +648,14 @@ class Canvas(Graph):
|
||||
o = self.get_component_obj(c)
|
||||
if o.component_name.lower() == "userfillup":
|
||||
o.invoke()
|
||||
another_inputs.update(o.get_input_elements())
|
||||
another_inputs.update({
|
||||
k: v for k, v in o.get_input_elements().items()
|
||||
if not self._is_input_field_satisfied(v)
|
||||
})
|
||||
if o.get_param("enable_tips"):
|
||||
tips = o.output("tips")
|
||||
if not another_inputs:
|
||||
continue
|
||||
self.path = path
|
||||
yield decorate("user_inputs", {"inputs": another_inputs, "tips": tips})
|
||||
return
|
||||
@@ -734,7 +740,25 @@ class Canvas(Graph):
|
||||
|
||||
def add_user_input(self, question):
|
||||
self.history.append(("user", question))
|
||||
self.globals["sys.history"].append(f"{self.history[-1][0]}: {self.history[-1][1]}")
|
||||
rendered = json.dumps(question, ensure_ascii=False) if isinstance(question, dict) else question
|
||||
self.globals["sys.history"].append(f"{self.history[-1][0]}: {rendered}")
|
||||
|
||||
@staticmethod
|
||||
def _is_input_field_satisfied(field: Any) -> bool:
|
||||
if not isinstance(field, dict):
|
||||
return field is not None
|
||||
|
||||
value = field.get("value")
|
||||
field_type = str(field.get("type", "")).lower()
|
||||
if field_type.find("file") >= 0:
|
||||
if field.get("optional") and value is None:
|
||||
return True
|
||||
return value not in (None, [], "")
|
||||
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_prologue(self):
|
||||
return self.components["begin"]["obj"]._param.prologue
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from agent.component.fillup import UserFillUpParam, UserFillUp
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
|
||||
class BeginParam(UserFillUpParam):
|
||||
@@ -42,20 +41,11 @@ class Begin(UserFillUp):
|
||||
return
|
||||
|
||||
layout_recognize = self._param.layout_recognize or None
|
||||
for k, v in kwargs.get("inputs", {}).items():
|
||||
merged_inputs = self._merge_runtime_inputs(kwargs.get("inputs", {}))
|
||||
for k, v in merged_inputs.items():
|
||||
if self.check_if_canceled("Begin processing"):
|
||||
return
|
||||
|
||||
if isinstance(v, dict) and v.get("type", "").lower().find("file") >= 0:
|
||||
if v.get("optional") and v.get("value", None) is None:
|
||||
v = None
|
||||
else:
|
||||
file_value = v["value"]
|
||||
# Support both single file (backward compatibility) and multiple files
|
||||
files = file_value if isinstance(file_value, list) else [file_value]
|
||||
v = FileService.get_files(files, layout_recognize=layout_recognize)
|
||||
else:
|
||||
v = v.get("value")
|
||||
v = self._resolve_input_value(v, layout_recognize)
|
||||
self.set_output(k, v)
|
||||
self.set_input_value(k, v)
|
||||
|
||||
|
||||
@@ -21,6 +21,9 @@ from agent.component.base import ComponentParamBase, ComponentBase
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
|
||||
_INITIAL_USER_INPUT_CONSUMED_KEY = "sys.__initial_user_input_consumed__"
|
||||
|
||||
|
||||
class UserFillUpParam(ComponentParamBase):
|
||||
|
||||
def __init__(self):
|
||||
@@ -36,6 +39,52 @@ class UserFillUpParam(ComponentParamBase):
|
||||
class UserFillUp(ComponentBase):
|
||||
component_name = "UserFillUp"
|
||||
|
||||
def _merge_runtime_inputs(self, runtime_inputs):
|
||||
if runtime_inputs:
|
||||
return runtime_inputs
|
||||
|
||||
fields = self.get_input_elements()
|
||||
if not fields:
|
||||
return {}
|
||||
|
||||
if self._canvas.globals.get(_INITIAL_USER_INPUT_CONSUMED_KEY):
|
||||
return {}
|
||||
|
||||
query = self._canvas.globals.get("sys.query")
|
||||
if query is None or query == "":
|
||||
return {}
|
||||
|
||||
if isinstance(query, dict):
|
||||
matched = {
|
||||
key: value if isinstance(value, dict) else {"value": value}
|
||||
for key, value in query.items()
|
||||
if key in fields
|
||||
}
|
||||
if matched:
|
||||
self._canvas.globals[_INITIAL_USER_INPUT_CONSUMED_KEY] = True
|
||||
return matched
|
||||
|
||||
if len(fields) == 1:
|
||||
field_name = next(iter(fields))
|
||||
self._canvas.globals[_INITIAL_USER_INPUT_CONSUMED_KEY] = True
|
||||
return {field_name: {"value": query}}
|
||||
|
||||
return {}
|
||||
|
||||
def _resolve_input_value(self, value, layout_recognize):
|
||||
if isinstance(value, dict) and value.get("type", "").lower().find("file") >= 0:
|
||||
if value.get("optional") and value.get("value", None) is None:
|
||||
return None
|
||||
|
||||
file_value = value["value"]
|
||||
files = file_value if isinstance(file_value, list) else [file_value]
|
||||
return FileService.get_files(files, layout_recognize=layout_recognize)
|
||||
|
||||
if isinstance(value, dict):
|
||||
return value.get("value")
|
||||
|
||||
return value
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("UserFillUp processing"):
|
||||
return
|
||||
@@ -63,20 +112,13 @@ class UserFillUp(ComponentBase):
|
||||
|
||||
self.set_output("tips", content)
|
||||
layout_recognize = self._param.layout_recognize or None
|
||||
for k, v in kwargs.get("inputs", {}).items():
|
||||
merged_inputs = self._merge_runtime_inputs(kwargs.get("inputs", {}))
|
||||
for k, v in merged_inputs.items():
|
||||
if self.check_if_canceled("UserFillUp processing"):
|
||||
return
|
||||
if isinstance(v, dict) and v.get("type", "").lower().find("file") >= 0:
|
||||
if v.get("optional") and v.get("value", None) is None:
|
||||
v = None
|
||||
else:
|
||||
file_value = v["value"]
|
||||
# Support both single file (backward compatibility) and multiple files
|
||||
files = file_value if isinstance(file_value, list) else [file_value]
|
||||
v = FileService.get_files(files, layout_recognize=layout_recognize)
|
||||
else:
|
||||
v = v.get("value")
|
||||
self.set_output(k, v)
|
||||
resolved = self._resolve_input_value(v, layout_recognize)
|
||||
self.set_output(k, resolved)
|
||||
self.set_input_value(k, resolved)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Waiting for your input..."
|
||||
|
||||
@@ -38,9 +38,17 @@ class ListOperationsParam(ComponentParamBase):
|
||||
"type": "?"
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_operation_name(operation):
|
||||
op = "" if operation is None else str(operation).strip()
|
||||
if op.lower() == "topn":
|
||||
return "head"
|
||||
return op or "nth"
|
||||
|
||||
def check(self):
|
||||
self.check_empty(self.query, "query")
|
||||
self.operations = self._normalize_operation_name(self.operations)
|
||||
self.check_valid_value(
|
||||
self.operations,
|
||||
"Support operations",
|
||||
|
||||
@@ -226,7 +226,7 @@ class AliyunCodeInterpreterProvider(SandboxProvider):
|
||||
# Connect to existing sandbox instance
|
||||
sandbox = Sandbox.connect(sandbox_id=instance_id, config=self._config)
|
||||
|
||||
# agentrun-sdk 0.0.26 only exposes CodeLanguage.PYTHON; keep JS as string fallback.
|
||||
# CodeLanguage enum only exposes PYTHON across agentrun-sdk 0.0.26+; keep JS as string fallback.
|
||||
code_language = CodeLanguage.PYTHON if normalized_lang == "python" else "javascript"
|
||||
|
||||
# Wrap code to call main() function
|
||||
@@ -355,7 +355,7 @@ class AliyunCodeInterpreterProvider(SandboxProvider):
|
||||
# Try to list templates to verify connection
|
||||
from agentrun.sandbox import Template
|
||||
|
||||
templates = Template.list(config=self._config)
|
||||
templates = Template.list_templates(config=self._config)
|
||||
return templates is not None
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -45,11 +45,11 @@ class TestAliyunCodeInterpreterProvider:
|
||||
assert provider.timeout == 30
|
||||
assert not provider._initialized
|
||||
|
||||
@patch("agent.sandbox.providers.aliyun_codeinterpreter.Template")
|
||||
@patch("agentrun.sandbox.Template")
|
||||
def test_initialize_success(self, mock_template):
|
||||
"""Test successful initialization."""
|
||||
# Mock health check response
|
||||
mock_template.list.return_value = []
|
||||
mock_template.list_templates.return_value = []
|
||||
|
||||
provider = AliyunCodeInterpreterProvider()
|
||||
result = provider.initialize(
|
||||
@@ -89,10 +89,10 @@ class TestAliyunCodeInterpreterProvider:
|
||||
result = provider2.initialize({"access_key_id": "LTAI5tXXXXXXXXXX", "access_key_secret": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"})
|
||||
assert result is False
|
||||
|
||||
@patch("agent.sandbox.providers.aliyun_codeinterpreter.Template")
|
||||
@patch("agentrun.sandbox.Template")
|
||||
def test_initialize_default_config(self, mock_template):
|
||||
"""Test initialization with default config."""
|
||||
mock_template.list.return_value = []
|
||||
mock_template.list_templates.return_value = []
|
||||
|
||||
provider = AliyunCodeInterpreterProvider()
|
||||
result = provider.initialize({"access_key_id": "LTAI5tXXXXXXXXXX", "access_key_secret": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", "account_id": "1234567890123456"})
|
||||
|
||||
Reference in New Issue
Block a user