mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
[Go] Fix searchbot retrieval_test accept kb_id as array, fix model recognize (#16452)
This commit is contained in:
@@ -132,6 +132,24 @@ type SearchBotRetrievalTestRequest struct {
|
||||
// Highlight *bool `json:"highlight,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON accepts both kb_id (Python API) and kb_ids (Go compatibility).
|
||||
func (r *SearchBotRetrievalTestRequest) UnmarshalJSON(data []byte) error {
|
||||
type Alias SearchBotRetrievalTestRequest
|
||||
aux := struct {
|
||||
*Alias
|
||||
KbID common.StringSlice `json:"kb_id"`
|
||||
}{
|
||||
Alias: (*Alias)(r),
|
||||
}
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(r.KbIDs) == 0 && len(aux.KbID) > 0 {
|
||||
r.KbIDs = aux.KbID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SearchBotRequest is the request body for POST /api/v1/searchbots/related_questions.
|
||||
type SearchBotRequest struct {
|
||||
Question string `json:"question" binding:"required"`
|
||||
|
||||
@@ -346,22 +346,28 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
|
||||
common.Debug("LabelQuestion result", zap.Any("labels", labels))
|
||||
|
||||
// Determine embedding model
|
||||
modelProviderSvc := service.NewModelProviderService()
|
||||
var embdID string
|
||||
var tenantLLM *entity.TenantLLM
|
||||
var embeddingModel *models.EmbeddingModel
|
||||
if kbRecords[0].TenantEmbdID != nil && *kbRecords[0].TenantEmbdID > 0 {
|
||||
tenantLLM, embdID, err = dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kbRecords[0].TenantEmbdID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", err)
|
||||
}
|
||||
} else if kbRecords[0].EmbdID != "" {
|
||||
parts := strings.Split(kbRecords[0].EmbdID, "@")
|
||||
if len(parts) == 2 && parts[1] != "" {
|
||||
tenantLLM, embdID, err = dao.LookupTenantLLMByFactory(dao.NewTenantLLMDAO(), tenantIDs[0], parts[1], parts[0], entity.ModelTypeEmbedding)
|
||||
if strings.Contains(kbRecords[0].EmbdID, "@") {
|
||||
driver, modelName, apiConfig, maxTokens, embErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, kbRecords[0].EmbdID)
|
||||
if embErr != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", embErr)
|
||||
}
|
||||
embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
|
||||
embdID = kbRecords[0].EmbdID
|
||||
} else {
|
||||
tenantLLM, embdID, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], kbRecords[0].EmbdID, entity.ModelTypeEmbedding)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", err)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tenantLLM, err = dao.NewTenantLLMDAO().GetByTenantAndType(tenantIDs[0], entity.ModelTypeEmbedding)
|
||||
@@ -375,10 +381,11 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
|
||||
}
|
||||
|
||||
// Get embedding model for the tenant
|
||||
modelProviderSvc := service.NewModelProviderService()
|
||||
embeddingModel, err := modelProviderSvc.GetEmbeddingModel(tenantIDs[0], embdID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model: %w", err)
|
||||
if embeddingModel == nil {
|
||||
embeddingModel, err = modelProviderSvc.GetEmbeddingModel(tenantIDs[0], embdID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model: %w", err)
|
||||
}
|
||||
}
|
||||
common.Info("Fetched embedding model for retrieval",
|
||||
zap.String("tenantID", tenantIDs[0]),
|
||||
@@ -397,10 +404,7 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
|
||||
return nil, fmt.Errorf("failed to get rerank model by tenant_rerank_id: %w", err)
|
||||
}
|
||||
} else if req.RerankID != nil && *req.RerankID != "" {
|
||||
_, rerankCompositeName, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], *req.RerankID, entity.ModelTypeRerank)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get rerank model by rerank_id: %w", err)
|
||||
}
|
||||
rerankCompositeName = *req.RerankID
|
||||
}
|
||||
if rerankCompositeName != "" {
|
||||
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeRerank, rerankCompositeName)
|
||||
|
||||
Reference in New Issue
Block a user