mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-04 01:29:35 +08:00
### Summary Closes #15381 Every provider in `internal/entity/models/` reads its streaming response with `bufio.NewScanner(resp.Body)` and iterates over `scanner.Scan()`. The default `bufio.Scanner` maximum token size is 64KB, so when an upstream sends a single SSE `data:` line larger than 64KB (long content deltas, large tool or function call argument blobs, bundled `reasoning_content`, or providers that emit a whole message in one event) `scanner.Scan()` returns `false` and `scanner.Err()` returns `bufio.ErrTooLong`. Streaming chat then ends with an error partway through the response. This change adds `scanner.Buffer(make([]byte, 64*1024), 1024*1024)` immediately after every SSE scanner that was still bare, raising the cap to 1MB. 1MB is the value already used for streaming chat in `openai.go`, `modelscope.go`, `groq.go`, `mistral.go`, `xai.go` and the other already patched providers (the 8MB cap in the repo is reserved for TTS and embedding paths), so this simply converges the remaining providers onto the established pattern. Nothing else changes: line parsing, `data:` prefix handling, `[DONE]` detection, JSON unmarshalling, error handling, and the existing `scanner.Err()` checks all stay the same. Providers covered (23 scanners across 22 files): 302ai, aliyun, baichuan, baidu, cohere, deepinfra, deepseek, gitee, huggingface, lmstudio, minimax (the chat scanner, whose TTS scanner was already bumped), moonshot, nvidia, ollama, openrouter, orcarouter, paddleocr, siliconflow, tokenhub, vllm, volcengine, xunfei, zhipu-ai. `jiekouai.go` is excluded because it is covered by the in flight #15337. A table driven regression test (`sse_scanner_buffer_test.go`) streams a single 128KB `data:` content delta followed by `data: [DONE]` through an `httptest` server and asserts that `ChatStreamlyWithSender` delivers the full content with no error across a representative subset of providers. Without the buffer fix the test fails with `bufio.Scanner: token too long`. This PR also removes three duplicate declarations of the package level `roundTripperFunc` test helper that several recently merged provider PRs each added independently, which had left the `internal/entity/models` test package unable to compile. The helper now lives in a single place and is shared. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
302 lines
8.8 KiB
Go
302 lines
8.8 KiB
Go
package models
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type PaddleOCRModel struct {
|
|
BaseURL map[string]string
|
|
URLSuffix URLSuffix
|
|
httpClient *http.Client
|
|
}
|
|
|
|
func NewPaddleOCRModel(baseURL map[string]string, urlSuffix URLSuffix) *PaddleOCRModel {
|
|
return &PaddleOCRModel{
|
|
BaseURL: baseURL,
|
|
URLSuffix: urlSuffix,
|
|
httpClient: &http.Client{
|
|
Timeout: 120 * time.Second,
|
|
Transport: &http.Transport{
|
|
MaxIdleConns: 100,
|
|
MaxIdleConnsPerHost: 10,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
DisableCompression: false,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (p PaddleOCRModel) NewInstance(baseURL map[string]string) ModelDriver {
|
|
return &PaddleOCRModel{
|
|
BaseURL: baseURL,
|
|
URLSuffix: p.URLSuffix,
|
|
httpClient: &http.Client{
|
|
Timeout: 120 * time.Second,
|
|
Transport: &http.Transport{
|
|
MaxIdleConns: 100,
|
|
MaxIdleConnsPerHost: 10,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
DisableCompression: false,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (p *PaddleOCRModel) Name() string {
|
|
return "paddle_ocr.net"
|
|
}
|
|
|
|
func (p *PaddleOCRModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", p.Name())
|
|
}
|
|
|
|
func (p *PaddleOCRModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error {
|
|
return fmt.Errorf("%s, no such method", p.Name())
|
|
}
|
|
|
|
func (p *PaddleOCRModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
|
|
return nil, fmt.Errorf("%s, no such method", p.Name())
|
|
}
|
|
|
|
func (p *PaddleOCRModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", p.Name())
|
|
}
|
|
|
|
func (p *PaddleOCRModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", p.Name())
|
|
}
|
|
|
|
func (p *PaddleOCRModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error {
|
|
return fmt.Errorf("%s, no such method", p.Name())
|
|
}
|
|
|
|
func (p *PaddleOCRModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", p.Name())
|
|
}
|
|
|
|
func (p *PaddleOCRModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {
|
|
return fmt.Errorf("%s, no such method", p.Name())
|
|
}
|
|
|
|
type paddleSubmitResponse struct {
|
|
Data struct {
|
|
JobId string `json:"jobId"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
type paddlePollResponse struct {
|
|
Data struct {
|
|
State string `json:"state"`
|
|
ErrorMsg string `json:"errorMsg"`
|
|
ResultUrl struct {
|
|
JsonUrl string `json:"jsonUrl"`
|
|
} `json:"resultUrl"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
type paddleJsonlLine struct {
|
|
Result struct {
|
|
LayoutParsingResults []struct {
|
|
Markdown struct {
|
|
Text string `json:"text"`
|
|
} `json:"markdown"`
|
|
} `json:"layoutParsingResults"`
|
|
} `json:"result"`
|
|
}
|
|
|
|
func (p *PaddleOCRModel) OCRFile(modelName *string, content []byte, fileURL *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) {
|
|
if (content == nil || len(content) == 0) && (fileURL == nil || *fileURL == "") {
|
|
return nil, fmt.Errorf("content and fileURL cannot be both empty")
|
|
}
|
|
|
|
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
|
return nil, fmt.Errorf("api key is required")
|
|
}
|
|
|
|
var region = "default"
|
|
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
|
region = *apiConfig.Region
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/%s", p.BaseURL[region], p.URLSuffix.OCR)
|
|
|
|
optionalPayload := map[string]bool{
|
|
"useDocOrientationClassify": false,
|
|
"useDocUnwarping": false,
|
|
"useChartRecognition": false,
|
|
}
|
|
optBytes, _ := json.Marshal(optionalPayload)
|
|
|
|
var req *http.Request
|
|
var err error
|
|
|
|
if fileURL != nil && strings.HasPrefix(*fileURL, "http") {
|
|
reqData := map[string]interface{}{
|
|
"fileUrl": *fileURL,
|
|
"model": *modelName,
|
|
"optionalPayload": optionalPayload,
|
|
}
|
|
jsonData, err := json.Marshal(reqData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal json: %w", err)
|
|
}
|
|
req, err = http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
} else {
|
|
body := &bytes.Buffer{}
|
|
writer := multipart.NewWriter(body)
|
|
|
|
_ = writer.WriteField("model", *modelName)
|
|
_ = writer.WriteField("optionalPayload", string(optBytes))
|
|
|
|
part, err := writer.CreateFormFile("file", "document.pdf")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create form file: %w", err)
|
|
}
|
|
part.Write(content)
|
|
writer.Close()
|
|
|
|
req, err = http.NewRequest("POST", url, body)
|
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
|
}
|
|
|
|
req.Header.Set("Authorization", fmt.Sprintf("bearer %s", *apiConfig.ApiKey))
|
|
|
|
resp, err := p.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to submit job: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("submit job failed: %s", string(respBody))
|
|
}
|
|
|
|
var submitResp paddleSubmitResponse
|
|
if err := json.Unmarshal(respBody, &submitResp); err != nil {
|
|
return nil, fmt.Errorf("failed to parse submit response: %w", err)
|
|
}
|
|
|
|
jobId := submitResp.Data.JobId
|
|
if jobId == "" {
|
|
return nil, fmt.Errorf("failed to get jobId from response")
|
|
}
|
|
|
|
pollUrl := fmt.Sprintf("%s/%s", url, jobId)
|
|
var jsonlUrl string
|
|
|
|
for {
|
|
time.Sleep(3 * time.Second)
|
|
|
|
pollReq, _ := http.NewRequest("GET", pollUrl, nil)
|
|
pollReq.Header.Set("Authorization", fmt.Sprintf("bearer %s", *apiConfig.ApiKey))
|
|
|
|
pollResp, err := p.httpClient.Do(pollReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to poll job status: %w", err)
|
|
}
|
|
|
|
pollBody, _ := io.ReadAll(pollResp.Body)
|
|
pollResp.Body.Close()
|
|
|
|
if pollResp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("poll job failed: %s", string(pollBody))
|
|
}
|
|
|
|
var pollData paddlePollResponse
|
|
if err = json.Unmarshal(pollBody, &pollData); err != nil {
|
|
return nil, fmt.Errorf("failed to parse poll response: %w", err)
|
|
}
|
|
|
|
// end if 'done' or 'failed'
|
|
state := pollData.Data.State
|
|
if state == "done" {
|
|
jsonlUrl = pollData.Data.ResultUrl.JsonUrl
|
|
break
|
|
} else if state == "failed" {
|
|
return nil, fmt.Errorf("ocr job failed on server: %s", pollData.Data.ErrorMsg)
|
|
}
|
|
}
|
|
|
|
if jsonlUrl == "" {
|
|
return nil, fmt.Errorf("job done but jsonl url is empty")
|
|
}
|
|
|
|
resReq, err := http.NewRequest("GET", jsonlUrl, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request for jsonl: %w", err)
|
|
}
|
|
|
|
resResp, err := p.httpClient.Do(resReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to download jsonl result: %w", err)
|
|
}
|
|
defer resResp.Body.Close()
|
|
|
|
if resResp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("failed to download jsonl, status: %d", resResp.StatusCode)
|
|
}
|
|
|
|
var fullMarkdown strings.Builder
|
|
scanner := bufio.NewScanner(resResp.Body)
|
|
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
|
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if line == "" {
|
|
continue
|
|
}
|
|
|
|
var lineData paddleJsonlLine
|
|
if err := json.Unmarshal([]byte(line), &lineData); err != nil {
|
|
continue
|
|
}
|
|
|
|
for _, layoutRes := range lineData.Result.LayoutParsingResults {
|
|
fullMarkdown.WriteString(layoutRes.Markdown.Text)
|
|
fullMarkdown.WriteString("\n\n")
|
|
}
|
|
}
|
|
|
|
if err = scanner.Err(); err != nil {
|
|
return nil, fmt.Errorf("error reading jsonl: %w", err)
|
|
}
|
|
|
|
extractedText := strings.TrimSpace(fullMarkdown.String())
|
|
|
|
return &OCRFileResponse{Text: &extractedText}, nil
|
|
}
|
|
|
|
func (p *PaddleOCRModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", p.Name())
|
|
}
|
|
|
|
func (p *PaddleOCRModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
|
return nil, fmt.Errorf("%s, no such method", p.Name())
|
|
}
|
|
|
|
func (p *PaddleOCRModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
|
|
return nil, fmt.Errorf("%s, no such method", p.Name())
|
|
}
|
|
|
|
func (p *PaddleOCRModel) CheckConnection(apiConfig *APIConfig) error {
|
|
return fmt.Errorf("%s, no such method", p.Name())
|
|
}
|
|
|
|
func (p *PaddleOCRModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) {
|
|
return nil, fmt.Errorf("%s, no such method", p.Name())
|
|
}
|
|
|
|
func (p *PaddleOCRModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", p.Name())
|
|
}
|