Simplify Encode (#14437)

### What problem does this PR solve?

Simplify Encode

### Type of change

- [x] Refactoring
This commit is contained in:
qinling0210
2026-04-28 18:07:42 +08:00
committed by GitHub
parent d532151be0
commit dcce864d4c
14 changed files with 31 additions and 173 deletions

View File

@@ -332,21 +332,11 @@ func (z *AliyunModel) ChatStreamlyWithSender(modelName, message *string, apiConf
return scanner.Err()
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *AliyunModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
// Encode encodes a list of texts into embeddings
func (z *AliyunModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("%s, no such method", z.Name())
}
// Encode encodes a list of texts into embeddings (convenience method)
func (z *AliyunModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
}
// EncodeQuery encodes a single query string into embedding (convenience method)
func (z *AliyunModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
}
// Rerank calculates similarity scores between query and texts
func (z *AliyunModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())

View File

@@ -396,21 +396,11 @@ func (z *DeepSeekModel) ChatStreamlyWithSender(modelName, message *string, apiCo
return scanner.Err()
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *DeepSeekModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
// Encode encodes a list of texts into embeddings
func (z *DeepSeekModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("%s, no such method", z.Name())
}
// Encode encodes a list of texts into embeddings (convenience method)
func (z *DeepSeekModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
}
// EncodeQuery encodes a single query string into embedding (convenience method)
func (z *DeepSeekModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
}
type DSModel struct {
ID string `json:"id"`
Object string `json:"object"`

View File

@@ -53,21 +53,11 @@ func (z *DummyModel) ChatStreamlyWithSender(modelName, message *string, apiConfi
return fmt.Errorf("not implemented")
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *DummyModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
// Encode encodes a list of texts into embeddings
func (z *DummyModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("not implemented")
}
// Encode encodes a list of texts into embeddings (convenience method)
func (z *DummyModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
}
// EncodeQuery encodes a single query string into embedding (convenience method)
func (z *DummyModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
}
func (z *DummyModel) ListModels(apiConfig *APIConfig) ([]string, error) {
return nil, fmt.Errorf("not implemented")
}

View File

@@ -362,21 +362,11 @@ func (z *GiteeModel) ChatStreamlyWithSender(modelName, message *string, apiConfi
return scanner.Err()
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *GiteeModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
// Encode encodes a list of texts into embeddings
func (z *GiteeModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("%s, no such method", z.Name())
}
// Encode encodes a list of texts into embeddings (convenience method)
func (z *GiteeModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
}
// EncodeQuery encodes a single query string into embedding (convenience method)
func (z *GiteeModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
}
// Rerank calculates similarity scores between query and texts
func (z *GiteeModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())

View File

@@ -136,8 +136,8 @@ func (z *GoogleModel) ChatStreamlyWithSender(modelName, message *string, apiConf
return err
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *GoogleModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
// Encode encodes a list of texts into embeddings
func (z *GoogleModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("not implemented")
}
@@ -172,23 +172,6 @@ func (z *GoogleModel) CheckConnection(apiConfig *APIConfig) error {
return fmt.Errorf("no such method")
}
// Encode encodes a list of texts into embeddings (convenience method)
func (z *GoogleModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
return z.EncodeToEmbedding(modelName, texts, apiConfig, nil)
}
// EncodeQuery encodes a single query string into embedding (convenience method)
func (z *GoogleModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
embeddings, err := z.Encode(modelName, []string{query}, apiConfig)
if err != nil {
return nil, err
}
if len(embeddings) == 0 {
return nil, fmt.Errorf("no embedding returned")
}
return embeddings[0], nil
}
// Rerank calculates similarity scores between query and texts
func (z *GoogleModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())

View File

@@ -66,21 +66,11 @@ func (z *MinimaxModel) ChatStreamlyWithSender(modelName, message *string, apiCon
return fmt.Errorf("%s, no such method", z.Name())
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *MinimaxModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
// Encode encodes a list of texts into embeddings
func (z *MinimaxModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("not implemented")
}
// Encode encodes a list of texts into embeddings (convenience method)
func (z *MinimaxModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
}
// EncodeQuery encodes a single query string into embedding (convenience method)
func (z *MinimaxModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
}
func (z *MinimaxModel) ListModels(apiConfig *APIConfig) ([]string, error) {
return nil, fmt.Errorf("%s, no such method", z.Name())
}

View File

@@ -332,21 +332,11 @@ func (k *MoonshotModel) ChatStreamlyWithSender(modelName, message *string, apiCo
return scanner.Err()
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *MoonshotModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
// Encode encodes a list of texts into embeddings
func (z *MoonshotModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("not implemented")
}
// Encode encodes a list of texts into embeddings (convenience method)
func (z *MoonshotModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
}
// EncodeQuery encodes a single query string into embedding (convenience method)
func (z *MoonshotModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
}
func (z *MoonshotModel) ListModels(apiConfig *APIConfig) ([]string, error) {
var region = "default"
if apiConfig.Region != nil {

View File

@@ -381,8 +381,8 @@ func (z *SiliconflowModel) ChatStreamlyWithSender(modelName, message *string, ap
return scanner.Err()
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (s *SiliconflowModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
// Encode encodes a list of texts into embeddings
func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
if len(texts) == 0 {
return [][]float64{}, nil
}
@@ -477,23 +477,6 @@ func (s *SiliconflowModel) EncodeToEmbedding(modelName *string, texts []string,
return embeddings, nil
}
// Encode encodes a list of texts into embeddings (convenience method)
func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
return s.EncodeToEmbedding(modelName, texts, apiConfig, nil)
}
// EncodeQuery encodes a single query string into embedding (convenience method)
func (s *SiliconflowModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
embeddings, err := s.Encode(modelName, []string{query}, apiConfig)
if err != nil {
return nil, err
}
if len(embeddings) == 0 {
return nil, fmt.Errorf("no embedding returned")
}
return embeddings[0], nil
}
func (z *SiliconflowModel) ListModels(apiConfig *APIConfig) ([]string, error) {
var region = "default"
if apiConfig.Region != nil {

View File

@@ -1,7 +1,5 @@
package models
import "fmt"
// Message represents a chat message with role
type Message struct {
Role string
@@ -18,12 +16,8 @@ type ModelDriver interface {
ChatWithMessages(modelName string, apiKey *string, messages []Message, modelConfig *ChatConfig) (string, error)
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error
// EncodeToEmbedding encodes a list of texts into embeddings
EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error)
// Encode encodes a list of texts into embeddings (convenience method)
Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error)
// EncodeQuery encodes a single query string into embedding (convenience method)
EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error)
// Encode encodes a list of texts into embeddings
Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error)
// Rerank calculates similarity scores between query and texts
Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error)
// List suppported models
@@ -89,23 +83,6 @@ func NewEmbeddingModel(driver ModelDriver, modelName *string, apiConfig *APIConf
}
}
// Encode encodes a list of texts into embeddings
func (e *EmbeddingModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
return e.ModelDriver.EncodeToEmbedding(modelName, texts, apiConfig, nil)
}
// EncodeQuery encodes a single query string into embedding
func (e *EmbeddingModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
embeddings, err := e.ModelDriver.Encode(modelName, []string{query}, apiConfig)
if err != nil {
return nil, err
}
if len(embeddings) == 0 {
return nil, fmt.Errorf("no embedding returned")
}
return embeddings[0], nil
}
// RerankModel wraps a ModelDriver with rerank-specific configuration
type RerankModel struct {
ModelDriver ModelDriver

View File

@@ -66,21 +66,11 @@ func (z *VolcEngine) ChatStreamlyWithSender(modelName, message *string, apiConfi
return fmt.Errorf("%s, no such method", z.Name())
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *VolcEngine) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
// Encode encodes a list of texts into embeddings
func (z *VolcEngine) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("not implemented")
}
// Encode encodes a list of texts into embeddings (convenience method)
func (z *VolcEngine) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
}
// EncodeQuery encodes a single query string into embedding (convenience method)
func (z *VolcEngine) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
}
// Rerank calculates similarity scores between query and texts
func (z *VolcEngine) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())

View File

@@ -433,8 +433,8 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, message *string, apiCon
return scanner.Err()
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *ZhipuAIModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
// Encode encodes a list of texts into embeddings
func (z *ZhipuAIModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
var region = "default"
if apiConfig.Region != nil {
region = *apiConfig.Region
@@ -518,23 +518,6 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName *string, texts []string, apiC
return embeddings, nil
}
// Encode encodes a list of texts into embeddings (convenience method)
func (z *ZhipuAIModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
return z.EncodeToEmbedding(modelName, texts, apiConfig, nil)
}
// EncodeQuery encodes a single query string into embedding (convenience method)
func (z *ZhipuAIModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
embeddings, err := z.Encode(modelName, []string{query}, apiConfig)
if err != nil {
return nil, err
}
if len(embeddings) == 0 {
return nil, fmt.Errorf("no embedding returned")
}
return embeddings[0], nil
}
func (z *ZhipuAIModel) ListModels(apiConfig *APIConfig) ([]string, error) {
return nil, fmt.Errorf("%s, no such method", z.Name())
}

View File

@@ -43,9 +43,7 @@ const (
// EmbeddingModel interface for embedding models
type EmbeddingModel interface {
// Encode encodes a list of texts into embeddings
Encode(modelName *string, texts []string, apiConfig *models.APIConfig) ([][]float64, error)
// EncodeQuery encodes a single query string into embedding
EncodeQuery(modelName *string, query string, apiConfig *models.APIConfig) ([]float64, error)
Encode(modelName *string, texts []string, apiConfig *models.APIConfig, embeddingConfig *models.EmbeddingConfig) ([][]float64, error)
}
// ChatModel interface for chat models

View File

@@ -90,7 +90,7 @@ func (b *ModelBundle) Encode(texts []string) ([][]float64, int64, error) {
return nil, 0, fmt.Errorf("model is not an embedding model")
}
embeddings, err := embeddingModel.Encode(&b.modelName, texts, b.apiConfig)
embeddings, err := embeddingModel.Encode(&b.modelName, texts, b.apiConfig, b.embeddingConfig)
if err != nil {
return nil, 0, err
}
@@ -117,15 +117,18 @@ func (b *ModelBundle) EncodeQuery(query string) ([]float64, int64, error) {
return nil, 0, fmt.Errorf("model is not an embedding model")
}
embedding, err := embeddingModel.EncodeQuery(&b.modelName, query, b.apiConfig)
embeddings, err := embeddingModel.Encode(&b.modelName, []string{query}, b.apiConfig, b.embeddingConfig)
if err != nil {
return nil, 0, err
}
if len(embeddings) == 0 {
return nil, 0, fmt.Errorf("no embedding returned")
}
// TODO: Calculate actual token count
tokenCount := int64(len(query) / 4)
return embedding, tokenCount, nil
return embeddings[0], tokenCount, nil
}
// Chat sends a chat message and returns response

View File

@@ -597,11 +597,12 @@ func (s *RetrievalService) Search(ctx context.Context, req *RetrievalSearchReque
// GetVector computes query vector and returns MatchDenseExpr for hybrid search
func (s *RetrievalService) GetVector(txt string, embModel *models.EmbeddingModel, topk int, similarity float64) (*types.MatchDenseExpr, error) {
vector, err := embModel.ModelDriver.EncodeQuery(embModel.ModelName, txt, embModel.APIConfig)
embeddings, err := embModel.ModelDriver.Encode(embModel.ModelName, []string{txt}, embModel.APIConfig, nil)
if err != nil {
return nil, err
}
vector := embeddings[0]
vectorSize := len(vector)
vectorColumnName := fmt.Sprintf("q_%d_vec", vectorSize)