From 3f02ca7ba10adf31eabfff87ba56746a202e1c17 Mon Sep 17 00:00:00 2001 From: Haruko386 Date: Fri, 22 May 2026 18:02:01 +0800 Subject: [PATCH] Go: implement embed, rerank, tts for AstraFlow (#15135) ### What problem does this PR solve? implement embed, rerank, tts for AstraFlow **Verify from CLI** ``` # Astraflow RAGFlow(user)> tts with 'IndexTeam/IndexTTS-2@test3@astraflow' text 'hello? show yourself' play format 'wav' param '{"voice": "jack_cheng"}' SUCCESS RAGFlow(user)> rerank query 'what is rag' document 'rag is retrieval augment generation' 'rag need llm' 'famous rag project includes ragflow' with 'bge-reranker-v2-m3@test3@astraflow' top 3; +-------+---------------------+ | index | relevance_score | +-------+---------------------+ | 0 | 0.9837390184402466 | | 2 | 0.06322699040174484 | | 1 | 0.04663187265396118 | +-------+---------------------+ RAGFlow(user)> embed text 'walkerwhat' 'jumperwho' with 'text-embedding-3-large@test3@astraflow' dimension 16 +-----------+-------+ | dimension | index | +-----------+-------+ | 3072 | 0 | | 3072 | 1 | +-----------+-------+ # Xinference ``` ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring --- conf/models/astraflow.json | 102 +++++++++++--- internal/entity/models/astraflow.go | 201 +++++++++++++++++++++++++++- 2 files changed, 279 insertions(+), 24 deletions(-) diff --git a/conf/models/astraflow.json b/conf/models/astraflow.json index 746c92c073..27119a4fde 100644 --- a/conf/models/astraflow.json +++ b/conf/models/astraflow.json @@ -1,103 +1,163 @@ { "name": "Astraflow", "url": { - "default": "https://api-us-ca.umodelverse.ai/v1" + "default": "https://api.modelverse.cn/v1", + "us-ca": "https://api-us-ca.umodelverse.ai/v1" }, "url_suffix": { "chat": "chat/completions", - "models": "models" + "models": "models", + "embedding": "embeddings", + "rerank": "rerank", + "tts": "audio/speech" }, "class": "astraflow", "models": [ + { + "name": "text-embedding-3-large", + "max_tokens": 16384, + "model_types": [ + "embedding" + ] + }, + { + "name": "bge-reranker-v2-m3", + "max_tokens": 8192, + "model_types": [ + "rerank" + ] + }, + { + "name": "IndexTeam/IndexTTS-2", + "model_types": [ + "tts" + ] + }, { "name": "claude-opus-4-7", "max_tokens": 200000, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "claude-opus-4-6", "max_tokens": 200000, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "claude-sonnet-4-5-20250929", "max_tokens": 200000, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "claude-haiku-4-5-20251001", "max_tokens": 200000, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "gpt-5.4", "max_tokens": 400000, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "gpt-5.4-mini", "max_tokens": 400000, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "gpt-5.4-nano", "max_tokens": 400000, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "gpt-4o-mini", "max_tokens": 128000, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "Qwen/Qwen3-Max", "max_tokens": 131072, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "Qwen/Qwen3-Coder", "max_tokens": 131072, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "Qwen/Qwen3-32B", "max_tokens": 131072, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "Qwen/Qwen3-VL-235B-A22B-Instruct", "max_tokens": 131072, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "kimi-k2.6", "max_tokens": 200000, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "glm-5.1", "max_tokens": 128000, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "MiniMax-M2.7", "max_tokens": 1000000, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "MiniMax-M2", "max_tokens": 1000000, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "gemini-2.5-pro", "max_tokens": 1000000, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "gemini-2.5-flash", "max_tokens": 1000000, - "model_types": ["chat"] + "model_types": [ + "chat" + ] } ] -} +} \ No newline at end of file diff --git a/internal/entity/models/astraflow.go b/internal/entity/models/astraflow.go index 4d5479db4a..3c2b5873bc 100644 --- a/internal/entity/models/astraflow.go +++ b/internal/entity/models/astraflow.go @@ -452,11 +452,149 @@ func (a *AstraflowModel) CheckConnection(apiConfig *APIConfig) error { // chat, mirroring how Novita / TogetherAI / DeepInfra landed // method-by-method. func (a *AstraflowModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { - return nil, fmt.Errorf("%s, no such method", a.Name()) + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", a.BaseURL[region], a.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Astraflow embedding API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var parsedResponse struct { + Data []struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` + } + + if err = json.Unmarshal(body, &parsedResponse); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + if len(parsedResponse.Data) == 0 { + return nil, fmt.Errorf("Astraflow embedding response contains no data: %s", string(body)) + } + + var embeddings []EmbeddingData + for _, dataElem := range parsedResponse.Data { + embeddings = append(embeddings, EmbeddingData{ + Embedding: dataElem.Embedding, + Index: dataElem.Index, + }) + } + + return embeddings, nil } func (a *AstraflowModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { - return nil, fmt.Errorf("%s, no such method", a.Name()) + if len(documents) == 0 { + return &RerankResponse{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", a.BaseURL[region], a.URLSuffix.Rerank) + + var topN = rerankConfig.TopN + if rerankConfig.TopN != 0 { + topN = rerankConfig.TopN + } + + reqBody := map[string]interface{}{ + "model": *modelName, + "query": query, + "documents": documents, + "top_n": topN, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Astraflow Rerank API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var rerankResp struct { + Results []struct { + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` + } `json:"results"` + } + + if err = json.Unmarshal(body, &rerankResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + var rerankResponse RerankResponse + for _, result := range rerankResp.Results { + rerankResult := RerankResult{ + Index: result.Index, + RelevanceScore: result.RelevanceScore, + } + rerankResponse.Data = append(rerankResponse.Data, rerankResult) + } + + return &rerankResponse, nil } func (a *AstraflowModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { @@ -472,7 +610,64 @@ func (a *AstraflowModel) TranscribeAudioWithSender(modelName *string, file *stri } func (a *AstraflowModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { - return nil, fmt.Errorf("%s, no such method", a.Name()) + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("Astraflow API key is missing") + } + + if audioContent == nil || *audioContent == "" { + return nil, fmt.Errorf("text content is missing") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", a.BaseURL[region], a.URLSuffix.TTS) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": *audioContent, + } + + if ttsConfig != nil && ttsConfig.Params != nil { + for key, value := range ttsConfig.Params { + reqBody[key] = value + } + } + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["format"] = ttsConfig.Format + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%s - %s", resp.Status, string(body)) + } + + return &TTSResponse{Audio: body}, nil } func (a *AstraflowModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {