mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-03 01:01:56 +08:00
Go: implement remaining interface for OpenRouter (#14657)
### What problem does this PR solve? 1. implement `rerank`, `embedding`, `balance`, `checkConnet` method for `OpenRouter` 2. delete `chat` method in `internal/entity/models/volcengine.go` ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring
This commit is contained in:
@@ -5,7 +5,10 @@
|
||||
},
|
||||
"url_suffix": {
|
||||
"chat": "chat/completions",
|
||||
"models": "models"
|
||||
"models": "models",
|
||||
"embedding": "embeddings",
|
||||
"rerank": "rerank",
|
||||
"balance": "credits"
|
||||
},
|
||||
"class": "openrouter",
|
||||
"models": [
|
||||
|
||||
@@ -352,8 +352,122 @@ func (o *OpenRouterModel) ChatStreamlyWithSender(modelName string, messages []Me
|
||||
}
|
||||
|
||||
func (o *OpenRouterModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
if len(texts) == 0 {
|
||||
return [][]float64{}, nil
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.Embedding)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"input": texts,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
}
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("OpenRouter embedding API error: status %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
dataObj, ok := result["data"].([]interface{})
|
||||
if !ok || len(dataObj) == 0 {
|
||||
return nil, fmt.Errorf("OpenRouter embedding response contains no data: %s", string(body))
|
||||
}
|
||||
|
||||
embeddings := make([][]float64, len(texts))
|
||||
|
||||
for _, item := range dataObj {
|
||||
dataMap, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
indexFloat, ok := dataMap["index"].(float64)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
index := int(indexFloat)
|
||||
|
||||
if index < 0 || index >= len(texts) {
|
||||
continue
|
||||
}
|
||||
|
||||
embeddingSlice, ok := dataMap["embedding"].([]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
embedding := make([]float64, len(embeddingSlice))
|
||||
for j, v := range embeddingSlice {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
embedding[j] = val
|
||||
case float32:
|
||||
embedding[j] = float64(val)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected embedding value type")
|
||||
}
|
||||
}
|
||||
|
||||
embeddings[index] = embedding
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// OpenRouterRerankRequest OpenRouter official rerank request format
|
||||
type OpenRouterRerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN int `json:"top_n,omitempty"`
|
||||
}
|
||||
|
||||
// OpenRouterRerankResponse OpenRouter official rerank response format
|
||||
type OpenRouterRerankResponse struct {
|
||||
Model string `json:"model"`
|
||||
ID string `json:"id"`
|
||||
Results []struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
Document *struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"document,omitempty"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
func (o *OpenRouterModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
@@ -366,19 +480,11 @@ func (o *OpenRouterModel) Rerank(modelName *string, query string, texts []string
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
apiKey := ""
|
||||
if apiConfig != nil && apiConfig.ApiKey != nil {
|
||||
apiKey = *apiConfig.ApiKey
|
||||
}
|
||||
|
||||
reqBody := SiliconflowRerankRequest{
|
||||
Model: *modelName,
|
||||
Query: query,
|
||||
Documents: texts,
|
||||
TopN: len(texts),
|
||||
ReturnDocuments: false,
|
||||
MaxChunksPerDoc: 1024,
|
||||
OverlapTokens: 80,
|
||||
reqBody := OpenRouterRerankRequest{
|
||||
Model: *modelName,
|
||||
Query: query,
|
||||
Documents: texts,
|
||||
TopN: len(texts),
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
@@ -388,15 +494,13 @@ func (o *OpenRouterModel) Rerank(modelName *string, query string, texts []string
|
||||
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(o.BaseURL[region], "/"), o.URLSuffix.Rerank)
|
||||
|
||||
req, err := http.NewRequest("POST", url, strings.NewReader(string(jsonData)))
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -404,21 +508,25 @@ func (o *OpenRouterModel) Rerank(modelName *string, query string, texts []string
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("SiliconFlow Rerank API error: %s, body: %s", resp.Status, string(body))
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("OpenRouter Rerank API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var rerankResp SiliconflowRerankResponse
|
||||
if err := json.Unmarshal(body, &rerankResp); err != nil {
|
||||
var rerankResp OpenRouterRerankResponse
|
||||
if err = json.Unmarshal(body, &rerankResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
scores := make([]float64, len(texts))
|
||||
|
||||
for _, result := range rerankResp.Results {
|
||||
if result.Index >= 0 && result.Index < len(texts) {
|
||||
if result.Index >= 0 &&
|
||||
result.Index < len(texts) {
|
||||
scores[result.Index] = result.RelevanceScore
|
||||
}
|
||||
}
|
||||
@@ -483,11 +591,58 @@ func (o *OpenRouterModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
}
|
||||
|
||||
func (o *OpenRouterModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
region := "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.Balance)
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Data struct {
|
||||
TotalCredits float64 `json:"total_credits"`
|
||||
TotalUsage float64 `json:"total_usage"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse balance response: %w", err)
|
||||
}
|
||||
|
||||
remainingBalance := result.Data.TotalCredits - result.Data.TotalUsage
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_credits": result.Data.TotalCredits,
|
||||
"total_usage": result.Data.TotalUsage,
|
||||
"balance": remainingBalance,
|
||||
"currency": "USD",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (o *OpenRouterModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
_, err := o.Balance(apiConfig)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -60,156 +60,6 @@ func (z *VolcEngine) Name() string {
|
||||
return "volcengine"
|
||||
}
|
||||
|
||||
// Chat sends a message and returns response
|
||||
func (z *VolcEngine) Chat(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig) (*ChatResponse, error) {
|
||||
if message == nil {
|
||||
return nil, fmt.Errorf("message is nil")
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Chat)
|
||||
|
||||
//Build request body
|
||||
reqBody := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": *message},
|
||||
},
|
||||
"stream": false,
|
||||
"temperature": 1,
|
||||
}
|
||||
|
||||
if modelConfig.Stream != nil {
|
||||
reqBody["stream"] = *modelConfig.Stream
|
||||
}
|
||||
|
||||
if modelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *modelConfig.MaxTokens
|
||||
}
|
||||
|
||||
if modelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *modelConfig.Temperature
|
||||
}
|
||||
|
||||
if modelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *modelConfig.TopP
|
||||
}
|
||||
// TODO VolcEngine has `auto` mode
|
||||
if modelConfig.Thinking != nil {
|
||||
if *modelConfig.Thinking {
|
||||
var thinkingFlag string
|
||||
switch *modelConfig.Effort {
|
||||
case "none", "minimal":
|
||||
thinkingFlag = "disabled"
|
||||
reqBody["reasoning_effort"] = "minimal"
|
||||
break
|
||||
case "low":
|
||||
thinkingFlag = "enabled"
|
||||
reqBody["reasoning_effort"] = "low"
|
||||
break
|
||||
case "medium":
|
||||
thinkingFlag = "enabled"
|
||||
reqBody["reasoning_effort"] = "medium"
|
||||
break
|
||||
case "auto", "default":
|
||||
thinkingFlag = "enabled"
|
||||
reqBody["reasoning_effort"] = "medium"
|
||||
break
|
||||
case "high":
|
||||
thinkingFlag = "enabled"
|
||||
reqBody["reasoning_effort"] = "high"
|
||||
break
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid effort level")
|
||||
}
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": thinkingFlag,
|
||||
}
|
||||
} else {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "disabled",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var result map[string]interface{}
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
choices, ok := result["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
return nil, fmt.Errorf("no choices in responses")
|
||||
}
|
||||
|
||||
firstChoice, ok := choices[0].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid choice format")
|
||||
}
|
||||
|
||||
messageMap, ok := firstChoice["message"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid message format")
|
||||
}
|
||||
|
||||
content, ok := messageMap["content"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid content format")
|
||||
}
|
||||
|
||||
var reasonContent string
|
||||
if modelConfig.Thinking != nil && *modelConfig.Thinking {
|
||||
reasonContent, ok = messageMap["reasoning_content"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid reasonContent format")
|
||||
}
|
||||
// if first char of reasonContent is \n remove the \n
|
||||
if reasonContent != "" && reasonContent[0] == '\n' {
|
||||
reasonContent = reasonContent[1:]
|
||||
}
|
||||
}
|
||||
|
||||
chatResponse := &ChatResponse{
|
||||
Answer: &content,
|
||||
ReasonContent: &reasonContent,
|
||||
}
|
||||
|
||||
return chatResponse, nil
|
||||
}
|
||||
|
||||
// ChatWithMessages sends multiple messages with roles and returns response
|
||||
func (z *VolcEngine) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
||||
if len(messages) == 0 {
|
||||
|
||||
Reference in New Issue
Block a user