mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-04 01:29:35 +08:00
fix(agent/tools): port AkShare to ToolBase so it works as an Agent tool (#16417)
### What problem does this PR solve? Closes #16416. The **AkShare** agent tool (`agent/tools/akshare.py`) was never ported to the modern `ToolBase`/`_invoke` interface during the agent module redesign and was still written against the removed legacy `_run`/`be_output` API, so it was non-functional: 1. **Adding it to an Agent raised `AttributeError`.** `AkShare` extended `ComponentBase` (not `ToolBase`) and `AkShareParam` defined no `meta`, so it had no `get_meta()`. `agent/component/agent_with_tools.py` builds each tool's function descriptor via `cpn.get_meta()`, so constructing an Agent that includes the AkShare tool raised `AttributeError: 'AkShare' object has no attribute 'get_meta'`. 2. **It could never run.** `invoke()` dispatches to `self._invoke`, but `AkShare` only implemented the legacy `_run`, so `_invoke` fell through to `ComponentBase._invoke` → `NotImplementedError`. `_run` also called `be_output(...)`, which no longer exists on the base classes. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) ### Changes - Port `AkShareParam` to `ToolParamBase` with a `ToolMeta` (defined before `super().__init__()`, matching `ArXivParam`/`TavilyExtractParam`) exposing a required `query` parameter — the stock symbol to look up, default `{sys.query}`. `query` matches the `{sys.query}` convention shared by the other tools. - Rewrite the component with `_invoke`/`set_output("formalized_content", ...)` (errors surfaced via `_ERROR`), keeping `top_n` and importing `akshare` lazily. - Add regression tests (`test/unit_test/agent/component/test_akshare.py`) covering param construction, validation, and the tool descriptor. Same class of defect as #16329 (DeepL) and #16414 (Crawler). Backend-only; no frontend changes. --------- Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
This commit is contained in:
@@ -13,44 +13,85 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
import pandas as pd
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
|
||||
from common.connection_utils import timeout
|
||||
|
||||
|
||||
class AkShareParam(ComponentParamBase):
|
||||
class AkShareParam(ToolParamBase):
|
||||
"""
|
||||
Define the AkShare component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta: ToolMeta = {
|
||||
"name": "akshare_stock_news",
|
||||
"description": "AkShare retrieves the latest news articles for a given Chinese A-share stock from East Money (东方财富).",
|
||||
"parameters": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The stock symbol/code to fetch news for, e.g. '600519'.",
|
||||
"default": "{sys.query}",
|
||||
"required": True,
|
||||
}
|
||||
},
|
||||
}
|
||||
super().__init__()
|
||||
self.top_n = 10
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {"query": {"name": "Stock symbol", "type": "line"}}
|
||||
|
||||
class AkShare(ComponentBase, ABC):
|
||||
|
||||
class AkShare(ToolBase, ABC):
|
||||
component_name = "AkShare"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
import akshare as ak
|
||||
ans = self.get_input()
|
||||
ans = ",".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return AkShare.be_output("")
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("AkShare processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
ak_res = []
|
||||
stock_news_em_df = ak.stock_news_em(symbol=ans)
|
||||
stock_news_em_df = stock_news_em_df.head(self._param.top_n)
|
||||
ak_res = [{"content": '<a href="' + i["新闻链接"] + '">' + i["新闻标题"] + '</a>\n 新闻内容: ' + i[
|
||||
"新闻内容"] + " \n发布时间:" + i["发布时间"] + " \n文章来源: " + i["文章来源"]} for index, i in stock_news_em_df.iterrows()]
|
||||
except Exception as e:
|
||||
return AkShare.be_output("**ERROR**: " + str(e))
|
||||
symbol = kwargs.get("query")
|
||||
if not symbol:
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
if not ak_res:
|
||||
return AkShare.be_output("")
|
||||
last_e = None
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
if self.check_if_canceled("AkShare processing"):
|
||||
return
|
||||
|
||||
return pd.DataFrame(ak_res)
|
||||
try:
|
||||
import akshare as ak
|
||||
|
||||
df = ak.stock_news_em(symbol=symbol).head(self._param.top_n)
|
||||
|
||||
if self.check_if_canceled("AkShare processing"):
|
||||
return
|
||||
|
||||
items = ['<a href="{}">{}</a>\n 新闻内容: {} \n发布时间:{} \n文章来源: {}'.format(i["新闻链接"], i["新闻标题"], i["新闻内容"], i["发布时间"], i["文章来源"]) for _, i in df.iterrows()]
|
||||
res = "\n\n".join(items)
|
||||
self.set_output("formalized_content", res)
|
||||
return res
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("AkShare processing"):
|
||||
return
|
||||
|
||||
last_e = e
|
||||
logging.exception(f"AkShare error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
if last_e:
|
||||
self.set_output("_ERROR", str(last_e))
|
||||
return f"AkShare error: {last_e}"
|
||||
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Looking up the latest stock news for: {}".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
99
test/unit_test/agent/component/test_akshare.py
Normal file
99
test/unit_test/agent/component/test_akshare.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.tools.akshare import AkShare, AkShareParam
|
||||
|
||||
|
||||
def _make_tool(top_n=10):
|
||||
# Bypass the canvas-bound __init__ (mirrors test_pubmed_unit.py) and stub the
|
||||
# canvas-touching helpers so we can exercise _invoke's execution path.
|
||||
tool = AkShare.__new__(AkShare)
|
||||
param = AkShareParam()
|
||||
param.top_n = top_n
|
||||
tool._param = param
|
||||
tool.check_if_canceled = lambda *a, **k: False
|
||||
out = {}
|
||||
tool.set_output = lambda k, v: out.__setitem__(k, v)
|
||||
tool.output = lambda k=None: (out.get(k) if k else out)
|
||||
return tool, out
|
||||
|
||||
|
||||
def _fake_news_df(n):
|
||||
import pandas as pd
|
||||
|
||||
return pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"新闻链接": f"https://u{i}",
|
||||
"新闻标题": f"title{i}",
|
||||
"新闻内容": f"content{i}",
|
||||
"发布时间": "2026-01-01",
|
||||
"文章来源": "src",
|
||||
}
|
||||
for i in range(n)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_param_instantiates():
|
||||
AkShareParam()
|
||||
|
||||
|
||||
def test_check_passes_with_defaults():
|
||||
AkShareParam().check()
|
||||
|
||||
|
||||
def test_meta_exposes_query_parameter():
|
||||
# Regression: AkShare extended ComponentBase and defined no `meta`, so it
|
||||
# had no get_meta() and crashed agent_with_tools when added to an Agent.
|
||||
meta = AkShareParam().get_meta()
|
||||
params = meta["function"]["parameters"]
|
||||
assert "query" in params["properties"]
|
||||
assert "query" in params["required"]
|
||||
|
||||
|
||||
def test_check_rejects_non_positive_top_n():
|
||||
param = AkShareParam()
|
||||
param.top_n = 0
|
||||
with pytest.raises(ValueError):
|
||||
param.check()
|
||||
|
||||
|
||||
def test_invoke_returns_content_and_sets_formalized_content(monkeypatch):
|
||||
# Regression for the restored runtime path: _invoke(query=...) must fetch
|
||||
# news, return the formatted content, write it to formalized_content, and
|
||||
# respect top_n.
|
||||
pytest.importorskip("akshare")
|
||||
import akshare
|
||||
|
||||
monkeypatch.setattr(akshare, "stock_news_em", lambda symbol: _fake_news_df(5))
|
||||
|
||||
tool, out = _make_tool(top_n=2)
|
||||
res = tool._invoke(query="600519")
|
||||
|
||||
assert "title0" in res and "https://u0" in res
|
||||
assert out["formalized_content"] == res
|
||||
# top_n is applied via .head(top_n): only 2 articles formatted.
|
||||
assert res.count("新闻内容:") == 2
|
||||
|
||||
|
||||
def test_invoke_empty_query_returns_empty():
|
||||
# Empty query short-circuits without calling akshare.
|
||||
tool, out = _make_tool()
|
||||
assert tool._invoke(query="") == ""
|
||||
assert out.get("formalized_content") == ""
|
||||
Reference in New Issue
Block a user