diff --git a/conf/models/zhipu-ai.json b/conf/models/zhipu-ai.json index 1c95fa4179..a79a51a443 100644 --- a/conf/models/zhipu-ai.json +++ b/conf/models/zhipu-ai.json @@ -9,11 +9,11 @@ "async_result": "async-result", "embedding": "embeddings", "rerank": "rerank", + "ocr": "layout_parsing", "asr": "audio/transcriptions", "tts": "audio/speech", "files": "files", - "models": "models", - "ocr": "layout_parsing" + "models": "models" }, "class": "glm", "models": [ @@ -269,12 +269,6 @@ "model_types": [ "rerank" ] - }, - { - "name": "glm-ocr", - "model_types": [ - "ocr" - ] } ] } diff --git a/internal/entity/models/zhipu-ai.go b/internal/entity/models/zhipu-ai.go index cd0bf86fc3..9e9cabad66 100644 --- a/internal/entity/models/zhipu-ai.go +++ b/internal/entity/models/zhipu-ai.go @@ -19,6 +19,7 @@ package models import ( "bufio" "bytes" + "encoding/base64" "encoding/json" "fmt" "io" @@ -584,6 +585,10 @@ type zhipuRerankResponse struct { } `json:"results"` } +type zhipuOCRResponse struct { + MarkdownResults *string `json:"md_results"` +} + // Rerank calculates similarity scores between query and documents using // the ZhipuAI /rerank endpoint (e.g. glm-rerank). The result is one // score per input text, in the same order the documents were given. @@ -933,8 +938,88 @@ func (z *ZhipuAIModel) buildTTSRequest(modelName *string, audioContent *string, } // OCRFile OCR file -func (m *ZhipuAIModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { - return nil, fmt.Errorf("%s, no such method", m.Name()) +func (z *ZhipuAIModel) OCRFile(modelName *string, content []byte, fileURL *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + if (fileURL == nil || *fileURL == "") && len(content) == 0 { + return nil, fmt.Errorf("file url or content is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, ok := z.BaseURL[region] + if !ok || baseURL == "" { + return nil, fmt.Errorf("zhipu-ai: no base URL configured for region %q", region) + } + + if z.URLSuffix.OCR == "" { + return nil, fmt.Errorf("zhipu-ai: no OCR URL suffix configured") + } + + file := "" + if fileURL != nil && *fileURL != "" { + file = *fileURL + } else { + mimeType := http.DetectContentType(content) + if len(content) > 4 && string(content[:4]) == "%PDF" { + mimeType = "application/pdf" + } + file = fmt.Sprintf("data:%s;base64,%s", mimeType, base64.StdEncoding.EncodeToString(content)) + } + + reqBody := map[string]interface{}{ + "model": *modelName, + "file": file, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + endpoint := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), strings.TrimPrefix(z.URLSuffix.OCR, "/")) + req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("ZhipuAI OCR API error: %s, body: %s", resp.Status, string(body)) + } + + var zhipuResp zhipuOCRResponse + if err = json.Unmarshal(body, &zhipuResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + if zhipuResp.MarkdownResults == nil { + return nil, fmt.Errorf("ZhipuAI OCR API response missing md_results") + } + + return &OCRFileResponse{Text: zhipuResp.MarkdownResults}, nil } // ParseFile parse file diff --git a/internal/entity/models/zhipu-ai_test.go b/internal/entity/models/zhipu-ai_test.go new file mode 100644 index 0000000000..44b302bc19 --- /dev/null +++ b/internal/entity/models/zhipu-ai_test.go @@ -0,0 +1,158 @@ +package models + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestZhipuAIOCRFileSendsLayoutParsingRequest(t *testing.T) { + apiKey := "test-key" + modelName := "glm-ocr" + fileURL := "https://example.com/doc.png" + expectedText := "# OCR result" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/layout_parsing" { + t.Errorf("path = %s, want /layout_parsing", r.URL.Path) + return + } + if r.Method != http.MethodPost { + t.Errorf("method = %s, want POST", r.Method) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer "+apiKey { + t.Errorf("Authorization = %q", got) + return + } + if got := r.Header.Get("Content-Type"); got != "application/json" { + t.Errorf("Content-Type = %q", got) + return + } + + var req map[string]string + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Errorf("decode request: %v", err) + return + } + if req["model"] != modelName { + t.Errorf("model = %q, want %q", req["model"], modelName) + return + } + if req["file"] != fileURL { + t.Errorf("file = %q, want %q", req["file"], fileURL) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"md_results": expectedText}) + })) + defer server.Close() + + model := NewZhipuAIModel(map[string]string{"default": server.URL}, URLSuffix{OCR: "layout_parsing"}) + resp, err := model.OCRFile(&modelName, nil, &fileURL, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("OCRFile returned error: %v", err) + } + if resp == nil || resp.Text == nil || *resp.Text != expectedText { + t.Fatalf("OCRFile text = %#v, want %q", resp, expectedText) + } +} + +func TestZhipuAIOCRFileEncodesContent(t *testing.T) { + apiKey := "test-key" + modelName := "glm-ocr" + content := []byte("sample image bytes") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req map[string]string + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Errorf("decode request: %v", err) + return + } + if !strings.HasPrefix(req["file"], "data:text/plain; charset=utf-8;base64,") { + t.Errorf("file = %q, want base64 data URL", req["file"]) + return + } + _ = json.NewEncoder(w).Encode(map[string]string{"md_results": "ok"}) + })) + defer server.Close() + + model := NewZhipuAIModel(map[string]string{"default": server.URL}, URLSuffix{OCR: "layout_parsing"}) + if _, err := model.OCRFile(&modelName, content, nil, &APIConfig{ApiKey: &apiKey}, nil); err != nil { + t.Fatalf("OCRFile returned error: %v", err) + } +} + +func TestZhipuAIOCRFileDetectsPDFContent(t *testing.T) { + apiKey := "test-key" + modelName := "glm-ocr" + content := []byte("%PDF-1.7 sample") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req map[string]string + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Errorf("decode request: %v", err) + return + } + if !strings.HasPrefix(req["file"], "data:application/pdf;base64,") { + t.Errorf("file = %q, want PDF data URL", req["file"]) + return + } + _ = json.NewEncoder(w).Encode(map[string]string{"md_results": "ok"}) + })) + defer server.Close() + + model := NewZhipuAIModel(map[string]string{"default": server.URL}, URLSuffix{OCR: "layout_parsing"}) + if _, err := model.OCRFile(&modelName, content, nil, &APIConfig{ApiKey: &apiKey}, nil); err != nil { + t.Fatalf("OCRFile returned error: %v", err) + } +} + +func TestZhipuAIOCRFileValidation(t *testing.T) { + apiKey := "test-key" + modelName := "glm-ocr" + fileURL := "https://example.com/doc.png" + model := NewZhipuAIModel(map[string]string{"default": "https://example.com"}, URLSuffix{OCR: "layout_parsing"}) + + tests := []struct { + name string + modelName *string + fileURL *string + apiConfig *APIConfig + want string + }{ + { + name: "missing api key", + modelName: &modelName, + fileURL: &fileURL, + apiConfig: &APIConfig{}, + want: "api key is required", + }, + { + name: "missing model name", + modelName: nil, + fileURL: &fileURL, + apiConfig: &APIConfig{ApiKey: &apiKey}, + want: "model name is required", + }, + { + name: "missing file", + modelName: &modelName, + fileURL: nil, + apiConfig: &APIConfig{ApiKey: &apiKey}, + want: "file url or content is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := model.OCRFile(tt.modelName, nil, tt.fileURL, tt.apiConfig, nil) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("error = %v, want containing %q", err, tt.want) + } + }) + } +}