diff --git a/internal/handler/searchbot.go b/internal/handler/searchbot.go index e549c8cf8b..9a056f045a 100644 --- a/internal/handler/searchbot.go +++ b/internal/handler/searchbot.go @@ -132,6 +132,24 @@ type SearchBotRetrievalTestRequest struct { // Highlight *bool `json:"highlight,omitempty"` } +// UnmarshalJSON accepts both kb_id (Python API) and kb_ids (Go compatibility). +func (r *SearchBotRetrievalTestRequest) UnmarshalJSON(data []byte) error { + type Alias SearchBotRetrievalTestRequest + aux := struct { + *Alias + KbID common.StringSlice `json:"kb_id"` + }{ + Alias: (*Alias)(r), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + if len(r.KbIDs) == 0 && len(aux.KbID) > 0 { + r.KbIDs = aux.KbID + } + return nil +} + // SearchBotRequest is the request body for POST /api/v1/searchbots/related_questions. type SearchBotRequest struct { Question string `json:"question" binding:"required"` diff --git a/internal/service/chunk/chunk.go b/internal/service/chunk/chunk.go index 36b05b5657..ecda520cd9 100644 --- a/internal/service/chunk/chunk.go +++ b/internal/service/chunk/chunk.go @@ -346,22 +346,28 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s common.Debug("LabelQuestion result", zap.Any("labels", labels)) // Determine embedding model + modelProviderSvc := service.NewModelProviderService() var embdID string var tenantLLM *entity.TenantLLM + var embeddingModel *models.EmbeddingModel if kbRecords[0].TenantEmbdID != nil && *kbRecords[0].TenantEmbdID > 0 { tenantLLM, embdID, err = dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kbRecords[0].TenantEmbdID) if err != nil { return nil, fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", err) } } else if kbRecords[0].EmbdID != "" { - parts := strings.Split(kbRecords[0].EmbdID, "@") - if len(parts) == 2 && parts[1] != "" { - tenantLLM, embdID, err = dao.LookupTenantLLMByFactory(dao.NewTenantLLMDAO(), tenantIDs[0], parts[1], parts[0], entity.ModelTypeEmbedding) + if strings.Contains(kbRecords[0].EmbdID, "@") { + driver, modelName, apiConfig, maxTokens, embErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, kbRecords[0].EmbdID) + if embErr != nil { + return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", embErr) + } + embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens) + embdID = kbRecords[0].EmbdID } else { tenantLLM, embdID, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], kbRecords[0].EmbdID, entity.ModelTypeEmbedding) - } - if err != nil { - return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", err) + if err != nil { + return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", err) + } } } else { tenantLLM, err = dao.NewTenantLLMDAO().GetByTenantAndType(tenantIDs[0], entity.ModelTypeEmbedding) @@ -375,10 +381,11 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s } // Get embedding model for the tenant - modelProviderSvc := service.NewModelProviderService() - embeddingModel, err := modelProviderSvc.GetEmbeddingModel(tenantIDs[0], embdID) - if err != nil { - return nil, fmt.Errorf("failed to get embedding model: %w", err) + if embeddingModel == nil { + embeddingModel, err = modelProviderSvc.GetEmbeddingModel(tenantIDs[0], embdID) + if err != nil { + return nil, fmt.Errorf("failed to get embedding model: %w", err) + } } common.Info("Fetched embedding model for retrieval", zap.String("tenantID", tenantIDs[0]), @@ -397,10 +404,7 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s return nil, fmt.Errorf("failed to get rerank model by tenant_rerank_id: %w", err) } } else if req.RerankID != nil && *req.RerankID != "" { - _, rerankCompositeName, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], *req.RerankID, entity.ModelTypeRerank) - if err != nil { - return nil, fmt.Errorf("failed to get rerank model by rerank_id: %w", err) - } + rerankCompositeName = *req.RerankID } if rerankCompositeName != "" { driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeRerank, rerankCompositeName)