mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
fix(go): support OpenAI audio endpoints (#15104)
### What problem does this PR solve? Closes #15102. OpenAI's Go provider config advertises `whisper-1` as ASR and `tts-1` as TTS, but the Go driver returned `openai, no such method` for both audio paths and did not define `url_suffix.asr` / `url_suffix.tts`. This PR: - adds OpenAI audio URL suffixes for `audio/transcriptions` and `audio/speech` - implements non-streaming `TranscribeAudio` using multipart form uploads - implements non-streaming `AudioSpeech` using the OpenAI speech JSON request shape - keeps streaming TTS explicitly unsupported instead of sending binary audio through the text SSE sender - adds focused tests for config coverage, ASR/TTS request shape, required TTS voice validation, and unsupported streaming TTS ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@@ -6,7 +6,9 @@
|
||||
"url_suffix": {
|
||||
"chat": "chat/completions",
|
||||
"models": "models",
|
||||
"embedding": "embeddings"
|
||||
"embedding": "embeddings",
|
||||
"asr": "audio/transcriptions",
|
||||
"tts": "audio/speech"
|
||||
},
|
||||
"class": "gpt",
|
||||
"models": [
|
||||
@@ -224,4 +226,4 @@
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,10 +20,15 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -596,20 +601,437 @@ func (z *OpenAIModel) Rerank(modelName *string, query string, documents []string
|
||||
|
||||
// TranscribeAudio transcribe audio
|
||||
func (o *OpenAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", o.Name())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, responseFormat, err := o.newOpenAIASRRequest(ctx, modelName, file, apiConfig, asrConfig, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("OpenAI ASR API error: %s, body: %s", resp.Status, string(respBody))
|
||||
}
|
||||
|
||||
return decodeOpenAIASRResponse(respBody, responseFormat)
|
||||
}
|
||||
|
||||
func (z *OpenAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("%s, no such method", z.Name())
|
||||
if sender == nil {
|
||||
return fmt.Errorf("sender is required")
|
||||
}
|
||||
|
||||
req, responseFormat, err := z.newOpenAIASRRequest(context.Background(), modelName, file, apiConfig, asrConfig, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("OpenAI ASR stream API error: %s, body: %s", resp.Status, string(respBody))
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
response, err := decodeOpenAIASRResponse(respBody, responseFormat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if response != nil && response.Text != "" {
|
||||
if err = sender(&response.Text, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
done := "[DONE]"
|
||||
return sender(&done, nil)
|
||||
}
|
||||
|
||||
sentDelta := false
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 64*1024), 8*1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
dataStr := strings.TrimSpace(line[6:])
|
||||
if dataStr == "" {
|
||||
continue
|
||||
}
|
||||
if dataStr == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var event struct {
|
||||
Type string `json:"type"`
|
||||
Delta string `json:"delta"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err = json.Unmarshal([]byte(dataStr), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch {
|
||||
case event.Delta != "":
|
||||
if err = sender(&event.Delta, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
sentDelta = true
|
||||
case event.Type == "transcript.text.segment" && event.Text != "":
|
||||
if err = sender(&event.Text, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
sentDelta = true
|
||||
case event.Type == "transcript.text.done" && !sentDelta && event.Text != "":
|
||||
if err = sender(&event.Text, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if err = scanner.Err(); err != nil {
|
||||
return fmt.Errorf("error reading OpenAI ASR stream: %w", err)
|
||||
}
|
||||
|
||||
done := "[DONE]"
|
||||
return sender(&done, nil)
|
||||
}
|
||||
|
||||
func decodeOpenAIASRResponse(respBody []byte, responseFormat string) (*ASRResponse, error) {
|
||||
switch responseFormat {
|
||||
case "text", "srt", "vtt":
|
||||
return &ASRResponse{Text: string(respBody)}, nil
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w, body=%s", err, string(respBody))
|
||||
}
|
||||
|
||||
return &ASRResponse{Text: result.Text}, nil
|
||||
}
|
||||
|
||||
// AudioSpeech convert text to audio
|
||||
func (o *OpenAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", o.Name())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, _, err := o.newOpenAITTSRequest(ctx, modelName, audioContent, apiConfig, ttsConfig, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("OpenAI TTS API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
return &TTSResponse{Audio: body}, nil
|
||||
}
|
||||
|
||||
func (z *OpenAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("%s, no such method", z.Name())
|
||||
if sender == nil {
|
||||
return fmt.Errorf("sender is required")
|
||||
}
|
||||
|
||||
req, streamFormat, err := z.newOpenAITTSRequest(context.Background(), modelName, audioContent, apiConfig, ttsConfig, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if streamFormat == "sse" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("OpenAI TTS stream API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
if streamFormat == "sse" || strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
|
||||
return readOpenAITTSSSEStream(resp.Body, sender)
|
||||
}
|
||||
return readOpenAITTSRawStream(resp.Body, sender)
|
||||
}
|
||||
|
||||
func (o *OpenAIModel) newOpenAIASRRequest(ctx context.Context, modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, stream bool) (*http.Request, string, error) {
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, "", fmt.Errorf("api key is required")
|
||||
}
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, "", fmt.Errorf("model name is required")
|
||||
}
|
||||
if file == nil || *file == "" {
|
||||
return nil, "", fmt.Errorf("file is missing")
|
||||
}
|
||||
if strings.TrimSpace(o.URLSuffix.ASR) == "" {
|
||||
return nil, "", fmt.Errorf("openai ASR URL suffix is required")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL, err := o.baseURLForRegion(region)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), strings.TrimPrefix(o.URLSuffix.ASR, "/"))
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
|
||||
audioFile, err := os.Open(*file)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to open audio file: %w", err)
|
||||
}
|
||||
defer audioFile.Close()
|
||||
|
||||
part, err := writer.CreateFormFile("file", filepath.Base(*file))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create multipart file: %w", err)
|
||||
}
|
||||
if _, err = io.Copy(part, audioFile); err != nil {
|
||||
return nil, "", fmt.Errorf("failed to copy audio data: %w", err)
|
||||
}
|
||||
if err = writer.WriteField("model", *modelName); err != nil {
|
||||
return nil, "", fmt.Errorf("failed to write model field: %w", err)
|
||||
}
|
||||
|
||||
responseFormat := ""
|
||||
if asrConfig != nil && asrConfig.Params != nil {
|
||||
if value, ok := asrConfig.Params["response_format"].(string); ok {
|
||||
responseFormat = value
|
||||
}
|
||||
for key, value := range asrConfig.Params {
|
||||
if stream && key == "stream" {
|
||||
continue
|
||||
}
|
||||
if err = writeOpenAIMultipartField(writer, key, value); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
if stream {
|
||||
if err = writer.WriteField("stream", "true"); err != nil {
|
||||
return nil, "", fmt.Errorf("failed to write stream field: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err = writer.Close(); err != nil {
|
||||
return nil, "", fmt.Errorf("failed to close multipart writer: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, &body)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return req, responseFormat, nil
|
||||
}
|
||||
|
||||
func (o *OpenAIModel) newOpenAITTSRequest(ctx context.Context, modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, stream bool) (*http.Request, string, error) {
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, "", fmt.Errorf("api key is required")
|
||||
}
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, "", fmt.Errorf("model name is required")
|
||||
}
|
||||
if audioContent == nil || *audioContent == "" {
|
||||
return nil, "", fmt.Errorf("audio content is empty")
|
||||
}
|
||||
if strings.TrimSpace(o.URLSuffix.TTS) == "" {
|
||||
return nil, "", fmt.Errorf("openai TTS URL suffix is required")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL, err := o.baseURLForRegion(region)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), strings.TrimPrefix(o.URLSuffix.TTS, "/"))
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"input": *audioContent,
|
||||
}
|
||||
|
||||
if ttsConfig != nil && ttsConfig.Params != nil {
|
||||
for key, value := range ttsConfig.Params {
|
||||
reqBody[key] = value
|
||||
}
|
||||
}
|
||||
if ttsConfig != nil && ttsConfig.Format != "" {
|
||||
reqBody["response_format"] = ttsConfig.Format
|
||||
}
|
||||
if stream {
|
||||
if _, ok := reqBody["stream_format"]; !ok {
|
||||
reqBody["stream_format"] = "audio"
|
||||
}
|
||||
}
|
||||
|
||||
voice, ok := reqBody["voice"]
|
||||
if !ok || voice == nil {
|
||||
return nil, "", fmt.Errorf("voice is required")
|
||||
}
|
||||
voiceString, ok := voice.(string)
|
||||
if !ok || strings.TrimSpace(voiceString) == "" {
|
||||
return nil, "", fmt.Errorf("voice is required")
|
||||
}
|
||||
|
||||
streamFormat := ""
|
||||
if value, ok := reqBody["stream_format"].(string); ok {
|
||||
streamFormat = value
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
return req, streamFormat, nil
|
||||
}
|
||||
|
||||
func readOpenAITTSSSEStream(body io.Reader, sender func(*string, *string) error) error {
|
||||
scanner := bufio.NewScanner(body)
|
||||
scanner.Buffer(make([]byte, 64*1024), 8*1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
dataStr := strings.TrimSpace(line[6:])
|
||||
if dataStr == "" || dataStr == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event struct {
|
||||
Type string `json:"type"`
|
||||
Audio string `json:"audio"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(dataStr), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if event.Type == "speech.audio.delta" && event.Audio != "" {
|
||||
audioBytes, err := base64.StdEncoding.DecodeString(event.Audio)
|
||||
if err == nil && len(audioBytes) > 0 {
|
||||
chunk := string(audioBytes)
|
||||
if errSend := sender(&chunk, nil); errSend != nil {
|
||||
return errSend
|
||||
}
|
||||
}
|
||||
}
|
||||
if event.Type == "speech.audio.done" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("error reading OpenAI TTS stream: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readOpenAITTSRawStream(body io.Reader, sender func(*string, *string) error) error {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := body.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := string(buf[:n])
|
||||
if errSend := sender(&chunk, nil); errSend != nil {
|
||||
return errSend
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading OpenAI TTS stream: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeOpenAIMultipartField(writer *multipart.Writer, key string, value interface{}) error {
|
||||
var val string
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
val = v
|
||||
case bool:
|
||||
val = strconv.FormatBool(v)
|
||||
case int:
|
||||
val = strconv.Itoa(v)
|
||||
case int64:
|
||||
val = strconv.FormatInt(v, 10)
|
||||
case float32:
|
||||
val = strconv.FormatFloat(float64(v), 'f', -1, 32)
|
||||
case float64:
|
||||
val = strconv.FormatFloat(v, 'f', -1, 64)
|
||||
default:
|
||||
val = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
if err := writer.WriteField(key, val); err != nil {
|
||||
return fmt.Errorf("failed to write field %s: %w", key, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OCRFile OCR file
|
||||
|
||||
414
internal/entity/models/openai_test.go
Normal file
414
internal/entity/models/openai_test.go
Normal file
@@ -0,0 +1,414 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newOpenAIForTest(baseURL string) *OpenAIModel {
|
||||
return NewOpenAIModel(
|
||||
map[string]string{"default": baseURL},
|
||||
URLSuffix{
|
||||
Chat: "chat/completions",
|
||||
Models: "models",
|
||||
Embedding: "embeddings",
|
||||
ASR: "audio/transcriptions",
|
||||
TTS: "audio/speech",
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestOpenAIConfigAdvertisedAudioModelsHaveSuffixes(t *testing.T) {
|
||||
raw, err := os.ReadFile("../../../conf/models/openai.json")
|
||||
if err != nil {
|
||||
t.Fatalf("read openai config: %v", err)
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
URLSuffix URLSuffix `json:"url_suffix"`
|
||||
Models []struct {
|
||||
Name string `json:"name"`
|
||||
ModelTypes []string `json:"model_types"`
|
||||
} `json:"models"`
|
||||
}
|
||||
if err = json.Unmarshal(raw, &cfg); err != nil {
|
||||
t.Fatalf("unmarshal openai config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.URLSuffix.ASR != "audio/transcriptions" {
|
||||
t.Fatalf("ASR suffix=%q, want audio/transcriptions", cfg.URLSuffix.ASR)
|
||||
}
|
||||
if cfg.URLSuffix.TTS != "audio/speech" {
|
||||
t.Fatalf("TTS suffix=%q, want audio/speech", cfg.URLSuffix.TTS)
|
||||
}
|
||||
|
||||
var hasASR, hasTTS bool
|
||||
for _, model := range cfg.Models {
|
||||
for _, modelType := range model.ModelTypes {
|
||||
if model.Name == "whisper-1" && modelType == "asr" {
|
||||
hasASR = true
|
||||
}
|
||||
if model.Name == "tts-1" && modelType == "tts" {
|
||||
hasTTS = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasASR {
|
||||
t.Fatal("openai config should advertise whisper-1 as ASR")
|
||||
}
|
||||
if !hasTTS {
|
||||
t.Fatal("openai config should advertise tts-1 as TTS")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAITranscribeAudioPostsMultipartToAudioEndpoint(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("method=%s, want POST", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/audio/transcriptions" {
|
||||
t.Errorf("path=%s, want /audio/transcriptions", r.URL.Path)
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
|
||||
t.Errorf("Authorization=%q, want Bearer test-key", got)
|
||||
}
|
||||
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "multipart/form-data") {
|
||||
t.Errorf("Content-Type=%q, want multipart/form-data", got)
|
||||
}
|
||||
|
||||
if err := r.ParseMultipartForm(1024 * 1024); err != nil {
|
||||
t.Errorf("ParseMultipartForm: %v", err)
|
||||
http.Error(w, "bad multipart", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.FormValue("model"); got != "whisper-1" {
|
||||
t.Errorf("model=%q, want whisper-1", got)
|
||||
}
|
||||
if got := r.FormValue("language"); got != "en" {
|
||||
t.Errorf("language=%q, want en", got)
|
||||
}
|
||||
if got := r.FormValue("temperature"); got != "0.2" {
|
||||
t.Errorf("temperature=%q, want 0.2", got)
|
||||
}
|
||||
|
||||
file, _, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
t.Errorf("FormFile: %v", err)
|
||||
http.Error(w, "missing file", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
content, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
t.Errorf("read upload: %v", err)
|
||||
http.Error(w, "read upload failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if string(content) != "audio-bytes" {
|
||||
t.Errorf("file content=%q, want audio-bytes", string(content))
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"text": "hello world"})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
audioPath := t.TempDir() + "/sample.wav"
|
||||
if err := os.WriteFile(audioPath, []byte("audio-bytes"), 0600); err != nil {
|
||||
t.Fatalf("write audio fixture: %v", err)
|
||||
}
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "whisper-1"
|
||||
resp, err := newOpenAIForTest(srv.URL).TranscribeAudio(
|
||||
&model,
|
||||
&audioPath,
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
&ASRConfig{Params: map[string]interface{}{
|
||||
"language": "en",
|
||||
"temperature": 0.2,
|
||||
}},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("TranscribeAudio: %v", err)
|
||||
}
|
||||
if resp.Text != "hello world" {
|
||||
t.Fatalf("Text=%q, want hello world", resp.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAITranscribeAudioWithSenderStreamsDeltas(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("method=%s, want POST", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/audio/transcriptions" {
|
||||
t.Errorf("path=%s, want /audio/transcriptions", r.URL.Path)
|
||||
}
|
||||
if got := r.Header.Get("Accept"); got != "text/event-stream" {
|
||||
t.Errorf("Accept=%q, want text/event-stream", got)
|
||||
}
|
||||
if err := r.ParseMultipartForm(1024 * 1024); err != nil {
|
||||
t.Errorf("ParseMultipartForm: %v", err)
|
||||
http.Error(w, "bad multipart", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.FormValue("stream"); got != "true" {
|
||||
t.Errorf("stream=%q, want true", got)
|
||||
}
|
||||
if got := r.FormValue("model"); got != "gpt-4o-mini-transcribe" {
|
||||
t.Errorf("model=%q, want gpt-4o-mini-transcribe", got)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"transcript.text.delta\",\"delta\":\"hello\"}\n\n"))
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"transcript.text.delta\",\"delta\":\" world\"}\n\n"))
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"transcript.text.done\",\"text\":\"hello world\"}\n\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
audioPath := t.TempDir() + "/sample.wav"
|
||||
if err := os.WriteFile(audioPath, []byte("audio-bytes"), 0600); err != nil {
|
||||
t.Fatalf("write audio fixture: %v", err)
|
||||
}
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "gpt-4o-mini-transcribe"
|
||||
var chunks []string
|
||||
err := newOpenAIForTest(srv.URL).TranscribeAudioWithSender(
|
||||
&model,
|
||||
&audioPath,
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
nil,
|
||||
func(content, _ *string) error {
|
||||
if content != nil {
|
||||
chunks = append(chunks, *content)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("TranscribeAudioWithSender: %v", err)
|
||||
}
|
||||
if got := strings.Join(chunks, ""); got != "hello world[DONE]" {
|
||||
t.Fatalf("streamed text=%q, want hello world[DONE]", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIAudioSpeechPostsJSONToAudioEndpoint(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("method=%s, want POST", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/audio/speech" {
|
||||
t.Errorf("path=%s, want /audio/speech", r.URL.Path)
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
|
||||
t.Errorf("Authorization=%q, want Bearer test-key", got)
|
||||
}
|
||||
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") {
|
||||
t.Errorf("Content-Type=%q, want application/json", got)
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Errorf("decode body: %v", err)
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if body["model"] != "tts-1" {
|
||||
t.Errorf("model=%v, want tts-1", body["model"])
|
||||
}
|
||||
if body["input"] != "hello" {
|
||||
t.Errorf("input=%v, want hello", body["input"])
|
||||
}
|
||||
if body["voice"] != "alloy" {
|
||||
t.Errorf("voice=%v, want alloy", body["voice"])
|
||||
}
|
||||
if body["response_format"] != "wav" {
|
||||
t.Errorf("response_format=%v, want wav", body["response_format"])
|
||||
}
|
||||
if body["speed"] != float64(1.25) {
|
||||
t.Errorf("speed=%v, want 1.25", body["speed"])
|
||||
}
|
||||
|
||||
_, _ = w.Write([]byte("audio-bytes"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "tts-1"
|
||||
input := "hello"
|
||||
resp, err := newOpenAIForTest(srv.URL).AudioSpeech(
|
||||
&model,
|
||||
&input,
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
&TTSConfig{
|
||||
Format: "wav",
|
||||
Params: map[string]interface{}{
|
||||
"voice": "alloy",
|
||||
"speed": 1.25,
|
||||
},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("AudioSpeech: %v", err)
|
||||
}
|
||||
if string(resp.Audio) != "audio-bytes" {
|
||||
t.Fatalf("Audio=%q, want audio-bytes", string(resp.Audio))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIAudioSpeechRequiresVoice(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
model := "tts-1"
|
||||
input := "hello"
|
||||
|
||||
_, err := newOpenAIForTest("http://unused").AudioSpeech(
|
||||
&model,
|
||||
&input,
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
nil,
|
||||
)
|
||||
if err == nil || !strings.Contains(err.Error(), "voice is required") {
|
||||
t.Fatalf("err=%v, want voice is required", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIAudioSpeechRejectsNonStringVoice(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
model := "tts-1"
|
||||
input := "hello"
|
||||
|
||||
_, err := newOpenAIForTest("http://unused").AudioSpeech(
|
||||
&model,
|
||||
&input,
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
&TTSConfig{Params: map[string]interface{}{"voice": 123}},
|
||||
)
|
||||
if err == nil || !strings.Contains(err.Error(), "voice is required") {
|
||||
t.Fatalf("err=%v, want voice is required", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIAudioSpeechWithSenderStreamsRawAudio(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("method=%s, want POST", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/audio/speech" {
|
||||
t.Errorf("path=%s, want /audio/speech", r.URL.Path)
|
||||
}
|
||||
var body map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Errorf("decode body: %v", err)
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if body["model"] != "tts-1" {
|
||||
t.Errorf("model=%v, want tts-1", body["model"])
|
||||
}
|
||||
if body["input"] != "hello" {
|
||||
t.Errorf("input=%v, want hello", body["input"])
|
||||
}
|
||||
if body["voice"] != "alloy" {
|
||||
t.Errorf("voice=%v, want alloy", body["voice"])
|
||||
}
|
||||
if body["stream_format"] != "audio" {
|
||||
t.Errorf("stream_format=%v, want audio", body["stream_format"])
|
||||
}
|
||||
|
||||
flusher, _ := w.(http.Flusher)
|
||||
_, _ = w.Write([]byte("audio-"))
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
_, _ = w.Write([]byte("bytes"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "tts-1"
|
||||
input := "hello"
|
||||
|
||||
var chunks []string
|
||||
err := newOpenAIForTest(srv.URL).AudioSpeechWithSender(
|
||||
&model,
|
||||
&input,
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
&TTSConfig{Params: map[string]interface{}{"voice": "alloy"}},
|
||||
func(content, _ *string) error {
|
||||
if content != nil {
|
||||
chunks = append(chunks, *content)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("AudioSpeechWithSender: %v", err)
|
||||
}
|
||||
if got := strings.Join(chunks, ""); got != "audio-bytes" {
|
||||
t.Fatalf("streamed audio=%q, want audio-bytes", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIAudioSpeechWithSenderStreamsSSEAudioDeltas(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Accept"); got != "text/event-stream" {
|
||||
t.Errorf("Accept=%q, want text/event-stream", got)
|
||||
}
|
||||
var body map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Errorf("decode body: %v", err)
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if body["stream_format"] != "sse" {
|
||||
t.Errorf("stream_format=%v, want sse", body["stream_format"])
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
chunk := base64.StdEncoding.EncodeToString([]byte("audio-bytes"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"speech.audio.delta\",\"audio\":\"" + chunk + "\"}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"speech.audio.done\"}\n\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "gpt-4o-mini-tts"
|
||||
input := "hello"
|
||||
var chunks []string
|
||||
err := newOpenAIForTest(srv.URL).AudioSpeechWithSender(
|
||||
&model,
|
||||
&input,
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
&TTSConfig{Params: map[string]interface{}{
|
||||
"voice": "alloy",
|
||||
"stream_format": "sse",
|
||||
}},
|
||||
func(content, _ *string) error {
|
||||
if content != nil {
|
||||
chunks = append(chunks, *content)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("AudioSpeechWithSender: %v", err)
|
||||
}
|
||||
if got := strings.Join(chunks, ""); got != "audio-bytes" {
|
||||
t.Fatalf("streamed audio=%q, want audio-bytes", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user