diff --git a/conf/models/openai.json b/conf/models/openai.json index 33e4a10506..b4711ab59c 100644 --- a/conf/models/openai.json +++ b/conf/models/openai.json @@ -6,7 +6,9 @@ "url_suffix": { "chat": "chat/completions", "models": "models", - "embedding": "embeddings" + "embedding": "embeddings", + "asr": "audio/transcriptions", + "tts": "audio/speech" }, "class": "gpt", "models": [ @@ -224,4 +226,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/internal/entity/models/openai.go b/internal/entity/models/openai.go index c19751231e..afb7f9ff91 100644 --- a/internal/entity/models/openai.go +++ b/internal/entity/models/openai.go @@ -20,10 +20,15 @@ import ( "bufio" "bytes" "context" + "encoding/base64" "encoding/json" "fmt" "io" + "mime/multipart" "net/http" + "os" + "path/filepath" + "strconv" "strings" "time" ) @@ -596,20 +601,437 @@ func (z *OpenAIModel) Rerank(modelName *string, query string, documents []string // TranscribeAudio transcribe audio func (o *OpenAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { - return nil, fmt.Errorf("%s, no such method", o.Name()) + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, responseFormat, err := o.newOpenAIASRRequest(ctx, modelName, file, apiConfig, asrConfig, false) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("OpenAI ASR API error: %s, body: %s", resp.Status, string(respBody)) + } + + return decodeOpenAIASRResponse(respBody, responseFormat) } func (z *OpenAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { - return fmt.Errorf("%s, no such method", z.Name()) + if sender == nil { + return fmt.Errorf("sender is required") + } + + req, responseFormat, err := z.newOpenAIASRRequest(context.Background(), modelName, file, apiConfig, asrConfig, true) + if err != nil { + return err + } + req.Header.Set("Accept", "text/event-stream") + + resp, err := z.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("OpenAI ASR stream API error: %s, body: %s", resp.Status, string(respBody)) + } + + if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") { + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + response, err := decodeOpenAIASRResponse(respBody, responseFormat) + if err != nil { + return err + } + if response != nil && response.Text != "" { + if err = sender(&response.Text, nil); err != nil { + return err + } + } + done := "[DONE]" + return sender(&done, nil) + } + + sentDelta := false + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 8*1024*1024) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + + dataStr := strings.TrimSpace(line[6:]) + if dataStr == "" { + continue + } + if dataStr == "[DONE]" { + break + } + + var event struct { + Type string `json:"type"` + Delta string `json:"delta"` + Text string `json:"text"` + } + if err = json.Unmarshal([]byte(dataStr), &event); err != nil { + continue + } + + switch { + case event.Delta != "": + if err = sender(&event.Delta, nil); err != nil { + return err + } + sentDelta = true + case event.Type == "transcript.text.segment" && event.Text != "": + if err = sender(&event.Text, nil); err != nil { + return err + } + sentDelta = true + case event.Type == "transcript.text.done" && !sentDelta && event.Text != "": + if err = sender(&event.Text, nil); err != nil { + return err + } + } + } + if err = scanner.Err(); err != nil { + return fmt.Errorf("error reading OpenAI ASR stream: %w", err) + } + + done := "[DONE]" + return sender(&done, nil) +} + +func decodeOpenAIASRResponse(respBody []byte, responseFormat string) (*ASRResponse, error) { + switch responseFormat { + case "text", "srt", "vtt": + return &ASRResponse{Text: string(respBody)}, nil + } + + var result struct { + Text string `json:"text"` + } + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w, body=%s", err, string(respBody)) + } + + return &ASRResponse{Text: result.Text}, nil } // AudioSpeech convert text to audio func (o *OpenAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { - return nil, fmt.Errorf("%s, no such method", o.Name()) + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, _, err := o.newOpenAITTSRequest(ctx, modelName, audioContent, apiConfig, ttsConfig, false) + if err != nil { + return nil, err + } + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("OpenAI TTS API error: %s, body: %s", resp.Status, string(body)) + } + + return &TTSResponse{Audio: body}, nil } func (z *OpenAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { - return fmt.Errorf("%s, no such method", z.Name()) + if sender == nil { + return fmt.Errorf("sender is required") + } + + req, streamFormat, err := z.newOpenAITTSRequest(context.Background(), modelName, audioContent, apiConfig, ttsConfig, true) + if err != nil { + return err + } + if streamFormat == "sse" { + req.Header.Set("Accept", "text/event-stream") + } + + resp, err := z.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("OpenAI TTS stream API error: %s, body: %s", resp.Status, string(body)) + } + + if streamFormat == "sse" || strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") { + return readOpenAITTSSSEStream(resp.Body, sender) + } + return readOpenAITTSRawStream(resp.Body, sender) +} + +func (o *OpenAIModel) newOpenAIASRRequest(ctx context.Context, modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, stream bool) (*http.Request, string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, "", fmt.Errorf("api key is required") + } + if modelName == nil || *modelName == "" { + return nil, "", fmt.Errorf("model name is required") + } + if file == nil || *file == "" { + return nil, "", fmt.Errorf("file is missing") + } + if strings.TrimSpace(o.URLSuffix.ASR) == "" { + return nil, "", fmt.Errorf("openai ASR URL suffix is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := o.baseURLForRegion(region) + if err != nil { + return nil, "", err + } + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), strings.TrimPrefix(o.URLSuffix.ASR, "/")) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + audioFile, err := os.Open(*file) + if err != nil { + return nil, "", fmt.Errorf("failed to open audio file: %w", err) + } + defer audioFile.Close() + + part, err := writer.CreateFormFile("file", filepath.Base(*file)) + if err != nil { + return nil, "", fmt.Errorf("failed to create multipart file: %w", err) + } + if _, err = io.Copy(part, audioFile); err != nil { + return nil, "", fmt.Errorf("failed to copy audio data: %w", err) + } + if err = writer.WriteField("model", *modelName); err != nil { + return nil, "", fmt.Errorf("failed to write model field: %w", err) + } + + responseFormat := "" + if asrConfig != nil && asrConfig.Params != nil { + if value, ok := asrConfig.Params["response_format"].(string); ok { + responseFormat = value + } + for key, value := range asrConfig.Params { + if stream && key == "stream" { + continue + } + if err = writeOpenAIMultipartField(writer, key, value); err != nil { + return nil, "", err + } + } + } + if stream { + if err = writer.WriteField("stream", "true"); err != nil { + return nil, "", fmt.Errorf("failed to write stream field: %w", err) + } + } + + if err = writer.Close(); err != nil { + return nil, "", fmt.Errorf("failed to close multipart writer: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", url, &body) + if err != nil { + return nil, "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Content-Type", writer.FormDataContentType()) + return req, responseFormat, nil +} + +func (o *OpenAIModel) newOpenAITTSRequest(ctx context.Context, modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, stream bool) (*http.Request, string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, "", fmt.Errorf("api key is required") + } + if modelName == nil || *modelName == "" { + return nil, "", fmt.Errorf("model name is required") + } + if audioContent == nil || *audioContent == "" { + return nil, "", fmt.Errorf("audio content is empty") + } + if strings.TrimSpace(o.URLSuffix.TTS) == "" { + return nil, "", fmt.Errorf("openai TTS URL suffix is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := o.baseURLForRegion(region) + if err != nil { + return nil, "", err + } + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), strings.TrimPrefix(o.URLSuffix.TTS, "/")) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": *audioContent, + } + + if ttsConfig != nil && ttsConfig.Params != nil { + for key, value := range ttsConfig.Params { + reqBody[key] = value + } + } + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["response_format"] = ttsConfig.Format + } + if stream { + if _, ok := reqBody["stream_format"]; !ok { + reqBody["stream_format"] = "audio" + } + } + + voice, ok := reqBody["voice"] + if !ok || voice == nil { + return nil, "", fmt.Errorf("voice is required") + } + voiceString, ok := voice.(string) + if !ok || strings.TrimSpace(voiceString) == "" { + return nil, "", fmt.Errorf("voice is required") + } + + streamFormat := "" + if value, ok := reqBody["stream_format"].(string); ok { + streamFormat = value + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, "", fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + return req, streamFormat, nil +} + +func readOpenAITTSSSEStream(body io.Reader, sender func(*string, *string) error) error { + scanner := bufio.NewScanner(body) + scanner.Buffer(make([]byte, 64*1024), 8*1024*1024) + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + + dataStr := strings.TrimSpace(line[6:]) + if dataStr == "" || dataStr == "[DONE]" { + continue + } + + var event struct { + Type string `json:"type"` + Audio string `json:"audio"` + } + if err := json.Unmarshal([]byte(dataStr), &event); err != nil { + continue + } + + if event.Type == "speech.audio.delta" && event.Audio != "" { + audioBytes, err := base64.StdEncoding.DecodeString(event.Audio) + if err == nil && len(audioBytes) > 0 { + chunk := string(audioBytes) + if errSend := sender(&chunk, nil); errSend != nil { + return errSend + } + } + } + if event.Type == "speech.audio.done" { + break + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("error reading OpenAI TTS stream: %w", err) + } + return nil +} + +func readOpenAITTSRawStream(body io.Reader, sender func(*string, *string) error) error { + buf := make([]byte, 32*1024) + for { + n, err := body.Read(buf) + if n > 0 { + chunk := string(buf[:n]) + if errSend := sender(&chunk, nil); errSend != nil { + return errSend + } + } + if err == io.EOF { + return nil + } + if err != nil { + return fmt.Errorf("error reading OpenAI TTS stream: %w", err) + } + } +} + +func writeOpenAIMultipartField(writer *multipart.Writer, key string, value interface{}) error { + var val string + + switch v := value.(type) { + case string: + val = v + case bool: + val = strconv.FormatBool(v) + case int: + val = strconv.Itoa(v) + case int64: + val = strconv.FormatInt(v, 10) + case float32: + val = strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + val = strconv.FormatFloat(v, 'f', -1, 64) + default: + val = fmt.Sprintf("%v", v) + } + + if err := writer.WriteField(key, val); err != nil { + return fmt.Errorf("failed to write field %s: %w", key, err) + } + return nil } // OCRFile OCR file diff --git a/internal/entity/models/openai_test.go b/internal/entity/models/openai_test.go new file mode 100644 index 0000000000..8483c4d94c --- /dev/null +++ b/internal/entity/models/openai_test.go @@ -0,0 +1,414 @@ +package models + +import ( + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" +) + +func newOpenAIForTest(baseURL string) *OpenAIModel { + return NewOpenAIModel( + map[string]string{"default": baseURL}, + URLSuffix{ + Chat: "chat/completions", + Models: "models", + Embedding: "embeddings", + ASR: "audio/transcriptions", + TTS: "audio/speech", + }, + ) +} + +func TestOpenAIConfigAdvertisedAudioModelsHaveSuffixes(t *testing.T) { + raw, err := os.ReadFile("../../../conf/models/openai.json") + if err != nil { + t.Fatalf("read openai config: %v", err) + } + + var cfg struct { + URLSuffix URLSuffix `json:"url_suffix"` + Models []struct { + Name string `json:"name"` + ModelTypes []string `json:"model_types"` + } `json:"models"` + } + if err = json.Unmarshal(raw, &cfg); err != nil { + t.Fatalf("unmarshal openai config: %v", err) + } + + if cfg.URLSuffix.ASR != "audio/transcriptions" { + t.Fatalf("ASR suffix=%q, want audio/transcriptions", cfg.URLSuffix.ASR) + } + if cfg.URLSuffix.TTS != "audio/speech" { + t.Fatalf("TTS suffix=%q, want audio/speech", cfg.URLSuffix.TTS) + } + + var hasASR, hasTTS bool + for _, model := range cfg.Models { + for _, modelType := range model.ModelTypes { + if model.Name == "whisper-1" && modelType == "asr" { + hasASR = true + } + if model.Name == "tts-1" && modelType == "tts" { + hasTTS = true + } + } + } + if !hasASR { + t.Fatal("openai config should advertise whisper-1 as ASR") + } + if !hasTTS { + t.Fatal("openai config should advertise tts-1 as TTS") + } +} + +func TestOpenAITranscribeAudioPostsMultipartToAudioEndpoint(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("method=%s, want POST", r.Method) + } + if r.URL.Path != "/audio/transcriptions" { + t.Errorf("path=%s, want /audio/transcriptions", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("Authorization=%q, want Bearer test-key", got) + } + if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "multipart/form-data") { + t.Errorf("Content-Type=%q, want multipart/form-data", got) + } + + if err := r.ParseMultipartForm(1024 * 1024); err != nil { + t.Errorf("ParseMultipartForm: %v", err) + http.Error(w, "bad multipart", http.StatusBadRequest) + return + } + if got := r.FormValue("model"); got != "whisper-1" { + t.Errorf("model=%q, want whisper-1", got) + } + if got := r.FormValue("language"); got != "en" { + t.Errorf("language=%q, want en", got) + } + if got := r.FormValue("temperature"); got != "0.2" { + t.Errorf("temperature=%q, want 0.2", got) + } + + file, _, err := r.FormFile("file") + if err != nil { + t.Errorf("FormFile: %v", err) + http.Error(w, "missing file", http.StatusBadRequest) + return + } + defer file.Close() + content, err := io.ReadAll(file) + if err != nil { + t.Errorf("read upload: %v", err) + http.Error(w, "read upload failed", http.StatusBadRequest) + return + } + if string(content) != "audio-bytes" { + t.Errorf("file content=%q, want audio-bytes", string(content)) + } + + _ = json.NewEncoder(w).Encode(map[string]string{"text": "hello world"}) + })) + defer srv.Close() + + audioPath := t.TempDir() + "/sample.wav" + if err := os.WriteFile(audioPath, []byte("audio-bytes"), 0600); err != nil { + t.Fatalf("write audio fixture: %v", err) + } + + apiKey := "test-key" + model := "whisper-1" + resp, err := newOpenAIForTest(srv.URL).TranscribeAudio( + &model, + &audioPath, + &APIConfig{ApiKey: &apiKey}, + &ASRConfig{Params: map[string]interface{}{ + "language": "en", + "temperature": 0.2, + }}, + ) + if err != nil { + t.Fatalf("TranscribeAudio: %v", err) + } + if resp.Text != "hello world" { + t.Fatalf("Text=%q, want hello world", resp.Text) + } +} + +func TestOpenAITranscribeAudioWithSenderStreamsDeltas(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("method=%s, want POST", r.Method) + } + if r.URL.Path != "/audio/transcriptions" { + t.Errorf("path=%s, want /audio/transcriptions", r.URL.Path) + } + if got := r.Header.Get("Accept"); got != "text/event-stream" { + t.Errorf("Accept=%q, want text/event-stream", got) + } + if err := r.ParseMultipartForm(1024 * 1024); err != nil { + t.Errorf("ParseMultipartForm: %v", err) + http.Error(w, "bad multipart", http.StatusBadRequest) + return + } + if got := r.FormValue("stream"); got != "true" { + t.Errorf("stream=%q, want true", got) + } + if got := r.FormValue("model"); got != "gpt-4o-mini-transcribe" { + t.Errorf("model=%q, want gpt-4o-mini-transcribe", got) + } + + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + _, _ = w.Write([]byte("data: {\"type\":\"transcript.text.delta\",\"delta\":\"hello\"}\n\n")) + if flusher != nil { + flusher.Flush() + } + _, _ = w.Write([]byte("data: {\"type\":\"transcript.text.delta\",\"delta\":\" world\"}\n\n")) + if flusher != nil { + flusher.Flush() + } + _, _ = w.Write([]byte("data: {\"type\":\"transcript.text.done\",\"text\":\"hello world\"}\n\n")) + })) + defer srv.Close() + + audioPath := t.TempDir() + "/sample.wav" + if err := os.WriteFile(audioPath, []byte("audio-bytes"), 0600); err != nil { + t.Fatalf("write audio fixture: %v", err) + } + + apiKey := "test-key" + model := "gpt-4o-mini-transcribe" + var chunks []string + err := newOpenAIForTest(srv.URL).TranscribeAudioWithSender( + &model, + &audioPath, + &APIConfig{ApiKey: &apiKey}, + nil, + func(content, _ *string) error { + if content != nil { + chunks = append(chunks, *content) + } + return nil + }, + ) + if err != nil { + t.Fatalf("TranscribeAudioWithSender: %v", err) + } + if got := strings.Join(chunks, ""); got != "hello world[DONE]" { + t.Fatalf("streamed text=%q, want hello world[DONE]", got) + } +} + +func TestOpenAIAudioSpeechPostsJSONToAudioEndpoint(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("method=%s, want POST", r.Method) + } + if r.URL.Path != "/audio/speech" { + t.Errorf("path=%s, want /audio/speech", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("Authorization=%q, want Bearer test-key", got) + } + if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") { + t.Errorf("Content-Type=%q, want application/json", got) + } + + var body map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("decode body: %v", err) + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if body["model"] != "tts-1" { + t.Errorf("model=%v, want tts-1", body["model"]) + } + if body["input"] != "hello" { + t.Errorf("input=%v, want hello", body["input"]) + } + if body["voice"] != "alloy" { + t.Errorf("voice=%v, want alloy", body["voice"]) + } + if body["response_format"] != "wav" { + t.Errorf("response_format=%v, want wav", body["response_format"]) + } + if body["speed"] != float64(1.25) { + t.Errorf("speed=%v, want 1.25", body["speed"]) + } + + _, _ = w.Write([]byte("audio-bytes")) + })) + defer srv.Close() + + apiKey := "test-key" + model := "tts-1" + input := "hello" + resp, err := newOpenAIForTest(srv.URL).AudioSpeech( + &model, + &input, + &APIConfig{ApiKey: &apiKey}, + &TTSConfig{ + Format: "wav", + Params: map[string]interface{}{ + "voice": "alloy", + "speed": 1.25, + }, + }, + ) + if err != nil { + t.Fatalf("AudioSpeech: %v", err) + } + if string(resp.Audio) != "audio-bytes" { + t.Fatalf("Audio=%q, want audio-bytes", string(resp.Audio)) + } +} + +func TestOpenAIAudioSpeechRequiresVoice(t *testing.T) { + apiKey := "test-key" + model := "tts-1" + input := "hello" + + _, err := newOpenAIForTest("http://unused").AudioSpeech( + &model, + &input, + &APIConfig{ApiKey: &apiKey}, + nil, + ) + if err == nil || !strings.Contains(err.Error(), "voice is required") { + t.Fatalf("err=%v, want voice is required", err) + } +} + +func TestOpenAIAudioSpeechRejectsNonStringVoice(t *testing.T) { + apiKey := "test-key" + model := "tts-1" + input := "hello" + + _, err := newOpenAIForTest("http://unused").AudioSpeech( + &model, + &input, + &APIConfig{ApiKey: &apiKey}, + &TTSConfig{Params: map[string]interface{}{"voice": 123}}, + ) + if err == nil || !strings.Contains(err.Error(), "voice is required") { + t.Fatalf("err=%v, want voice is required", err) + } +} + +func TestOpenAIAudioSpeechWithSenderStreamsRawAudio(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("method=%s, want POST", r.Method) + } + if r.URL.Path != "/audio/speech" { + t.Errorf("path=%s, want /audio/speech", r.URL.Path) + } + var body map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("decode body: %v", err) + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if body["model"] != "tts-1" { + t.Errorf("model=%v, want tts-1", body["model"]) + } + if body["input"] != "hello" { + t.Errorf("input=%v, want hello", body["input"]) + } + if body["voice"] != "alloy" { + t.Errorf("voice=%v, want alloy", body["voice"]) + } + if body["stream_format"] != "audio" { + t.Errorf("stream_format=%v, want audio", body["stream_format"]) + } + + flusher, _ := w.(http.Flusher) + _, _ = w.Write([]byte("audio-")) + if flusher != nil { + flusher.Flush() + } + _, _ = w.Write([]byte("bytes")) + })) + defer srv.Close() + + apiKey := "test-key" + model := "tts-1" + input := "hello" + + var chunks []string + err := newOpenAIForTest(srv.URL).AudioSpeechWithSender( + &model, + &input, + &APIConfig{ApiKey: &apiKey}, + &TTSConfig{Params: map[string]interface{}{"voice": "alloy"}}, + func(content, _ *string) error { + if content != nil { + chunks = append(chunks, *content) + } + return nil + }, + ) + if err != nil { + t.Fatalf("AudioSpeechWithSender: %v", err) + } + if got := strings.Join(chunks, ""); got != "audio-bytes" { + t.Fatalf("streamed audio=%q, want audio-bytes", got) + } +} + +func TestOpenAIAudioSpeechWithSenderStreamsSSEAudioDeltas(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Accept"); got != "text/event-stream" { + t.Errorf("Accept=%q, want text/event-stream", got) + } + var body map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("decode body: %v", err) + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if body["stream_format"] != "sse" { + t.Errorf("stream_format=%v, want sse", body["stream_format"]) + } + + w.Header().Set("Content-Type", "text/event-stream") + chunk := base64.StdEncoding.EncodeToString([]byte("audio-bytes")) + _, _ = w.Write([]byte("data: {\"type\":\"speech.audio.delta\",\"audio\":\"" + chunk + "\"}\n\n")) + _, _ = w.Write([]byte("data: {\"type\":\"speech.audio.done\"}\n\n")) + })) + defer srv.Close() + + apiKey := "test-key" + model := "gpt-4o-mini-tts" + input := "hello" + var chunks []string + err := newOpenAIForTest(srv.URL).AudioSpeechWithSender( + &model, + &input, + &APIConfig{ApiKey: &apiKey}, + &TTSConfig{Params: map[string]interface{}{ + "voice": "alloy", + "stream_format": "sse", + }}, + func(content, _ *string) error { + if content != nil { + chunks = append(chunks, *content) + } + return nil + }, + ) + if err != nil { + t.Fatalf("AudioSpeechWithSender: %v", err) + } + if got := strings.Join(chunks, ""); got != "audio-bytes" { + t.Fatalf("streamed audio=%q, want audio-bytes", got) + } +}