diff --git a/internal/entity/models/jiekouai.go b/internal/entity/models/jiekouai.go index 459b8d959c..69c798fe31 100644 --- a/internal/entity/models/jiekouai.go +++ b/internal/entity/models/jiekouai.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "net/http" - "ragflow/internal/common" "strings" "time" ) @@ -54,9 +53,27 @@ func (j *JieKouAIModel) Name() string { return "jiekouai" } +func validateJieKouAIAPIKey(apiConfig *APIConfig) (string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || strings.TrimSpace(*apiConfig.ApiKey) == "" { + return "", fmt.Errorf("api key is required") + } + return strings.TrimSpace(*apiConfig.ApiKey), nil +} + +func validateJieKouAIModelName(modelName *string) (string, error) { + if modelName == nil || strings.TrimSpace(*modelName) == "" { + return "", fmt.Errorf("model name is required") + } + return strings.TrimSpace(*modelName), nil +} + func (j *JieKouAIModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { - if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { - return nil, fmt.Errorf("api key is required") + apiKey, err := validateJieKouAIAPIKey(apiConfig) + if err != nil { + return nil, err + } + if modelName = strings.TrimSpace(modelName); modelName == "" { + return nil, fmt.Errorf("model name is required") } if len(messages) == 0 { return nil, fmt.Errorf("messages is empty") @@ -87,10 +104,6 @@ func (j *JieKouAIModel) ChatWithMessages(modelName string, messages []Message, a } if chatModelConfig != nil { - if chatModelConfig.Stream != nil { - reqBody["stream"] = *chatModelConfig.Stream - } - if chatModelConfig.MaxTokens != nil { reqBody["max_tokens"] = *chatModelConfig.MaxTokens } @@ -126,7 +139,8 @@ func (j *JieKouAIModel) ChatWithMessages(modelName string, messages []Message, a } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + req.Header.Set("Accept", "application/json") resp, err := j.httpClient.Do(req) if err != nil { @@ -190,6 +204,16 @@ func (j *JieKouAIModel) ChatWithMessages(modelName string, messages []Message, a } func (j *JieKouAIModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + apiKey, err := validateJieKouAIAPIKey(apiConfig) + if err != nil { + return err + } + if modelName = strings.TrimSpace(modelName); modelName == "" { + return fmt.Errorf("model name is required") + } + if sender == nil { + return fmt.Errorf("sender is required") + } if len(messages) == 0 { return fmt.Errorf("messages is empty") } @@ -219,10 +243,6 @@ func (j *JieKouAIModel) ChatStreamlyWithSender(modelName string, messages []Mess } if modelConfig != nil { - if modelConfig.Stream != nil { - reqBody["stream"] = *modelConfig.Stream - } - if modelConfig.MaxTokens != nil { reqBody["max_tokens"] = *modelConfig.MaxTokens } @@ -261,7 +281,8 @@ func (j *JieKouAIModel) ChatStreamlyWithSender(modelName string, messages []Mess } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + req.Header.Set("Accept", "text/event-stream") resp, err := j.httpClient.Do(req) if err != nil { @@ -278,7 +299,6 @@ func (j *JieKouAIModel) ChatStreamlyWithSender(modelName string, messages []Mess scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { line := scanner.Text() - common.Info(line) // SSE data line starts with "data:" if !strings.HasPrefix(line, "data:") { @@ -347,6 +367,14 @@ func (j *JieKouAIModel) Embed(modelName *string, texts []string, apiConfig *APIC if len(texts) == 0 { return []EmbeddingData{}, fmt.Errorf("texts is empty") } + model, err := validateJieKouAIModelName(modelName) + if err != nil { + return nil, err + } + apiKey, err := validateJieKouAIAPIKey(apiConfig) + if err != nil { + return nil, err + } var region = "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { @@ -356,7 +384,7 @@ func (j *JieKouAIModel) Embed(modelName *string, texts []string, apiConfig *APIC url := fmt.Sprintf("%s/%s", strings.TrimSuffix(j.BaseURL[region], "/"), j.URLSuffix.Embedding) reqBody := map[string]interface{}{ - "model": *modelName, + "model": model, "input": texts, } @@ -371,7 +399,7 @@ func (j *JieKouAIModel) Embed(modelName *string, texts []string, apiConfig *APIC } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) resp, err := j.httpClient.Do(req) if err != nil { @@ -418,6 +446,17 @@ func (j *JieKouAIModel) Rerank(modelName *string, query string, documents []stri if len(documents) == 0 { return &RerankResponse{}, nil } + if strings.TrimSpace(query) == "" { + return nil, fmt.Errorf("query is required") + } + model, err := validateJieKouAIModelName(modelName) + if err != nil { + return nil, err + } + apiKey, err := validateJieKouAIAPIKey(apiConfig) + if err != nil { + return nil, err + } var region = "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { @@ -426,16 +465,13 @@ func (j *JieKouAIModel) Rerank(modelName *string, query string, documents []stri url := fmt.Sprintf("%s/%s", strings.TrimSuffix(j.BaseURL[region], "/"), j.URLSuffix.Rerank) - var topN = rerankConfig.TopN - if rerankConfig.TopN != 0 { - topN = rerankConfig.TopN - } - reqBody := map[string]interface{}{ - "model": *modelName, - "query": query, + "model": model, + "query": strings.TrimSpace(query), "documents": documents, - "top_n": topN, + } + if rerankConfig != nil && rerankConfig.TopN != 0 { + reqBody["top_n"] = rerankConfig.TopN } jsonData, err := json.Marshal(reqBody) @@ -449,7 +485,7 @@ func (j *JieKouAIModel) Rerank(modelName *string, query string, documents []stri } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) resp, err := j.httpClient.Do(req) if err != nil { @@ -514,6 +550,10 @@ func (j *JieKouAIModel) ParseFile(modelName *string, content []byte, url *string } func (j *JieKouAIModel) ListModels(apiConfig *APIConfig) ([]string, error) { + apiKey, err := validateJieKouAIAPIKey(apiConfig) + if err != nil { + return nil, err + } var region = "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region @@ -521,20 +561,13 @@ func (j *JieKouAIModel) ListModels(apiConfig *APIConfig) ([]string, error) { url := fmt.Sprintf("%s/%s", j.BaseURL[region], j.URLSuffix.Models) - reqBody := map[string]string{} - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData)) + req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + req.Header.Set("Accept", "application/json") resp, err := j.httpClient.Do(req) if err != nil { @@ -551,18 +584,24 @@ func (j *JieKouAIModel) ListModels(apiConfig *APIConfig) ([]string, error) { return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) } - // Parse response - var result map[string]interface{} + var result struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } if err = json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } + if result.Data == nil { + return nil, fmt.Errorf("models response missing data") + } - // convert result["data"] to []map[string]interface{} - models := make([]string, 0) - for _, model := range result["data"].([]interface{}) { - modelMap := model.(map[string]interface{}) - modelName := modelMap["id"].(string) - models = append(models, modelName) + models := make([]string, 0, len(result.Data)) + for _, model := range result.Data { + if strings.TrimSpace(model.ID) == "" { + return nil, fmt.Errorf("models response contains empty id") + } + models = append(models, strings.TrimSpace(model.ID)) } return models, nil diff --git a/internal/entity/models/jiekouai_test.go b/internal/entity/models/jiekouai_test.go new file mode 100644 index 0000000000..b95360268b --- /dev/null +++ b/internal/entity/models/jiekouai_test.go @@ -0,0 +1,370 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newJieKouAIServer(t *testing.T, handler func(t *testing.T, r *http.Request, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("expected Authorization=Bearer test-key, got %q", got) + return + } + + var body map[string]interface{} + if r.Method == http.MethodPost { + if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") { + t.Errorf("expected Content-Type to start with application/json, got %q", got) + return + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("unmarshal: %v\nraw=%s", err, string(raw)) + return + } + } else { + if r.ContentLength > 0 { + t.Errorf("expected %s request without body, ContentLength=%d", r.Method, r.ContentLength) + return + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + if len(raw) != 0 { + t.Errorf("expected %s request without body, got %q", r.Method, string(raw)) + return + } + } + + handler(t, r, body, w) + })) +} + +func newJieKouAIForTest(baseURL string) *JieKouAIModel { + return NewJieKouAIModel( + map[string]string{"default": baseURL}, + URLSuffix{ + Chat: "openai/v1/chat/completions", + Embedding: "openai/v1/embeddings", + Rerank: "openai/v1/rerank", + Models: "openai/v1/models", + }, + ) +} + +func TestJieKouAIChatForcesNonStreaming(t *testing.T) { + srv := newJieKouAIServer(t, func(t *testing.T, r *http.Request, body map[string]interface{}, w http.ResponseWriter) { + if r.Method != http.MethodPost { + t.Errorf("method=%s, want POST", r.Method) + } + if r.URL.Path != "/openai/v1/chat/completions" { + t.Errorf("path=%s, want /openai/v1/chat/completions", r.URL.Path) + } + if body["stream"] != false { + t.Errorf("stream=%v, want false", body["stream"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{ + "content": "pong", + "reasoning_content": "\nthought", + }, + }}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + stream := true + thinking := true + resp, err := newJieKouAIForTest(srv.URL).ChatWithMessages( + " gpt-5 ", + []Message{{Role: "user", Content: "ping"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{Stream: &stream, Thinking: &thinking}, + ) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if resp.Answer == nil || *resp.Answer != "pong" { + t.Errorf("Answer=%v, want pong", resp.Answer) + } + if resp.ReasonContent == nil || *resp.ReasonContent != "thought" { + t.Errorf("ReasonContent=%v, want thought", resp.ReasonContent) + } +} + +func TestJieKouAIStreamForcesStreaming(t *testing.T) { + srv := newJieKouAIServer(t, func(t *testing.T, r *http.Request, body map[string]interface{}, w http.ResponseWriter) { + if r.URL.Path != "/openai/v1/chat/completions" { + t.Errorf("path=%s, want /openai/v1/chat/completions", r.URL.Path) + } + if body["stream"] != true { + t.Errorf("stream=%v, want true", body["stream"]) + } + if got := r.Header.Get("Accept"); got != "text/event-stream" { + t.Errorf("Accept=%q, want text/event-stream", got) + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, strings.Join([]string{ + `data: {"choices":[{"delta":{"reasoning_content":"thinking"}}]}`, + `data: {"choices":[{"delta":{"content":"hello"}}]}`, + `data: [DONE]`, + ``, + }, "\n")) + }) + defer srv.Close() + + apiKey := "test-key" + stream := false + var content, reasoning []string + err := newJieKouAIForTest(srv.URL).ChatStreamlyWithSender( + "gpt-5", + []Message{{Role: "user", Content: "ping"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{Stream: &stream}, + func(answer, reason *string) error { + if answer != nil { + content = append(content, *answer) + } + if reason != nil { + reasoning = append(reasoning, *reason) + } + return nil + }, + ) + if err != nil { + t.Fatalf("ChatStreamlyWithSender: %v", err) + } + if got := strings.Join(content, ""); got != "hello[DONE]" { + t.Errorf("content=%q, want hello[DONE]", got) + } + if got := strings.Join(reasoning, ""); got != "thinking" { + t.Errorf("reasoning=%q, want thinking", got) + } +} + +func TestJieKouAIListModelsHappyPath(t *testing.T) { + srv := newJieKouAIServer(t, func(t *testing.T, r *http.Request, _ map[string]interface{}, w http.ResponseWriter) { + if r.Method != http.MethodGet { + t.Errorf("method=%s, want GET", r.Method) + } + if r.URL.Path != "/openai/v1/models" { + t.Errorf("path=%s, want /openai/v1/models", r.URL.Path) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]string{ + {"id": "gpt-5"}, + {"id": " text-embedding-3-large "}, + }, + }) + }) + defer srv.Close() + + apiKey := "test-key" + models, err := newJieKouAIForTest(srv.URL).ListModels(&APIConfig{ApiKey: &apiKey}) + if err != nil { + t.Fatalf("ListModels: %v", err) + } + if got := strings.Join(models, ","); got != "gpt-5,text-embedding-3-large" { + t.Errorf("models=%q", got) + } +} + +func TestJieKouAIListModelsRejectsMalformedResponse(t *testing.T) { + apiKey := "test-key" + for name, response := range map[string]interface{}{ + "missing data": map[string]interface{}{"object": "list"}, + "empty id": map[string]interface{}{"data": []map[string]string{{"id": ""}}}, + } { + t.Run(name, func(t *testing.T) { + srv := newJieKouAIServer(t, func(t *testing.T, _ *http.Request, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(response) + }) + defer srv.Close() + + if _, err := newJieKouAIForTest(srv.URL).ListModels(&APIConfig{ApiKey: &apiKey}); err == nil { + t.Fatal("expected malformed response error") + } + }) + } +} + +func TestJieKouAIEmbedSendsValidatedRequest(t *testing.T) { + srv := newJieKouAIServer(t, func(t *testing.T, r *http.Request, body map[string]interface{}, w http.ResponseWriter) { + if r.URL.Path != "/openai/v1/embeddings" { + t.Errorf("path=%s, want /openai/v1/embeddings", r.URL.Path) + } + if body["model"] != "text-embedding-3-large" { + t.Errorf("model=%v, want text-embedding-3-large", body["model"]) + } + inputs, ok := body["input"].([]interface{}) + if !ok || len(inputs) != 1 || inputs[0] != "hello" { + t.Errorf("input=%v, want [hello]", body["input"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{{ + "embedding": []float64{0.1, 0.2}, + "index": 0, + }}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + model := " text-embedding-3-large " + embeddings, err := newJieKouAIForTest(srv.URL).Embed(&model, []string{"hello"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed: %v", err) + } + if len(embeddings) != 1 || embeddings[0].Index != 0 || len(embeddings[0].Embedding) != 2 { + t.Fatalf("embeddings=%v", embeddings) + } +} + +func TestJieKouAIRerankHandlesNilConfig(t *testing.T) { + srv := newJieKouAIServer(t, func(t *testing.T, r *http.Request, body map[string]interface{}, w http.ResponseWriter) { + if r.URL.Path != "/openai/v1/rerank" { + t.Errorf("path=%s, want /openai/v1/rerank", r.URL.Path) + } + if body["model"] != "baai/bge-reranker-v2-m3" { + t.Errorf("model=%v, want baai/bge-reranker-v2-m3", body["model"]) + } + if body["query"] != "question" { + t.Errorf("query=%v, want question", body["query"]) + } + if _, ok := body["top_n"]; ok { + t.Errorf("top_n=%v, want omitted", body["top_n"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "results": []map[string]interface{}{{ + "index": 0, + "relevance_score": 0.9, + }}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + model := " baai/bge-reranker-v2-m3 " + resp, err := newJieKouAIForTest(srv.URL).Rerank(&model, " question ", []string{"doc"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Rerank: %v", err) + } + if resp == nil || len(resp.Data) != 1 || resp.Data[0].Index != 0 || resp.Data[0].RelevanceScore != 0.9 { + t.Fatalf("Rerank response=%v", resp) + } +} + +func TestJieKouAIValidatesInputs(t *testing.T) { + apiKey := "test-key" + emptyKey := " " + model := "gpt-5" + send := func(*string, *string) error { return nil } + + tests := []struct { + name string + run func() error + want string + }{ + { + name: "chat api key", + run: func() error { + _, err := newJieKouAIForTest("http://unused").ChatWithMessages("gpt-5", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &emptyKey}, nil) + return err + }, + want: "api key is required", + }, + { + name: "chat model", + run: func() error { + _, err := newJieKouAIForTest("http://unused").ChatWithMessages(" ", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil) + return err + }, + want: "model name is required", + }, + { + name: "stream api key", + run: func() error { + return newJieKouAIForTest("http://unused").ChatStreamlyWithSender("gpt-5", []Message{{Role: "user", Content: "x"}}, nil, nil, send) + }, + want: "api key is required", + }, + { + name: "stream sender", + run: func() error { + return newJieKouAIForTest("http://unused").ChatStreamlyWithSender("gpt-5", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil, nil) + }, + want: "sender is required", + }, + { + name: "embed model", + run: func() error { + _, err := newJieKouAIForTest("http://unused").Embed(nil, []string{"x"}, &APIConfig{ApiKey: &apiKey}, nil) + return err + }, + want: "model name is required", + }, + { + name: "embed api key", + run: func() error { + _, err := newJieKouAIForTest("http://unused").Embed(&model, []string{"x"}, nil, nil) + return err + }, + want: "api key is required", + }, + { + name: "rerank model", + run: func() error { + _, err := newJieKouAIForTest("http://unused").Rerank(nil, "q", []string{"doc"}, &APIConfig{ApiKey: &apiKey}, nil) + return err + }, + want: "model name is required", + }, + { + name: "rerank api key", + run: func() error { + _, err := newJieKouAIForTest("http://unused").Rerank(&model, "q", []string{"doc"}, &APIConfig{ApiKey: &emptyKey}, nil) + return err + }, + want: "api key is required", + }, + { + name: "rerank query", + run: func() error { + _, err := newJieKouAIForTest("http://unused").Rerank(&model, " ", []string{"doc"}, &APIConfig{ApiKey: &apiKey}, nil) + return err + }, + want: "query is required", + }, + { + name: "models api key", + run: func() error { + _, err := newJieKouAIForTest("http://unused").ListModels(&APIConfig{}) + return err + }, + want: "api key is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.run() + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("expected %q error, got %v", tt.want, err) + } + }) + } +}