mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 00:05:43 +08:00
Simplify Encode (#14437)
### What problem does this PR solve? Simplify Encode ### Type of change - [x] Refactoring
This commit is contained in:
@@ -332,21 +332,11 @@ func (z *AliyunModel) ChatStreamlyWithSender(modelName, message *string, apiConf
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *AliyunModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
// Encode encodes a list of texts into embeddings
|
||||
func (z *AliyunModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", z.Name())
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *AliyunModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *AliyunModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
func (z *AliyunModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())
|
||||
|
||||
@@ -396,21 +396,11 @@ func (z *DeepSeekModel) ChatStreamlyWithSender(modelName, message *string, apiCo
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *DeepSeekModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
// Encode encodes a list of texts into embeddings
|
||||
func (z *DeepSeekModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", z.Name())
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *DeepSeekModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *DeepSeekModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
|
||||
}
|
||||
|
||||
type DSModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
|
||||
@@ -53,21 +53,11 @@ func (z *DummyModel) ChatStreamlyWithSender(modelName, message *string, apiConfi
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *DummyModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
// Encode encodes a list of texts into embeddings
|
||||
func (z *DummyModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *DummyModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *DummyModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
|
||||
}
|
||||
|
||||
func (z *DummyModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
@@ -362,21 +362,11 @@ func (z *GiteeModel) ChatStreamlyWithSender(modelName, message *string, apiConfi
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *GiteeModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
// Encode encodes a list of texts into embeddings
|
||||
func (z *GiteeModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", z.Name())
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *GiteeModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *GiteeModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
func (z *GiteeModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())
|
||||
|
||||
@@ -136,8 +136,8 @@ func (z *GoogleModel) ChatStreamlyWithSender(modelName, message *string, apiConf
|
||||
return err
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *GoogleModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
// Encode encodes a list of texts into embeddings
|
||||
func (z *GoogleModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
@@ -172,23 +172,6 @@ func (z *GoogleModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
return fmt.Errorf("no such method")
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *GoogleModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return z.EncodeToEmbedding(modelName, texts, apiConfig, nil)
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *GoogleModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
embeddings, err := z.Encode(modelName, []string{query}, apiConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return nil, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
return embeddings[0], nil
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
func (z *GoogleModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())
|
||||
|
||||
@@ -66,21 +66,11 @@ func (z *MinimaxModel) ChatStreamlyWithSender(modelName, message *string, apiCon
|
||||
return fmt.Errorf("%s, no such method", z.Name())
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *MinimaxModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
// Encode encodes a list of texts into embeddings
|
||||
func (z *MinimaxModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *MinimaxModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *MinimaxModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
|
||||
}
|
||||
|
||||
func (z *MinimaxModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", z.Name())
|
||||
}
|
||||
|
||||
@@ -332,21 +332,11 @@ func (k *MoonshotModel) ChatStreamlyWithSender(modelName, message *string, apiCo
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *MoonshotModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
// Encode encodes a list of texts into embeddings
|
||||
func (z *MoonshotModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *MoonshotModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *MoonshotModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
|
||||
}
|
||||
|
||||
func (z *MoonshotModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil {
|
||||
|
||||
@@ -381,8 +381,8 @@ func (z *SiliconflowModel) ChatStreamlyWithSender(modelName, message *string, ap
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (s *SiliconflowModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
// Encode encodes a list of texts into embeddings
|
||||
func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
if len(texts) == 0 {
|
||||
return [][]float64{}, nil
|
||||
}
|
||||
@@ -477,23 +477,6 @@ func (s *SiliconflowModel) EncodeToEmbedding(modelName *string, texts []string,
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return s.EncodeToEmbedding(modelName, texts, apiConfig, nil)
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (s *SiliconflowModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
embeddings, err := s.Encode(modelName, []string{query}, apiConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return nil, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
return embeddings[0], nil
|
||||
}
|
||||
|
||||
func (z *SiliconflowModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil {
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
package models
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Message represents a chat message with role
|
||||
type Message struct {
|
||||
Role string
|
||||
@@ -18,12 +16,8 @@ type ModelDriver interface {
|
||||
ChatWithMessages(modelName string, apiKey *string, messages []Message, modelConfig *ChatConfig) (string, error)
|
||||
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
|
||||
ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error)
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error)
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error)
|
||||
// Encode encodes a list of texts into embeddings
|
||||
Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error)
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error)
|
||||
// List suppported models
|
||||
@@ -89,23 +83,6 @@ func NewEmbeddingModel(driver ModelDriver, modelName *string, apiConfig *APIConf
|
||||
}
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings
|
||||
func (e *EmbeddingModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return e.ModelDriver.EncodeToEmbedding(modelName, texts, apiConfig, nil)
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding
|
||||
func (e *EmbeddingModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
embeddings, err := e.ModelDriver.Encode(modelName, []string{query}, apiConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return nil, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
return embeddings[0], nil
|
||||
}
|
||||
|
||||
// RerankModel wraps a ModelDriver with rerank-specific configuration
|
||||
type RerankModel struct {
|
||||
ModelDriver ModelDriver
|
||||
|
||||
@@ -66,21 +66,11 @@ func (z *VolcEngine) ChatStreamlyWithSender(modelName, message *string, apiConfi
|
||||
return fmt.Errorf("%s, no such method", z.Name())
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *VolcEngine) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
// Encode encodes a list of texts into embeddings
|
||||
func (z *VolcEngine) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *VolcEngine) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *VolcEngine) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
func (z *VolcEngine) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())
|
||||
|
||||
@@ -433,8 +433,8 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, message *string, apiCon
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *ZhipuAIModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
// Encode encodes a list of texts into embeddings
|
||||
func (z *ZhipuAIModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil {
|
||||
region = *apiConfig.Region
|
||||
@@ -518,23 +518,6 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName *string, texts []string, apiC
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *ZhipuAIModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return z.EncodeToEmbedding(modelName, texts, apiConfig, nil)
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *ZhipuAIModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
embeddings, err := z.Encode(modelName, []string{query}, apiConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return nil, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
return embeddings[0], nil
|
||||
}
|
||||
|
||||
func (z *ZhipuAIModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", z.Name())
|
||||
}
|
||||
|
||||
@@ -43,9 +43,7 @@ const (
|
||||
// EmbeddingModel interface for embedding models
|
||||
type EmbeddingModel interface {
|
||||
// Encode encodes a list of texts into embeddings
|
||||
Encode(modelName *string, texts []string, apiConfig *models.APIConfig) ([][]float64, error)
|
||||
// EncodeQuery encodes a single query string into embedding
|
||||
EncodeQuery(modelName *string, query string, apiConfig *models.APIConfig) ([]float64, error)
|
||||
Encode(modelName *string, texts []string, apiConfig *models.APIConfig, embeddingConfig *models.EmbeddingConfig) ([][]float64, error)
|
||||
}
|
||||
|
||||
// ChatModel interface for chat models
|
||||
|
||||
@@ -90,7 +90,7 @@ func (b *ModelBundle) Encode(texts []string) ([][]float64, int64, error) {
|
||||
return nil, 0, fmt.Errorf("model is not an embedding model")
|
||||
}
|
||||
|
||||
embeddings, err := embeddingModel.Encode(&b.modelName, texts, b.apiConfig)
|
||||
embeddings, err := embeddingModel.Encode(&b.modelName, texts, b.apiConfig, b.embeddingConfig)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -117,15 +117,18 @@ func (b *ModelBundle) EncodeQuery(query string) ([]float64, int64, error) {
|
||||
return nil, 0, fmt.Errorf("model is not an embedding model")
|
||||
}
|
||||
|
||||
embedding, err := embeddingModel.EncodeQuery(&b.modelName, query, b.apiConfig)
|
||||
embeddings, err := embeddingModel.Encode(&b.modelName, []string{query}, b.apiConfig, b.embeddingConfig)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return nil, 0, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
|
||||
// TODO: Calculate actual token count
|
||||
tokenCount := int64(len(query) / 4)
|
||||
|
||||
return embedding, tokenCount, nil
|
||||
return embeddings[0], tokenCount, nil
|
||||
}
|
||||
|
||||
// Chat sends a chat message and returns response
|
||||
|
||||
@@ -597,11 +597,12 @@ func (s *RetrievalService) Search(ctx context.Context, req *RetrievalSearchReque
|
||||
|
||||
// GetVector computes query vector and returns MatchDenseExpr for hybrid search
|
||||
func (s *RetrievalService) GetVector(txt string, embModel *models.EmbeddingModel, topk int, similarity float64) (*types.MatchDenseExpr, error) {
|
||||
vector, err := embModel.ModelDriver.EncodeQuery(embModel.ModelName, txt, embModel.APIConfig)
|
||||
embeddings, err := embModel.ModelDriver.Encode(embModel.ModelName, []string{txt}, embModel.APIConfig, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vector := embeddings[0]
|
||||
vectorSize := len(vector)
|
||||
vectorColumnName := fmt.Sprintf("q_%d_vec", vectorSize)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user