[Go] Fix searchbot retrieval_test accept kb_id as array, fix model recognize (#16452)

This commit is contained in:
Wang Qi
2026-06-29 17:17:20 +08:00
committed by GitHub
parent 3202ec6abf
commit c0f64295c2
2 changed files with 36 additions and 14 deletions

View File

@@ -132,6 +132,24 @@ type SearchBotRetrievalTestRequest struct {
// Highlight *bool `json:"highlight,omitempty"`
}
// UnmarshalJSON accepts both kb_id (Python API) and kb_ids (Go compatibility).
func (r *SearchBotRetrievalTestRequest) UnmarshalJSON(data []byte) error {
type Alias SearchBotRetrievalTestRequest
aux := struct {
*Alias
KbID common.StringSlice `json:"kb_id"`
}{
Alias: (*Alias)(r),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
if len(r.KbIDs) == 0 && len(aux.KbID) > 0 {
r.KbIDs = aux.KbID
}
return nil
}
// SearchBotRequest is the request body for POST /api/v1/searchbots/related_questions.
type SearchBotRequest struct {
Question string `json:"question" binding:"required"`

View File

@@ -346,22 +346,28 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
common.Debug("LabelQuestion result", zap.Any("labels", labels))
// Determine embedding model
modelProviderSvc := service.NewModelProviderService()
var embdID string
var tenantLLM *entity.TenantLLM
var embeddingModel *models.EmbeddingModel
if kbRecords[0].TenantEmbdID != nil && *kbRecords[0].TenantEmbdID > 0 {
tenantLLM, embdID, err = dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kbRecords[0].TenantEmbdID)
if err != nil {
return nil, fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", err)
}
} else if kbRecords[0].EmbdID != "" {
parts := strings.Split(kbRecords[0].EmbdID, "@")
if len(parts) == 2 && parts[1] != "" {
tenantLLM, embdID, err = dao.LookupTenantLLMByFactory(dao.NewTenantLLMDAO(), tenantIDs[0], parts[1], parts[0], entity.ModelTypeEmbedding)
if strings.Contains(kbRecords[0].EmbdID, "@") {
driver, modelName, apiConfig, maxTokens, embErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, kbRecords[0].EmbdID)
if embErr != nil {
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", embErr)
}
embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
embdID = kbRecords[0].EmbdID
} else {
tenantLLM, embdID, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], kbRecords[0].EmbdID, entity.ModelTypeEmbedding)
}
if err != nil {
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", err)
if err != nil {
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", err)
}
}
} else {
tenantLLM, err = dao.NewTenantLLMDAO().GetByTenantAndType(tenantIDs[0], entity.ModelTypeEmbedding)
@@ -375,10 +381,11 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
}
// Get embedding model for the tenant
modelProviderSvc := service.NewModelProviderService()
embeddingModel, err := modelProviderSvc.GetEmbeddingModel(tenantIDs[0], embdID)
if err != nil {
return nil, fmt.Errorf("failed to get embedding model: %w", err)
if embeddingModel == nil {
embeddingModel, err = modelProviderSvc.GetEmbeddingModel(tenantIDs[0], embdID)
if err != nil {
return nil, fmt.Errorf("failed to get embedding model: %w", err)
}
}
common.Info("Fetched embedding model for retrieval",
zap.String("tenantID", tenantIDs[0]),
@@ -397,10 +404,7 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
return nil, fmt.Errorf("failed to get rerank model by tenant_rerank_id: %w", err)
}
} else if req.RerankID != nil && *req.RerankID != "" {
_, rerankCompositeName, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], *req.RerankID, entity.ModelTypeRerank)
if err != nil {
return nil, fmt.Errorf("failed to get rerank model by rerank_id: %w", err)
}
rerankCompositeName = *req.RerankID
}
if rerankCompositeName != "" {
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeRerank, rerankCompositeName)