mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Go: add stream / think chat (#14242)
### What problem does this PR solve? 1. Supports stream and non-stream chat 2. Supports think and non-think chat 3. List supported models from DeepSeek service. (This command can be used to verify the API validity) ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@@ -222,10 +222,20 @@
|
||||
"glm-4.6v",
|
||||
"glm-4.5",
|
||||
"glm-4.5v"
|
||||
],
|
||||
"clear": {
|
||||
"default_value": true
|
||||
}
|
||||
]
|
||||
},
|
||||
"clear_thinking": {
|
||||
"default_value": true,
|
||||
"supported_models": [
|
||||
"glm-5.1",
|
||||
"glm-5",
|
||||
"glm-5v-turbo",
|
||||
"glm-4.7",
|
||||
"glm-4.6",
|
||||
"glm-4.6v",
|
||||
"glm-4.5",
|
||||
"glm-4.5v"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -164,6 +164,8 @@ func (c *RAGFlowClient) ExecuteAdminCommand(cmd *Command) (ResponseIf, error) {
|
||||
return c.ShowProvider(cmd)
|
||||
case "list_provider_models":
|
||||
return c.ListModels(cmd)
|
||||
case "list_supported_models":
|
||||
return c.ListSupportedModels(cmd)
|
||||
case "list_instance_models":
|
||||
return c.ListInstanceModels(cmd)
|
||||
case "show_model":
|
||||
@@ -214,6 +216,8 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
return c.ShowProvider(cmd)
|
||||
case "list_provider_models":
|
||||
return c.ListModels(cmd)
|
||||
case "list_supported_models":
|
||||
return c.ListSupportedModels(cmd)
|
||||
case "list_instance_models":
|
||||
return c.ListInstanceModels(cmd)
|
||||
case "show_model":
|
||||
|
||||
@@ -335,6 +335,45 @@ func (c *RAGFlowClient) ListModels(cmd *Command) (ResponseIf, error) {
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *RAGFlowClient) ListSupportedModels(cmd *Command) (ResponseIf, error) {
|
||||
|
||||
providerName, ok := cmd.Params["provider_name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("provider_name not provided")
|
||||
}
|
||||
instanceName, ok := cmd.Params["instance_name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("instance_name not provided")
|
||||
}
|
||||
|
||||
var endPoint string
|
||||
if c.ServerType == "admin" {
|
||||
endPoint = fmt.Sprintf("/admin/providers/%s/instances/%s/models?supported=true", providerName, instanceName)
|
||||
} else {
|
||||
endPoint = fmt.Sprintf("/providers/%s/instances/%s/models?supported=true", providerName, instanceName)
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("GET", endPoint, true, "web", nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list models: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to list models: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result CommonResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to list models: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", result.Message)
|
||||
}
|
||||
result.Duration = resp.Duration
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *RAGFlowClient) ShowModel(cmd *Command) (ResponseIf, error) {
|
||||
providerName, ok := cmd.Params["provider_name"].(string)
|
||||
if !ok {
|
||||
|
||||
@@ -303,6 +303,8 @@ func (l *Lexer) lookupIdent(ident string) Token {
|
||||
return Token{Type: TokenChat, Value: ident}
|
||||
case "THINK":
|
||||
return Token{Type: TokenThink, Value: ident}
|
||||
case "STREAM":
|
||||
return Token{Type: TokenStream, Value: ident}
|
||||
case "LS":
|
||||
return Token{Type: TokenLS, Value: ident}
|
||||
case "CAT":
|
||||
@@ -363,6 +365,8 @@ func (l *Lexer) lookupIdent(ident string) Token {
|
||||
return Token{Type: TokenTable, Value: ident}
|
||||
case "AVAILABLE":
|
||||
return Token{Type: TokenAvailable, Value: ident}
|
||||
case "SUPPORTED":
|
||||
return Token{Type: TokenSupported, Value: ident}
|
||||
case "NAME":
|
||||
return Token{Type: TokenName, Value: ident}
|
||||
case "INSTANCE":
|
||||
|
||||
@@ -190,6 +190,8 @@ func (p *Parser) parseUserCommand() (*Command, error) {
|
||||
return p.parseEnableCommand()
|
||||
case TokenDisable:
|
||||
return p.parseDisableCommand()
|
||||
case TokenStream:
|
||||
return p.parseStreamCommand()
|
||||
case TokenChat:
|
||||
return p.parseChatCommand()
|
||||
case TokenThink:
|
||||
|
||||
@@ -113,28 +113,33 @@ func (r *SimpleResponse) PrintOut() {
|
||||
}
|
||||
}
|
||||
|
||||
type MessageResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
OutputFormat OutputFormat
|
||||
type NonStreamResponse struct {
|
||||
Code int `json:"code"`
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
Answer string `json:"answer"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
OutputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *MessageResponse) Type() string {
|
||||
return "message"
|
||||
func (r *NonStreamResponse) Type() string {
|
||||
return "non_stream_message"
|
||||
}
|
||||
|
||||
func (r *MessageResponse) TimeCost() float64 {
|
||||
func (r *NonStreamResponse) TimeCost() float64 {
|
||||
return r.Duration
|
||||
}
|
||||
|
||||
func (r *MessageResponse) SetOutputFormat(format OutputFormat) {
|
||||
func (r *NonStreamResponse) SetOutputFormat(format OutputFormat) {
|
||||
r.OutputFormat = format
|
||||
}
|
||||
|
||||
func (r *MessageResponse) PrintOut() {
|
||||
func (r *NonStreamResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
fmt.Println(r.Message)
|
||||
if r.ReasoningContent != "" {
|
||||
fmt.Printf("Thinking: %s\n", r.ReasoningContent)
|
||||
}
|
||||
fmt.Printf("Answer: %s\n", r.Answer)
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
|
||||
@@ -73,6 +73,7 @@ const (
|
||||
TokenKeys
|
||||
TokenGenerate
|
||||
TokenAvailable
|
||||
TokenSupported
|
||||
TokenModel
|
||||
TokenModels
|
||||
TokenProvider
|
||||
@@ -80,6 +81,7 @@ const (
|
||||
TokenDefault
|
||||
TokenChats
|
||||
TokenChat
|
||||
TokenStream
|
||||
TokenFiles
|
||||
TokenAs
|
||||
TokenParse
|
||||
@@ -106,7 +108,6 @@ const (
|
||||
TokenIndex
|
||||
TokenVector
|
||||
TokenSize
|
||||
TokenDocMeta
|
||||
TokenName // For ALTER PROVIDER <name> NAME <new_name>
|
||||
TokenInstance
|
||||
TokenInstances
|
||||
|
||||
@@ -1436,83 +1436,106 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
message := cmd.Params["message"].(string)
|
||||
reasoning := cmd.Params["reasoning"].(bool)
|
||||
thinking := cmd.Params["thinking"].(bool)
|
||||
stream := cmd.Params["stream"].(bool)
|
||||
|
||||
url := fmt.Sprintf("/providers/%s/instances/%s/models/%s", providerName, instanceName, modelName)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"message": message,
|
||||
"stream": true, // use stream API
|
||||
"reasoning": reasoning,
|
||||
"message": message,
|
||||
"stream": stream, // use stream API
|
||||
"thinking": thinking,
|
||||
}
|
||||
|
||||
// Call stream http api
|
||||
reader, duration, err := c.HTTPClient.RequestStream("POST", url, true, "web", nil, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to chat model: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Parse SSE and output to console
|
||||
scanner := bufio.NewScanner(reader)
|
||||
var fullMessage strings.Builder
|
||||
|
||||
reasoningPrint := true
|
||||
messagePrint := true
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
data := strings.TrimPrefix(line, "data:")
|
||||
data = strings.TrimSpace(data)
|
||||
|
||||
if strings.HasPrefix(data, "[REASONING]") {
|
||||
data = strings.TrimPrefix(data, "[REASONING]")
|
||||
if reasoningPrint {
|
||||
fmt.Print("Thinking: ")
|
||||
reasoningPrint = false
|
||||
} else {
|
||||
fmt.Print(data)
|
||||
}
|
||||
os.Stdout.Sync()
|
||||
}
|
||||
if strings.HasPrefix(data, "[MESSAGE]") {
|
||||
data = strings.TrimPrefix(data, "[MESSAGE]")
|
||||
if messagePrint {
|
||||
if reasoning {
|
||||
fmt.Println()
|
||||
}
|
||||
fmt.Print("Answer: ")
|
||||
messagePrint = false
|
||||
} else {
|
||||
fmt.Print(data)
|
||||
os.Stdout.Sync()
|
||||
fullMessage.WriteString(data)
|
||||
}
|
||||
}
|
||||
} else if strings.HasPrefix(line, "event:error") {
|
||||
// error event
|
||||
if scanner.Scan() {
|
||||
errData := strings.TrimPrefix(scanner.Text(), "data:")
|
||||
errData = strings.TrimSpace(errData)
|
||||
return nil, fmt.Errorf("chat error: %s", errData)
|
||||
}
|
||||
// If there's an error, return a generic error
|
||||
return nil, fmt.Errorf("chat error: received error event from server")
|
||||
if stream {
|
||||
// Call stream http api
|
||||
reader, duration, err := c.HTTPClient.RequestStream("POST", url, true, "web", nil, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to chat model: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Parse SSE and output to console
|
||||
scanner := bufio.NewScanner(reader)
|
||||
var fullMessage strings.Builder
|
||||
|
||||
reasoningPrint := true
|
||||
messagePrint := true
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
data := strings.TrimPrefix(line, "data:")
|
||||
data = strings.TrimSpace(data)
|
||||
|
||||
if strings.HasPrefix(data, "[REASONING]") {
|
||||
data = strings.TrimPrefix(data, "[REASONING]")
|
||||
if reasoningPrint {
|
||||
fmt.Print("Thinking: ")
|
||||
reasoningPrint = false
|
||||
} else {
|
||||
fmt.Print(data)
|
||||
}
|
||||
os.Stdout.Sync()
|
||||
}
|
||||
if strings.HasPrefix(data, "[MESSAGE]") {
|
||||
data = strings.TrimPrefix(data, "[MESSAGE]")
|
||||
if messagePrint {
|
||||
if thinking {
|
||||
fmt.Println()
|
||||
}
|
||||
fmt.Print("Answer: ")
|
||||
messagePrint = false
|
||||
} else {
|
||||
fmt.Print(data)
|
||||
os.Stdout.Sync()
|
||||
fullMessage.WriteString(data)
|
||||
}
|
||||
}
|
||||
} else if strings.HasPrefix(line, "event:error") {
|
||||
// error event
|
||||
if scanner.Scan() {
|
||||
errData := strings.TrimPrefix(scanner.Text(), "data:")
|
||||
errData = strings.TrimSpace(errData)
|
||||
return nil, fmt.Errorf("chat error: %s", errData)
|
||||
}
|
||||
// If there's an error, return a generic error
|
||||
return nil, fmt.Errorf("chat error: received error event from server")
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error reading stream: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
|
||||
result := &StreamMessageResponse{
|
||||
Code: 0,
|
||||
Message: fullMessage.String(),
|
||||
Duration: duration,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error reading stream: %w", err)
|
||||
resp, err := c.HTTPClient.Request("POST", url, true, "web", nil, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list instance models: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
|
||||
result := &StreamMessageResponse{
|
||||
Code: 0,
|
||||
Message: fullMessage.String(),
|
||||
Duration: duration,
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to list instance models: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
return result, nil
|
||||
|
||||
var result NonStreamResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to list instance models: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", result.Message)
|
||||
}
|
||||
result.Duration = resp.Duration
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// UseModel sets the current model for chat
|
||||
|
||||
@@ -163,6 +163,8 @@ func (p *Parser) parseListCommand() (*Command, error) {
|
||||
return p.parseListTokens()
|
||||
case TokenModel:
|
||||
return p.parseListModelProviders()
|
||||
case TokenSupported:
|
||||
return p.parseListModelsOfProvider()
|
||||
case TokenModels:
|
||||
return p.parseListModelsOfProvider()
|
||||
case TokenProviders:
|
||||
@@ -2014,11 +2016,55 @@ func (p *Parser) parseSearchCommand() (*Command, error) {
|
||||
}
|
||||
|
||||
func (p *Parser) parseListModelsOfProvider() (*Command, error) {
|
||||
|
||||
if p.curToken.Type == TokenSupported {
|
||||
// List supported models
|
||||
p.nextToken()
|
||||
|
||||
cmd := NewCommand("list_supported_models")
|
||||
if p.curToken.Type != TokenModels {
|
||||
return nil, fmt.Errorf("expected MODELS")
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
if p.curToken.Type != TokenFrom {
|
||||
return nil, fmt.Errorf("expected FROM")
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
if p.curToken.Type != TokenQuotedString {
|
||||
return nil, fmt.Errorf("expected quoted string for provider name")
|
||||
}
|
||||
firstName, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
if p.curToken.Type != TokenQuotedString {
|
||||
return nil, fmt.Errorf("expected quoted string for instance name")
|
||||
}
|
||||
secondName, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
cmd.Params["provider_name"] = firstName
|
||||
cmd.Params["instance_name"] = secondName
|
||||
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
if p.curToken.Type != TokenModels {
|
||||
return nil, fmt.Errorf("expected MODELS")
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
|
||||
if p.curToken.Type != TokenFrom {
|
||||
return nil, fmt.Errorf("expected FROM")
|
||||
}
|
||||
@@ -2194,19 +2240,47 @@ func (p *Parser) parseChatCommand() (*Command, error) {
|
||||
cmd.Params["composite_model_name"] = compositeModelName
|
||||
}
|
||||
cmd.Params["message"] = message
|
||||
cmd.Params["reasoning"] = false
|
||||
cmd.Params["thinking"] = false
|
||||
cmd.Params["stream"] = false
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseThinkCommand() (*Command, error) {
|
||||
|
||||
p.nextToken() // consume THINK
|
||||
|
||||
if p.curToken.Type != TokenChat {
|
||||
return nil, fmt.Errorf("expected CHAT after THINK")
|
||||
}
|
||||
|
||||
command, err := p.parseChatCommand()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
command.Type = "think_chat_to_model"
|
||||
command.Params["reasoning"] = true
|
||||
command.Params["thinking"] = true
|
||||
return command, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseStreamCommand() (*Command, error) {
|
||||
|
||||
p.nextToken() // consume STREAM
|
||||
|
||||
var command *Command
|
||||
var err error
|
||||
|
||||
if p.curToken.Type == TokenChat {
|
||||
command, err = p.parseChatCommand()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if p.curToken.Type == TokenThink {
|
||||
command, err = p.parseThinkCommand()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
command.Params["stream"] = true
|
||||
return command, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -61,14 +61,14 @@ type Reasoning struct {
|
||||
|
||||
// Reasoning represents the reasoning capability (can be one of three types)
|
||||
type ClearReasoningContent struct {
|
||||
DefaultValue bool `json:"default_value"`
|
||||
DefaultValue bool `json:"default_value"`
|
||||
SupportedModels []string `json:"supported_models"`
|
||||
}
|
||||
|
||||
// Reasoning represents the reasoning capability (can be one of three types)
|
||||
type Thinking struct {
|
||||
DefaultValue bool `json:"default_value"`
|
||||
SupportedModels []string `json:"supported_models"`
|
||||
Clear ClearReasoningContent `json:"clear"`
|
||||
DefaultValue bool `json:"default_value"`
|
||||
SupportedModels []string `json:"supported_models"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON custom unmarshal for Reasoning
|
||||
@@ -142,9 +142,10 @@ type Multimodal struct {
|
||||
|
||||
// Features represents all features of a model
|
||||
type Features struct {
|
||||
Multimodal *Multimodal `json:"multimodal,omitempty"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
Thinking *Thinking `json:"thinking,omitempty"`
|
||||
Multimodal *Multimodal `json:"multimodal,omitempty"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
Thinking *Thinking `json:"thinking,omitempty"`
|
||||
ClearThinking *ClearReasoningContent `json:"clear_thinking,omitempty"`
|
||||
}
|
||||
|
||||
type ModelThinking struct {
|
||||
@@ -231,16 +232,29 @@ func NewProviderManager(dirPath string) (*ProviderManager, error) {
|
||||
}
|
||||
}
|
||||
|
||||
modelClearThinking := make(map[string]bool)
|
||||
if provider.Features.ClearThinking != nil {
|
||||
for _, modelName := range provider.Features.ClearThinking.SupportedModels {
|
||||
modelClearThinking[modelName] = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, model := range provider.Models {
|
||||
// if the prefix of mode.Name is matched with keys of modelSupportThinking
|
||||
for modelPrefix, _ := range modelSupportThinking {
|
||||
if strings.HasPrefix(model.Name, modelPrefix) {
|
||||
model.Thinking = &ModelThinking{
|
||||
DefaultValue: provider.Features.Thinking.DefaultValue,
|
||||
ClearContent: provider.Features.Thinking.Clear.DefaultValue,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelPrefix, _ := range modelClearThinking {
|
||||
if strings.HasPrefix(model.Name, modelPrefix) {
|
||||
model.Thinking.ClearContent = true
|
||||
}
|
||||
}
|
||||
|
||||
model.ModelTypeMap = make(map[string]bool)
|
||||
for _, modelType := range model.ModelTypes {
|
||||
model.ModelTypeMap[modelType] = true
|
||||
|
||||
147
internal/entity/models/deepseek.go
Normal file
147
internal/entity/models/deepseek.go
Normal file
@@ -0,0 +1,147 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DeepSeekModel implements ModelDriver for DeepSeek
|
||||
type DeepSeekModel struct {
|
||||
BaseURL map[string]string
|
||||
URLSuffix URLSuffix
|
||||
httpClient *http.Client // Reusable HTTP client with connection pool
|
||||
}
|
||||
|
||||
// NewDeepSeekModel creates a new DeepSeek model instance
|
||||
func NewDeepSeekModel(baseURL map[string]string, urlSuffix URLSuffix) *DeepSeekModel {
|
||||
return &DeepSeekModel{
|
||||
BaseURL: baseURL,
|
||||
URLSuffix: urlSuffix,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DisableCompression: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Chat sends a message and returns response
|
||||
func (z *DeepSeekModel) Chat(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
|
||||
func (z *DeepSeekModel) ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *DeepSeekModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
/*
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "deepseek-chat",
|
||||
"object": "model",
|
||||
"owned_by": "deepseek"
|
||||
},
|
||||
{
|
||||
"id": "deepseek-reasoner",
|
||||
"object": "model",
|
||||
"owned_by": "deepseek"
|
||||
}
|
||||
]
|
||||
}
|
||||
*/
|
||||
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
type ModelList struct {
|
||||
Object string `json:"object"`
|
||||
Models []Model `json:"data"`
|
||||
}
|
||||
|
||||
func (z *DeepSeekModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Models)
|
||||
|
||||
// Build request body
|
||||
reqBody := map[string]interface{}{}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var modelList ModelList
|
||||
if err = json.Unmarshal(body, &modelList); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
var models []string
|
||||
for _, model := range modelList.Models {
|
||||
models = append(models, model.ID)
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
@@ -35,16 +35,20 @@ func NewDummyModel(baseURL map[string]string, urlSuffix URLSuffix) *DummyModel {
|
||||
}
|
||||
|
||||
// Chat sends a message and returns response
|
||||
func (z *DummyModel) Chat(modelName, apiKey, message *string, modelConfig *ChatConfig) (string, error) {
|
||||
return "", fmt.Errorf("not implemented")
|
||||
func (z *DummyModel) Chat(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig) (*ChatResponse, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
|
||||
func (z *DummyModel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error {
|
||||
func (z *DummyModel) ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *DummyModel) EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
func (z *DummyModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (z *DummyModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
@@ -35,6 +35,10 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string
|
||||
switch providerLower {
|
||||
case "zhipu-ai":
|
||||
return NewZhipuAIModel(baseURL, urlSuffix), nil
|
||||
case "deepseek":
|
||||
return NewDeepSeekModel(baseURL, urlSuffix), nil
|
||||
case "moonshot":
|
||||
return NewMooshotModel(baseURL, urlSuffix), nil
|
||||
default:
|
||||
return NewDummyModel(baseURL, urlSuffix), nil
|
||||
}
|
||||
|
||||
118
internal/entity/models/moonshot.go
Normal file
118
internal/entity/models/moonshot.go
Normal file
@@ -0,0 +1,118 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MooshotModel implements ModelDriver for Mooshot
|
||||
type MooshotModel struct {
|
||||
BaseURL map[string]string
|
||||
URLSuffix URLSuffix
|
||||
httpClient *http.Client // Reusable HTTP client with connection pool
|
||||
}
|
||||
|
||||
// NewMooshotModel creates a new Mooshot model instance
|
||||
func NewMooshotModel(baseURL map[string]string, urlSuffix URLSuffix) *MooshotModel {
|
||||
return &MooshotModel{
|
||||
BaseURL: baseURL,
|
||||
URLSuffix: urlSuffix,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DisableCompression: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Chat sends a message and returns response
|
||||
func (z *MooshotModel) Chat(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
|
||||
func (z *MooshotModel) ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *MooshotModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (z *MooshotModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Models)
|
||||
|
||||
// Build request body
|
||||
reqBody := map[string]interface{}{}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var result map[string]interface{}
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
models, ok := result["models"].([]string)
|
||||
if !ok || len(models) == 0 {
|
||||
return nil, fmt.Errorf("no models in response")
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
@@ -3,11 +3,18 @@ package models
|
||||
// EmbeddingModel interface for embedding models
|
||||
type ModelDriver interface {
|
||||
// Chat sends a message and returns response
|
||||
Chat(modelName, apiKey, message *string, modelConfig *ChatConfig) (string, error)
|
||||
Chat(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig) (*ChatResponse, error)
|
||||
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
|
||||
ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error
|
||||
ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error
|
||||
// Encode encodes a list of texts into embeddings
|
||||
EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error)
|
||||
EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error)
|
||||
// List suppported models
|
||||
ListModels(apiConfig *APIConfig) ([]string, error)
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Answer *string `json:"answer"`
|
||||
ReasonContent *string `json:"reason_content"`
|
||||
}
|
||||
|
||||
// URLSuffix represents the URL suffixes for different API endpoints
|
||||
@@ -17,19 +24,24 @@ type URLSuffix struct {
|
||||
AsyncResult string `json:"async_result"`
|
||||
Embedding string `json:"embedding"`
|
||||
Rerank string `json:"rerank"`
|
||||
Models string `json:"models"`
|
||||
Balance string `json:"balance"`
|
||||
}
|
||||
|
||||
type ChatConfig struct {
|
||||
Stream *bool
|
||||
Reasoning *bool
|
||||
Thinking *bool
|
||||
MaxTokens *int
|
||||
Temperature *float64
|
||||
TopP *float64
|
||||
DoSample *bool
|
||||
Stop *[]string
|
||||
Region *string
|
||||
}
|
||||
|
||||
type APIConfig struct {
|
||||
ApiKey *string
|
||||
Region *string
|
||||
}
|
||||
|
||||
type EmbeddingConfig struct {
|
||||
Region *string
|
||||
}
|
||||
|
||||
@@ -53,12 +53,17 @@ func NewZhipuAIModel(baseURL map[string]string, urlSuffix URLSuffix) *ZhipuAIMod
|
||||
}
|
||||
|
||||
// Chat sends a message and returns response
|
||||
func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, chatModelConfig *ChatConfig) (string, error) {
|
||||
func (z *ZhipuAIModel) Chat(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
||||
if message == nil {
|
||||
return "", fmt.Errorf("message is nil")
|
||||
return nil, fmt.Errorf("message is nil")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", z.BaseURL, z.URLSuffix.Chat)
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Chat)
|
||||
|
||||
// Build request body
|
||||
reqBody := map[string]interface{}{
|
||||
@@ -70,82 +75,117 @@ func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, chatModelConfig
|
||||
"temperature": 1,
|
||||
}
|
||||
|
||||
if chatModelConfig != nil {
|
||||
if chatModelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
|
||||
}
|
||||
if chatModelConfig.Stream != nil {
|
||||
reqBody["stream"] = *chatModelConfig.Stream
|
||||
}
|
||||
|
||||
if chatModelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *chatModelConfig.Temperature
|
||||
}
|
||||
if chatModelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
|
||||
}
|
||||
|
||||
if chatModelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *chatModelConfig.TopP
|
||||
if chatModelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *chatModelConfig.Temperature
|
||||
}
|
||||
|
||||
if chatModelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *chatModelConfig.TopP
|
||||
}
|
||||
|
||||
if chatModelConfig.Stop != nil {
|
||||
reqBody["stop"] = *chatModelConfig.Stop
|
||||
}
|
||||
|
||||
if chatModelConfig.Thinking != nil {
|
||||
if *chatModelConfig.Thinking {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "enabled",
|
||||
}
|
||||
} else {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "disabled",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to send request: %w", err)
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
choices, ok := result["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
return "", fmt.Errorf("no choices in response")
|
||||
return nil, fmt.Errorf("no choices in response")
|
||||
}
|
||||
|
||||
firstChoice, ok := choices[0].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid choice format")
|
||||
return nil, fmt.Errorf("invalid choice format")
|
||||
}
|
||||
|
||||
messageMap, ok := firstChoice["message"].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid message format")
|
||||
return nil, fmt.Errorf("invalid message format")
|
||||
}
|
||||
|
||||
content, ok := messageMap["content"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid content format")
|
||||
return nil, fmt.Errorf("invalid content format")
|
||||
}
|
||||
|
||||
return content, nil
|
||||
var reasonContent string
|
||||
if chatModelConfig.Thinking != nil && *chatModelConfig.Thinking {
|
||||
reasonContent, ok = messageMap["reasoning_content"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid content format")
|
||||
}
|
||||
// if first char of reasonContent is \n remove the '\n'
|
||||
if reasonContent != "" && reasonContent[0] == '\n' {
|
||||
reasonContent = reasonContent[1:]
|
||||
}
|
||||
}
|
||||
|
||||
chatResponse := &ChatResponse{
|
||||
Answer: &content,
|
||||
ReasonContent: &reasonContent,
|
||||
}
|
||||
|
||||
return chatResponse, nil
|
||||
}
|
||||
|
||||
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
|
||||
func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string, chatModelConfig *ChatConfig, sender func(*string, *string) error) error {
|
||||
func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error {
|
||||
var region = "default"
|
||||
if chatModelConfig.Region != nil {
|
||||
region = *chatModelConfig.Region
|
||||
if apiConfig.Region != nil {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/chat/completions", z.BaseURL[region])
|
||||
@@ -160,40 +200,38 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string
|
||||
"temperature": 1,
|
||||
}
|
||||
|
||||
if chatModelConfig != nil {
|
||||
if chatModelConfig.Stream != nil {
|
||||
reqBody["stream"] = *chatModelConfig.Stream
|
||||
}
|
||||
if chatModelConfig.Stream != nil {
|
||||
reqBody["stream"] = *chatModelConfig.Stream
|
||||
}
|
||||
|
||||
if chatModelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
|
||||
}
|
||||
if chatModelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
|
||||
}
|
||||
|
||||
if chatModelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *chatModelConfig.Temperature
|
||||
}
|
||||
if chatModelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *chatModelConfig.Temperature
|
||||
}
|
||||
|
||||
if chatModelConfig.DoSample != nil {
|
||||
reqBody["do_sample"] = *chatModelConfig.DoSample
|
||||
}
|
||||
if chatModelConfig.DoSample != nil {
|
||||
reqBody["do_sample"] = *chatModelConfig.DoSample
|
||||
}
|
||||
|
||||
if chatModelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *chatModelConfig.TopP
|
||||
}
|
||||
if chatModelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *chatModelConfig.TopP
|
||||
}
|
||||
|
||||
if chatModelConfig.Stop != nil {
|
||||
reqBody["stop"] = *chatModelConfig.Stop
|
||||
}
|
||||
if chatModelConfig.Stop != nil {
|
||||
reqBody["stop"] = *chatModelConfig.Stop
|
||||
}
|
||||
|
||||
if chatModelConfig.Reasoning != nil {
|
||||
if *chatModelConfig.Reasoning {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "enabled",
|
||||
}
|
||||
} else {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "disabled",
|
||||
}
|
||||
if chatModelConfig.Thinking != nil {
|
||||
if *chatModelConfig.Thinking {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "enabled",
|
||||
}
|
||||
} else {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "disabled",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -209,7 +247,7 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -292,10 +330,10 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
func (z *ZhipuAIModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
var region = "default"
|
||||
if embeddingConfig.Region != nil {
|
||||
region = *embeddingConfig.Region
|
||||
if apiConfig.Region != nil {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/embedding", z.BaseURL[region])
|
||||
@@ -319,7 +357,7 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []stri
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -375,3 +413,7 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []stri
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (z *ZhipuAIModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
return nil, fmt.Errorf("no such method")
|
||||
}
|
||||
|
||||
@@ -458,6 +458,41 @@ func (h *ProviderHandler) ListInstanceModels(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
keywords := ""
|
||||
if queryKeywords := c.Query("supported"); queryKeywords != "" {
|
||||
keywords = queryKeywords
|
||||
}
|
||||
|
||||
// convert keywords to small case
|
||||
keywords = strings.ToLower(keywords)
|
||||
if keywords == "true" {
|
||||
// list supported models
|
||||
|
||||
modelList, err := h.modelProviderService.ListSupportedModels(providerName, instanceName, c.GetString("user_id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var modelResponse []map[string]string
|
||||
for _, modelName := range modelList {
|
||||
modelResponse = append(modelResponse, map[string]string{
|
||||
"model_name": modelName,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": modelResponse,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
modelInstances, err := h.modelProviderService.ListInstanceModels(providerName, instanceName, c.GetString("user_id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -533,9 +568,9 @@ func (h *ProviderHandler) EnableOrDisableModel(c *gin.Context) {
|
||||
}
|
||||
|
||||
type ChatToModelRequest struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
Stream bool `json:"stream"`
|
||||
Reasoning bool `json:"reasoning"`
|
||||
Message string `json:"message" binding:"required"`
|
||||
Stream bool `json:"stream"`
|
||||
Thinking bool `json:"thinking"`
|
||||
}
|
||||
|
||||
func (h *ProviderHandler) ChatToModel(c *gin.Context) {
|
||||
@@ -610,19 +645,23 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
|
||||
return nil
|
||||
}
|
||||
|
||||
apiConfig := models.APIConfig{
|
||||
ApiKey: nil,
|
||||
Region: nil,
|
||||
}
|
||||
|
||||
chatConfig := models.ChatConfig{
|
||||
Reasoning: &req.Reasoning,
|
||||
Thinking: &req.Thinking,
|
||||
Stream: &req.Stream,
|
||||
Stop: &[]string{},
|
||||
DoSample: nil,
|
||||
MaxTokens: nil,
|
||||
Temperature: nil,
|
||||
TopP: nil,
|
||||
Region: nil,
|
||||
}
|
||||
|
||||
// Stream response using sender function (best performance, no channel)
|
||||
errorCode := h.modelProviderService.ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, req.Message, &chatConfig, sender)
|
||||
errorCode := h.modelProviderService.ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, req.Message, &apiConfig, &chatConfig, sender)
|
||||
|
||||
if errorCode != common.CodeSuccess {
|
||||
c.SSEvent("error", "stream failed")
|
||||
@@ -630,19 +669,23 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
apiConfig := models.APIConfig{
|
||||
ApiKey: nil,
|
||||
Region: nil,
|
||||
}
|
||||
|
||||
chatConfig := models.ChatConfig{
|
||||
Reasoning: &req.Reasoning,
|
||||
Thinking: &req.Thinking,
|
||||
Stream: &req.Stream,
|
||||
Stop: &[]string{},
|
||||
DoSample: nil,
|
||||
MaxTokens: nil,
|
||||
Temperature: nil,
|
||||
TopP: nil,
|
||||
Region: nil,
|
||||
}
|
||||
|
||||
// Non-stream response
|
||||
response, errorCode, err := h.modelProviderService.ChatToModel(providerName, instanceName, modelName, userID, req.Message, &chatConfig)
|
||||
response, errorCode, err := h.modelProviderService.ChatToModel(providerName, instanceName, modelName, userID, req.Message, &apiConfig, &chatConfig)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": errorCode,
|
||||
@@ -652,7 +695,8 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": response,
|
||||
"code": 0,
|
||||
"reasoning_content": response.ReasonContent,
|
||||
"answer": response.Answer,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -229,6 +229,54 @@ func (m *ModelProviderService) DeleteModelProvider(providerName, userID string)
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, userID string) ([]string, error) {
|
||||
|
||||
// Get tenant ID from user
|
||||
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
|
||||
if err != nil {
|
||||
return nil, errors.New("fail to get tenant")
|
||||
}
|
||||
|
||||
if len(tenants) == 0 {
|
||||
return nil, errors.New("user has no tenants")
|
||||
}
|
||||
|
||||
tenantID := tenants[0].TenantID
|
||||
|
||||
// Check if provider exists
|
||||
provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
providerInfo := dao.GetModelProviderManager().FindProvider(providerName)
|
||||
if providerInfo == nil {
|
||||
return nil, fmt.Errorf("provider %s not found", providerName)
|
||||
}
|
||||
|
||||
var extra map[string]string
|
||||
err = json.Unmarshal([]byte(instance.Extra), &extra)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
apiConfig := &modelModule.APIConfig{
|
||||
ApiKey: nil,
|
||||
Region: nil,
|
||||
}
|
||||
|
||||
region := extra["region"]
|
||||
apiConfig.Region = ®ion
|
||||
apiConfig.ApiKey = &instance.APIKey
|
||||
|
||||
return providerInfo.ModelDriver.ListModels(apiConfig)
|
||||
}
|
||||
|
||||
func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName, apiKey, userID, region string) (common.ErrorCode, error) {
|
||||
// Get tenant ID from user
|
||||
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
|
||||
@@ -531,7 +579,7 @@ func (m *ModelProviderService) UpdateModelStatus(providerName, instanceName, mod
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName, userID, message string, modelConfig *modelModule.ChatConfig) (*string, common.ErrorCode, error) {
|
||||
func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName, userID, message string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ChatConfig) (*modelModule.ChatResponse, common.ErrorCode, error) {
|
||||
|
||||
// Get tenant ID from user
|
||||
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
|
||||
@@ -575,22 +623,23 @@ func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName
|
||||
}
|
||||
|
||||
region := extra["region"]
|
||||
modelConfig.Region = ®ion
|
||||
apiConfig.Region = ®ion
|
||||
apiConfig.ApiKey = &instance.APIKey
|
||||
|
||||
var response string
|
||||
response, err = providerInfo.ModelDriver.Chat(&modelName, &instance.APIKey, &message, modelConfig)
|
||||
var response *modelModule.ChatResponse
|
||||
response, err = providerInfo.ModelDriver.Chat(&modelName, &message, apiConfig, modelConfig)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
return &response, common.CodeSuccess, nil
|
||||
return response, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
return nil, common.CodeServerError, errors.New("model is disabled")
|
||||
}
|
||||
|
||||
// ChatToModelStreamWithSender streams chat response directly via sender function (best performance, no channel)
|
||||
func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, message string, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) common.ErrorCode {
|
||||
func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, message string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) common.ErrorCode {
|
||||
// Get tenant ID from user
|
||||
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
|
||||
if err != nil {
|
||||
@@ -633,10 +682,11 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc
|
||||
}
|
||||
|
||||
region := extra["region"]
|
||||
modelConfig.Region = ®ion
|
||||
apiConfig.Region = ®ion
|
||||
apiConfig.ApiKey = &instance.APIKey
|
||||
|
||||
// Direct call with sender function
|
||||
err := providerInfo.ModelDriver.ChatStreamlyWithSender(&modelName, &instance.APIKey, &message, modelConfig, sender)
|
||||
err = providerInfo.ModelDriver.ChatStreamlyWithSender(&modelName, &message, apiConfig, modelConfig, sender)
|
||||
if err != nil {
|
||||
return common.CodeServerError
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user