fix: SSE write timeout (#15852)

### What problem does this PR solve?

Fixes #15840.

The Go HTTP server sets `WriteTimeout: 120s`, which also applies to
long-lived SSE responses. Existing Go streaming handlers did not clear
the per-response write deadline, so streams that run longer than the
server timeout can be terminated mid-response.

This PR adds a small handler helper that clears the response write
deadline for SSE requests and calls it only in existing Go streaming
branches:

- conversation completion streaming
- provider chat streaming
- provider transcription streaming
- provider speech streaming

The global server `WriteTimeout` remains unchanged for non-streaming
requests.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

### Test plan

- `/root/go/bin/go test ./internal/handler -run
TestDisableWriteDeadlineForSSEAllowsLongLivedStream -count=1`
- `/root/go/bin/go test ./internal/handler -count=1`
This commit is contained in:
bitloi
2026-06-12 09:49:34 -03:00
committed by GitHub
parent 234f1b7cff
commit cafa0f2e4f
6 changed files with 242 additions and 1 deletions

View File

@@ -285,6 +285,7 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
// Call service
if req.Stream != nil && *req.Stream {
// Streaming response
disableWriteDeadlineForSSE(c)
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")

View File

@@ -967,6 +967,7 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
// Check if it's a stream request
if req.Stream {
// Set SSE headers
disableWriteDeadlineForSSE(c)
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
@@ -1256,6 +1257,7 @@ func (h *ProviderHandler) TranscribeAudio(c *gin.Context) {
// Check if it's a stream request
if req.Stream {
// Set SSE headers
disableWriteDeadlineForSSE(c)
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
@@ -1375,6 +1377,7 @@ func (h *ProviderHandler) AudioSpeech(c *gin.Context) {
// Check if it's a stream request
if req.Stream {
// Set SSE headers
disableWriteDeadlineForSSE(c)
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")

View File

@@ -374,6 +374,7 @@ func (h *SearchBotHandler) Ask(c *gin.Context) {
return
}
disableWriteDeadlineForSSE(c)
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")

View File

@@ -21,17 +21,22 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/entity"
modelModule "ragflow/internal/entity/models"
"ragflow/internal/service"
"github.com/gin-gonic/gin"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
)
// mockChunkService implements ChunkRetriever for testing.
@@ -637,12 +642,34 @@ func TestAskHandler_MissingKbIDs(t *testing.T) {
type fakeStreamingLLM struct {
chunks []string
err error
delay time.Duration
}
func (f *fakeStreamingLLM) ChatStream(_ context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) {
func (f *fakeStreamingLLM) ChatStream(ctx context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) {
if f.err != nil {
return nil, f.err
}
if f.delay > 0 {
ch := make(chan string)
go func() {
defer close(ch)
for i, chunk := range f.chunks {
if i > 0 {
select {
case <-time.After(f.delay):
case <-ctx.Done():
return
}
}
select {
case ch <- chunk:
case <-ctx.Done():
return
}
}
}()
return ch, nil
}
ch := make(chan string, len(f.chunks)+1)
for _, c := range f.chunks {
ch <- c
@@ -680,6 +707,108 @@ func (w *bufferSSEWriter) Write(_ *gin.Context, data string) {
}
func (w *bufferSSEWriter) String() string { return w.buf.String() }
func setupAskHandlerTenantDB(t *testing.T) {
t.Helper()
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{
TranslateError: true,
})
if err != nil {
t.Fatalf("failed to open sqlite: %v", err)
}
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("failed to get sqlite db: %v", err)
}
sqlDB.SetMaxOpenConns(1)
if err := db.AutoMigrate(&entity.Tenant{}); err != nil {
t.Fatalf("failed to migrate tenant table: %v", err)
}
status := "1"
name := "Test Tenant"
if err := db.Create(&entity.Tenant{
ID: "user-1",
Name: &name,
LLMID: "test-model",
EmbdID: "test-embedding",
ASRID: "test-asr",
Img2TxtID: "test-image",
RerankID: "test-rerank",
ParserIDs: "naive",
Status: &status,
}).Error; err != nil {
t.Fatalf("failed to create tenant: %v", err)
}
orig := dao.DB
dao.DB = db
t.Cleanup(func() {
dao.DB = orig
_ = sqlDB.Close()
})
}
func TestAskHandler_DisablesWriteDeadlineForSSE(t *testing.T) {
setupAskHandlerTenantDB(t)
gin.SetMode(gin.TestMode)
ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{
Chunks: []map[string]interface{}{
{"id": "c1", "content_with_weight": "test chunk", "docnm_kwd": "Doc", "kb_id": "kb1", "doc_id": "d1"},
},
DocAggs: []map[string]interface{}{{"doc_id": "d1", "count": 1}},
}}
llm := &fakeStreamingLLM{
chunks: []string{"first response chunk", "second response chunk"},
delay: 120 * time.Millisecond,
}
h := NewSearchBotHandler(nil, service.NewTenantService(), nil, ret)
h.SetStreamLLM(llm)
h.SetAskService(service.NewAskService(ret, nil, 0, 1))
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set("user", &entity.User{ID: "user-1"})
})
router.POST("/api/v1/searchbots/ask", h.Ask)
server := httptest.NewUnstartedServer(router)
server.Config.WriteTimeout = 30 * time.Millisecond
server.Start()
defer server.Close()
client := server.Client()
client.Timeout = time.Second
resp, err := client.Post(server.URL+"/api/v1/searchbots/ask", "application/json",
strings.NewReader(`{"question": "test", "kb_ids": ["kb1"]}`))
if err != nil {
t.Fatalf("post ask stream: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.StatusCode)
}
if contentType := resp.Header.Get("Content-Type"); !strings.Contains(contentType, "text/event-stream") {
t.Fatalf("expected SSE content type, got %q", contentType)
}
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read ask stream body: %v", err)
}
body := string(bodyBytes)
for _, want := range []string{"first response chunk", "second response chunk"} {
if !strings.Contains(body, want) {
t.Fatalf("stream body missing %q: %q", want, body)
}
}
}
// ---- Ask handler tests ----
func TestAskHandler_EmptyQuestion(t *testing.T) {

View File

@@ -0,0 +1,28 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
)
func disableWriteDeadlineForSSE(c *gin.Context) {
_ = http.NewResponseController(c.Writer).SetWriteDeadline(time.Time{})
}

View File

@@ -0,0 +1,79 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
)
func TestDisableWriteDeadlineForSSEAllowsLongLivedStream(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.GET("/stream", func(c *gin.Context) {
disableWriteDeadlineForSSE(c)
c.Header("Content-Type", "text/event-stream")
c.Writer.WriteHeader(http.StatusOK)
c.Writer.Flush()
if _, err := c.Writer.Write([]byte("data: first\n\n")); err != nil {
t.Errorf("write first chunk: %v", err)
return
}
c.Writer.Flush()
time.Sleep(120 * time.Millisecond)
if _, err := c.Writer.Write([]byte("data: second\n\n")); err != nil {
t.Errorf("write second chunk: %v", err)
return
}
c.Writer.Flush()
})
server := httptest.NewUnstartedServer(router)
server.Config.WriteTimeout = 30 * time.Millisecond
server.Start()
defer server.Close()
client := server.Client()
client.Timeout = time.Second
resp, err := client.Get(server.URL + "/stream")
if err != nil {
t.Fatalf("get stream: %v", err)
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read stream body: %v", err)
}
body := string(bodyBytes)
for _, want := range []string{"data: first", "data: second"} {
if !strings.Contains(body, want) {
t.Fatalf("stream body missing %q: %q", want, body)
}
}
}