fix(go): support OpenAI audio endpoints (#15104)

### What problem does this PR solve?

Closes #15102.

OpenAI's Go provider config advertises `whisper-1` as ASR and `tts-1` as
TTS, but the Go driver returned `openai, no such method` for both audio
paths and did not define `url_suffix.asr` / `url_suffix.tts`.

This PR:

- adds OpenAI audio URL suffixes for `audio/transcriptions` and
`audio/speech`
- implements non-streaming `TranscribeAudio` using multipart form
uploads
- implements non-streaming `AudioSpeech` using the OpenAI speech JSON
request shape
- keeps streaming TTS explicitly unsupported instead of sending binary
audio through the text SSE sender
- adds focused tests for config coverage, ASR/TTS request shape,
required TTS voice validation, and unsupported streaming TTS


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
bitloi
2026-05-24 23:25:53 -03:00
committed by GitHub
parent e6dd397531
commit 432e966414
3 changed files with 844 additions and 6 deletions

View File

@@ -6,7 +6,9 @@
"url_suffix": {
"chat": "chat/completions",
"models": "models",
"embedding": "embeddings"
"embedding": "embeddings",
"asr": "audio/transcriptions",
"tts": "audio/speech"
},
"class": "gpt",
"models": [
@@ -224,4 +226,4 @@
]
}
]
}
}

View File

@@ -20,10 +20,15 @@ import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
)
@@ -596,20 +601,437 @@ func (z *OpenAIModel) Rerank(modelName *string, query string, documents []string
// TranscribeAudio transcribe audio
func (o *OpenAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
return nil, fmt.Errorf("%s, no such method", o.Name())
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
defer cancel()
req, responseFormat, err := o.newOpenAIASRRequest(ctx, modelName, file, apiConfig, asrConfig, false)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("OpenAI ASR API error: %s, body: %s", resp.Status, string(respBody))
}
return decodeOpenAIASRResponse(respBody, responseFormat)
}
func (z *OpenAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error {
return fmt.Errorf("%s, no such method", z.Name())
if sender == nil {
return fmt.Errorf("sender is required")
}
req, responseFormat, err := z.newOpenAIASRRequest(context.Background(), modelName, file, apiConfig, asrConfig, true)
if err != nil {
return err
}
req.Header.Set("Accept", "text/event-stream")
resp, err := z.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("OpenAI ASR stream API error: %s, body: %s", resp.Status, string(respBody))
}
if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}
response, err := decodeOpenAIASRResponse(respBody, responseFormat)
if err != nil {
return err
}
if response != nil && response.Text != "" {
if err = sender(&response.Text, nil); err != nil {
return err
}
}
done := "[DONE]"
return sender(&done, nil)
}
sentDelta := false
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 64*1024), 8*1024*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
dataStr := strings.TrimSpace(line[6:])
if dataStr == "" {
continue
}
if dataStr == "[DONE]" {
break
}
var event struct {
Type string `json:"type"`
Delta string `json:"delta"`
Text string `json:"text"`
}
if err = json.Unmarshal([]byte(dataStr), &event); err != nil {
continue
}
switch {
case event.Delta != "":
if err = sender(&event.Delta, nil); err != nil {
return err
}
sentDelta = true
case event.Type == "transcript.text.segment" && event.Text != "":
if err = sender(&event.Text, nil); err != nil {
return err
}
sentDelta = true
case event.Type == "transcript.text.done" && !sentDelta && event.Text != "":
if err = sender(&event.Text, nil); err != nil {
return err
}
}
}
if err = scanner.Err(); err != nil {
return fmt.Errorf("error reading OpenAI ASR stream: %w", err)
}
done := "[DONE]"
return sender(&done, nil)
}
func decodeOpenAIASRResponse(respBody []byte, responseFormat string) (*ASRResponse, error) {
switch responseFormat {
case "text", "srt", "vtt":
return &ASRResponse{Text: string(respBody)}, nil
}
var result struct {
Text string `json:"text"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w, body=%s", err, string(respBody))
}
return &ASRResponse{Text: result.Text}, nil
}
// AudioSpeech convert text to audio
func (o *OpenAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) {
return nil, fmt.Errorf("%s, no such method", o.Name())
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
defer cancel()
req, _, err := o.newOpenAITTSRequest(ctx, modelName, audioContent, apiConfig, ttsConfig, false)
if err != nil {
return nil, err
}
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("OpenAI TTS API error: %s, body: %s", resp.Status, string(body))
}
return &TTSResponse{Audio: body}, nil
}
func (z *OpenAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {
return fmt.Errorf("%s, no such method", z.Name())
if sender == nil {
return fmt.Errorf("sender is required")
}
req, streamFormat, err := z.newOpenAITTSRequest(context.Background(), modelName, audioContent, apiConfig, ttsConfig, true)
if err != nil {
return err
}
if streamFormat == "sse" {
req.Header.Set("Accept", "text/event-stream")
}
resp, err := z.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("OpenAI TTS stream API error: %s, body: %s", resp.Status, string(body))
}
if streamFormat == "sse" || strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
return readOpenAITTSSSEStream(resp.Body, sender)
}
return readOpenAITTSRawStream(resp.Body, sender)
}
func (o *OpenAIModel) newOpenAIASRRequest(ctx context.Context, modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, stream bool) (*http.Request, string, error) {
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return nil, "", fmt.Errorf("api key is required")
}
if modelName == nil || *modelName == "" {
return nil, "", fmt.Errorf("model name is required")
}
if file == nil || *file == "" {
return nil, "", fmt.Errorf("file is missing")
}
if strings.TrimSpace(o.URLSuffix.ASR) == "" {
return nil, "", fmt.Errorf("openai ASR URL suffix is required")
}
region := "default"
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL, err := o.baseURLForRegion(region)
if err != nil {
return nil, "", err
}
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), strings.TrimPrefix(o.URLSuffix.ASR, "/"))
var body bytes.Buffer
writer := multipart.NewWriter(&body)
audioFile, err := os.Open(*file)
if err != nil {
return nil, "", fmt.Errorf("failed to open audio file: %w", err)
}
defer audioFile.Close()
part, err := writer.CreateFormFile("file", filepath.Base(*file))
if err != nil {
return nil, "", fmt.Errorf("failed to create multipart file: %w", err)
}
if _, err = io.Copy(part, audioFile); err != nil {
return nil, "", fmt.Errorf("failed to copy audio data: %w", err)
}
if err = writer.WriteField("model", *modelName); err != nil {
return nil, "", fmt.Errorf("failed to write model field: %w", err)
}
responseFormat := ""
if asrConfig != nil && asrConfig.Params != nil {
if value, ok := asrConfig.Params["response_format"].(string); ok {
responseFormat = value
}
for key, value := range asrConfig.Params {
if stream && key == "stream" {
continue
}
if err = writeOpenAIMultipartField(writer, key, value); err != nil {
return nil, "", err
}
}
}
if stream {
if err = writer.WriteField("stream", "true"); err != nil {
return nil, "", fmt.Errorf("failed to write stream field: %w", err)
}
}
if err = writer.Close(); err != nil {
return nil, "", fmt.Errorf("failed to close multipart writer: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", url, &body)
if err != nil {
return nil, "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
req.Header.Set("Content-Type", writer.FormDataContentType())
return req, responseFormat, nil
}
func (o *OpenAIModel) newOpenAITTSRequest(ctx context.Context, modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, stream bool) (*http.Request, string, error) {
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return nil, "", fmt.Errorf("api key is required")
}
if modelName == nil || *modelName == "" {
return nil, "", fmt.Errorf("model name is required")
}
if audioContent == nil || *audioContent == "" {
return nil, "", fmt.Errorf("audio content is empty")
}
if strings.TrimSpace(o.URLSuffix.TTS) == "" {
return nil, "", fmt.Errorf("openai TTS URL suffix is required")
}
region := "default"
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL, err := o.baseURLForRegion(region)
if err != nil {
return nil, "", err
}
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), strings.TrimPrefix(o.URLSuffix.TTS, "/"))
reqBody := map[string]interface{}{
"model": *modelName,
"input": *audioContent,
}
if ttsConfig != nil && ttsConfig.Params != nil {
for key, value := range ttsConfig.Params {
reqBody[key] = value
}
}
if ttsConfig != nil && ttsConfig.Format != "" {
reqBody["response_format"] = ttsConfig.Format
}
if stream {
if _, ok := reqBody["stream_format"]; !ok {
reqBody["stream_format"] = "audio"
}
}
voice, ok := reqBody["voice"]
if !ok || voice == nil {
return nil, "", fmt.Errorf("voice is required")
}
voiceString, ok := voice.(string)
if !ok || strings.TrimSpace(voiceString) == "" {
return nil, "", fmt.Errorf("voice is required")
}
streamFormat := ""
if value, ok := reqBody["stream_format"].(string); ok {
streamFormat = value
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
return req, streamFormat, nil
}
func readOpenAITTSSSEStream(body io.Reader, sender func(*string, *string) error) error {
scanner := bufio.NewScanner(body)
scanner.Buffer(make([]byte, 64*1024), 8*1024*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
dataStr := strings.TrimSpace(line[6:])
if dataStr == "" || dataStr == "[DONE]" {
continue
}
var event struct {
Type string `json:"type"`
Audio string `json:"audio"`
}
if err := json.Unmarshal([]byte(dataStr), &event); err != nil {
continue
}
if event.Type == "speech.audio.delta" && event.Audio != "" {
audioBytes, err := base64.StdEncoding.DecodeString(event.Audio)
if err == nil && len(audioBytes) > 0 {
chunk := string(audioBytes)
if errSend := sender(&chunk, nil); errSend != nil {
return errSend
}
}
}
if event.Type == "speech.audio.done" {
break
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("error reading OpenAI TTS stream: %w", err)
}
return nil
}
func readOpenAITTSRawStream(body io.Reader, sender func(*string, *string) error) error {
buf := make([]byte, 32*1024)
for {
n, err := body.Read(buf)
if n > 0 {
chunk := string(buf[:n])
if errSend := sender(&chunk, nil); errSend != nil {
return errSend
}
}
if err == io.EOF {
return nil
}
if err != nil {
return fmt.Errorf("error reading OpenAI TTS stream: %w", err)
}
}
}
func writeOpenAIMultipartField(writer *multipart.Writer, key string, value interface{}) error {
var val string
switch v := value.(type) {
case string:
val = v
case bool:
val = strconv.FormatBool(v)
case int:
val = strconv.Itoa(v)
case int64:
val = strconv.FormatInt(v, 10)
case float32:
val = strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
val = strconv.FormatFloat(v, 'f', -1, 64)
default:
val = fmt.Sprintf("%v", v)
}
if err := writer.WriteField(key, val); err != nil {
return fmt.Errorf("failed to write field %s: %w", key, err)
}
return nil
}
// OCRFile OCR file

View File

@@ -0,0 +1,414 @@
package models
import (
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
)
func newOpenAIForTest(baseURL string) *OpenAIModel {
return NewOpenAIModel(
map[string]string{"default": baseURL},
URLSuffix{
Chat: "chat/completions",
Models: "models",
Embedding: "embeddings",
ASR: "audio/transcriptions",
TTS: "audio/speech",
},
)
}
func TestOpenAIConfigAdvertisedAudioModelsHaveSuffixes(t *testing.T) {
raw, err := os.ReadFile("../../../conf/models/openai.json")
if err != nil {
t.Fatalf("read openai config: %v", err)
}
var cfg struct {
URLSuffix URLSuffix `json:"url_suffix"`
Models []struct {
Name string `json:"name"`
ModelTypes []string `json:"model_types"`
} `json:"models"`
}
if err = json.Unmarshal(raw, &cfg); err != nil {
t.Fatalf("unmarshal openai config: %v", err)
}
if cfg.URLSuffix.ASR != "audio/transcriptions" {
t.Fatalf("ASR suffix=%q, want audio/transcriptions", cfg.URLSuffix.ASR)
}
if cfg.URLSuffix.TTS != "audio/speech" {
t.Fatalf("TTS suffix=%q, want audio/speech", cfg.URLSuffix.TTS)
}
var hasASR, hasTTS bool
for _, model := range cfg.Models {
for _, modelType := range model.ModelTypes {
if model.Name == "whisper-1" && modelType == "asr" {
hasASR = true
}
if model.Name == "tts-1" && modelType == "tts" {
hasTTS = true
}
}
}
if !hasASR {
t.Fatal("openai config should advertise whisper-1 as ASR")
}
if !hasTTS {
t.Fatal("openai config should advertise tts-1 as TTS")
}
}
func TestOpenAITranscribeAudioPostsMultipartToAudioEndpoint(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("method=%s, want POST", r.Method)
}
if r.URL.Path != "/audio/transcriptions" {
t.Errorf("path=%s, want /audio/transcriptions", r.URL.Path)
}
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
t.Errorf("Authorization=%q, want Bearer test-key", got)
}
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "multipart/form-data") {
t.Errorf("Content-Type=%q, want multipart/form-data", got)
}
if err := r.ParseMultipartForm(1024 * 1024); err != nil {
t.Errorf("ParseMultipartForm: %v", err)
http.Error(w, "bad multipart", http.StatusBadRequest)
return
}
if got := r.FormValue("model"); got != "whisper-1" {
t.Errorf("model=%q, want whisper-1", got)
}
if got := r.FormValue("language"); got != "en" {
t.Errorf("language=%q, want en", got)
}
if got := r.FormValue("temperature"); got != "0.2" {
t.Errorf("temperature=%q, want 0.2", got)
}
file, _, err := r.FormFile("file")
if err != nil {
t.Errorf("FormFile: %v", err)
http.Error(w, "missing file", http.StatusBadRequest)
return
}
defer file.Close()
content, err := io.ReadAll(file)
if err != nil {
t.Errorf("read upload: %v", err)
http.Error(w, "read upload failed", http.StatusBadRequest)
return
}
if string(content) != "audio-bytes" {
t.Errorf("file content=%q, want audio-bytes", string(content))
}
_ = json.NewEncoder(w).Encode(map[string]string{"text": "hello world"})
}))
defer srv.Close()
audioPath := t.TempDir() + "/sample.wav"
if err := os.WriteFile(audioPath, []byte("audio-bytes"), 0600); err != nil {
t.Fatalf("write audio fixture: %v", err)
}
apiKey := "test-key"
model := "whisper-1"
resp, err := newOpenAIForTest(srv.URL).TranscribeAudio(
&model,
&audioPath,
&APIConfig{ApiKey: &apiKey},
&ASRConfig{Params: map[string]interface{}{
"language": "en",
"temperature": 0.2,
}},
)
if err != nil {
t.Fatalf("TranscribeAudio: %v", err)
}
if resp.Text != "hello world" {
t.Fatalf("Text=%q, want hello world", resp.Text)
}
}
func TestOpenAITranscribeAudioWithSenderStreamsDeltas(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("method=%s, want POST", r.Method)
}
if r.URL.Path != "/audio/transcriptions" {
t.Errorf("path=%s, want /audio/transcriptions", r.URL.Path)
}
if got := r.Header.Get("Accept"); got != "text/event-stream" {
t.Errorf("Accept=%q, want text/event-stream", got)
}
if err := r.ParseMultipartForm(1024 * 1024); err != nil {
t.Errorf("ParseMultipartForm: %v", err)
http.Error(w, "bad multipart", http.StatusBadRequest)
return
}
if got := r.FormValue("stream"); got != "true" {
t.Errorf("stream=%q, want true", got)
}
if got := r.FormValue("model"); got != "gpt-4o-mini-transcribe" {
t.Errorf("model=%q, want gpt-4o-mini-transcribe", got)
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
_, _ = w.Write([]byte("data: {\"type\":\"transcript.text.delta\",\"delta\":\"hello\"}\n\n"))
if flusher != nil {
flusher.Flush()
}
_, _ = w.Write([]byte("data: {\"type\":\"transcript.text.delta\",\"delta\":\" world\"}\n\n"))
if flusher != nil {
flusher.Flush()
}
_, _ = w.Write([]byte("data: {\"type\":\"transcript.text.done\",\"text\":\"hello world\"}\n\n"))
}))
defer srv.Close()
audioPath := t.TempDir() + "/sample.wav"
if err := os.WriteFile(audioPath, []byte("audio-bytes"), 0600); err != nil {
t.Fatalf("write audio fixture: %v", err)
}
apiKey := "test-key"
model := "gpt-4o-mini-transcribe"
var chunks []string
err := newOpenAIForTest(srv.URL).TranscribeAudioWithSender(
&model,
&audioPath,
&APIConfig{ApiKey: &apiKey},
nil,
func(content, _ *string) error {
if content != nil {
chunks = append(chunks, *content)
}
return nil
},
)
if err != nil {
t.Fatalf("TranscribeAudioWithSender: %v", err)
}
if got := strings.Join(chunks, ""); got != "hello world[DONE]" {
t.Fatalf("streamed text=%q, want hello world[DONE]", got)
}
}
func TestOpenAIAudioSpeechPostsJSONToAudioEndpoint(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("method=%s, want POST", r.Method)
}
if r.URL.Path != "/audio/speech" {
t.Errorf("path=%s, want /audio/speech", r.URL.Path)
}
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
t.Errorf("Authorization=%q, want Bearer test-key", got)
}
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") {
t.Errorf("Content-Type=%q, want application/json", got)
}
var body map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Errorf("decode body: %v", err)
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
if body["model"] != "tts-1" {
t.Errorf("model=%v, want tts-1", body["model"])
}
if body["input"] != "hello" {
t.Errorf("input=%v, want hello", body["input"])
}
if body["voice"] != "alloy" {
t.Errorf("voice=%v, want alloy", body["voice"])
}
if body["response_format"] != "wav" {
t.Errorf("response_format=%v, want wav", body["response_format"])
}
if body["speed"] != float64(1.25) {
t.Errorf("speed=%v, want 1.25", body["speed"])
}
_, _ = w.Write([]byte("audio-bytes"))
}))
defer srv.Close()
apiKey := "test-key"
model := "tts-1"
input := "hello"
resp, err := newOpenAIForTest(srv.URL).AudioSpeech(
&model,
&input,
&APIConfig{ApiKey: &apiKey},
&TTSConfig{
Format: "wav",
Params: map[string]interface{}{
"voice": "alloy",
"speed": 1.25,
},
},
)
if err != nil {
t.Fatalf("AudioSpeech: %v", err)
}
if string(resp.Audio) != "audio-bytes" {
t.Fatalf("Audio=%q, want audio-bytes", string(resp.Audio))
}
}
func TestOpenAIAudioSpeechRequiresVoice(t *testing.T) {
apiKey := "test-key"
model := "tts-1"
input := "hello"
_, err := newOpenAIForTest("http://unused").AudioSpeech(
&model,
&input,
&APIConfig{ApiKey: &apiKey},
nil,
)
if err == nil || !strings.Contains(err.Error(), "voice is required") {
t.Fatalf("err=%v, want voice is required", err)
}
}
func TestOpenAIAudioSpeechRejectsNonStringVoice(t *testing.T) {
apiKey := "test-key"
model := "tts-1"
input := "hello"
_, err := newOpenAIForTest("http://unused").AudioSpeech(
&model,
&input,
&APIConfig{ApiKey: &apiKey},
&TTSConfig{Params: map[string]interface{}{"voice": 123}},
)
if err == nil || !strings.Contains(err.Error(), "voice is required") {
t.Fatalf("err=%v, want voice is required", err)
}
}
func TestOpenAIAudioSpeechWithSenderStreamsRawAudio(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("method=%s, want POST", r.Method)
}
if r.URL.Path != "/audio/speech" {
t.Errorf("path=%s, want /audio/speech", r.URL.Path)
}
var body map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Errorf("decode body: %v", err)
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
if body["model"] != "tts-1" {
t.Errorf("model=%v, want tts-1", body["model"])
}
if body["input"] != "hello" {
t.Errorf("input=%v, want hello", body["input"])
}
if body["voice"] != "alloy" {
t.Errorf("voice=%v, want alloy", body["voice"])
}
if body["stream_format"] != "audio" {
t.Errorf("stream_format=%v, want audio", body["stream_format"])
}
flusher, _ := w.(http.Flusher)
_, _ = w.Write([]byte("audio-"))
if flusher != nil {
flusher.Flush()
}
_, _ = w.Write([]byte("bytes"))
}))
defer srv.Close()
apiKey := "test-key"
model := "tts-1"
input := "hello"
var chunks []string
err := newOpenAIForTest(srv.URL).AudioSpeechWithSender(
&model,
&input,
&APIConfig{ApiKey: &apiKey},
&TTSConfig{Params: map[string]interface{}{"voice": "alloy"}},
func(content, _ *string) error {
if content != nil {
chunks = append(chunks, *content)
}
return nil
},
)
if err != nil {
t.Fatalf("AudioSpeechWithSender: %v", err)
}
if got := strings.Join(chunks, ""); got != "audio-bytes" {
t.Fatalf("streamed audio=%q, want audio-bytes", got)
}
}
func TestOpenAIAudioSpeechWithSenderStreamsSSEAudioDeltas(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Accept"); got != "text/event-stream" {
t.Errorf("Accept=%q, want text/event-stream", got)
}
var body map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Errorf("decode body: %v", err)
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
if body["stream_format"] != "sse" {
t.Errorf("stream_format=%v, want sse", body["stream_format"])
}
w.Header().Set("Content-Type", "text/event-stream")
chunk := base64.StdEncoding.EncodeToString([]byte("audio-bytes"))
_, _ = w.Write([]byte("data: {\"type\":\"speech.audio.delta\",\"audio\":\"" + chunk + "\"}\n\n"))
_, _ = w.Write([]byte("data: {\"type\":\"speech.audio.done\"}\n\n"))
}))
defer srv.Close()
apiKey := "test-key"
model := "gpt-4o-mini-tts"
input := "hello"
var chunks []string
err := newOpenAIForTest(srv.URL).AudioSpeechWithSender(
&model,
&input,
&APIConfig{ApiKey: &apiKey},
&TTSConfig{Params: map[string]interface{}{
"voice": "alloy",
"stream_format": "sse",
}},
func(content, _ *string) error {
if content != nil {
chunks = append(chunks, *content)
}
return nil
},
)
if err != nil {
t.Fatalf("AudioSpeechWithSender: %v", err)
}
if got := strings.Join(chunks, ""); got != "audio-bytes" {
t.Fatalf("streamed audio=%q, want audio-bytes", got)
}
}