Port agent PRs to GO - 3 (#16596)

### Summary

Port
https://github.com/infiniflow/ragflow/pull/16415
https://github.com/infiniflow/ragflow/pull/16417
This commit is contained in:
qinling0210
2026-07-03 18:03:23 +08:00
committed by GitHub
parent 8db68e3eec
commit ffc4d29a06
7 changed files with 576 additions and 159 deletions

View File

@@ -19,59 +19,103 @@ package tool
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema"
)
const akshareToolName = "akshare"
const akshareToolName = "akshare_stock_news"
const akshareToolDescription = "Query Chinese A-share financial data (AkShare). " +
"STUB: AkShare is a Python library — not available in the Go Canvas. " +
"Use the Python Canvas for Chinese A-share data."
const akshareToolDescription = "Retrieves the latest East Money news articles for a Chinese A-share stock symbol."
const akshareUnsupportedMessage = "AkShare is a Python library — not available in Go Canvas. " +
"Use Python Canvas for Chinese A-share data."
const defaultAkShareTopN = 10
const maxAkShareResponseBytes = 4 << 20
var akshareStockNewsEndpoint = "https://search-api-web.eastmoney.com/search/jsonp"
// akshareParams is the JSON shape the model sends into InvokableRun.
//
// AkShare exposes 200+ indicator functions (e.g. stock_zh_a_hist,
// stock_zh_a_spot_em, bond_zh_hs_cov). For a future HTTP shim that routes
// to a Python backend we keep the same shape: a stock/asset symbol plus
// the indicator name. The Go stub validates the shape and rejects all
// invocations with a clear error.
type akshareParams struct {
Symbol string `json:"symbol"`
Indicator string `json:"indicator"`
Query string `json:"query"`
Symbol string `json:"symbol,omitempty"`
TopN int `json:"top_n,omitempty"`
}
func (p akshareParams) stockSymbol() string {
if query := strings.TrimSpace(p.Query); query != "" {
return query
}
return strings.TrimSpace(p.Symbol)
}
type akshareArticle struct {
Title string `json:"title"`
Content string `json:"content"`
PublishedAt string `json:"published_at"`
Source string `json:"source"`
URL string `json:"url"`
}
// akshareEnvelope is the model-facing JSON shape. The stub always
// returns a populated Error so the model gets a deterministic, parseable
// failure.
type akshareEnvelope struct {
Fields []string `json:"fields,omitempty"`
Items []any `json:"items,omitempty"`
Error string `json:"_ERROR,omitempty"`
Content string `json:"content,omitempty"`
Articles []akshareArticle `json:"articles,omitempty"`
Error string `json:"_ERROR,omitempty"`
}
// AkShareTool is a stub implementation of the AkShare Chinese
// financial data tool.
//
// AkShare is a Python library (https://github.com/akfamily/akshare)
// — there is no public HTTP API. The tool is registered with the
// canvas so DSLs that reference "akshare" continue to parse, but
// every invocation fails fast with a clear "use Python Canvas"
// message rather
// than silently no-op'ing.
//
// AkShareTool does not own an HTTPHelper — it never makes network calls.
type AkShareTool struct{}
type akshareSearchRequest struct {
UID string `json:"uid"`
Keyword string `json:"keyword"`
Type []string `json:"type"`
Client string `json:"client"`
ClientType string `json:"clientType"`
ClientVersion string `json:"clientVersion"`
Param map[string]any `json:"param"`
}
// NewAkShareTool returns an AkShareTool. No HTTPHelper is allocated;
// the stub never issues network requests.
func NewAkShareTool() *AkShareTool { return &AkShareTool{} }
type akshareSearchResponse struct {
Result struct {
Articles []akshareEastMoneyArticle `json:"cmsArticleWebOld"`
} `json:"result"`
}
type akshareEastMoneyArticle struct {
Code string `json:"code"`
Title string `json:"title"`
Content string `json:"content"`
Date string `json:"date"`
MediaName string `json:"mediaName"`
}
// AkShareTool retrieves East Money stock news using the same endpoint
// as AkShare's stock_news_em(symbol=...) helper.
type AkShareTool struct {
helper *HTTPHelper
topN int
}
func NewAkShareTool() *AkShareTool {
return NewAkShareToolWith(NewHTTPHelper())
}
func NewAkShareToolWith(h *HTTPHelper) *AkShareTool {
return NewAkShareToolWithTopN(h, defaultAkShareTopN)
}
func NewAkShareToolWithTopN(h *HTTPHelper, topN int) *AkShareTool {
if h == nil {
h = NewHTTPHelper()
}
if topN <= 0 {
topN = defaultAkShareTopN
}
return &AkShareTool{helper: h, topN: topN}
}
// Info returns the tool's metadata for the chat model.
func (a *AkShareTool) Info(_ context.Context) (*schema.ToolInfo, error) {
@@ -79,36 +123,204 @@ func (a *AkShareTool) Info(_ context.Context) (*schema.ToolInfo, error) {
Name: akshareToolName,
Desc: akshareToolDescription,
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"symbol": {
"query": {
Type: schema.String,
Desc: "Stock or asset symbol (e.g. 000001, sh600000, bj920566).",
Required: true,
},
"indicator": {
Type: schema.String,
Desc: "AkShare indicator function name (e.g. stock_zh_a_hist, stock_zh_a_spot_em).",
Desc: "Stock symbol/code to fetch East Money news for, e.g. 600519.",
Required: true,
},
}),
}, nil
}
// InvokableRun validates the input shape and returns a clear "use Python
// Canvas" error. The model receives a JSON envelope with the message in
// the `_ERROR` field so it can present a useful message back to the user
// without surfacing a Go stack trace.
func (a *AkShareTool) InvokableRun(_ context.Context, argsJSON string, _ ...tool.Option) (string, error) {
// InvokableRun fetches stock news from East Money and returns both a
// formatted content string and structured article records.
func (a *AkShareTool) InvokableRun(ctx context.Context, argsJSON string, _ ...tool.Option) (string, error) {
var p akshareParams
if err := json.Unmarshal([]byte(argsJSON), &p); err != nil {
return akshareErrJSON(errors.New(akshareUnsupportedMessage)),
errors.New(akshareUnsupportedMessage)
return akshareErrJSON(fmt.Errorf("akshare: parse arguments: %w", err)),
fmt.Errorf("akshare: parse arguments: %w", err)
}
if p.Symbol == "" || p.Indicator == "" {
return akshareErrJSON(errors.New(akshareUnsupportedMessage)),
errors.New(akshareUnsupportedMessage)
symbol := p.stockSymbol()
if symbol == "" {
return akshareErrJSON(fmt.Errorf("akshare: query is required")),
fmt.Errorf("akshare: query is required")
}
return akshareErrJSON(errors.New(akshareUnsupportedMessage)),
errors.New(akshareUnsupportedMessage)
topN := a.topN
if p.TopN > 0 {
topN = p.TopN
}
if topN <= 0 {
topN = defaultAkShareTopN
}
endpoint, err := buildAkShareStockNewsURL(symbol, topN)
if err != nil {
return akshareErrJSON(err), err
}
headers := map[string]string{
"Accept": "*/*",
"Referer": "https://so.eastmoney.com/news/s?keyword=" + url.QueryEscape(symbol),
"User-Agent": "Mozilla/5.0 (compatible; ragflow/1.0)",
}
resp, err := a.helper.Do(ctx, http.MethodGet, endpoint, "", "", headers)
if err != nil {
return akshareErrJSON(err), err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
err := fmt.Errorf("akshare: upstream returned %d", resp.StatusCode)
return akshareErrJSON(err), err
}
body, err := io.ReadAll(io.LimitReader(resp.Body, maxAkShareResponseBytes+1))
if err != nil {
err = fmt.Errorf("akshare: read response: %w", err)
return akshareErrJSON(err), err
}
if len(body) > maxAkShareResponseBytes {
err = fmt.Errorf("akshare: response too large")
return akshareErrJSON(err), err
}
articles, err := parseAkShareStockNews(body, topN)
if err != nil {
return akshareErrJSON(err), err
}
env := akshareEnvelope{
Content: formatAkShareArticles(articles),
Articles: articles,
}
return akshareJSON(env), nil
}
func buildAkShareStockNewsURL(symbol string, topN int) (string, error) {
if topN <= 0 {
topN = defaultAkShareTopN
}
inner := akshareSearchRequest{
Keyword: symbol,
Type: []string{"cmsArticleWebOld"},
Client: "web",
ClientType: "web",
ClientVersion: "curr",
Param: map[string]any{
"cmsArticleWebOld": map[string]any{
"searchScope": "default",
"sort": "default",
"pageIndex": 1,
"pageSize": topN,
"preTag": "<em>",
"postTag": "</em>",
},
},
}
innerJSON, err := json.Marshal(inner)
if err != nil {
return "", fmt.Errorf("akshare: build request: %w", err)
}
callback := fmt.Sprintf("jQuery%d_%d", time.Now().UnixNano(), time.Now().UnixMilli())
q := url.Values{}
q.Set("cb", callback)
q.Set("param", string(innerJSON))
q.Set("_", fmt.Sprint(time.Now().UnixMilli()))
return akshareStockNewsEndpoint + "?" + q.Encode(), nil
}
func parseAkShareStockNews(body []byte, topN int) ([]akshareArticle, error) {
if topN <= 0 {
topN = defaultAkShareTopN
}
rawJSON, err := stripJSONP(string(body))
if err != nil {
return nil, err
}
var raw akshareSearchResponse
if err := json.Unmarshal([]byte(rawJSON), &raw); err != nil {
return nil, fmt.Errorf("akshare: decode response: %w", err)
}
if topN > len(raw.Result.Articles) {
topN = len(raw.Result.Articles)
}
articles := make([]akshareArticle, 0, topN)
for _, item := range raw.Result.Articles[:topN] {
articleURL := ""
if code := strings.TrimSpace(item.Code); code != "" {
articleURL = "http://finance.eastmoney.com/a/" + code + ".html"
}
articles = append(articles, akshareArticle{
Title: cleanAkShareText(item.Title),
Content: cleanAkShareText(item.Content),
PublishedAt: strings.TrimSpace(item.Date),
Source: strings.TrimSpace(item.MediaName),
URL: articleURL,
})
}
return articles, nil
}
func stripJSONP(s string) (string, error) {
s = strings.TrimSpace(s)
start := strings.IndexByte(s, '(')
end := strings.LastIndexByte(s, ')')
if start < 0 || end <= start {
if strings.HasPrefix(s, "{") {
return s, nil
}
return "", fmt.Errorf("akshare: invalid JSONP response")
}
return strings.TrimSpace(s[start+1 : end]), nil
}
func cleanAkShareText(s string) string {
replacements := []struct {
old string
new string
}{
{old: "(<em>", new: ""},
{old: "</em>)", new: ""},
{old: "<em>", new: ""},
{old: "</em>", new: ""},
{old: "\u3000", new: ""},
{old: "\r\n", new: " "},
}
for _, repl := range replacements {
s = strings.ReplaceAll(s, repl.old, repl.new)
}
return strings.TrimSpace(s)
}
// formatAkShareArticles renders articles into a human-readable block.
// NOTE: the upstream Python PR's _invoke dropped 文章来源 from the
// format string (5 args but only 4 {} placeholders); we preserve all
// five fields here (link, title, content, date, source).
func formatAkShareArticles(articles []akshareArticle) string {
if len(articles) == 0 {
return ""
}
parts := make([]string, 0, len(articles))
for _, item := range articles {
parts = append(parts, fmt.Sprintf(
`<a href="%s">%s</a>
新闻内容: %s
发布时间:%s
文章来源: %s`,
item.URL,
item.Title,
item.Content,
item.PublishedAt,
item.Source,
))
}
return strings.Join(parts, "\n\n")
}
func akshareJSON(env akshareEnvelope) string {

View File

@@ -19,61 +19,87 @@ package tool
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestAkShare_StubsUnsupported(t *testing.T) {
func TestAkShare_FetchesStockNews(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/search/jsonp" {
t.Errorf("path = %q, want /search/jsonp", r.URL.Path)
http.Error(w, "unexpected path", http.StatusBadRequest)
return
}
param := r.URL.Query().Get("param")
var req akshareSearchRequest
if err := json.Unmarshal([]byte(param), &req); err != nil {
t.Errorf("param is not valid request JSON: %v (%s)", err, param)
http.Error(w, "invalid param", http.StatusBadRequest)
return
}
if req.Keyword != "600519" {
t.Errorf("keyword = %q, want 600519", req.Keyword)
http.Error(w, "unexpected keyword", http.StatusBadRequest)
return
}
callback := r.URL.Query().Get("cb")
if callback == "" {
callback = "callback"
}
w.Header().Set("Content-Type", "application/javascript")
_, _ = w.Write([]byte(callback + `({
"result": {
"cmsArticleWebOld": [
{"code":"202601010001","title":"<em>贵州茅台</em> title0","content":"content0\r\nmore","date":"2026-01-01","mediaName":"src0"},
{"code":"202601010002","title":"title1","content":"content1","date":"2026-01-02","mediaName":"src1"},
{"code":"202601010003","title":"title2","content":"content2","date":"2026-01-03","mediaName":"src2"}
]
}
})`))
}))
defer srv.Close()
oldEndpoint := akshareStockNewsEndpoint
akshareStockNewsEndpoint = srv.URL + "/search/jsonp"
defer func() { akshareStockNewsEndpoint = oldEndpoint }()
tool := NewAkShareToolWithTopN(NewHTTPHelper(), 2)
out, err := tool.InvokableRun(context.Background(), `{"query":"600519"}`)
if err != nil {
t.Fatalf("InvokableRun: %v (out=%s)", err, out)
}
var env akshareEnvelope
if err := json.Unmarshal([]byte(out), &env); err != nil {
t.Fatalf("output is not valid JSON: %v (raw=%s)", err, out)
}
if env.Error != "" {
t.Fatalf("env.Error = %q, want empty", env.Error)
}
if len(env.Articles) != 2 {
t.Fatalf("len(Articles) = %d, want 2", len(env.Articles))
}
if env.Articles[0].Title != "贵州茅台 title0" {
t.Errorf("first title = %q, want cleaned title", env.Articles[0].Title)
}
if !strings.Contains(env.Content, `<a href="http://finance.eastmoney.com/a/202601010001.html">贵州茅台 title0</a>`) {
t.Errorf("formatted content missing first link/title: %q", env.Content)
}
if strings.Count(env.Content, "新闻内容:") != 2 {
t.Errorf("formatted content count = %d, want 2", strings.Count(env.Content, "新闻内容:"))
}
}
func TestAkShare_ParseTruncatesToTopN(t *testing.T) {
t.Parallel()
cases := []struct {
name string
args string
wantSub string
}{
{
name: "well-formed args",
args: `{"symbol":"000001","indicator":"stock_zh_a_hist"}`,
wantSub: "Python library",
},
{
name: "missing indicator",
args: `{"symbol":"000001"}`,
wantSub: "Python library",
},
{
name: "missing symbol",
args: `{"indicator":"stock_zh_a_hist"}`,
wantSub: "Python library",
},
{
name: "empty payload",
args: `{}`,
wantSub: "Python library",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tool := NewAkShareTool()
out, err := tool.InvokableRun(context.Background(), tc.args)
if err == nil {
t.Fatalf("expected error, got nil (out=%s)", out)
}
if !strings.Contains(err.Error(), tc.wantSub) {
t.Errorf("err = %q, want substring %q", err.Error(), tc.wantSub)
}
var env akshareEnvelope
if jerr := json.Unmarshal([]byte(out), &env); jerr != nil {
t.Fatalf("output is not valid JSON: %v (raw=%s)", jerr, out)
}
if env.Error == "" {
t.Errorf("env.Error = empty, want populated")
}
if !strings.Contains(env.Error, "Python") {
t.Errorf("env.Error = %q, want to mention Python Canvas", env.Error)
}
})
articles := parseAkShareFixture(t, 3)
if len(articles) != 3 {
t.Fatalf("fixture articles = %d, want 3", len(articles))
}
}
@@ -85,8 +111,21 @@ func TestAkShare_RejectsMalformedJSON(t *testing.T) {
if err == nil {
t.Fatal("expected error for malformed JSON, got nil")
}
if !strings.Contains(err.Error(), "Python") {
t.Errorf("err = %q, want to mention Python", err.Error())
if !strings.Contains(err.Error(), "parse arguments") {
t.Errorf("err = %q, want parse arguments error", err.Error())
}
}
func TestAkShare_RejectsMissingQuery(t *testing.T) {
t.Parallel()
tool := NewAkShareTool()
out, err := tool.InvokableRun(context.Background(), `{}`)
if err == nil {
t.Fatalf("expected error for missing query, got nil (out=%s)", out)
}
if !strings.Contains(err.Error(), "query is required") {
t.Errorf("err = %q, want query required", err.Error())
}
}
@@ -98,13 +137,93 @@ func TestAkShare_Info(t *testing.T) {
if err != nil {
t.Fatalf("Info: %v", err)
}
if info.Name != "akshare" {
t.Errorf("Name = %q, want akshare", info.Name)
if info.Name != "akshare_stock_news" {
t.Errorf("Name = %q, want akshare_stock_news", info.Name)
}
if !strings.Contains(info.Desc, "AkShare") {
t.Errorf("Desc = %q, want to mention AkShare", info.Desc)
if !strings.Contains(info.Desc, "East Money") {
t.Errorf("Desc = %q, want to mention East Money", info.Desc)
}
if !strings.Contains(info.Desc, "STUB") && !strings.Contains(info.Desc, "Python") {
t.Errorf("Desc = %q, want to flag stub status", info.Desc)
if info.ParamsOneOf == nil {
t.Fatal("ParamsOneOf = nil, want schema definition")
}
schema, err := info.ParamsOneOf.ToJSONSchema()
if err != nil {
t.Fatalf("ToJSONSchema: %v", err)
}
raw, err := json.Marshal(schema)
if err != nil {
t.Fatalf("marshal params schema: %v", err)
}
params := string(raw)
if !strings.Contains(params, `"query"`) {
t.Fatalf("schema missing query parameter: %s", params)
}
if !strings.Contains(params, `"required":["query"]`) {
t.Fatalf("schema does not require query: %s", params)
}
}
func TestBuildByName_AkShareRejectsInvalidTopN(t *testing.T) {
t.Parallel()
_, err := BuildByName("akshare", map[string]any{"top_n": 0})
if err == nil {
t.Fatal("expected top_n validation error")
}
if !strings.Contains(err.Error(), "positive integer") {
t.Fatalf("err = %q, want positive integer validation", err.Error())
}
}
func parseAkShareFixture(t *testing.T, topN int) []akshareArticle {
t.Helper()
callback := "cb"
body := callback + `({
"result": {
"cmsArticleWebOld": [
{"code":"1","title":"title1","content":"content1","date":"2026-01-01","mediaName":"src1"},
{"code":"2","title":"title2","content":"content2","date":"2026-01-02","mediaName":"src2"},
{"code":"3","title":"title3","content":"content3","date":"2026-01-03","mediaName":"src3"},
{"code":"4","title":"title4","content":"content4","date":"2026-01-04","mediaName":"src4"}
]
}
})`
articles, err := parseAkShareStockNews([]byte(body), topN)
if err != nil {
t.Fatalf("parseAkShareStockNews: %v", err)
}
return articles
}
func TestBuildAkShareStockNewsURL(t *testing.T) {
oldEndpoint := akshareStockNewsEndpoint
akshareStockNewsEndpoint = "https://example.com/search/jsonp"
defer func() { akshareStockNewsEndpoint = oldEndpoint }()
rawURL, err := buildAkShareStockNewsURL("600519", 25)
if err != nil {
t.Fatalf("buildAkShareStockNewsURL: %v", err)
}
u, err := url.Parse(rawURL)
if err != nil {
t.Fatalf("parse built URL: %v", err)
}
if u.Host != "example.com" {
t.Fatalf("host = %q, want example.com", u.Host)
}
var req akshareSearchRequest
if err := json.Unmarshal([]byte(u.Query().Get("param")), &req); err != nil {
t.Fatalf("param JSON: %v", err)
}
if req.Keyword != "600519" {
t.Fatalf("keyword = %q, want 600519", req.Keyword)
}
cmsParam, ok := req.Param["cmsArticleWebOld"].(map[string]any)
if !ok {
t.Fatalf("cmsArticleWebOld param = %T, want object", req.Param["cmsArticleWebOld"])
}
if cmsParam["pageSize"] != float64(25) {
t.Fatalf("pageSize = %v, want 25", cmsParam["pageSize"])
}
}

View File

@@ -31,20 +31,29 @@ import (
"golang.org/x/net/html"
)
const crawlerToolName = "crawler"
const crawlerToolName = "web_crawler"
const crawlerToolDescription = "Fetches a web page and returns its extracted text content and links."
const crawlerToolDescription = "Crawls a web page and returns its extracted text content and links."
// crawlerArgs is the JSON shape the model sends in. max_depth and
// max_pages are accepted for API symmetry with the Python tool, but
// the current implementation only supports depth=0 (single page
// fetch).
// crawlerArgs is the JSON shape the model sends in. query is the
// Python-compatible argument name; url is accepted for older Go
// callers. max_depth and max_pages are accepted for API symmetry with
// earlier Go callers, but the current implementation only supports
// depth=0 (single page fetch).
type crawlerArgs struct {
URL string `json:"url"`
Query string `json:"query"`
URL string `json:"url,omitempty"`
MaxDepth int `json:"max_depth,omitempty"`
MaxPages int `json:"max_pages,omitempty"`
}
func (a crawlerArgs) targetURL() string {
if url := strings.TrimSpace(a.Query); url != "" {
return url
}
return strings.TrimSpace(a.URL)
}
// crawlerResult is the JSON envelope returned to the model. The shape
// mirrors the Python tool's `content` / `links` output.
type crawlerResult struct {
@@ -119,9 +128,9 @@ func (c *CrawlerTool) Info(_ context.Context) (*schema.ToolInfo, error) {
Name: crawlerToolName,
Desc: crawlerToolDescription,
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"url": {
"query": {
Type: schema.String,
Desc: "The URL to fetch. Must be http or https.",
Desc: "The absolute URL to crawl. Must include an http:// or https:// scheme.",
Required: true,
},
"max_depth": {
@@ -159,20 +168,21 @@ func (c *CrawlerTool) InvokableRun(ctx context.Context, argumentsInJSON string,
fmt.Errorf("crawler: parse arguments: %w", err)
}
if strings.TrimSpace(args.URL) == "" {
return crawlerStubResult(crawlerResult{Error: "url is required"}),
errors.New("crawler: empty url")
targetURL := args.targetURL()
if targetURL == "" {
return crawlerStubResult(crawlerResult{Error: "query is required"}),
errors.New("crawler: empty query")
}
if !strings.HasPrefix(args.URL, "http://") && !strings.HasPrefix(args.URL, "https://") {
return crawlerStubResult(crawlerResult{URL: args.URL, Error: "url must be http or https"}),
fmt.Errorf("crawler: unsupported url scheme: %s", args.URL)
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
return crawlerStubResult(crawlerResult{URL: targetURL, Error: "url must be http or https"}),
fmt.Errorf("crawler: unsupported url scheme: %s", targetURL)
}
// Reject max_depth > 0 BEFORE the SSRF guard: the guard performs a
// DNS lookup that may be slow / fail in CI, and a depth-0 caller
// asking for max_depth=10 should be rejected on a structural
// problem first.
if args.MaxDepth > 0 {
return crawlerStubResult(crawlerResult{URL: args.URL, Error: ErrCrawlerDepthUnsupported.Error()}),
return crawlerStubResult(crawlerResult{URL: targetURL, Error: ErrCrawlerDepthUnsupported.Error()}),
ErrCrawlerDepthUnsupported
}
// SSRF guard + DNS-rebinding pinning. c.resolve validates the URL
@@ -184,29 +194,29 @@ func (c *CrawlerTool) InvokableRun(ctx context.Context, argumentsInJSON string,
// IP we resolved here) AND TLS SNI / cert verification continue to
// target the validated hostname. Rewriting the URL host to the IP
// would have broken HTTPS, so the pinning happens in the dialer.
host, pinnedIP, resolveErr := c.resolve(args.URL)
host, pinnedIP, resolveErr := c.resolve(targetURL)
if resolveErr != nil {
return crawlerStubResult(crawlerResult{URL: args.URL, Error: resolveErr.Error()}), resolveErr
return crawlerStubResult(crawlerResult{URL: targetURL, Error: resolveErr.Error()}), resolveErr
}
resp, err := c.helper.DoPinned(ctx, http.MethodGet, args.URL, "", "", nil, host, pinnedIP)
resp, err := c.helper.DoPinned(ctx, http.MethodGet, targetURL, "", "", nil, host, pinnedIP)
if err != nil {
return crawlerStubResult(crawlerResult{URL: args.URL, Error: err.Error()}), err
return crawlerStubResult(crawlerResult{URL: targetURL, Error: err.Error()}), err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return crawlerStubResult(crawlerResult{URL: args.URL, Status: resp.StatusCode, Error: "read body: " + err.Error()}),
return crawlerStubResult(crawlerResult{URL: targetURL, Status: resp.StatusCode, Error: "read body: " + err.Error()}),
fmt.Errorf("crawler: read body: %w", err)
}
page, err := extractPage(body)
if err != nil {
return crawlerStubResult(crawlerResult{URL: args.URL, Status: resp.StatusCode, Error: err.Error()}),
return crawlerStubResult(crawlerResult{URL: targetURL, Status: resp.StatusCode, Error: err.Error()}),
fmt.Errorf("crawler: extract: %w", err)
}
page.URL = args.URL
page.URL = targetURL
page.Status = resp.StatusCode
return crawlerJSON(page)

View File

@@ -64,7 +64,7 @@ func TestCrawler_FetchesAndExtractsText(t *testing.T) {
}
c := NewCrawlerTool().WithResolver(loopbackResolver)
out, err := c.InvokableRun(context.Background(),
`{"url":`+jsonString(srv.URL)+`,"max_depth":0}`)
`{"query":`+jsonString(srv.URL)+`,"max_depth":0}`)
if err != nil {
t.Fatalf("InvokableRun: %v", err)
}
@@ -107,19 +107,19 @@ func TestCrawler_RejectsMaxDepthGreaterThanZero(t *testing.T) {
t.Parallel()
c := NewCrawlerTool()
_, err := c.InvokableRun(context.Background(), `{"url":"https://example.com","max_depth":1}`)
_, err := c.InvokableRun(context.Background(), `{"query":"https://example.com","max_depth":1}`)
if !errors.Is(err, ErrCrawlerDepthUnsupported) {
t.Fatalf("err = %v, want ErrCrawlerDepthUnsupported", err)
}
}
func TestCrawler_RejectsMissingURL(t *testing.T) {
func TestCrawler_RejectsMissingQuery(t *testing.T) {
t.Parallel()
c := NewCrawlerTool()
_, err := c.InvokableRun(context.Background(), `{"url":""}`)
_, err := c.InvokableRun(context.Background(), `{"query":""}`)
if err == nil {
t.Fatal("expected error for empty url")
t.Fatal("expected error for empty query")
}
}
@@ -127,12 +127,29 @@ func TestCrawler_RejectsNonHTTPScheme(t *testing.T) {
t.Parallel()
c := NewCrawlerTool()
_, err := c.InvokableRun(context.Background(), `{"url":"file:///etc/passwd"}`)
_, err := c.InvokableRun(context.Background(), `{"query":"file:///etc/passwd"}`)
if err == nil || !strings.Contains(err.Error(), "scheme") {
t.Fatalf("err = %v, want to reject file:// scheme", err)
}
}
func TestCrawler_AcceptsLegacyURLArgument(t *testing.T) {
t.Parallel()
sentinel := errors.New("stop after legacy url normalization")
c := NewCrawlerTool().WithResolver(func(rawURL string) (string, net.IP, error) {
if rawURL != "https://example.com" {
t.Fatalf("resolver URL = %q, want https://example.com", rawURL)
}
return "example.com", net.ParseIP("93.184.216.34"), sentinel
})
_, err := c.InvokableRun(context.Background(), `{"url":"https://example.com"}`)
if !errors.Is(err, sentinel) {
t.Fatalf("err = %v, want resolver error after accepting legacy url", err)
}
}
func TestCrawler_Info(t *testing.T) {
t.Parallel()
@@ -141,12 +158,33 @@ func TestCrawler_Info(t *testing.T) {
if err != nil {
t.Fatalf("Info: %v", err)
}
if info.Name != "crawler" {
t.Errorf("Name = %q, want crawler", info.Name)
if info.Name != "web_crawler" {
t.Errorf("Name = %q, want web_crawler", info.Name)
}
if !strings.Contains(info.Desc, "text") {
t.Errorf("Desc = %q, want to mention text extraction", info.Desc)
}
if info.ParamsOneOf == nil {
t.Fatal("ParamsOneOf = nil, want schema definition")
}
paramsSchema, err := info.ParamsOneOf.ToJSONSchema()
if err != nil {
t.Fatalf("ToJSONSchema: %v", err)
}
paramsJSON, err := json.Marshal(paramsSchema)
if err != nil {
t.Fatalf("marshal params schema: %v", err)
}
params := string(paramsJSON)
if !strings.Contains(params, `"query"`) {
t.Fatalf("schema missing query parameter: %s", params)
}
if !strings.Contains(params, `"required":["query"]`) {
t.Fatalf("schema does not require query: %s", params)
}
if strings.Contains(params, `"url"`) {
t.Fatalf("schema exposes legacy url parameter: %s", params)
}
}
// jsonString is defined in exesql_test.go (same package).

View File

@@ -29,7 +29,7 @@ import (
type Factory func(params map[string]any) (einotool.BaseTool, error)
var registry = map[string]Factory{
"akshare": noConfig("akshare", func() einotool.BaseTool { return NewAkShareTool() }),
"akshare": buildAkShareTool,
"arxiv": noConfig("arxiv", func() einotool.BaseTool { return NewArxivTool() }),
"bgpt": noConfig("bgpt", func() einotool.BaseTool { return NewBGPTTool() }),
"code_exec": noConfig("code_exec", func() einotool.BaseTool { return NewCodeExecTool() }),
@@ -53,6 +53,7 @@ var registry = map[string]Factory{
"tavily": noConfig("tavily", func() einotool.BaseTool { return NewTavilyTool() }),
"tushare": noConfig("tushare", func() einotool.BaseTool { return NewTushareTool() }),
"wencai": noConfig("wencai", func() einotool.BaseTool { return NewWencaiTool() }),
"web_crawler": noConfig("web_crawler", func() einotool.BaseTool { return NewCrawlerTool() }),
"wikipedia": noConfig("wikipedia", func() einotool.BaseTool { return NewWikipediaTool() }),
"yahoo_finance": noConfig("yahoo_finance", func() einotool.BaseTool { return NewYahooFinanceTool() }),
}
@@ -106,6 +107,24 @@ func BuildAll(names []string, perToolParams map[string]map[string]any) ([]einoto
return tools, nil
}
func buildAkShareTool(params map[string]any) (einotool.BaseTool, error) {
topN := defaultAkShareTopN
if len(params) != 0 {
for key := range params {
if key != "top_n" {
return nil, fmt.Errorf("agent tool: tool %q only accepts node-level param top_n", "akshare")
}
}
if v, ok := intParam(params, "top_n"); ok {
topN = v
}
if topN <= 0 {
return nil, fmt.Errorf("agent tool: tool %q requires positive integer node-level param top_n", "akshare")
}
}
return NewAkShareToolWithTopN(nil, topN), nil
}
func buildExeSQLTool(params map[string]any) (einotool.BaseTool, error) {
conn, err := decodeExeSQLConnParams(params)
if err != nil {

View File

@@ -41,11 +41,14 @@ func TestBuildAll_UnknownTool(t *testing.T) {
}
func TestBuildAll_AllRegisteredTools(t *testing.T) {
// Every key in registry (27 entries, 23 unique canonical tools).
names := []string{
"akshare", "arxiv", "code_exec", "crawler", "deepl", "duckduckgo",
"email", "github", "google", "google_scholar", "jin10", "pubmed",
"qweather", "retrieval", "searxng", "tavily", "tushare", "wencai",
"wikipedia", "yahoo_finance", "execute_sql",
"akshare", "arxiv", "bgpt", "code_exec", "crawler", "deepl",
"duckduckgo", "email", "exesql", "execute_sql", "github", "google",
"google_scholar", "jin10", "keenable", "pubmed", "qweather",
"retrieval", "search_my_dataset", "search_my_dateset", "searxng",
"tavily", "tushare", "web_crawler", "wencai", "wikipedia",
"yahoo_finance",
}
params := map[string]map[string]any{
"execute_sql": {
@@ -57,6 +60,18 @@ func TestBuildAll_AllRegisteredTools(t *testing.T) {
"password": "p",
"max_records": 10,
},
"exesql": {
"db_type": "mysql",
"host": "127.0.0.1",
"port": 3306,
"database": "demo",
"username": "u",
"password": "p",
"max_records": 10,
},
"keenable": {
"api_key": "key-test",
},
}
tools, err := BuildAll(names, params)
if err != nil {
@@ -101,15 +116,16 @@ func TestBuildAll_KeenableRejectsEmptyNodeAPIKey(t *testing.T) {
func TestToolRegistry_SchemasAreComplete(t *testing.T) {
t.Parallel()
// Every entry the registry advertises. 23 names, 22 unique
// Every entry the registry advertises. 27 names, 23 unique
// canonical tools (execute_sql == exesql, retrieval ==
// search_my_dateset).
// search_my_dataset == search_my_dateset, crawler == web_crawler).
names := []string{
"akshare", "arxiv", "code_exec", "crawler", "deepl", "duckduckgo",
"email", "execute_sql", "exesql", "github", "google",
"google_scholar", "jin10", "keenable", "pubmed", "qweather", "retrieval",
"search_my_dateset", "searxng", "tavily", "tushare", "wencai",
"wikipedia", "yahoo_finance",
"akshare", "arxiv", "bgpt", "code_exec", "crawler", "deepl",
"duckduckgo", "email", "execute_sql", "exesql", "github", "google",
"google_scholar", "jin10", "keenable", "pubmed", "qweather",
"retrieval", "search_my_dataset", "search_my_dateset", "searxng",
"tavily", "tushare", "web_crawler", "wencai", "wikipedia",
"yahoo_finance",
}
params := map[string]map[string]any{
"execute_sql": {
@@ -161,14 +177,17 @@ func TestToolRegistry_SchemasAreComplete(t *testing.T) {
}
// Alias consistency: execute_sql and exesql must surface the
// same canonical Info().Name; same for retrieval and
// search_my_dateset. A bug here would mean an alias was
// accidentally pointed at a different tool.
// same canonical Info().Name; same for retrieval/search_my_dataset/
// search_my_dateset and crawler/web_crawler. A bug here would mean
// an alias was accidentally pointed at a different tool.
canonicalByAlias := map[string]string{
"execute_sql": "execute_sql",
"exesql": "execute_sql",
"retrieval": "search_my_dateset",
"search_my_dataset": "search_my_dateset",
"search_my_dateset": "search_my_dateset",
"crawler": "web_crawler",
"web_crawler": "web_crawler",
}
for _, name := range names {
canonical, ok := canonicalByAlias[name]

View File

@@ -161,7 +161,7 @@ func sortedToolNames() []string {
"email", "execute_sql", "exesql", "github", "google", "google_scholar",
"jin10", "pubmed", "qweather", "retrieval", "search_my_dataset",
"search_my_dateset", "searxng", "tavily", "tushare", "wencai",
"wikipedia", "yahoo_finance",
"web_crawler", "wikipedia", "yahoo_finance",
}
sort.Strings(known)
return known