mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-02 16:55:42 +08:00
Refactor model in GO (#14398)
### What problem does this PR solve? Refactor model in GO ### Type of change - [x] Refactoring
This commit is contained in:
@@ -37,6 +37,13 @@
|
||||
"model_types": [
|
||||
"rerank"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Qwen/Qwen3-Embedding-0.6B",
|
||||
"max_tokens": 8192,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"chat": "chat/completions",
|
||||
"async_chat": "async/chat/completions",
|
||||
"async_result": "async-result",
|
||||
"embedding": "embedding",
|
||||
"embedding": "embeddings",
|
||||
"rerank": "rerank",
|
||||
"files": "files"
|
||||
},
|
||||
|
||||
@@ -337,6 +337,21 @@ func (z *AliyunModel) EncodeToEmbedding(modelName *string, texts []string, apiCo
|
||||
return nil, fmt.Errorf("%s, no such method", z.Name())
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *AliyunModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *AliyunModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
func (z *AliyunModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())
|
||||
}
|
||||
|
||||
type AliyunModelItem struct {
|
||||
ModelName string `json:"model_name"`
|
||||
BaseCapacity int `json:"base_capacity"`
|
||||
|
||||
@@ -401,6 +401,16 @@ func (z *DeepSeekModel) EncodeToEmbedding(modelName *string, texts []string, api
|
||||
return nil, fmt.Errorf("%s, no such method", z.Name())
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *DeepSeekModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *DeepSeekModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
|
||||
}
|
||||
|
||||
type DSModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
@@ -476,3 +486,8 @@ func (z *DeepSeekModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
func (z *DeepSeekModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())
|
||||
}
|
||||
|
||||
@@ -58,6 +58,16 @@ func (z *DummyModel) EncodeToEmbedding(modelName *string, texts []string, apiCon
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *DummyModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *DummyModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
|
||||
}
|
||||
|
||||
func (z *DummyModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
@@ -69,3 +79,8 @@ func (z *DummyModel) Balance(apiConfig *APIConfig) (map[string]interface{}, erro
|
||||
func (z *DummyModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
return fmt.Errorf("no such method")
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
func (z *DummyModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())
|
||||
}
|
||||
|
||||
@@ -367,6 +367,21 @@ func (z *GiteeModel) EncodeToEmbedding(modelName *string, texts []string, apiCon
|
||||
return nil, fmt.Errorf("%s, no such method", z.Name())
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *GiteeModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *GiteeModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
func (z *GiteeModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())
|
||||
}
|
||||
|
||||
func (z *GiteeModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil {
|
||||
|
||||
@@ -171,3 +171,25 @@ func (z *GoogleModel) Balance(apiConfig *APIConfig) (map[string]interface{}, err
|
||||
func (z *GoogleModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
return fmt.Errorf("no such method")
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *GoogleModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return z.EncodeToEmbedding(modelName, texts, apiConfig, nil)
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *GoogleModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
embeddings, err := z.Encode(modelName, []string{query}, apiConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return nil, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
return embeddings[0], nil
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
func (z *GoogleModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())
|
||||
}
|
||||
|
||||
@@ -71,6 +71,16 @@ func (z *MinimaxModel) EncodeToEmbedding(modelName *string, texts []string, apiC
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *MinimaxModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *MinimaxModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
|
||||
}
|
||||
|
||||
func (z *MinimaxModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", z.Name())
|
||||
}
|
||||
@@ -112,3 +122,8 @@ func (z *MinimaxModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
func (z *MinimaxModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())
|
||||
}
|
||||
|
||||
@@ -73,6 +73,16 @@ func (z *MoonshotModel) EncodeToEmbedding(modelName *string, texts []string, api
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *MoonshotModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Encode not implemented", z.Name())
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *MoonshotModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, EncodeQuery not implemented", z.Name())
|
||||
}
|
||||
|
||||
func (z *MoonshotModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil {
|
||||
@@ -193,3 +203,8 @@ func (z *MoonshotModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
func (z *MoonshotModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())
|
||||
}
|
||||
|
||||
@@ -56,6 +56,26 @@ func (z *SiliconflowModel) Name() string {
|
||||
return "siliconflow"
|
||||
}
|
||||
|
||||
|
||||
// SiliconflowRerankRequest represents SILICONFLOW rerank request
|
||||
type SiliconflowRerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN int `json:"top_n"`
|
||||
ReturnDocuments bool `json:"return_documents"`
|
||||
MaxChunksPerDoc int `json:"max_chunks_per_doc"`
|
||||
OverlapTokens int `json:"overlap_tokens"`
|
||||
}
|
||||
|
||||
// SiliconflowRerankResponse represents SILICONFLOW rerank response
|
||||
type SiliconflowRerankResponse struct {
|
||||
Results []struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
// Chat sends a message and returns response
|
||||
func (z *SiliconflowModel) Chat(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
||||
if message == nil {
|
||||
@@ -363,8 +383,116 @@ func (z *SiliconflowModel) ChatStreamlyWithSender(modelName, message *string, ap
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *SiliconflowModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", z.Name())
|
||||
func (s *SiliconflowModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
if len(texts) == 0 {
|
||||
return [][]float64{}, nil
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(s.BaseURL[region], "/"), s.URLSuffix.Embedding)
|
||||
|
||||
apiKey := ""
|
||||
if apiConfig != nil && apiConfig.ApiKey != nil {
|
||||
apiKey = *apiConfig.ApiKey
|
||||
}
|
||||
|
||||
embeddings := make([][]float64, len(texts))
|
||||
|
||||
for i, text := range texts {
|
||||
reqBody := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"input": text,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("SILICONFLOW API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var result map[string]interface{}
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
data, ok := result["data"].([]interface{})
|
||||
if !ok || len(data) == 0 {
|
||||
return nil, fmt.Errorf("no data in response")
|
||||
}
|
||||
|
||||
firstData, ok := data[0].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid data format")
|
||||
}
|
||||
|
||||
embeddingSlice, ok := firstData["embedding"].([]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid embedding format")
|
||||
}
|
||||
|
||||
embedding := make([]float64, len(embeddingSlice))
|
||||
for j, v := range embeddingSlice {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
embedding[j] = val
|
||||
case float32:
|
||||
embedding[j] = float64(val)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected embedding value type")
|
||||
}
|
||||
}
|
||||
|
||||
embeddings[i] = embedding
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return s.EncodeToEmbedding(modelName, texts, apiConfig, nil)
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (s *SiliconflowModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
embeddings, err := s.Encode(modelName, []string{query}, apiConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return nil, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
return embeddings[0], nil
|
||||
}
|
||||
|
||||
func (z *SiliconflowModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
@@ -435,3 +563,74 @@ func (z *SiliconflowModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
func (s *SiliconflowModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
if len(texts) == 0 {
|
||||
return []float64{}, nil
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
apiKey := ""
|
||||
if apiConfig != nil && apiConfig.ApiKey != nil {
|
||||
apiKey = *apiConfig.ApiKey
|
||||
}
|
||||
|
||||
reqBody := SiliconflowRerankRequest{
|
||||
Model: *modelName,
|
||||
Query: query,
|
||||
Documents: texts,
|
||||
TopN: len(texts),
|
||||
ReturnDocuments: false,
|
||||
MaxChunksPerDoc: 1024,
|
||||
OverlapTokens: 80,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(s.BaseURL[region], "/"), s.URLSuffix.Rerank)
|
||||
|
||||
req, err := http.NewRequest("POST", url, strings.NewReader(string(jsonData)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("SiliconFlow Rerank API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
var rerankResp SiliconflowRerankResponse
|
||||
if err := json.Unmarshal(body, &rerankResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
scores := make([]float64, len(texts))
|
||||
for _, result := range rerankResp.Results {
|
||||
if result.Index >= 0 && result.Index < len(texts) {
|
||||
scores[result.Index] = result.RelevanceScore
|
||||
}
|
||||
}
|
||||
|
||||
return scores, nil
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package models
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Message represents a chat message with role
|
||||
type Message struct {
|
||||
Role string
|
||||
@@ -16,8 +18,14 @@ type ModelDriver interface {
|
||||
ChatWithMessages(modelName string, apiKey *string, messages []Message, modelConfig *ChatConfig) (string, error)
|
||||
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
|
||||
ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error
|
||||
// Encode encodes a list of texts into embeddings
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error)
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error)
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error)
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error)
|
||||
// List suppported models
|
||||
ListModels(apiConfig *APIConfig) ([]string, error)
|
||||
|
||||
@@ -64,3 +72,73 @@ type APIConfig struct {
|
||||
|
||||
type EmbeddingConfig struct {
|
||||
}
|
||||
|
||||
// EmbeddingModel wraps a ModelDriver with embedding-specific configuration
|
||||
type EmbeddingModel struct {
|
||||
ModelDriver ModelDriver
|
||||
ModelName string
|
||||
APIConfig *APIConfig
|
||||
}
|
||||
|
||||
// NewEmbeddingModel creates a new EmbeddingModel
|
||||
func NewEmbeddingModel(driver ModelDriver, modelName string, apiConfig *APIConfig) *EmbeddingModel {
|
||||
return &EmbeddingModel{
|
||||
ModelDriver: driver,
|
||||
ModelName: modelName,
|
||||
APIConfig: apiConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings
|
||||
func (e *EmbeddingModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return e.ModelDriver.EncodeToEmbedding(modelName, texts, apiConfig, nil)
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding
|
||||
func (e *EmbeddingModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
embeddings, err := e.ModelDriver.Encode(modelName, []string{query}, apiConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return nil, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
return embeddings[0], nil
|
||||
}
|
||||
|
||||
// RerankModel wraps a ModelDriver with rerank-specific configuration
|
||||
type RerankModel struct {
|
||||
ModelDriver ModelDriver
|
||||
ModelName string
|
||||
APIConfig *APIConfig
|
||||
}
|
||||
|
||||
// NewRerankModel creates a new RerankModel
|
||||
func NewRerankModel(driver ModelDriver, modelName string, apiConfig *APIConfig) *RerankModel {
|
||||
return &RerankModel{
|
||||
ModelDriver: driver,
|
||||
ModelName: modelName,
|
||||
APIConfig: apiConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// Rerank calculates similarity between query and texts
|
||||
func (r *RerankModel) Rerank(query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return r.ModelDriver.Rerank(&r.ModelName, query, texts, apiConfig)
|
||||
}
|
||||
|
||||
// ChatModel wraps a ModelDriver with chat-specific configuration
|
||||
type ChatModel struct {
|
||||
ModelDriver ModelDriver
|
||||
ModelName string
|
||||
APIConfig *APIConfig
|
||||
}
|
||||
|
||||
// NewChatModel creates a new ChatModel
|
||||
func NewChatModel(driver ModelDriver, modelName string, apiConfig *APIConfig) *ChatModel {
|
||||
return &ChatModel{
|
||||
ModelDriver: driver,
|
||||
ModelName: modelName,
|
||||
APIConfig: apiConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -292,7 +292,7 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, message *string, apiCon
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/chat/completions", z.BaseURL[region])
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(z.BaseURL[region], "/"), z.URLSuffix.Chat)
|
||||
|
||||
// Build request body with streaming enabled
|
||||
reqBody := map[string]interface{}{
|
||||
@@ -440,7 +440,7 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName *string, texts []string, apiC
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/embedding", z.BaseURL[region])
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(z.BaseURL[region], "/"), z.URLSuffix.Embedding)
|
||||
|
||||
embeddings := make([][]float64, len(texts))
|
||||
|
||||
@@ -518,6 +518,23 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName *string, texts []string, apiC
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings (convenience method)
|
||||
func (z *ZhipuAIModel) Encode(modelName *string, texts []string, apiConfig *APIConfig) ([][]float64, error) {
|
||||
return z.EncodeToEmbedding(modelName, texts, apiConfig, nil)
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding (convenience method)
|
||||
func (z *ZhipuAIModel) EncodeQuery(modelName *string, query string, apiConfig *APIConfig) ([]float64, error) {
|
||||
embeddings, err := z.Encode(modelName, []string{query}, apiConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return nil, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
return embeddings[0], nil
|
||||
}
|
||||
|
||||
func (z *ZhipuAIModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", z.Name())
|
||||
}
|
||||
@@ -559,3 +576,8 @@ func (z *ZhipuAIModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
func (z *ZhipuAIModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("%s, Rerank not implemented", z.Name())
|
||||
}
|
||||
|
||||
@@ -16,6 +16,10 @@
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
// ModelType represents the type of model
|
||||
type ModelType string
|
||||
|
||||
@@ -39,9 +43,9 @@ const (
|
||||
// EmbeddingModel interface for embedding models
|
||||
type EmbeddingModel interface {
|
||||
// Encode encodes a list of texts into embeddings
|
||||
Encode(texts []string) ([][]float64, error)
|
||||
Encode(modelName *string, texts []string, apiConfig *models.APIConfig) ([][]float64, error)
|
||||
// EncodeQuery encodes a single query string into embedding
|
||||
EncodeQuery(query string) ([]float64, error)
|
||||
EncodeQuery(modelName *string, query string, apiConfig *models.APIConfig) ([]float64, error)
|
||||
}
|
||||
|
||||
// ChatModel interface for chat models
|
||||
@@ -54,8 +58,8 @@ type ChatModel interface {
|
||||
|
||||
// RerankModel interface for rerank models
|
||||
type RerankModel interface {
|
||||
// Similarity calculates similarity between query and texts
|
||||
Similarity(query string, texts []string) ([]float64, error)
|
||||
// Rerank calculates similarity between query and texts
|
||||
Rerank(query string, texts []string, apiConfig *models.APIConfig) ([]float64, error)
|
||||
}
|
||||
|
||||
// ModelConfig represents configuration for a model
|
||||
|
||||
@@ -607,6 +607,9 @@ func (h *ProviderHandler) EnableOrDisableModel(c *gin.Context) {
|
||||
}
|
||||
|
||||
modelName := c.Param("model_name")
|
||||
if modelName != "" {
|
||||
modelName = strings.TrimPrefix(modelName, "/")
|
||||
}
|
||||
if modelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
|
||||
@@ -217,7 +217,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
provider.PUT("/:provider_name/instances/:instance_name", r.providerHandler.AlterProviderInstance)
|
||||
provider.DELETE("/:provider_name/instances", r.providerHandler.DropProviderInstance)
|
||||
provider.GET("/:provider_name/instances/:instance_name/models", r.providerHandler.ListInstanceModels)
|
||||
provider.PATCH("/:provider_name/instances/:instance_name/models/:model_name", r.providerHandler.EnableOrDisableModel)
|
||||
provider.PATCH("/:provider_name/instances/:instance_name/models/*model_name", r.providerHandler.EnableOrDisableModel)
|
||||
provider.POST("/:provider_name/instances/:instance_name/models", r.providerHandler.ChatToModel)
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/entity/models"
|
||||
"ragflow/internal/server"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -40,7 +41,6 @@ import (
|
||||
type ChunkService struct {
|
||||
docEngine engine.DocEngine
|
||||
engineType server.EngineType
|
||||
modelProvider ModelProvider
|
||||
embeddingCache *utility.EmbeddingLRU
|
||||
kbDAO *dao.KnowledgebaseDAO
|
||||
userTenantDAO *dao.UserTenantDAO
|
||||
@@ -53,7 +53,6 @@ func NewChunkService() *ChunkService {
|
||||
return &ChunkService{
|
||||
docEngine: engine.Get(),
|
||||
engineType: cfg.DocEngine.Type,
|
||||
modelProvider: NewModelProvider(),
|
||||
embeddingCache: utility.NewEmbeddingLRU(1000), // default capacity
|
||||
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||
userTenantDAO: dao.NewUserTenantDAO(),
|
||||
@@ -340,8 +339,8 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) (
|
||||
}
|
||||
|
||||
// Get embedding model for the tenant
|
||||
var embeddingModel entity.EmbeddingModel
|
||||
embeddingModel, err = s.modelProvider.GetEmbeddingModel(ctx, tenantIDs[0], embdID)
|
||||
modelProviderSvc := NewModelProviderService()
|
||||
embeddingModel, err := modelProviderSvc.GetEmbeddingModel(tenantIDs[0], embdID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model: %w", err)
|
||||
}
|
||||
@@ -350,7 +349,7 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) (
|
||||
zap.String("embdID", embdID))
|
||||
|
||||
// Get rerank model if RerankID is specified
|
||||
var rerankModel nlp.RerankModel
|
||||
var rerankModel *models.RerankModel
|
||||
var rerankCompositeName string
|
||||
if req.TenantRerankID != nil && *req.TenantRerankID != "" {
|
||||
tenantRerankIDInt, parseErr := strconv.ParseInt(*req.TenantRerankID, 10, 64)
|
||||
@@ -361,19 +360,16 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) (
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get rerank model by tenant_rerank_id: %w", err)
|
||||
}
|
||||
rerankModel, err = s.modelProvider.GetRerankModel(ctx, tenantIDs[0], rerankCompositeName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get rerank model by tenant_rerank_id: %w", err)
|
||||
}
|
||||
} else if req.RerankID != nil && *req.RerankID != "" {
|
||||
var err error
|
||||
_, rerankCompositeName, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], *req.RerankID, entity.ModelTypeRerank)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get rerank model by rerank_id: %w", err)
|
||||
}
|
||||
rerankModel, err = s.modelProvider.GetRerankModel(ctx, tenantIDs[0], rerankCompositeName)
|
||||
}
|
||||
if rerankCompositeName != "" {
|
||||
rerankModel, err = modelProviderSvc.GetRerankModel(tenantIDs[0], rerankCompositeName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get rerank model by rerank_id: %w", err)
|
||||
return nil, fmt.Errorf("failed to get rerank model: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,26 +17,29 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"ragflow/internal/entity"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
// ModelBundle provides a unified interface for various model operations
|
||||
// Similar to Python's LLMBundle but with a more generic name
|
||||
type ModelBundle struct {
|
||||
tenantID string
|
||||
modelType entity.ModelType
|
||||
modelName string
|
||||
model interface{} // underlying model instance
|
||||
tenantID string
|
||||
modelType entity.ModelType
|
||||
modelName string
|
||||
model interface{} // underlying model instance
|
||||
apiConfig *modelModule.APIConfig
|
||||
embeddingConfig *modelModule.EmbeddingConfig
|
||||
}
|
||||
|
||||
// NewModelBundle creates a new ModelBundle for the given tenant and model type
|
||||
// If modelName is empty, uses the default model for the tenant and type
|
||||
func NewModelBundle(tenantID string, modelType entity.ModelType, modelName ...string) (*ModelBundle, error) {
|
||||
bundle := &ModelBundle{
|
||||
tenantID: tenantID,
|
||||
modelType: modelType,
|
||||
tenantID: tenantID,
|
||||
modelType: modelType,
|
||||
embeddingConfig: &modelModule.EmbeddingConfig{},
|
||||
}
|
||||
|
||||
// Use provided model name if available
|
||||
@@ -45,26 +48,29 @@ func NewModelBundle(tenantID string, modelType entity.ModelType, modelName ...st
|
||||
}
|
||||
|
||||
// Get model instance based on type
|
||||
provider := NewModelProvider()
|
||||
modelProviderSvc := NewModelProviderService()
|
||||
switch modelType {
|
||||
case entity.ModelTypeEmbedding:
|
||||
embeddingModel, err := provider.GetEmbeddingModel(context.Background(), tenantID, bundle.modelName)
|
||||
embd, err := modelProviderSvc.GetEmbeddingModel(tenantID, bundle.modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model: %w", err)
|
||||
}
|
||||
bundle.model = embeddingModel
|
||||
bundle.model = embd.ModelDriver
|
||||
bundle.apiConfig = embd.APIConfig
|
||||
case entity.ModelTypeChat:
|
||||
chatModel, err := provider.GetChatModel(context.Background(), tenantID, bundle.modelName)
|
||||
chatMdl, err := modelProviderSvc.GetChatModel(tenantID, bundle.modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get chat model: %w", err)
|
||||
}
|
||||
bundle.model = chatModel
|
||||
bundle.model = chatMdl.ModelDriver
|
||||
bundle.apiConfig = chatMdl.APIConfig
|
||||
case entity.ModelTypeRerank:
|
||||
rerankModel, err := provider.GetRerankModel(context.Background(), tenantID, bundle.modelName)
|
||||
rerankMdl, err := modelProviderSvc.GetRerankModel(tenantID, bundle.modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get rerank model: %w", err)
|
||||
}
|
||||
bundle.model = rerankModel
|
||||
bundle.model = rerankMdl.ModelDriver
|
||||
bundle.apiConfig = rerankMdl.APIConfig
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported model type: %s", modelType)
|
||||
}
|
||||
@@ -84,7 +90,7 @@ func (b *ModelBundle) Encode(texts []string) ([][]float64, int64, error) {
|
||||
return nil, 0, fmt.Errorf("model is not an embedding model")
|
||||
}
|
||||
|
||||
embeddings, err := embeddingModel.Encode(texts)
|
||||
embeddings, err := embeddingModel.Encode(&b.modelName, texts, b.apiConfig)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -111,7 +117,7 @@ func (b *ModelBundle) EncodeQuery(query string) ([]float64, int64, error) {
|
||||
return nil, 0, fmt.Errorf("model is not an embedding model")
|
||||
}
|
||||
|
||||
embedding, err := embeddingModel.EncodeQuery(query)
|
||||
embedding, err := embeddingModel.EncodeQuery(&b.modelName, query, b.apiConfig)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -144,10 +150,10 @@ func (b *ModelBundle) Chat(system string, history []map[string]string, genConf m
|
||||
return response, tokenCount, nil
|
||||
}
|
||||
|
||||
// Similarity calculates similarity between query and texts
|
||||
func (b *ModelBundle) Similarity(query string, texts []string) ([]float64, int64, error) {
|
||||
// Rerank calculates similarity between query and texts
|
||||
func (b *ModelBundle) Rerank(query string, texts []string) ([]float64, int64, error) {
|
||||
if b.modelType != entity.ModelTypeRerank {
|
||||
return nil, 0, fmt.Errorf("model type %s does not support similarity", b.modelType)
|
||||
return nil, 0, fmt.Errorf("model type %s does not support rerank", b.modelType)
|
||||
}
|
||||
|
||||
rerankModel, ok := b.model.(entity.RerankModel)
|
||||
@@ -155,7 +161,7 @@ func (b *ModelBundle) Similarity(query string, texts []string) ([]float64, int64
|
||||
return nil, 0, fmt.Errorf("model is not a rerank model")
|
||||
}
|
||||
|
||||
similarities, err := rerankModel.Similarity(query, texts)
|
||||
similarities, err := rerankModel.Rerank(query, texts, b.apiConfig)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
@@ -17,45 +17,17 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/service/models"
|
||||
)
|
||||
|
||||
// ModelProvider provides model instances based on tenant and model type
|
||||
type ModelProvider interface {
|
||||
// GetEmbeddingModel returns an embedding model for the given tenant
|
||||
GetEmbeddingModel(ctx context.Context, tenantID string, modelName string) (entity.EmbeddingModel, error)
|
||||
// GetChatModel returns a chat model for the given tenant
|
||||
GetChatModel(ctx context.Context, tenantID string, modelName string) (entity.ChatModel, error)
|
||||
// GetRerankModel returns a rerank model for the given tenant
|
||||
GetRerankModel(ctx context.Context, tenantID string, modelName string) (entity.RerankModel, error)
|
||||
}
|
||||
|
||||
// ModelProviderImpl implements ModelProvider
|
||||
type ModelProviderImpl struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewModelProvider creates a new ModelProvider
|
||||
func NewModelProvider() *ModelProviderImpl {
|
||||
return &ModelProviderImpl{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// parseModelName parses a composite model name in format "model_name@provider"
|
||||
// Returns modelName and provider separately
|
||||
func parseModelName(compositeName string) (modelName, provider string, err error) {
|
||||
@@ -69,111 +41,6 @@ func parseModelName(compositeName string) (modelName, provider string, err error
|
||||
}
|
||||
}
|
||||
|
||||
// GetEmbeddingModel returns an embedding model for the given tenant
|
||||
func (p *ModelProviderImpl) GetEmbeddingModel(ctx context.Context, tenantID string, compositeModelName string) (entity.EmbeddingModel, error) {
|
||||
// Parse composite model name to extract model name and provider
|
||||
modelName, provider, err := parseModelName(compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get API key and configuration
|
||||
embeddingModel, err := dao.NewTenantLLMDAO().GetByTenantFactoryAndModelName(tenantID, provider, modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
apiKey := embeddingModel.APIKey
|
||||
if apiKey == nil || *apiKey == "" {
|
||||
return nil, fmt.Errorf("no API key found for tenant %s and model %s", tenantID, compositeModelName)
|
||||
}
|
||||
|
||||
// Get API base from TenantLLM if set, otherwise from model provider configuration
|
||||
apiBase := ""
|
||||
if embeddingModel.APIBase != nil && *embeddingModel.APIBase != "" {
|
||||
apiBase = *embeddingModel.APIBase
|
||||
} else {
|
||||
providerDAO := dao.NewModelProviderDAO()
|
||||
providerConfig := providerDAO.GetProviderByName(provider)
|
||||
if providerConfig == nil || providerConfig.DefaultURL == "" {
|
||||
return nil, fmt.Errorf("no API base found for provider %s", provider)
|
||||
}
|
||||
apiBase = providerConfig.DefaultURL
|
||||
}
|
||||
|
||||
return models.CreateEmbeddingModel(provider, *apiKey, apiBase, modelName, p.httpClient)
|
||||
}
|
||||
|
||||
// GetChatModel returns a chat model for the given tenant
|
||||
func (p *ModelProviderImpl) GetChatModel(ctx context.Context, tenantID string, compositeModelName string) (entity.ChatModel, error) {
|
||||
// Parse composite model name to extract model name and provider
|
||||
modelName, provider, err := parseModelName(compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get chat model from database
|
||||
chatModel, err := dao.NewTenantLLMDAO().GetByTenantFactoryAndModelName(tenantID, provider, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no chat model found for tenant %s and model %s: %w", tenantID, compositeModelName, err)
|
||||
}
|
||||
|
||||
apiKey := chatModel.APIKey
|
||||
if apiKey == nil || *apiKey == "" {
|
||||
return nil, fmt.Errorf("no API key found for tenant %s and model %s", tenantID, compositeModelName)
|
||||
}
|
||||
|
||||
// Get API base from TenantLLM if set, otherwise from model provider configuration
|
||||
apiBase := ""
|
||||
if chatModel.APIBase != nil && *chatModel.APIBase != "" {
|
||||
apiBase = *chatModel.APIBase
|
||||
} else {
|
||||
providerDAO := dao.NewModelProviderDAO()
|
||||
providerConfig := providerDAO.GetProviderByName(provider)
|
||||
if providerConfig == nil || providerConfig.DefaultURL == "" {
|
||||
return nil, fmt.Errorf("no API base found for provider %s", provider)
|
||||
}
|
||||
apiBase = providerConfig.DefaultURL
|
||||
}
|
||||
|
||||
return models.CreateChatModel(provider, *apiKey, apiBase, modelName, p.httpClient)
|
||||
}
|
||||
|
||||
// GetRerankModel returns a rerank model for the given tenant
|
||||
func (p *ModelProviderImpl) GetRerankModel(ctx context.Context, tenantID string, compositeModelName string) (entity.RerankModel, error) {
|
||||
// Parse composite model name to extract model name and provider
|
||||
modelName, provider, err := parseModelName(compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get rerank model from database
|
||||
rerankModel, err := dao.NewTenantLLMDAO().GetByTenantFactoryAndModelName(tenantID, provider, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no rerank model found for tenant %s and model %s: %w", tenantID, compositeModelName, err)
|
||||
}
|
||||
|
||||
apiKey := rerankModel.APIKey
|
||||
if apiKey == nil || *apiKey == "" {
|
||||
return nil, fmt.Errorf("no API key found for tenant %s and model %s", tenantID, compositeModelName)
|
||||
}
|
||||
|
||||
// Get API base from TenantLLM if set, otherwise from model provider configuration
|
||||
apiBase := ""
|
||||
if rerankModel.APIBase != nil && *rerankModel.APIBase != "" {
|
||||
apiBase = *rerankModel.APIBase
|
||||
} else {
|
||||
providerDAO := dao.NewModelProviderDAO()
|
||||
providerConfig := providerDAO.GetProviderByName(provider)
|
||||
if providerConfig == nil || providerConfig.DefaultURL == "" {
|
||||
return nil, fmt.Errorf("no API base found for provider %s", provider)
|
||||
}
|
||||
apiBase = providerConfig.DefaultURL
|
||||
}
|
||||
|
||||
return models.CreateRerankModel(provider, *apiKey, apiBase, modelName, p.httpClient)
|
||||
}
|
||||
|
||||
func NewModelProviderService() *ModelProviderService {
|
||||
return &ModelProviderService{
|
||||
modelProviderDAO: dao.NewTenantModelProviderDAO(),
|
||||
@@ -973,3 +840,94 @@ func (m *ModelProviderService) GetModelByName(modelName string, tenantID string)
|
||||
APIKey: *tenantLLM.APIKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetEmbeddingModel returns an EmbeddingModel wrapper for the given tenant
|
||||
func (m *ModelProviderService) GetEmbeddingModel(tenantID, compositeModelName string) (*modelModule.EmbeddingModel, error) {
|
||||
driver, modelName, apiConfig, err := m.getModelConfig(tenantID, compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return modelModule.NewEmbeddingModel(driver, modelName, apiConfig), nil
|
||||
}
|
||||
|
||||
// GetRerankModel returns a RerankModel wrapper for the given tenant
|
||||
func (m *ModelProviderService) GetRerankModel(tenantID, compositeModelName string) (*modelModule.RerankModel, error) {
|
||||
driver, modelName, apiConfig, err := m.getModelConfig(tenantID, compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return modelModule.NewRerankModel(driver, modelName, apiConfig), nil
|
||||
}
|
||||
|
||||
// GetChatModel returns a ChatModel wrapper for the given tenant
|
||||
func (m *ModelProviderService) GetChatModel(tenantID, compositeModelName string) (*modelModule.ChatModel, error) {
|
||||
driver, modelName, apiConfig, err := m.getModelConfig(tenantID, compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return modelModule.NewChatModel(driver, modelName, apiConfig), nil
|
||||
}
|
||||
|
||||
// getModelConfig returns the model driver, model name, and API config for a model
|
||||
func (m *ModelProviderService) getModelConfig(tenantID, compositeModelName string) (modelModule.ModelDriver, string, *modelModule.APIConfig, error) {
|
||||
modelName, providerName, err := parseModelName(compositeModelName)
|
||||
if err != nil {
|
||||
return nil, "", nil, err
|
||||
}
|
||||
|
||||
// Check if provider exists
|
||||
provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName)
|
||||
if err != nil {
|
||||
return nil, "", nil, err
|
||||
}
|
||||
if provider == nil {
|
||||
return nil, "", nil, fmt.Errorf("provider %s not found", providerName)
|
||||
}
|
||||
|
||||
instanceName := "default_instance"
|
||||
instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName)
|
||||
if err != nil {
|
||||
return nil, "", nil, err
|
||||
}
|
||||
if instance == nil {
|
||||
return nil, "", nil, fmt.Errorf("instance %s not found for provider %s", instanceName, providerName)
|
||||
}
|
||||
|
||||
_, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName)
|
||||
if err != nil {
|
||||
providerInfo := dao.GetModelProviderManager().FindProvider(providerName)
|
||||
if providerInfo == nil {
|
||||
return nil, "", nil, fmt.Errorf("provider %s not found", providerName)
|
||||
}
|
||||
|
||||
_, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName)
|
||||
if err != nil {
|
||||
return nil, "", nil, fmt.Errorf("provider %s model %s not found", providerName, modelName)
|
||||
}
|
||||
|
||||
var extra map[string]string
|
||||
err = json.Unmarshal([]byte(instance.Extra), &extra)
|
||||
if err != nil {
|
||||
return nil, "", nil, err
|
||||
}
|
||||
region := extra["region"]
|
||||
|
||||
apiConfig := &modelModule.APIConfig{ApiKey: &instance.APIKey, Region: ®ion}
|
||||
return providerInfo.ModelDriver, modelName, apiConfig, nil
|
||||
}
|
||||
|
||||
var extra map[string]string
|
||||
err = json.Unmarshal([]byte(instance.Extra), &extra)
|
||||
if err != nil {
|
||||
return nil, "", nil, err
|
||||
}
|
||||
region := extra["region"]
|
||||
|
||||
providerInfo := dao.GetModelProviderManager().FindProvider(providerName)
|
||||
if providerInfo == nil {
|
||||
return nil, "", nil, fmt.Errorf("provider %s not found", providerName)
|
||||
}
|
||||
|
||||
apiConfig := &modelModule.APIConfig{ApiKey: &instance.APIKey, Region: ®ion}
|
||||
return providerInfo.ModelDriver, modelName, apiConfig, nil
|
||||
}
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterEmbeddingModelFactory("DeepSeek", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel {
|
||||
return &openAIEmbeddingModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
model: modelName,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"ragflow/internal/entity"
|
||||
|
||||
"sync"
|
||||
)
|
||||
|
||||
// EmbeddingModelFactory creates an EmbeddingModel instance
|
||||
type EmbeddingModelFactory func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel
|
||||
|
||||
// ChatModelFactory creates a ChatModel instance
|
||||
type ChatModelFactory func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.ChatModel
|
||||
|
||||
// RerankModelFactory creates a RerankModel instance
|
||||
type RerankModelFactory func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.RerankModel
|
||||
|
||||
var (
|
||||
embeddingModelFactories = make(map[string]EmbeddingModelFactory)
|
||||
chatModelFactories = make(map[string]ChatModelFactory)
|
||||
rerankModelFactories = make(map[string]RerankModelFactory)
|
||||
factoryMu sync.RWMutex
|
||||
)
|
||||
|
||||
// RegisterEmbeddingModelFactory registers a factory for a provider name.
|
||||
// Should be called from init() functions of provider implementations.
|
||||
func RegisterEmbeddingModelFactory(providerName string, factory EmbeddingModelFactory) {
|
||||
factoryMu.Lock()
|
||||
defer factoryMu.Unlock()
|
||||
embeddingModelFactories[providerName] = factory
|
||||
}
|
||||
|
||||
// RegisterChatModelFactory registers a factory for a chat provider name.
|
||||
// Should be called from init() functions of provider implementations.
|
||||
func RegisterChatModelFactory(providerName string, factory ChatModelFactory) {
|
||||
factoryMu.Lock()
|
||||
defer factoryMu.Unlock()
|
||||
chatModelFactories[providerName] = factory
|
||||
}
|
||||
|
||||
// RegisterRerankModelFactory registers a factory for a rerank provider name.
|
||||
// Should be called from init() functions of provider implementations.
|
||||
func RegisterRerankModelFactory(providerName string, factory RerankModelFactory) {
|
||||
factoryMu.Lock()
|
||||
defer factoryMu.Unlock()
|
||||
rerankModelFactories[providerName] = factory
|
||||
}
|
||||
|
||||
// GetEmbeddingModelFactory returns the factory for the given provider name.
|
||||
// Returns nil if not found.
|
||||
func GetEmbeddingModelFactory(providerName string) EmbeddingModelFactory {
|
||||
factoryMu.RLock()
|
||||
defer factoryMu.RUnlock()
|
||||
return embeddingModelFactories[providerName]
|
||||
}
|
||||
|
||||
// GetChatModelFactory returns the factory for the given chat provider name.
|
||||
// Returns nil if not found.
|
||||
func GetChatModelFactory(providerName string) ChatModelFactory {
|
||||
factoryMu.RLock()
|
||||
defer factoryMu.RUnlock()
|
||||
return chatModelFactories[providerName]
|
||||
}
|
||||
|
||||
// GetRerankModelFactory returns the factory for the given rerank provider name.
|
||||
// Returns nil if not found.
|
||||
func GetRerankModelFactory(providerName string) RerankModelFactory {
|
||||
factoryMu.RLock()
|
||||
defer factoryMu.RUnlock()
|
||||
return rerankModelFactories[providerName]
|
||||
}
|
||||
|
||||
// CreateEmbeddingModel creates an EmbeddingModel instance for the given provider.
|
||||
// Returns error if provider not registered.
|
||||
func CreateEmbeddingModel(providerName, apiKey, apiBase, modelName string, httpClient *http.Client) (entity.EmbeddingModel, error) {
|
||||
factory := GetEmbeddingModelFactory(providerName)
|
||||
if factory == nil {
|
||||
return nil, fmt.Errorf("no embedding model factory registered for provider %s", providerName)
|
||||
}
|
||||
return factory(apiKey, apiBase, modelName, httpClient), nil
|
||||
}
|
||||
|
||||
// CreateChatModel creates a ChatModel instance for the given provider.
|
||||
// Returns error if provider not registered.
|
||||
func CreateChatModel(providerName, apiKey, apiBase, modelName string, httpClient *http.Client) (entity.ChatModel, error) {
|
||||
factory := GetChatModelFactory(providerName)
|
||||
if factory == nil {
|
||||
return nil, fmt.Errorf("no chat model factory registered for provider %s", providerName)
|
||||
}
|
||||
return factory(apiKey, apiBase, modelName, httpClient), nil
|
||||
}
|
||||
|
||||
// CreateRerankModel creates a RerankModel instance for the given provider.
|
||||
// Returns error if provider not registered.
|
||||
func CreateRerankModel(providerName, apiKey, apiBase, modelName string, httpClient *http.Client) (entity.RerankModel, error) {
|
||||
factory := GetRerankModelFactory(providerName)
|
||||
if factory == nil {
|
||||
return nil, fmt.Errorf("no rerank model factory registered for provider %s", providerName)
|
||||
}
|
||||
return factory(apiKey, apiBase, modelName, httpClient), nil
|
||||
}
|
||||
@@ -1,127 +0,0 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"ragflow/internal/entity"
|
||||
|
||||
"strings"
|
||||
)
|
||||
|
||||
// giteeEmbeddingModel implements EmbeddingModel for GiteeAI API (assumed OpenAI-compatible)
|
||||
type giteeEmbeddingModel struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// GiteeEmbeddingRequest represents GiteeAI embedding request
|
||||
type GiteeEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
EncodeFormat string `json:"encode_format"`
|
||||
}
|
||||
|
||||
// GiteeEmbeddingResponse represents GiteeAI embedding response
|
||||
type GiteeEmbeddingResponse struct {
|
||||
Data []struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings using GiteeAI API
|
||||
func (m *giteeEmbeddingModel) Encode(texts []string) ([][]float64, error) {
|
||||
if len(texts) == 0 {
|
||||
return [][]float64{}, nil
|
||||
}
|
||||
|
||||
reqBody := GiteeEmbeddingRequest{
|
||||
Model: m.model,
|
||||
Input: texts,
|
||||
EncodeFormat: "float",
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", m.apiBase, strings.NewReader(string(jsonData)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+m.apiKey)
|
||||
|
||||
resp, err := m.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("GiteeAI API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var embeddingResp GiteeEmbeddingResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&embeddingResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
// Sort embeddings by index to ensure correct order
|
||||
embeddings := make([][]float64, len(texts))
|
||||
for _, data := range embeddingResp.Data {
|
||||
if data.Index < len(embeddings) {
|
||||
embeddings[data.Index] = data.Embedding
|
||||
}
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding
|
||||
func (m *giteeEmbeddingModel) EncodeQuery(query string) ([]float64, error) {
|
||||
embeddings, err := m.Encode([]string{query})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return nil, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
return embeddings[0], nil
|
||||
}
|
||||
|
||||
// init registers the GiteeAI embedding model factory
|
||||
func init() {
|
||||
RegisterEmbeddingModelFactory("GiteeAI", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel {
|
||||
return &giteeEmbeddingModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
model: modelName,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterEmbeddingModelFactory("Moonshot", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel {
|
||||
return &openAIEmbeddingModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
model: modelName,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterEmbeddingModelFactory("OpenAI-API-Compatible", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel {
|
||||
return &openAIEmbeddingModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
model: modelName,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,124 +0,0 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"ragflow/internal/entity"
|
||||
|
||||
"strings"
|
||||
)
|
||||
|
||||
// openAIEmbeddingModel implements EmbeddingModel for OpenAI API
|
||||
type openAIEmbeddingModel struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// OpenAIEmbeddingRequest represents OpenAI embedding request
|
||||
type OpenAIEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
// OpenAIEmbeddingResponse represents OpenAI embedding response
|
||||
type OpenAIEmbeddingResponse struct {
|
||||
Data []struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings using OpenAI API
|
||||
func (m *openAIEmbeddingModel) Encode(texts []string) ([][]float64, error) {
|
||||
if len(texts) == 0 {
|
||||
return [][]float64{}, nil
|
||||
}
|
||||
|
||||
reqBody := OpenAIEmbeddingRequest{
|
||||
Model: m.model,
|
||||
Input: texts,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", m.apiBase+"/embeddings", strings.NewReader(string(jsonData)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+m.apiKey)
|
||||
|
||||
resp, err := m.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("OpenAI API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var embeddingResp OpenAIEmbeddingResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&embeddingResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
// Sort embeddings by index to ensure correct order
|
||||
embeddings := make([][]float64, len(texts))
|
||||
for _, data := range embeddingResp.Data {
|
||||
if data.Index < len(embeddings) {
|
||||
embeddings[data.Index] = data.Embedding
|
||||
}
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding
|
||||
func (m *openAIEmbeddingModel) EncodeQuery(query string) ([]float64, error) {
|
||||
embeddings, err := m.Encode([]string{query})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return nil, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
return embeddings[0], nil
|
||||
}
|
||||
|
||||
// init registers the OpenAI embedding model factory
|
||||
func init() {
|
||||
RegisterEmbeddingModelFactory("OpenAI", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel {
|
||||
return &openAIEmbeddingModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
model: modelName,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,380 +0,0 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"ragflow/internal/entity"
|
||||
|
||||
"strings"
|
||||
)
|
||||
|
||||
// siliconflowEmbeddingModel implements EmbeddingModel for SILICONFLOW API (OpenAI-compatible)
|
||||
type siliconflowEmbeddingModel struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// siliconflowChatModel implements ChatModel for SILICONFLOW API
|
||||
type siliconflowChatModel struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// siliconflowRerankModel implements RerankModel for SILICONFLOW API
|
||||
type siliconflowRerankModel struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// SiliconflowEmbeddingRequest represents SILICONFLOW embedding request
|
||||
type SiliconflowEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
// SiliconflowEmbeddingResponse represents SILICONFLOW embedding response
|
||||
type SiliconflowEmbeddingResponse struct {
|
||||
Data []struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// SiliconflowChatRequest represents SILICONFLOW chat request
|
||||
type SiliconflowChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
// SiliconflowChatResponse represents SILICONFLOW chat response
|
||||
type SiliconflowChatResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
Code string `json:"code"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ChatMessage represents a chat message
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// SiliconflowRerankRequest represents SILICONFLOW rerank request
|
||||
type SiliconflowRerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN int `json:"top_n"`
|
||||
ReturnDocuments bool `json:"return_documents"`
|
||||
MaxChunksPerDoc int `json:"max_chunks_per_doc"`
|
||||
OverlapTokens int `json:"overlap_tokens"`
|
||||
}
|
||||
|
||||
// SiliconflowRerankResponse represents SILICONFLOW rerank response
|
||||
type SiliconflowRerankResponse struct {
|
||||
Results []struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings using SILICONFLOW API
|
||||
func (m *siliconflowEmbeddingModel) Encode(texts []string) ([][]float64, error) {
|
||||
if len(texts) == 0 {
|
||||
return [][]float64{}, nil
|
||||
}
|
||||
|
||||
reqBody := SiliconflowEmbeddingRequest{
|
||||
Model: m.model,
|
||||
Input: texts,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", m.apiBase+"/embeddings", strings.NewReader(string(jsonData)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+m.apiKey)
|
||||
|
||||
resp, err := m.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("SILICONFLOW API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var embeddingResp SiliconflowEmbeddingResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&embeddingResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
// Sort embeddings by index to ensure correct order
|
||||
embeddings := make([][]float64, len(texts))
|
||||
for _, data := range embeddingResp.Data {
|
||||
if data.Index < len(embeddings) {
|
||||
embeddings[data.Index] = data.Embedding
|
||||
}
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a single query string into embedding
|
||||
func (m *siliconflowEmbeddingModel) EncodeQuery(query string) ([]float64, error) {
|
||||
embeddings, err := m.Encode([]string{query})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return nil, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
return embeddings[0], nil
|
||||
}
|
||||
|
||||
// Chat sends a chat message and returns response
|
||||
func (m *siliconflowChatModel) Chat(system string, history []map[string]string, genConf map[string]interface{}) (string, error) {
|
||||
// Build messages array
|
||||
var messages []ChatMessage
|
||||
|
||||
// Add system message if provided
|
||||
if system != "" {
|
||||
messages = append(messages, ChatMessage{Role: "system", Content: system})
|
||||
}
|
||||
|
||||
// Add history messages
|
||||
for _, msg := range history {
|
||||
role := msg["role"]
|
||||
content := msg["content"]
|
||||
if role != "" && content != "" {
|
||||
messages = append(messages, ChatMessage{Role: role, Content: content})
|
||||
}
|
||||
}
|
||||
|
||||
// Extract generation config
|
||||
temperature := 0.7
|
||||
if temp, ok := genConf["temperature"].(float64); ok {
|
||||
temperature = temp
|
||||
}
|
||||
maxTokens := 1024
|
||||
if mt, ok := genConf["max_tokens"].(int); ok {
|
||||
maxTokens = mt
|
||||
}
|
||||
|
||||
// Build request
|
||||
reqBody := SiliconflowChatRequest{
|
||||
Model: m.model,
|
||||
Messages: messages,
|
||||
Temperature: temperature,
|
||||
MaxTokens: maxTokens,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
// Build URL - append /chat/completions if not already present
|
||||
url := m.apiBase
|
||||
if !strings.HasSuffix(url, "/chat/completions") {
|
||||
if !strings.HasSuffix(url, "/") {
|
||||
url += "/"
|
||||
}
|
||||
url += "chat/completions"
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, strings.NewReader(string(jsonData)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+m.apiKey)
|
||||
|
||||
resp, err := m.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("SILICONFLOW API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var chatResp SiliconflowChatResponse
|
||||
if err := json.Unmarshal(body, &chatResp); err != nil {
|
||||
return "", fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if chatResp.Error.Message != "" {
|
||||
return "", fmt.Errorf("chat error: %s", chatResp.Error.Message)
|
||||
}
|
||||
|
||||
if len(chatResp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no response choices returned")
|
||||
}
|
||||
|
||||
return chatResp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
// ChatStreamly sends a chat message and streams response
|
||||
func (m *siliconflowChatModel) ChatStreamly(system string, history []map[string]string, genConf map[string]interface{}) (<-chan string, error) {
|
||||
// For now, return a simple non-streaming implementation
|
||||
// Streaming can be implemented later with SSE support
|
||||
responseChan := make(chan string)
|
||||
|
||||
go func() {
|
||||
defer close(responseChan)
|
||||
response, err := m.Chat(system, history, genConf)
|
||||
if err != nil {
|
||||
responseChan <- "**ERROR**: " + err.Error()
|
||||
return
|
||||
}
|
||||
responseChan <- response
|
||||
}()
|
||||
|
||||
return responseChan, nil
|
||||
}
|
||||
|
||||
// Similarity calculates similarity scores between query and texts using SiliconFlow API
|
||||
func (m *siliconflowRerankModel) Similarity(query string, texts []string) ([]float64, error) {
|
||||
if len(texts) == 0 {
|
||||
return []float64{}, nil
|
||||
}
|
||||
|
||||
reqBody := SiliconflowRerankRequest{
|
||||
Model: m.model,
|
||||
Query: query,
|
||||
Documents: texts,
|
||||
TopN: len(texts),
|
||||
ReturnDocuments: false,
|
||||
MaxChunksPerDoc: 1024,
|
||||
OverlapTokens: 80,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
reqURL := m.apiBase
|
||||
if !strings.Contains(reqURL, "/rerank") {
|
||||
if !strings.HasSuffix(reqURL, "/") {
|
||||
reqURL += "/"
|
||||
}
|
||||
reqURL += "rerank"
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", reqURL, strings.NewReader(string(jsonData)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+m.apiKey)
|
||||
|
||||
resp, err := m.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("SiliconFlow Rerank API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
var rerankResp SiliconflowRerankResponse
|
||||
if err := json.Unmarshal(body, &rerankResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
scores := make([]float64, len(texts))
|
||||
for _, result := range rerankResp.Results {
|
||||
if result.Index >= 0 && result.Index < len(texts) {
|
||||
scores[result.Index] = result.RelevanceScore
|
||||
}
|
||||
}
|
||||
|
||||
return scores, nil
|
||||
}
|
||||
|
||||
// init registers the SILICONFLOW model factories
|
||||
func init() {
|
||||
RegisterEmbeddingModelFactory("SILICONFLOW", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel {
|
||||
return &siliconflowEmbeddingModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
model: modelName,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
})
|
||||
|
||||
RegisterChatModelFactory("SILICONFLOW", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.ChatModel {
|
||||
return &siliconflowChatModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
model: modelName,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
})
|
||||
|
||||
RegisterRerankModelFactory("SILICONFLOW", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.RerankModel {
|
||||
return &siliconflowRerankModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
model: modelName,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterEmbeddingModelFactory("ZHIPU-AI", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel {
|
||||
return &openAIEmbeddingModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
model: modelName,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -23,18 +23,12 @@ import (
|
||||
"strings"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity/models"
|
||||
"ragflow/internal/logger"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RerankModel defines the interface for reranker models
|
||||
// This matches model.RerankModel interface
|
||||
type RerankModel interface {
|
||||
// Similarity calculates similarity between query and texts
|
||||
Similarity(query string, texts []string) ([]float64, error)
|
||||
}
|
||||
|
||||
// SearchResult represents the result of a search operation
|
||||
type SearchResult struct {
|
||||
Total int
|
||||
@@ -60,7 +54,7 @@ type SearchResult struct {
|
||||
// - tsim: token similarity scores
|
||||
// - vsim: vector similarity scores
|
||||
func Rerank(
|
||||
rerankModel RerankModel,
|
||||
rerankModel *models.RerankModel,
|
||||
chunks []map[string]interface{},
|
||||
total int,
|
||||
keywords []string,
|
||||
@@ -94,7 +88,7 @@ func Rerank(
|
||||
|
||||
// RerankByModel performs reranking using a reranker model
|
||||
func RerankByModel(
|
||||
rerankModel RerankModel,
|
||||
rerankModel *models.RerankModel,
|
||||
chunks []map[string]interface{},
|
||||
query string,
|
||||
tkWeight, vtWeight float64,
|
||||
@@ -142,9 +136,9 @@ func RerankByModel(
|
||||
tsim = TokenSimilarity(keywords, insTw, qb)
|
||||
|
||||
// Get similarity scores from reranker model
|
||||
modelSim, err := rerankModel.Similarity(query, docs)
|
||||
modelSim, err := rerankModel.ModelDriver.Rerank(&rerankModel.ModelName, query, docs, rerankModel.APIConfig)
|
||||
if err != nil {
|
||||
logger.Error("RerankByModel: rerankModel.Similarity failed; falling back to token-only similarity", err)
|
||||
logger.Error("RerankByModel: rerankModel.Rerank failed; falling back to token-only similarity", err)
|
||||
// If model fails, fall back to token similarity only
|
||||
modelSim = make([]float64, len(tsim))
|
||||
}
|
||||
|
||||
@@ -20,13 +20,13 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/engine/types"
|
||||
"ragflow/internal/entity/models"
|
||||
"ragflow/internal/logger"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/engine/types"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/tokenizer"
|
||||
|
||||
"go.uber.org/zap"
|
||||
@@ -54,8 +54,8 @@ type RetrievalRequest struct {
|
||||
SimilarityThreshold *float64
|
||||
VectorSimilarityWeight *float64
|
||||
RankFeature *map[string]float64
|
||||
RerankModel RerankModel
|
||||
EmbeddingModel entity.EmbeddingModel
|
||||
RerankModel *models.RerankModel
|
||||
EmbeddingModel *models.EmbeddingModel
|
||||
Aggs *bool
|
||||
Highlight *bool
|
||||
}
|
||||
@@ -384,7 +384,7 @@ type RetrievalSearchRequest struct {
|
||||
SimilarityThreshold float64
|
||||
RankFeature map[string]float64
|
||||
Filter map[string]interface{}
|
||||
EmbeddingModel interface{}
|
||||
EmbeddingModel *models.EmbeddingModel
|
||||
}
|
||||
|
||||
type RetrievalSearchResult struct {
|
||||
@@ -489,7 +489,7 @@ func (s *RetrievalService) Search(ctx context.Context, req *RetrievalSearchReque
|
||||
if similarityForGetVector <= 0 {
|
||||
similarityForGetVector = 0.1
|
||||
}
|
||||
matchDense, err := s.GetVector(req.Question, req.EmbeddingModel.(entity.EmbeddingModel), topk, similarityForGetVector)
|
||||
matchDense, err := s.GetVector(req.Question, req.EmbeddingModel, topk, similarityForGetVector)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetVector failed: %w", err)
|
||||
}
|
||||
@@ -596,8 +596,8 @@ func (s *RetrievalService) Search(ctx context.Context, req *RetrievalSearchReque
|
||||
}
|
||||
|
||||
// GetVector computes query vector and returns MatchDenseExpr for hybrid search
|
||||
func (s *RetrievalService) GetVector(txt string, embModel entity.EmbeddingModel, topk int, similarity float64) (*types.MatchDenseExpr, error) {
|
||||
vector, err := embModel.EncodeQuery(txt)
|
||||
func (s *RetrievalService) GetVector(txt string, embModel *models.EmbeddingModel, topk int, similarity float64) (*types.MatchDenseExpr, error) {
|
||||
vector, err := embModel.ModelDriver.EncodeQuery(&embModel.ModelName, txt, embModel.APIConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user