From effc84a042bfadd96363cb5ca2732bacf4ef93bf Mon Sep 17 00:00:00 2001 From: qinling0210 <88864212+qinling0210@users.noreply.github.com> Date: Tue, 28 Apr 2026 12:59:01 +0800 Subject: [PATCH] Refactor model in GO (#14398) ### What problem does this PR solve? Refactor model in GO ### Type of change - [x] Refactoring --- conf/models/siliconflow.json | 7 + conf/models/zhipu-ai.json | 2 +- internal/entity/models/aliyun.go | 15 + internal/entity/models/deepseek.go | 15 + internal/entity/models/dummy.go | 15 + internal/entity/models/gitee.go | 15 + internal/entity/models/google.go | 22 + internal/entity/models/minimax.go | 15 + internal/entity/models/moonshot.go | 15 + internal/entity/models/siliconflow.go | 203 +++++++++- internal/entity/models/types.go | 80 +++- internal/entity/models/zhipu-ai.go | 26 +- internal/entity/types.go | 12 +- internal/handler/providers.go | 3 + internal/router/router.go | 2 +- internal/service/chunk.go | 20 +- internal/service/model_bundle.go | 46 ++- internal/service/model_service.go | 224 +++++------ internal/service/models/deepseek_model.go | 33 -- internal/service/models/factory.go | 119 ------ internal/service/models/gitee_model.go | 127 ------ internal/service/models/moonshot_model.go | 33 -- .../models/openai_api_compatible_model.go | 33 -- internal/service/models/openai_model.go | 124 ------ internal/service/models/siliconflow_model.go | 380 ------------------ internal/service/models/zhipu_model.go | 33 -- internal/service/nlp/reranker.go | 16 +- internal/service/nlp/retrieval.go | 18 +- 28 files changed, 575 insertions(+), 1078 deletions(-) delete mode 100644 internal/service/models/deepseek_model.go delete mode 100644 internal/service/models/factory.go delete mode 100644 internal/service/models/gitee_model.go delete mode 100644 internal/service/models/moonshot_model.go delete mode 100644 internal/service/models/openai_api_compatible_model.go delete mode 100644 internal/service/models/openai_model.go delete mode 100644 internal/service/models/siliconflow_model.go delete mode 100644 internal/service/models/zhipu_model.go diff --git a/conf/models/siliconflow.json b/conf/models/siliconflow.json index ad9e2bde28..d9340365d0 100644 --- a/conf/models/siliconflow.json +++ b/conf/models/siliconflow.json @@ -37,6 +37,13 @@ "model_types": [ "rerank" ] + }, + { + "name": "Qwen/Qwen3-Embedding-0.6B", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] } ] } diff --git a/conf/models/zhipu-ai.json b/conf/models/zhipu-ai.json index d7414e94c4..1027dc5273 100644 --- a/conf/models/zhipu-ai.json +++ b/conf/models/zhipu-ai.json @@ -7,7 +7,7 @@ "chat": "chat/completions", "async_chat": "async/chat/completions", "async_result": "async-result", - "embedding": "embedding", + "embedding": "embeddings", "rerank": "rerank", "files": "files" }, diff --git a/internal/entity/models/aliyun.go b/internal/entity/models/aliyun.go index f3ed09a68a..4975ed295e 100644 --- a/internal/entity/models/aliyun.go +++ b/internal/entity/models/aliyun.go @@ -337,6 +337,21 @@ func (z *AliyunModel) EncodeToEmbedding(modelName *string, texts []string, apiCo return nil, fmt.Errorf("%s, no such method", z.Name()) } +// Encode encodes a list of texts into embeddings (convenience method) +func (z *AliyunModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { + return nil, fmt.Errorf("%s, Encode not implemented", z.Name()) +} + +// EncodeQuery encodes a single query string into embedding (convenience method) +func (z *AliyunModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { + return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name()) +} + +// Rerank calculates similarity scores between query and texts +func (z *AliyunModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) { + return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) +} + type AliyunModelItem struct { ModelName string `json:"model_name"` BaseCapacity int `json:"base_capacity"` diff --git a/internal/entity/models/deepseek.go b/internal/entity/models/deepseek.go index 9ca5f534f8..eee8b800d3 100644 --- a/internal/entity/models/deepseek.go +++ b/internal/entity/models/deepseek.go @@ -401,6 +401,16 @@ func (z *DeepSeekModel) EncodeToEmbedding(modelName *string, texts []string, api return nil, fmt.Errorf("%s, no such method", z.Name()) } +// Encode encodes a list of texts into embeddings (convenience method) +func (z *DeepSeekModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { + return nil, fmt.Errorf("%s, Encode not implemented", z.Name()) +} + +// EncodeQuery encodes a single query string into embedding (convenience method) +func (z *DeepSeekModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { + return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name()) +} + type DSModel struct { ID string `json:"id"` Object string `json:"object"` @@ -476,3 +486,8 @@ func (z *DeepSeekModel) CheckConnection(apiConfig *APIConfig) error { } return nil } + +// Rerank calculates similarity scores between query and texts +func (z *DeepSeekModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) { + return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) +} diff --git a/internal/entity/models/dummy.go b/internal/entity/models/dummy.go index e7be91543c..e93de49fe4 100644 --- a/internal/entity/models/dummy.go +++ b/internal/entity/models/dummy.go @@ -58,6 +58,16 @@ func (z *DummyModel) EncodeToEmbedding(modelName *string, texts []string, apiCon return nil, fmt.Errorf("not implemented") } +// Encode encodes a list of texts into embeddings (convenience method) +func (z *DummyModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { + return nil, fmt.Errorf("%s, Encode not implemented", z.Name()) +} + +// EncodeQuery encodes a single query string into embedding (convenience method) +func (z *DummyModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { + return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name()) +} + func (z *DummyModel) ListModels(apiConfig *APIConfig) ([]string, error) { return nil, fmt.Errorf("not implemented") } @@ -69,3 +79,8 @@ func (z *DummyModel) Balance(apiConfig *APIConfig) (map[string]interface{}, erro func (z *DummyModel) CheckConnection(apiConfig *APIConfig) error { return fmt.Errorf("no such method") } + +// Rerank calculates similarity scores between query and texts +func (z *DummyModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) { + return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) +} diff --git a/internal/entity/models/gitee.go b/internal/entity/models/gitee.go index 35cc7ef8ca..2ea88a450a 100644 --- a/internal/entity/models/gitee.go +++ b/internal/entity/models/gitee.go @@ -367,6 +367,21 @@ func (z *GiteeModel) EncodeToEmbedding(modelName *string, texts []string, apiCon return nil, fmt.Errorf("%s, no such method", z.Name()) } +// Encode encodes a list of texts into embeddings (convenience method) +func (z *GiteeModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { + return nil, fmt.Errorf("%s, Encode not implemented", z.Name()) +} + +// EncodeQuery encodes a single query string into embedding (convenience method) +func (z *GiteeModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { + return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name()) +} + +// Rerank calculates similarity scores between query and texts +func (z *GiteeModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) { + return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) +} + func (z *GiteeModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" if apiConfig.Region != nil { diff --git a/internal/entity/models/google.go b/internal/entity/models/google.go index 461416c35f..c0c3b20f7d 100644 --- a/internal/entity/models/google.go +++ b/internal/entity/models/google.go @@ -171,3 +171,25 @@ func (z *GoogleModel) Balance(apiConfig *APIConfig) (map[string]interface{}, err func (z *GoogleModel) CheckConnection(apiConfig *APIConfig) error { return fmt.Errorf("no such method") } + +// Encode encodes a list of texts into embeddings (convenience method) +func (z *GoogleModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { + return z.EncodeToEmbedding(modelName, texts, apiConfig, nil) +} + +// EncodeQuery encodes a single query string into embedding (convenience method) +func (z *GoogleModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { + embeddings, err := z.Encode(modelName, []string{query}, apiConfig) + if err != nil { + return nil, err + } + if len(embeddings) == 0 { + return nil, fmt.Errorf("no embedding returned") + } + return embeddings[0], nil +} + +// Rerank calculates similarity scores between query and texts +func (z *GoogleModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) { + return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) +} diff --git a/internal/entity/models/minimax.go b/internal/entity/models/minimax.go index 011ac4725b..2e512d3392 100644 --- a/internal/entity/models/minimax.go +++ b/internal/entity/models/minimax.go @@ -71,6 +71,16 @@ func (z *MinimaxModel) EncodeToEmbedding(modelName *string, texts []string, apiC return nil, fmt.Errorf("not implemented") } +// Encode encodes a list of texts into embeddings (convenience method) +func (z *MinimaxModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { + return nil, fmt.Errorf("%s, Encode not implemented", z.Name()) +} + +// EncodeQuery encodes a single query string into embedding (convenience method) +func (z *MinimaxModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { + return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name()) +} + func (z *MinimaxModel) ListModels(apiConfig *APIConfig) ([]string, error) { return nil, fmt.Errorf("%s, no such method", z.Name()) } @@ -112,3 +122,8 @@ func (z *MinimaxModel) CheckConnection(apiConfig *APIConfig) error { return nil } + +// Rerank calculates similarity scores between query and texts +func (z *MinimaxModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) { + return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) +} diff --git a/internal/entity/models/moonshot.go b/internal/entity/models/moonshot.go index ab7ba2aeaf..f35558ef8b 100644 --- a/internal/entity/models/moonshot.go +++ b/internal/entity/models/moonshot.go @@ -73,6 +73,16 @@ func (z *MoonshotModel) EncodeToEmbedding(modelName *string, texts []string, api return nil, fmt.Errorf("not implemented") } +// Encode encodes a list of texts into embeddings (convenience method) +func (z *MoonshotModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { + return nil, fmt.Errorf("%s, Encode not implemented", z.Name()) +} + +// EncodeQuery encodes a single query string into embedding (convenience method) +func (z *MoonshotModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { + return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name()) +} + func (z *MoonshotModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" if apiConfig.Region != nil { @@ -193,3 +203,8 @@ func (z *MoonshotModel) CheckConnection(apiConfig *APIConfig) error { } return nil } + +// Rerank calculates similarity scores between query and texts +func (z *MoonshotModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) { + return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) +} diff --git a/internal/entity/models/siliconflow.go b/internal/entity/models/siliconflow.go index 8edb0e7436..5938d23782 100644 --- a/internal/entity/models/siliconflow.go +++ b/internal/entity/models/siliconflow.go @@ -56,6 +56,26 @@ func (z *SiliconflowModel) Name() string { return "siliconflow" } + +// SiliconflowRerankRequest represents SILICONFLOW rerank request +type SiliconflowRerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + Documents []string `json:"documents"` + TopN int `json:"top_n"` + ReturnDocuments bool `json:"return_documents"` + MaxChunksPerDoc int `json:"max_chunks_per_doc"` + OverlapTokens int `json:"overlap_tokens"` +} + +// SiliconflowRerankResponse represents SILICONFLOW rerank response +type SiliconflowRerankResponse struct { + Results []struct { + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` + } `json:"results"` +} + // Chat sends a message and returns response func (z *SiliconflowModel) Chat(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { if message == nil { @@ -363,8 +383,116 @@ func (z *SiliconflowModel) ChatStreamlyWithSender(modelName, message *string, ap } // EncodeToEmbedding encodes a list of texts into embeddings -func (z *SiliconflowModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { - return nil, fmt.Errorf("%s, no such method", z.Name()) +func (s *SiliconflowModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { + if len(texts) == 0 { + return [][]float64{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(s.BaseURL[region], "/"), s.URLSuffix.Embedding) + + apiKey := "" + if apiConfig != nil && apiConfig.ApiKey != nil { + apiKey = *apiConfig.ApiKey + } + + embeddings := make([][]float64, len(texts)) + + for i, text := range texts { + reqBody := map[string]interface{}{ + "model": modelName, + "input": text, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } + + resp, err := s.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("SILICONFLOW API error: %s, body: %s", resp.Status, string(body)) + } + + // Parse response + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + data, ok := result["data"].([]interface{}) + if !ok || len(data) == 0 { + return nil, fmt.Errorf("no data in response") + } + + firstData, ok := data[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid data format") + } + + embeddingSlice, ok := firstData["embedding"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid embedding format") + } + + embedding := make([]float64, len(embeddingSlice)) + for j, v := range embeddingSlice { + switch val := v.(type) { + case float64: + embedding[j] = val + case float32: + embedding[j] = float64(val) + default: + return nil, fmt.Errorf("unexpected embedding value type") + } + } + + embeddings[i] = embedding + } + + return embeddings, nil +} + +// Encode encodes a list of texts into embeddings (convenience method) +func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { + return s.EncodeToEmbedding(modelName, texts, apiConfig, nil) +} + +// EncodeQuery encodes a single query string into embedding (convenience method) +func (s *SiliconflowModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { + embeddings, err := s.Encode(modelName, []string{query}, apiConfig) + if err != nil { + return nil, err + } + if len(embeddings) == 0 { + return nil, fmt.Errorf("no embedding returned") + } + return embeddings[0], nil } func (z *SiliconflowModel) ListModels(apiConfig *APIConfig) ([]string, error) { @@ -435,3 +563,74 @@ func (z *SiliconflowModel) CheckConnection(apiConfig *APIConfig) error { } return nil } + +// Rerank calculates similarity scores between query and texts +func (s *SiliconflowModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) { + if len(texts) == 0 { + return []float64{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil { + region = *apiConfig.Region + } + + apiKey := "" + if apiConfig != nil && apiConfig.ApiKey != nil { + apiKey = *apiConfig.ApiKey + } + + reqBody := SiliconflowRerankRequest{ + Model: *modelName, + Query: query, + Documents: texts, + TopN: len(texts), + ReturnDocuments: false, + MaxChunksPerDoc: 1024, + OverlapTokens: 80, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(s.BaseURL[region], "/"), s.URLSuffix.Rerank) + + req, err := http.NewRequest("POST", url, strings.NewReader(string(jsonData))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } + + resp, err := s.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("SiliconFlow Rerank API error: %s, body: %s", resp.Status, string(body)) + } + + body, _ := io.ReadAll(resp.Body) + + var rerankResp SiliconflowRerankResponse + if err := json.Unmarshal(body, &rerankResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + scores := make([]float64, len(texts)) + for _, result := range rerankResp.Results { + if result.Index >= 0 && result.Index < len(texts) { + scores[result.Index] = result.RelevanceScore + } + } + + return scores, nil +} diff --git a/internal/entity/models/types.go b/internal/entity/models/types.go index 1163a438e7..0043bef41a 100644 --- a/internal/entity/models/types.go +++ b/internal/entity/models/types.go @@ -1,5 +1,7 @@ package models +import "fmt" + // Message represents a chat message with role type Message struct { Role string @@ -16,8 +18,14 @@ type ModelDriver interface { ChatWithMessages(modelName string, apiKey *string, messages []Message, modelConfig *ChatConfig) (string, error) // ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel) ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error - // Encode encodes a list of texts into embeddings + // EncodeToEmbedding encodes a list of texts into embeddings EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) + // Encode encodes a list of texts into embeddings (convenience method) + Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) + // EncodeQuery encodes a single query string into embedding (convenience method) + EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) + // Rerank calculates similarity scores between query and texts + Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) // List suppported models ListModels(apiConfig *APIConfig) ([]string, error) @@ -64,3 +72,73 @@ type APIConfig struct { type EmbeddingConfig struct { } + +// EmbeddingModel wraps a ModelDriver with embedding-specific configuration +type EmbeddingModel struct { + ModelDriver ModelDriver + ModelName string + APIConfig *APIConfig +} + +// NewEmbeddingModel creates a new EmbeddingModel +func NewEmbeddingModel(driver ModelDriver, modelName string, apiConfig *APIConfig) *EmbeddingModel { + return &EmbeddingModel{ + ModelDriver: driver, + ModelName: modelName, + APIConfig: apiConfig, + } +} + +// Encode encodes a list of texts into embeddings +func (e *EmbeddingModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { + return e.ModelDriver.EncodeToEmbedding(modelName, texts, apiConfig, nil) +} + +// EncodeQuery encodes a single query string into embedding +func (e *EmbeddingModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { + embeddings, err := e.ModelDriver.Encode(modelName, []string{query}, apiConfig) + if err != nil { + return nil, err + } + if len(embeddings) == 0 { + return nil, fmt.Errorf("no embedding returned") + } + return embeddings[0], nil +} + +// RerankModel wraps a ModelDriver with rerank-specific configuration +type RerankModel struct { + ModelDriver ModelDriver + ModelName string + APIConfig *APIConfig +} + +// NewRerankModel creates a new RerankModel +func NewRerankModel(driver ModelDriver, modelName string, apiConfig *APIConfig) *RerankModel { + return &RerankModel{ + ModelDriver: driver, + ModelName: modelName, + APIConfig: apiConfig, + } +} + +// Rerank calculates similarity between query and texts +func (r *RerankModel) Rerank(query string, texts []string, apiConfig *APIConfig) ([]float64, error) { + return r.ModelDriver.Rerank(&r.ModelName, query, texts, apiConfig) +} + +// ChatModel wraps a ModelDriver with chat-specific configuration +type ChatModel struct { + ModelDriver ModelDriver + ModelName string + APIConfig *APIConfig +} + +// NewChatModel creates a new ChatModel +func NewChatModel(driver ModelDriver, modelName string, apiConfig *APIConfig) *ChatModel { + return &ChatModel{ + ModelDriver: driver, + ModelName: modelName, + APIConfig: apiConfig, + } +} diff --git a/internal/entity/models/zhipu-ai.go b/internal/entity/models/zhipu-ai.go index bf395a7e9c..c041f39152 100644 --- a/internal/entity/models/zhipu-ai.go +++ b/internal/entity/models/zhipu-ai.go @@ -292,7 +292,7 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, message *string, apiCon region = *apiConfig.Region } - url := fmt.Sprintf("%s/chat/completions", z.BaseURL[region]) + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(z.BaseURL[region], "/"), z.URLSuffix.Chat) // Build request body with streaming enabled reqBody := map[string]interface{}{ @@ -440,7 +440,7 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName *string, texts []string, apiC region = *apiConfig.Region } - url := fmt.Sprintf("%s/embedding", z.BaseURL[region]) + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(z.BaseURL[region], "/"), z.URLSuffix.Embedding) embeddings := make([][]float64, len(texts)) @@ -518,6 +518,23 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName *string, texts []string, apiC return embeddings, nil } +// Encode encodes a list of texts into embeddings (convenience method) +func (z *ZhipuAIModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) { + return z.EncodeToEmbedding(modelName, texts, apiConfig, nil) +} + +// EncodeQuery encodes a single query string into embedding (convenience method) +func (z *ZhipuAIModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) { + embeddings, err := z.Encode(modelName, []string{query}, apiConfig) + if err != nil { + return nil, err + } + if len(embeddings) == 0 { + return nil, fmt.Errorf("no embedding returned") + } + return embeddings[0], nil +} + func (z *ZhipuAIModel) ListModels(apiConfig *APIConfig) ([]string, error) { return nil, fmt.Errorf("%s, no such method", z.Name()) } @@ -559,3 +576,8 @@ func (z *ZhipuAIModel) CheckConnection(apiConfig *APIConfig) error { return nil } + +// Rerank calculates similarity scores between query and texts +func (z *ZhipuAIModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) { + return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) +} diff --git a/internal/entity/types.go b/internal/entity/types.go index b2f2df2958..8f78dd33f6 100644 --- a/internal/entity/types.go +++ b/internal/entity/types.go @@ -16,6 +16,10 @@ package entity +import ( + "ragflow/internal/entity/models" +) + // ModelType represents the type of model type ModelType string @@ -39,9 +43,9 @@ const ( // EmbeddingModel interface for embedding models type EmbeddingModel interface { // Encode encodes a list of texts into embeddings - Encode(texts []string) ([][]float64, error) + Encode(modelName *string, texts []string, apiConfig *models.APIConfig) ([][]float64, error) // EncodeQuery encodes a single query string into embedding - EncodeQuery(query string) ([]float64, error) + EncodeQuery(modelName *string, query string, apiConfig *models.APIConfig) ([]float64, error) } // ChatModel interface for chat models @@ -54,8 +58,8 @@ type ChatModel interface { // RerankModel interface for rerank models type RerankModel interface { - // Similarity calculates similarity between query and texts - Similarity(query string, texts []string) ([]float64, error) + // Rerank calculates similarity between query and texts + Rerank(query string, texts []string, apiConfig *models.APIConfig) ([]float64, error) } // ModelConfig represents configuration for a model diff --git a/internal/handler/providers.go b/internal/handler/providers.go index 8e4e177042..7c49186f77 100644 --- a/internal/handler/providers.go +++ b/internal/handler/providers.go @@ -607,6 +607,9 @@ func (h *ProviderHandler) EnableOrDisableModel(c *gin.Context) { } modelName := c.Param("model_name") + if modelName != "" { + modelName = strings.TrimPrefix(modelName, "/") + } if modelName == "" { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, diff --git a/internal/router/router.go b/internal/router/router.go index 64123ff0a3..6eca00edc2 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -217,7 +217,7 @@ func (r *Router) Setup(engine *gin.Engine) { provider.PUT("/:provider_name/instances/:instance_name", r.providerHandler.AlterProviderInstance) provider.DELETE("/:provider_name/instances", r.providerHandler.DropProviderInstance) provider.GET("/:provider_name/instances/:instance_name/models", r.providerHandler.ListInstanceModels) - provider.PATCH("/:provider_name/instances/:instance_name/models/:model_name", r.providerHandler.EnableOrDisableModel) + provider.PATCH("/:provider_name/instances/:instance_name/models/*model_name", r.providerHandler.EnableOrDisableModel) provider.POST("/:provider_name/instances/:instance_name/models", r.providerHandler.ChatToModel) } diff --git a/internal/service/chunk.go b/internal/service/chunk.go index 53f8d7db74..fe9a71ff27 100644 --- a/internal/service/chunk.go +++ b/internal/service/chunk.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "ragflow/internal/entity" + "ragflow/internal/entity/models" "ragflow/internal/server" "strconv" "strings" @@ -40,7 +41,6 @@ import ( type ChunkService struct { docEngine engine.DocEngine engineType server.EngineType - modelProvider ModelProvider embeddingCache *utility.EmbeddingLRU kbDAO *dao.KnowledgebaseDAO userTenantDAO *dao.UserTenantDAO @@ -53,7 +53,6 @@ func NewChunkService() *ChunkService { return &ChunkService{ docEngine: engine.Get(), engineType: cfg.DocEngine.Type, - modelProvider: NewModelProvider(), embeddingCache: utility.NewEmbeddingLRU(1000), // default capacity kbDAO: dao.NewKnowledgebaseDAO(), userTenantDAO: dao.NewUserTenantDAO(), @@ -340,8 +339,8 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) ( } // Get embedding model for the tenant - var embeddingModel entity.EmbeddingModel - embeddingModel, err = s.modelProvider.GetEmbeddingModel(ctx, tenantIDs[0], embdID) + modelProviderSvc := NewModelProviderService() + embeddingModel, err := modelProviderSvc.GetEmbeddingModel(tenantIDs[0], embdID) if err != nil { return nil, fmt.Errorf("failed to get embedding model: %w", err) } @@ -350,7 +349,7 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) ( zap.String("embdID", embdID)) // Get rerank model if RerankID is specified - var rerankModel nlp.RerankModel + var rerankModel *models.RerankModel var rerankCompositeName string if req.TenantRerankID != nil && *req.TenantRerankID != "" { tenantRerankIDInt, parseErr := strconv.ParseInt(*req.TenantRerankID, 10, 64) @@ -361,19 +360,16 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) ( if err != nil { return nil, fmt.Errorf("failed to get rerank model by tenant_rerank_id: %w", err) } - rerankModel, err = s.modelProvider.GetRerankModel(ctx, tenantIDs[0], rerankCompositeName) - if err != nil { - return nil, fmt.Errorf("failed to get rerank model by tenant_rerank_id: %w", err) - } } else if req.RerankID != nil && *req.RerankID != "" { - var err error _, rerankCompositeName, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], *req.RerankID, entity.ModelTypeRerank) if err != nil { return nil, fmt.Errorf("failed to get rerank model by rerank_id: %w", err) } - rerankModel, err = s.modelProvider.GetRerankModel(ctx, tenantIDs[0], rerankCompositeName) + } + if rerankCompositeName != "" { + rerankModel, err = modelProviderSvc.GetRerankModel(tenantIDs[0], rerankCompositeName) if err != nil { - return nil, fmt.Errorf("failed to get rerank model by rerank_id: %w", err) + return nil, fmt.Errorf("failed to get rerank model: %w", err) } } diff --git a/internal/service/model_bundle.go b/internal/service/model_bundle.go index 441ee32a04..0f3fc6a65a 100644 --- a/internal/service/model_bundle.go +++ b/internal/service/model_bundle.go @@ -17,26 +17,29 @@ package service import ( - "context" "fmt" "ragflow/internal/entity" + modelModule "ragflow/internal/entity/models" ) // ModelBundle provides a unified interface for various model operations // Similar to Python's LLMBundle but with a more generic name type ModelBundle struct { - tenantID string - modelType entity.ModelType - modelName string - model interface{} // underlying model instance + tenantID string + modelType entity.ModelType + modelName string + model interface{} // underlying model instance + apiConfig *modelModule.APIConfig + embeddingConfig *modelModule.EmbeddingConfig } // NewModelBundle creates a new ModelBundle for the given tenant and model type // If modelName is empty, uses the default model for the tenant and type func NewModelBundle(tenantID string, modelType entity.ModelType, modelName ...string) (*ModelBundle, error) { bundle := &ModelBundle{ - tenantID: tenantID, - modelType: modelType, + tenantID: tenantID, + modelType: modelType, + embeddingConfig: &modelModule.EmbeddingConfig{}, } // Use provided model name if available @@ -45,26 +48,29 @@ func NewModelBundle(tenantID string, modelType entity.ModelType, modelName ...st } // Get model instance based on type - provider := NewModelProvider() + modelProviderSvc := NewModelProviderService() switch modelType { case entity.ModelTypeEmbedding: - embeddingModel, err := provider.GetEmbeddingModel(context.Background(), tenantID, bundle.modelName) + embd, err := modelProviderSvc.GetEmbeddingModel(tenantID, bundle.modelName) if err != nil { return nil, fmt.Errorf("failed to get embedding model: %w", err) } - bundle.model = embeddingModel + bundle.model = embd.ModelDriver + bundle.apiConfig = embd.APIConfig case entity.ModelTypeChat: - chatModel, err := provider.GetChatModel(context.Background(), tenantID, bundle.modelName) + chatMdl, err := modelProviderSvc.GetChatModel(tenantID, bundle.modelName) if err != nil { return nil, fmt.Errorf("failed to get chat model: %w", err) } - bundle.model = chatModel + bundle.model = chatMdl.ModelDriver + bundle.apiConfig = chatMdl.APIConfig case entity.ModelTypeRerank: - rerankModel, err := provider.GetRerankModel(context.Background(), tenantID, bundle.modelName) + rerankMdl, err := modelProviderSvc.GetRerankModel(tenantID, bundle.modelName) if err != nil { return nil, fmt.Errorf("failed to get rerank model: %w", err) } - bundle.model = rerankModel + bundle.model = rerankMdl.ModelDriver + bundle.apiConfig = rerankMdl.APIConfig default: return nil, fmt.Errorf("unsupported model type: %s", modelType) } @@ -84,7 +90,7 @@ func (b *ModelBundle) Encode(texts []string) ([][]float64, int64, error) { return nil, 0, fmt.Errorf("model is not an embedding model") } - embeddings, err := embeddingModel.Encode(texts) + embeddings, err := embeddingModel.Encode(&b.modelName, texts, b.apiConfig) if err != nil { return nil, 0, err } @@ -111,7 +117,7 @@ func (b *ModelBundle) EncodeQuery(query string) ([]float64, int64, error) { return nil, 0, fmt.Errorf("model is not an embedding model") } - embedding, err := embeddingModel.EncodeQuery(query) + embedding, err := embeddingModel.EncodeQuery(&b.modelName, query, b.apiConfig) if err != nil { return nil, 0, err } @@ -144,10 +150,10 @@ func (b *ModelBundle) Chat(system string, history []map[string]string, genConf m return response, tokenCount, nil } -// Similarity calculates similarity between query and texts -func (b *ModelBundle) Similarity(query string, texts []string) ([]float64, int64, error) { +// Rerank calculates similarity between query and texts +func (b *ModelBundle) Rerank(query string, texts []string) ([]float64, int64, error) { if b.modelType != entity.ModelTypeRerank { - return nil, 0, fmt.Errorf("model type %s does not support similarity", b.modelType) + return nil, 0, fmt.Errorf("model type %s does not support rerank", b.modelType) } rerankModel, ok := b.model.(entity.RerankModel) @@ -155,7 +161,7 @@ func (b *ModelBundle) Similarity(query string, texts []string) ([]float64, int64 return nil, 0, fmt.Errorf("model is not a rerank model") } - similarities, err := rerankModel.Similarity(query, texts) + similarities, err := rerankModel.Rerank(query, texts, b.apiConfig) if err != nil { return nil, 0, err } diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 20ed3fd930..902bc75d37 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -17,45 +17,17 @@ package service import ( - "context" "encoding/json" "errors" "fmt" - "net/http" "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/entity" modelModule "ragflow/internal/entity/models" "strings" "time" - - "ragflow/internal/service/models" ) -// ModelProvider provides model instances based on tenant and model type -type ModelProvider interface { - // GetEmbeddingModel returns an embedding model for the given tenant - GetEmbeddingModel(ctx context.Context, tenantID string, modelName string) (entity.EmbeddingModel, error) - // GetChatModel returns a chat model for the given tenant - GetChatModel(ctx context.Context, tenantID string, modelName string) (entity.ChatModel, error) - // GetRerankModel returns a rerank model for the given tenant - GetRerankModel(ctx context.Context, tenantID string, modelName string) (entity.RerankModel, error) -} - -// ModelProviderImpl implements ModelProvider -type ModelProviderImpl struct { - httpClient *http.Client -} - -// NewModelProvider creates a new ModelProvider -func NewModelProvider() *ModelProviderImpl { - return &ModelProviderImpl{ - httpClient: &http.Client{ - Timeout: 30 * time.Second, - }, - } -} - // parseModelName parses a composite model name in format "model_name@provider" // Returns modelName and provider separately func parseModelName(compositeName string) (modelName, provider string, err error) { @@ -69,111 +41,6 @@ func parseModelName(compositeName string) (modelName, provider string, err error } } -// GetEmbeddingModel returns an embedding model for the given tenant -func (p *ModelProviderImpl) GetEmbeddingModel(ctx context.Context, tenantID string, compositeModelName string) (entity.EmbeddingModel, error) { - // Parse composite model name to extract model name and provider - modelName, provider, err := parseModelName(compositeModelName) - if err != nil { - return nil, err - } - - // Get API key and configuration - embeddingModel, err := dao.NewTenantLLMDAO().GetByTenantFactoryAndModelName(tenantID, provider, modelName) - if err != nil { - return nil, err - } - - apiKey := embeddingModel.APIKey - if apiKey == nil || *apiKey == "" { - return nil, fmt.Errorf("no API key found for tenant %s and model %s", tenantID, compositeModelName) - } - - // Get API base from TenantLLM if set, otherwise from model provider configuration - apiBase := "" - if embeddingModel.APIBase != nil && *embeddingModel.APIBase != "" { - apiBase = *embeddingModel.APIBase - } else { - providerDAO := dao.NewModelProviderDAO() - providerConfig := providerDAO.GetProviderByName(provider) - if providerConfig == nil || providerConfig.DefaultURL == "" { - return nil, fmt.Errorf("no API base found for provider %s", provider) - } - apiBase = providerConfig.DefaultURL - } - - return models.CreateEmbeddingModel(provider, *apiKey, apiBase, modelName, p.httpClient) -} - -// GetChatModel returns a chat model for the given tenant -func (p *ModelProviderImpl) GetChatModel(ctx context.Context, tenantID string, compositeModelName string) (entity.ChatModel, error) { - // Parse composite model name to extract model name and provider - modelName, provider, err := parseModelName(compositeModelName) - if err != nil { - return nil, err - } - - // Get chat model from database - chatModel, err := dao.NewTenantLLMDAO().GetByTenantFactoryAndModelName(tenantID, provider, modelName) - if err != nil { - return nil, fmt.Errorf("no chat model found for tenant %s and model %s: %w", tenantID, compositeModelName, err) - } - - apiKey := chatModel.APIKey - if apiKey == nil || *apiKey == "" { - return nil, fmt.Errorf("no API key found for tenant %s and model %s", tenantID, compositeModelName) - } - - // Get API base from TenantLLM if set, otherwise from model provider configuration - apiBase := "" - if chatModel.APIBase != nil && *chatModel.APIBase != "" { - apiBase = *chatModel.APIBase - } else { - providerDAO := dao.NewModelProviderDAO() - providerConfig := providerDAO.GetProviderByName(provider) - if providerConfig == nil || providerConfig.DefaultURL == "" { - return nil, fmt.Errorf("no API base found for provider %s", provider) - } - apiBase = providerConfig.DefaultURL - } - - return models.CreateChatModel(provider, *apiKey, apiBase, modelName, p.httpClient) -} - -// GetRerankModel returns a rerank model for the given tenant -func (p *ModelProviderImpl) GetRerankModel(ctx context.Context, tenantID string, compositeModelName string) (entity.RerankModel, error) { - // Parse composite model name to extract model name and provider - modelName, provider, err := parseModelName(compositeModelName) - if err != nil { - return nil, err - } - - // Get rerank model from database - rerankModel, err := dao.NewTenantLLMDAO().GetByTenantFactoryAndModelName(tenantID, provider, modelName) - if err != nil { - return nil, fmt.Errorf("no rerank model found for tenant %s and model %s: %w", tenantID, compositeModelName, err) - } - - apiKey := rerankModel.APIKey - if apiKey == nil || *apiKey == "" { - return nil, fmt.Errorf("no API key found for tenant %s and model %s", tenantID, compositeModelName) - } - - // Get API base from TenantLLM if set, otherwise from model provider configuration - apiBase := "" - if rerankModel.APIBase != nil && *rerankModel.APIBase != "" { - apiBase = *rerankModel.APIBase - } else { - providerDAO := dao.NewModelProviderDAO() - providerConfig := providerDAO.GetProviderByName(provider) - if providerConfig == nil || providerConfig.DefaultURL == "" { - return nil, fmt.Errorf("no API base found for provider %s", provider) - } - apiBase = providerConfig.DefaultURL - } - - return models.CreateRerankModel(provider, *apiKey, apiBase, modelName, p.httpClient) -} - func NewModelProviderService() *ModelProviderService { return &ModelProviderService{ modelProviderDAO: dao.NewTenantModelProviderDAO(), @@ -973,3 +840,94 @@ func (m *ModelProviderService) GetModelByName(modelName string, tenantID string) APIKey: *tenantLLM.APIKey, }, nil } + +// GetEmbeddingModel returns an EmbeddingModel wrapper for the given tenant +func (m *ModelProviderService) GetEmbeddingModel(tenantID, compositeModelName string) (*modelModule.EmbeddingModel, error) { + driver, modelName, apiConfig, err := m.getModelConfig(tenantID, compositeModelName) + if err != nil { + return nil, err + } + return modelModule.NewEmbeddingModel(driver, modelName, apiConfig), nil +} + +// GetRerankModel returns a RerankModel wrapper for the given tenant +func (m *ModelProviderService) GetRerankModel(tenantID, compositeModelName string) (*modelModule.RerankModel, error) { + driver, modelName, apiConfig, err := m.getModelConfig(tenantID, compositeModelName) + if err != nil { + return nil, err + } + return modelModule.NewRerankModel(driver, modelName, apiConfig), nil +} + +// GetChatModel returns a ChatModel wrapper for the given tenant +func (m *ModelProviderService) GetChatModel(tenantID, compositeModelName string) (*modelModule.ChatModel, error) { + driver, modelName, apiConfig, err := m.getModelConfig(tenantID, compositeModelName) + if err != nil { + return nil, err + } + return modelModule.NewChatModel(driver, modelName, apiConfig), nil +} + +// getModelConfig returns the model driver, model name, and API config for a model +func (m *ModelProviderService) getModelConfig(tenantID, compositeModelName string) (modelModule.ModelDriver, string, *modelModule.APIConfig, error) { + modelName, providerName, err := parseModelName(compositeModelName) + if err != nil { + return nil, "", nil, err + } + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return nil, "", nil, err + } + if provider == nil { + return nil, "", nil, fmt.Errorf("provider %s not found", providerName) + } + + instanceName := "default_instance" + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return nil, "", nil, err + } + if instance == nil { + return nil, "", nil, fmt.Errorf("instance %s not found for provider %s", instanceName, providerName) + } + + _, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) + if err != nil { + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, "", nil, fmt.Errorf("provider %s not found", providerName) + } + + _, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return nil, "", nil, fmt.Errorf("provider %s model %s not found", providerName, modelName) + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, "", nil, err + } + region := extra["region"] + + apiConfig := &modelModule.APIConfig{ApiKey: &instance.APIKey, Region: ®ion} + return providerInfo.ModelDriver, modelName, apiConfig, nil + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, "", nil, err + } + region := extra["region"] + + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, "", nil, fmt.Errorf("provider %s not found", providerName) + } + + apiConfig := &modelModule.APIConfig{ApiKey: &instance.APIKey, Region: ®ion} + return providerInfo.ModelDriver, modelName, apiConfig, nil +} diff --git a/internal/service/models/deepseek_model.go b/internal/service/models/deepseek_model.go deleted file mode 100644 index cf6a2f2167..0000000000 --- a/internal/service/models/deepseek_model.go +++ /dev/null @@ -1,33 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package models - -import ( - "net/http" - "ragflow/internal/entity" -) - -func init() { - RegisterEmbeddingModelFactory("DeepSeek", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel { - return &openAIEmbeddingModel{ - apiKey: apiKey, - apiBase: apiBase, - model: modelName, - httpClient: httpClient, - } - }) -} diff --git a/internal/service/models/factory.go b/internal/service/models/factory.go deleted file mode 100644 index b3ed9c5c76..0000000000 --- a/internal/service/models/factory.go +++ /dev/null @@ -1,119 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package models - -import ( - "fmt" - "net/http" - "ragflow/internal/entity" - - "sync" -) - -// EmbeddingModelFactory creates an EmbeddingModel instance -type EmbeddingModelFactory func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel - -// ChatModelFactory creates a ChatModel instance -type ChatModelFactory func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.ChatModel - -// RerankModelFactory creates a RerankModel instance -type RerankModelFactory func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.RerankModel - -var ( - embeddingModelFactories = make(map[string]EmbeddingModelFactory) - chatModelFactories = make(map[string]ChatModelFactory) - rerankModelFactories = make(map[string]RerankModelFactory) - factoryMu sync.RWMutex -) - -// RegisterEmbeddingModelFactory registers a factory for a provider name. -// Should be called from init() functions of provider implementations. -func RegisterEmbeddingModelFactory(providerName string, factory EmbeddingModelFactory) { - factoryMu.Lock() - defer factoryMu.Unlock() - embeddingModelFactories[providerName] = factory -} - -// RegisterChatModelFactory registers a factory for a chat provider name. -// Should be called from init() functions of provider implementations. -func RegisterChatModelFactory(providerName string, factory ChatModelFactory) { - factoryMu.Lock() - defer factoryMu.Unlock() - chatModelFactories[providerName] = factory -} - -// RegisterRerankModelFactory registers a factory for a rerank provider name. -// Should be called from init() functions of provider implementations. -func RegisterRerankModelFactory(providerName string, factory RerankModelFactory) { - factoryMu.Lock() - defer factoryMu.Unlock() - rerankModelFactories[providerName] = factory -} - -// GetEmbeddingModelFactory returns the factory for the given provider name. -// Returns nil if not found. -func GetEmbeddingModelFactory(providerName string) EmbeddingModelFactory { - factoryMu.RLock() - defer factoryMu.RUnlock() - return embeddingModelFactories[providerName] -} - -// GetChatModelFactory returns the factory for the given chat provider name. -// Returns nil if not found. -func GetChatModelFactory(providerName string) ChatModelFactory { - factoryMu.RLock() - defer factoryMu.RUnlock() - return chatModelFactories[providerName] -} - -// GetRerankModelFactory returns the factory for the given rerank provider name. -// Returns nil if not found. -func GetRerankModelFactory(providerName string) RerankModelFactory { - factoryMu.RLock() - defer factoryMu.RUnlock() - return rerankModelFactories[providerName] -} - -// CreateEmbeddingModel creates an EmbeddingModel instance for the given provider. -// Returns error if provider not registered. -func CreateEmbeddingModel(providerName, apiKey, apiBase, modelName string, httpClient *http.Client) (entity.EmbeddingModel, error) { - factory := GetEmbeddingModelFactory(providerName) - if factory == nil { - return nil, fmt.Errorf("no embedding model factory registered for provider %s", providerName) - } - return factory(apiKey, apiBase, modelName, httpClient), nil -} - -// CreateChatModel creates a ChatModel instance for the given provider. -// Returns error if provider not registered. -func CreateChatModel(providerName, apiKey, apiBase, modelName string, httpClient *http.Client) (entity.ChatModel, error) { - factory := GetChatModelFactory(providerName) - if factory == nil { - return nil, fmt.Errorf("no chat model factory registered for provider %s", providerName) - } - return factory(apiKey, apiBase, modelName, httpClient), nil -} - -// CreateRerankModel creates a RerankModel instance for the given provider. -// Returns error if provider not registered. -func CreateRerankModel(providerName, apiKey, apiBase, modelName string, httpClient *http.Client) (entity.RerankModel, error) { - factory := GetRerankModelFactory(providerName) - if factory == nil { - return nil, fmt.Errorf("no rerank model factory registered for provider %s", providerName) - } - return factory(apiKey, apiBase, modelName, httpClient), nil -} diff --git a/internal/service/models/gitee_model.go b/internal/service/models/gitee_model.go deleted file mode 100644 index c121db6b99..0000000000 --- a/internal/service/models/gitee_model.go +++ /dev/null @@ -1,127 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package models - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "ragflow/internal/entity" - - "strings" -) - -// giteeEmbeddingModel implements EmbeddingModel for GiteeAI API (assumed OpenAI-compatible) -type giteeEmbeddingModel struct { - apiKey string - apiBase string - model string - httpClient *http.Client -} - -// GiteeEmbeddingRequest represents GiteeAI embedding request -type GiteeEmbeddingRequest struct { - Model string `json:"model"` - Input []string `json:"input"` - EncodeFormat string `json:"encode_format"` -} - -// GiteeEmbeddingResponse represents GiteeAI embedding response -type GiteeEmbeddingResponse struct { - Data []struct { - Embedding []float64 `json:"embedding"` - Index int `json:"index"` - } `json:"data"` -} - -// Encode encodes a list of texts into embeddings using GiteeAI API -func (m *giteeEmbeddingModel) Encode(texts []string) ([][]float64, error) { - if len(texts) == 0 { - return [][]float64{}, nil - } - - reqBody := GiteeEmbeddingRequest{ - Model: m.model, - Input: texts, - EncodeFormat: "float", - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequest("POST", m.apiBase, strings.NewReader(string(jsonData))) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Accept", "application/json") - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+m.apiKey) - - resp, err := m.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("GiteeAI API error: %s, body: %s", resp.Status, string(body)) - } - - var embeddingResp GiteeEmbeddingResponse - if err := json.NewDecoder(resp.Body).Decode(&embeddingResp); err != nil { - return nil, fmt.Errorf("failed to decode response: %w", err) - } - - // Sort embeddings by index to ensure correct order - embeddings := make([][]float64, len(texts)) - for _, data := range embeddingResp.Data { - if data.Index < len(embeddings) { - embeddings[data.Index] = data.Embedding - } - } - - return embeddings, nil -} - -// EncodeQuery encodes a single query string into embedding -func (m *giteeEmbeddingModel) EncodeQuery(query string) ([]float64, error) { - embeddings, err := m.Encode([]string{query}) - if err != nil { - return nil, err - } - if len(embeddings) == 0 { - return nil, fmt.Errorf("no embedding returned") - } - return embeddings[0], nil -} - -// init registers the GiteeAI embedding model factory -func init() { - RegisterEmbeddingModelFactory("GiteeAI", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel { - return &giteeEmbeddingModel{ - apiKey: apiKey, - apiBase: apiBase, - model: modelName, - httpClient: httpClient, - } - }) -} diff --git a/internal/service/models/moonshot_model.go b/internal/service/models/moonshot_model.go deleted file mode 100644 index 74d2fec9cc..0000000000 --- a/internal/service/models/moonshot_model.go +++ /dev/null @@ -1,33 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package models - -import ( - "net/http" - "ragflow/internal/entity" -) - -func init() { - RegisterEmbeddingModelFactory("Moonshot", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel { - return &openAIEmbeddingModel{ - apiKey: apiKey, - apiBase: apiBase, - model: modelName, - httpClient: httpClient, - } - }) -} diff --git a/internal/service/models/openai_api_compatible_model.go b/internal/service/models/openai_api_compatible_model.go deleted file mode 100644 index eff6c839ca..0000000000 --- a/internal/service/models/openai_api_compatible_model.go +++ /dev/null @@ -1,33 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package models - -import ( - "net/http" - "ragflow/internal/entity" -) - -func init() { - RegisterEmbeddingModelFactory("OpenAI-API-Compatible", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel { - return &openAIEmbeddingModel{ - apiKey: apiKey, - apiBase: apiBase, - model: modelName, - httpClient: httpClient, - } - }) -} diff --git a/internal/service/models/openai_model.go b/internal/service/models/openai_model.go deleted file mode 100644 index 7524a9dd9c..0000000000 --- a/internal/service/models/openai_model.go +++ /dev/null @@ -1,124 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package models - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "ragflow/internal/entity" - - "strings" -) - -// openAIEmbeddingModel implements EmbeddingModel for OpenAI API -type openAIEmbeddingModel struct { - apiKey string - apiBase string - model string - httpClient *http.Client -} - -// OpenAIEmbeddingRequest represents OpenAI embedding request -type OpenAIEmbeddingRequest struct { - Model string `json:"model"` - Input []string `json:"input"` -} - -// OpenAIEmbeddingResponse represents OpenAI embedding response -type OpenAIEmbeddingResponse struct { - Data []struct { - Embedding []float64 `json:"embedding"` - Index int `json:"index"` - } `json:"data"` -} - -// Encode encodes a list of texts into embeddings using OpenAI API -func (m *openAIEmbeddingModel) Encode(texts []string) ([][]float64, error) { - if len(texts) == 0 { - return [][]float64{}, nil - } - - reqBody := OpenAIEmbeddingRequest{ - Model: m.model, - Input: texts, - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequest("POST", m.apiBase+"/embeddings", strings.NewReader(string(jsonData))) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+m.apiKey) - - resp, err := m.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("OpenAI API error: %s, body: %s", resp.Status, string(body)) - } - - var embeddingResp OpenAIEmbeddingResponse - if err := json.NewDecoder(resp.Body).Decode(&embeddingResp); err != nil { - return nil, fmt.Errorf("failed to decode response: %w", err) - } - - // Sort embeddings by index to ensure correct order - embeddings := make([][]float64, len(texts)) - for _, data := range embeddingResp.Data { - if data.Index < len(embeddings) { - embeddings[data.Index] = data.Embedding - } - } - - return embeddings, nil -} - -// EncodeQuery encodes a single query string into embedding -func (m *openAIEmbeddingModel) EncodeQuery(query string) ([]float64, error) { - embeddings, err := m.Encode([]string{query}) - if err != nil { - return nil, err - } - if len(embeddings) == 0 { - return nil, fmt.Errorf("no embedding returned") - } - return embeddings[0], nil -} - -// init registers the OpenAI embedding model factory -func init() { - RegisterEmbeddingModelFactory("OpenAI", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel { - return &openAIEmbeddingModel{ - apiKey: apiKey, - apiBase: apiBase, - model: modelName, - httpClient: httpClient, - } - }) -} diff --git a/internal/service/models/siliconflow_model.go b/internal/service/models/siliconflow_model.go deleted file mode 100644 index 75f89f3525..0000000000 --- a/internal/service/models/siliconflow_model.go +++ /dev/null @@ -1,380 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package models - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "ragflow/internal/entity" - - "strings" -) - -// siliconflowEmbeddingModel implements EmbeddingModel for SILICONFLOW API (OpenAI-compatible) -type siliconflowEmbeddingModel struct { - apiKey string - apiBase string - model string - httpClient *http.Client -} - -// siliconflowChatModel implements ChatModel for SILICONFLOW API -type siliconflowChatModel struct { - apiKey string - apiBase string - model string - httpClient *http.Client -} - -// siliconflowRerankModel implements RerankModel for SILICONFLOW API -type siliconflowRerankModel struct { - apiKey string - apiBase string - model string - httpClient *http.Client -} - -// SiliconflowEmbeddingRequest represents SILICONFLOW embedding request -type SiliconflowEmbeddingRequest struct { - Model string `json:"model"` - Input []string `json:"input"` -} - -// SiliconflowEmbeddingResponse represents SILICONFLOW embedding response -type SiliconflowEmbeddingResponse struct { - Data []struct { - Embedding []float64 `json:"embedding"` - Index int `json:"index"` - } `json:"data"` -} - -// SiliconflowChatRequest represents SILICONFLOW chat request -type SiliconflowChatRequest struct { - Model string `json:"model"` - Messages []ChatMessage `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Stream bool `json:"stream,omitempty"` -} - -// SiliconflowChatResponse represents SILICONFLOW chat response -type SiliconflowChatResponse struct { - Choices []struct { - Message struct { - Content string `json:"content"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Error struct { - Message string `json:"message"` - Code string `json:"code"` - } `json:"error,omitempty"` -} - -// ChatMessage represents a chat message -type ChatMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -// SiliconflowRerankRequest represents SILICONFLOW rerank request -type SiliconflowRerankRequest struct { - Model string `json:"model"` - Query string `json:"query"` - Documents []string `json:"documents"` - TopN int `json:"top_n"` - ReturnDocuments bool `json:"return_documents"` - MaxChunksPerDoc int `json:"max_chunks_per_doc"` - OverlapTokens int `json:"overlap_tokens"` -} - -// SiliconflowRerankResponse represents SILICONFLOW rerank response -type SiliconflowRerankResponse struct { - Results []struct { - Index int `json:"index"` - RelevanceScore float64 `json:"relevance_score"` - } `json:"results"` -} - -// Encode encodes a list of texts into embeddings using SILICONFLOW API -func (m *siliconflowEmbeddingModel) Encode(texts []string) ([][]float64, error) { - if len(texts) == 0 { - return [][]float64{}, nil - } - - reqBody := SiliconflowEmbeddingRequest{ - Model: m.model, - Input: texts, - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequest("POST", m.apiBase+"/embeddings", strings.NewReader(string(jsonData))) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+m.apiKey) - - resp, err := m.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("SILICONFLOW API error: %s, body: %s", resp.Status, string(body)) - } - - var embeddingResp SiliconflowEmbeddingResponse - if err := json.NewDecoder(resp.Body).Decode(&embeddingResp); err != nil { - return nil, fmt.Errorf("failed to decode response: %w", err) - } - - // Sort embeddings by index to ensure correct order - embeddings := make([][]float64, len(texts)) - for _, data := range embeddingResp.Data { - if data.Index < len(embeddings) { - embeddings[data.Index] = data.Embedding - } - } - - return embeddings, nil -} - -// EncodeQuery encodes a single query string into embedding -func (m *siliconflowEmbeddingModel) EncodeQuery(query string) ([]float64, error) { - embeddings, err := m.Encode([]string{query}) - if err != nil { - return nil, err - } - if len(embeddings) == 0 { - return nil, fmt.Errorf("no embedding returned") - } - return embeddings[0], nil -} - -// Chat sends a chat message and returns response -func (m *siliconflowChatModel) Chat(system string, history []map[string]string, genConf map[string]interface{}) (string, error) { - // Build messages array - var messages []ChatMessage - - // Add system message if provided - if system != "" { - messages = append(messages, ChatMessage{Role: "system", Content: system}) - } - - // Add history messages - for _, msg := range history { - role := msg["role"] - content := msg["content"] - if role != "" && content != "" { - messages = append(messages, ChatMessage{Role: role, Content: content}) - } - } - - // Extract generation config - temperature := 0.7 - if temp, ok := genConf["temperature"].(float64); ok { - temperature = temp - } - maxTokens := 1024 - if mt, ok := genConf["max_tokens"].(int); ok { - maxTokens = mt - } - - // Build request - reqBody := SiliconflowChatRequest{ - Model: m.model, - Messages: messages, - Temperature: temperature, - MaxTokens: maxTokens, - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return "", fmt.Errorf("failed to marshal request: %w", err) - } - - // Build URL - append /chat/completions if not already present - url := m.apiBase - if !strings.HasSuffix(url, "/chat/completions") { - if !strings.HasSuffix(url, "/") { - url += "/" - } - url += "chat/completions" - } - - req, err := http.NewRequest("POST", url, strings.NewReader(string(jsonData))) - if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+m.apiKey) - - resp, err := m.httpClient.Do(req) - if err != nil { - return "", fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("SILICONFLOW API error: %s, body: %s", resp.Status, string(body)) - } - - var chatResp SiliconflowChatResponse - if err := json.Unmarshal(body, &chatResp); err != nil { - return "", fmt.Errorf("failed to decode response: %w", err) - } - - if chatResp.Error.Message != "" { - return "", fmt.Errorf("chat error: %s", chatResp.Error.Message) - } - - if len(chatResp.Choices) == 0 { - return "", fmt.Errorf("no response choices returned") - } - - return chatResp.Choices[0].Message.Content, nil -} - -// ChatStreamly sends a chat message and streams response -func (m *siliconflowChatModel) ChatStreamly(system string, history []map[string]string, genConf map[string]interface{}) (<-chan string, error) { - // For now, return a simple non-streaming implementation - // Streaming can be implemented later with SSE support - responseChan := make(chan string) - - go func() { - defer close(responseChan) - response, err := m.Chat(system, history, genConf) - if err != nil { - responseChan <- "**ERROR**: " + err.Error() - return - } - responseChan <- response - }() - - return responseChan, nil -} - -// Similarity calculates similarity scores between query and texts using SiliconFlow API -func (m *siliconflowRerankModel) Similarity(query string, texts []string) ([]float64, error) { - if len(texts) == 0 { - return []float64{}, nil - } - - reqBody := SiliconflowRerankRequest{ - Model: m.model, - Query: query, - Documents: texts, - TopN: len(texts), - ReturnDocuments: false, - MaxChunksPerDoc: 1024, - OverlapTokens: 80, - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - reqURL := m.apiBase - if !strings.Contains(reqURL, "/rerank") { - if !strings.HasSuffix(reqURL, "/") { - reqURL += "/" - } - reqURL += "rerank" - } - - req, err := http.NewRequest("POST", reqURL, strings.NewReader(string(jsonData))) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+m.apiKey) - - resp, err := m.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("SiliconFlow Rerank API error: %s, body: %s", resp.Status, string(body)) - } - - body, _ := io.ReadAll(resp.Body) - - var rerankResp SiliconflowRerankResponse - if err := json.Unmarshal(body, &rerankResp); err != nil { - return nil, fmt.Errorf("failed to decode response: %w", err) - } - - scores := make([]float64, len(texts)) - for _, result := range rerankResp.Results { - if result.Index >= 0 && result.Index < len(texts) { - scores[result.Index] = result.RelevanceScore - } - } - - return scores, nil -} - -// init registers the SILICONFLOW model factories -func init() { - RegisterEmbeddingModelFactory("SILICONFLOW", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel { - return &siliconflowEmbeddingModel{ - apiKey: apiKey, - apiBase: apiBase, - model: modelName, - httpClient: httpClient, - } - }) - - RegisterChatModelFactory("SILICONFLOW", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.ChatModel { - return &siliconflowChatModel{ - apiKey: apiKey, - apiBase: apiBase, - model: modelName, - httpClient: httpClient, - } - }) - - RegisterRerankModelFactory("SILICONFLOW", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.RerankModel { - return &siliconflowRerankModel{ - apiKey: apiKey, - apiBase: apiBase, - model: modelName, - httpClient: httpClient, - } - }) -} diff --git a/internal/service/models/zhipu_model.go b/internal/service/models/zhipu_model.go deleted file mode 100644 index f674d07d4d..0000000000 --- a/internal/service/models/zhipu_model.go +++ /dev/null @@ -1,33 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package models - -import ( - "net/http" - "ragflow/internal/entity" -) - -func init() { - RegisterEmbeddingModelFactory("ZHIPU-AI", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel { - return &openAIEmbeddingModel{ - apiKey: apiKey, - apiBase: apiBase, - model: modelName, - httpClient: httpClient, - } - }) -} diff --git a/internal/service/nlp/reranker.go b/internal/service/nlp/reranker.go index 0ab4d1c5c8..fab55987a4 100644 --- a/internal/service/nlp/reranker.go +++ b/internal/service/nlp/reranker.go @@ -23,18 +23,12 @@ import ( "strings" "ragflow/internal/common" + "ragflow/internal/entity/models" "ragflow/internal/logger" "go.uber.org/zap" ) -// RerankModel defines the interface for reranker models -// This matches model.RerankModel interface -type RerankModel interface { - // Similarity calculates similarity between query and texts - Similarity(query string, texts []string) ([]float64, error) -} - // SearchResult represents the result of a search operation type SearchResult struct { Total int @@ -60,7 +54,7 @@ type SearchResult struct { // - tsim: token similarity scores // - vsim: vector similarity scores func Rerank( - rerankModel RerankModel, + rerankModel *models.RerankModel, chunks []map[string]interface{}, total int, keywords []string, @@ -94,7 +88,7 @@ func Rerank( // RerankByModel performs reranking using a reranker model func RerankByModel( - rerankModel RerankModel, + rerankModel *models.RerankModel, chunks []map[string]interface{}, query string, tkWeight, vtWeight float64, @@ -142,9 +136,9 @@ func RerankByModel( tsim = TokenSimilarity(keywords, insTw, qb) // Get similarity scores from reranker model - modelSim, err := rerankModel.Similarity(query, docs) + modelSim, err := rerankModel.ModelDriver.Rerank(&rerankModel.ModelName, query, docs, rerankModel.APIConfig) if err != nil { - logger.Error("RerankByModel: rerankModel.Similarity failed; falling back to token-only similarity", err) + logger.Error("RerankByModel: rerankModel.Rerank failed; falling back to token-only similarity", err) // If model fails, fall back to token similarity only modelSim = make([]float64, len(tsim)) } diff --git a/internal/service/nlp/retrieval.go b/internal/service/nlp/retrieval.go index 5f6bb8185f..76f6d7d7fc 100644 --- a/internal/service/nlp/retrieval.go +++ b/internal/service/nlp/retrieval.go @@ -20,13 +20,13 @@ import ( "context" "fmt" "math" + "ragflow/internal/engine" + "ragflow/internal/engine/types" + "ragflow/internal/entity/models" "ragflow/internal/logger" "sort" "strings" - "ragflow/internal/engine" - "ragflow/internal/engine/types" - "ragflow/internal/entity" "ragflow/internal/tokenizer" "go.uber.org/zap" @@ -54,8 +54,8 @@ type RetrievalRequest struct { SimilarityThreshold *float64 VectorSimilarityWeight *float64 RankFeature *map[string]float64 - RerankModel RerankModel - EmbeddingModel entity.EmbeddingModel + RerankModel *models.RerankModel + EmbeddingModel *models.EmbeddingModel Aggs *bool Highlight *bool } @@ -384,7 +384,7 @@ type RetrievalSearchRequest struct { SimilarityThreshold float64 RankFeature map[string]float64 Filter map[string]interface{} - EmbeddingModel interface{} + EmbeddingModel *models.EmbeddingModel } type RetrievalSearchResult struct { @@ -489,7 +489,7 @@ func (s *RetrievalService) Search(ctx context.Context, req *RetrievalSearchReque if similarityForGetVector <= 0 { similarityForGetVector = 0.1 } - matchDense, err := s.GetVector(req.Question, req.EmbeddingModel.(entity.EmbeddingModel), topk, similarityForGetVector) + matchDense, err := s.GetVector(req.Question, req.EmbeddingModel, topk, similarityForGetVector) if err != nil { return nil, fmt.Errorf("GetVector failed: %w", err) } @@ -596,8 +596,8 @@ func (s *RetrievalService) Search(ctx context.Context, req *RetrievalSearchReque } // GetVector computes query vector and returns MatchDenseExpr for hybrid search -func (s *RetrievalService) GetVector(txt string, embModel entity.EmbeddingModel, topk int, similarity float64) (*types.MatchDenseExpr, error) { - vector, err := embModel.EncodeQuery(txt) +func (s *RetrievalService) GetVector(txt string, embModel *models.EmbeddingModel, topk int, similarity float64) (*types.MatchDenseExpr, error) { + vector, err := embModel.ModelDriver.EncodeQuery(&embModel.ModelName, txt, embModel.APIConfig) if err != nil { return nil, err }