Files

494 lines
15 KiB
Go
Raw Permalink Normal View History

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package elasticsearch
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"sync"
"testing"
"time"
"ragflow/internal/tokenizer"
"github.com/elastic/go-elasticsearch/v8"
)
// capturedRequest holds the request body the test server saw, for
// assertions.
type capturedRequest struct {
mu sync.Mutex
path string
body string
method string
}
// newCapturingServer returns an httptest.Server that captures each
// incoming request and replies with the given body / status.
func newCapturingServer(t *testing.T, replyStatus int, replyBody string) (*httptest.Server, *capturedRequest) {
t.Helper()
cap := &capturedRequest{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
cap.mu.Lock()
cap.method = r.Method
cap.path = r.URL.Path
cap.body = string(body)
cap.mu.Unlock()
w.Header().Set("X-Elastic-Product", "Elasticsearch")
w.WriteHeader(replyStatus)
_, _ = w.Write([]byte(replyBody))
}))
t.Cleanup(srv.Close)
return srv, cap
}
// newTestEngine constructs an elasticsearchEngine pointing at the given
// test server. Bypasses NewEngine (which calls ES Info to verify
// connectivity) — the test server is a stub, not a real ES cluster.
func newTestEngine(t *testing.T, srvURL string) *elasticsearchEngine {
t.Helper()
client, err := elasticsearch.NewClient(elasticsearch.Config{
Addresses: []string{srvURL},
})
if err != nil {
t.Fatalf("elasticsearch.NewClient: %v", err)
}
return &elasticsearchEngine{client: client}
}
const sampleESResponse = `{
"columns": [
{"name": "doc_id", "type": "text"},
{"name": "docnm", "type": "text"},
{"name": "count", "type": "long"}
],
"rows": [
["d1", "report.pdf", 5],
["d2", "spec.pdf", 3]
]
}`
// TestRunSQL_NoFilterAdded verifies the request body is exactly
// {"query": <sql>} — the redundant `filter` field that the previous
// implementation added is gone. (service.addKBFilter is the source of
// truth for kb_id scoping upstream of RunSQL.)
func TestRunSQL_NoFilterAdded(t *testing.T) {
srv, cap := newCapturingServer(t, http.StatusOK, sampleESResponse)
e := newTestEngine(t, srv.URL)
rows, err := e.RunSQL(context.Background(), "ragflow_t1", "SELECT doc_id FROM ragflow_t1", nil, "json")
if err != nil {
t.Fatalf("RunSQL: %v", err)
}
if len(rows) != 2 {
t.Fatalf("rows: got %d, want 2", len(rows))
}
cap.mu.Lock()
got := cap.body
cap.mu.Unlock()
var body map[string]interface{}
if err := json.Unmarshal([]byte(got), &body); err != nil {
t.Fatalf("body is not JSON: %v\nbody=%q", err, got)
}
if _, has := body["filter"]; has {
t.Errorf("RunSQL request must NOT include top-level filter (addKBFilter is the source of truth upstream). body=%v", body)
}
if _, has := body["query"]; !has {
t.Errorf("RunSQL request must include query. body=%v", body)
}
}
// TestRunSQL_WhitespaceNormalizedAndPercentStripped verifies the Python
// preprocessing step `re.sub(r"[ `]+", " ", sql)` + `sql.replace("%", "")`
// is applied. Without these, the LLM-generated SQL with stray backticks
// or `%` characters (e.g. from JSON decoding glitches) would fail to
// parse in ES.
func TestRunSQL_WhitespaceNormalizedAndPercentStripped(t *testing.T) {
srv, cap := newCapturingServer(t, http.StatusOK, sampleESResponse)
e := newTestEngine(t, srv.URL)
// Input SQL has multiple backticks/spaces and trailing % characters.
in := "SELECT doc_id FROM `ragflow_t1` WHERE count > 0 %"
_, err := e.RunSQL(context.Background(), "ragflow_t1", in, nil, "json")
if err != nil {
t.Fatalf("RunSQL: %v", err)
}
cap.mu.Lock()
got := cap.body
cap.mu.Unlock()
var body map[string]interface{}
if err := json.Unmarshal([]byte(got), &body); err != nil {
t.Fatalf("body is not JSON: %v\nbody=%q", err, got)
}
q, _ := body["query"].(string)
if strings.Contains(q, " ") {
t.Errorf("query still has multiple spaces (whitespace not normalized): %q", q)
}
if strings.Contains(q, "`") {
t.Errorf("query still has backticks (whitespace+backtick regex not applied): %q", q)
}
if strings.Contains(q, "%") {
t.Errorf("query still has %% (percent strip not applied): %q", q)
}
}
// TestRunSQL_PerAttemptTimeout verifies the derived context has a 2s
// deadline. We send a hanging response from the test server and assert
// the call returns well before 30s (the Go ES client's default
// transport-level timeout). With the retry loop in place, the total
// time is 2s (first attempt) + 3s (sleep) + 2s (second attempt) = ~7s.
func TestRunSQL_PerAttemptTimeout(t *testing.T) {
hang := make(chan struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-hang
}))
t.Cleanup(func() {
close(hang)
srv.Close()
})
e := newTestEngine(t, srv.URL)
start := time.Now()
_, err := e.RunSQL(context.Background(), "ragflow_t1", "SELECT 1", nil, "json")
elapsed := time.Since(start)
if err == nil {
t.Fatalf("RunSQL: got nil error, want timeout error")
}
// 2s + 3s + 2s = 7s; allow generous upper bound for the test runner.
if elapsed < 6*time.Second {
t.Errorf("RunSQL returned in %s; expected ~7s (2 attempts + 3s sleep)", elapsed)
}
if elapsed > 10*time.Second {
t.Errorf("RunSQL took %s; expected ~7s, looks like the retry didn't fire", elapsed)
}
// The final error should mention timeout + 2 attempts.
if !strings.Contains(err.Error(), "timeout after 2 attempts") {
t.Errorf("err: got %q, want substring %q", err.Error(), "timeout after 2 attempts")
}
}
// TestRunSQL_RetryOnTimeoutThenSucceed simulates Python's
// ConnectionTimeout-retry pattern: the first attempt times out, the
// second attempt returns valid rows. The loop should silently retry
// and return the rows.
func TestRunSQL_RetryOnTimeoutThenSucceed(t *testing.T) {
var (
mu sync.Mutex
calls int
)
release := make(chan struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
calls++
attempt := calls
mu.Unlock()
if attempt == 1 {
// First attempt: hang so the 2s context fires.
select {
case <-release:
case <-r.Context().Done():
}
return
}
w.Header().Set("X-Elastic-Product", "Elasticsearch")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(sampleESResponse))
}))
t.Cleanup(func() {
close(release)
srv.Close()
})
e := newTestEngine(t, srv.URL)
rows, err := e.RunSQL(context.Background(), "ragflow_t1", "SELECT 1", nil, "json")
if err != nil {
t.Fatalf("RunSQL: %v", err)
}
if len(rows) != 2 {
t.Errorf("rows: got %d, want 2 (second attempt should succeed)", len(rows))
}
mu.Lock()
defer mu.Unlock()
if calls != 2 {
t.Errorf("server calls: got %d, want 2 (initial + one retry)", calls)
}
}
// TestRunSQL_NonTimeoutErrorSurfacesImmediately verifies the non-retry
// path: a 4xx ES response should NOT trigger a retry. The error must
// be wrapped as `SQL error: <e>\n\nSQL: <sql>`, matching Python's
// es_conn_base.py:400.
func TestRunSQL_NonTimeoutErrorSurfacesImmediately(t *testing.T) {
var (
mu sync.Mutex
calls int
)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
calls++
mu.Unlock()
w.Header().Set("X-Elastic-Product", "Elasticsearch")
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error": "syntax error"}`))
}))
t.Cleanup(srv.Close)
e := newTestEngine(t, srv.URL)
_, err := e.RunSQL(context.Background(), "ragflow_t1", "SELECT bad", nil, "json")
if err == nil {
t.Fatalf("RunSQL: got nil error, want error")
}
mu.Lock()
defer mu.Unlock()
if calls != 1 {
t.Errorf("server calls: got %d, want 1 (non-timeout error must NOT retry)", calls)
}
// Python wraps as `f"SQL error: {e}\n\nSQL: {sql}"`.
if !strings.Contains(err.Error(), "SQL error:") {
t.Errorf("err: got %q, want substring 'SQL error:'", err.Error())
}
if !strings.Contains(err.Error(), "SQL: SELECT bad") {
t.Errorf("err: got %q, want substring 'SQL: SELECT bad'", err.Error())
}
}
// TestRunSQL_RequestBodyHasFetchSizeAndFormat verifies the request body
// includes fetch_size=128 and the SQLQueryRequest is built with
// format="json", matching the Python defaults at rag/nlp/search.py:773.
func TestRunSQL_RequestBodyHasFetchSizeAndFormat(t *testing.T) {
srv, cap := newCapturingServer(t, http.StatusOK, sampleESResponse)
e := newTestEngine(t, srv.URL)
if _, err := e.RunSQL(context.Background(), "ragflow_t1", "SELECT 1", nil, "json"); err != nil {
t.Fatalf("RunSQL: %v", err)
}
cap.mu.Lock()
got := cap.body
cap.mu.Unlock()
var body map[string]interface{}
if err := json.Unmarshal([]byte(got), &body); err != nil {
t.Fatalf("body is not JSON: %v\nbody=%q", err, got)
}
fs, ok := body["fetch_size"]
if !ok {
t.Errorf("body has no fetch_size; got %v", body)
}
if fmt.Sprint(fs) != "128" {
t.Errorf("fetch_size: got %v, want 128", fs)
}
}
// TestRunSQL_EmptyRowsReturnsNilNil verifies the (nil, nil) sentinel
// for empty results — callers treat this as "fall through to vector
// retrieval".
func TestRunSQL_EmptyRowsReturnsNilNil(t *testing.T) {
empty := `{"columns": [{"name": "doc_id", "type": "text"}], "rows": []}`
srv, _ := newCapturingServer(t, http.StatusOK, empty)
e := newTestEngine(t, srv.URL)
rows, err := e.RunSQL(context.Background(), "ragflow_t1", "SELECT doc_id FROM ragflow_t1", nil, "json")
if err != nil {
t.Fatalf("RunSQL: %v", err)
}
if rows != nil {
t.Errorf("rows: got %v, want nil (empty-rows sentinel)", rows)
}
}
// TestRunSQL_PostsToSQLPath verifies the request goes to the /_sql
// endpoint (the modern ES SQL API; the older /_xpack/sql path is
// deprecated as of ES 7.x). The Go SDK's esapi.SQLQueryRequest hits
// /_sql; the Python ES client is also pinned to the modern endpoint
// at runtime even though the legacy /_xpack/sql name appears in the
// SDK's method (`es.sql.query(...)`).
func TestRunSQL_PostsToSQLPath(t *testing.T) {
srv, cap := newCapturingServer(t, http.StatusOK, sampleESResponse)
e := newTestEngine(t, srv.URL)
if _, err := e.RunSQL(context.Background(), "ragflow_t1", "SELECT 1", nil, "json"); err != nil {
t.Fatalf("RunSQL: %v", err)
}
cap.mu.Lock()
got := cap.path
cap.mu.Unlock()
if got != "/_sql" {
t.Errorf("path: got %q, want /_sql", got)
}
}
// TestMain registers the engine as "infinity" so tokenizer.Tokenize and
// tokenizer.FineGrainedTokenize short-circuit and return the input
// as-is. This lets the rewrite tests assert on the SHAPE of the MATCH()
// substitution without depending on a real tokenizer pool.
func TestMain(m *testing.M) {
tokenizer.RegisterEngineType(func() string { return "infinity" })
m.Run()
}
func TestPreprocess_WhitespaceAndBackticks(t *testing.T) {
cases := []struct {
in, want string
}{
{"a b", "a b"},
{"a b c", "a b c"},
{"a`b`c", "a b c"},
{"a `` b", "a b"},
{" leading and trailing ", " leading and trailing "},
}
for _, c := range cases {
if got := Preprocess(c.in); got != c.want {
t.Errorf("Preprocess(%q) = %q, want %q", c.in, got, c.want)
}
}
}
func TestPreprocess_StripsPercent(t *testing.T) {
cases := []struct {
in, want string
}{
{"count > 0 %", "count > 0 "},
{"100% match", "100 match"},
{"%%%", ""},
}
for _, c := range cases {
if got := Preprocess(c.in); got != c.want {
t.Errorf("Preprocess(%q) = %q, want %q", c.in, got, c.want)
}
}
}
func TestPreprocess_LktksRewrite(t *testing.T) {
cases := []struct {
name string
in string
field string
expect string
}{
{
"like with single-token value (ltks suffix)",
"select content_ltks like 'weather'",
"content_ltks",
"MATCH(content_ltks,",
},
{
"= with multi-word value (ltks suffix)",
"select content_ltks = 'final report'",
"content_ltks",
"MATCH(content_ltks,",
},
{
"tks (no l) suffix",
"select title_tks = 'hello'",
"title_tks",
"MATCH(title_tks,",
},
{
"leading-space anchor: no leading space means no match (mirrors Python regex)",
"content_ltks like 'weather'",
"content_ltks",
"content_ltks like 'weather'",
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got := Preprocess(c.in)
isAnchorTest := c.expect == c.in
if isAnchorTest {
if got != c.in {
t.Errorf("Preprocess(%q) = %q, want unchanged (leading-space anchor should prevent match)", c.in, got)
}
return
}
if strings.Contains(got, c.field+" ") {
pattern := regexp.MustCompile(c.field + `( like | ?= ?)`)
if pattern.MatchString(got) {
t.Errorf("Preprocess(%q) = %q, still contains the original `<field> like/=` pattern", c.in, got)
}
}
if !strings.Contains(got, c.expect) {
t.Errorf("Preprocess(%q) = %q, want substring %q", c.in, got, c.expect)
}
if !strings.Contains(got, "minimum_should_match=30") {
t.Errorf("Preprocess(%q) = %q, want substring minimum_should_match=30", c.in, got)
}
})
}
}
func TestPreprocess_NoMatchLeavesSQLAlone(t *testing.T) {
in := "SELECT doc_id FROM ragflow_t1"
got := Preprocess(in)
if got != in {
t.Errorf("Preprocess(%q) = %q, want unchanged", in, got)
}
}
// fakeNetTimeoutErr implements net.Error with Timeout()==true.
type fakeNetTimeoutErr struct{}
func (fakeNetTimeoutErr) Error() string { return "i/o timeout" }
func (fakeNetTimeoutErr) Timeout() bool { return true }
func (fakeNetTimeoutErr) Temporary() bool { return true }
func TestIsTimeoutError(t *testing.T) {
cases := []struct {
name string
err error
want bool
}{
{"nil error", nil, false},
{"context.DeadlineExceeded", context.DeadlineExceeded, true},
{"wrapped context.DeadlineExceeded", fmt.Errorf("wrap: %w", context.DeadlineExceeded), true},
{"net.Error.Timeout()==true", fakeNetTimeoutErr{}, true},
{"wrapped net.Error.Timeout", fmt.Errorf("wrap: %w", fakeNetTimeoutErr{}), true},
{"plain string 'i/o timeout'", errors.New("read tcp: i/o timeout"), true},
{"plain string 'deadline exceeded'", errors.New("context deadline exceeded"), true},
{"plain string 'connection timeout'", errors.New("connection timeout while reading"), true},
{"unrelated error", errors.New("parse: invalid character"), false},
{"EOF is not a timeout", errors.New("EOF"), false},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
if got := isTimeoutError(c.err); got != c.want {
t.Errorf("isTimeoutError(%v) = %v, want %v", c.err, got, c.want)
}
})
}
}
func TestIsTimeoutError_NonTimeoutNetError(t *testing.T) {
e := &net.OpError{
Op: "dial",
Err: errors.New("connection refused"),
}
if isTimeoutError(e) {
t.Errorf("isTimeoutError(connection-refused) = true, want false")
}
}