diff --git a/internal/cli/client.go b/internal/cli/client.go index 2bd50cb695..0523b36c05 100644 --- a/internal/cli/client.go +++ b/internal/cli/client.go @@ -267,6 +267,12 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { return c.EmbedUserText(cmd) case "rarank_user_document": return c.RerankUserDocument(cmd) + case "tts_user_command": + return c.TTSUserCommand(cmd) + case "asr_user_command": + return c.ASRUserCommand(cmd) + case "ocr_user_command": + return c.OCRUserCommand(cmd) case "check_provider_connection": return c.CheckProviderConnection(cmd) case "use_model": diff --git a/internal/cli/parser.go b/internal/cli/parser.go index e373c5a874..0bba27847b 100644 --- a/internal/cli/parser.go +++ b/internal/cli/parser.go @@ -201,6 +201,12 @@ func (p *Parser) parseUserCommand() (*Command, error) { return p.parseEmbedCommand() case TokenRerank: return p.parseRerankCommand() + case TokenASR: + return p.parseASRCommand() + case TokenTTS: + return p.parseTTSCommand() + case TokenOCR: + return p.parseOCRCommand() case TokenCheck: return p.parseCheckCommand() case TokenLS: diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index 14a058aa25..abc06c443d 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -27,6 +27,7 @@ import ( "net" netUrl "net/url" "os" + "path/filepath" ce "ragflow/internal/cli/filesystem" "strings" "time" @@ -1622,16 +1623,35 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) { } } - //audios, ok := cmd.Params["audios"].([]string) - //if !ok { - // return nil, fmt.Errorf("images not provided") - //} + audios, ok := cmd.Params["audios"].([]string) + if !ok { + return nil, fmt.Errorf("images not provided") + } + if len(audios) > 0 { + if len(audios) != 1 { + return nil, fmt.Errorf("only one audio file is supported") + } + audioFile := audios[0] + audioContent, err := os.ReadFile(audioFile) + if err != nil { + return nil, fmt.Errorf("failed to read audio: %w", err) + } + // file type: wav or mp3 + format := filepath.Ext(audioFile) // file type: wav or mp3 + format = strings.TrimPrefix(format, ".") + contents = append(contents, map[string]interface{}{ + "type": "input_audio", + "input_audio": map[string]interface{}{ + "data": base64.StdEncoding.EncodeToString(audioContent), + "format": format, + }, + }) + } files, ok := cmd.Params["files"].([]string) if !ok { return nil, fmt.Errorf("images not provided") } - if len(files) > 0 { for _, file := range files { if isValidURL(file) { @@ -1660,21 +1680,6 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) { url := "/chat/completions" - //message = strings.TrimSpace(message) - //var content interface{} = message - //if strings.HasPrefix(message, "[") && strings.HasSuffix(message, "]") { - // var parts []map[string]interface{} - // if err := json.Unmarshal([]byte(message), &parts); err == nil { - // content = parts - // } - //} - //formattedMessage := []map[string]interface{}{ - // { - // "role": "user", - // "content": content, - // }, - //} - payload := map[string]interface{}{ "provider_name": providerName, "instance_name": instanceName, @@ -1922,6 +1927,210 @@ func (c *RAGFlowClient) RerankUserDocument(cmd *Command) (ResponseIf, error) { return &result, nil } +func (c *RAGFlowClient) TTSUserCommand(cmd *Command) (ResponseIf, error) { + if c.HTTPClient.APIToken == "" && c.HTTPClient.LoginToken == "" { + return nil, fmt.Errorf("API token not set. Please login first") + } + + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + var providerName, instanceName, modelName string + + // Check if composite_model_name is provided in command + if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" { + names := strings.Split(compositeModelName, "@") + if len(names) != 3 { + return nil, fmt.Errorf("model name must be in format 'model@instance@provider'") + } + providerName = names[2] + instanceName = names[1] + modelName = names[0] + } else if c.CurrentModel != nil { + // Use current model if set + providerName = c.CurrentModel.Provider + instanceName = c.CurrentModel.Instance + modelName = c.CurrentModel.Model + } else { + return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first") + } + + text, ok := cmd.Params["text"].(string) + if !ok { + return nil, fmt.Errorf("text not provided") + } + + //fileToSave, ok := cmd.Params["file"].(string) + //if !ok { + // return nil, fmt.Errorf("file not provided") + //} + + payload := map[string]interface{}{ + "provider_name": providerName, + "instance_name": instanceName, + "model_name": modelName, + "text": text, + } + + url := "/audio/speech" + + resp, err := c.HTTPClient.Request("POST", url, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to TTS document: %w", err) + } + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to TTS document: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("TTS document failed: invalid JSON (%w)", err) + } + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + + // save file + //err = os.WriteFile(fileToSave, resp.Body, 0644) + //if err != nil { + // result.Message += fmt.Sprintf("failed to save file: %s", err.Error()) + // result.Code = 1 + //} + + return &result, nil +} + +func (c *RAGFlowClient) ASRUserCommand(cmd *Command) (ResponseIf, error) { + if c.HTTPClient.APIToken == "" && c.HTTPClient.LoginToken == "" { + return nil, fmt.Errorf("API token not set. Please login first") + } + + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + var providerName, instanceName, modelName string + + // Check if composite_model_name is provided in command + if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" { + names := strings.Split(compositeModelName, "@") + if len(names) != 3 { + return nil, fmt.Errorf("model name must be in format 'model@instance@provider'") + } + providerName = names[2] + instanceName = names[1] + modelName = names[0] + } else if c.CurrentModel != nil { + // Use current model if set + providerName = c.CurrentModel.Provider + instanceName = c.CurrentModel.Instance + modelName = c.CurrentModel.Model + } else { + return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first") + } + + audioFile, ok := cmd.Params["audio_file"].(string) + if !ok { + return nil, fmt.Errorf("text not provided") + } + + payload := map[string]interface{}{ + "provider_name": providerName, + "instance_name": instanceName, + "model_name": modelName, + "audio_file": audioFile, + } + + url := "/audio/transcriptions" + + resp, err := c.HTTPClient.Request("POST", url, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to ASR document: %w", err) + } + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to ASR document: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("ASR document failed: invalid JSON (%w)", err) + } + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + + return &result, nil +} + +func (c *RAGFlowClient) OCRUserCommand(cmd *Command) (ResponseIf, error) { + if c.HTTPClient.APIToken == "" && c.HTTPClient.LoginToken == "" { + return nil, fmt.Errorf("API token not set. Please login first") + } + + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + var providerName, instanceName, modelName string + + // Check if composite_model_name is provided in command + if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" { + names := strings.Split(compositeModelName, "@") + if len(names) != 3 { + return nil, fmt.Errorf("model name must be in format 'model@instance@provider'") + } + providerName = names[2] + instanceName = names[1] + modelName = names[0] + } else if c.CurrentModel != nil { + // Use current model if set + providerName = c.CurrentModel.Provider + instanceName = c.CurrentModel.Instance + modelName = c.CurrentModel.Model + } else { + return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first") + } + + filename, ok := cmd.Params["file"].(string) + if !ok { + return nil, fmt.Errorf("text not provided") + } + + // read file and convert to base64 + text, err := os.ReadFile(filename) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) + } + base64Text := base64.StdEncoding.EncodeToString(text) + payload := map[string]interface{}{ + "provider_name": providerName, + "instance_name": instanceName, + "model_name": modelName, + "content": base64Text, + } + + url := "/file/ocr" + + resp, err := c.HTTPClient.Request("POST", url, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to OCR document: %w", err) + } + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to OCR document: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("OCR document failed: invalid JSON (%w)", err) + } + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + + return &result, nil +} + func (c *RAGFlowClient) CheckProviderConnection(cmd *Command) (ResponseIf, error) { if c.HTTPClient.APIToken == "" && c.HTTPClient.LoginToken == "" { return nil, fmt.Errorf("API token not set. Please login first") diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index c49eeee11a..5c98b52f42 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -2587,16 +2587,29 @@ func (p *Parser) parseStreamCommand() (*Command, error) { var command *Command var err error - if p.curToken.Type == TokenChat { + switch p.curToken.Type { + case TokenChat: command, err = p.parseChatCommand() if err != nil { return nil, err } - } else if p.curToken.Type == TokenThink { + case TokenThink: command, err = p.parseThinkCommand() if err != nil { return nil, err } + case TokenASR: + command, err = p.parseASRCommand() + if err != nil { + return nil, err + } + case TokenTTS: + command, err = p.parseTTSCommand() + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("expected CHAT, THINK, ASR, or TTS after STREAM") } command.Params["stream"] = true @@ -2723,6 +2736,109 @@ documentLoop: return cmd, nil } +func (p *Parser) parseASRCommand() (*Command, error) { + p.nextToken() // consume ASR + + if p.curToken.Type != TokenWith { + return nil, fmt.Errorf("expected WITH after ASR") + } + p.nextToken() // consume WITH + + compositeModelName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() + + if p.curToken.Type != TokenAudio { + return nil, fmt.Errorf("expected AUDIO to ASR") + } + p.nextToken() // consume FILE + + audioFile, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + + cmd := NewCommand("asr_user_command") + cmd.Params["composite_model_name"] = compositeModelName + cmd.Params["audio_file"] = audioFile + return cmd, nil +} + +func (p *Parser) parseTTSCommand() (*Command, error) { + p.nextToken() // consume TTS + + if p.curToken.Type != TokenWith { + return nil, fmt.Errorf("expected WITH after TTS") + } + p.nextToken() // consume WITH + + compositeModelName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() + + if p.curToken.Type != TokenText { + return nil, fmt.Errorf("expected TEXT to TTS") + } + p.nextToken() // consume FILE + + text, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + + cmd := NewCommand("tts_user_command") + cmd.Params["composite_model_name"] = compositeModelName + cmd.Params["text"] = text + return cmd, nil +} + +func (p *Parser) parseOCRCommand() (*Command, error) { + p.nextToken() // consume OCR + + if p.curToken.Type != TokenWith { + return nil, fmt.Errorf("expected WITH after OCR") + } + p.nextToken() // consume WITH + + compositeModelName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() + + if p.curToken.Type != TokenFile { + return nil, fmt.Errorf("expected FILE to OCR") + } + p.nextToken() // consume FILE + + file, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() + + cmd := NewCommand("ocr_user_command") + cmd.Params["composite_model_name"] = compositeModelName + cmd.Params["file"] = file + return cmd, nil +} + func (p *Parser) parseCheckCommand() (*Command, error) { p.nextToken() // consume CHECK diff --git a/internal/entity/models/aliyun.go b/internal/entity/models/aliyun.go index 325eb0ac6d..e010bfecdc 100644 --- a/internal/entity/models/aliyun.go +++ b/internal/entity/models/aliyun.go @@ -36,6 +36,11 @@ type AliyunModel struct { httpClient *http.Client // Reusable HTTP client with connection pool } +func (z *AliyunModel) ParseFile() { + //TODO implement me + panic("implement me") +} + // NewAliyunModel creates a new Aliyun model instance func NewAliyunModel(baseURL map[string]string, urlSuffix URLSuffix) *AliyunModel { return &AliyunModel{ @@ -555,6 +560,29 @@ func (z *AliyunModel) Rerank(modelName *string, query string, documents []string return &rerankResponse, nil } +// TranscribeAudio transcribe audio +func (z *AliyunModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *AliyunModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (z *AliyunModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *AliyunModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (z *AliyunModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + type AliyunModelItem struct { ModelName string `json:"model_name"` BaseCapacity int `json:"base_capacity"` diff --git a/internal/entity/models/baichuan.go b/internal/entity/models/baichuan.go index 5a8282164a..1b0cf78a9f 100644 --- a/internal/entity/models/baichuan.go +++ b/internal/entity/models/baichuan.go @@ -380,6 +380,29 @@ func (b *BaichuanModel) Rerank(modelName *string, query string, documents []stri return nil, fmt.Errorf("no such method") } +// TranscribeAudio transcribe audio +func (z *BaichuanModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *BaichuanModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (z *BaichuanModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *BaichuanModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (z *BaichuanModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + func (b *BaichuanModel) ListModels(apiConfig *APIConfig) ([]string, error) { return nil, fmt.Errorf("no such method") } diff --git a/internal/entity/models/baidu.go b/internal/entity/models/baidu.go index 15fb4f4284..7e81995a70 100644 --- a/internal/entity/models/baidu.go +++ b/internal/entity/models/baidu.go @@ -18,6 +18,11 @@ type BaiduModel struct { httpClient *http.Client } +func (b *BaiduModel) ParseFile() { + //TODO implement me + panic("implement me") +} + func (b *BaiduModel) NewInstance(baseURL map[string]string) ModelDriver { return &BaiduModel{ BaseURL: baseURL, @@ -568,6 +573,29 @@ func (b *BaiduModel) Rerank(modelName *string, query string, documents []string, return &rerankResponse, nil } +// TranscribeAudio transcribe audio +func (b *BaiduModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", b.Name()) +} + +func (z *BaiduModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (b *BaiduModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", b.Name()) +} + +func (z *BaiduModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (b *BaiduModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", b.Name()) +} + func (b *BaiduModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { diff --git a/internal/entity/models/cohere.go b/internal/entity/models/cohere.go index f327400676..61dc60551b 100644 --- a/internal/entity/models/cohere.go +++ b/internal/entity/models/cohere.go @@ -17,6 +17,11 @@ type CoHereModel struct { httpClient *http.Client } +func (c *CoHereModel) ParseFile() { + //TODO implement me + panic("implement me") +} + func (c *CoHereModel) NewInstance(baseURL map[string]string) ModelDriver { return &CoHereModel{ BaseURL: baseURL, @@ -480,6 +485,29 @@ func (c *CoHereModel) Rerank(modelName *string, query string, documents []string return &rerankResponse, nil } +// TranscribeAudio transcribe audio +func (c *CoHereModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", c.Name()) +} + +func (z *CoHereModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (c *CoHereModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", c.Name()) +} + +func (z *CoHereModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (c *CoHereModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", c.Name()) +} + func (c *CoHereModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { diff --git a/internal/entity/models/deepseek.go b/internal/entity/models/deepseek.go index 1f4e107e42..8b52418cb7 100644 --- a/internal/entity/models/deepseek.go +++ b/internal/entity/models/deepseek.go @@ -36,6 +36,11 @@ type DeepSeekModel struct { httpClient *http.Client // Reusable HTTP client with connection pool } +func (z *DeepSeekModel) ParseFile() { + //TODO implement me + panic("implement me") +} + // NewDeepSeekModel creates a new DeepSeek model instance func NewDeepSeekModel(baseURL map[string]string, urlSuffix URLSuffix) *DeepSeekModel { return &DeepSeekModel{ @@ -584,3 +589,26 @@ func (z *DeepSeekModel) CheckConnection(apiConfig *APIConfig) error { func (z *DeepSeekModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) } + +// TranscribeAudio transcribe audio +func (d *DeepSeekModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", d.Name()) +} + +func (z *DeepSeekModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (d *DeepSeekModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", d.Name()) +} + +func (z *DeepSeekModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (d *DeepSeekModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", d.Name()) +} diff --git a/internal/entity/models/dummy.go b/internal/entity/models/dummy.go index 149c69af73..2dd29e0929 100644 --- a/internal/entity/models/dummy.go +++ b/internal/entity/models/dummy.go @@ -26,6 +26,11 @@ type DummyModel struct { URLSuffix URLSuffix } +func (d *DummyModel) ParseFile() { + //TODO implement me + panic("implement me") +} + // NewDummyModel creates a new Dummy AI model instance func NewDummyModel(baseURL map[string]string, urlSuffix URLSuffix) *DummyModel { return &DummyModel{ @@ -34,42 +39,65 @@ func NewDummyModel(baseURL map[string]string, urlSuffix URLSuffix) *DummyModel { } } -func (z *DummyModel) NewInstance(baseURL map[string]string) ModelDriver { +func (d *DummyModel) NewInstance(baseURL map[string]string) ModelDriver { return nil } -func (z *DummyModel) Name() string { +func (d *DummyModel) Name() string { return "dummy" } // ChatWithMessages sends multiple messages with roles and returns response -func (z *DummyModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { +func (d *DummyModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { return nil, fmt.Errorf("not implemented") } // ChatStreamlyWithSender sends messages and streams response via sender function (best performance, no channel) -func (z *DummyModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { +func (d *DummyModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { return fmt.Errorf("not implemented") } // Embed embeds a list of texts into embeddings -func (z *DummyModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { +func (d *DummyModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { return nil, fmt.Errorf("not implemented") } -func (z *DummyModel) ListModels(apiConfig *APIConfig) ([]string, error) { +func (d *DummyModel) ListModels(apiConfig *APIConfig) ([]string, error) { return nil, fmt.Errorf("not implemented") } -func (z *DummyModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { +func (d *DummyModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { return nil, fmt.Errorf("no such method") } -func (z *DummyModel) CheckConnection(apiConfig *APIConfig) error { +func (d *DummyModel) CheckConnection(apiConfig *APIConfig) error { return fmt.Errorf("no such method") } // Rerank calculates similarity scores between query and documents -func (z *DummyModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { - return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) +func (d *DummyModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s, Rerank not implemented", d.Name()) +} + +// TranscribeAudio transcribe audio +func (d *DummyModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", d.Name()) +} + +func (z *DummyModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (d *DummyModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", d.Name()) +} + +func (z *DummyModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (d *DummyModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", d.Name()) } diff --git a/internal/entity/models/fishaudio.go b/internal/entity/models/fishaudio.go index d767816006..66ff4b1dda 100644 --- a/internal/entity/models/fishaudio.go +++ b/internal/entity/models/fishaudio.go @@ -17,6 +17,11 @@ type FishAudioModel struct { httpClient *http.Client } +func (f *FishAudioModel) ParseFile() { + //TODO implement me + panic("implement me") +} + func NewFishAudioModel(baseURL map[string]string, urlSuffix URLSuffix) *FishAudioModel { return &FishAudioModel{ BaseURL: baseURL, @@ -56,6 +61,30 @@ func (f *FishAudioModel) Embed(modelName *string, texts []string, apiConfig *API func (f *FishAudioModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("no such method") } + +// TranscribeAudio transcribe audio +func (f *FishAudioModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", f.Name()) +} + +func (z *FishAudioModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (f *FishAudioModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", f.Name()) +} + +func (z *FishAudioModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (f *FishAudioModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", f.Name()) +} + func (f *FishAudioModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { diff --git a/internal/entity/models/gitee.go b/internal/entity/models/gitee.go index 335ec63484..ac7424bfde 100644 --- a/internal/entity/models/gitee.go +++ b/internal/entity/models/gitee.go @@ -36,6 +36,11 @@ type GiteeModel struct { httpClient *http.Client // Reusable HTTP client with connection pool } +func (g *GiteeModel) ParseFile() { + //TODO implement me + panic("implement me") +} + // NewGiteeModel creates a new Gitee model instance func NewGiteeModel(baseURL map[string]string, urlSuffix URLSuffix) *GiteeModel { return &GiteeModel{ @@ -53,16 +58,16 @@ func NewGiteeModel(baseURL map[string]string, urlSuffix URLSuffix) *GiteeModel { } } -func (z *GiteeModel) NewInstance(baseURL map[string]string) ModelDriver { +func (g *GiteeModel) NewInstance(baseURL map[string]string) ModelDriver { return nil } -func (z *GiteeModel) Name() string { +func (g *GiteeModel) Name() string { return "gitee" } // ChatWithMessages sends multiple messages with roles and returns response -func (z *GiteeModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { +func (g *GiteeModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { return nil, fmt.Errorf("api key is nil or empty") } @@ -75,7 +80,7 @@ func (z *GiteeModel) ChatWithMessages(modelName string, messages []Message, apiC if apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region } - url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Chat) + url := fmt.Sprintf("%s/%s", g.BaseURL[region], g.URLSuffix.Chat) // Convert messages to the format expected by API apiMessages := make([]map[string]interface{}, len(messages)) @@ -144,7 +149,7 @@ func (z *GiteeModel) ChatWithMessages(modelName string, messages []Message, apiC req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) - resp, err := z.httpClient.Do(req) + resp, err := g.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } @@ -213,7 +218,7 @@ func (z *GiteeModel) ChatWithMessages(modelName string, messages []Message, apiC } // ChatStreamlyWithSender sends messages and streams response via sender function (best performance, no channel) -func (z *GiteeModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { +func (g *GiteeModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { if len(messages) == 0 { return fmt.Errorf("messages is empty") } @@ -223,7 +228,7 @@ func (z *GiteeModel) ChatStreamlyWithSender(modelName string, messages []Message region = *apiConfig.Region } - url := fmt.Sprintf("%s/chat/completions", z.BaseURL[region]) + url := fmt.Sprintf("%s/chat/completions", g.BaseURL[region]) // Convert messages to API format apiMessages := make([]map[string]interface{}, len(messages)) @@ -291,7 +296,7 @@ func (z *GiteeModel) ChatStreamlyWithSender(modelName string, messages []Message req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) - resp, err := z.httpClient.Do(req) + resp, err := g.httpClient.Do(req) if err != nil { return fmt.Errorf("failed to send request: %w", err) } @@ -417,7 +422,7 @@ type giteeUsage struct { } // Embed embeds a list of texts into embeddings -func (z *GiteeModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { +func (g *GiteeModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { if len(texts) == 0 { return []EmbeddingData{}, nil } @@ -435,9 +440,9 @@ func (z *GiteeModel) Embed(modelName *string, texts []string, apiConfig *APIConf region = *apiConfig.Region } - baseURL := z.BaseURL["default"] + baseURL := g.BaseURL["default"] if region != "default" { - if regional, ok := z.BaseURL[region]; ok && regional != "" { + if regional, ok := g.BaseURL[region]; ok && regional != "" { baseURL = regional } } @@ -445,7 +450,7 @@ func (z *GiteeModel) Embed(modelName *string, texts []string, apiConfig *APIConf return nil, fmt.Errorf("gitee: no base URL configured for default region") } - url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), z.URLSuffix.Embedding) + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), g.URLSuffix.Embedding) reqBody := map[string]interface{}{ "model": *modelName, @@ -471,7 +476,7 @@ func (z *GiteeModel) Embed(modelName *string, texts []string, apiConfig *APIConf req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) - resp, err := z.httpClient.Do(req) + resp, err := g.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } @@ -511,7 +516,7 @@ type giteeRerankRequest struct { } // Rerank calculates similarity scores between query and documents -func (z *GiteeModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { +func (g *GiteeModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { if len(documents) == 0 { return &RerankResponse{}, nil } @@ -529,9 +534,9 @@ func (z *GiteeModel) Rerank(modelName *string, query string, documents []string, region = *apiConfig.Region } - baseURL := z.BaseURL["default"] + baseURL := g.BaseURL["default"] if region != "default" { - if regional, ok := z.BaseURL[region]; ok && regional != "" { + if regional, ok := g.BaseURL[region]; ok && regional != "" { baseURL = regional } } @@ -539,7 +544,7 @@ func (z *GiteeModel) Rerank(modelName *string, query string, documents []string, return nil, fmt.Errorf("gitee: no base URL configured for default region") } - url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), z.URLSuffix.Rerank) + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), g.URLSuffix.Rerank) var topN = rerankConfig.TopN if rerankConfig.TopN == 0 { @@ -570,7 +575,7 @@ func (z *GiteeModel) Rerank(modelName *string, query string, documents []string, req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) - resp, err := z.httpClient.Do(req) + resp, err := g.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } @@ -593,13 +598,36 @@ func (z *GiteeModel) Rerank(modelName *string, query string, documents []string, return &rerankResponse, nil } -func (z *GiteeModel) ListModels(apiConfig *APIConfig) ([]string, error) { +// TranscribeAudio transcribe audio +func (g *GiteeModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", g.Name()) +} + +func (z *GiteeModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (g *GiteeModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", g.Name()) +} + +func (z *GiteeModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (g *GiteeModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", g.Name()) +} + +func (g *GiteeModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" if apiConfig.Region != nil { region = *apiConfig.Region } - url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Models) + url := fmt.Sprintf("%s/%s", g.BaseURL[region], g.URLSuffix.Models) // Build request body reqBody := map[string]interface{}{} @@ -617,7 +645,7 @@ func (z *GiteeModel) ListModels(apiConfig *APIConfig) ([]string, error) { req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) - resp, err := z.httpClient.Do(req) + resp, err := g.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } @@ -650,13 +678,13 @@ func (z *GiteeModel) ListModels(apiConfig *APIConfig) ([]string, error) { return models, nil } -func (z *GiteeModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { +func (g *GiteeModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { var region = "default" if apiConfig.Region != nil { region = *apiConfig.Region } - url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Balance) + url := fmt.Sprintf("%s/%s", g.BaseURL[region], g.URLSuffix.Balance) // Build request body reqBody := map[string]interface{}{} @@ -674,7 +702,7 @@ func (z *GiteeModel) Balance(apiConfig *APIConfig) (map[string]interface{}, erro req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) - resp, err := z.httpClient.Do(req) + resp, err := g.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } @@ -705,13 +733,13 @@ func (z *GiteeModel) Balance(apiConfig *APIConfig) (map[string]interface{}, erro return response, nil } -func (z *GiteeModel) CheckConnection(apiConfig *APIConfig) error { +func (g *GiteeModel) CheckConnection(apiConfig *APIConfig) error { var region = "default" if apiConfig.Region != nil { region = *apiConfig.Region } - url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Status) + url := fmt.Sprintf("%s/%s", g.BaseURL[region], g.URLSuffix.Status) // Build request body reqBody := map[string]interface{}{} @@ -729,7 +757,7 @@ func (z *GiteeModel) CheckConnection(apiConfig *APIConfig) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) - resp, err := z.httpClient.Do(req) + resp, err := g.httpClient.Do(req) if err != nil { return fmt.Errorf("failed to send request: %w", err) } diff --git a/internal/entity/models/google.go b/internal/entity/models/google.go index fabd51e4c3..b0bcbf4026 100644 --- a/internal/entity/models/google.go +++ b/internal/entity/models/google.go @@ -77,6 +77,11 @@ type GoogleModel struct { URLSuffix URLSuffix } +func (g *GoogleModel) ParseFile() { + //TODO implement me + panic("implement me") +} + // NewGoogleModel creates a new Google AI model instance func NewGoogleModel(baseURL map[string]string, urlSuffix URLSuffix) *GoogleModel { return &GoogleModel{ @@ -85,15 +90,15 @@ func NewGoogleModel(baseURL map[string]string, urlSuffix URLSuffix) *GoogleModel } } -func (z *GoogleModel) NewInstance(baseURL map[string]string) ModelDriver { +func (g *GoogleModel) NewInstance(baseURL map[string]string) ModelDriver { return nil } -func (z *GoogleModel) Name() string { +func (g *GoogleModel) Name() string { return "google" } -func (z *GoogleModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { +func (g *GoogleModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { return nil, fmt.Errorf("api key is nil or empty") } @@ -167,7 +172,7 @@ func (z *GoogleModel) ChatWithMessages(modelName string, messages []Message, api } // ChatStreamlyWithSender sends messages and streams response via sender function (best performance, no channel) -func (z *GoogleModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { +func (g *GoogleModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { if len(messages) == 0 { return fmt.Errorf("messages is empty") } @@ -261,7 +266,7 @@ func (z *GoogleModel) ChatStreamlyWithSender(modelName string, messages []Messag // Embed generates embeddings for a batch of texts using the Gemini embeddings API. // The SDK routes to batchEmbedContents internally, so all texts are sent in one request. -func (z *GoogleModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { +func (g *GoogleModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { return nil, fmt.Errorf("api key is required") } @@ -318,7 +323,7 @@ func (z *GoogleModel) Embed(modelName *string, texts []string, apiConfig *APICon return result, nil } -func (z *GoogleModel) ListModels(apiConfig *APIConfig) ([]string, error) { +func (g *GoogleModel) ListModels(apiConfig *APIConfig) ([]string, error) { if apiConfig == nil || apiConfig.ApiKey == nil || strings.TrimSpace(*apiConfig.ApiKey) == "" { return nil, fmt.Errorf("api key is required") } @@ -326,16 +331,39 @@ func (z *GoogleModel) ListModels(apiConfig *APIConfig) ([]string, error) { return googleListModels(context.Background(), *apiConfig.ApiKey) } -func (z *GoogleModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { +func (g *GoogleModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { return nil, fmt.Errorf("no such method") } -func (z *GoogleModel) CheckConnection(apiConfig *APIConfig) error { - _, err := z.ListModels(apiConfig) +func (g *GoogleModel) CheckConnection(apiConfig *APIConfig) error { + _, err := g.ListModels(apiConfig) return err } // Rerank calculates similarity scores between query and documents -func (z *GoogleModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { - return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) +func (g *GoogleModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s, Rerank not implemented", g.Name()) +} + +// TranscribeAudio transcribe audio +func (g *GoogleModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", g.Name()) +} + +func (z *GoogleModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (g *GoogleModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", g.Name()) +} + +func (z *GoogleModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (g *GoogleModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", g.Name()) } diff --git a/internal/entity/models/huggingface.go b/internal/entity/models/huggingface.go index 8684aedca1..b2dedbc7f5 100644 --- a/internal/entity/models/huggingface.go +++ b/internal/entity/models/huggingface.go @@ -19,6 +19,11 @@ type HuggingFaceModel struct { httpClient *http.Client } +func (h *HuggingFaceModel) ParseFile() { + //TODO implement me + panic("implement me") +} + // NewHuggingFaceModel creates a new huggingFace model instance func NewHuggingFaceModel(baseURL map[string]string, urlSuffix URLSuffix) *HuggingFaceModel { return &HuggingFaceModel{ @@ -411,6 +416,29 @@ func (h *HuggingFaceModel) Rerank(modelName *string, query string, documents []s return nil, fmt.Errorf("no such method") } +// TranscribeAudio transcribe audio +func (h *HuggingFaceModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", h.Name()) +} + +func (z *HuggingFaceModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (h *HuggingFaceModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", h.Name()) +} + +func (z *HuggingFaceModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (h *HuggingFaceModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", h.Name()) +} + func (h *HuggingFaceModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" if apiConfig.Region != nil && *apiConfig.Region != "" { diff --git a/internal/entity/models/lmstudio.go b/internal/entity/models/lmstudio.go index 136d8bb571..e62814a505 100644 --- a/internal/entity/models/lmstudio.go +++ b/internal/entity/models/lmstudio.go @@ -20,6 +20,11 @@ type LmStudioModel struct { httpClient *http.Client } +func (l *LmStudioModel) ParseFile() { + //TODO implement me + panic("implement me") +} + // NewLmStudioModel func NewLmStudioModel(baseURL map[string]string, urlSuffix URLSuffix) *LmStudioModel { return &LmStudioModel{ @@ -447,6 +452,29 @@ func (l *LmStudioModel) Rerank(modelName *string, query string, documents []stri return nil, fmt.Errorf("no such method") } +// TranscribeAudio transcribe audio +func (z *LmStudioModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *LmStudioModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (z *LmStudioModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *LmStudioModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (l *LmStudioModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", l.Name()) +} + // ListModels list supported models func (l *LmStudioModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" diff --git a/internal/entity/models/minimax.go b/internal/entity/models/minimax.go index 67b4e83907..9919933bd6 100644 --- a/internal/entity/models/minimax.go +++ b/internal/entity/models/minimax.go @@ -35,6 +35,11 @@ type MinimaxModel struct { httpClient *http.Client // Reusable HTTP client with connection pool } +func (z *MinimaxModel) ParseFile() { + //TODO implement me + panic("implement me") +} + // NewMinimaxModel creates a new Minimax model instance func NewMinimaxModel(baseURL map[string]string, urlSuffix URLSuffix) *MinimaxModel { return &MinimaxModel{ @@ -447,3 +452,26 @@ func (z *MinimaxModel) CheckConnection(apiConfig *APIConfig) error { func (z *MinimaxModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) } + +// TranscribeAudio transcribe audio +func (z *MinimaxModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *MinimaxModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (z *MinimaxModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *MinimaxModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *MinimaxModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} diff --git a/internal/entity/models/moonshot.go b/internal/entity/models/moonshot.go index 2c1443251b..9e8e5a99a9 100644 --- a/internal/entity/models/moonshot.go +++ b/internal/entity/models/moonshot.go @@ -35,6 +35,11 @@ type MoonshotModel struct { httpClient *http.Client // Reusable HTTP client with connection pool } +func (m *MoonshotModel) ParseFile() { + //TODO implement me + panic("implement me") +} + // NewMoonshotModel creates a new Moonshot model instance func NewMoonshotModel(baseURL map[string]string, urlSuffix URLSuffix) *MoonshotModel { return &MoonshotModel{ @@ -487,3 +492,26 @@ func (z *MoonshotModel) CheckConnection(apiConfig *APIConfig) error { func (z *MoonshotModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) } + +// TranscribeAudio transcribe audio +func (z *MoonshotModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *MoonshotModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (z *MoonshotModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *MoonshotModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *MoonshotModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} diff --git a/internal/entity/models/nvidia.go b/internal/entity/models/nvidia.go index 88029dac15..9dc9763510 100644 --- a/internal/entity/models/nvidia.go +++ b/internal/entity/models/nvidia.go @@ -19,6 +19,11 @@ type NvidiaModel struct { httpClient *http.Client } +func (n NvidiaModel) ParseFile() { + //TODO implement me + panic("implement me") +} + // NewNvidiaModel creates a new Nvidia model instance func NewNvidiaModel(baseURL map[string]string, urlSuffix URLSuffix) *NvidiaModel { return &NvidiaModel{ @@ -552,6 +557,29 @@ func (n NvidiaModel) Rerank(modelName *string, query string, documents []string, return &rerankResponse, nil } +// TranscribeAudio transcribe audio +func (n *NvidiaModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", n.Name()) +} + +func (z *NvidiaModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (n *NvidiaModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", n.Name()) +} + +func (z *NvidiaModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *NvidiaModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + // ListModels calls /v1/models on the configured NVIDIA NIM base URL // and returns the list of available model ids. The endpoint is // OpenAI-compatible, so the parsing follows the same shape used by diff --git a/internal/entity/models/ollama.go b/internal/entity/models/ollama.go index d1b05588d7..2ba36b27f3 100644 --- a/internal/entity/models/ollama.go +++ b/internal/entity/models/ollama.go @@ -20,6 +20,11 @@ type OllamaModel struct { httpClient *http.Client } +func (o *OllamaModel) ParseFile() { + //TODO implement me + panic("implement me") +} + // NewOllamaModel creates a new Ollama AI model instance func NewOllamaModel(baseURL map[string]string, urlSuffix URLSuffix) *OllamaModel { return &OllamaModel{ @@ -445,6 +450,29 @@ func (o *OllamaModel) Rerank(modelName *string, query string, documents []string return nil, fmt.Errorf("no such method") } +// TranscribeAudio transcribe audio +func (o *OllamaModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *OllamaModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (o *OllamaModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *OllamaModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *OllamaModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + func (o *OllamaModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" diff --git a/internal/entity/models/openai.go b/internal/entity/models/openai.go index 6461444e7b..69ea5cf190 100644 --- a/internal/entity/models/openai.go +++ b/internal/entity/models/openai.go @@ -37,6 +37,11 @@ type OpenAIModel struct { httpClient *http.Client // Reusable HTTP client with connection pool } +func (o *OpenAIModel) ParseFile() { + //TODO implement me + panic("implement me") +} + // NewOpenAIModel creates a new OpenAI model instance. // // We clone http.DefaultTransport so we keep Go's defaults for @@ -593,3 +598,26 @@ func (z *OpenAIModel) CheckConnection(apiConfig *APIConfig) error { func (z *OpenAIModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) } + +// TranscribeAudio transcribe audio +func (o *OpenAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *OpenAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (o *OpenAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *OpenAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *OpenAIModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} diff --git a/internal/entity/models/openrouter.go b/internal/entity/models/openrouter.go index 7ebf09b5fb..41bed6f81e 100644 --- a/internal/entity/models/openrouter.go +++ b/internal/entity/models/openrouter.go @@ -19,6 +19,11 @@ type OpenRouterModel struct { httpClient *http.Client } +func (o *OpenRouterModel) ParseFile() { + //TODO implement me + panic("implement me") +} + // NewOpenRouterModel creates a new OpenRouter AI model instance func NewOpenRouterModel(baseURL map[string]string, urlSuffix URLSuffix) *OpenRouterModel { return &OpenRouterModel{ @@ -529,6 +534,29 @@ func (o *OpenRouterModel) Rerank(modelName *string, query string, documents []st return &rerankResponse, nil } +// TranscribeAudio transcribe audio +func (o *OpenRouterModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *OpenRouterModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (o *OpenRouterModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *OpenRouterModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *OpenRouterModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + func (o *OpenRouterModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { diff --git a/internal/entity/models/siliconflow.go b/internal/entity/models/siliconflow.go index 3659ddef02..a530086850 100644 --- a/internal/entity/models/siliconflow.go +++ b/internal/entity/models/siliconflow.go @@ -36,6 +36,11 @@ type SiliconflowModel struct { httpClient *http.Client // Reusable HTTP client with connection pool } +func (s *SiliconflowModel) ParseFile() { + //TODO implement me + panic("implement me") +} + // NewSiliconflowModel creates a new Siliconflow model instance func NewSiliconflowModel(baseURL map[string]string, urlSuffix URLSuffix) *SiliconflowModel { return &SiliconflowModel{ @@ -720,3 +725,26 @@ func (s *SiliconflowModel) Rerank(modelName *string, query string, documents []s } return &rerankResponse, nil } + +// TranscribeAudio transcribe audio +func (o *SiliconflowModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *SiliconflowModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (o *SiliconflowModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *SiliconflowModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *SiliconflowModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} diff --git a/internal/entity/models/stepfun.go b/internal/entity/models/stepfun.go index ddccbabb3d..2fd0a9e829 100644 --- a/internal/entity/models/stepfun.go +++ b/internal/entity/models/stepfun.go @@ -457,3 +457,26 @@ func (s *StepFunModel) CheckConnection(apiConfig *APIConfig) error { func (s *StepFunModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("no such method") } + +// TranscribeAudio transcribe audio +func (z *StepFunModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *StepFunModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (z *StepFunModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *StepFunModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (z *StepFunModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/types.go b/internal/entity/models/types.go index 3a32cec9dd..991ceedbce 100644 --- a/internal/entity/models/types.go +++ b/internal/entity/models/types.go @@ -26,6 +26,14 @@ type ModelDriver interface { Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) // Rerank calculates similarity scores between query and texts Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) + // TranscribeAudio transcribe audio + TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) + TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error + // AudioSpeech convert audio to text + AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) + AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error + // OCRFile OCR file + OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) // ListModels List supported models ListModels(apiConfig *APIConfig) ([]string, error) @@ -53,6 +61,15 @@ type RerankResponse struct { Data []RerankResult `json:"data"` } +type ASRResponse struct { +} + +type TTSResponse struct { +} + +type OCRResponse struct { +} + // URLSuffix represents the URL suffixes for different API endpoints type URLSuffix struct { Chat string `json:"chat"` @@ -93,6 +110,15 @@ type RerankConfig struct { TopN int } +type ASRConfig struct { +} + +type TTSConfig struct { +} + +type OCRConfig struct { +} + // EmbeddingModel wraps a ModelDriver with embedding-specific configuration type EmbeddingModel struct { ModelDriver ModelDriver diff --git a/internal/entity/models/upstage.go b/internal/entity/models/upstage.go index fad7f857ac..c68abcce08 100644 --- a/internal/entity/models/upstage.go +++ b/internal/entity/models/upstage.go @@ -584,3 +584,26 @@ func (u *UpstageModel) CheckConnection(apiConfig *APIConfig) error { func (u *UpstageModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("no such method") } + +// TranscribeAudio transcribe audio +func (z *UpstageModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *UpstageModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (z *UpstageModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *UpstageModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (z *UpstageModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/vllm.go b/internal/entity/models/vllm.go index a7e3e118fb..2fe1f78fd7 100644 --- a/internal/entity/models/vllm.go +++ b/internal/entity/models/vllm.go @@ -36,6 +36,11 @@ type VllmModel struct { httpClient *http.Client // Reusable HTTP client with connection pool } +func (v *VllmModel) ParseFile() { + //TODO implement me + panic("implement me") +} + // NewVllmModel creates a new Vllm AI model instance func NewVllmModel(baseURL map[string]string, urlSuffix URLSuffix) *VllmModel { return &VllmModel{ @@ -551,3 +556,26 @@ func (z *VllmModel) CheckConnection(apiConfig *APIConfig) error { func (z *VllmModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) } + +// TranscribeAudio transcribe audio +func (o *VllmModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *VllmModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (o *VllmModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *VllmModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *VllmModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} diff --git a/internal/entity/models/volcengine.go b/internal/entity/models/volcengine.go index 22da539936..e5ad964525 100644 --- a/internal/entity/models/volcengine.go +++ b/internal/entity/models/volcengine.go @@ -510,6 +510,29 @@ func (z *VolcEngine) Rerank(modelName *string, query string, documents []string, return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) } +// TranscribeAudio transcribe audio +func (o *VolcEngine) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *VolcEngine) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (o *VolcEngine) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *VolcEngine) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *VolcEngine) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + func (z *VolcEngine) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { diff --git a/internal/entity/models/xai.go b/internal/entity/models/xai.go index 1b3175d4b7..bc0391adb7 100644 --- a/internal/entity/models/xai.go +++ b/internal/entity/models/xai.go @@ -492,3 +492,26 @@ func (z *XAIModel) CheckConnection(apiConfig *APIConfig) error { func (z *XAIModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) } + +// TranscribeAudio transcribe audio +func (o *XAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *XAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (o *XAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *XAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *XAIModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} diff --git a/internal/entity/models/zhipu-ai.go b/internal/entity/models/zhipu-ai.go index e4041614f8..a381105534 100644 --- a/internal/entity/models/zhipu-ai.go +++ b/internal/entity/models/zhipu-ai.go @@ -157,7 +157,7 @@ func (z *ZhipuAIModel) ChatWithMessages(modelName string, messages []Message, ap // Parse response var result map[string]interface{} - if err := json.Unmarshal(body, &result); err != nil { + if err = json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } @@ -610,3 +610,26 @@ func (z *ZhipuAIModel) Rerank(modelName *string, query string, documents []strin return &rerankResponse, nil } + +// TranscribeAudio transcribe audio +func (o *ZhipuAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *ZhipuAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert audio to text +func (o *ZhipuAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *ZhipuAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *ZhipuAIModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} diff --git a/internal/handler/providers.go b/internal/handler/providers.go index af101c60e3..f71f1220a4 100644 --- a/internal/handler/providers.go +++ b/internal/handler/providers.go @@ -1047,3 +1047,311 @@ func (h *ProviderHandler) RerankDocument(c *gin.Context) { "message": "success", }) } + +type TranscribeAudioRequest struct { + ProviderName *string `json:"provider_name"` + InstanceName *string `json:"instance_name"` + ModelName *string `json:"model_name"` + File *string `json:"file"` + Language []string `json:"language"` + Prompt int `json:"prompt"` + Stream bool `json:"stream"` +} + +func (h *ProviderHandler) TranscribeAudio(c *gin.Context) { + var req TranscribeAudioRequest + if err := c.ShouldBindJSON(&req); err != nil { + println("JSON bind error: %v (type: %T)", err, err) + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + }) + return + } + + if req.ProviderName == nil || *req.ProviderName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } + + if req.InstanceName == nil || *req.InstanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } + + if req.ModelName == nil || *req.ModelName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model name is required", + }) + return + } + + userID := c.GetString("user_id") + + apiConfig := models.APIConfig{ + ApiKey: nil, + Region: nil, + } + + asrConfig := models.ASRConfig{} + + // Check if it's a stream request + if req.Stream { + // Set SSE headers + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Writer.WriteHeader(http.StatusOK) + c.Writer.Flush() + + // Create sender function that writes directly to response + sender := func(content, reasoningContent *string) error { + // Check for [DONE] marker (OpenAI compatible) + if content != nil { + if *content == "[DONE]" { + c.SSEvent("done", "[DONE]") + return nil + } + message := fmt.Sprintf("[MESSAGE]%s", *content) + c.SSEvent("message", message) + c.Writer.Flush() + } + + if reasoningContent != nil { + message := fmt.Sprintf("[REASONING]%s", *reasoningContent) + c.SSEvent("message", message) + c.Writer.Flush() + } + + //logger.Info(data) + return nil + } + + // Stream response using sender function (best performance, no channel) + errorCode, err := h.modelProviderService.TranscribeAudioStream(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.File, &apiConfig, &asrConfig, sender) + + if errorCode != common.CodeSuccess { + c.SSEvent("error", err.Error()) + } + return + } + + // Non-stream response + var response *models.ASRResponse + var errorCode common.ErrorCode + var err error + + response, errorCode, err = h.modelProviderService.TranscribeAudio(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.File, &apiConfig, &asrConfig) + + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": errorCode, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": response, + "message": "success", + }) +} + +type AudioSpeechRequest struct { + ProviderName *string `json:"provider_name"` + InstanceName *string `json:"instance_name"` + ModelName *string `json:"model_name"` + Text *string `json:"text"` + Language []string `json:"language"` + Voice int `json:"voice"` + Stream bool `json:"stream"` + Volume bool `json:"volume"` +} + +func (h *ProviderHandler) AudioSpeech(c *gin.Context) { + var req AudioSpeechRequest + if err := c.ShouldBindJSON(&req); err != nil { + println("JSON bind error: %v (type: %T)", err, err) + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + }) + return + } + + if req.ProviderName == nil || *req.ProviderName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } + + if req.InstanceName == nil || *req.InstanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } + + if req.ModelName == nil || *req.ModelName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model name is required", + }) + return + } + + userID := c.GetString("user_id") + + apiConfig := models.APIConfig{ + ApiKey: nil, + Region: nil, + } + + ttsConfig := models.TTSConfig{} + + // Check if it's a stream request + if req.Stream { + // Set SSE headers + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Writer.WriteHeader(http.StatusOK) + c.Writer.Flush() + + // Create sender function that writes directly to response + sender := func(content, reasoningContent *string) error { + // Check for [DONE] marker (OpenAI compatible) + if content != nil { + if *content == "[DONE]" { + c.SSEvent("done", "[DONE]") + return nil + } + message := fmt.Sprintf("[MESSAGE]%s", *content) + c.SSEvent("message", message) + c.Writer.Flush() + } + + if reasoningContent != nil { + message := fmt.Sprintf("[REASONING]%s", *reasoningContent) + c.SSEvent("message", message) + c.Writer.Flush() + } + + //logger.Info(data) + return nil + } + + // Stream response using sender function (best performance, no channel) + errorCode, err := h.modelProviderService.AudioSpeechStream(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Text, &apiConfig, &ttsConfig, sender) + + if errorCode != common.CodeSuccess { + c.SSEvent("error", err.Error()) + } + return + } + + // Non-stream response + var response *models.TTSResponse + var errorCode common.ErrorCode + var err error + + response, errorCode, err = h.modelProviderService.AudioSpeech(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Text, &apiConfig, &ttsConfig) + + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": errorCode, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": response, + "message": "success", + }) +} + +type OCRFileRequest struct { + ProviderName *string `json:"provider_name"` + InstanceName *string `json:"instance_name"` + ModelName *string `json:"model_name"` + File *string `json:"file"` +} + +func (h *ProviderHandler) OCRFile(c *gin.Context) { + var req OCRFileRequest + if err := c.ShouldBindJSON(&req); err != nil { + println("JSON bind error: %v (type: %T)", err, err) + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + }) + return + } + + if req.ProviderName == nil || *req.ProviderName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } + + if req.InstanceName == nil || *req.InstanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } + + if req.ModelName == nil || *req.ModelName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model name is required", + }) + return + } + + userID := c.GetString("user_id") + + apiConfig := models.APIConfig{ + ApiKey: nil, + Region: nil, + } + + OCRConfig := models.OCRConfig{} + + // Non-stream response + var response *models.OCRResponse + var errorCode common.ErrorCode + var err error + + response, errorCode, err = h.modelProviderService.OCRFile(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.File, &apiConfig, &OCRConfig) + + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": errorCode, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": response, + "message": "success", + }) +} diff --git a/internal/router/router.go b/internal/router/router.go index 67ae4e0a12..05a56ff8c8 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -272,6 +272,9 @@ func (r *Router) Setup(engine *gin.Engine) { v1.POST("/chat/completions", r.providerHandler.ChatToModel) v1.POST("/embeddings", r.providerHandler.EmbedText) v1.POST("/rerank", r.providerHandler.RerankDocument) + v1.POST("/audio/transcriptions", r.providerHandler.TranscribeAudio) + v1.POST("/audio/speech", r.providerHandler.AudioSpeech) + v1.POST("/file/ocr", r.providerHandler.OCRFile) } model := v1.Group("/models") diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 5ac2495198..446e2f90cb 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -1100,6 +1100,487 @@ func (m *ModelProviderService) RerankDocument(providerName, instanceName, modelN return nil, common.CodeServerError, errors.New("model is disabled") } +// TranscribeAudio transcribe audio file to text +func (m *ModelProviderService) TranscribeAudio(providerName, instanceName, modelName, userID string, audioFile *string, apiConfig *modelModule.APIConfig, asrConfig *modelModule.ASRConfig) (*modelModule.ASRResponse, common.ErrorCode, error) { + if apiConfig == nil { + apiConfig = &modelModule.APIConfig{} + } + if asrConfig == nil { + asrConfig = &modelModule.ASRConfig{} + } + + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return nil, common.CodeServerError, err + } + + if len(tenants) == 0 { + return nil, common.CodeNotFound, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return nil, common.CodeServerError, err + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return nil, common.CodeServerError, err + } + + modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) + if err != nil { + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeNotFound, errors.New("provider not found") + } + + _, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + var response *modelModule.ASRResponse + response, err = providerInfo.ModelDriver.TranscribeAudio(&modelName, audioFile, apiConfig, asrConfig) + if err != nil { + return nil, common.CodeServerError, err + } + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") + } + + return response, common.CodeSuccess, nil + } + + if modelInfo.Status == "active" { + // For local deployed models + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeNotFound, errors.New("provider not found") + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + newURL := map[string]string{ + region: extra["base_url"], + } + newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) + + var response *modelModule.ASRResponse + response, err = newProviderInfo.TranscribeAudio(&modelName, audioFile, apiConfig, asrConfig) + if err != nil { + return nil, common.CodeServerError, err + } + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") + } + + return response, common.CodeSuccess, nil + } + + return nil, common.CodeServerError, errors.New("model is disabled") +} + +// ChatToModelStreamWithSender streams chat response directly via sender function (best performance, no channel) +func (m *ModelProviderService) TranscribeAudioStream(providerName, instanceName, modelName, userID string, audioFile *string, apiConfig *modelModule.APIConfig, asrConfig *modelModule.ASRConfig, sender func(*string, *string) error) (common.ErrorCode, error) { + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return common.CodeServerError, err + } + + if len(tenants) == 0 { + return common.CodeNotFound, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return common.CodeServerError, err + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return common.CodeServerError, err + } + + modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) + if err != nil { + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return common.CodeNotFound, err + } + + _, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return common.CodeNotFound, err + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + err = providerInfo.ModelDriver.TranscribeAudioWithSender(&modelName, audioFile, apiConfig, asrConfig, sender) + if err != nil { + return common.CodeServerError, err + } + + return common.CodeSuccess, nil + } + + if modelInfo.Status == "active" { + // For local deployed models + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return common.CodeNotFound, errors.New("provider not found") + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + newURL := map[string]string{ + region: extra["base_url"], + } + newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) + + err = newProviderInfo.TranscribeAudioWithSender(&modelName, audioFile, apiConfig, asrConfig, sender) + if err != nil { + return common.CodeServerError, err + } + return common.CodeSuccess, nil + } + + return common.CodeServerError, errors.New("model is disabled") +} + +// TranscribeAudio transcribe audio file to text +func (m *ModelProviderService) AudioSpeech(providerName, instanceName, modelName, userID string, audioContent *string, apiConfig *modelModule.APIConfig, ttsConfig *modelModule.TTSConfig) (*modelModule.TTSResponse, common.ErrorCode, error) { + if apiConfig == nil { + apiConfig = &modelModule.APIConfig{} + } + if ttsConfig == nil { + ttsConfig = &modelModule.TTSConfig{} + } + + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return nil, common.CodeServerError, err + } + + if len(tenants) == 0 { + return nil, common.CodeNotFound, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return nil, common.CodeServerError, err + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return nil, common.CodeServerError, err + } + + modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) + if err != nil { + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeNotFound, errors.New("provider not found") + } + + _, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + var response *modelModule.TTSResponse + response, err = providerInfo.ModelDriver.AudioSpeech(&modelName, audioContent, apiConfig, ttsConfig) + if err != nil { + return nil, common.CodeServerError, err + } + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") + } + + return response, common.CodeSuccess, nil + } + + if modelInfo.Status == "active" { + // For local deployed models + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeNotFound, errors.New("provider not found") + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + newURL := map[string]string{ + region: extra["base_url"], + } + newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) + + var response *modelModule.TTSResponse + response, err = newProviderInfo.AudioSpeech(&modelName, audioContent, apiConfig, ttsConfig) + if err != nil { + return nil, common.CodeServerError, err + } + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") + } + + return response, common.CodeSuccess, nil + } + + return nil, common.CodeServerError, errors.New("model is disabled") +} + +func (m *ModelProviderService) AudioSpeechStream(providerName, instanceName, modelName, userID string, audioContent *string, apiConfig *modelModule.APIConfig, ttsConfig *modelModule.TTSConfig, sender func(*string, *string) error) (common.ErrorCode, error) { + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return common.CodeServerError, err + } + + if len(tenants) == 0 { + return common.CodeNotFound, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return common.CodeServerError, err + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return common.CodeServerError, err + } + + modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) + if err != nil { + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return common.CodeNotFound, err + } + + _, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return common.CodeNotFound, err + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + err = providerInfo.ModelDriver.AudioSpeechWithSender(&modelName, audioContent, apiConfig, ttsConfig, sender) + if err != nil { + return common.CodeServerError, err + } + + return common.CodeSuccess, nil + } + + if modelInfo.Status == "active" { + // For local deployed models + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return common.CodeNotFound, errors.New("provider not found") + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + newURL := map[string]string{ + region: extra["base_url"], + } + newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) + + err = newProviderInfo.AudioSpeechWithSender(&modelName, audioContent, apiConfig, ttsConfig, sender) + if err != nil { + return common.CodeServerError, err + } + return common.CodeSuccess, nil + } + + return common.CodeServerError, errors.New("model is disabled") +} + +func (m *ModelProviderService) OCRFile(providerName, instanceName, modelName, userID string, fileContent *string, apiConfig *modelModule.APIConfig, ocrConfig *modelModule.OCRConfig) (*modelModule.OCRResponse, common.ErrorCode, error) { + if apiConfig == nil { + apiConfig = &modelModule.APIConfig{} + } + if ocrConfig == nil { + ocrConfig = &modelModule.OCRConfig{} + } + + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return nil, common.CodeServerError, err + } + + if len(tenants) == 0 { + return nil, common.CodeNotFound, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return nil, common.CodeServerError, err + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return nil, common.CodeServerError, err + } + + modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) + if err != nil { + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeNotFound, errors.New("provider not found") + } + + _, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + var response *modelModule.OCRResponse + response, err = providerInfo.ModelDriver.OCRFile(&modelName, fileContent, apiConfig, ocrConfig) + if err != nil { + return nil, common.CodeServerError, err + } + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") + } + + return response, common.CodeSuccess, nil + } + + if modelInfo.Status == "active" { + // For local deployed models + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeNotFound, errors.New("provider not found") + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + newURL := map[string]string{ + region: extra["base_url"], + } + newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) + + var response *modelModule.OCRResponse + response, err = newProviderInfo.OCRFile(&modelName, fileContent, apiConfig, ocrConfig) + if err != nil { + return nil, common.CodeServerError, err + } + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") + } + + return response, common.CodeSuccess, nil + } + + return nil, common.CodeServerError, errors.New("model is disabled") +} + // GetEmbeddingModel returns an EmbeddingModel wrapper for the given tenant func (m *ModelProviderService) GetEmbeddingModel(tenantID, compositeModelName string) (*modelModule.EmbeddingModel, error) { driver, modelName, apiConfig, maxTokens, err := m.getModelConfig(tenantID, compositeModelName)