Go: add stream / think chat (#14242)

### What problem does this PR solve?

1. Supports stream and non-stream chat
2. Supports think and non-think chat
3. List supported models from DeepSeek service. (This command can be
used to verify the API validity)

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2026-04-21 16:52:32 +08:00
committed by GitHub
parent a2bea30749
commit e48d75987c
18 changed files with 780 additions and 183 deletions

View File

@@ -222,10 +222,20 @@
"glm-4.6v",
"glm-4.5",
"glm-4.5v"
],
"clear": {
"default_value": true
}
]
},
"clear_thinking": {
"default_value": true,
"supported_models": [
"glm-5.1",
"glm-5",
"glm-5v-turbo",
"glm-4.7",
"glm-4.6",
"glm-4.6v",
"glm-4.5",
"glm-4.5v"
]
}
}
}

View File

@@ -164,6 +164,8 @@ func (c *RAGFlowClient) ExecuteAdminCommand(cmd *Command) (ResponseIf, error) {
return c.ShowProvider(cmd)
case "list_provider_models":
return c.ListModels(cmd)
case "list_supported_models":
return c.ListSupportedModels(cmd)
case "list_instance_models":
return c.ListInstanceModels(cmd)
case "show_model":
@@ -214,6 +216,8 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) {
return c.ShowProvider(cmd)
case "list_provider_models":
return c.ListModels(cmd)
case "list_supported_models":
return c.ListSupportedModels(cmd)
case "list_instance_models":
return c.ListInstanceModels(cmd)
case "show_model":

View File

@@ -335,6 +335,45 @@ func (c *RAGFlowClient) ListModels(cmd *Command) (ResponseIf, error) {
return &result, nil
}
func (c *RAGFlowClient) ListSupportedModels(cmd *Command) (ResponseIf, error) {
providerName, ok := cmd.Params["provider_name"].(string)
if !ok {
return nil, fmt.Errorf("provider_name not provided")
}
instanceName, ok := cmd.Params["instance_name"].(string)
if !ok {
return nil, fmt.Errorf("instance_name not provided")
}
var endPoint string
if c.ServerType == "admin" {
endPoint = fmt.Sprintf("/admin/providers/%s/instances/%s/models?supported=true", providerName, instanceName)
} else {
endPoint = fmt.Sprintf("/providers/%s/instances/%s/models?supported=true", providerName, instanceName)
}
resp, err := c.HTTPClient.Request("GET", endPoint, true, "web", nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to list models: %w", err)
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("failed to list models: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
}
var result CommonResponse
if err = json.Unmarshal(resp.Body, &result); err != nil {
return nil, fmt.Errorf("failed to list models: invalid JSON (%w)", err)
}
if result.Code != 0 {
return nil, fmt.Errorf("%s", result.Message)
}
result.Duration = resp.Duration
return &result, nil
}
func (c *RAGFlowClient) ShowModel(cmd *Command) (ResponseIf, error) {
providerName, ok := cmd.Params["provider_name"].(string)
if !ok {

View File

@@ -303,6 +303,8 @@ func (l *Lexer) lookupIdent(ident string) Token {
return Token{Type: TokenChat, Value: ident}
case "THINK":
return Token{Type: TokenThink, Value: ident}
case "STREAM":
return Token{Type: TokenStream, Value: ident}
case "LS":
return Token{Type: TokenLS, Value: ident}
case "CAT":
@@ -363,6 +365,8 @@ func (l *Lexer) lookupIdent(ident string) Token {
return Token{Type: TokenTable, Value: ident}
case "AVAILABLE":
return Token{Type: TokenAvailable, Value: ident}
case "SUPPORTED":
return Token{Type: TokenSupported, Value: ident}
case "NAME":
return Token{Type: TokenName, Value: ident}
case "INSTANCE":

View File

@@ -190,6 +190,8 @@ func (p *Parser) parseUserCommand() (*Command, error) {
return p.parseEnableCommand()
case TokenDisable:
return p.parseDisableCommand()
case TokenStream:
return p.parseStreamCommand()
case TokenChat:
return p.parseChatCommand()
case TokenThink:

View File

@@ -113,28 +113,33 @@ func (r *SimpleResponse) PrintOut() {
}
}
type MessageResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Duration float64
OutputFormat OutputFormat
type NonStreamResponse struct {
Code int `json:"code"`
ReasoningContent string `json:"reasoning_content"`
Answer string `json:"answer"`
Message string `json:"message"`
Duration float64
OutputFormat OutputFormat
}
func (r *MessageResponse) Type() string {
return "message"
func (r *NonStreamResponse) Type() string {
return "non_stream_message"
}
func (r *MessageResponse) TimeCost() float64 {
func (r *NonStreamResponse) TimeCost() float64 {
return r.Duration
}
func (r *MessageResponse) SetOutputFormat(format OutputFormat) {
func (r *NonStreamResponse) SetOutputFormat(format OutputFormat) {
r.OutputFormat = format
}
func (r *MessageResponse) PrintOut() {
func (r *NonStreamResponse) PrintOut() {
if r.Code == 0 {
fmt.Println(r.Message)
if r.ReasoningContent != "" {
fmt.Printf("Thinking: %s\n", r.ReasoningContent)
}
fmt.Printf("Answer: %s\n", r.Answer)
} else {
fmt.Println("ERROR")
fmt.Printf("%d, %s\n", r.Code, r.Message)

View File

@@ -73,6 +73,7 @@ const (
TokenKeys
TokenGenerate
TokenAvailable
TokenSupported
TokenModel
TokenModels
TokenProvider
@@ -80,6 +81,7 @@ const (
TokenDefault
TokenChats
TokenChat
TokenStream
TokenFiles
TokenAs
TokenParse
@@ -106,7 +108,6 @@ const (
TokenIndex
TokenVector
TokenSize
TokenDocMeta
TokenName // For ALTER PROVIDER <name> NAME <new_name>
TokenInstance
TokenInstances

View File

@@ -1436,83 +1436,106 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) {
}
message := cmd.Params["message"].(string)
reasoning := cmd.Params["reasoning"].(bool)
thinking := cmd.Params["thinking"].(bool)
stream := cmd.Params["stream"].(bool)
url := fmt.Sprintf("/providers/%s/instances/%s/models/%s", providerName, instanceName, modelName)
payload := map[string]interface{}{
"message": message,
"stream": true, // use stream API
"reasoning": reasoning,
"message": message,
"stream": stream, // use stream API
"thinking": thinking,
}
// Call stream http api
reader, duration, err := c.HTTPClient.RequestStream("POST", url, true, "web", nil, payload)
if err != nil {
return nil, fmt.Errorf("failed to chat model: %w", err)
}
defer reader.Close()
// Parse SSE and output to console
scanner := bufio.NewScanner(reader)
var fullMessage strings.Builder
reasoningPrint := true
messagePrint := true
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data:") {
data := strings.TrimPrefix(line, "data:")
data = strings.TrimSpace(data)
if strings.HasPrefix(data, "[REASONING]") {
data = strings.TrimPrefix(data, "[REASONING]")
if reasoningPrint {
fmt.Print("Thinking: ")
reasoningPrint = false
} else {
fmt.Print(data)
}
os.Stdout.Sync()
}
if strings.HasPrefix(data, "[MESSAGE]") {
data = strings.TrimPrefix(data, "[MESSAGE]")
if messagePrint {
if reasoning {
fmt.Println()
}
fmt.Print("Answer: ")
messagePrint = false
} else {
fmt.Print(data)
os.Stdout.Sync()
fullMessage.WriteString(data)
}
}
} else if strings.HasPrefix(line, "event:error") {
// error event
if scanner.Scan() {
errData := strings.TrimPrefix(scanner.Text(), "data:")
errData = strings.TrimSpace(errData)
return nil, fmt.Errorf("chat error: %s", errData)
}
// If there's an error, return a generic error
return nil, fmt.Errorf("chat error: received error event from server")
if stream {
// Call stream http api
reader, duration, err := c.HTTPClient.RequestStream("POST", url, true, "web", nil, payload)
if err != nil {
return nil, fmt.Errorf("failed to chat model: %w", err)
}
defer reader.Close()
// Parse SSE and output to console
scanner := bufio.NewScanner(reader)
var fullMessage strings.Builder
reasoningPrint := true
messagePrint := true
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data:") {
data := strings.TrimPrefix(line, "data:")
data = strings.TrimSpace(data)
if strings.HasPrefix(data, "[REASONING]") {
data = strings.TrimPrefix(data, "[REASONING]")
if reasoningPrint {
fmt.Print("Thinking: ")
reasoningPrint = false
} else {
fmt.Print(data)
}
os.Stdout.Sync()
}
if strings.HasPrefix(data, "[MESSAGE]") {
data = strings.TrimPrefix(data, "[MESSAGE]")
if messagePrint {
if thinking {
fmt.Println()
}
fmt.Print("Answer: ")
messagePrint = false
} else {
fmt.Print(data)
os.Stdout.Sync()
fullMessage.WriteString(data)
}
}
} else if strings.HasPrefix(line, "event:error") {
// error event
if scanner.Scan() {
errData := strings.TrimPrefix(scanner.Text(), "data:")
errData = strings.TrimSpace(errData)
return nil, fmt.Errorf("chat error: %s", errData)
}
// If there's an error, return a generic error
return nil, fmt.Errorf("chat error: received error event from server")
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading stream: %w", err)
}
fmt.Println()
result := &StreamMessageResponse{
Code: 0,
Message: fullMessage.String(),
Duration: duration,
}
return result, nil
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading stream: %w", err)
resp, err := c.HTTPClient.Request("POST", url, true, "web", nil, payload)
if err != nil {
return nil, fmt.Errorf("failed to list instance models: %w", err)
}
fmt.Println()
result := &StreamMessageResponse{
Code: 0,
Message: fullMessage.String(),
Duration: duration,
if resp.StatusCode != 200 {
return nil, fmt.Errorf("failed to list instance models: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
}
return result, nil
var result NonStreamResponse
if err = json.Unmarshal(resp.Body, &result); err != nil {
return nil, fmt.Errorf("failed to list instance models: invalid JSON (%w)", err)
}
if result.Code != 0 {
return nil, fmt.Errorf("%s", result.Message)
}
result.Duration = resp.Duration
return &result, nil
}
// UseModel sets the current model for chat

View File

@@ -163,6 +163,8 @@ func (p *Parser) parseListCommand() (*Command, error) {
return p.parseListTokens()
case TokenModel:
return p.parseListModelProviders()
case TokenSupported:
return p.parseListModelsOfProvider()
case TokenModels:
return p.parseListModelsOfProvider()
case TokenProviders:
@@ -2014,11 +2016,55 @@ func (p *Parser) parseSearchCommand() (*Command, error) {
}
func (p *Parser) parseListModelsOfProvider() (*Command, error) {
if p.curToken.Type == TokenSupported {
// List supported models
p.nextToken()
cmd := NewCommand("list_supported_models")
if p.curToken.Type != TokenModels {
return nil, fmt.Errorf("expected MODELS")
}
p.nextToken()
if p.curToken.Type != TokenFrom {
return nil, fmt.Errorf("expected FROM")
}
p.nextToken()
if p.curToken.Type != TokenQuotedString {
return nil, fmt.Errorf("expected quoted string for provider name")
}
firstName, err := p.parseQuotedString()
if err != nil {
return nil, err
}
p.nextToken()
if p.curToken.Type != TokenQuotedString {
return nil, fmt.Errorf("expected quoted string for instance name")
}
secondName, err := p.parseQuotedString()
if err != nil {
return nil, err
}
p.nextToken()
cmd.Params["provider_name"] = firstName
cmd.Params["instance_name"] = secondName
// Semicolon is optional for UNSET TOKEN
if p.curToken.Type == TokenSemicolon {
p.nextToken()
}
return cmd, nil
}
if p.curToken.Type != TokenModels {
return nil, fmt.Errorf("expected MODELS")
}
p.nextToken()
if p.curToken.Type != TokenFrom {
return nil, fmt.Errorf("expected FROM")
}
@@ -2194,19 +2240,47 @@ func (p *Parser) parseChatCommand() (*Command, error) {
cmd.Params["composite_model_name"] = compositeModelName
}
cmd.Params["message"] = message
cmd.Params["reasoning"] = false
cmd.Params["thinking"] = false
cmd.Params["stream"] = false
return cmd, nil
}
func (p *Parser) parseThinkCommand() (*Command, error) {
p.nextToken() // consume THINK
if p.curToken.Type != TokenChat {
return nil, fmt.Errorf("expected CHAT after THINK")
}
command, err := p.parseChatCommand()
if err != nil {
return nil, err
}
command.Type = "think_chat_to_model"
command.Params["reasoning"] = true
command.Params["thinking"] = true
return command, nil
}
func (p *Parser) parseStreamCommand() (*Command, error) {
p.nextToken() // consume STREAM
var command *Command
var err error
if p.curToken.Type == TokenChat {
command, err = p.parseChatCommand()
if err != nil {
return nil, err
}
} else if p.curToken.Type == TokenThink {
command, err = p.parseThinkCommand()
if err != nil {
return nil, err
}
}
command.Params["stream"] = true
return command, nil
}

View File

@@ -61,14 +61,14 @@ type Reasoning struct {
// Reasoning represents the reasoning capability (can be one of three types)
type ClearReasoningContent struct {
DefaultValue bool `json:"default_value"`
DefaultValue bool `json:"default_value"`
SupportedModels []string `json:"supported_models"`
}
// Reasoning represents the reasoning capability (can be one of three types)
type Thinking struct {
DefaultValue bool `json:"default_value"`
SupportedModels []string `json:"supported_models"`
Clear ClearReasoningContent `json:"clear"`
DefaultValue bool `json:"default_value"`
SupportedModels []string `json:"supported_models"`
}
// UnmarshalJSON custom unmarshal for Reasoning
@@ -142,9 +142,10 @@ type Multimodal struct {
// Features represents all features of a model
type Features struct {
Multimodal *Multimodal `json:"multimodal,omitempty"`
Reasoning *Reasoning `json:"reasoning,omitempty"`
Thinking *Thinking `json:"thinking,omitempty"`
Multimodal *Multimodal `json:"multimodal,omitempty"`
Reasoning *Reasoning `json:"reasoning,omitempty"`
Thinking *Thinking `json:"thinking,omitempty"`
ClearThinking *ClearReasoningContent `json:"clear_thinking,omitempty"`
}
type ModelThinking struct {
@@ -231,16 +232,29 @@ func NewProviderManager(dirPath string) (*ProviderManager, error) {
}
}
modelClearThinking := make(map[string]bool)
if provider.Features.ClearThinking != nil {
for _, modelName := range provider.Features.ClearThinking.SupportedModels {
modelClearThinking[modelName] = true
}
}
for _, model := range provider.Models {
// if the prefix of mode.Name is matched with keys of modelSupportThinking
for modelPrefix, _ := range modelSupportThinking {
if strings.HasPrefix(model.Name, modelPrefix) {
model.Thinking = &ModelThinking{
DefaultValue: provider.Features.Thinking.DefaultValue,
ClearContent: provider.Features.Thinking.Clear.DefaultValue,
}
}
}
for modelPrefix, _ := range modelClearThinking {
if strings.HasPrefix(model.Name, modelPrefix) {
model.Thinking.ClearContent = true
}
}
model.ModelTypeMap = make(map[string]bool)
for _, modelType := range model.ModelTypes {
model.ModelTypeMap[modelType] = true

View File

@@ -0,0 +1,147 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package models
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
// DeepSeekModel implements ModelDriver for DeepSeek
type DeepSeekModel struct {
BaseURL map[string]string
URLSuffix URLSuffix
httpClient *http.Client // Reusable HTTP client with connection pool
}
// NewDeepSeekModel creates a new DeepSeek model instance
func NewDeepSeekModel(baseURL map[string]string, urlSuffix URLSuffix) *DeepSeekModel {
return &DeepSeekModel{
BaseURL: baseURL,
URLSuffix: urlSuffix,
httpClient: &http.Client{
Timeout: 120 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
DisableCompression: false,
},
},
}
}
// Chat sends a message and returns response
func (z *DeepSeekModel) Chat(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
return nil, fmt.Errorf("not implemented")
}
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
func (z *DeepSeekModel) ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error {
return fmt.Errorf("not implemented")
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *DeepSeekModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("not implemented")
}
/*
{
"object": "list",
"data": [
{
"id": "deepseek-chat",
"object": "model",
"owned_by": "deepseek"
},
{
"id": "deepseek-reasoner",
"object": "model",
"owned_by": "deepseek"
}
]
}
*/
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by"`
}
type ModelList struct {
Object string `json:"object"`
Models []Model `json:"data"`
}
func (z *DeepSeekModel) ListModels(apiConfig *APIConfig) ([]string, error) {
var region = "default"
if apiConfig.Region != nil {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Models)
// Build request body
reqBody := map[string]interface{}{}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := z.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var modelList ModelList
if err = json.Unmarshal(body, &modelList); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
var models []string
for _, model := range modelList.Models {
models = append(models, model.ID)
}
return models, nil
}

View File

@@ -35,16 +35,20 @@ func NewDummyModel(baseURL map[string]string, urlSuffix URLSuffix) *DummyModel {
}
// Chat sends a message and returns response
func (z *DummyModel) Chat(modelName, apiKey, message *string, modelConfig *ChatConfig) (string, error) {
return "", fmt.Errorf("not implemented")
func (z *DummyModel) Chat(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig) (*ChatResponse, error) {
return nil, fmt.Errorf("not implemented")
}
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
func (z *DummyModel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error {
func (z *DummyModel) ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error {
return fmt.Errorf("not implemented")
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *DummyModel) EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
func (z *DummyModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("not implemented")
}
func (z *DummyModel) ListModels(apiConfig *APIConfig) ([]string, error) {
return nil, fmt.Errorf("not implemented")
}

View File

@@ -35,6 +35,10 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string
switch providerLower {
case "zhipu-ai":
return NewZhipuAIModel(baseURL, urlSuffix), nil
case "deepseek":
return NewDeepSeekModel(baseURL, urlSuffix), nil
case "moonshot":
return NewMooshotModel(baseURL, urlSuffix), nil
default:
return NewDummyModel(baseURL, urlSuffix), nil
}

View File

@@ -0,0 +1,118 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package models
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
// MooshotModel implements ModelDriver for Mooshot
type MooshotModel struct {
BaseURL map[string]string
URLSuffix URLSuffix
httpClient *http.Client // Reusable HTTP client with connection pool
}
// NewMooshotModel creates a new Mooshot model instance
func NewMooshotModel(baseURL map[string]string, urlSuffix URLSuffix) *MooshotModel {
return &MooshotModel{
BaseURL: baseURL,
URLSuffix: urlSuffix,
httpClient: &http.Client{
Timeout: 120 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
DisableCompression: false,
},
},
}
}
// Chat sends a message and returns response
func (z *MooshotModel) Chat(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
return nil, fmt.Errorf("not implemented")
}
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
func (z *MooshotModel) ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error {
return fmt.Errorf("not implemented")
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *MooshotModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("not implemented")
}
func (z *MooshotModel) ListModels(apiConfig *APIConfig) ([]string, error) {
var region = "default"
if apiConfig.Region != nil {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Models)
// Build request body
reqBody := map[string]interface{}{}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := z.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var result map[string]interface{}
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
models, ok := result["models"].([]string)
if !ok || len(models) == 0 {
return nil, fmt.Errorf("no models in response")
}
return models, nil
}

View File

@@ -3,11 +3,18 @@ package models
// EmbeddingModel interface for embedding models
type ModelDriver interface {
// Chat sends a message and returns response
Chat(modelName, apiKey, message *string, modelConfig *ChatConfig) (string, error)
Chat(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig) (*ChatResponse, error)
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error
ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error
// Encode encodes a list of texts into embeddings
EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error)
EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error)
// List suppported models
ListModels(apiConfig *APIConfig) ([]string, error)
}
type ChatResponse struct {
Answer *string `json:"answer"`
ReasonContent *string `json:"reason_content"`
}
// URLSuffix represents the URL suffixes for different API endpoints
@@ -17,19 +24,24 @@ type URLSuffix struct {
AsyncResult string `json:"async_result"`
Embedding string `json:"embedding"`
Rerank string `json:"rerank"`
Models string `json:"models"`
Balance string `json:"balance"`
}
type ChatConfig struct {
Stream *bool
Reasoning *bool
Thinking *bool
MaxTokens *int
Temperature *float64
TopP *float64
DoSample *bool
Stop *[]string
Region *string
}
type APIConfig struct {
ApiKey *string
Region *string
}
type EmbeddingConfig struct {
Region *string
}

View File

@@ -53,12 +53,17 @@ func NewZhipuAIModel(baseURL map[string]string, urlSuffix URLSuffix) *ZhipuAIMod
}
// Chat sends a message and returns response
func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, chatModelConfig *ChatConfig) (string, error) {
func (z *ZhipuAIModel) Chat(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
if message == nil {
return "", fmt.Errorf("message is nil")
return nil, fmt.Errorf("message is nil")
}
url := fmt.Sprintf("%s/%s", z.BaseURL, z.URLSuffix.Chat)
var region = "default"
if apiConfig.Region != nil {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Chat)
// Build request body
reqBody := map[string]interface{}{
@@ -70,82 +75,117 @@ func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, chatModelConfig
"temperature": 1,
}
if chatModelConfig != nil {
if chatModelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
}
if chatModelConfig.Stream != nil {
reqBody["stream"] = *chatModelConfig.Stream
}
if chatModelConfig.Temperature != nil {
reqBody["temperature"] = *chatModelConfig.Temperature
}
if chatModelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
}
if chatModelConfig.TopP != nil {
reqBody["top_p"] = *chatModelConfig.TopP
if chatModelConfig.Temperature != nil {
reqBody["temperature"] = *chatModelConfig.Temperature
}
if chatModelConfig.TopP != nil {
reqBody["top_p"] = *chatModelConfig.TopP
}
if chatModelConfig.Stop != nil {
reqBody["stop"] = *chatModelConfig.Stop
}
if chatModelConfig.Thinking != nil {
if *chatModelConfig.Thinking {
reqBody["thinking"] = map[string]interface{}{
"type": "enabled",
}
} else {
reqBody["thinking"] = map[string]interface{}{
"type": "disabled",
}
}
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("failed to marshal request: %w", err)
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := z.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
choices, ok := result["choices"].([]interface{})
if !ok || len(choices) == 0 {
return "", fmt.Errorf("no choices in response")
return nil, fmt.Errorf("no choices in response")
}
firstChoice, ok := choices[0].(map[string]interface{})
if !ok {
return "", fmt.Errorf("invalid choice format")
return nil, fmt.Errorf("invalid choice format")
}
messageMap, ok := firstChoice["message"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("invalid message format")
return nil, fmt.Errorf("invalid message format")
}
content, ok := messageMap["content"].(string)
if !ok {
return "", fmt.Errorf("invalid content format")
return nil, fmt.Errorf("invalid content format")
}
return content, nil
var reasonContent string
if chatModelConfig.Thinking != nil && *chatModelConfig.Thinking {
reasonContent, ok = messageMap["reasoning_content"].(string)
if !ok {
return nil, fmt.Errorf("invalid content format")
}
// if first char of reasonContent is \n remove the '\n'
if reasonContent != "" && reasonContent[0] == '\n' {
reasonContent = reasonContent[1:]
}
}
chatResponse := &ChatResponse{
Answer: &content,
ReasonContent: &reasonContent,
}
return chatResponse, nil
}
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string, chatModelConfig *ChatConfig, sender func(*string, *string) error) error {
func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error {
var region = "default"
if chatModelConfig.Region != nil {
region = *chatModelConfig.Region
if apiConfig.Region != nil {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/chat/completions", z.BaseURL[region])
@@ -160,40 +200,38 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string
"temperature": 1,
}
if chatModelConfig != nil {
if chatModelConfig.Stream != nil {
reqBody["stream"] = *chatModelConfig.Stream
}
if chatModelConfig.Stream != nil {
reqBody["stream"] = *chatModelConfig.Stream
}
if chatModelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
}
if chatModelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
}
if chatModelConfig.Temperature != nil {
reqBody["temperature"] = *chatModelConfig.Temperature
}
if chatModelConfig.Temperature != nil {
reqBody["temperature"] = *chatModelConfig.Temperature
}
if chatModelConfig.DoSample != nil {
reqBody["do_sample"] = *chatModelConfig.DoSample
}
if chatModelConfig.DoSample != nil {
reqBody["do_sample"] = *chatModelConfig.DoSample
}
if chatModelConfig.TopP != nil {
reqBody["top_p"] = *chatModelConfig.TopP
}
if chatModelConfig.TopP != nil {
reqBody["top_p"] = *chatModelConfig.TopP
}
if chatModelConfig.Stop != nil {
reqBody["stop"] = *chatModelConfig.Stop
}
if chatModelConfig.Stop != nil {
reqBody["stop"] = *chatModelConfig.Stop
}
if chatModelConfig.Reasoning != nil {
if *chatModelConfig.Reasoning {
reqBody["thinking"] = map[string]interface{}{
"type": "enabled",
}
} else {
reqBody["thinking"] = map[string]interface{}{
"type": "disabled",
}
if chatModelConfig.Thinking != nil {
if *chatModelConfig.Thinking {
reqBody["thinking"] = map[string]interface{}{
"type": "enabled",
}
} else {
reqBody["thinking"] = map[string]interface{}{
"type": "disabled",
}
}
}
@@ -209,7 +247,7 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := z.httpClient.Do(req)
if err != nil {
@@ -292,10 +330,10 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
func (z *ZhipuAIModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
var region = "default"
if embeddingConfig.Region != nil {
region = *embeddingConfig.Region
if apiConfig.Region != nil {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/embedding", z.BaseURL[region])
@@ -319,7 +357,7 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []stri
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := z.httpClient.Do(req)
if err != nil {
@@ -375,3 +413,7 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []stri
return embeddings, nil
}
func (z *ZhipuAIModel) ListModels(apiConfig *APIConfig) ([]string, error) {
return nil, fmt.Errorf("no such method")
}

View File

@@ -458,6 +458,41 @@ func (h *ProviderHandler) ListInstanceModels(c *gin.Context) {
})
return
}
keywords := ""
if queryKeywords := c.Query("supported"); queryKeywords != "" {
keywords = queryKeywords
}
// convert keywords to small case
keywords = strings.ToLower(keywords)
if keywords == "true" {
// list supported models
modelList, err := h.modelProviderService.ListSupportedModels(providerName, instanceName, c.GetString("user_id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": err.Error(),
})
return
}
var modelResponse []map[string]string
for _, modelName := range modelList {
modelResponse = append(modelResponse, map[string]string{
"model_name": modelName,
})
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "success",
"data": modelResponse,
})
return
}
modelInstances, err := h.modelProviderService.ListInstanceModels(providerName, instanceName, c.GetString("user_id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -533,9 +568,9 @@ func (h *ProviderHandler) EnableOrDisableModel(c *gin.Context) {
}
type ChatToModelRequest struct {
Message string `json:"message" binding:"required"`
Stream bool `json:"stream"`
Reasoning bool `json:"reasoning"`
Message string `json:"message" binding:"required"`
Stream bool `json:"stream"`
Thinking bool `json:"thinking"`
}
func (h *ProviderHandler) ChatToModel(c *gin.Context) {
@@ -610,19 +645,23 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
return nil
}
apiConfig := models.APIConfig{
ApiKey: nil,
Region: nil,
}
chatConfig := models.ChatConfig{
Reasoning: &req.Reasoning,
Thinking: &req.Thinking,
Stream: &req.Stream,
Stop: &[]string{},
DoSample: nil,
MaxTokens: nil,
Temperature: nil,
TopP: nil,
Region: nil,
}
// Stream response using sender function (best performance, no channel)
errorCode := h.modelProviderService.ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, req.Message, &chatConfig, sender)
errorCode := h.modelProviderService.ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, req.Message, &apiConfig, &chatConfig, sender)
if errorCode != common.CodeSuccess {
c.SSEvent("error", "stream failed")
@@ -630,19 +669,23 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
return
}
apiConfig := models.APIConfig{
ApiKey: nil,
Region: nil,
}
chatConfig := models.ChatConfig{
Reasoning: &req.Reasoning,
Thinking: &req.Thinking,
Stream: &req.Stream,
Stop: &[]string{},
DoSample: nil,
MaxTokens: nil,
Temperature: nil,
TopP: nil,
Region: nil,
}
// Non-stream response
response, errorCode, err := h.modelProviderService.ChatToModel(providerName, instanceName, modelName, userID, req.Message, &chatConfig)
response, errorCode, err := h.modelProviderService.ChatToModel(providerName, instanceName, modelName, userID, req.Message, &apiConfig, &chatConfig)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": errorCode,
@@ -652,7 +695,8 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": response,
"code": 0,
"reasoning_content": response.ReasonContent,
"answer": response.Answer,
})
}

View File

@@ -229,6 +229,54 @@ func (m *ModelProviderService) DeleteModelProvider(providerName, userID string)
return common.CodeSuccess, nil
}
func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, userID string) ([]string, error) {
// Get tenant ID from user
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
if err != nil {
return nil, errors.New("fail to get tenant")
}
if len(tenants) == 0 {
return nil, errors.New("user has no tenants")
}
tenantID := tenants[0].TenantID
// Check if provider exists
provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName)
if err != nil {
return nil, err
}
instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName)
if err != nil {
return nil, err
}
providerInfo := dao.GetModelProviderManager().FindProvider(providerName)
if providerInfo == nil {
return nil, fmt.Errorf("provider %s not found", providerName)
}
var extra map[string]string
err = json.Unmarshal([]byte(instance.Extra), &extra)
if err != nil {
return nil, err
}
apiConfig := &modelModule.APIConfig{
ApiKey: nil,
Region: nil,
}
region := extra["region"]
apiConfig.Region = &region
apiConfig.ApiKey = &instance.APIKey
return providerInfo.ModelDriver.ListModels(apiConfig)
}
func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName, apiKey, userID, region string) (common.ErrorCode, error) {
// Get tenant ID from user
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
@@ -531,7 +579,7 @@ func (m *ModelProviderService) UpdateModelStatus(providerName, instanceName, mod
return common.CodeSuccess, nil
}
func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName, userID, message string, modelConfig *modelModule.ChatConfig) (*string, common.ErrorCode, error) {
func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName, userID, message string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ChatConfig) (*modelModule.ChatResponse, common.ErrorCode, error) {
// Get tenant ID from user
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
@@ -575,22 +623,23 @@ func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName
}
region := extra["region"]
modelConfig.Region = &region
apiConfig.Region = &region
apiConfig.ApiKey = &instance.APIKey
var response string
response, err = providerInfo.ModelDriver.Chat(&modelName, &instance.APIKey, &message, modelConfig)
var response *modelModule.ChatResponse
response, err = providerInfo.ModelDriver.Chat(&modelName, &message, apiConfig, modelConfig)
if err != nil {
return nil, common.CodeServerError, err
}
return &response, common.CodeSuccess, nil
return response, common.CodeSuccess, nil
}
return nil, common.CodeServerError, errors.New("model is disabled")
}
// ChatToModelStreamWithSender streams chat response directly via sender function (best performance, no channel)
func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, message string, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) common.ErrorCode {
func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, message string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) common.ErrorCode {
// Get tenant ID from user
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
if err != nil {
@@ -633,10 +682,11 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc
}
region := extra["region"]
modelConfig.Region = &region
apiConfig.Region = &region
apiConfig.ApiKey = &instance.APIKey
// Direct call with sender function
err := providerInfo.ModelDriver.ChatStreamlyWithSender(&modelName, &instance.APIKey, &message, modelConfig, sender)
err = providerInfo.ModelDriver.ChatStreamlyWithSender(&modelName, &message, apiConfig, modelConfig, sender)
if err != nil {
return common.CodeServerError
}