From 2d3a1a448393a0a8a22dbdb922ed7ae94f6e546f Mon Sep 17 00:00:00 2001 From: web-dev0521 Date: Wed, 20 May 2026 23:52:56 -0400 Subject: [PATCH] feat(go-models): add Azure OpenAI model driver (#15022) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What problem does this PR solve? Closes #15021. The Go model-provider layer had no support for **Azure OpenAI**. Azure OpenAI is *not* a drop-in base-URL swap of the OpenAI driver — it differs in authentication, endpoint structure, and how models are listed — so it needs its own `ModelDriver` implementation. ## Type of change - [x] New feature (non-breaking change which adds functionality) Co-authored-by: Jin Hai --- conf/models/azure-openai.json | 9 + internal/entity/models/azure_openai.go | 609 +++++++++++++++++++++++++ internal/entity/models/factory.go | 2 + 3 files changed, 620 insertions(+) create mode 100644 conf/models/azure-openai.json create mode 100644 internal/entity/models/azure_openai.go diff --git a/conf/models/azure-openai.json b/conf/models/azure-openai.json new file mode 100644 index 0000000000..f8cbb7a52c --- /dev/null +++ b/conf/models/azure-openai.json @@ -0,0 +1,9 @@ +{ + "name": "Azure-OpenAI", + "url_suffix": { + "chat": "chat/completions", + "embedding": "embeddings", + "models": "deployments" + }, + "class": "gpt" +} diff --git a/internal/entity/models/azure_openai.go b/internal/entity/models/azure_openai.go new file mode 100644 index 0000000000..e2fe1f1b58 --- /dev/null +++ b/internal/entity/models/azure_openai.go @@ -0,0 +1,609 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// azureAPIVersion is the Azure OpenAI REST API version sent as the +// api-version query parameter on every request. Azure requires this on +// all data-plane calls; 2024-10-21 is the latest GA (non-preview) version. +const azureAPIVersion = "2024-10-21" + +// AzureOpenAIModel implements ModelDriver for Azure OpenAI. +// +// Azure OpenAI is not a base-URL swap of the OpenAI driver. It differs in +// three ways that this driver handles: +// - Endpoints are deployment-scoped: +// {baseURL}/deployments/{deployment}/{op}?api-version={azureAPIVersion} +// The model name passed in is the Azure deployment name, which goes in +// the URL path rather than the request body. +// - Authentication uses the "api-key" header, not "Authorization: Bearer". +// - Listing models means listing deployments via {baseURL}/deployments. +// +// The base URL is user-supplied (e.g. https://.openai.azure.com/openai) +// because each Azure resource has its own endpoint; there is no shared default. +type AzureOpenAIModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client // Reusable HTTP client with connection pool +} + +// NewAzureOpenAIModel creates a new Azure OpenAI model instance. +// +// The transport mirrors the OpenAI driver: clone http.DefaultTransport to +// keep Go's defaults (proxy, dial keep-alive, HTTP/2, TLS handshake) and +// only tune the connection pool. The Client has no overall Timeout so SSE +// streams in ChatStreamlyWithSender are not cut off; non-streaming callers +// wrap each request in context.WithTimeout, and ResponseHeaderTimeout caps +// how long we wait for the first response header. +func NewAzureOpenAIModel(baseURL map[string]string, urlSuffix URLSuffix) *AzureOpenAIModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &AzureOpenAIModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (z *AzureOpenAIModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewAzureOpenAIModel(baseURL, z.URLSuffix) +} + +func (z *AzureOpenAIModel) Name() string { + return "azure-openai" +} + +// baseURLForRegion returns the base URL for the given region, or an error if +// no entry exists. A misconfigured region fails fast with a clear message +// instead of silently producing a relative URL the transport then rejects. +func (z *AzureOpenAIModel) baseURLForRegion(region string) (string, error) { + base, ok := z.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("azure-openai: no base URL configured for region %q", region) + } + return base, nil +} + +// deploymentURL builds a deployment-scoped data-plane URL of the form +// {baseURL}/deployments/{deployment}/{op}?api-version={azureAPIVersion}. +func (z *AzureOpenAIModel) deploymentURL(baseURL, deployment, op string) string { + return fmt.Sprintf("%s/deployments/%s/%s?api-version=%s", + strings.TrimRight(baseURL, "/"), deployment, op, azureAPIVersion) +} + +// ChatWithMessages sends multiple messages with roles and returns the response. +func (z *AzureOpenAIModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + if modelName == "" { + return nil, fmt.Errorf("deployment name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := z.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := z.deploymentURL(baseURL, modelName, z.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + // The deployment (and therefore the model) is identified by the URL path, + // so Azure does not take a "model" field in the body. + reqBody := map[string]interface{}{ + "messages": apiMessages, + "stream": false, + "temperature": 1, + } + + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("api-key", *apiConfig.ApiKey) + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + + var reasonContent string + if rc, ok := messageMap["reasoning_content"].(string); ok { + reasonContent = rc + if reasonContent != "" && reasonContent[0] == '\n' { + reasonContent = reasonContent[1:] + } + } + + return &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + }, nil +} + +// ChatStreamlyWithSender sends messages and streams the response via the +// sender function. Used for streaming chat responses with no extra channel. +func (z *AzureOpenAIModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("api key is required") + } + + if modelName == "" { + return fmt.Errorf("deployment name is required") + } + + if sender == nil { + return fmt.Errorf("sender is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := z.baseURLForRegion(region) + if err != nil { + return err + } + url := z.deploymentURL(baseURL, modelName, z.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "messages": apiMessages, + "stream": true, + } + + if chatModelConfig != nil { + // This code path only knows how to read SSE, so refuse an explicit + // stream=false rather than mis-parsing a single JSON response as a + // stream and emitting no chunks. + if chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + // Background context: SSE streams are long-lived so we attach no hard + // deadline. The transport's ResponseHeaderTimeout caps connection setup. + req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("api-key", *apiConfig.ApiKey) + + resp, err := z.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: bump the scanner buffer to 1MB so a long data: line is + // never silently truncated by the default 64KB cap. + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + // sawTerminal flips true when upstream signals the stream is done (a + // "[DONE]" marker or a non-empty finish_reason). If the body closes + // before either, we must not emit a synthetic "[DONE]" that would hide + // a truncated response from the caller. + sawTerminal := false + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "data:") { + continue + } + + data := strings.TrimSpace(line[5:]) + + if data == "[DONE]" { + sawTerminal = true + break + } + + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + if reasoningContent, ok := delta["reasoning_content"].(string); ok && reasoningContent != "" { + if err := sender(nil, &reasoningContent); err != nil { + return err + } + } + + if content, ok := delta["content"].(string); ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + if finishReason, ok := firstChoice["finish_reason"].(string); ok && finishReason != "" { + sawTerminal = true + break + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawTerminal { + return fmt.Errorf("azure-openai: stream ended before [DONE] or finish_reason") + } + + endOfStream := "[DONE]" + return sender(&endOfStream, nil) +} + +type azureEmbeddingResponse struct { + Data []struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` +} + +// Embed turns a list of texts into embedding vectors using the Azure OpenAI +// embeddings deployment. The output has one vector per input, in the same +// order the inputs were given. +func (z *AzureOpenAIModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("deployment name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := z.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := z.deploymentURL(baseURL, *modelName, z.URLSuffix.Embedding) + + // As with chat, the deployment is in the URL path, so no "model" field. + reqBody := map[string]interface{}{ + "input": texts, + } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("api-key", *apiConfig.ApiKey) + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Azure OpenAI embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed azureEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Azure returns one item per input, but does not guarantee response + // order. Each item's Index refers to its position in the input list, so + // place vectors by Index to honor the documented input-order guarantee. + // Reject an out-of-range index instead of panicking. + embeddings := make([]EmbeddingData, len(texts)) + for _, d := range parsed.Data { + if d.Index < 0 || d.Index >= len(embeddings) { + return nil, fmt.Errorf("embedding index %d out of range [0,%d)", d.Index, len(embeddings)) + } + embeddings[d.Index] = EmbeddingData{ + Embedding: d.Embedding, + Index: d.Index, + } + } + + return embeddings, nil +} + +// ListModels returns the deployment names visible to the configured API key. +// Azure exposes deployments (not a shared model catalog) at +// {baseURL}/deployments?api-version={azureAPIVersion}. +func (z *AzureOpenAIModel) ListModels(apiConfig *APIConfig) ([]string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := z.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s?api-version=%s", + strings.TrimRight(baseURL, "/"), z.URLSuffix.Models, azureAPIVersion) + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("api-key", *apiConfig.ApiKey) + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + data, ok := result["data"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid deployments list format") + } + + models := make([]string, 0, len(data)) + for _, item := range data { + m, ok := item.(map[string]interface{}) + if !ok { + continue + } + // Each deployment object exposes its name under "id". + id, ok := m["id"].(string) + if !ok { + continue + } + models = append(models, id) + } + + return models, nil +} + +// CheckConnection runs a lightweight ListModels call to verify the endpoint +// and API key. +func (z *AzureOpenAIModel) CheckConnection(apiConfig *APIConfig) error { + _, err := z.ListModels(apiConfig) + return err +} + +// Balance is not exposed by the Azure OpenAI API. +func (z *AzureOpenAIModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("no such method") +} + +// Rerank is not exposed by the Azure OpenAI API. +func (z *AzureOpenAIModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *AzureOpenAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *AzureOpenAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *AzureOpenAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *AzureOpenAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *AzureOpenAIModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *AzureOpenAIModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *AzureOpenAIModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *AzureOpenAIModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/factory.go b/internal/entity/models/factory.go index f7dad653b9..30848e9595 100644 --- a/internal/entity/models/factory.go +++ b/internal/entity/models/factory.go @@ -63,6 +63,8 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string return NewOllamaModel(baseURL, urlSuffix), nil case "openai": return NewOpenAIModel(baseURL, urlSuffix), nil + case "azure-openai": + return NewAzureOpenAIModel(baseURL, urlSuffix), nil case "nvidia": return NewNvidiaModel(baseURL, urlSuffix), nil case "openrouter":