Go: implement remaining interface for OpenRouter (#14657)

### What problem does this PR solve?

1. implement `rerank`, `embedding`, `balance`, `checkConnet` method for
`OpenRouter`
2. delete `chat` method in `internal/entity/models/volcengine.go`

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
This commit is contained in:
Haruko386
2026-05-08 13:56:45 +08:00
committed by GitHub
parent 731c887ba0
commit d13a240dc0
3 changed files with 189 additions and 181 deletions

View File

@@ -5,7 +5,10 @@
},
"url_suffix": {
"chat": "chat/completions",
"models": "models"
"models": "models",
"embedding": "embeddings",
"rerank": "rerank",
"balance": "credits"
},
"class": "openrouter",
"models": [

View File

@@ -352,8 +352,122 @@ func (o *OpenRouterModel) ChatStreamlyWithSender(modelName string, messages []Me
}
func (o *OpenRouterModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
//TODO implement me
panic("implement me")
if len(texts) == 0 {
return [][]float64{}, nil
}
var region = "default"
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.Embedding)
reqBody := map[string]interface{}{
"model": *modelName,
"input": texts,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
}
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("OpenRouter embedding API error: status %d, body: %s", resp.StatusCode, string(body))
}
var result map[string]interface{}
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
dataObj, ok := result["data"].([]interface{})
if !ok || len(dataObj) == 0 {
return nil, fmt.Errorf("OpenRouter embedding response contains no data: %s", string(body))
}
embeddings := make([][]float64, len(texts))
for _, item := range dataObj {
dataMap, ok := item.(map[string]interface{})
if !ok {
continue
}
indexFloat, ok := dataMap["index"].(float64)
if !ok {
continue
}
index := int(indexFloat)
if index < 0 || index >= len(texts) {
continue
}
embeddingSlice, ok := dataMap["embedding"].([]interface{})
if !ok {
continue
}
embedding := make([]float64, len(embeddingSlice))
for j, v := range embeddingSlice {
switch val := v.(type) {
case float64:
embedding[j] = val
case float32:
embedding[j] = float64(val)
default:
return nil, fmt.Errorf("unexpected embedding value type")
}
}
embeddings[index] = embedding
}
return embeddings, nil
}
// OpenRouterRerankRequest OpenRouter official rerank request format
type OpenRouterRerankRequest struct {
Model string `json:"model"`
Query string `json:"query"`
Documents []string `json:"documents"`
TopN int `json:"top_n,omitempty"`
}
// OpenRouterRerankResponse OpenRouter official rerank response format
type OpenRouterRerankResponse struct {
Model string `json:"model"`
ID string `json:"id"`
Results []struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
Document *struct {
Text string `json:"text"`
} `json:"document,omitempty"`
} `json:"results"`
}
func (o *OpenRouterModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
@@ -366,19 +480,11 @@ func (o *OpenRouterModel) Rerank(modelName *string, query string, texts []string
region = *apiConfig.Region
}
apiKey := ""
if apiConfig != nil && apiConfig.ApiKey != nil {
apiKey = *apiConfig.ApiKey
}
reqBody := SiliconflowRerankRequest{
Model: *modelName,
Query: query,
Documents: texts,
TopN: len(texts),
ReturnDocuments: false,
MaxChunksPerDoc: 1024,
OverlapTokens: 80,
reqBody := OpenRouterRerankRequest{
Model: *modelName,
Query: query,
Documents: texts,
TopN: len(texts),
}
jsonData, err := json.Marshal(reqBody)
@@ -388,15 +494,13 @@ func (o *OpenRouterModel) Rerank(modelName *string, query string, texts []string
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(o.BaseURL[region], "/"), o.URLSuffix.Rerank)
req, err := http.NewRequest("POST", url, strings.NewReader(string(jsonData)))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := o.httpClient.Do(req)
if err != nil {
@@ -404,21 +508,25 @@ func (o *OpenRouterModel) Rerank(modelName *string, query string, texts []string
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("SiliconFlow Rerank API error: %s, body: %s", resp.Status, string(body))
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("OpenRouter Rerank API error: %s, body: %s", resp.Status, string(body))
}
var rerankResp SiliconflowRerankResponse
if err := json.Unmarshal(body, &rerankResp); err != nil {
var rerankResp OpenRouterRerankResponse
if err = json.Unmarshal(body, &rerankResp); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
scores := make([]float64, len(texts))
for _, result := range rerankResp.Results {
if result.Index >= 0 && result.Index < len(texts) {
if result.Index >= 0 &&
result.Index < len(texts) {
scores[result.Index] = result.RelevanceScore
}
}
@@ -483,11 +591,58 @@ func (o *OpenRouterModel) ListModels(apiConfig *APIConfig) ([]string, error) {
}
func (o *OpenRouterModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
//TODO implement me
panic("implement me")
region := "default"
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.Balance)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
var result struct {
Data struct {
TotalCredits float64 `json:"total_credits"`
TotalUsage float64 `json:"total_usage"`
} `json:"data"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse balance response: %w", err)
}
remainingBalance := result.Data.TotalCredits - result.Data.TotalUsage
return map[string]interface{}{
"total_credits": result.Data.TotalCredits,
"total_usage": result.Data.TotalUsage,
"balance": remainingBalance,
"currency": "USD",
}, nil
}
func (o *OpenRouterModel) CheckConnection(apiConfig *APIConfig) error {
//TODO implement me
panic("implement me")
_, err := o.Balance(apiConfig)
return err
}

View File

@@ -60,156 +60,6 @@ func (z *VolcEngine) Name() string {
return "volcengine"
}
// Chat sends a message and returns response
func (z *VolcEngine) Chat(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig) (*ChatResponse, error) {
if message == nil {
return nil, fmt.Errorf("message is nil")
}
var region = "default"
if apiConfig.Region != nil {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Chat)
//Build request body
reqBody := map[string]interface{}{
"model": modelName,
"messages": []map[string]string{
{"role": "user", "content": *message},
},
"stream": false,
"temperature": 1,
}
if modelConfig.Stream != nil {
reqBody["stream"] = *modelConfig.Stream
}
if modelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *modelConfig.MaxTokens
}
if modelConfig.Temperature != nil {
reqBody["temperature"] = *modelConfig.Temperature
}
if modelConfig.TopP != nil {
reqBody["top_p"] = *modelConfig.TopP
}
// TODO VolcEngine has `auto` mode
if modelConfig.Thinking != nil {
if *modelConfig.Thinking {
var thinkingFlag string
switch *modelConfig.Effort {
case "none", "minimal":
thinkingFlag = "disabled"
reqBody["reasoning_effort"] = "minimal"
break
case "low":
thinkingFlag = "enabled"
reqBody["reasoning_effort"] = "low"
break
case "medium":
thinkingFlag = "enabled"
reqBody["reasoning_effort"] = "medium"
break
case "auto", "default":
thinkingFlag = "enabled"
reqBody["reasoning_effort"] = "medium"
break
case "high":
thinkingFlag = "enabled"
reqBody["reasoning_effort"] = "high"
break
default:
return nil, fmt.Errorf("invalid effort level")
}
reqBody["thinking"] = map[string]interface{}{
"type": thinkingFlag,
}
} else {
reqBody["thinking"] = map[string]interface{}{
"type": "disabled",
}
}
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := z.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var result map[string]interface{}
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
choices, ok := result["choices"].([]interface{})
if !ok || len(choices) == 0 {
return nil, fmt.Errorf("no choices in responses")
}
firstChoice, ok := choices[0].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid choice format")
}
messageMap, ok := firstChoice["message"].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid message format")
}
content, ok := messageMap["content"].(string)
if !ok {
return nil, fmt.Errorf("invalid content format")
}
var reasonContent string
if modelConfig.Thinking != nil && *modelConfig.Thinking {
reasonContent, ok = messageMap["reasoning_content"].(string)
if !ok {
return nil, fmt.Errorf("invalid reasonContent format")
}
// if first char of reasonContent is \n remove the \n
if reasonContent != "" && reasonContent[0] == '\n' {
reasonContent = reasonContent[1:]
}
}
chatResponse := &ChatResponse{
Answer: &content,
ReasonContent: &reasonContent,
}
return chatResponse, nil
}
// ChatWithMessages sends multiple messages with roles and returns response
func (z *VolcEngine) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
if len(messages) == 0 {