Remove model_bundle.go, modify chat_session.go (#14458)

### What problem does this PR solve?

Remove model_bundle.go, modify chat_session.go

### Type of change

- [x] Refactoring
This commit is contained in:
qinling0210
2026-04-29 14:44:12 +08:00
committed by GitHub
parent ce933357c6
commit f3c232cf47
6 changed files with 187 additions and 422 deletions

View File

@@ -16,10 +16,6 @@
package entity
import (
"ragflow/internal/entity/models"
)
// ModelType represents the type of model
type ModelType string
@@ -40,38 +36,6 @@ const (
ModelTypeOCR ModelType = "ocr"
)
// EmbeddingModel interface for embedding models
type EmbeddingModel interface {
// Encode encodes a list of texts into embeddings
Encode(modelName *string, texts []string, apiConfig *models.APIConfig, embeddingConfig *models.EmbeddingConfig) ([][]float64, error)
}
// ChatModel interface for chat models
type ChatModel interface {
// Chat sends a message and returns response
Chat(system string, history []map[string]string, genConf map[string]interface{}) (string, error)
// ChatStreamly sends a message and streams response
ChatStreamly(system string, history []map[string]string, genConf map[string]interface{}) (<-chan string, error)
}
// RerankModel interface for rerank models
type RerankModel interface {
// Rerank calculates similarity between query and texts
Rerank(query string, texts []string, apiConfig *models.APIConfig) ([]float64, error)
}
// ModelConfig represents configuration for a model
type ModelConfig struct {
TenantID string `json:"tenant_id"`
LLMFactory string `json:"llm_factory"`
ModelType ModelType `json:"model_type"`
LLMName string `json:"llm_name"`
APIKey string `json:"api_key"`
APIBase string `json:"api_base"`
MaxTokens int64 `json:"max_tokens"`
IsTools bool `json:"is_tools"`
}
// ModelCredentials holds the credentials for a model
type ModelCredentials struct {
ProviderName string

View File

@@ -200,7 +200,8 @@ type CompletionRequest struct {
ConversationID string `json:"conversation_id" binding:"required"`
Messages []map[string]interface{} `json:"messages" binding:"required"`
LLMID string `json:"llm_id,omitempty"`
Stream bool `json:"stream,omitempty"`
Stream *bool `json:"stream,omitempty"`
Thinking *bool `json:"thinking,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
@@ -252,6 +253,12 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
if req.MaxTokens != 0 {
chatModelConfig["max_tokens"] = req.MaxTokens
}
if req.Stream != nil {
chatModelConfig["stream"] = *req.Stream
}
if req.Thinking != nil {
chatModelConfig["thinking"] = *req.Thinking
}
// Process messages - filter out system messages and initial assistant messages
var processedMessages []map[string]interface{}
@@ -276,7 +283,7 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
}
// Call service
if req.Stream {
if req.Stream != nil && *req.Stream {
// Streaming response
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")

View File

@@ -24,24 +24,29 @@ import (
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"ragflow/internal/dao"
"ragflow/internal/entity"
modelModule "ragflow/internal/entity/models"
"ragflow/internal/logger"
)
// ChatSessionService chat session (conversation) service
type ChatSessionService struct {
chatSessionDAO *dao.ChatSessionDAO
chatDAO *dao.ChatDAO
userTenantDAO *dao.UserTenantDAO
chatSessionDAO *dao.ChatSessionDAO
chatDAO *dao.ChatDAO
userTenantDAO *dao.UserTenantDAO
modelProviderSvc *ModelProviderService
}
// NewChatSessionService create chat session service
func NewChatSessionService() *ChatSessionService {
return &ChatSessionService{
chatSessionDAO: dao.NewChatSessionDAO(),
chatDAO: dao.NewChatDAO(),
userTenantDAO: dao.NewUserTenantDAO(),
chatSessionDAO: dao.NewChatSessionDAO(),
chatDAO: dao.NewChatDAO(),
userTenantDAO: dao.NewUserTenantDAO(),
modelProviderSvc: NewModelProviderService(),
}
}
@@ -433,97 +438,6 @@ func (s *ChatSessionService) checkTenantLLMAPIKey(tenantID, modelName string) (b
return true, nil
}
func (s *ChatSessionService) performChat(dialog *entity.Chat, messages []map[string]interface{}, config map[string]interface{}) (string, error) {
// Get system prompt from dialog
systemPrompt := ""
if dialog.PromptConfig != nil {
if sys, ok := dialog.PromptConfig["system"].(string); ok {
systemPrompt = sys
}
}
// Convert messages to history format
history := make([]map[string]string, 0)
for _, msg := range messages {
role, _ := msg["role"].(string)
content, _ := msg["content"].(string)
if role != "" && content != "" {
history = append(history, map[string]string{
"role": role,
"content": content,
})
}
}
// Use ModelBundle to perform chat
bundle, err := NewModelBundle(dialog.TenantID, entity.ModelTypeChat, dialog.LLMID)
if err != nil {
return "", err
}
// Merge dialog's LLM setting with request config
genConf := make(map[string]interface{})
if dialog.LLMSetting != nil {
for k, v := range dialog.LLMSetting {
genConf[k] = v
}
}
for k, v := range config {
genConf[k] = v
}
response, _, err := bundle.Chat(systemPrompt, history, genConf)
return response, err
}
func (s *ChatSessionService) performChatStream(dialog *entity.Chat, messages []map[string]interface{}, config map[string]interface{}) (<-chan string, error) {
// Get system prompt from dialog
systemPrompt := ""
if dialog.PromptConfig != nil {
if sys, ok := dialog.PromptConfig["system"].(string); ok {
systemPrompt = sys
}
}
// Convert messages to history format
history := make([]map[string]string, 0)
for _, msg := range messages {
role, _ := msg["role"].(string)
content, _ := msg["content"].(string)
if role != "" && content != "" {
history = append(history, map[string]string{
"role": role,
"content": content,
})
}
}
// Use ModelBundle to perform streaming chat
bundle, err := NewModelBundle(dialog.TenantID, entity.ModelTypeChat, dialog.LLMID)
if err != nil {
return nil, err
}
// Merge dialog's LLM setting with request config
genConf := make(map[string]interface{})
if dialog.LLMSetting != nil {
for k, v := range dialog.LLMSetting {
genConf[k] = v
}
}
for k, v := range config {
genConf[k] = v
}
// Get chat model and call ChatStreamly
chatModel, ok := bundle.GetModel().(entity.ChatModel)
if !ok {
return nil, fmt.Errorf("model is not a chat model")
}
return chatModel.ChatStreamly(systemPrompt, history, genConf)
}
func (s *ChatSessionService) structureAnswer(session *entity.ChatSession, answer string, messageID, conversationID string, reference []interface{}) map[string]interface{} {
return map[string]interface{}{
"answer": answer,
@@ -610,39 +524,52 @@ func (s *ChatSessionService) asyncChatStream(dialog *entity.Chat, session *entit
// asyncChatSolo performs simple chat without RAG (non-streaming)
func (s *ChatSessionService) asyncChatSolo(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
logger.Info("asyncChatSolo started",
zap.String("tenant_id", dialog.TenantID),
zap.String("llm_id", dialog.LLMID),
zap.String("dialog_id", dialog.ID),
zap.Int("message_count", len(messages)))
// Get system prompt
systemPrompt := s.buildSystemPrompt(dialog)
// Process messages - handle attachments and image files
processedMessages := s.processMessages(messages, dialog)
// Get LLM type
llmType := s.getLLMType(dialog.LLMID)
// Build generation config
genConf := s.buildGenConf(dialog, config)
// Create ModelBundle for chat
var bundle *ModelBundle
var err error
if llmType == "image2text" {
bundle, err = NewModelBundle(dialog.TenantID, entity.ModelTypeImage2Text, dialog.LLMID)
} else {
bundle, err = NewModelBundle(dialog.TenantID, entity.ModelTypeChat, dialog.LLMID)
}
chatModel, err := s.modelProviderSvc.GetChatModel(dialog.TenantID, dialog.LLMID)
if err != nil {
logger.Error("asyncChatSolo failed to get chat model", err)
return nil, err
}
// Convert messages to history format
history := s.convertToHistory(processedMessages)
// Convert messages to Message format
var msgs []modelModule.Message
if systemPrompt != "" {
msgs = append(msgs, modelModule.Message{Role: "system", Content: systemPrompt})
}
for _, msg := range processedMessages {
role, _ := msg["role"].(string)
content, _ := msg["content"].(string)
if role != "" && content != "" && role != "system" {
msgs = append(msgs, modelModule.Message{Role: role, Content: content})
}
}
// Get ChatConfig directly from dialog and config
chatConfig := s.buildChatConfig(dialog, config)
// Perform chat
response, _, err := bundle.Chat(systemPrompt, history, genConf)
response, err := chatModel.ModelDriver.ChatWithMessages(*chatModel.ModelName, chatModel.APIConfig.ApiKey, msgs, chatConfig)
if err != nil {
logger.Error("asyncChatSolo chat failed", err)
return nil, err
}
logger.Info("asyncChatSolo completed",
zap.String("tenant_id", dialog.TenantID),
zap.String("llm_id", dialog.LLMID),
zap.Int("response_length", len(response)))
// Structure the answer
ans := map[string]interface{}{
"answer": response,
@@ -655,57 +582,67 @@ func (s *ChatSessionService) asyncChatSolo(dialog *entity.Chat, session *entity.
// asyncChatSoloStream performs simple streaming chat without RAG
func (s *ChatSessionService) asyncChatSoloStream(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, resultChan chan<- map[string]interface{}) {
logger.Info("asyncChatSoloStream started",
zap.String("tenant_id", dialog.TenantID),
zap.String("llm_id", dialog.LLMID),
zap.String("dialog_id", dialog.ID),
zap.Int("message_count", len(messages)))
// Get system prompt
systemPrompt := s.buildSystemPrompt(dialog)
// Process messages
processedMessages := s.processMessages(messages, dialog)
// Get LLM type
llmType := s.getLLMType(dialog.LLMID)
// Build generation config
genConf := s.buildGenConf(dialog, config)
// Create ModelBundle
var bundle *ModelBundle
var err error
if llmType == "image2text" {
bundle, err = NewModelBundle(dialog.TenantID, entity.ModelTypeImage2Text, dialog.LLMID)
} else {
bundle, err = NewModelBundle(dialog.TenantID, entity.ModelTypeChat, dialog.LLMID)
}
chatModel, err := s.modelProviderSvc.GetChatModel(dialog.TenantID, dialog.LLMID)
if err != nil {
logger.Error("asyncChatSoloStream failed to get chat model", err)
resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference)
return
}
// Convert messages to history
history := s.convertToHistory(processedMessages)
// Get chat model
chatModel, ok := bundle.GetModel().(entity.ChatModel)
if !ok {
resultChan <- s.structureAnswer(session, "**ERROR**: model is not a chat model", messageID, session.ID, reference)
return
// Convert messages to single string for ChatStreamlyWithSender
var msgBuilder strings.Builder
if systemPrompt != "" {
msgBuilder.WriteString("System: " + systemPrompt + "\n")
}
// Perform streaming chat
streamChan, err := chatModel.ChatStreamly(systemPrompt, history, genConf)
if err != nil {
resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference)
return
for _, msg := range processedMessages {
role, _ := msg["role"].(string)
content, _ := msg["content"].(string)
if role != "" && content != "" && role != "system" {
msgBuilder.WriteString(role + ": " + content + "\n")
}
}
messageStr := msgBuilder.String()
// Stream results
// Get ChatConfig directly from dialog and config
chatConfig := s.buildChatConfig(dialog, config)
// Perform streaming chat using ChatStreamlyWithSender
fullAnswer := ""
for chunk := range streamChan {
fullAnswer += chunk
// Clean up reasoning content
fullAnswer = s.removeReasoningContent(fullAnswer)
ans := s.structureAnswer(session, fullAnswer, messageID, session.ID, reference)
resultChan <- ans
err = chatModel.ModelDriver.ChatStreamlyWithSender(chatModel.ModelName, &messageStr, chatModel.APIConfig, chatConfig, func(answer *string, reason *string) error {
if reason != nil && *reason != "" {
fullAnswer += *reason
ans := s.structureAnswer(session, fullAnswer, messageID, session.ID, reference)
resultChan <- ans
}
if answer != nil && *answer != "" {
fullAnswer += *answer
fullAnswer = s.removeReasoningContent(fullAnswer)
ans := s.structureAnswer(session, fullAnswer, messageID, session.ID, reference)
resultChan <- ans
}
return nil
})
if err != nil {
resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference)
return
}
logger.Info("asyncChatSoloStream completed",
zap.String("tenant_id", dialog.TenantID),
zap.String("llm_id", dialog.LLMID),
zap.Int("response_length", len(fullAnswer)))
}
// buildSystemPrompt builds the system prompt from dialog configuration
@@ -745,50 +682,6 @@ func (s *ChatSessionService) cleanContent(content string) string {
return content
}
// convertToHistory converts messages to history format for LLM
func (s *ChatSessionService) convertToHistory(messages []map[string]interface{}) []map[string]string {
history := make([]map[string]string, 0)
for _, msg := range messages {
role, _ := msg["role"].(string)
content, _ := msg["content"].(string)
if role != "" && content != "" && role != "system" {
history = append(history, map[string]string{
"role": role,
"content": content,
})
}
}
return history
}
// buildGenConf builds generation config from dialog and request
func (s *ChatSessionService) buildGenConf(dialog *entity.Chat, config map[string]interface{}) map[string]interface{} {
genConf := make(map[string]interface{})
// Start with dialog's LLM setting
if dialog.LLMSetting != nil {
for k, v := range dialog.LLMSetting {
genConf[k] = v
}
}
// Override with request config
for k, v := range config {
genConf[k] = v
}
return genConf
}
// getLLMType gets the LLM type from model ID
func (s *ChatSessionService) getLLMType(llmID string) string {
// Simplified - would need to query TenantLLMService
if strings.Contains(llmID, "image") || strings.Contains(llmID, "vision") {
return "image2text"
}
return "chat"
}
// removeReasoningContent removes reasoning/thinking content from answer
func (s *ChatSessionService) removeReasoningContent(answer string) string {
// Remove </think> tags
@@ -891,3 +784,78 @@ func (s *ChatSessionService) chunksFormat(reference map[string]interface{}) []in
}
return formatted
}
// buildChatConfig builds ChatConfig directly from dialog.LLMSetting and config
func (s *ChatSessionService) buildChatConfig(dialog *entity.Chat, config map[string]interface{}) *modelModule.ChatConfig {
cfg := &modelModule.ChatConfig{}
// Start with dialog's LLM setting
if dialog.LLMSetting != nil {
if v, ok := dialog.LLMSetting["stream"].(bool); ok {
cfg.Stream = &v
}
if v, ok := dialog.LLMSetting["thinking"].(bool); ok {
cfg.Thinking = &v
}
if v, ok := dialog.LLMSetting["max_tokens"].(int); ok {
cfg.MaxTokens = &v
}
if v, ok := dialog.LLMSetting["temperature"].(float64); ok {
cfg.Temperature = &v
}
if v, ok := dialog.LLMSetting["top_p"].(float64); ok {
cfg.TopP = &v
}
if v, ok := dialog.LLMSetting["do_sample"].(bool); ok {
cfg.DoSample = &v
}
if v, ok := dialog.LLMSetting["stop"].([]string); ok {
cfg.Stop = &v
}
if v, ok := dialog.LLMSetting["model_class"].(string); ok {
cfg.ModelClass = &v
}
if v, ok := dialog.LLMSetting["effort"].(string); ok {
cfg.Effort = &v
}
if v, ok := dialog.LLMSetting["verbosity"].(string); ok {
cfg.Verbosity = &v
}
}
// Override with request config
if config != nil {
if v, ok := config["stream"].(bool); ok {
cfg.Stream = &v
}
if v, ok := config["thinking"].(bool); ok {
cfg.Thinking = &v
}
if v, ok := config["max_tokens"].(int); ok {
cfg.MaxTokens = &v
}
if v, ok := config["temperature"].(float64); ok {
cfg.Temperature = &v
}
if v, ok := config["top_p"].(float64); ok {
cfg.TopP = &v
}
if v, ok := config["do_sample"].(bool); ok {
cfg.DoSample = &v
}
if v, ok := config["stop"].([]string); ok {
cfg.Stop = &v
}
if v, ok := config["model_class"].(string); ok {
cfg.ModelClass = &v
}
if v, ok := config["effort"].(string); ok {
cfg.Effort = &v
}
if v, ok := config["verbosity"].(string); ok {
cfg.Verbosity = &v
}
}
return cfg
}

View File

@@ -671,7 +671,7 @@ func normalizeDatasetUUID1(id string) (string, error) {
}
func (s *DatasetsService) verifyEmbeddingAvailability(embdID string, tenantID string) (bool, string) {
modelName, provider, err := parseModelName(embdID)
modelName, _, provider, err := parseModelName(embdID)
if err != nil {
return false, "Embedding model identifier must follow <model_name>@<provider> format"
}

View File

@@ -1,181 +0,0 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package service
import (
"fmt"
"ragflow/internal/entity"
modelModule "ragflow/internal/entity/models"
)
// ModelBundle provides a unified interface for various model operations
// Similar to Python's LLMBundle but with a more generic name
type ModelBundle struct {
tenantID string
modelType entity.ModelType
modelName string
model interface{} // underlying model instance
apiConfig *modelModule.APIConfig
embeddingConfig *modelModule.EmbeddingConfig
}
// NewModelBundle creates a new ModelBundle for the given tenant and model type
// If modelName is empty, uses the default model for the tenant and type
func NewModelBundle(tenantID string, modelType entity.ModelType, modelName ...string) (*ModelBundle, error) {
bundle := &ModelBundle{
tenantID: tenantID,
modelType: modelType,
embeddingConfig: &modelModule.EmbeddingConfig{},
}
// Use provided model name if available
if len(modelName) > 0 && modelName[0] != "" {
bundle.modelName = modelName[0]
}
// Get model instance based on type
modelProviderSvc := NewModelProviderService()
switch modelType {
case entity.ModelTypeEmbedding:
embd, err := modelProviderSvc.GetEmbeddingModel(tenantID, bundle.modelName)
if err != nil {
return nil, fmt.Errorf("failed to get embedding model: %w", err)
}
bundle.model = embd.ModelDriver
bundle.apiConfig = embd.APIConfig
case entity.ModelTypeChat:
chatMdl, err := modelProviderSvc.GetChatModel(tenantID, bundle.modelName)
if err != nil {
return nil, fmt.Errorf("failed to get chat model: %w", err)
}
bundle.model = chatMdl.ModelDriver
bundle.apiConfig = chatMdl.APIConfig
case entity.ModelTypeRerank:
rerankMdl, err := modelProviderSvc.GetRerankModel(tenantID, bundle.modelName)
if err != nil {
return nil, fmt.Errorf("failed to get rerank model: %w", err)
}
bundle.model = rerankMdl.ModelDriver
bundle.apiConfig = rerankMdl.APIConfig
default:
return nil, fmt.Errorf("unsupported model type: %s", modelType)
}
return bundle, nil
}
// Encode encodes a list of texts into embeddings
// Returns embeddings and token count (for compatibility with Python interface)
func (b *ModelBundle) Encode(texts []string) ([][]float64, int64, error) {
if b.modelType != entity.ModelTypeEmbedding {
return nil, 0, fmt.Errorf("model type %s does not support encode", b.modelType)
}
embeddingModel, ok := b.model.(entity.EmbeddingModel)
if !ok {
return nil, 0, fmt.Errorf("model is not an embedding model")
}
embeddings, err := embeddingModel.Encode(&b.modelName, texts, b.apiConfig, b.embeddingConfig)
if err != nil {
return nil, 0, err
}
// TODO: Calculate actual token count
// For now, return a dummy token count
tokenCount := int64(0)
for _, text := range texts {
tokenCount += int64(len(text) / 4) // rough approximation
}
return embeddings, tokenCount, nil
}
// EncodeQuery encodes a single query string into embedding
// Returns embedding and token count
func (b *ModelBundle) EncodeQuery(query string) ([]float64, int64, error) {
if b.modelType != entity.ModelTypeEmbedding {
return nil, 0, fmt.Errorf("model type %s does not support encode query", b.modelType)
}
embeddingModel, ok := b.model.(entity.EmbeddingModel)
if !ok {
return nil, 0, fmt.Errorf("model is not an embedding model")
}
embeddings, err := embeddingModel.Encode(&b.modelName, []string{query}, b.apiConfig, b.embeddingConfig)
if err != nil {
return nil, 0, err
}
if len(embeddings) == 0 {
return nil, 0, fmt.Errorf("no embedding returned")
}
// TODO: Calculate actual token count
tokenCount := int64(len(query) / 4)
return embeddings[0], tokenCount, nil
}
// Chat sends a chat message and returns response
func (b *ModelBundle) Chat(system string, history []map[string]string, genConf map[string]interface{}) (string, int64, error) {
if b.modelType != entity.ModelTypeChat {
return "", 0, fmt.Errorf("model type %s does not support chat", b.modelType)
}
chatModel, ok := b.model.(entity.ChatModel)
if !ok {
return "", 0, fmt.Errorf("model is not a chat model")
}
response, err := chatModel.Chat(system, history, genConf)
if err != nil {
return "", 0, err
}
// TODO: Calculate actual token count
tokenCount := int64(len(response) / 4)
return response, tokenCount, nil
}
// Rerank calculates similarity between query and texts
func (b *ModelBundle) Rerank(query string, texts []string) ([]float64, int64, error) {
if b.modelType != entity.ModelTypeRerank {
return nil, 0, fmt.Errorf("model type %s does not support rerank", b.modelType)
}
rerankModel, ok := b.model.(entity.RerankModel)
if !ok {
return nil, 0, fmt.Errorf("model is not a rerank model")
}
similarities, err := rerankModel.Rerank(query, texts, b.apiConfig)
if err != nil {
return nil, 0, err
}
// TODO: Calculate actual token count
tokenCount := int64(len(query)/4) + int64(len(texts)*10)
return similarities, tokenCount, nil
}
// GetModel returns the underlying model instance
func (b *ModelBundle) GetModel() interface{} {
return b.model
}

View File

@@ -28,16 +28,20 @@ import (
"time"
)
// parseModelName parses a composite model name in format "model_name@provider"
// Returns modelName and provider separately
func parseModelName(compositeName string) (modelName, provider string, err error) {
// parseModelName parses a composite model name in format "model@instance@provider" or "model@provider"
// Returns modelName, instanceName, providerName separately
func parseModelName(compositeName string) (modelName, instanceName, providerName string, err error) {
parts := strings.Split(compositeName, "@")
if len(parts) == 2 {
return parts[0], parts[1], nil
if len(parts) == 3 {
// Format: model@instance@provider
return parts[0], parts[1], parts[2], nil
} else if len(parts) == 2 {
// Format: model@provider (legacy)
return parts[0], "", parts[1], nil
} else if len(parts) == 1 {
return parts[0], "", fmt.Errorf("provider name missing in model name: %s", compositeName)
return parts[0], "", "", fmt.Errorf("provider name missing in model name: %s", compositeName)
} else {
return "", "", fmt.Errorf("invalid model name format: %s", compositeName)
return "", "", "", fmt.Errorf("invalid model name format: %s", compositeName)
}
}
@@ -848,7 +852,7 @@ func (m *ModelProviderService) GetChatModel(tenantID, compositeModelName string)
// getModelConfig returns the model driver, model name, and API config for a model
func (m *ModelProviderService) getModelConfig(tenantID, compositeModelName string) (modelModule.ModelDriver, string, *modelModule.APIConfig, error) {
modelName, providerName, err := parseModelName(compositeModelName)
modelName, instanceName, providerName, err := parseModelName(compositeModelName)
if err != nil {
return nil, "", nil, err
}
@@ -862,7 +866,10 @@ func (m *ModelProviderService) getModelConfig(tenantID, compositeModelName strin
return nil, "", nil, fmt.Errorf("provider %s not found", providerName)
}
instanceName := "default_instance"
if instanceName == "" {
instanceName = "default_instance"
}
instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName)
if err != nil {
return nil, "", nil, err