From a725e114f9c4d0bacb8dec7c0267728cf4f8e24d Mon Sep 17 00:00:00 2001 From: Haruko386 Date: Thu, 21 May 2026 18:28:06 +0800 Subject: [PATCH] Go: implement ASR and TTS for Xinference (#15096) ### What problem does this PR solve? implement ASR and TTS for Xinference ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring --- conf/models/novita.json | 1 + internal/entity/models/fishaudio.go | 7 +- internal/entity/models/moonshot.go | 61 +++---- internal/entity/models/novita.go | 68 +++++++- internal/entity/models/xinference.go | 162 +++++++++++++++++- .../test_dify_retrieval_routes_unit.py | 6 +- .../api/apps/sdk/test_dify_retrieval.py | 5 +- 7 files changed, 266 insertions(+), 44 deletions(-) diff --git a/conf/models/novita.json b/conf/models/novita.json index 6dad88c2fa..dfc11e0382 100644 --- a/conf/models/novita.json +++ b/conf/models/novita.json @@ -7,6 +7,7 @@ "chat": "openai/v1/chat/completions", "models": "openai/v1/models", "embedding": "openai/v1/embeddings", + "balance": "openapi/v1/billing/balance/detail", "rerank": "openai/v1/rerank" }, "class": "novita", diff --git a/internal/entity/models/fishaudio.go b/internal/entity/models/fishaudio.go index 70f6707621..5e2696f51f 100644 --- a/internal/entity/models/fishaudio.go +++ b/internal/entity/models/fishaudio.go @@ -66,7 +66,6 @@ func (f *FishAudioModel) Rerank(modelName *string, query string, documents []str // TranscribeAudio transcribe audio func (f *FishAudioModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { - if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { return nil, fmt.Errorf("FishAudio API key is missing") } @@ -151,11 +150,7 @@ func (f *FishAudioModel) TranscribeAudio(modelName *string, file *string, apiCon } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf( - "FishAudio ASR error: %s - %s", - resp.Status, - string(respBody), - ) + return nil, fmt.Errorf("FishAudio ASR error: %s - %s", resp.Status, string(respBody)) } // result diff --git a/internal/entity/models/moonshot.go b/internal/entity/models/moonshot.go index fa1ad76ec4..6114bd2546 100644 --- a/internal/entity/models/moonshot.go +++ b/internal/entity/models/moonshot.go @@ -210,7 +210,7 @@ func (k *MoonshotModel) ChatStreamlyWithSender(modelName string, messages []Mess region = *apiConfig.Region } - url := fmt.Sprintf("%s/chat/completions", k.BaseURL[region]) + url := fmt.Sprintf("%s/%s", k.BaseURL[region], k.URLSuffix.Chat) // Convert messages to API format apiMessages := make([]map[string]interface{}, len(messages)) @@ -228,38 +228,40 @@ func (k *MoonshotModel) ChatStreamlyWithSender(modelName string, messages []Mess "stream": true, } - if chatModelConfig.Stream != nil { - reqBody["stream"] = *chatModelConfig.Stream - } + if chatModelConfig != nil { + if chatModelConfig.Stream != nil { + reqBody["stream"] = *chatModelConfig.Stream + } - if chatModelConfig.MaxTokens != nil { - reqBody["max_tokens"] = *chatModelConfig.MaxTokens - } + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } - if chatModelConfig.Temperature != nil { - reqBody["temperature"] = *chatModelConfig.Temperature - } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } - if chatModelConfig.DoSample != nil { - reqBody["do_sample"] = *chatModelConfig.DoSample - } + if chatModelConfig.DoSample != nil { + reqBody["do_sample"] = *chatModelConfig.DoSample + } - if chatModelConfig.TopP != nil { - reqBody["top_p"] = *chatModelConfig.TopP - } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } - if chatModelConfig.Stop != nil { - reqBody["stop"] = *chatModelConfig.Stop - } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } - if chatModelConfig.Thinking != nil { - if *chatModelConfig.Thinking { - reqBody["thinking"] = map[string]interface{}{ - "type": "enabled", - } - } else { - reqBody["thinking"] = map[string]interface{}{ - "type": "disabled", + if chatModelConfig.Thinking != nil { + if *chatModelConfig.Thinking { + reqBody["thinking"] = map[string]interface{}{ + "type": "enabled", + } + } else { + reqBody["thinking"] = map[string]interface{}{ + "type": "disabled", + } } } } @@ -364,7 +366,7 @@ func (z *MoonshotModel) Embed(modelName *string, texts []string, apiConfig *APIC func (z *MoonshotModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" - if apiConfig.Region != nil { + if apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region } @@ -419,9 +421,8 @@ func (z *MoonshotModel) ListModels(apiConfig *APIConfig) ([]string, error) { } func (z *MoonshotModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { - var region = "default" - if apiConfig.Region != nil { + if apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region } diff --git a/internal/entity/models/novita.go b/internal/entity/models/novita.go index 7335dbff68..980a61949b 100644 --- a/internal/entity/models/novita.go +++ b/internal/entity/models/novita.go @@ -24,6 +24,7 @@ import ( "fmt" "io" "net/http" + "strconv" "strings" "time" ) @@ -841,9 +842,72 @@ func (n *NovitaModel) Rerank(modelName *string, query string, documents []string return &rerankResponse, nil } -// Balance is not exposed by the Novita API. +// Balance Get remaining credit func (n *NovitaModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { - return nil, fmt.Errorf("%s, no such method", n.Name()) + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", n.BaseURL[region], n.URLSuffix.Balance) + + // Build request body + reqBody := map[string]interface{}{} + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := n.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + balanceInterface, exists := result["availableBalance"] + if !exists || balanceInterface == nil { + return nil, fmt.Errorf("missing 'availableBalance' in response. Raw body: %s", string(body)) + } + + balanceStr, ok := balanceInterface.(string) + if !ok { + return nil, fmt.Errorf("'availableBalance' is not a string. Raw body: %s", string(body)) + } + balance, err := strconv.ParseFloat(balanceStr, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse 'availableBalance' as float: %w. Raw body: %s", err, string(body)) + } + + var response = map[string]interface{}{ + "balance": balance, + "currency": "USD", + } + + return response, nil } func (n *NovitaModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { diff --git a/internal/entity/models/xinference.go b/internal/entity/models/xinference.go index 971948cfc4..52307da44d 100644 --- a/internal/entity/models/xinference.go +++ b/internal/entity/models/xinference.go @@ -12,7 +12,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// package models @@ -23,7 +22,11 @@ import ( "encoding/json" "fmt" "io" + "mime/multipart" "net/http" + "os" + "path/filepath" + "strconv" "strings" "sync" "time" @@ -589,15 +592,166 @@ func (x *XinferenceModel) Rerank(modelName *string, query string, documents []st } func (x *XinferenceModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { - return nil, fmt.Errorf("%s, no such method", x.Name()) + if file == nil || *file == "" { + return nil, fmt.Errorf("file is missing") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", x.BaseURL[region], x.URLSuffix.ASR) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + // audio file + audioFile, err := os.Open(*file) + if err != nil { + return nil, fmt.Errorf("failed to open audio file: %w", err) + } + defer audioFile.Close() + + part, err := writer.CreateFormFile("file", filepath.Base(*file)) + if err != nil { + return nil, fmt.Errorf("failed to create multipart file: %w", err) + } + + if _, err = io.Copy(part, audioFile); err != nil { + return nil, fmt.Errorf("failed to copy audio data: %w", err) + } + + if err = writer.WriteField("model", *modelName); err != nil { + return nil, fmt.Errorf("failed to write model name: %w", err) + } + + // extra params + if asrConfig != nil && asrConfig.Params != nil { + for key, value := range asrConfig.Params { + + var val string + + switch v := value.(type) { + case string: + val = v + case bool: + val = strconv.FormatBool(v) + case int: + val = strconv.Itoa(v) + case float64: + val = strconv.FormatFloat(v, 'f', -1, 64) + default: + val = fmt.Sprintf("%v", v) + } + + if err := writer.WriteField(key, val); err != nil { + return nil, fmt.Errorf("failed to write field %s: %w", key, err) + } + } + } + + if err := writer.Close(); err != nil { + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + // request + req, err := http.NewRequest("POST", url, &body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Content-Type", writer.FormDataContentType()) + + resp, err := x.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("FishAudio ASR error: %s - %s", resp.Status, string(respBody)) + } + + // result + var result struct { + Text string `json:"text"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &ASRResponse{ + Text: result.Text, + }, nil } func (x *XinferenceModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { return fmt.Errorf("%s, no such method", x.Name()) } -func (x *XinferenceModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { - return nil, fmt.Errorf("%s, no such method", x.Name()) +func (x *XinferenceModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + if audioContent == nil || *audioContent == "" { + return nil, fmt.Errorf("text content is missing") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", x.BaseURL[region], x.URLSuffix.TTS) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": *audioContent, + } + + if ttsConfig != nil && ttsConfig.Params != nil { + for key, value := range ttsConfig.Params { + reqBody[key] = value + } + } + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["format"] = ttsConfig.Format + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := x.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%s - %s", resp.Status, string(body)) + } + + return &TTSResponse{Audio: body}, nil } func (x *XinferenceModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { diff --git a/test/testcases/restful_api/test_dify_retrieval_routes_unit.py b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py index 01b23e6107..b348503030 100644 --- a/test/testcases/restful_api/test_dify_retrieval_routes_unit.py +++ b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py @@ -290,7 +290,11 @@ def test_retrieval_success_with_metadata_and_kg(monkeypatch): } monkeypatch.setattr(module.settings, "kg_retriever", _DummyKgRetriever()) - monkeypatch.setattr(module.DocumentService, "get_by_id", lambda doc_id: (True, SimpleNamespace(meta_fields={"origin": f"meta-{doc_id}"}))) + monkeypatch.setattr( + module.DocumentService, + "get_by_ids", + lambda doc_ids, cols=None: [SimpleNamespace(id=doc_id, meta_fields={"origin": f"meta-{doc_id}"}) for doc_id in doc_ids], + ) monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: []) res = _run(inspect.unwrap(module.retrieval)("tenant-1")) diff --git a/test/unit_test/api/apps/sdk/test_dify_retrieval.py b/test/unit_test/api/apps/sdk/test_dify_retrieval.py index 113ff139f0..a74da5a649 100644 --- a/test/unit_test/api/apps/sdk/test_dify_retrieval.py +++ b/test/unit_test/api/apps/sdk/test_dify_retrieval.py @@ -97,7 +97,10 @@ def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, chunks=No _stub( monkeypatch, "api.db.services.document_service", - DocumentService=SimpleNamespace(get_by_id=lambda _id: (True, SimpleNamespace(meta_fields={}))), + DocumentService=SimpleNamespace( + get_by_id=lambda _id: (True, SimpleNamespace(id=_id, meta_fields={})), + get_by_ids=lambda ids, cols=None: [SimpleNamespace(id=doc_id, meta_fields={}) for doc_id in ids], + ), ) _stub( monkeypatch,