From d13a240dc0b1f2a3ea53d59dfe31a5c625cb3434 Mon Sep 17 00:00:00 2001 From: Haruko386 Date: Fri, 8 May 2026 13:56:45 +0800 Subject: [PATCH] Go: implement remaining interface for OpenRouter (#14657) ### What problem does this PR solve? 1. implement `rerank`, `embedding`, `balance`, `checkConnet` method for `OpenRouter` 2. delete `chat` method in `internal/entity/models/volcengine.go` ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring --- conf/models/openrouter.json | 5 +- internal/entity/models/openrouter.go | 215 +++++++++++++++++++++++---- internal/entity/models/volcengine.go | 150 ------------------- 3 files changed, 189 insertions(+), 181 deletions(-) diff --git a/conf/models/openrouter.json b/conf/models/openrouter.json index 4d9fca3665..6af1e2d15d 100644 --- a/conf/models/openrouter.json +++ b/conf/models/openrouter.json @@ -5,7 +5,10 @@ }, "url_suffix": { "chat": "chat/completions", - "models": "models" + "models": "models", + "embedding": "embeddings", + "rerank": "rerank", + "balance": "credits" }, "class": "openrouter", "models": [ diff --git a/internal/entity/models/openrouter.go b/internal/entity/models/openrouter.go index fbc8e3394e..b5ab500d11 100644 --- a/internal/entity/models/openrouter.go +++ b/internal/entity/models/openrouter.go @@ -352,8 +352,122 @@ func (o *OpenRouterModel) ChatStreamlyWithSender(modelName string, messages []Me } func (o *OpenRouterModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { - //TODO implement me - panic("implement me") + if len(texts) == 0 { + return [][]float64{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + } + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("OpenRouter embedding API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + dataObj, ok := result["data"].([]interface{}) + if !ok || len(dataObj) == 0 { + return nil, fmt.Errorf("OpenRouter embedding response contains no data: %s", string(body)) + } + + embeddings := make([][]float64, len(texts)) + + for _, item := range dataObj { + dataMap, ok := item.(map[string]interface{}) + if !ok { + continue + } + + indexFloat, ok := dataMap["index"].(float64) + if !ok { + continue + } + index := int(indexFloat) + + if index < 0 || index >= len(texts) { + continue + } + + embeddingSlice, ok := dataMap["embedding"].([]interface{}) + if !ok { + continue + } + + embedding := make([]float64, len(embeddingSlice)) + for j, v := range embeddingSlice { + switch val := v.(type) { + case float64: + embedding[j] = val + case float32: + embedding[j] = float64(val) + default: + return nil, fmt.Errorf("unexpected embedding value type") + } + } + + embeddings[index] = embedding + } + + return embeddings, nil +} + +// OpenRouterRerankRequest OpenRouter official rerank request format +type OpenRouterRerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + Documents []string `json:"documents"` + TopN int `json:"top_n,omitempty"` +} + +// OpenRouterRerankResponse OpenRouter official rerank response format +type OpenRouterRerankResponse struct { + Model string `json:"model"` + ID string `json:"id"` + Results []struct { + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` + Document *struct { + Text string `json:"text"` + } `json:"document,omitempty"` + } `json:"results"` } func (o *OpenRouterModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) { @@ -366,19 +480,11 @@ func (o *OpenRouterModel) Rerank(modelName *string, query string, texts []string region = *apiConfig.Region } - apiKey := "" - if apiConfig != nil && apiConfig.ApiKey != nil { - apiKey = *apiConfig.ApiKey - } - - reqBody := SiliconflowRerankRequest{ - Model: *modelName, - Query: query, - Documents: texts, - TopN: len(texts), - ReturnDocuments: false, - MaxChunksPerDoc: 1024, - OverlapTokens: 80, + reqBody := OpenRouterRerankRequest{ + Model: *modelName, + Query: query, + Documents: texts, + TopN: len(texts), } jsonData, err := json.Marshal(reqBody) @@ -388,15 +494,13 @@ func (o *OpenRouterModel) Rerank(modelName *string, query string, texts []string url := fmt.Sprintf("%s/%s", strings.TrimSuffix(o.BaseURL[region], "/"), o.URLSuffix.Rerank) - req, err := http.NewRequest("POST", url, strings.NewReader(string(jsonData))) + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") - if apiKey != "" { - req.Header.Set("Authorization", "Bearer "+apiKey) - } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) resp, err := o.httpClient.Do(req) if err != nil { @@ -404,21 +508,25 @@ func (o *OpenRouterModel) Rerank(modelName *string, query string, texts []string } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("SiliconFlow Rerank API error: %s, body: %s", resp.Status, string(body)) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) } - body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("OpenRouter Rerank API error: %s, body: %s", resp.Status, string(body)) + } - var rerankResp SiliconflowRerankResponse - if err := json.Unmarshal(body, &rerankResp); err != nil { + var rerankResp OpenRouterRerankResponse + if err = json.Unmarshal(body, &rerankResp); err != nil { return nil, fmt.Errorf("failed to decode response: %w", err) } scores := make([]float64, len(texts)) + for _, result := range rerankResp.Results { - if result.Index >= 0 && result.Index < len(texts) { + if result.Index >= 0 && + result.Index < len(texts) { scores[result.Index] = result.RelevanceScore } } @@ -483,11 +591,58 @@ func (o *OpenRouterModel) ListModels(apiConfig *APIConfig) ([]string, error) { } func (o *OpenRouterModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { - //TODO implement me - panic("implement me") + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.Balance) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + Data struct { + TotalCredits float64 `json:"total_credits"` + TotalUsage float64 `json:"total_usage"` + } `json:"data"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse balance response: %w", err) + } + + remainingBalance := result.Data.TotalCredits - result.Data.TotalUsage + + return map[string]interface{}{ + "total_credits": result.Data.TotalCredits, + "total_usage": result.Data.TotalUsage, + "balance": remainingBalance, + "currency": "USD", + }, nil } func (o *OpenRouterModel) CheckConnection(apiConfig *APIConfig) error { - //TODO implement me - panic("implement me") + _, err := o.Balance(apiConfig) + return err } diff --git a/internal/entity/models/volcengine.go b/internal/entity/models/volcengine.go index 2364502762..6269ebef5a 100644 --- a/internal/entity/models/volcengine.go +++ b/internal/entity/models/volcengine.go @@ -60,156 +60,6 @@ func (z *VolcEngine) Name() string { return "volcengine" } -// Chat sends a message and returns response -func (z *VolcEngine) Chat(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig) (*ChatResponse, error) { - if message == nil { - return nil, fmt.Errorf("message is nil") - } - - var region = "default" - if apiConfig.Region != nil { - region = *apiConfig.Region - } - - url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Chat) - - //Build request body - reqBody := map[string]interface{}{ - "model": modelName, - "messages": []map[string]string{ - {"role": "user", "content": *message}, - }, - "stream": false, - "temperature": 1, - } - - if modelConfig.Stream != nil { - reqBody["stream"] = *modelConfig.Stream - } - - if modelConfig.MaxTokens != nil { - reqBody["max_tokens"] = *modelConfig.MaxTokens - } - - if modelConfig.Temperature != nil { - reqBody["temperature"] = *modelConfig.Temperature - } - - if modelConfig.TopP != nil { - reqBody["top_p"] = *modelConfig.TopP - } - // TODO VolcEngine has `auto` mode - if modelConfig.Thinking != nil { - if *modelConfig.Thinking { - var thinkingFlag string - switch *modelConfig.Effort { - case "none", "minimal": - thinkingFlag = "disabled" - reqBody["reasoning_effort"] = "minimal" - break - case "low": - thinkingFlag = "enabled" - reqBody["reasoning_effort"] = "low" - break - case "medium": - thinkingFlag = "enabled" - reqBody["reasoning_effort"] = "medium" - break - case "auto", "default": - thinkingFlag = "enabled" - reqBody["reasoning_effort"] = "medium" - break - case "high": - thinkingFlag = "enabled" - reqBody["reasoning_effort"] = "high" - break - default: - return nil, fmt.Errorf("invalid effort level") - } - reqBody["thinking"] = map[string]interface{}{ - "type": thinkingFlag, - } - } else { - reqBody["thinking"] = map[string]interface{}{ - "type": "disabled", - } - } - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) - - resp, err := z.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != 200 { - return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) - } - - // Parse response - var result map[string]interface{} - if err = json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - choices, ok := result["choices"].([]interface{}) - if !ok || len(choices) == 0 { - return nil, fmt.Errorf("no choices in responses") - } - - firstChoice, ok := choices[0].(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("invalid choice format") - } - - messageMap, ok := firstChoice["message"].(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("invalid message format") - } - - content, ok := messageMap["content"].(string) - if !ok { - return nil, fmt.Errorf("invalid content format") - } - - var reasonContent string - if modelConfig.Thinking != nil && *modelConfig.Thinking { - reasonContent, ok = messageMap["reasoning_content"].(string) - if !ok { - return nil, fmt.Errorf("invalid reasonContent format") - } - // if first char of reasonContent is \n remove the \n - if reasonContent != "" && reasonContent[0] == '\n' { - reasonContent = reasonContent[1:] - } - } - - chatResponse := &ChatResponse{ - Answer: &content, - ReasonContent: &reasonContent, - } - - return chatResponse, nil -} - // ChatWithMessages sends multiple messages with roles and returns response func (z *VolcEngine) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { if len(messages) == 0 {