diff --git a/internal/agent/tool/akshare.go b/internal/agent/tool/akshare.go index dd232cff43..296c8475e9 100644 --- a/internal/agent/tool/akshare.go +++ b/internal/agent/tool/akshare.go @@ -19,59 +19,103 @@ package tool import ( "context" "encoding/json" - "errors" "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" ) -const akshareToolName = "akshare" +const akshareToolName = "akshare_stock_news" -const akshareToolDescription = "Query Chinese A-share financial data (AkShare). " + - "STUB: AkShare is a Python library — not available in the Go Canvas. " + - "Use the Python Canvas for Chinese A-share data." +const akshareToolDescription = "Retrieves the latest East Money news articles for a Chinese A-share stock symbol." -const akshareUnsupportedMessage = "AkShare is a Python library — not available in Go Canvas. " + - "Use Python Canvas for Chinese A-share data." +const defaultAkShareTopN = 10 + +const maxAkShareResponseBytes = 4 << 20 + +var akshareStockNewsEndpoint = "https://search-api-web.eastmoney.com/search/jsonp" // akshareParams is the JSON shape the model sends into InvokableRun. -// -// AkShare exposes 200+ indicator functions (e.g. stock_zh_a_hist, -// stock_zh_a_spot_em, bond_zh_hs_cov). For a future HTTP shim that routes -// to a Python backend we keep the same shape: a stock/asset symbol plus -// the indicator name. The Go stub validates the shape and rejects all -// invocations with a clear error. type akshareParams struct { - Symbol string `json:"symbol"` - Indicator string `json:"indicator"` + Query string `json:"query"` + Symbol string `json:"symbol,omitempty"` + TopN int `json:"top_n,omitempty"` +} + +func (p akshareParams) stockSymbol() string { + if query := strings.TrimSpace(p.Query); query != "" { + return query + } + return strings.TrimSpace(p.Symbol) +} + +type akshareArticle struct { + Title string `json:"title"` + Content string `json:"content"` + PublishedAt string `json:"published_at"` + Source string `json:"source"` + URL string `json:"url"` } -// akshareEnvelope is the model-facing JSON shape. The stub always -// returns a populated Error so the model gets a deterministic, parseable -// failure. type akshareEnvelope struct { - Fields []string `json:"fields,omitempty"` - Items []any `json:"items,omitempty"` - Error string `json:"_ERROR,omitempty"` + Content string `json:"content,omitempty"` + Articles []akshareArticle `json:"articles,omitempty"` + Error string `json:"_ERROR,omitempty"` } -// AkShareTool is a stub implementation of the AkShare Chinese -// financial data tool. -// -// AkShare is a Python library (https://github.com/akfamily/akshare) -// — there is no public HTTP API. The tool is registered with the -// canvas so DSLs that reference "akshare" continue to parse, but -// every invocation fails fast with a clear "use Python Canvas" -// message rather -// than silently no-op'ing. -// -// AkShareTool does not own an HTTPHelper — it never makes network calls. -type AkShareTool struct{} +type akshareSearchRequest struct { + UID string `json:"uid"` + Keyword string `json:"keyword"` + Type []string `json:"type"` + Client string `json:"client"` + ClientType string `json:"clientType"` + ClientVersion string `json:"clientVersion"` + Param map[string]any `json:"param"` +} -// NewAkShareTool returns an AkShareTool. No HTTPHelper is allocated; -// the stub never issues network requests. -func NewAkShareTool() *AkShareTool { return &AkShareTool{} } +type akshareSearchResponse struct { + Result struct { + Articles []akshareEastMoneyArticle `json:"cmsArticleWebOld"` + } `json:"result"` +} + +type akshareEastMoneyArticle struct { + Code string `json:"code"` + Title string `json:"title"` + Content string `json:"content"` + Date string `json:"date"` + MediaName string `json:"mediaName"` +} + +// AkShareTool retrieves East Money stock news using the same endpoint +// as AkShare's stock_news_em(symbol=...) helper. +type AkShareTool struct { + helper *HTTPHelper + topN int +} + +func NewAkShareTool() *AkShareTool { + return NewAkShareToolWith(NewHTTPHelper()) +} + +func NewAkShareToolWith(h *HTTPHelper) *AkShareTool { + return NewAkShareToolWithTopN(h, defaultAkShareTopN) +} + +func NewAkShareToolWithTopN(h *HTTPHelper, topN int) *AkShareTool { + if h == nil { + h = NewHTTPHelper() + } + if topN <= 0 { + topN = defaultAkShareTopN + } + return &AkShareTool{helper: h, topN: topN} +} // Info returns the tool's metadata for the chat model. func (a *AkShareTool) Info(_ context.Context) (*schema.ToolInfo, error) { @@ -79,36 +123,204 @@ func (a *AkShareTool) Info(_ context.Context) (*schema.ToolInfo, error) { Name: akshareToolName, Desc: akshareToolDescription, ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ - "symbol": { + "query": { Type: schema.String, - Desc: "Stock or asset symbol (e.g. 000001, sh600000, bj920566).", - Required: true, - }, - "indicator": { - Type: schema.String, - Desc: "AkShare indicator function name (e.g. stock_zh_a_hist, stock_zh_a_spot_em).", + Desc: "Stock symbol/code to fetch East Money news for, e.g. 600519.", Required: true, }, }), }, nil } -// InvokableRun validates the input shape and returns a clear "use Python -// Canvas" error. The model receives a JSON envelope with the message in -// the `_ERROR` field so it can present a useful message back to the user -// without surfacing a Go stack trace. -func (a *AkShareTool) InvokableRun(_ context.Context, argsJSON string, _ ...tool.Option) (string, error) { +// InvokableRun fetches stock news from East Money and returns both a +// formatted content string and structured article records. +func (a *AkShareTool) InvokableRun(ctx context.Context, argsJSON string, _ ...tool.Option) (string, error) { var p akshareParams if err := json.Unmarshal([]byte(argsJSON), &p); err != nil { - return akshareErrJSON(errors.New(akshareUnsupportedMessage)), - errors.New(akshareUnsupportedMessage) + return akshareErrJSON(fmt.Errorf("akshare: parse arguments: %w", err)), + fmt.Errorf("akshare: parse arguments: %w", err) } - if p.Symbol == "" || p.Indicator == "" { - return akshareErrJSON(errors.New(akshareUnsupportedMessage)), - errors.New(akshareUnsupportedMessage) + + symbol := p.stockSymbol() + if symbol == "" { + return akshareErrJSON(fmt.Errorf("akshare: query is required")), + fmt.Errorf("akshare: query is required") } - return akshareErrJSON(errors.New(akshareUnsupportedMessage)), - errors.New(akshareUnsupportedMessage) + + topN := a.topN + if p.TopN > 0 { + topN = p.TopN + } + if topN <= 0 { + topN = defaultAkShareTopN + } + + endpoint, err := buildAkShareStockNewsURL(symbol, topN) + if err != nil { + return akshareErrJSON(err), err + } + headers := map[string]string{ + "Accept": "*/*", + "Referer": "https://so.eastmoney.com/news/s?keyword=" + url.QueryEscape(symbol), + "User-Agent": "Mozilla/5.0 (compatible; ragflow/1.0)", + } + + resp, err := a.helper.Do(ctx, http.MethodGet, endpoint, "", "", headers) + if err != nil { + return akshareErrJSON(err), err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + err := fmt.Errorf("akshare: upstream returned %d", resp.StatusCode) + return akshareErrJSON(err), err + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, maxAkShareResponseBytes+1)) + if err != nil { + err = fmt.Errorf("akshare: read response: %w", err) + return akshareErrJSON(err), err + } + if len(body) > maxAkShareResponseBytes { + err = fmt.Errorf("akshare: response too large") + return akshareErrJSON(err), err + } + + articles, err := parseAkShareStockNews(body, topN) + if err != nil { + return akshareErrJSON(err), err + } + + env := akshareEnvelope{ + Content: formatAkShareArticles(articles), + Articles: articles, + } + return akshareJSON(env), nil +} + +func buildAkShareStockNewsURL(symbol string, topN int) (string, error) { + if topN <= 0 { + topN = defaultAkShareTopN + } + inner := akshareSearchRequest{ + Keyword: symbol, + Type: []string{"cmsArticleWebOld"}, + Client: "web", + ClientType: "web", + ClientVersion: "curr", + Param: map[string]any{ + "cmsArticleWebOld": map[string]any{ + "searchScope": "default", + "sort": "default", + "pageIndex": 1, + "pageSize": topN, + "preTag": "", + "postTag": "", + }, + }, + } + innerJSON, err := json.Marshal(inner) + if err != nil { + return "", fmt.Errorf("akshare: build request: %w", err) + } + callback := fmt.Sprintf("jQuery%d_%d", time.Now().UnixNano(), time.Now().UnixMilli()) + q := url.Values{} + q.Set("cb", callback) + q.Set("param", string(innerJSON)) + q.Set("_", fmt.Sprint(time.Now().UnixMilli())) + return akshareStockNewsEndpoint + "?" + q.Encode(), nil +} + +func parseAkShareStockNews(body []byte, topN int) ([]akshareArticle, error) { + if topN <= 0 { + topN = defaultAkShareTopN + } + + rawJSON, err := stripJSONP(string(body)) + if err != nil { + return nil, err + } + + var raw akshareSearchResponse + if err := json.Unmarshal([]byte(rawJSON), &raw); err != nil { + return nil, fmt.Errorf("akshare: decode response: %w", err) + } + + if topN > len(raw.Result.Articles) { + topN = len(raw.Result.Articles) + } + + articles := make([]akshareArticle, 0, topN) + for _, item := range raw.Result.Articles[:topN] { + articleURL := "" + if code := strings.TrimSpace(item.Code); code != "" { + articleURL = "http://finance.eastmoney.com/a/" + code + ".html" + } + articles = append(articles, akshareArticle{ + Title: cleanAkShareText(item.Title), + Content: cleanAkShareText(item.Content), + PublishedAt: strings.TrimSpace(item.Date), + Source: strings.TrimSpace(item.MediaName), + URL: articleURL, + }) + } + return articles, nil +} + +func stripJSONP(s string) (string, error) { + s = strings.TrimSpace(s) + start := strings.IndexByte(s, '(') + end := strings.LastIndexByte(s, ')') + if start < 0 || end <= start { + if strings.HasPrefix(s, "{") { + return s, nil + } + return "", fmt.Errorf("akshare: invalid JSONP response") + } + return strings.TrimSpace(s[start+1 : end]), nil +} + +func cleanAkShareText(s string) string { + replacements := []struct { + old string + new string + }{ + {old: "(", new: ""}, + {old: ")", new: ""}, + {old: "", new: ""}, + {old: "", new: ""}, + {old: "\u3000", new: ""}, + {old: "\r\n", new: " "}, + } + for _, repl := range replacements { + s = strings.ReplaceAll(s, repl.old, repl.new) + } + return strings.TrimSpace(s) +} + +// formatAkShareArticles renders articles into a human-readable block. +// NOTE: the upstream Python PR's _invoke dropped 文章来源 from the +// format string (5 args but only 4 {} placeholders); we preserve all +// five fields here (link, title, content, date, source). +func formatAkShareArticles(articles []akshareArticle) string { + if len(articles) == 0 { + return "" + } + parts := make([]string, 0, len(articles)) + for _, item := range articles { + parts = append(parts, fmt.Sprintf( + `%s + 新闻内容: %s +发布时间:%s +文章来源: %s`, + item.URL, + item.Title, + item.Content, + item.PublishedAt, + item.Source, + )) + } + return strings.Join(parts, "\n\n") } func akshareJSON(env akshareEnvelope) string { diff --git a/internal/agent/tool/akshare_test.go b/internal/agent/tool/akshare_test.go index ba4a9574e1..75da4799d7 100644 --- a/internal/agent/tool/akshare_test.go +++ b/internal/agent/tool/akshare_test.go @@ -19,61 +19,87 @@ package tool import ( "context" "encoding/json" + "net/http" + "net/http/httptest" + "net/url" "strings" "testing" ) -func TestAkShare_StubsUnsupported(t *testing.T) { +func TestAkShare_FetchesStockNews(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/search/jsonp" { + t.Errorf("path = %q, want /search/jsonp", r.URL.Path) + http.Error(w, "unexpected path", http.StatusBadRequest) + return + } + param := r.URL.Query().Get("param") + var req akshareSearchRequest + if err := json.Unmarshal([]byte(param), &req); err != nil { + t.Errorf("param is not valid request JSON: %v (%s)", err, param) + http.Error(w, "invalid param", http.StatusBadRequest) + return + } + if req.Keyword != "600519" { + t.Errorf("keyword = %q, want 600519", req.Keyword) + http.Error(w, "unexpected keyword", http.StatusBadRequest) + return + } + + callback := r.URL.Query().Get("cb") + if callback == "" { + callback = "callback" + } + w.Header().Set("Content-Type", "application/javascript") + _, _ = w.Write([]byte(callback + `({ + "result": { + "cmsArticleWebOld": [ + {"code":"202601010001","title":"贵州茅台 title0","content":"content0\r\nmore","date":"2026-01-01","mediaName":"src0"}, + {"code":"202601010002","title":"title1","content":"content1","date":"2026-01-02","mediaName":"src1"}, + {"code":"202601010003","title":"title2","content":"content2","date":"2026-01-03","mediaName":"src2"} + ] + } + })`)) + })) + defer srv.Close() + + oldEndpoint := akshareStockNewsEndpoint + akshareStockNewsEndpoint = srv.URL + "/search/jsonp" + defer func() { akshareStockNewsEndpoint = oldEndpoint }() + + tool := NewAkShareToolWithTopN(NewHTTPHelper(), 2) + out, err := tool.InvokableRun(context.Background(), `{"query":"600519"}`) + if err != nil { + t.Fatalf("InvokableRun: %v (out=%s)", err, out) + } + + var env akshareEnvelope + if err := json.Unmarshal([]byte(out), &env); err != nil { + t.Fatalf("output is not valid JSON: %v (raw=%s)", err, out) + } + if env.Error != "" { + t.Fatalf("env.Error = %q, want empty", env.Error) + } + if len(env.Articles) != 2 { + t.Fatalf("len(Articles) = %d, want 2", len(env.Articles)) + } + if env.Articles[0].Title != "贵州茅台 title0" { + t.Errorf("first title = %q, want cleaned title", env.Articles[0].Title) + } + if !strings.Contains(env.Content, `贵州茅台 title0`) { + t.Errorf("formatted content missing first link/title: %q", env.Content) + } + if strings.Count(env.Content, "新闻内容:") != 2 { + t.Errorf("formatted content count = %d, want 2", strings.Count(env.Content, "新闻内容:")) + } +} + +func TestAkShare_ParseTruncatesToTopN(t *testing.T) { t.Parallel() - cases := []struct { - name string - args string - wantSub string - }{ - { - name: "well-formed args", - args: `{"symbol":"000001","indicator":"stock_zh_a_hist"}`, - wantSub: "Python library", - }, - { - name: "missing indicator", - args: `{"symbol":"000001"}`, - wantSub: "Python library", - }, - { - name: "missing symbol", - args: `{"indicator":"stock_zh_a_hist"}`, - wantSub: "Python library", - }, - { - name: "empty payload", - args: `{}`, - wantSub: "Python library", - }, - } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - tool := NewAkShareTool() - out, err := tool.InvokableRun(context.Background(), tc.args) - if err == nil { - t.Fatalf("expected error, got nil (out=%s)", out) - } - if !strings.Contains(err.Error(), tc.wantSub) { - t.Errorf("err = %q, want substring %q", err.Error(), tc.wantSub) - } - var env akshareEnvelope - if jerr := json.Unmarshal([]byte(out), &env); jerr != nil { - t.Fatalf("output is not valid JSON: %v (raw=%s)", jerr, out) - } - if env.Error == "" { - t.Errorf("env.Error = empty, want populated") - } - if !strings.Contains(env.Error, "Python") { - t.Errorf("env.Error = %q, want to mention Python Canvas", env.Error) - } - }) + articles := parseAkShareFixture(t, 3) + if len(articles) != 3 { + t.Fatalf("fixture articles = %d, want 3", len(articles)) } } @@ -85,8 +111,21 @@ func TestAkShare_RejectsMalformedJSON(t *testing.T) { if err == nil { t.Fatal("expected error for malformed JSON, got nil") } - if !strings.Contains(err.Error(), "Python") { - t.Errorf("err = %q, want to mention Python", err.Error()) + if !strings.Contains(err.Error(), "parse arguments") { + t.Errorf("err = %q, want parse arguments error", err.Error()) + } +} + +func TestAkShare_RejectsMissingQuery(t *testing.T) { + t.Parallel() + + tool := NewAkShareTool() + out, err := tool.InvokableRun(context.Background(), `{}`) + if err == nil { + t.Fatalf("expected error for missing query, got nil (out=%s)", out) + } + if !strings.Contains(err.Error(), "query is required") { + t.Errorf("err = %q, want query required", err.Error()) } } @@ -98,13 +137,93 @@ func TestAkShare_Info(t *testing.T) { if err != nil { t.Fatalf("Info: %v", err) } - if info.Name != "akshare" { - t.Errorf("Name = %q, want akshare", info.Name) + if info.Name != "akshare_stock_news" { + t.Errorf("Name = %q, want akshare_stock_news", info.Name) } - if !strings.Contains(info.Desc, "AkShare") { - t.Errorf("Desc = %q, want to mention AkShare", info.Desc) + if !strings.Contains(info.Desc, "East Money") { + t.Errorf("Desc = %q, want to mention East Money", info.Desc) } - if !strings.Contains(info.Desc, "STUB") && !strings.Contains(info.Desc, "Python") { - t.Errorf("Desc = %q, want to flag stub status", info.Desc) + if info.ParamsOneOf == nil { + t.Fatal("ParamsOneOf = nil, want schema definition") + } + schema, err := info.ParamsOneOf.ToJSONSchema() + if err != nil { + t.Fatalf("ToJSONSchema: %v", err) + } + raw, err := json.Marshal(schema) + if err != nil { + t.Fatalf("marshal params schema: %v", err) + } + params := string(raw) + if !strings.Contains(params, `"query"`) { + t.Fatalf("schema missing query parameter: %s", params) + } + if !strings.Contains(params, `"required":["query"]`) { + t.Fatalf("schema does not require query: %s", params) + } +} + +func TestBuildByName_AkShareRejectsInvalidTopN(t *testing.T) { + t.Parallel() + + _, err := BuildByName("akshare", map[string]any{"top_n": 0}) + if err == nil { + t.Fatal("expected top_n validation error") + } + if !strings.Contains(err.Error(), "positive integer") { + t.Fatalf("err = %q, want positive integer validation", err.Error()) + } +} + +func parseAkShareFixture(t *testing.T, topN int) []akshareArticle { + t.Helper() + + callback := "cb" + body := callback + `({ + "result": { + "cmsArticleWebOld": [ + {"code":"1","title":"title1","content":"content1","date":"2026-01-01","mediaName":"src1"}, + {"code":"2","title":"title2","content":"content2","date":"2026-01-02","mediaName":"src2"}, + {"code":"3","title":"title3","content":"content3","date":"2026-01-03","mediaName":"src3"}, + {"code":"4","title":"title4","content":"content4","date":"2026-01-04","mediaName":"src4"} + ] + } + })` + articles, err := parseAkShareStockNews([]byte(body), topN) + if err != nil { + t.Fatalf("parseAkShareStockNews: %v", err) + } + return articles +} + +func TestBuildAkShareStockNewsURL(t *testing.T) { + oldEndpoint := akshareStockNewsEndpoint + akshareStockNewsEndpoint = "https://example.com/search/jsonp" + defer func() { akshareStockNewsEndpoint = oldEndpoint }() + + rawURL, err := buildAkShareStockNewsURL("600519", 25) + if err != nil { + t.Fatalf("buildAkShareStockNewsURL: %v", err) + } + u, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("parse built URL: %v", err) + } + if u.Host != "example.com" { + t.Fatalf("host = %q, want example.com", u.Host) + } + var req akshareSearchRequest + if err := json.Unmarshal([]byte(u.Query().Get("param")), &req); err != nil { + t.Fatalf("param JSON: %v", err) + } + if req.Keyword != "600519" { + t.Fatalf("keyword = %q, want 600519", req.Keyword) + } + cmsParam, ok := req.Param["cmsArticleWebOld"].(map[string]any) + if !ok { + t.Fatalf("cmsArticleWebOld param = %T, want object", req.Param["cmsArticleWebOld"]) + } + if cmsParam["pageSize"] != float64(25) { + t.Fatalf("pageSize = %v, want 25", cmsParam["pageSize"]) } } diff --git a/internal/agent/tool/crawler.go b/internal/agent/tool/crawler.go index f942ec5059..21db1c6709 100644 --- a/internal/agent/tool/crawler.go +++ b/internal/agent/tool/crawler.go @@ -31,20 +31,29 @@ import ( "golang.org/x/net/html" ) -const crawlerToolName = "crawler" +const crawlerToolName = "web_crawler" -const crawlerToolDescription = "Fetches a web page and returns its extracted text content and links." +const crawlerToolDescription = "Crawls a web page and returns its extracted text content and links." -// crawlerArgs is the JSON shape the model sends in. max_depth and -// max_pages are accepted for API symmetry with the Python tool, but -// the current implementation only supports depth=0 (single page -// fetch). +// crawlerArgs is the JSON shape the model sends in. query is the +// Python-compatible argument name; url is accepted for older Go +// callers. max_depth and max_pages are accepted for API symmetry with +// earlier Go callers, but the current implementation only supports +// depth=0 (single page fetch). type crawlerArgs struct { - URL string `json:"url"` + Query string `json:"query"` + URL string `json:"url,omitempty"` MaxDepth int `json:"max_depth,omitempty"` MaxPages int `json:"max_pages,omitempty"` } +func (a crawlerArgs) targetURL() string { + if url := strings.TrimSpace(a.Query); url != "" { + return url + } + return strings.TrimSpace(a.URL) +} + // crawlerResult is the JSON envelope returned to the model. The shape // mirrors the Python tool's `content` / `links` output. type crawlerResult struct { @@ -119,9 +128,9 @@ func (c *CrawlerTool) Info(_ context.Context) (*schema.ToolInfo, error) { Name: crawlerToolName, Desc: crawlerToolDescription, ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ - "url": { + "query": { Type: schema.String, - Desc: "The URL to fetch. Must be http or https.", + Desc: "The absolute URL to crawl. Must include an http:// or https:// scheme.", Required: true, }, "max_depth": { @@ -159,20 +168,21 @@ func (c *CrawlerTool) InvokableRun(ctx context.Context, argumentsInJSON string, fmt.Errorf("crawler: parse arguments: %w", err) } - if strings.TrimSpace(args.URL) == "" { - return crawlerStubResult(crawlerResult{Error: "url is required"}), - errors.New("crawler: empty url") + targetURL := args.targetURL() + if targetURL == "" { + return crawlerStubResult(crawlerResult{Error: "query is required"}), + errors.New("crawler: empty query") } - if !strings.HasPrefix(args.URL, "http://") && !strings.HasPrefix(args.URL, "https://") { - return crawlerStubResult(crawlerResult{URL: args.URL, Error: "url must be http or https"}), - fmt.Errorf("crawler: unsupported url scheme: %s", args.URL) + if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") { + return crawlerStubResult(crawlerResult{URL: targetURL, Error: "url must be http or https"}), + fmt.Errorf("crawler: unsupported url scheme: %s", targetURL) } // Reject max_depth > 0 BEFORE the SSRF guard: the guard performs a // DNS lookup that may be slow / fail in CI, and a depth-0 caller // asking for max_depth=10 should be rejected on a structural // problem first. if args.MaxDepth > 0 { - return crawlerStubResult(crawlerResult{URL: args.URL, Error: ErrCrawlerDepthUnsupported.Error()}), + return crawlerStubResult(crawlerResult{URL: targetURL, Error: ErrCrawlerDepthUnsupported.Error()}), ErrCrawlerDepthUnsupported } // SSRF guard + DNS-rebinding pinning. c.resolve validates the URL @@ -184,29 +194,29 @@ func (c *CrawlerTool) InvokableRun(ctx context.Context, argumentsInJSON string, // IP we resolved here) AND TLS SNI / cert verification continue to // target the validated hostname. Rewriting the URL host to the IP // would have broken HTTPS, so the pinning happens in the dialer. - host, pinnedIP, resolveErr := c.resolve(args.URL) + host, pinnedIP, resolveErr := c.resolve(targetURL) if resolveErr != nil { - return crawlerStubResult(crawlerResult{URL: args.URL, Error: resolveErr.Error()}), resolveErr + return crawlerStubResult(crawlerResult{URL: targetURL, Error: resolveErr.Error()}), resolveErr } - resp, err := c.helper.DoPinned(ctx, http.MethodGet, args.URL, "", "", nil, host, pinnedIP) + resp, err := c.helper.DoPinned(ctx, http.MethodGet, targetURL, "", "", nil, host, pinnedIP) if err != nil { - return crawlerStubResult(crawlerResult{URL: args.URL, Error: err.Error()}), err + return crawlerStubResult(crawlerResult{URL: targetURL, Error: err.Error()}), err } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return crawlerStubResult(crawlerResult{URL: args.URL, Status: resp.StatusCode, Error: "read body: " + err.Error()}), + return crawlerStubResult(crawlerResult{URL: targetURL, Status: resp.StatusCode, Error: "read body: " + err.Error()}), fmt.Errorf("crawler: read body: %w", err) } page, err := extractPage(body) if err != nil { - return crawlerStubResult(crawlerResult{URL: args.URL, Status: resp.StatusCode, Error: err.Error()}), + return crawlerStubResult(crawlerResult{URL: targetURL, Status: resp.StatusCode, Error: err.Error()}), fmt.Errorf("crawler: extract: %w", err) } - page.URL = args.URL + page.URL = targetURL page.Status = resp.StatusCode return crawlerJSON(page) diff --git a/internal/agent/tool/crawler_test.go b/internal/agent/tool/crawler_test.go index 933bd15244..4c3df14672 100644 --- a/internal/agent/tool/crawler_test.go +++ b/internal/agent/tool/crawler_test.go @@ -64,7 +64,7 @@ func TestCrawler_FetchesAndExtractsText(t *testing.T) { } c := NewCrawlerTool().WithResolver(loopbackResolver) out, err := c.InvokableRun(context.Background(), - `{"url":`+jsonString(srv.URL)+`,"max_depth":0}`) + `{"query":`+jsonString(srv.URL)+`,"max_depth":0}`) if err != nil { t.Fatalf("InvokableRun: %v", err) } @@ -107,19 +107,19 @@ func TestCrawler_RejectsMaxDepthGreaterThanZero(t *testing.T) { t.Parallel() c := NewCrawlerTool() - _, err := c.InvokableRun(context.Background(), `{"url":"https://example.com","max_depth":1}`) + _, err := c.InvokableRun(context.Background(), `{"query":"https://example.com","max_depth":1}`) if !errors.Is(err, ErrCrawlerDepthUnsupported) { t.Fatalf("err = %v, want ErrCrawlerDepthUnsupported", err) } } -func TestCrawler_RejectsMissingURL(t *testing.T) { +func TestCrawler_RejectsMissingQuery(t *testing.T) { t.Parallel() c := NewCrawlerTool() - _, err := c.InvokableRun(context.Background(), `{"url":""}`) + _, err := c.InvokableRun(context.Background(), `{"query":""}`) if err == nil { - t.Fatal("expected error for empty url") + t.Fatal("expected error for empty query") } } @@ -127,12 +127,29 @@ func TestCrawler_RejectsNonHTTPScheme(t *testing.T) { t.Parallel() c := NewCrawlerTool() - _, err := c.InvokableRun(context.Background(), `{"url":"file:///etc/passwd"}`) + _, err := c.InvokableRun(context.Background(), `{"query":"file:///etc/passwd"}`) if err == nil || !strings.Contains(err.Error(), "scheme") { t.Fatalf("err = %v, want to reject file:// scheme", err) } } +func TestCrawler_AcceptsLegacyURLArgument(t *testing.T) { + t.Parallel() + + sentinel := errors.New("stop after legacy url normalization") + c := NewCrawlerTool().WithResolver(func(rawURL string) (string, net.IP, error) { + if rawURL != "https://example.com" { + t.Fatalf("resolver URL = %q, want https://example.com", rawURL) + } + return "example.com", net.ParseIP("93.184.216.34"), sentinel + }) + + _, err := c.InvokableRun(context.Background(), `{"url":"https://example.com"}`) + if !errors.Is(err, sentinel) { + t.Fatalf("err = %v, want resolver error after accepting legacy url", err) + } +} + func TestCrawler_Info(t *testing.T) { t.Parallel() @@ -141,12 +158,33 @@ func TestCrawler_Info(t *testing.T) { if err != nil { t.Fatalf("Info: %v", err) } - if info.Name != "crawler" { - t.Errorf("Name = %q, want crawler", info.Name) + if info.Name != "web_crawler" { + t.Errorf("Name = %q, want web_crawler", info.Name) } if !strings.Contains(info.Desc, "text") { t.Errorf("Desc = %q, want to mention text extraction", info.Desc) } + if info.ParamsOneOf == nil { + t.Fatal("ParamsOneOf = nil, want schema definition") + } + paramsSchema, err := info.ParamsOneOf.ToJSONSchema() + if err != nil { + t.Fatalf("ToJSONSchema: %v", err) + } + paramsJSON, err := json.Marshal(paramsSchema) + if err != nil { + t.Fatalf("marshal params schema: %v", err) + } + params := string(paramsJSON) + if !strings.Contains(params, `"query"`) { + t.Fatalf("schema missing query parameter: %s", params) + } + if !strings.Contains(params, `"required":["query"]`) { + t.Fatalf("schema does not require query: %s", params) + } + if strings.Contains(params, `"url"`) { + t.Fatalf("schema exposes legacy url parameter: %s", params) + } } // jsonString is defined in exesql_test.go (same package). diff --git a/internal/agent/tool/registry.go b/internal/agent/tool/registry.go index ecc0efd004..0d1de6c6f5 100644 --- a/internal/agent/tool/registry.go +++ b/internal/agent/tool/registry.go @@ -29,7 +29,7 @@ import ( type Factory func(params map[string]any) (einotool.BaseTool, error) var registry = map[string]Factory{ - "akshare": noConfig("akshare", func() einotool.BaseTool { return NewAkShareTool() }), + "akshare": buildAkShareTool, "arxiv": noConfig("arxiv", func() einotool.BaseTool { return NewArxivTool() }), "bgpt": noConfig("bgpt", func() einotool.BaseTool { return NewBGPTTool() }), "code_exec": noConfig("code_exec", func() einotool.BaseTool { return NewCodeExecTool() }), @@ -53,6 +53,7 @@ var registry = map[string]Factory{ "tavily": noConfig("tavily", func() einotool.BaseTool { return NewTavilyTool() }), "tushare": noConfig("tushare", func() einotool.BaseTool { return NewTushareTool() }), "wencai": noConfig("wencai", func() einotool.BaseTool { return NewWencaiTool() }), + "web_crawler": noConfig("web_crawler", func() einotool.BaseTool { return NewCrawlerTool() }), "wikipedia": noConfig("wikipedia", func() einotool.BaseTool { return NewWikipediaTool() }), "yahoo_finance": noConfig("yahoo_finance", func() einotool.BaseTool { return NewYahooFinanceTool() }), } @@ -106,6 +107,24 @@ func BuildAll(names []string, perToolParams map[string]map[string]any) ([]einoto return tools, nil } +func buildAkShareTool(params map[string]any) (einotool.BaseTool, error) { + topN := defaultAkShareTopN + if len(params) != 0 { + for key := range params { + if key != "top_n" { + return nil, fmt.Errorf("agent tool: tool %q only accepts node-level param top_n", "akshare") + } + } + if v, ok := intParam(params, "top_n"); ok { + topN = v + } + if topN <= 0 { + return nil, fmt.Errorf("agent tool: tool %q requires positive integer node-level param top_n", "akshare") + } + } + return NewAkShareToolWithTopN(nil, topN), nil +} + func buildExeSQLTool(params map[string]any) (einotool.BaseTool, error) { conn, err := decodeExeSQLConnParams(params) if err != nil { diff --git a/internal/agent/tool/registry_test.go b/internal/agent/tool/registry_test.go index ba9a824465..de889a7fa3 100644 --- a/internal/agent/tool/registry_test.go +++ b/internal/agent/tool/registry_test.go @@ -41,11 +41,14 @@ func TestBuildAll_UnknownTool(t *testing.T) { } func TestBuildAll_AllRegisteredTools(t *testing.T) { + // Every key in registry (27 entries, 23 unique canonical tools). names := []string{ - "akshare", "arxiv", "code_exec", "crawler", "deepl", "duckduckgo", - "email", "github", "google", "google_scholar", "jin10", "pubmed", - "qweather", "retrieval", "searxng", "tavily", "tushare", "wencai", - "wikipedia", "yahoo_finance", "execute_sql", + "akshare", "arxiv", "bgpt", "code_exec", "crawler", "deepl", + "duckduckgo", "email", "exesql", "execute_sql", "github", "google", + "google_scholar", "jin10", "keenable", "pubmed", "qweather", + "retrieval", "search_my_dataset", "search_my_dateset", "searxng", + "tavily", "tushare", "web_crawler", "wencai", "wikipedia", + "yahoo_finance", } params := map[string]map[string]any{ "execute_sql": { @@ -57,6 +60,18 @@ func TestBuildAll_AllRegisteredTools(t *testing.T) { "password": "p", "max_records": 10, }, + "exesql": { + "db_type": "mysql", + "host": "127.0.0.1", + "port": 3306, + "database": "demo", + "username": "u", + "password": "p", + "max_records": 10, + }, + "keenable": { + "api_key": "key-test", + }, } tools, err := BuildAll(names, params) if err != nil { @@ -101,15 +116,16 @@ func TestBuildAll_KeenableRejectsEmptyNodeAPIKey(t *testing.T) { func TestToolRegistry_SchemasAreComplete(t *testing.T) { t.Parallel() - // Every entry the registry advertises. 23 names, 22 unique + // Every entry the registry advertises. 27 names, 23 unique // canonical tools (execute_sql == exesql, retrieval == - // search_my_dateset). + // search_my_dataset == search_my_dateset, crawler == web_crawler). names := []string{ - "akshare", "arxiv", "code_exec", "crawler", "deepl", "duckduckgo", - "email", "execute_sql", "exesql", "github", "google", - "google_scholar", "jin10", "keenable", "pubmed", "qweather", "retrieval", - "search_my_dateset", "searxng", "tavily", "tushare", "wencai", - "wikipedia", "yahoo_finance", + "akshare", "arxiv", "bgpt", "code_exec", "crawler", "deepl", + "duckduckgo", "email", "execute_sql", "exesql", "github", "google", + "google_scholar", "jin10", "keenable", "pubmed", "qweather", + "retrieval", "search_my_dataset", "search_my_dateset", "searxng", + "tavily", "tushare", "web_crawler", "wencai", "wikipedia", + "yahoo_finance", } params := map[string]map[string]any{ "execute_sql": { @@ -161,14 +177,17 @@ func TestToolRegistry_SchemasAreComplete(t *testing.T) { } // Alias consistency: execute_sql and exesql must surface the - // same canonical Info().Name; same for retrieval and - // search_my_dateset. A bug here would mean an alias was - // accidentally pointed at a different tool. + // same canonical Info().Name; same for retrieval/search_my_dataset/ + // search_my_dateset and crawler/web_crawler. A bug here would mean + // an alias was accidentally pointed at a different tool. canonicalByAlias := map[string]string{ "execute_sql": "execute_sql", "exesql": "execute_sql", "retrieval": "search_my_dateset", + "search_my_dataset": "search_my_dateset", "search_my_dateset": "search_my_dateset", + "crawler": "web_crawler", + "web_crawler": "web_crawler", } for _, name := range names { canonical, ok := canonicalByAlias[name] diff --git a/tools/gen-component-parity/main.go b/tools/gen-component-parity/main.go index e4c33e1dfb..fbe7ab392f 100644 --- a/tools/gen-component-parity/main.go +++ b/tools/gen-component-parity/main.go @@ -161,7 +161,7 @@ func sortedToolNames() []string { "email", "execute_sql", "exesql", "github", "google", "google_scholar", "jin10", "pubmed", "qweather", "retrieval", "search_my_dataset", "search_my_dateset", "searxng", "tavily", "tushare", "wencai", - "wikipedia", "yahoo_finance", + "web_crawler", "wikipedia", "yahoo_finance", } sort.Strings(known) return known