mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 16:25:44 +08:00
Go: implement embed, rerank, tts for AstraFlow (#15135)
### What problem does this PR solve?
implement embed, rerank, tts for AstraFlow
**Verify from CLI**
```
# Astraflow
RAGFlow(user)> tts with 'IndexTeam/IndexTTS-2@test3@astraflow' text 'hello? show yourself' play format 'wav' param '{"voice": "jack_cheng"}'
SUCCESS
RAGFlow(user)> rerank query 'what is rag' document 'rag is retrieval augment generation' 'rag need llm' 'famous rag project includes ragflow' with 'bge-reranker-v2-m3@test3@astraflow' top 3;
+-------+---------------------+
| index | relevance_score |
+-------+---------------------+
| 0 | 0.9837390184402466 |
| 2 | 0.06322699040174484 |
| 1 | 0.04663187265396118 |
+-------+---------------------+
RAGFlow(user)> embed text 'walkerwhat' 'jumperwho' with 'text-embedding-3-large@test3@astraflow' dimension 16
+-----------+-------+
| dimension | index |
+-----------+-------+
| 3072 | 0 |
| 3072 | 1 |
+-----------+-------+
# Xinference
```
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
This commit is contained in:
@@ -1,103 +1,163 @@
|
||||
{
|
||||
"name": "Astraflow",
|
||||
"url": {
|
||||
"default": "https://api-us-ca.umodelverse.ai/v1"
|
||||
"default": "https://api.modelverse.cn/v1",
|
||||
"us-ca": "https://api-us-ca.umodelverse.ai/v1"
|
||||
},
|
||||
"url_suffix": {
|
||||
"chat": "chat/completions",
|
||||
"models": "models"
|
||||
"models": "models",
|
||||
"embedding": "embeddings",
|
||||
"rerank": "rerank",
|
||||
"tts": "audio/speech"
|
||||
},
|
||||
"class": "astraflow",
|
||||
"models": [
|
||||
{
|
||||
"name": "text-embedding-3-large",
|
||||
"max_tokens": 16384,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "bge-reranker-v2-m3",
|
||||
"max_tokens": 8192,
|
||||
"model_types": [
|
||||
"rerank"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "IndexTeam/IndexTTS-2",
|
||||
"model_types": [
|
||||
"tts"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "claude-opus-4-7",
|
||||
"max_tokens": 200000,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "claude-opus-4-6",
|
||||
"max_tokens": 200000,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "claude-sonnet-4-5-20250929",
|
||||
"max_tokens": 200000,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "claude-haiku-4-5-20251001",
|
||||
"max_tokens": 200000,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "gpt-5.4",
|
||||
"max_tokens": 400000,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "gpt-5.4-mini",
|
||||
"max_tokens": 400000,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "gpt-5.4-nano",
|
||||
"max_tokens": 400000,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o-mini",
|
||||
"max_tokens": 128000,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Qwen/Qwen3-Max",
|
||||
"max_tokens": 131072,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Qwen/Qwen3-Coder",
|
||||
"max_tokens": 131072,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Qwen/Qwen3-32B",
|
||||
"max_tokens": 131072,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Qwen/Qwen3-VL-235B-A22B-Instruct",
|
||||
"max_tokens": 131072,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "kimi-k2.6",
|
||||
"max_tokens": 200000,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "glm-5.1",
|
||||
"max_tokens": 128000,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MiniMax-M2.7",
|
||||
"max_tokens": 1000000,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MiniMax-M2",
|
||||
"max_tokens": 1000000,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "gemini-2.5-pro",
|
||||
"max_tokens": 1000000,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "gemini-2.5-flash",
|
||||
"max_tokens": 1000000,
|
||||
"model_types": ["chat"]
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -452,11 +452,149 @@ func (a *AstraflowModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
// chat, mirroring how Novita / TogetherAI / DeepInfra landed
|
||||
// method-by-method.
|
||||
func (a *AstraflowModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", a.Name())
|
||||
if len(texts) == 0 {
|
||||
return []EmbeddingData{}, nil
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", a.BaseURL[region], a.URLSuffix.Embedding)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"input": texts,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Astraflow embedding API error: status %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var parsedResponse struct {
|
||||
Data []struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(body, &parsedResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(parsedResponse.Data) == 0 {
|
||||
return nil, fmt.Errorf("Astraflow embedding response contains no data: %s", string(body))
|
||||
}
|
||||
|
||||
var embeddings []EmbeddingData
|
||||
for _, dataElem := range parsedResponse.Data {
|
||||
embeddings = append(embeddings, EmbeddingData{
|
||||
Embedding: dataElem.Embedding,
|
||||
Index: dataElem.Index,
|
||||
})
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (a *AstraflowModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", a.Name())
|
||||
if len(documents) == 0 {
|
||||
return &RerankResponse{}, nil
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", a.BaseURL[region], a.URLSuffix.Rerank)
|
||||
|
||||
var topN = rerankConfig.TopN
|
||||
if rerankConfig.TopN != 0 {
|
||||
topN = rerankConfig.TopN
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": topN,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Astraflow Rerank API error: status %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var rerankResp struct {
|
||||
Results []struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(body, &rerankResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
var rerankResponse RerankResponse
|
||||
for _, result := range rerankResp.Results {
|
||||
rerankResult := RerankResult{
|
||||
Index: result.Index,
|
||||
RelevanceScore: result.RelevanceScore,
|
||||
}
|
||||
rerankResponse.Data = append(rerankResponse.Data, rerankResult)
|
||||
}
|
||||
|
||||
return &rerankResponse, nil
|
||||
}
|
||||
|
||||
func (a *AstraflowModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
|
||||
@@ -472,7 +610,64 @@ func (a *AstraflowModel) TranscribeAudioWithSender(modelName *string, file *stri
|
||||
}
|
||||
|
||||
func (a *AstraflowModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", a.Name())
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("Astraflow API key is missing")
|
||||
}
|
||||
|
||||
if audioContent == nil || *audioContent == "" {
|
||||
return nil, fmt.Errorf("text content is missing")
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", a.BaseURL[region], a.URLSuffix.TTS)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"input": *audioContent,
|
||||
}
|
||||
|
||||
if ttsConfig != nil && ttsConfig.Params != nil {
|
||||
for key, value := range ttsConfig.Params {
|
||||
reqBody[key] = value
|
||||
}
|
||||
}
|
||||
if ttsConfig != nil && ttsConfig.Format != "" {
|
||||
reqBody["format"] = ttsConfig.Format
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("%s - %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
return &TTSResponse{Audio: body}, nil
|
||||
}
|
||||
|
||||
func (a *AstraflowModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {
|
||||
|
||||
Reference in New Issue
Block a user