diff --git a/internal/handler/chat_session.go b/internal/handler/chat_session.go index 882d3da87b..3c395e88fc 100644 --- a/internal/handler/chat_session.go +++ b/internal/handler/chat_session.go @@ -285,6 +285,7 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) { // Call service if req.Stream != nil && *req.Stream { // Streaming response + disableWriteDeadlineForSSE(c) c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") diff --git a/internal/handler/providers.go b/internal/handler/providers.go index 5a13475723..ddf2f05690 100644 --- a/internal/handler/providers.go +++ b/internal/handler/providers.go @@ -967,6 +967,7 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) { // Check if it's a stream request if req.Stream { // Set SSE headers + disableWriteDeadlineForSSE(c) c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") @@ -1256,6 +1257,7 @@ func (h *ProviderHandler) TranscribeAudio(c *gin.Context) { // Check if it's a stream request if req.Stream { // Set SSE headers + disableWriteDeadlineForSSE(c) c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") @@ -1375,6 +1377,7 @@ func (h *ProviderHandler) AudioSpeech(c *gin.Context) { // Check if it's a stream request if req.Stream { // Set SSE headers + disableWriteDeadlineForSSE(c) c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") diff --git a/internal/handler/searchbot.go b/internal/handler/searchbot.go index 72c427c0d9..56ec1eb091 100644 --- a/internal/handler/searchbot.go +++ b/internal/handler/searchbot.go @@ -374,6 +374,7 @@ func (h *SearchBotHandler) Ask(c *gin.Context) { return } + disableWriteDeadlineForSSE(c) c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") diff --git a/internal/handler/searchbot_test.go b/internal/handler/searchbot_test.go index 8592b0feac..20602defc1 100644 --- a/internal/handler/searchbot_test.go +++ b/internal/handler/searchbot_test.go @@ -21,17 +21,22 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "net/http/httptest" "strings" "testing" + "time" "ragflow/internal/common" + "ragflow/internal/dao" "ragflow/internal/entity" modelModule "ragflow/internal/entity/models" "ragflow/internal/service" "github.com/gin-gonic/gin" + "github.com/glebarez/sqlite" + "gorm.io/gorm" ) // mockChunkService implements ChunkRetriever for testing. @@ -637,12 +642,34 @@ func TestAskHandler_MissingKbIDs(t *testing.T) { type fakeStreamingLLM struct { chunks []string err error + delay time.Duration } -func (f *fakeStreamingLLM) ChatStream(_ context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) { +func (f *fakeStreamingLLM) ChatStream(ctx context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) { if f.err != nil { return nil, f.err } + if f.delay > 0 { + ch := make(chan string) + go func() { + defer close(ch) + for i, chunk := range f.chunks { + if i > 0 { + select { + case <-time.After(f.delay): + case <-ctx.Done(): + return + } + } + select { + case ch <- chunk: + case <-ctx.Done(): + return + } + } + }() + return ch, nil + } ch := make(chan string, len(f.chunks)+1) for _, c := range f.chunks { ch <- c @@ -680,6 +707,108 @@ func (w *bufferSSEWriter) Write(_ *gin.Context, data string) { } func (w *bufferSSEWriter) String() string { return w.buf.String() } + +func setupAskHandlerTenantDB(t *testing.T) { + t.Helper() + + db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{ + TranslateError: true, + }) + if err != nil { + t.Fatalf("failed to open sqlite: %v", err) + } + sqlDB, err := db.DB() + if err != nil { + t.Fatalf("failed to get sqlite db: %v", err) + } + sqlDB.SetMaxOpenConns(1) + + if err := db.AutoMigrate(&entity.Tenant{}); err != nil { + t.Fatalf("failed to migrate tenant table: %v", err) + } + + status := "1" + name := "Test Tenant" + if err := db.Create(&entity.Tenant{ + ID: "user-1", + Name: &name, + LLMID: "test-model", + EmbdID: "test-embedding", + ASRID: "test-asr", + Img2TxtID: "test-image", + RerankID: "test-rerank", + ParserIDs: "naive", + Status: &status, + }).Error; err != nil { + t.Fatalf("failed to create tenant: %v", err) + } + + orig := dao.DB + dao.DB = db + t.Cleanup(func() { + dao.DB = orig + _ = sqlDB.Close() + }) +} + +func TestAskHandler_DisablesWriteDeadlineForSSE(t *testing.T) { + setupAskHandlerTenantDB(t) + gin.SetMode(gin.TestMode) + + ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{ + Chunks: []map[string]interface{}{ + {"id": "c1", "content_with_weight": "test chunk", "docnm_kwd": "Doc", "kb_id": "kb1", "doc_id": "d1"}, + }, + DocAggs: []map[string]interface{}{{"doc_id": "d1", "count": 1}}, + }} + llm := &fakeStreamingLLM{ + chunks: []string{"first response chunk", "second response chunk"}, + delay: 120 * time.Millisecond, + } + h := NewSearchBotHandler(nil, service.NewTenantService(), nil, ret) + h.SetStreamLLM(llm) + h.SetAskService(service.NewAskService(ret, nil, 0, 1)) + + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: "user-1"}) + }) + router.POST("/api/v1/searchbots/ask", h.Ask) + + server := httptest.NewUnstartedServer(router) + server.Config.WriteTimeout = 30 * time.Millisecond + server.Start() + defer server.Close() + + client := server.Client() + client.Timeout = time.Second + resp, err := client.Post(server.URL+"/api/v1/searchbots/ask", "application/json", + strings.NewReader(`{"question": "test", "kb_ids": ["kb1"]}`)) + if err != nil { + t.Fatalf("post ask stream: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d", resp.StatusCode) + } + if contentType := resp.Header.Get("Content-Type"); !strings.Contains(contentType, "text/event-stream") { + t.Fatalf("expected SSE content type, got %q", contentType) + } + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read ask stream body: %v", err) + } + + body := string(bodyBytes) + for _, want := range []string{"first response chunk", "second response chunk"} { + if !strings.Contains(body, want) { + t.Fatalf("stream body missing %q: %q", want, body) + } + } +} + // ---- Ask handler tests ---- func TestAskHandler_EmptyQuestion(t *testing.T) { diff --git a/internal/handler/streaming.go b/internal/handler/streaming.go new file mode 100644 index 0000000000..b59f8e30a6 --- /dev/null +++ b/internal/handler/streaming.go @@ -0,0 +1,28 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + "time" + + "github.com/gin-gonic/gin" +) + +func disableWriteDeadlineForSSE(c *gin.Context) { + _ = http.NewResponseController(c.Writer).SetWriteDeadline(time.Time{}) +} diff --git a/internal/handler/streaming_test.go b/internal/handler/streaming_test.go new file mode 100644 index 0000000000..1237a7abb5 --- /dev/null +++ b/internal/handler/streaming_test.go @@ -0,0 +1,79 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" +) + +func TestDisableWriteDeadlineForSSEAllowsLongLivedStream(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.GET("/stream", func(c *gin.Context) { + disableWriteDeadlineForSSE(c) + c.Header("Content-Type", "text/event-stream") + c.Writer.WriteHeader(http.StatusOK) + c.Writer.Flush() + + if _, err := c.Writer.Write([]byte("data: first\n\n")); err != nil { + t.Errorf("write first chunk: %v", err) + return + } + c.Writer.Flush() + + time.Sleep(120 * time.Millisecond) + + if _, err := c.Writer.Write([]byte("data: second\n\n")); err != nil { + t.Errorf("write second chunk: %v", err) + return + } + c.Writer.Flush() + }) + + server := httptest.NewUnstartedServer(router) + server.Config.WriteTimeout = 30 * time.Millisecond + server.Start() + defer server.Close() + + client := server.Client() + client.Timeout = time.Second + resp, err := client.Get(server.URL + "/stream") + if err != nil { + t.Fatalf("get stream: %v", err) + } + defer resp.Body.Close() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read stream body: %v", err) + } + + body := string(bodyBytes) + for _, want := range []string{"data: first", "data: second"} { + if !strings.Contains(body, want) { + t.Fatalf("stream body missing %q: %q", want, body) + } + } +}