From f3c232cf47626c332d0aa7caee614715afeb214c Mon Sep 17 00:00:00 2001 From: qinling0210 <88864212+qinling0210@users.noreply.github.com> Date: Wed, 29 Apr 2026 14:44:12 +0800 Subject: [PATCH] Remove model_bundle.go, modify chat_session.go (#14458) ### What problem does this PR solve? Remove model_bundle.go, modify chat_session.go ### Type of change - [x] Refactoring --- internal/entity/types.go | 36 --- internal/handler/chat_session.go | 11 +- internal/service/chat_session.go | 354 ++++++++++++++---------------- internal/service/datasets.go | 2 +- internal/service/model_bundle.go | 181 --------------- internal/service/model_service.go | 25 ++- 6 files changed, 187 insertions(+), 422 deletions(-) delete mode 100644 internal/service/model_bundle.go diff --git a/internal/entity/types.go b/internal/entity/types.go index 41154dcf41..f342310acb 100644 --- a/internal/entity/types.go +++ b/internal/entity/types.go @@ -16,10 +16,6 @@ package entity -import ( - "ragflow/internal/entity/models" -) - // ModelType represents the type of model type ModelType string @@ -40,38 +36,6 @@ const ( ModelTypeOCR ModelType = "ocr" ) -// EmbeddingModel interface for embedding models -type EmbeddingModel interface { - // Encode encodes a list of texts into embeddings - Encode(modelName *string, texts []string, apiConfig *models.APIConfig, embeddingConfig *models.EmbeddingConfig) ([][]float64, error) -} - -// ChatModel interface for chat models -type ChatModel interface { - // Chat sends a message and returns response - Chat(system string, history []map[string]string, genConf map[string]interface{}) (string, error) - // ChatStreamly sends a message and streams response - ChatStreamly(system string, history []map[string]string, genConf map[string]interface{}) (<-chan string, error) -} - -// RerankModel interface for rerank models -type RerankModel interface { - // Rerank calculates similarity between query and texts - Rerank(query string, texts []string, apiConfig *models.APIConfig) ([]float64, error) -} - -// ModelConfig represents configuration for a model -type ModelConfig struct { - TenantID string `json:"tenant_id"` - LLMFactory string `json:"llm_factory"` - ModelType ModelType `json:"model_type"` - LLMName string `json:"llm_name"` - APIKey string `json:"api_key"` - APIBase string `json:"api_base"` - MaxTokens int64 `json:"max_tokens"` - IsTools bool `json:"is_tools"` -} - // ModelCredentials holds the credentials for a model type ModelCredentials struct { ProviderName string diff --git a/internal/handler/chat_session.go b/internal/handler/chat_session.go index ebf293957e..897e62f18a 100644 --- a/internal/handler/chat_session.go +++ b/internal/handler/chat_session.go @@ -200,7 +200,8 @@ type CompletionRequest struct { ConversationID string `json:"conversation_id" binding:"required"` Messages []map[string]interface{} `json:"messages" binding:"required"` LLMID string `json:"llm_id,omitempty"` - Stream bool `json:"stream,omitempty"` + Stream *bool `json:"stream,omitempty"` + Thinking *bool `json:"thinking,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` @@ -252,6 +253,12 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) { if req.MaxTokens != 0 { chatModelConfig["max_tokens"] = req.MaxTokens } + if req.Stream != nil { + chatModelConfig["stream"] = *req.Stream + } + if req.Thinking != nil { + chatModelConfig["thinking"] = *req.Thinking + } // Process messages - filter out system messages and initial assistant messages var processedMessages []map[string]interface{} @@ -276,7 +283,7 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) { } // Call service - if req.Stream { + if req.Stream != nil && *req.Stream { // Streaming response c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") diff --git a/internal/service/chat_session.go b/internal/service/chat_session.go index 1ec6c4f846..d563a2c363 100644 --- a/internal/service/chat_session.go +++ b/internal/service/chat_session.go @@ -24,24 +24,29 @@ import ( "time" "github.com/google/uuid" + "go.uber.org/zap" "ragflow/internal/dao" "ragflow/internal/entity" + modelModule "ragflow/internal/entity/models" + "ragflow/internal/logger" ) // ChatSessionService chat session (conversation) service type ChatSessionService struct { - chatSessionDAO *dao.ChatSessionDAO - chatDAO *dao.ChatDAO - userTenantDAO *dao.UserTenantDAO + chatSessionDAO *dao.ChatSessionDAO + chatDAO *dao.ChatDAO + userTenantDAO *dao.UserTenantDAO + modelProviderSvc *ModelProviderService } // NewChatSessionService create chat session service func NewChatSessionService() *ChatSessionService { return &ChatSessionService{ - chatSessionDAO: dao.NewChatSessionDAO(), - chatDAO: dao.NewChatDAO(), - userTenantDAO: dao.NewUserTenantDAO(), + chatSessionDAO: dao.NewChatSessionDAO(), + chatDAO: dao.NewChatDAO(), + userTenantDAO: dao.NewUserTenantDAO(), + modelProviderSvc: NewModelProviderService(), } } @@ -433,97 +438,6 @@ func (s *ChatSessionService) checkTenantLLMAPIKey(tenantID, modelName string) (b return true, nil } -func (s *ChatSessionService) performChat(dialog *entity.Chat, messages []map[string]interface{}, config map[string]interface{}) (string, error) { - // Get system prompt from dialog - systemPrompt := "" - if dialog.PromptConfig != nil { - if sys, ok := dialog.PromptConfig["system"].(string); ok { - systemPrompt = sys - } - } - - // Convert messages to history format - history := make([]map[string]string, 0) - for _, msg := range messages { - role, _ := msg["role"].(string) - content, _ := msg["content"].(string) - if role != "" && content != "" { - history = append(history, map[string]string{ - "role": role, - "content": content, - }) - } - } - - // Use ModelBundle to perform chat - bundle, err := NewModelBundle(dialog.TenantID, entity.ModelTypeChat, dialog.LLMID) - if err != nil { - return "", err - } - - // Merge dialog's LLM setting with request config - genConf := make(map[string]interface{}) - if dialog.LLMSetting != nil { - for k, v := range dialog.LLMSetting { - genConf[k] = v - } - } - for k, v := range config { - genConf[k] = v - } - - response, _, err := bundle.Chat(systemPrompt, history, genConf) - return response, err -} - -func (s *ChatSessionService) performChatStream(dialog *entity.Chat, messages []map[string]interface{}, config map[string]interface{}) (<-chan string, error) { - // Get system prompt from dialog - systemPrompt := "" - if dialog.PromptConfig != nil { - if sys, ok := dialog.PromptConfig["system"].(string); ok { - systemPrompt = sys - } - } - - // Convert messages to history format - history := make([]map[string]string, 0) - for _, msg := range messages { - role, _ := msg["role"].(string) - content, _ := msg["content"].(string) - if role != "" && content != "" { - history = append(history, map[string]string{ - "role": role, - "content": content, - }) - } - } - - // Use ModelBundle to perform streaming chat - bundle, err := NewModelBundle(dialog.TenantID, entity.ModelTypeChat, dialog.LLMID) - if err != nil { - return nil, err - } - - // Merge dialog's LLM setting with request config - genConf := make(map[string]interface{}) - if dialog.LLMSetting != nil { - for k, v := range dialog.LLMSetting { - genConf[k] = v - } - } - for k, v := range config { - genConf[k] = v - } - - // Get chat model and call ChatStreamly - chatModel, ok := bundle.GetModel().(entity.ChatModel) - if !ok { - return nil, fmt.Errorf("model is not a chat model") - } - - return chatModel.ChatStreamly(systemPrompt, history, genConf) -} - func (s *ChatSessionService) structureAnswer(session *entity.ChatSession, answer string, messageID, conversationID string, reference []interface{}) map[string]interface{} { return map[string]interface{}{ "answer": answer, @@ -610,39 +524,52 @@ func (s *ChatSessionService) asyncChatStream(dialog *entity.Chat, session *entit // asyncChatSolo performs simple chat without RAG (non-streaming) func (s *ChatSessionService) asyncChatSolo(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) { + logger.Info("asyncChatSolo started", + zap.String("tenant_id", dialog.TenantID), + zap.String("llm_id", dialog.LLMID), + zap.String("dialog_id", dialog.ID), + zap.Int("message_count", len(messages))) + // Get system prompt systemPrompt := s.buildSystemPrompt(dialog) // Process messages - handle attachments and image files processedMessages := s.processMessages(messages, dialog) - // Get LLM type - llmType := s.getLLMType(dialog.LLMID) - - // Build generation config - genConf := s.buildGenConf(dialog, config) - - // Create ModelBundle for chat - var bundle *ModelBundle - var err error - if llmType == "image2text" { - bundle, err = NewModelBundle(dialog.TenantID, entity.ModelTypeImage2Text, dialog.LLMID) - } else { - bundle, err = NewModelBundle(dialog.TenantID, entity.ModelTypeChat, dialog.LLMID) - } + chatModel, err := s.modelProviderSvc.GetChatModel(dialog.TenantID, dialog.LLMID) if err != nil { + logger.Error("asyncChatSolo failed to get chat model", err) return nil, err } - // Convert messages to history format - history := s.convertToHistory(processedMessages) + // Convert messages to Message format + var msgs []modelModule.Message + if systemPrompt != "" { + msgs = append(msgs, modelModule.Message{Role: "system", Content: systemPrompt}) + } + for _, msg := range processedMessages { + role, _ := msg["role"].(string) + content, _ := msg["content"].(string) + if role != "" && content != "" && role != "system" { + msgs = append(msgs, modelModule.Message{Role: role, Content: content}) + } + } + + // Get ChatConfig directly from dialog and config + chatConfig := s.buildChatConfig(dialog, config) // Perform chat - response, _, err := bundle.Chat(systemPrompt, history, genConf) + response, err := chatModel.ModelDriver.ChatWithMessages(*chatModel.ModelName, chatModel.APIConfig.ApiKey, msgs, chatConfig) if err != nil { + logger.Error("asyncChatSolo chat failed", err) return nil, err } + logger.Info("asyncChatSolo completed", + zap.String("tenant_id", dialog.TenantID), + zap.String("llm_id", dialog.LLMID), + zap.Int("response_length", len(response))) + // Structure the answer ans := map[string]interface{}{ "answer": response, @@ -655,57 +582,67 @@ func (s *ChatSessionService) asyncChatSolo(dialog *entity.Chat, session *entity. // asyncChatSoloStream performs simple streaming chat without RAG func (s *ChatSessionService) asyncChatSoloStream(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, resultChan chan<- map[string]interface{}) { + logger.Info("asyncChatSoloStream started", + zap.String("tenant_id", dialog.TenantID), + zap.String("llm_id", dialog.LLMID), + zap.String("dialog_id", dialog.ID), + zap.Int("message_count", len(messages))) + // Get system prompt systemPrompt := s.buildSystemPrompt(dialog) // Process messages processedMessages := s.processMessages(messages, dialog) - // Get LLM type - llmType := s.getLLMType(dialog.LLMID) - - // Build generation config - genConf := s.buildGenConf(dialog, config) - - // Create ModelBundle - var bundle *ModelBundle - var err error - if llmType == "image2text" { - bundle, err = NewModelBundle(dialog.TenantID, entity.ModelTypeImage2Text, dialog.LLMID) - } else { - bundle, err = NewModelBundle(dialog.TenantID, entity.ModelTypeChat, dialog.LLMID) - } + chatModel, err := s.modelProviderSvc.GetChatModel(dialog.TenantID, dialog.LLMID) if err != nil { + logger.Error("asyncChatSoloStream failed to get chat model", err) resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference) return } - // Convert messages to history - history := s.convertToHistory(processedMessages) - - // Get chat model - chatModel, ok := bundle.GetModel().(entity.ChatModel) - if !ok { - resultChan <- s.structureAnswer(session, "**ERROR**: model is not a chat model", messageID, session.ID, reference) - return + // Convert messages to single string for ChatStreamlyWithSender + var msgBuilder strings.Builder + if systemPrompt != "" { + msgBuilder.WriteString("System: " + systemPrompt + "\n") } - - // Perform streaming chat - streamChan, err := chatModel.ChatStreamly(systemPrompt, history, genConf) - if err != nil { - resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference) - return + for _, msg := range processedMessages { + role, _ := msg["role"].(string) + content, _ := msg["content"].(string) + if role != "" && content != "" && role != "system" { + msgBuilder.WriteString(role + ": " + content + "\n") + } } + messageStr := msgBuilder.String() - // Stream results + // Get ChatConfig directly from dialog and config + chatConfig := s.buildChatConfig(dialog, config) + + // Perform streaming chat using ChatStreamlyWithSender fullAnswer := "" - for chunk := range streamChan { - fullAnswer += chunk - // Clean up reasoning content - fullAnswer = s.removeReasoningContent(fullAnswer) - ans := s.structureAnswer(session, fullAnswer, messageID, session.ID, reference) - resultChan <- ans + err = chatModel.ModelDriver.ChatStreamlyWithSender(chatModel.ModelName, &messageStr, chatModel.APIConfig, chatConfig, func(answer *string, reason *string) error { + if reason != nil && *reason != "" { + fullAnswer += *reason + ans := s.structureAnswer(session, fullAnswer, messageID, session.ID, reference) + resultChan <- ans + } + if answer != nil && *answer != "" { + fullAnswer += *answer + fullAnswer = s.removeReasoningContent(fullAnswer) + ans := s.structureAnswer(session, fullAnswer, messageID, session.ID, reference) + resultChan <- ans + } + return nil + }) + if err != nil { + resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference) + return } + + logger.Info("asyncChatSoloStream completed", + zap.String("tenant_id", dialog.TenantID), + zap.String("llm_id", dialog.LLMID), + zap.Int("response_length", len(fullAnswer))) } // buildSystemPrompt builds the system prompt from dialog configuration @@ -745,50 +682,6 @@ func (s *ChatSessionService) cleanContent(content string) string { return content } -// convertToHistory converts messages to history format for LLM -func (s *ChatSessionService) convertToHistory(messages []map[string]interface{}) []map[string]string { - history := make([]map[string]string, 0) - for _, msg := range messages { - role, _ := msg["role"].(string) - content, _ := msg["content"].(string) - if role != "" && content != "" && role != "system" { - history = append(history, map[string]string{ - "role": role, - "content": content, - }) - } - } - return history -} - -// buildGenConf builds generation config from dialog and request -func (s *ChatSessionService) buildGenConf(dialog *entity.Chat, config map[string]interface{}) map[string]interface{} { - genConf := make(map[string]interface{}) - - // Start with dialog's LLM setting - if dialog.LLMSetting != nil { - for k, v := range dialog.LLMSetting { - genConf[k] = v - } - } - - // Override with request config - for k, v := range config { - genConf[k] = v - } - - return genConf -} - -// getLLMType gets the LLM type from model ID -func (s *ChatSessionService) getLLMType(llmID string) string { - // Simplified - would need to query TenantLLMService - if strings.Contains(llmID, "image") || strings.Contains(llmID, "vision") { - return "image2text" - } - return "chat" -} - // removeReasoningContent removes reasoning/thinking content from answer func (s *ChatSessionService) removeReasoningContent(answer string) string { // Remove tags @@ -891,3 +784,78 @@ func (s *ChatSessionService) chunksFormat(reference map[string]interface{}) []in } return formatted } + +// buildChatConfig builds ChatConfig directly from dialog.LLMSetting and config +func (s *ChatSessionService) buildChatConfig(dialog *entity.Chat, config map[string]interface{}) *modelModule.ChatConfig { + cfg := &modelModule.ChatConfig{} + + // Start with dialog's LLM setting + if dialog.LLMSetting != nil { + if v, ok := dialog.LLMSetting["stream"].(bool); ok { + cfg.Stream = &v + } + if v, ok := dialog.LLMSetting["thinking"].(bool); ok { + cfg.Thinking = &v + } + if v, ok := dialog.LLMSetting["max_tokens"].(int); ok { + cfg.MaxTokens = &v + } + if v, ok := dialog.LLMSetting["temperature"].(float64); ok { + cfg.Temperature = &v + } + if v, ok := dialog.LLMSetting["top_p"].(float64); ok { + cfg.TopP = &v + } + if v, ok := dialog.LLMSetting["do_sample"].(bool); ok { + cfg.DoSample = &v + } + if v, ok := dialog.LLMSetting["stop"].([]string); ok { + cfg.Stop = &v + } + if v, ok := dialog.LLMSetting["model_class"].(string); ok { + cfg.ModelClass = &v + } + if v, ok := dialog.LLMSetting["effort"].(string); ok { + cfg.Effort = &v + } + if v, ok := dialog.LLMSetting["verbosity"].(string); ok { + cfg.Verbosity = &v + } + } + + // Override with request config + if config != nil { + if v, ok := config["stream"].(bool); ok { + cfg.Stream = &v + } + if v, ok := config["thinking"].(bool); ok { + cfg.Thinking = &v + } + if v, ok := config["max_tokens"].(int); ok { + cfg.MaxTokens = &v + } + if v, ok := config["temperature"].(float64); ok { + cfg.Temperature = &v + } + if v, ok := config["top_p"].(float64); ok { + cfg.TopP = &v + } + if v, ok := config["do_sample"].(bool); ok { + cfg.DoSample = &v + } + if v, ok := config["stop"].([]string); ok { + cfg.Stop = &v + } + if v, ok := config["model_class"].(string); ok { + cfg.ModelClass = &v + } + if v, ok := config["effort"].(string); ok { + cfg.Effort = &v + } + if v, ok := config["verbosity"].(string); ok { + cfg.Verbosity = &v + } + } + + return cfg +} diff --git a/internal/service/datasets.go b/internal/service/datasets.go index 4c6172043f..271f457a20 100644 --- a/internal/service/datasets.go +++ b/internal/service/datasets.go @@ -671,7 +671,7 @@ func normalizeDatasetUUID1(id string) (string, error) { } func (s *DatasetsService) verifyEmbeddingAvailability(embdID string, tenantID string) (bool, string) { - modelName, provider, err := parseModelName(embdID) + modelName, _, provider, err := parseModelName(embdID) if err != nil { return false, "Embedding model identifier must follow @ format" } diff --git a/internal/service/model_bundle.go b/internal/service/model_bundle.go deleted file mode 100644 index 528de89d02..0000000000 --- a/internal/service/model_bundle.go +++ /dev/null @@ -1,181 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package service - -import ( - "fmt" - "ragflow/internal/entity" - modelModule "ragflow/internal/entity/models" -) - -// ModelBundle provides a unified interface for various model operations -// Similar to Python's LLMBundle but with a more generic name -type ModelBundle struct { - tenantID string - modelType entity.ModelType - modelName string - model interface{} // underlying model instance - apiConfig *modelModule.APIConfig - embeddingConfig *modelModule.EmbeddingConfig -} - -// NewModelBundle creates a new ModelBundle for the given tenant and model type -// If modelName is empty, uses the default model for the tenant and type -func NewModelBundle(tenantID string, modelType entity.ModelType, modelName ...string) (*ModelBundle, error) { - bundle := &ModelBundle{ - tenantID: tenantID, - modelType: modelType, - embeddingConfig: &modelModule.EmbeddingConfig{}, - } - - // Use provided model name if available - if len(modelName) > 0 && modelName[0] != "" { - bundle.modelName = modelName[0] - } - - // Get model instance based on type - modelProviderSvc := NewModelProviderService() - switch modelType { - case entity.ModelTypeEmbedding: - embd, err := modelProviderSvc.GetEmbeddingModel(tenantID, bundle.modelName) - if err != nil { - return nil, fmt.Errorf("failed to get embedding model: %w", err) - } - bundle.model = embd.ModelDriver - bundle.apiConfig = embd.APIConfig - case entity.ModelTypeChat: - chatMdl, err := modelProviderSvc.GetChatModel(tenantID, bundle.modelName) - if err != nil { - return nil, fmt.Errorf("failed to get chat model: %w", err) - } - bundle.model = chatMdl.ModelDriver - bundle.apiConfig = chatMdl.APIConfig - case entity.ModelTypeRerank: - rerankMdl, err := modelProviderSvc.GetRerankModel(tenantID, bundle.modelName) - if err != nil { - return nil, fmt.Errorf("failed to get rerank model: %w", err) - } - bundle.model = rerankMdl.ModelDriver - bundle.apiConfig = rerankMdl.APIConfig - default: - return nil, fmt.Errorf("unsupported model type: %s", modelType) - } - - return bundle, nil -} - -// Encode encodes a list of texts into embeddings -// Returns embeddings and token count (for compatibility with Python interface) -func (b *ModelBundle) Encode(texts []string) ([][]float64, int64, error) { - if b.modelType != entity.ModelTypeEmbedding { - return nil, 0, fmt.Errorf("model type %s does not support encode", b.modelType) - } - - embeddingModel, ok := b.model.(entity.EmbeddingModel) - if !ok { - return nil, 0, fmt.Errorf("model is not an embedding model") - } - - embeddings, err := embeddingModel.Encode(&b.modelName, texts, b.apiConfig, b.embeddingConfig) - if err != nil { - return nil, 0, err - } - - // TODO: Calculate actual token count - // For now, return a dummy token count - tokenCount := int64(0) - for _, text := range texts { - tokenCount += int64(len(text) / 4) // rough approximation - } - - return embeddings, tokenCount, nil -} - -// EncodeQuery encodes a single query string into embedding -// Returns embedding and token count -func (b *ModelBundle) EncodeQuery(query string) ([]float64, int64, error) { - if b.modelType != entity.ModelTypeEmbedding { - return nil, 0, fmt.Errorf("model type %s does not support encode query", b.modelType) - } - - embeddingModel, ok := b.model.(entity.EmbeddingModel) - if !ok { - return nil, 0, fmt.Errorf("model is not an embedding model") - } - - embeddings, err := embeddingModel.Encode(&b.modelName, []string{query}, b.apiConfig, b.embeddingConfig) - if err != nil { - return nil, 0, err - } - if len(embeddings) == 0 { - return nil, 0, fmt.Errorf("no embedding returned") - } - - // TODO: Calculate actual token count - tokenCount := int64(len(query) / 4) - - return embeddings[0], tokenCount, nil -} - -// Chat sends a chat message and returns response -func (b *ModelBundle) Chat(system string, history []map[string]string, genConf map[string]interface{}) (string, int64, error) { - if b.modelType != entity.ModelTypeChat { - return "", 0, fmt.Errorf("model type %s does not support chat", b.modelType) - } - - chatModel, ok := b.model.(entity.ChatModel) - if !ok { - return "", 0, fmt.Errorf("model is not a chat model") - } - - response, err := chatModel.Chat(system, history, genConf) - if err != nil { - return "", 0, err - } - - // TODO: Calculate actual token count - tokenCount := int64(len(response) / 4) - - return response, tokenCount, nil -} - -// Rerank calculates similarity between query and texts -func (b *ModelBundle) Rerank(query string, texts []string) ([]float64, int64, error) { - if b.modelType != entity.ModelTypeRerank { - return nil, 0, fmt.Errorf("model type %s does not support rerank", b.modelType) - } - - rerankModel, ok := b.model.(entity.RerankModel) - if !ok { - return nil, 0, fmt.Errorf("model is not a rerank model") - } - - similarities, err := rerankModel.Rerank(query, texts, b.apiConfig) - if err != nil { - return nil, 0, err - } - - // TODO: Calculate actual token count - tokenCount := int64(len(query)/4) + int64(len(texts)*10) - - return similarities, tokenCount, nil -} - -// GetModel returns the underlying model instance -func (b *ModelBundle) GetModel() interface{} { - return b.model -} diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 3387cbb9f5..85edf695bd 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -28,16 +28,20 @@ import ( "time" ) -// parseModelName parses a composite model name in format "model_name@provider" -// Returns modelName and provider separately -func parseModelName(compositeName string) (modelName, provider string, err error) { +// parseModelName parses a composite model name in format "model@instance@provider" or "model@provider" +// Returns modelName, instanceName, providerName separately +func parseModelName(compositeName string) (modelName, instanceName, providerName string, err error) { parts := strings.Split(compositeName, "@") - if len(parts) == 2 { - return parts[0], parts[1], nil + if len(parts) == 3 { + // Format: model@instance@provider + return parts[0], parts[1], parts[2], nil + } else if len(parts) == 2 { + // Format: model@provider (legacy) + return parts[0], "", parts[1], nil } else if len(parts) == 1 { - return parts[0], "", fmt.Errorf("provider name missing in model name: %s", compositeName) + return parts[0], "", "", fmt.Errorf("provider name missing in model name: %s", compositeName) } else { - return "", "", fmt.Errorf("invalid model name format: %s", compositeName) + return "", "", "", fmt.Errorf("invalid model name format: %s", compositeName) } } @@ -848,7 +852,7 @@ func (m *ModelProviderService) GetChatModel(tenantID, compositeModelName string) // getModelConfig returns the model driver, model name, and API config for a model func (m *ModelProviderService) getModelConfig(tenantID, compositeModelName string) (modelModule.ModelDriver, string, *modelModule.APIConfig, error) { - modelName, providerName, err := parseModelName(compositeModelName) + modelName, instanceName, providerName, err := parseModelName(compositeModelName) if err != nil { return nil, "", nil, err } @@ -862,7 +866,10 @@ func (m *ModelProviderService) getModelConfig(tenantID, compositeModelName strin return nil, "", nil, fmt.Errorf("provider %s not found", providerName) } - instanceName := "default_instance" + if instanceName == "" { + instanceName = "default_instance" + } + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) if err != nil { return nil, "", nil, err