diff --git a/internal/entity/models/aliyun.go b/internal/entity/models/aliyun.go index 4975ed295e..48ef6b7066 100644 --- a/internal/entity/models/aliyun.go +++ b/internal/entity/models/aliyun.go @@ -332,21 +332,11 @@ func (z *AliyunModel) ChatStreamlyWithSender(modelName, message *string, apiConf return scanner.Err() } -// EncodeToEmbedding encodes a list of texts into embeddings -func (z *AliyunModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +// Encode encodes a list of texts into embeddings +func (z *AliyunModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { return nil, fmt.Errorf("%s, no such method", z.Name()) } -// Encode encodes a list of texts into embeddings (convenience method) -func (z *AliyunModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { - return nil, fmt.Errorf("%s, Encode not implemented", z.Name()) -} - -// EncodeQuery encodes a single query string into embedding (convenience method) -func (z *AliyunModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { - return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name()) -} - // Rerank calculates similarity scores between query and texts func (z *AliyunModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) { return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) diff --git a/internal/entity/models/deepseek.go b/internal/entity/models/deepseek.go index eee8b800d3..ee47918a54 100644 --- a/internal/entity/models/deepseek.go +++ b/internal/entity/models/deepseek.go @@ -396,21 +396,11 @@ func (z *DeepSeekModel) ChatStreamlyWithSender(modelName, message *string, apiCo return scanner.Err() } -// EncodeToEmbedding encodes a list of texts into embeddings -func (z *DeepSeekModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +// Encode encodes a list of texts into embeddings +func (z *DeepSeekModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { return nil, fmt.Errorf("%s, no such method", z.Name()) } -// Encode encodes a list of texts into embeddings (convenience method) -func (z *DeepSeekModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { - return nil, fmt.Errorf("%s, Encode not implemented", z.Name()) -} - -// EncodeQuery encodes a single query string into embedding (convenience method) -func (z *DeepSeekModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { - return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name()) -} - type DSModel struct { ID string `json:"id"` Object string `json:"object"` diff --git a/internal/entity/models/dummy.go b/internal/entity/models/dummy.go index e93de49fe4..59a84b49fe 100644 --- a/internal/entity/models/dummy.go +++ b/internal/entity/models/dummy.go @@ -53,21 +53,11 @@ func (z *DummyModel) ChatStreamlyWithSender(modelName, message *string, apiConfi return fmt.Errorf("not implemented") } -// EncodeToEmbedding encodes a list of texts into embeddings -func (z *DummyModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +// Encode encodes a list of texts into embeddings +func (z *DummyModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { return nil, fmt.Errorf("not implemented") } -// Encode encodes a list of texts into embeddings (convenience method) -func (z *DummyModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { - return nil, fmt.Errorf("%s, Encode not implemented", z.Name()) -} - -// EncodeQuery encodes a single query string into embedding (convenience method) -func (z *DummyModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { - return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name()) -} - func (z *DummyModel) ListModels(apiConfig *APIConfig) ([]string, error) { return nil, fmt.Errorf("not implemented") } diff --git a/internal/entity/models/gitee.go b/internal/entity/models/gitee.go index d1ceee5f5a..b28bedea13 100644 --- a/internal/entity/models/gitee.go +++ b/internal/entity/models/gitee.go @@ -362,21 +362,11 @@ func (z *GiteeModel) ChatStreamlyWithSender(modelName, message *string, apiConfi return scanner.Err() } -// EncodeToEmbedding encodes a list of texts into embeddings -func (z *GiteeModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +// Encode encodes a list of texts into embeddings +func (z *GiteeModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { return nil, fmt.Errorf("%s, no such method", z.Name()) } -// Encode encodes a list of texts into embeddings (convenience method) -func (z *GiteeModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { - return nil, fmt.Errorf("%s, Encode not implemented", z.Name()) -} - -// EncodeQuery encodes a single query string into embedding (convenience method) -func (z *GiteeModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { - return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name()) -} - // Rerank calculates similarity scores between query and texts func (z *GiteeModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) { return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) diff --git a/internal/entity/models/google.go b/internal/entity/models/google.go index c0c3b20f7d..cbc42b2812 100644 --- a/internal/entity/models/google.go +++ b/internal/entity/models/google.go @@ -136,8 +136,8 @@ func (z *GoogleModel) ChatStreamlyWithSender(modelName, message *string, apiConf return err } -// EncodeToEmbedding encodes a list of texts into embeddings -func (z *GoogleModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +// Encode encodes a list of texts into embeddings +func (z *GoogleModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { return nil, fmt.Errorf("not implemented") } @@ -172,23 +172,6 @@ func (z *GoogleModel) CheckConnection(apiConfig *APIConfig) error { return fmt.Errorf("no such method") } -// Encode encodes a list of texts into embeddings (convenience method) -func (z *GoogleModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { - return z.EncodeToEmbedding(modelName, texts, apiConfig, nil) -} - -// EncodeQuery encodes a single query string into embedding (convenience method) -func (z *GoogleModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { - embeddings, err := z.Encode(modelName, []string{query}, apiConfig) - if err != nil { - return nil, err - } - if len(embeddings) == 0 { - return nil, fmt.Errorf("no embedding returned") - } - return embeddings[0], nil -} - // Rerank calculates similarity scores between query and texts func (z *GoogleModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) { return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) diff --git a/internal/entity/models/minimax.go b/internal/entity/models/minimax.go index 2e512d3392..c1001d50c8 100644 --- a/internal/entity/models/minimax.go +++ b/internal/entity/models/minimax.go @@ -66,21 +66,11 @@ func (z *MinimaxModel) ChatStreamlyWithSender(modelName, message *string, apiCon return fmt.Errorf("%s, no such method", z.Name()) } -// EncodeToEmbedding encodes a list of texts into embeddings -func (z *MinimaxModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +// Encode encodes a list of texts into embeddings +func (z *MinimaxModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { return nil, fmt.Errorf("not implemented") } -// Encode encodes a list of texts into embeddings (convenience method) -func (z *MinimaxModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { - return nil, fmt.Errorf("%s, Encode not implemented", z.Name()) -} - -// EncodeQuery encodes a single query string into embedding (convenience method) -func (z *MinimaxModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { - return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name()) -} - func (z *MinimaxModel) ListModels(apiConfig *APIConfig) ([]string, error) { return nil, fmt.Errorf("%s, no such method", z.Name()) } diff --git a/internal/entity/models/moonshot.go b/internal/entity/models/moonshot.go index 448a822686..b436d672f1 100644 --- a/internal/entity/models/moonshot.go +++ b/internal/entity/models/moonshot.go @@ -332,21 +332,11 @@ func (k *MoonshotModel) ChatStreamlyWithSender(modelName, message *string, apiCo return scanner.Err() } -// EncodeToEmbedding encodes a list of texts into embeddings -func (z *MoonshotModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +// Encode encodes a list of texts into embeddings +func (z *MoonshotModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { return nil, fmt.Errorf("not implemented") } -// Encode encodes a list of texts into embeddings (convenience method) -func (z *MoonshotModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { - return nil, fmt.Errorf("%s, Encode not implemented", z.Name()) -} - -// EncodeQuery encodes a single query string into embedding (convenience method) -func (z *MoonshotModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { - return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name()) -} - func (z *MoonshotModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" if apiConfig.Region != nil { diff --git a/internal/entity/models/siliconflow.go b/internal/entity/models/siliconflow.go index 6b6d63d07b..2c191b3349 100644 --- a/internal/entity/models/siliconflow.go +++ b/internal/entity/models/siliconflow.go @@ -381,8 +381,8 @@ func (z *SiliconflowModel) ChatStreamlyWithSender(modelName, message *string, ap return scanner.Err() } -// EncodeToEmbedding encodes a list of texts into embeddings -func (s *SiliconflowModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +// Encode encodes a list of texts into embeddings +func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { if len(texts) == 0 { return [][]float64{}, nil } @@ -477,23 +477,6 @@ func (s *SiliconflowModel) EncodeToEmbedding(modelName *string, texts []string, return embeddings, nil } -// Encode encodes a list of texts into embeddings (convenience method) -func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { - return s.EncodeToEmbedding(modelName, texts, apiConfig, nil) -} - -// EncodeQuery encodes a single query string into embedding (convenience method) -func (s *SiliconflowModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { - embeddings, err := s.Encode(modelName, []string{query}, apiConfig) - if err != nil { - return nil, err - } - if len(embeddings) == 0 { - return nil, fmt.Errorf("no embedding returned") - } - return embeddings[0], nil -} - func (z *SiliconflowModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" if apiConfig.Region != nil { diff --git a/internal/entity/models/types.go b/internal/entity/models/types.go index cb9cbec3e7..fd4e031b0a 100644 --- a/internal/entity/models/types.go +++ b/internal/entity/models/types.go @@ -1,7 +1,5 @@ package models -import "fmt" - // Message represents a chat message with role type Message struct { Role string @@ -18,12 +16,8 @@ type ModelDriver interface { ChatWithMessages(modelName string, apiKey *string, messages []Message, modelConfig *ChatConfig) (string, error) // ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel) ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error - // EncodeToEmbedding encodes a list of texts into embeddings - EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) - // Encode encodes a list of texts into embeddings (convenience method) - Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) - // EncodeQuery encodes a single query string into embedding (convenience method) - EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) + // Encode encodes a list of texts into embeddings + Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) // Rerank calculates similarity scores between query and texts Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) // List suppported models @@ -89,23 +83,6 @@ func NewEmbeddingModel(driver ModelDriver, modelName *string, apiConfig *APIConf } } -// Encode encodes a list of texts into embeddings -func (e *EmbeddingModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { - return e.ModelDriver.EncodeToEmbedding(modelName, texts, apiConfig, nil) -} - -// EncodeQuery encodes a single query string into embedding -func (e *EmbeddingModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { - embeddings, err := e.ModelDriver.Encode(modelName, []string{query}, apiConfig) - if err != nil { - return nil, err - } - if len(embeddings) == 0 { - return nil, fmt.Errorf("no embedding returned") - } - return embeddings[0], nil -} - // RerankModel wraps a ModelDriver with rerank-specific configuration type RerankModel struct { ModelDriver ModelDriver diff --git a/internal/entity/models/volcengine.go b/internal/entity/models/volcengine.go index 044b21c0ef..f203412caf 100644 --- a/internal/entity/models/volcengine.go +++ b/internal/entity/models/volcengine.go @@ -66,21 +66,11 @@ func (z *VolcEngine) ChatStreamlyWithSender(modelName, message *string, apiConfi return fmt.Errorf("%s, no such method", z.Name()) } -// EncodeToEmbedding encodes a list of texts into embeddings -func (z *VolcEngine) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +// Encode encodes a list of texts into embeddings +func (z *VolcEngine) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { return nil, fmt.Errorf("not implemented") } -// Encode encodes a list of texts into embeddings (convenience method) -func (z *VolcEngine) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { - return nil, fmt.Errorf("%s, Encode not implemented", z.Name()) -} - -// EncodeQuery encodes a single query string into embedding (convenience method) -func (z *VolcEngine) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { - return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name()) -} - // Rerank calculates similarity scores between query and texts func (z *VolcEngine) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) { return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) diff --git a/internal/entity/models/zhipu-ai.go b/internal/entity/models/zhipu-ai.go index c041f39152..cc30578102 100644 --- a/internal/entity/models/zhipu-ai.go +++ b/internal/entity/models/zhipu-ai.go @@ -433,8 +433,8 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, message *string, apiCon return scanner.Err() } -// EncodeToEmbedding encodes a list of texts into embeddings -func (z *ZhipuAIModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +// Encode encodes a list of texts into embeddings +func (z *ZhipuAIModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { var region = "default" if apiConfig.Region != nil { region = *apiConfig.Region @@ -518,23 +518,6 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName *string, texts []string, apiC return embeddings, nil } -// Encode encodes a list of texts into embeddings (convenience method) -func (z *ZhipuAIModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { - return z.EncodeToEmbedding(modelName, texts, apiConfig, nil) -} - -// EncodeQuery encodes a single query string into embedding (convenience method) -func (z *ZhipuAIModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { - embeddings, err := z.Encode(modelName, []string{query}, apiConfig) - if err != nil { - return nil, err - } - if len(embeddings) == 0 { - return nil, fmt.Errorf("no embedding returned") - } - return embeddings[0], nil -} - func (z *ZhipuAIModel) ListModels(apiConfig *APIConfig) ([]string, error) { return nil, fmt.Errorf("%s, no such method", z.Name()) } diff --git a/internal/entity/types.go b/internal/entity/types.go index 8f78dd33f6..41154dcf41 100644 --- a/internal/entity/types.go +++ b/internal/entity/types.go @@ -43,9 +43,7 @@ const ( // EmbeddingModel interface for embedding models type EmbeddingModel interface { // Encode encodes a list of texts into embeddings - Encode(modelName *string, texts []string, apiConfig *models.APIConfig) ([][]float64, error) - // EncodeQuery encodes a single query string into embedding - EncodeQuery(modelName *string, query string, apiConfig *models.APIConfig) ([]float64, error) + Encode(modelName *string, texts []string, apiConfig *models.APIConfig, embeddingConfig *models.EmbeddingConfig) ([][]float64, error) } // ChatModel interface for chat models diff --git a/internal/service/model_bundle.go b/internal/service/model_bundle.go index 0f3fc6a65a..528de89d02 100644 --- a/internal/service/model_bundle.go +++ b/internal/service/model_bundle.go @@ -90,7 +90,7 @@ func (b *ModelBundle) Encode(texts []string) ([][]float64, int64, error) { return nil, 0, fmt.Errorf("model is not an embedding model") } - embeddings, err := embeddingModel.Encode(&b.modelName, texts, b.apiConfig) + embeddings, err := embeddingModel.Encode(&b.modelName, texts, b.apiConfig, b.embeddingConfig) if err != nil { return nil, 0, err } @@ -117,15 +117,18 @@ func (b *ModelBundle) EncodeQuery(query string) ([]float64, int64, error) { return nil, 0, fmt.Errorf("model is not an embedding model") } - embedding, err := embeddingModel.EncodeQuery(&b.modelName, query, b.apiConfig) + embeddings, err := embeddingModel.Encode(&b.modelName, []string{query}, b.apiConfig, b.embeddingConfig) if err != nil { return nil, 0, err } + if len(embeddings) == 0 { + return nil, 0, fmt.Errorf("no embedding returned") + } // TODO: Calculate actual token count tokenCount := int64(len(query) / 4) - return embedding, tokenCount, nil + return embeddings[0], tokenCount, nil } // Chat sends a chat message and returns response diff --git a/internal/service/nlp/retrieval.go b/internal/service/nlp/retrieval.go index a03339a385..c271d32f40 100644 --- a/internal/service/nlp/retrieval.go +++ b/internal/service/nlp/retrieval.go @@ -597,11 +597,12 @@ func (s *RetrievalService) Search(ctx context.Context, req *RetrievalSearchReque // GetVector computes query vector and returns MatchDenseExpr for hybrid search func (s *RetrievalService) GetVector(txt string, embModel *models.EmbeddingModel, topk int, similarity float64) (*types.MatchDenseExpr, error) { - vector, err := embModel.ModelDriver.EncodeQuery(embModel.ModelName, txt, embModel.APIConfig) + embeddings, err := embModel.ModelDriver.Encode(embModel.ModelName, []string{txt}, embModel.APIConfig, nil) if err != nil { return nil, err } + vector := embeddings[0] vectorSize := len(vector) vectorColumnName := fmt.Sprintf("q_%d_vec", vectorSize)