mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
fix: SSE write timeout (#15852)
### What problem does this PR solve? Fixes #15840. The Go HTTP server sets `WriteTimeout: 120s`, which also applies to long-lived SSE responses. Existing Go streaming handlers did not clear the per-response write deadline, so streams that run longer than the server timeout can be terminated mid-response. This PR adds a small handler helper that clears the response write deadline for SSE requests and calls it only in existing Go streaming branches: - conversation completion streaming - provider chat streaming - provider transcription streaming - provider speech streaming The global server `WriteTimeout` remains unchanged for non-streaming requests. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) ### Test plan - `/root/go/bin/go test ./internal/handler -run TestDisableWriteDeadlineForSSEAllowsLongLivedStream -count=1` - `/root/go/bin/go test ./internal/handler -count=1`
This commit is contained in:
@@ -285,6 +285,7 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
|
||||
// Call service
|
||||
if req.Stream != nil && *req.Stream {
|
||||
// Streaming response
|
||||
disableWriteDeadlineForSSE(c)
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
@@ -967,6 +967,7 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
|
||||
// Check if it's a stream request
|
||||
if req.Stream {
|
||||
// Set SSE headers
|
||||
disableWriteDeadlineForSSE(c)
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
@@ -1256,6 +1257,7 @@ func (h *ProviderHandler) TranscribeAudio(c *gin.Context) {
|
||||
// Check if it's a stream request
|
||||
if req.Stream {
|
||||
// Set SSE headers
|
||||
disableWriteDeadlineForSSE(c)
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
@@ -1375,6 +1377,7 @@ func (h *ProviderHandler) AudioSpeech(c *gin.Context) {
|
||||
// Check if it's a stream request
|
||||
if req.Stream {
|
||||
// Set SSE headers
|
||||
disableWriteDeadlineForSSE(c)
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
@@ -374,6 +374,7 @@ func (h *SearchBotHandler) Ask(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
disableWriteDeadlineForSSE(c)
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
@@ -21,17 +21,22 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
"ragflow/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// mockChunkService implements ChunkRetriever for testing.
|
||||
@@ -637,12 +642,34 @@ func TestAskHandler_MissingKbIDs(t *testing.T) {
|
||||
type fakeStreamingLLM struct {
|
||||
chunks []string
|
||||
err error
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func (f *fakeStreamingLLM) ChatStream(_ context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) {
|
||||
func (f *fakeStreamingLLM) ChatStream(ctx context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) {
|
||||
if f.err != nil {
|
||||
return nil, f.err
|
||||
}
|
||||
if f.delay > 0 {
|
||||
ch := make(chan string)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
for i, chunk := range f.chunks {
|
||||
if i > 0 {
|
||||
select {
|
||||
case <-time.After(f.delay):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
select {
|
||||
case ch <- chunk:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return ch, nil
|
||||
}
|
||||
ch := make(chan string, len(f.chunks)+1)
|
||||
for _, c := range f.chunks {
|
||||
ch <- c
|
||||
@@ -680,6 +707,108 @@ func (w *bufferSSEWriter) Write(_ *gin.Context, data string) {
|
||||
}
|
||||
|
||||
func (w *bufferSSEWriter) String() string { return w.buf.String() }
|
||||
|
||||
func setupAskHandlerTenantDB(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{
|
||||
TranslateError: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite: %v", err)
|
||||
}
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get sqlite db: %v", err)
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
|
||||
if err := db.AutoMigrate(&entity.Tenant{}); err != nil {
|
||||
t.Fatalf("failed to migrate tenant table: %v", err)
|
||||
}
|
||||
|
||||
status := "1"
|
||||
name := "Test Tenant"
|
||||
if err := db.Create(&entity.Tenant{
|
||||
ID: "user-1",
|
||||
Name: &name,
|
||||
LLMID: "test-model",
|
||||
EmbdID: "test-embedding",
|
||||
ASRID: "test-asr",
|
||||
Img2TxtID: "test-image",
|
||||
RerankID: "test-rerank",
|
||||
ParserIDs: "naive",
|
||||
Status: &status,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to create tenant: %v", err)
|
||||
}
|
||||
|
||||
orig := dao.DB
|
||||
dao.DB = db
|
||||
t.Cleanup(func() {
|
||||
dao.DB = orig
|
||||
_ = sqlDB.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func TestAskHandler_DisablesWriteDeadlineForSSE(t *testing.T) {
|
||||
setupAskHandlerTenantDB(t)
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{
|
||||
Chunks: []map[string]interface{}{
|
||||
{"id": "c1", "content_with_weight": "test chunk", "docnm_kwd": "Doc", "kb_id": "kb1", "doc_id": "d1"},
|
||||
},
|
||||
DocAggs: []map[string]interface{}{{"doc_id": "d1", "count": 1}},
|
||||
}}
|
||||
llm := &fakeStreamingLLM{
|
||||
chunks: []string{"first response chunk", "second response chunk"},
|
||||
delay: 120 * time.Millisecond,
|
||||
}
|
||||
h := NewSearchBotHandler(nil, service.NewTenantService(), nil, ret)
|
||||
h.SetStreamLLM(llm)
|
||||
h.SetAskService(service.NewAskService(ret, nil, 0, 1))
|
||||
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: "user-1"})
|
||||
})
|
||||
router.POST("/api/v1/searchbots/ask", h.Ask)
|
||||
|
||||
server := httptest.NewUnstartedServer(router)
|
||||
server.Config.WriteTimeout = 30 * time.Millisecond
|
||||
server.Start()
|
||||
defer server.Close()
|
||||
|
||||
client := server.Client()
|
||||
client.Timeout = time.Second
|
||||
resp, err := client.Post(server.URL+"/api/v1/searchbots/ask", "application/json",
|
||||
strings.NewReader(`{"question": "test", "kb_ids": ["kb1"]}`))
|
||||
if err != nil {
|
||||
t.Fatalf("post ask stream: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
if contentType := resp.Header.Get("Content-Type"); !strings.Contains(contentType, "text/event-stream") {
|
||||
t.Fatalf("expected SSE content type, got %q", contentType)
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read ask stream body: %v", err)
|
||||
}
|
||||
|
||||
body := string(bodyBytes)
|
||||
for _, want := range []string{"first response chunk", "second response chunk"} {
|
||||
if !strings.Contains(body, want) {
|
||||
t.Fatalf("stream body missing %q: %q", want, body)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Ask handler tests ----
|
||||
|
||||
func TestAskHandler_EmptyQuestion(t *testing.T) {
|
||||
|
||||
28
internal/handler/streaming.go
Normal file
28
internal/handler/streaming.go
Normal file
@@ -0,0 +1,28 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func disableWriteDeadlineForSSE(c *gin.Context) {
|
||||
_ = http.NewResponseController(c.Writer).SetWriteDeadline(time.Time{})
|
||||
}
|
||||
79
internal/handler/streaming_test.go
Normal file
79
internal/handler/streaming_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestDisableWriteDeadlineForSSEAllowsLongLivedStream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.GET("/stream", func(c *gin.Context) {
|
||||
disableWriteDeadlineForSSE(c)
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
c.Writer.Flush()
|
||||
|
||||
if _, err := c.Writer.Write([]byte("data: first\n\n")); err != nil {
|
||||
t.Errorf("write first chunk: %v", err)
|
||||
return
|
||||
}
|
||||
c.Writer.Flush()
|
||||
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
|
||||
if _, err := c.Writer.Write([]byte("data: second\n\n")); err != nil {
|
||||
t.Errorf("write second chunk: %v", err)
|
||||
return
|
||||
}
|
||||
c.Writer.Flush()
|
||||
})
|
||||
|
||||
server := httptest.NewUnstartedServer(router)
|
||||
server.Config.WriteTimeout = 30 * time.Millisecond
|
||||
server.Start()
|
||||
defer server.Close()
|
||||
|
||||
client := server.Client()
|
||||
client.Timeout = time.Second
|
||||
resp, err := client.Get(server.URL + "/stream")
|
||||
if err != nil {
|
||||
t.Fatalf("get stream: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read stream body: %v", err)
|
||||
}
|
||||
|
||||
body := string(bodyBytes)
|
||||
for _, want := range []string{"data: first", "data: second"} {
|
||||
if !strings.Contains(body, want) {
|
||||
t.Fatalf("stream body missing %q: %q", want, body)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user