From b8b741555f9e4fb536b79cbd76928ccd3209abdc Mon Sep 17 00:00:00 2001 From: Panda Dev <56657208+pandadev66@users.noreply.github.com> Date: Thu, 7 May 2026 07:09:51 +0200 Subject: [PATCH] Go: implement provider: OpenAI (#14605) ### What problem does this PR solve? Add a Go driver for OpenAI (GPT models). The config file conf/models/openai.json has been in the repo for a while with the full GPT-5 model list, but internal/entity/models/factory.go had no case for "openai". So any tenant that configured OpenAI as a model provider in the Go layer fell through to the default branch and got the dummy driver. Chat, list models, and check connection all returned dummy responses instead of reaching the API. OpenAI is the most commonly requested provider and the JSON config already ships with the repo, so this gap is high impact even though the JSON has been there for some time. ### What this PR includes - New file internal/entity/models/openai.go with an OpenAIModel that implements the ModelDriver interface. - factory.go: route the "openai" provider name to NewOpenAIModel. - conf/models/openai.json: add "models": "models" under url_suffix so ListModels can hit /v1/models with no hardcoded fallback. ### How the driver works - OpenAI exposes the canonical OpenAI-compatible API at https://api.openai.com/v1. - ChatWithMessages and ChatStreamlyWithSender post to /chat/completions in the same shape the moonshot, vllm, and xai drivers use. - ListModels and CheckConnection call /models to list available ids and confirm the API key works. - reasoning_content is passed through for the o-series and other reasoning models, in both the non-stream and stream paths. - Encode (embeddings) is left as "not implemented" for now, the same way the other recent provider drivers do it. Rerank and Balance are not part of OpenAI's public API surface in this layer and return a clear "not implemented" or "no such method" error. ### Type of change - [x] New Feature (non-breaking change which adds functionality) ### How was this tested? - go build ./internal/entity/models/... in a clean go 1.25 image (the go.mod minimum) returns exit 0 with no errors. - Method set of OpenAIModel matches the ModelDriver interface: NewInstance, Name, ChatWithMessages, ChatStreamlyWithSender, Encode, Rerank, ListModels, Balance, CheckConnection. - Pattern parity with the merged moonshot (#14433), volcengine (#14460), minimax (#14478), vllm (#14532), xai (#14550), and lm-studio (#14586) PRs. Closes #14604 --- conf/models/openai.json | 3 +- internal/entity/models/factory.go | 2 + internal/entity/models/openai.go | 502 ++++++++++++++++++++++++++++++ 3 files changed, 506 insertions(+), 1 deletion(-) create mode 100644 internal/entity/models/openai.go diff --git a/conf/models/openai.json b/conf/models/openai.json index f4c3bdc9b1..696c6f93b3 100644 --- a/conf/models/openai.json +++ b/conf/models/openai.json @@ -4,7 +4,8 @@ "default": "https://api.openai.com/v1" }, "url_suffix": { - "chat": "chat/completions" + "chat": "chat/completions", + "models": "models" }, "class": "gpt", "models": [ diff --git a/internal/entity/models/factory.go b/internal/entity/models/factory.go index 9d941a534c..9efd33e472 100644 --- a/internal/entity/models/factory.go +++ b/internal/entity/models/factory.go @@ -57,6 +57,8 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string return NewXAIModel(baseURL, urlSuffix), nil case "lmstudio": return NewLmStudioModel(baseURL, urlSuffix), nil + case "openai": + return NewOpenAIModel(baseURL, urlSuffix), nil default: return NewDummyModel(baseURL, urlSuffix), nil } diff --git a/internal/entity/models/openai.go b/internal/entity/models/openai.go new file mode 100644 index 0000000000..0d3e259ff2 --- /dev/null +++ b/internal/entity/models/openai.go @@ -0,0 +1,502 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// OpenAIModel implements ModelDriver for OpenAI (GPT models). +// The non-streaming call timeout is the shared nonStreamCallTimeout +// constant defined alongside the xAI driver in this package. +type OpenAIModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client // Reusable HTTP client with connection pool +} + +// NewOpenAIModel creates a new OpenAI model instance. +// +// We clone http.DefaultTransport so we keep Go's defaults for +// ProxyFromEnvironment, DialContext (with KeepAlive), HTTP/2, +// TLSHandshakeTimeout, and ExpectContinueTimeout, and only override +// the few connection-pool fields we care about. +// +// The Client itself has no Timeout. http.Client.Timeout would also +// cap the time spent reading the response body, which would cut off +// long-lived SSE streams in ChatStreamlyWithSender. Non-streaming +// callers wrap each request with context.WithTimeout instead. +func NewOpenAIModel(baseURL map[string]string, urlSuffix URLSuffix) *OpenAIModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + // Cap how long the client waits for the first response header. + // This protects ChatStreamlyWithSender, which has no client-wide + // timeout, against a server that opens the TCP connection and + // then never sends a response. + transport.ResponseHeaderTimeout = 60 * time.Second + + return &OpenAIModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (z *OpenAIModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewOpenAIModel(baseURL, z.URLSuffix) +} + +func (z *OpenAIModel) Name() string { + return "openai" +} + +// baseURLForRegion returns the base URL for the given region, or an +// error if no entry exists. This makes a misconfigured region fail +// fast with a clear message, instead of silently producing a relative +// URL that the HTTP transport then rejects. +func (z *OpenAIModel) baseURLForRegion(region string) (string, error) { + base, ok := z.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("openai: no base URL configured for region %q", region) + } + return base, nil +} + +// ChatWithMessages sends multiple messages with roles and returns the response +func (z *OpenAIModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + var region = "default" + if apiConfig.Region != nil { + region = *apiConfig.Region + } + + baseURL, err := z.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, z.URLSuffix.Chat) + + // Convert messages to the format expected by the API + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + // Build request body + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + "temperature": 1, + } + + // Note: do NOT propagate chatModelConfig.Stream into the request body + // here. ChatWithMessages parses a single JSON response, so SSE/stream + // must always be off for this code path. + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + + // OpenAI reasoning models (o-series and similar) return reasoning text in + // the reasoning_content field. Pass it through when present. + var reasonContent string + if rc, ok := messageMap["reasoning_content"].(string); ok { + reasonContent = rc + if reasonContent != "" && reasonContent[0] == '\n' { + reasonContent = reasonContent[1:] + } + } + + chatResponse := &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + } + + return chatResponse, nil +} + +// ChatStreamlyWithSender sends messages and streams the response via the +// sender function. Used for streaming chat responses with no extra channel. +func (z *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("api key is required") + } + + var region = "default" + if apiConfig.Region != nil { + region = *apiConfig.Region + } + + baseURL, err := z.baseURLForRegion(region) + if err != nil { + return err + } + url := fmt.Sprintf("%s/%s", baseURL, z.URLSuffix.Chat) + + // Convert messages to API format (supports multimodal content) + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + // Build request body with streaming on by default + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + } + + if chatModelConfig != nil { + // Refuse to run if the caller explicitly asked for stream=false. + // The body of this method only knows how to read SSE, so a non-SSE + // JSON response would be parsed as if it were a stream and produce + // no chunks. Better to fail clearly. Leave reqBody["stream"] as + // the default (true) when Stream is nil or true. + if chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + // Use an explicit background context here so the request is at least + // cancellable in principle. We do not attach a hard deadline because + // SSE streams are long-lived. The transport's ResponseHeaderTimeout + // caps the connection-establishment phase. Threading a real ctx + // through the ModelDriver interface is a wider change for a follow-up. + req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := z.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: read line by line. The default bufio.Scanner buffer + // is 64KB, which can be too small for long SSE chunks. Bump it to + // 1MB so we never silently truncate a long data: line. + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + // sawTerminal flips to true when the upstream actually told us the + // stream is over (either a "[DONE]" marker or a non-empty + // finish_reason). If the body closes before either of those, we + // must not emit a synthetic "[DONE]" because that would hide a + // truncated response from the caller. + sawTerminal := false + for scanner.Scan() { + line := scanner.Text() + + // SSE data line starts with "data:" + if !strings.HasPrefix(line, "data:") { + continue + } + + // Extract JSON after "data:" + data := strings.TrimSpace(line[5:]) + + // [DONE] marks the end of the stream + if data == "[DONE]" { + sawTerminal = true + break + } + + // Parse the JSON event + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + reasoningContent, ok := delta["reasoning_content"].(string) + if ok && reasoningContent != "" { + if err := sender(nil, &reasoningContent); err != nil { + return err + } + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + sawTerminal = true + break + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawTerminal { + return fmt.Errorf("openai: stream ended before [DONE] or finish_reason") + } + + // Send the [DONE] marker for OpenAI compatibility + endOfStream := "[DONE]" + if err := sender(&endOfStream, nil); err != nil { + return err + } + + return nil +} + +// Encode encodes a list of texts into embeddings. OpenAI does expose +// embedding endpoints (text-embedding-3-* and text-embedding-ada-002), +// but this initial driver intentionally leaves embedding support +// unimplemented. A follow-up PR can add it. +func (z *OpenAIModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { + return nil, fmt.Errorf("not implemented") +} + +// ListModels returns the list of model ids visible to the API key. +func (z *OpenAIModel) ListModels(apiConfig *APIConfig) ([]string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + var region = "default" + if apiConfig.Region != nil { + region = *apiConfig.Region + } + + baseURL, err := z.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, z.URLSuffix.Models) + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // GET has no body, so Content-Type is not needed. + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + data, ok := result["data"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid models list format") + } + + models := make([]string, 0) + for _, model := range data { + modelMap, ok := model.(map[string]interface{}) + if !ok { + continue + } + modelName, ok := modelMap["id"].(string) + if !ok { + continue + } + models = append(models, modelName) + } + + return models, nil +} + +// Balance is not exposed by the OpenAI API, so this returns "no such method". +func (z *OpenAIModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("no such method") +} + +// CheckConnection runs a lightweight ListModels call to verify the API key. +func (z *OpenAIModel) CheckConnection(apiConfig *APIConfig) error { + _, err := z.ListModels(apiConfig) + if err != nil { + return err + } + return nil +} + +// Rerank calculates similarity scores between query and texts. OpenAI does +// not expose a rerank API, so this is left unimplemented. +func (z *OpenAIModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) { + return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) +}