mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-02 16:55:42 +08:00
fix: normalize reasoning model families (#15612)
### What problem does this PR solve? Closes #15611. RAGFlow's fallback reasoning parser only recognized the exact model family `qwen3`. For provider-prefixed Qwen model names such as SiliconFlow's `qwen/qwen3-8b`, the derived model class can be `qwen/qwen3`, so inline `<think>...</think>` content was not split from the visible answer when `reasoning_content` was absent. This PR normalizes model-family detection before fallback reasoning extraction, keeps the parser nil-safe, and adds focused tests for Qwen3 variants plus Gitee and SiliconFlow chat responses. It also makes SiliconFlow propagate `ChatConfig.Thinking` into the chat request body, matching the existing Gitee behavior, so Qwen thinking mode is actually enabled when requested. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring ### Validation - `/root/go/bin/gofmt -l internal/entity/models/common.go internal/entity/models/common_test.go internal/entity/models/reasoning_family_provider_test.go internal/entity/models/siliconflow.go` - `git diff --check` - `/root/go/bin/go test ./internal/entity/models -run 'Test(NormalizeModelFamily|GetThinkingAndAnswer|GiteeChatExtractsQwenThinkingFromInlineContent|SiliconflowChatExtractsProviderPrefixedQwenThinkingFromInlineContent)' -vet=off -count=1` Note: the full package command `/root/go/bin/go test ./internal/entity/models -vet=off -count=1` now runs locally, but it currently fails on an unrelated existing `TestAstraflowEmbedReturnsNoSuchMethod` panic in `internal/entity/models/astraflow.go:482`.
This commit is contained in:
@@ -19,14 +19,48 @@ package models
|
||||
import "strings"
|
||||
|
||||
func GetThinkingAndAnswer(modelType *string, content *string) (*string, *string) {
|
||||
switch *modelType {
|
||||
if content == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch NormalizeModelFamily(modelType) {
|
||||
case "qwen3":
|
||||
return extractThinkContent(content)
|
||||
}
|
||||
return nil, content
|
||||
}
|
||||
|
||||
// NormalizeModelFamily normalizes provider-prefixed model class/name strings for shared response parsing.
|
||||
func NormalizeModelFamily(modelType *string) string {
|
||||
if modelType == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
family := strings.ToLower(strings.TrimSpace(*modelType))
|
||||
if family == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if slash := strings.LastIndex(family, "/"); slash >= 0 && slash < len(family)-1 {
|
||||
family = family[slash+1:]
|
||||
}
|
||||
|
||||
if strings.HasPrefix(family, "qwen3") {
|
||||
return "qwen3"
|
||||
}
|
||||
|
||||
if dash := strings.Index(family, "-"); dash >= 0 {
|
||||
family = family[:dash]
|
||||
}
|
||||
|
||||
return family
|
||||
}
|
||||
|
||||
func extractThinkContent(content *string) (*string, *string) {
|
||||
if content == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
startTag := "<think>"
|
||||
endTag := "</think>"
|
||||
|
||||
|
||||
84
internal/entity/models/common_test.go
Normal file
84
internal/entity/models/common_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package models
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeModelFamily(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input *string
|
||||
want string
|
||||
}{
|
||||
{name: "nil", input: nil, want: ""},
|
||||
{name: "empty", input: modelFamilyTestString(""), want: ""},
|
||||
{name: "qwen3", input: modelFamilyTestString("qwen3"), want: "qwen3"},
|
||||
{name: "qwen3 variant", input: modelFamilyTestString("qwen3-8b"), want: "qwen3"},
|
||||
{name: "provider-prefixed qwen3", input: modelFamilyTestString("qwen/qwen3-8b"), want: "qwen3"},
|
||||
{name: "case-varied qwen3", input: modelFamilyTestString("Qwen/Qwen3.5-4B"), want: "qwen3"},
|
||||
{name: "provider-prefixed non-qwen", input: modelFamilyTestString("deepseek/deepseek-r1"), want: "deepseek"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := NormalizeModelFamily(tt.input); got != tt.want {
|
||||
t.Fatalf("NormalizeModelFamily()=%q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetThinkingAndAnswerExtractsQwenThinking(t *testing.T) {
|
||||
tests := []string{
|
||||
"qwen3",
|
||||
"qwen3-8b",
|
||||
"qwen/qwen3",
|
||||
"qwen/qwen3-8b",
|
||||
"Qwen/Qwen3.5-4B",
|
||||
}
|
||||
|
||||
for _, modelFamily := range tests {
|
||||
t.Run(modelFamily, func(t *testing.T) {
|
||||
content := "<think>\nreasoning</think>\nanswer"
|
||||
reasoning, answer := GetThinkingAndAnswer(&modelFamily, &content)
|
||||
|
||||
if reasoning == nil || *reasoning != "reasoning" {
|
||||
t.Fatalf("reasoning=%v, want reasoning", reasoning)
|
||||
}
|
||||
if answer == nil || *answer != "answer" {
|
||||
t.Fatalf("answer=%v, want answer", answer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetThinkingAndAnswerSkipsUnknownFamily(t *testing.T) {
|
||||
modelFamily := "deepseek/deepseek-r1"
|
||||
content := "<think>reasoning</think>answer"
|
||||
|
||||
reasoning, answer := GetThinkingAndAnswer(&modelFamily, &content)
|
||||
if reasoning != nil {
|
||||
t.Fatalf("reasoning=%q, want nil", *reasoning)
|
||||
}
|
||||
if answer == nil || *answer != content {
|
||||
t.Fatalf("answer=%v, want original content", answer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetThinkingAndAnswerHandlesNilInputs(t *testing.T) {
|
||||
reasoning, answer := GetThinkingAndAnswer(nil, nil)
|
||||
if reasoning != nil || answer != nil {
|
||||
t.Fatalf("GetThinkingAndAnswer(nil, nil)=(%v, %v), want nils", reasoning, answer)
|
||||
}
|
||||
|
||||
content := "answer"
|
||||
reasoning, answer = GetThinkingAndAnswer(nil, &content)
|
||||
if reasoning != nil {
|
||||
t.Fatalf("reasoning=%q, want nil", *reasoning)
|
||||
}
|
||||
if answer == nil || *answer != content {
|
||||
t.Fatalf("answer=%v, want original content", answer)
|
||||
}
|
||||
}
|
||||
|
||||
func modelFamilyTestString(value string) *string {
|
||||
return &value
|
||||
}
|
||||
150
internal/entity/models/reasoning_family_provider_test.go
Normal file
150
internal/entity/models/reasoning_family_provider_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newReasoningFamilyChatServer(t *testing.T, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("method=%s, want POST", r.Method)
|
||||
return
|
||||
}
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
t.Errorf("path=%s, want /chat/completions", r.URL.Path)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
|
||||
t.Errorf("Authorization=%q, want Bearer test-key", got)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") {
|
||||
t.Errorf("Content-Type=%q, want application/json", got)
|
||||
return
|
||||
}
|
||||
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("read body: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &body); err != nil {
|
||||
t.Errorf("unmarshal body: %v\nraw=%s", err, string(raw))
|
||||
return
|
||||
}
|
||||
|
||||
handler(t, body, w)
|
||||
}))
|
||||
}
|
||||
|
||||
func TestGiteeChatExtractsQwenThinkingFromInlineContent(t *testing.T) {
|
||||
srv := newReasoningFamilyChatServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
if body["model"] != "qwen3-8b" {
|
||||
t.Errorf("model=%v, want qwen3-8b", body["model"])
|
||||
}
|
||||
if !assertThinkingEnabled(t, body) {
|
||||
return
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"choices": []map[string]interface{}{{
|
||||
"message": map[string]interface{}{
|
||||
"content": "<think>\nreasoning</think>\nanswer",
|
||||
},
|
||||
}},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
thinking := true
|
||||
modelClass := "qwen3-8b"
|
||||
resp, err := NewGiteeModel(
|
||||
map[string]string{"default": srv.URL},
|
||||
URLSuffix{Chat: "chat/completions"},
|
||||
).ChatWithMessages(
|
||||
"qwen3-8b",
|
||||
[]Message{{Role: "user", Content: "ping"}},
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
&ChatConfig{Thinking: &thinking, ModelClass: &modelClass},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("ChatWithMessages: %v", err)
|
||||
}
|
||||
assertThinkingResponse(t, resp)
|
||||
}
|
||||
|
||||
func TestSiliconflowChatExtractsProviderPrefixedQwenThinkingFromInlineContent(t *testing.T) {
|
||||
srv := newReasoningFamilyChatServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
if body["model"] != "qwen/qwen3-8b" {
|
||||
t.Errorf("model=%v, want qwen/qwen3-8b", body["model"])
|
||||
}
|
||||
if !assertThinkingEnabled(t, body) {
|
||||
return
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"choices": []map[string]interface{}{{
|
||||
"message": map[string]interface{}{
|
||||
"content": "<think>\nreasoning</think>\nanswer",
|
||||
},
|
||||
}},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
thinking := true
|
||||
modelClass := "qwen/qwen3-8b"
|
||||
resp, err := NewSiliconflowModel(
|
||||
map[string]string{"default": srv.URL},
|
||||
URLSuffix{Chat: "chat/completions"},
|
||||
).ChatWithMessages(
|
||||
"qwen/qwen3-8b",
|
||||
[]Message{{Role: "user", Content: "ping"}},
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
&ChatConfig{Thinking: &thinking, ModelClass: &modelClass},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("ChatWithMessages: %v", err)
|
||||
}
|
||||
assertThinkingResponse(t, resp)
|
||||
}
|
||||
|
||||
func assertThinkingEnabled(t *testing.T, body map[string]interface{}) bool {
|
||||
t.Helper()
|
||||
|
||||
thinking, ok := body["thinking"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Errorf("thinking=%#v, want object", body["thinking"])
|
||||
return false
|
||||
}
|
||||
if thinking["type"] != "enabled" {
|
||||
t.Errorf("thinking.type=%v, want enabled", thinking["type"])
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func assertThinkingResponse(t *testing.T, resp *ChatResponse) {
|
||||
t.Helper()
|
||||
|
||||
if resp == nil {
|
||||
t.Fatal("response is nil")
|
||||
}
|
||||
if resp.Answer == nil || *resp.Answer != "answer" {
|
||||
t.Fatalf("Answer=%v, want answer", resp.Answer)
|
||||
}
|
||||
if resp.ReasonContent == nil || *resp.ReasonContent != "reasoning" {
|
||||
t.Fatalf("ReasonContent=%v, want reasoning", resp.ReasonContent)
|
||||
}
|
||||
}
|
||||
@@ -128,6 +128,18 @@ func (s *SiliconflowModel) ChatWithMessages(modelName string, messages []Message
|
||||
if chatModelConfig.Stop != nil {
|
||||
reqBody["stop"] = *chatModelConfig.Stop
|
||||
}
|
||||
|
||||
if chatModelConfig.Thinking != nil {
|
||||
if *chatModelConfig.Thinking {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "enabled",
|
||||
}
|
||||
} else {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "disabled",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
|
||||
Reference in New Issue
Block a user