mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat(go-api): add Go support for POST /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks (#16256)
## Summary
Add the Go implementation of `POST
/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks`.
This wires the full create-chunk path in Go:
- router and handler registration
- request/response structs
- chunk creation service logic
- embedding generation
- chunk insert into doc engine
- chunk/token counter increment
- `tag_feas` validation
- `image_base64` decoding and chunk image storage/merge
- unit tests for handler and service
## Testing
Unit tests:
- `/usr/local/go/bin/go test ./internal/handler`
- `/usr/local/go/bin/go test ./internal/service/chunk`
- `/usr/local/go/bin/go test ./internal/service`
- `/usr/local/go/bin/go test ./...`
All passed locally.
Manual curl checks:
- basic text chunk: Go passed
- chunk with `important_keywords` / `questions` / `tag_kwd` /
`tag_feas`: Go passed
- blank content validation: Go matched expected `code=102`
- invalid `image_base64` validation: Go matched expected `code=102`
- image upload and repeated image upload / merge path: Go passed twice
This commit is contained in:
@@ -17,6 +17,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"ragflow/internal/common"
|
"ragflow/internal/common"
|
||||||
@@ -38,6 +39,7 @@ type chunkService interface {
|
|||||||
UpdateChunk(req *service.UpdateChunkRequest, userID string) error
|
UpdateChunk(req *service.UpdateChunkRequest, userID string) error
|
||||||
RemoveChunks(req *service.RemoveChunksRequest, userID string) (int64, error)
|
RemoveChunks(req *service.RemoveChunksRequest, userID string) (int64, error)
|
||||||
Parse(userID, datasetID string, req *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error)
|
Parse(userID, datasetID string, req *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error)
|
||||||
|
AddChunk(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error)
|
||||||
StopParsing(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error)
|
StopParsing(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -788,3 +790,131 @@ func (h *ChunkHandler) RemoveChunks(c *gin.Context) {
|
|||||||
"message": "success",
|
"message": "success",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func addChunkStringField(rawBody map[string]json.RawMessage, field string) (string, error) {
|
||||||
|
raw, ok := rawBody[field]
|
||||||
|
if !ok {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
var value string
|
||||||
|
if err := json.Unmarshal(raw, &value); err != nil {
|
||||||
|
return "", fmt.Errorf("`%s` must be a string", field)
|
||||||
|
}
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addChunkStringPtrField(rawBody map[string]json.RawMessage, field string) (*string, error) {
|
||||||
|
raw, ok := rawBody[field]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var value string
|
||||||
|
if err := json.Unmarshal(raw, &value); err != nil {
|
||||||
|
return nil, fmt.Errorf("`%s` must be a string", field)
|
||||||
|
}
|
||||||
|
return &value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addChunkStringListField(rawBody map[string]json.RawMessage, field, listMessage, elementMessage string) ([]string, error) {
|
||||||
|
raw, ok := rawBody[field]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var values []interface{}
|
||||||
|
if err := json.Unmarshal(raw, &values); err != nil {
|
||||||
|
return nil, errors.New(listMessage)
|
||||||
|
}
|
||||||
|
result := make([]string, len(values))
|
||||||
|
for i, value := range values {
|
||||||
|
str, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New(elementMessage)
|
||||||
|
}
|
||||||
|
result[i] = str
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addChunkResponseMessage(code common.ErrorCode, err error) string {
|
||||||
|
if code == common.CodeServerError {
|
||||||
|
common.Warn("add chunk failed", zap.String("error", err.Error()))
|
||||||
|
return "Failed to add chunk"
|
||||||
|
}
|
||||||
|
return err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ChunkHandler) AddChunk(c *gin.Context) {
|
||||||
|
user, errorCode, errorMessage := GetUser(c)
|
||||||
|
if errorCode != common.CodeSuccess {
|
||||||
|
jsonError(c, errorCode, errorMessage)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID := user.ID
|
||||||
|
datasetID, documentID := strings.TrimSpace(c.Param("dataset_id")), strings.TrimSpace(c.Param("document_id"))
|
||||||
|
|
||||||
|
var rawBody map[string]json.RawMessage
|
||||||
|
if err := json.NewDecoder(c.Request.Body).Decode(&rawBody); err != nil {
|
||||||
|
jsonError(c, common.CodeArgumentError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
content, err := addChunkStringField(rawBody, "content")
|
||||||
|
if err != nil {
|
||||||
|
jsonError(c, common.CodeArgumentError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
importantKeywords, err := addChunkStringListField(rawBody, "important_keywords", "`important_keywords` is required to be a list", "`important_keywords` must be a list of strings")
|
||||||
|
if err != nil {
|
||||||
|
jsonError(c, common.CodeArgumentError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
questions, err := addChunkStringListField(rawBody, "questions", "`questions` is required to be a list", "`questions` must be a list of strings")
|
||||||
|
if err != nil {
|
||||||
|
jsonError(c, common.CodeArgumentError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tagKwd, err := addChunkStringListField(rawBody, "tag_kwd", "`tag_kwd` is required to be a list", "`tag_kwd` must be a list of strings")
|
||||||
|
if err != nil {
|
||||||
|
jsonError(c, common.CodeArgumentError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
imageBase64, err := addChunkStringPtrField(rawBody, "image_base64")
|
||||||
|
if err != nil {
|
||||||
|
jsonError(c, common.CodeArgumentError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var tagFeas interface{}
|
||||||
|
if raw, ok := rawBody["tag_feas"]; ok {
|
||||||
|
if err := json.Unmarshal(raw, &tagFeas); err != nil {
|
||||||
|
jsonError(c, common.CodeArgumentError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req := service.AddChunkRequest{
|
||||||
|
DatasetID: datasetID,
|
||||||
|
DocumentID: documentID,
|
||||||
|
Content: content,
|
||||||
|
ImportantKeywords: importantKeywords,
|
||||||
|
Questions: questions,
|
||||||
|
TagKwd: tagKwd,
|
||||||
|
TagFeas: tagFeas,
|
||||||
|
ImageBase64: imageBase64,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.chunkService.AddChunk(&req, userID)
|
||||||
|
if err != nil {
|
||||||
|
if codedErr, ok := err.(service.ErrorCoder); ok {
|
||||||
|
jsonError(c, codedErr.Code(), addChunkResponseMessage(codedErr.Code(), err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
jsonError(c, common.CodeServerError, addChunkResponseMessage(common.CodeServerError, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": 0,
|
||||||
|
"data": resp,
|
||||||
|
"message": "success",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
// Only the methods actually called by the test are set; others panic.
|
// Only the methods actually called by the test are set; others panic.
|
||||||
type mockChunkSvc struct {
|
type mockChunkSvc struct {
|
||||||
retrievalTestFn func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
|
retrievalTestFn func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
|
||||||
|
addChunkFn func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error)
|
||||||
listFn func(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error)
|
listFn func(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error)
|
||||||
switchChunksFn func(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error
|
switchChunksFn func(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error
|
||||||
updateChunkFn func(req *service.UpdateChunkRequest, userID string) error
|
updateChunkFn func(req *service.UpdateChunkRequest, userID string) error
|
||||||
@@ -68,6 +69,12 @@ func (m *mockChunkSvc) StopParsing(userID, datasetID string, req service.StopPar
|
|||||||
func (m *mockChunkSvc) Parse(string, string, *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error) {
|
func (m *mockChunkSvc) Parse(string, string, *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error) {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
func (m *mockChunkSvc) AddChunk(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) {
|
||||||
|
if m.addChunkFn != nil {
|
||||||
|
return m.addChunkFn(req, userID)
|
||||||
|
}
|
||||||
|
return &service.AddChunkResponse{Chunk: map[string]interface{}{"id": "chunk-1"}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func setupChunkRetrievalTest(userID string) (*gin.Engine, *mockChunkSvc) {
|
func setupChunkRetrievalTest(userID string) (*gin.Engine, *mockChunkSvc) {
|
||||||
mock := &mockChunkSvc{}
|
mock := &mockChunkSvc{}
|
||||||
@@ -563,3 +570,194 @@ func TestChunkRetrieval_ServiceError(t *testing.T) {
|
|||||||
t.Errorf("internal error details leaked to response: %q", msg)
|
t.Errorf("internal error details leaked to response: %q", msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type addChunkTestError struct {
|
||||||
|
code common.ErrorCode
|
||||||
|
msg string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e addChunkTestError) Error() string { return e.msg }
|
||||||
|
func (e addChunkTestError) Code() common.ErrorCode { return e.code }
|
||||||
|
|
||||||
|
func TestChunkHandlerAddChunkSuccess(t *testing.T) {
|
||||||
|
mock := &mockChunkSvc{
|
||||||
|
addChunkFn: func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) {
|
||||||
|
if userID != "user1" {
|
||||||
|
t.Fatalf("userID = %q, want user1", userID)
|
||||||
|
}
|
||||||
|
if req.DatasetID != "kb1" || req.DocumentID != "doc1" || req.Content != "chunk body" {
|
||||||
|
t.Fatalf("unexpected request: %#v", req)
|
||||||
|
}
|
||||||
|
return &service.AddChunkResponse{Chunk: map[string]interface{}{"id": "chunk-1", "content": req.Content}}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(func(c *gin.Context) {
|
||||||
|
c.Set("user", &entity.User{ID: "user1"})
|
||||||
|
})
|
||||||
|
h := &ChunkHandler{chunkService: mock}
|
||||||
|
r.POST("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.AddChunk)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/datasets/kb1/documents/doc1/chunks", strings.NewReader(`{"content":"chunk body"}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if resp["code"] != float64(common.CodeSuccess) {
|
||||||
|
t.Fatalf("expected success code, got %v", resp["code"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChunkHandlerAddChunkPathIDsOverrideBody(t *testing.T) {
|
||||||
|
mock := &mockChunkSvc{
|
||||||
|
addChunkFn: func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) {
|
||||||
|
if req.DatasetID != "kb1" || req.DocumentID != "doc1" {
|
||||||
|
t.Fatalf("path IDs were not preserved: %#v", req)
|
||||||
|
}
|
||||||
|
if req.Content != "chunk body" {
|
||||||
|
t.Fatalf("unexpected content: %#v", req)
|
||||||
|
}
|
||||||
|
return &service.AddChunkResponse{Chunk: map[string]interface{}{"id": "chunk-1"}}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(func(c *gin.Context) {
|
||||||
|
c.Set("user", &entity.User{ID: "user1"})
|
||||||
|
})
|
||||||
|
h := &ChunkHandler{chunkService: mock}
|
||||||
|
r.POST("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.AddChunk)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/datasets/kb1/documents/doc1/chunks", strings.NewReader(`{"dataset_id":"evil-kb","document_id":"evil-doc","content":"chunk body"}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChunkHandlerAddChunkCodedError(t *testing.T) {
|
||||||
|
mock := &mockChunkSvc{
|
||||||
|
addChunkFn: func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) {
|
||||||
|
return nil, addChunkTestError{code: common.CodeDataError, msg: "`content` is required"}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(func(c *gin.Context) {
|
||||||
|
c.Set("user", &entity.User{ID: "user1"})
|
||||||
|
})
|
||||||
|
h := &ChunkHandler{chunkService: mock}
|
||||||
|
r.POST("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.AddChunk)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/datasets/kb1/documents/doc1/chunks", strings.NewReader(`{"content":" "}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if resp["code"] != float64(common.CodeDataError) {
|
||||||
|
t.Fatalf("expected data error code, got %v", resp["code"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChunkHandlerAddChunkValidatesListFields(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
wantMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "important keywords type",
|
||||||
|
body: `{"content":"chunk body","important_keywords":{}}`,
|
||||||
|
wantMsg: "`important_keywords` is required to be a list",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tag kwd element type",
|
||||||
|
body: `{"content":"chunk body","tag_kwd":[1]}`,
|
||||||
|
wantMsg: "`tag_kwd` must be a list of strings",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mock := &mockChunkSvc{
|
||||||
|
addChunkFn: func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) {
|
||||||
|
t.Fatal("service should not be called for invalid request")
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(func(c *gin.Context) {
|
||||||
|
c.Set("user", &entity.User{ID: "user1"})
|
||||||
|
})
|
||||||
|
h := &ChunkHandler{chunkService: mock}
|
||||||
|
r.POST("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.AddChunk)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/datasets/kb1/documents/doc1/chunks", strings.NewReader(tt.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if resp["message"] != tt.wantMsg {
|
||||||
|
t.Fatalf("message = %v, want %q", resp["message"], tt.wantMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChunkHandlerAddChunkHidesServerErrorDetails(t *testing.T) {
|
||||||
|
mock := &mockChunkSvc{
|
||||||
|
addChunkFn: func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) {
|
||||||
|
return nil, addChunkTestError{code: common.CodeServerError, msg: "encode chunk embedding: provider secret"}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(func(c *gin.Context) {
|
||||||
|
c.Set("user", &entity.User{ID: "user1"})
|
||||||
|
})
|
||||||
|
h := &ChunkHandler{chunkService: mock}
|
||||||
|
r.POST("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.AddChunk)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/datasets/kb1/documents/doc1/chunks", strings.NewReader(`{"content":"chunk body"}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if resp["message"] != "Failed to add chunk" {
|
||||||
|
t.Fatalf("message = %v, want generic failure", resp["message"])
|
||||||
|
}
|
||||||
|
if strings.Contains(w.Body.String(), "provider secret") {
|
||||||
|
t.Fatalf("server error details leaked: %s", w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -367,9 +367,6 @@ func TestEngine_ChainOf100(t *testing.T) {
|
|||||||
if m == nil {
|
if m == nil {
|
||||||
m = map[string]any{}
|
m = map[string]any{}
|
||||||
}
|
}
|
||||||
if m == nil {
|
|
||||||
m = map[string]any{}
|
|
||||||
}
|
|
||||||
m["value"] = i
|
m["value"] = i
|
||||||
return m, nil
|
return m, nil
|
||||||
})
|
})
|
||||||
@@ -380,14 +377,16 @@ func TestEngine_ChainOf100(t *testing.T) {
|
|||||||
|
|
||||||
engine := NewEngine(sg,
|
engine := NewEngine(sg,
|
||||||
WithRecursionLimit(150),
|
WithRecursionLimit(150),
|
||||||
WithMaxConcurrency(4),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result, err := engine.RunSync(context.Background(), map[string]any{"value": "start"})
|
result, err := engine.RunSync(context.Background(), map[string]any{"value": "start"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("RunSync: %v", err)
|
t.Fatalf("RunSync: %v", err)
|
||||||
}
|
}
|
||||||
m := result.(map[string]any)
|
m, ok := result.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected map[string]any result, got %T", result)
|
||||||
|
}
|
||||||
if v, ok := m["value"]; !ok || v.(int) != 99 {
|
if v, ok := m["value"]; !ok || v.(int) != 99 {
|
||||||
t.Fatalf("expected value=99, got %v", m["value"])
|
t.Fatalf("expected value=99, got %v", m["value"])
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -310,6 +310,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
|||||||
datasets.GET("/:dataset_id/documents/:document_id", r.documentHandler.DownloadDocument)
|
datasets.GET("/:dataset_id/documents/:document_id", r.documentHandler.DownloadDocument)
|
||||||
datasets.PATCH("/:dataset_id/documents/:document_id", r.documentHandler.UpdateDatasetDocument)
|
datasets.PATCH("/:dataset_id/documents/:document_id", r.documentHandler.UpdateDatasetDocument)
|
||||||
datasets.DELETE("/:dataset_id/documents", r.documentHandler.DeleteDocuments)
|
datasets.DELETE("/:dataset_id/documents", r.documentHandler.DeleteDocuments)
|
||||||
|
datasets.POST("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.AddChunk)
|
||||||
|
|
||||||
// Dataset document chunk
|
// Dataset document chunk
|
||||||
datasets.GET("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.ListChunks)
|
datasets.GET("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.ListChunks)
|
||||||
|
|||||||
@@ -20,11 +20,17 @@ import (
|
|||||||
"archive/zip"
|
"archive/zip"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/csv"
|
"encoding/csv"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"image/color"
|
||||||
|
"image/draw"
|
||||||
|
"image/jpeg"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"ragflow/internal/common"
|
"ragflow/internal/common"
|
||||||
@@ -36,10 +42,15 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
_ "image/gif"
|
||||||
|
_ "image/png"
|
||||||
|
|
||||||
"github.com/cespare/xxhash/v2"
|
"github.com/cespare/xxhash/v2"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"ragflow/internal/dao"
|
"ragflow/internal/dao"
|
||||||
"ragflow/internal/engine"
|
"ragflow/internal/engine"
|
||||||
@@ -56,6 +67,16 @@ const (
|
|||||||
maximumTaskPageNumber = maximumPageNumber * 1000
|
maximumTaskPageNumber = maximumPageNumber * 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var chunkImageMergeLocks = struct {
|
||||||
|
sync.Mutex
|
||||||
|
locks map[string]*chunkImageMergeLock
|
||||||
|
}{locks: make(map[string]*chunkImageMergeLock)}
|
||||||
|
|
||||||
|
type chunkImageMergeLock struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
refs int
|
||||||
|
}
|
||||||
|
|
||||||
// ChunkService chunk service
|
// ChunkService chunk service
|
||||||
type ChunkService struct {
|
type ChunkService struct {
|
||||||
docEngine engine.DocEngine
|
docEngine engine.DocEngine
|
||||||
@@ -74,6 +95,12 @@ type ChunkService struct {
|
|||||||
queueParseTasksFunc func(*entity.Document, string, string, int64) error
|
queueParseTasksFunc func(*entity.Document, string, string, int64) error
|
||||||
beginParseDocumentFunc func(string) error
|
beginParseDocumentFunc func(string) error
|
||||||
deleteTasksByDocIDsFunc func([]string) (int64, error)
|
deleteTasksByDocIDsFunc func([]string) (int64, error)
|
||||||
|
getEmbeddingModelFunc func(string, string) (*models.EmbeddingModel, error)
|
||||||
|
incrementChunkStatsFunc func(string, string, int64, int64, float64) error
|
||||||
|
storeChunkImageFunc func(string, string, []byte) error
|
||||||
|
tokenizeFunc func(string) (string, error)
|
||||||
|
fineGrainedTokenizeFunc func(string) (string, error)
|
||||||
|
numTokensFunc func(string) int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewChunkService creates chunk service
|
// NewChunkService creates chunk service
|
||||||
@@ -1711,3 +1738,382 @@ func (s *ChunkService) RemoveChunks(req *service.RemoveChunksRequest, userID str
|
|||||||
|
|
||||||
return deletedCount, nil
|
return deletedCount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ChunkService) AddChunk(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) {
|
||||||
|
if s.docEngine == nil {
|
||||||
|
return nil, addChunkError{code: common.CodeServerError, message: "doc engine not initialized"}
|
||||||
|
}
|
||||||
|
if req == nil {
|
||||||
|
return nil, addChunkError{code: common.CodeDataError, message: "invalid request payload"}
|
||||||
|
}
|
||||||
|
if !s.accessible(req.DatasetID, userID) {
|
||||||
|
return nil, addChunkError{code: common.CodeDataError, message: fmt.Sprintf("You don't own the dataset %s.", req.DatasetID)}
|
||||||
|
}
|
||||||
|
|
||||||
|
kb, err := s.getKnowledgebaseByID(req.DatasetID)
|
||||||
|
if err != nil || kb == nil {
|
||||||
|
return nil, addChunkError{code: common.CodeDataError, message: fmt.Sprintf("You don't own the dataset %s.", req.DatasetID)}
|
||||||
|
}
|
||||||
|
|
||||||
|
doc, err := s.documentDAO.GetByDocumentIDAndDatasetID(req.DocumentID, req.DatasetID)
|
||||||
|
if err != nil || doc == nil {
|
||||||
|
return nil, addChunkError{code: common.CodeDataError, message: fmt.Sprintf("You don't own the document %s.", req.DocumentID)}
|
||||||
|
}
|
||||||
|
|
||||||
|
content := strings.TrimSpace(req.Content)
|
||||||
|
if content == "" {
|
||||||
|
return nil, addChunkError{code: common.CodeDataError, message: "`content` is required"}
|
||||||
|
}
|
||||||
|
|
||||||
|
var tagFeas map[string]float64
|
||||||
|
if req.TagFeas != nil {
|
||||||
|
tagFeas, err = validateTagFeatures(req.TagFeas)
|
||||||
|
if err != nil {
|
||||||
|
return nil, addChunkError{code: common.CodeDataError, message: "`tag_feas` " + err.Error()}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkID := strconv.FormatUint(xxhash.Sum64([]byte(req.Content+req.DocumentID)), 16)
|
||||||
|
indexName := fmt.Sprintf("ragflow_%s", kb.TenantID)
|
||||||
|
contentLtks, err := s.tokenize(req.Content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("tokenize content: %v", err)}
|
||||||
|
}
|
||||||
|
contentSmLtks, err := s.fineGrainedTokenize(contentLtks)
|
||||||
|
if err != nil {
|
||||||
|
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("tokenize content fine-grained: %v", err)}
|
||||||
|
}
|
||||||
|
importantTks, err := s.tokenize(strings.Join(req.ImportantKeywords, " "))
|
||||||
|
if err != nil {
|
||||||
|
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("tokenize important keywords: %v", err)}
|
||||||
|
}
|
||||||
|
questionKwd := filterTrimmedStrings(req.Questions)
|
||||||
|
questionTks, err := s.tokenize(strings.Join(req.Questions, "\n"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("tokenize questions: %v", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
docName := ""
|
||||||
|
if doc.Name != nil {
|
||||||
|
docName = *doc.Name
|
||||||
|
}
|
||||||
|
importantKeywords := req.ImportantKeywords
|
||||||
|
if importantKeywords == nil {
|
||||||
|
importantKeywords = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkData := map[string]interface{}{
|
||||||
|
"id": chunkID,
|
||||||
|
"content_with_weight": req.Content,
|
||||||
|
"content_ltks": contentLtks,
|
||||||
|
"content_sm_ltks": contentSmLtks,
|
||||||
|
"important_kwd": importantKeywords,
|
||||||
|
"important_tks": importantTks,
|
||||||
|
"question_kwd": questionKwd,
|
||||||
|
"question_tks": questionTks,
|
||||||
|
"create_time": now.Format("2006-01-02 15:04:05"),
|
||||||
|
"create_timestamp_flt": float64(now.UnixNano()) / float64(time.Second),
|
||||||
|
"kb_id": req.DatasetID,
|
||||||
|
"docnm_kwd": docName,
|
||||||
|
"doc_id": req.DocumentID,
|
||||||
|
}
|
||||||
|
if req.TagKwd != nil {
|
||||||
|
chunkData["tag_kwd"] = req.TagKwd
|
||||||
|
}
|
||||||
|
if tagFeas != nil {
|
||||||
|
chunkData["tag_feas"] = tagFeas
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.ImageBase64 != nil {
|
||||||
|
imageBinary, err := decodeChunkImageBase64(*req.ImageBase64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, addChunkError{code: common.CodeDataError, message: err.Error()}
|
||||||
|
}
|
||||||
|
if err := s.storeChunkImage(req.DatasetID, chunkID, imageBinary); err != nil {
|
||||||
|
return nil, addChunkError{code: common.CodeDataError, message: "Failed to store chunk image"}
|
||||||
|
}
|
||||||
|
chunkData["img_id"] = fmt.Sprintf("%s-%s", req.DatasetID, chunkID)
|
||||||
|
chunkData["doc_type_kwd"] = "image"
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddingModel, err := s.getEmbeddingModel(kb.TenantID, kb.EmbdID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("get embedding model: %v", err)}
|
||||||
|
}
|
||||||
|
embeddingText := req.Content
|
||||||
|
if len(questionKwd) > 0 {
|
||||||
|
embeddingText = strings.Join(questionKwd, "\n")
|
||||||
|
}
|
||||||
|
embeddings, err := embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, []string{docName, embeddingText}, embeddingModel.APIConfig, &models.EmbeddingConfig{Dimension: 0})
|
||||||
|
if err != nil {
|
||||||
|
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("encode chunk embedding: %v", err)}
|
||||||
|
}
|
||||||
|
if len(embeddings) != 2 {
|
||||||
|
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("unexpected embedding count: %d", len(embeddings))}
|
||||||
|
}
|
||||||
|
mergedVec, err := mergeChunkEmbeddings(embeddings[0].Embedding, embeddings[1].Embedding)
|
||||||
|
if err != nil {
|
||||||
|
return nil, addChunkError{code: common.CodeServerError, message: err.Error()}
|
||||||
|
}
|
||||||
|
chunkData[fmt.Sprintf("q_%d_vec", len(mergedVec))] = mergedVec
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 600*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if _, err := s.docEngine.InsertChunks(ctx, []map[string]interface{}{chunkData}, indexName, req.DatasetID); err != nil {
|
||||||
|
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("insert chunk: %v", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenNum := int64(s.numTokens(req.Content))
|
||||||
|
if err := s.incrementChunkStats(req.DocumentID, req.DatasetID, tokenNum, 1, 0); err != nil {
|
||||||
|
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("increment chunk stats: %v", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
renamedChunk := map[string]interface{}{
|
||||||
|
"id": chunkID,
|
||||||
|
"content": req.Content,
|
||||||
|
"document_id": req.DocumentID,
|
||||||
|
"document": docName,
|
||||||
|
"important_keywords": importantKeywords,
|
||||||
|
"questions": questionKwd,
|
||||||
|
"dataset_id": req.DatasetID,
|
||||||
|
"create_timestamp": chunkData["create_timestamp_flt"],
|
||||||
|
"create_time": chunkData["create_time"],
|
||||||
|
}
|
||||||
|
if req.TagKwd != nil {
|
||||||
|
renamedChunk["tag_kwd"] = req.TagKwd
|
||||||
|
}
|
||||||
|
if imgID, ok := chunkData["img_id"]; ok {
|
||||||
|
renamedChunk["image_id"] = imgID
|
||||||
|
}
|
||||||
|
|
||||||
|
return &service.AddChunkResponse{Chunk: renamedChunk}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type addChunkError struct {
|
||||||
|
code common.ErrorCode
|
||||||
|
message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e addChunkError) Error() string {
|
||||||
|
return e.message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e addChunkError) Code() common.ErrorCode {
|
||||||
|
return e.code
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateTagFeatures(raw interface{}) (map[string]float64, error) {
|
||||||
|
parsed, ok := raw.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("must be an object mapping string tags to finite numeric scores")
|
||||||
|
}
|
||||||
|
cleaned := make(map[string]float64, len(parsed))
|
||||||
|
for key, value := range parsed {
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if key == "" {
|
||||||
|
return nil, fmt.Errorf("keys must be non-empty strings")
|
||||||
|
}
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case float64:
|
||||||
|
if math.IsNaN(typed) || math.IsInf(typed, 0) {
|
||||||
|
return nil, fmt.Errorf("values must be finite numbers")
|
||||||
|
}
|
||||||
|
cleaned[key] = typed
|
||||||
|
case float32:
|
||||||
|
if math.IsNaN(float64(typed)) || math.IsInf(float64(typed), 0) {
|
||||||
|
return nil, fmt.Errorf("values must be finite numbers")
|
||||||
|
}
|
||||||
|
cleaned[key] = float64(typed)
|
||||||
|
case int:
|
||||||
|
cleaned[key] = float64(typed)
|
||||||
|
case int8:
|
||||||
|
cleaned[key] = float64(typed)
|
||||||
|
case int16:
|
||||||
|
cleaned[key] = float64(typed)
|
||||||
|
case int32:
|
||||||
|
cleaned[key] = float64(typed)
|
||||||
|
case int64:
|
||||||
|
cleaned[key] = float64(typed)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("values must be finite numbers")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cleaned, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeChunkImageBase64(raw string) ([]byte, error) {
|
||||||
|
if strings.TrimSpace(raw) == "" {
|
||||||
|
return nil, fmt.Errorf("`image_base64` must be a non-empty string")
|
||||||
|
}
|
||||||
|
imageBinary, err := base64.StdEncoding.Strict().DecodeString(raw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Invalid `image_base64`")
|
||||||
|
}
|
||||||
|
if len(imageBinary) == 0 {
|
||||||
|
return nil, fmt.Errorf("`image_base64` is empty")
|
||||||
|
}
|
||||||
|
return imageBinary, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeChunkEmbeddings(a, b []float64) ([]float64, error) {
|
||||||
|
if len(a) == 0 || len(b) == 0 || len(a) != len(b) {
|
||||||
|
return nil, fmt.Errorf("unexpected embedding dimensions")
|
||||||
|
}
|
||||||
|
merged := make([]float64, len(a))
|
||||||
|
for i := range a {
|
||||||
|
merged[i] = 0.1*a[i] + 0.9*b[i]
|
||||||
|
}
|
||||||
|
return merged, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterTrimmedStrings(values []string) []string {
|
||||||
|
filtered := make([]string, 0, len(values))
|
||||||
|
for _, value := range values {
|
||||||
|
trimmed := strings.TrimSpace(value)
|
||||||
|
if trimmed != "" {
|
||||||
|
filtered = append(filtered, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ChunkService) tokenize(text string) (string, error) {
|
||||||
|
if s.tokenizeFunc != nil {
|
||||||
|
return s.tokenizeFunc(text)
|
||||||
|
}
|
||||||
|
return tokenizer.Tokenize(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ChunkService) fineGrainedTokenize(text string) (string, error) {
|
||||||
|
if s.fineGrainedTokenizeFunc != nil {
|
||||||
|
return s.fineGrainedTokenizeFunc(text)
|
||||||
|
}
|
||||||
|
return tokenizer.FineGrainedTokenize(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ChunkService) numTokens(text string) int {
|
||||||
|
if s.numTokensFunc != nil {
|
||||||
|
return s.numTokensFunc(text)
|
||||||
|
}
|
||||||
|
return tokenizer.NumTokensFromString(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ChunkService) getEmbeddingModel(tenantID, embdID string) (*models.EmbeddingModel, error) {
|
||||||
|
if s.getEmbeddingModelFunc != nil {
|
||||||
|
return s.getEmbeddingModelFunc(tenantID, embdID)
|
||||||
|
}
|
||||||
|
return service.NewModelProviderService().GetEmbeddingModel(tenantID, embdID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ChunkService) incrementChunkStats(docID, kbID string, tokenNum, chunkNum int64, duration float64) error {
|
||||||
|
if s.incrementChunkStatsFunc != nil {
|
||||||
|
return s.incrementChunkStatsFunc(docID, kbID, tokenNum, chunkNum, duration)
|
||||||
|
}
|
||||||
|
return dao.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
result := tx.Model(&entity.Document{}).
|
||||||
|
Where("id = ? AND kb_id = ?", docID, kbID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"token_num": gorm.Expr("token_num + ?", tokenNum),
|
||||||
|
"chunk_num": gorm.Expr("chunk_num + ?", chunkNum),
|
||||||
|
"process_duration": gorm.Expr("process_duration + ?", duration),
|
||||||
|
})
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return fmt.Errorf("document not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
result = tx.Model(&entity.Knowledgebase{}).
|
||||||
|
Where("id = ?", kbID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"token_num": gorm.Expr("token_num + ?", tokenNum),
|
||||||
|
"chunk_num": gorm.Expr("chunk_num + ?", chunkNum),
|
||||||
|
})
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return fmt.Errorf("knowledgebase not found")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ChunkService) storeChunkImage(bucket, chunkID string, imageBinary []byte) error {
|
||||||
|
if s.storeChunkImageFunc != nil {
|
||||||
|
return s.storeChunkImageFunc(bucket, chunkID, imageBinary)
|
||||||
|
}
|
||||||
|
storageImpl := storage.GetStorageFactory().GetStorage()
|
||||||
|
if storageImpl == nil {
|
||||||
|
return fmt.Errorf("storage not initialized")
|
||||||
|
}
|
||||||
|
lockKey := bucket + "/" + chunkID
|
||||||
|
lock := acquireChunkImageMergeLock(lockKey)
|
||||||
|
lock.mu.Lock()
|
||||||
|
defer func() {
|
||||||
|
lock.mu.Unlock()
|
||||||
|
releaseChunkImageMergeLock(lockKey)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if !storageImpl.ObjExist(bucket, chunkID) {
|
||||||
|
return storageImpl.Put(bucket, chunkID, imageBinary)
|
||||||
|
}
|
||||||
|
|
||||||
|
oldBinary, err := storageImpl.Get(bucket, chunkID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
oldImage, _, err := image.Decode(bytes.NewReader(oldBinary))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
newImage, _, err := image.Decode(bytes.NewReader(imageBinary))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
oldBounds, newBounds := oldImage.Bounds(), newImage.Bounds()
|
||||||
|
width := oldBounds.Dx()
|
||||||
|
if newBounds.Dx() > width {
|
||||||
|
width = newBounds.Dx()
|
||||||
|
}
|
||||||
|
height := oldBounds.Dy() + newBounds.Dy()
|
||||||
|
combined := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||||
|
draw.Draw(combined, combined.Bounds(), &image.Uniform{C: color.White}, image.Point{}, draw.Src)
|
||||||
|
draw.Draw(combined, oldBounds, oldImage, oldBounds.Min, draw.Src)
|
||||||
|
draw.Draw(combined, image.Rect(0, oldBounds.Dy(), newBounds.Dx(), oldBounds.Dy()+newBounds.Dy()), newImage, newBounds.Min, draw.Src)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := jpeg.Encode(&buf, combined, nil); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return storageImpl.Put(bucket, chunkID, buf.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func acquireChunkImageMergeLock(key string) *chunkImageMergeLock {
|
||||||
|
chunkImageMergeLocks.Lock()
|
||||||
|
defer chunkImageMergeLocks.Unlock()
|
||||||
|
|
||||||
|
lock := chunkImageMergeLocks.locks[key]
|
||||||
|
if lock == nil {
|
||||||
|
lock = &chunkImageMergeLock{}
|
||||||
|
chunkImageMergeLocks.locks[key] = lock
|
||||||
|
}
|
||||||
|
lock.refs++
|
||||||
|
return lock
|
||||||
|
}
|
||||||
|
|
||||||
|
func releaseChunkImageMergeLock(key string) {
|
||||||
|
chunkImageMergeLocks.Lock()
|
||||||
|
defer chunkImageMergeLocks.Unlock()
|
||||||
|
|
||||||
|
lock := chunkImageMergeLocks.locks[key]
|
||||||
|
if lock == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lock.refs--
|
||||||
|
if lock.refs == 0 {
|
||||||
|
delete(chunkImageMergeLocks.locks, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,16 +1,23 @@
|
|||||||
package chunk
|
package chunk
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"image"
|
||||||
|
"image/color"
|
||||||
|
"image/png"
|
||||||
"ragflow/internal/common"
|
"ragflow/internal/common"
|
||||||
"ragflow/internal/dao"
|
"ragflow/internal/dao"
|
||||||
"ragflow/internal/engine/types"
|
"ragflow/internal/engine/types"
|
||||||
"ragflow/internal/entity"
|
"ragflow/internal/entity"
|
||||||
|
"ragflow/internal/entity/models"
|
||||||
"ragflow/internal/service"
|
"ragflow/internal/service"
|
||||||
|
"ragflow/internal/storage"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/glebarez/sqlite"
|
"github.com/glebarez/sqlite"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -412,6 +419,305 @@ func TestParseQueuesAndBeginsDocument(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAddChunkSuccess(t *testing.T) {
|
||||||
|
db := setupChunkTestDB(t)
|
||||||
|
pushChunkTestDB(t, db)
|
||||||
|
userID, datasetID, documentID := "user-1", "kb-1", "doc-1"
|
||||||
|
insertChunkTestKB(t, datasetID, userID)
|
||||||
|
insertChunkTestDoc(t, documentID, datasetID)
|
||||||
|
|
||||||
|
engine := &addChunkTestEngine{}
|
||||||
|
var incrementTokenNum, incrementChunkNum int64
|
||||||
|
svc := &ChunkService{
|
||||||
|
docEngine: engine,
|
||||||
|
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||||
|
documentDAO: dao.NewDocumentDAO(),
|
||||||
|
accessibleFunc: func(datasetIDArg, userIDArg string) bool {
|
||||||
|
return datasetIDArg == datasetID && userIDArg == userID
|
||||||
|
},
|
||||||
|
getKnowledgebaseByIDFunc: func(id string) (*entity.Knowledgebase, error) {
|
||||||
|
return &entity.Knowledgebase{ID: id, TenantID: userID, EmbdID: "embed-1"}, nil
|
||||||
|
},
|
||||||
|
getEmbeddingModelFunc: func(string, string) (*models.EmbeddingModel, error) {
|
||||||
|
driver := &stubEmbeddingDriver{
|
||||||
|
embeddings: []models.EmbeddingData{
|
||||||
|
{Embedding: []float64{1, 2}},
|
||||||
|
{Embedding: []float64{3, 4}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
modelName := "embed-1"
|
||||||
|
return models.NewEmbeddingModel(driver, &modelName, &models.APIConfig{}, 0), nil
|
||||||
|
},
|
||||||
|
incrementChunkStatsFunc: func(docID, kbID string, tokenNum, chunkNum int64, duration float64) error {
|
||||||
|
if docID != documentID || kbID != datasetID || duration != 0 {
|
||||||
|
t.Fatalf("unexpected increment args doc=%s kb=%s duration=%v", docID, kbID, duration)
|
||||||
|
}
|
||||||
|
incrementTokenNum = tokenNum
|
||||||
|
incrementChunkNum = chunkNum
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
tokenizeFunc: func(text string) (string, error) { return text, nil },
|
||||||
|
fineGrainedTokenizeFunc: func(text string) (string, error) { return text + "_fg", nil },
|
||||||
|
numTokensFunc: func(text string) int { return len(text) },
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := svc.AddChunk(&service.AddChunkRequest{
|
||||||
|
DatasetID: datasetID,
|
||||||
|
DocumentID: documentID,
|
||||||
|
Content: "chunk body",
|
||||||
|
ImportantKeywords: []string{"k1"},
|
||||||
|
Questions: []string{" q1 ", ""},
|
||||||
|
TagKwd: []string{"tag1"},
|
||||||
|
TagFeas: map[string]interface{}{"tag1": float64(0.5)},
|
||||||
|
}, userID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AddChunk() error = %v", err)
|
||||||
|
}
|
||||||
|
if resp == nil || resp.Chunk == nil {
|
||||||
|
t.Fatalf("expected chunk response, got %#v", resp)
|
||||||
|
}
|
||||||
|
if resp.Chunk["dataset_id"] != datasetID || resp.Chunk["document_id"] != documentID {
|
||||||
|
t.Fatalf("unexpected response chunk: %#v", resp.Chunk)
|
||||||
|
}
|
||||||
|
if resp.Chunk["content"] != "chunk body" {
|
||||||
|
t.Fatalf("content = %v, want chunk body", resp.Chunk["content"])
|
||||||
|
}
|
||||||
|
if resp.Chunk["document"] != "doc-1.txt" {
|
||||||
|
t.Fatalf("document = %v, want doc-1.txt", resp.Chunk["document"])
|
||||||
|
}
|
||||||
|
if incrementChunkNum != 1 {
|
||||||
|
t.Fatalf("increment chunk num = %d, want 1", incrementChunkNum)
|
||||||
|
}
|
||||||
|
if incrementTokenNum <= 0 {
|
||||||
|
t.Fatalf("increment token num = %d, want > 0", incrementTokenNum)
|
||||||
|
}
|
||||||
|
if len(engine.insertedChunks) != 1 {
|
||||||
|
t.Fatalf("inserted chunks = %d, want 1", len(engine.insertedChunks))
|
||||||
|
}
|
||||||
|
inserted := engine.insertedChunks[0]
|
||||||
|
if inserted["doc_id"] != documentID || inserted["kb_id"] != datasetID {
|
||||||
|
t.Fatalf("unexpected inserted chunk: %#v", inserted)
|
||||||
|
}
|
||||||
|
if inserted["img_id"] != nil {
|
||||||
|
t.Fatalf("did not expect image id in inserted chunk: %#v", inserted)
|
||||||
|
}
|
||||||
|
vec, ok := inserted["q_2_vec"].([]float64)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected q_2_vec []float64, got %T", inserted["q_2_vec"])
|
||||||
|
}
|
||||||
|
if len(vec) != 2 || vec[0] < 2.7999 || vec[0] > 2.8001 || vec[1] < 3.7999 || vec[1] > 3.8001 {
|
||||||
|
t.Fatalf("vector = %v, want approximately [2.8 3.8]", vec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddChunkValidationErrors(t *testing.T) {
|
||||||
|
db := setupChunkTestDB(t)
|
||||||
|
pushChunkTestDB(t, db)
|
||||||
|
insertChunkTestKB(t, "kb-1", "user-1")
|
||||||
|
insertChunkTestDoc(t, "doc-1", "kb-1")
|
||||||
|
|
||||||
|
svc := &ChunkService{
|
||||||
|
docEngine: &addChunkTestEngine{},
|
||||||
|
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||||
|
documentDAO: dao.NewDocumentDAO(),
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req *service.AddChunkRequest
|
||||||
|
setup func(*ChunkService)
|
||||||
|
wantMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil request",
|
||||||
|
req: nil,
|
||||||
|
wantMsg: "invalid request payload",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty content",
|
||||||
|
req: &service.AddChunkRequest{DatasetID: "kb-1", DocumentID: "doc-1", Content: " "},
|
||||||
|
setup: func(svc *ChunkService) {
|
||||||
|
svc.accessibleFunc = func(string, string) bool { return true }
|
||||||
|
svc.getKnowledgebaseByIDFunc = func(string) (*entity.Knowledgebase, error) {
|
||||||
|
return &entity.Knowledgebase{ID: "kb-1", TenantID: "user-1", EmbdID: "embed-1"}, nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantMsg: "`content` is required",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.setup != nil {
|
||||||
|
tt.setup(svc)
|
||||||
|
}
|
||||||
|
_, err := svc.AddChunk(tt.req, "user-1")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), tt.wantMsg) {
|
||||||
|
t.Fatalf("error = %v, want substring %q", err, tt.wantMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddChunkImageAndTagFeatureValidation(t *testing.T) {
|
||||||
|
db := setupChunkTestDB(t)
|
||||||
|
pushChunkTestDB(t, db)
|
||||||
|
userID, datasetID, documentID := "user-1", "kb-1", "doc-1"
|
||||||
|
insertChunkTestKB(t, datasetID, userID)
|
||||||
|
insertChunkTestDoc(t, documentID, datasetID)
|
||||||
|
|
||||||
|
storeCalls := 0
|
||||||
|
svc := &ChunkService{
|
||||||
|
docEngine: &addChunkTestEngine{},
|
||||||
|
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||||
|
documentDAO: dao.NewDocumentDAO(),
|
||||||
|
accessibleFunc: func(string, string) bool { return true },
|
||||||
|
getKnowledgebaseByIDFunc: func(id string) (*entity.Knowledgebase, error) {
|
||||||
|
return &entity.Knowledgebase{ID: id, TenantID: userID, EmbdID: "embed-1"}, nil
|
||||||
|
},
|
||||||
|
tokenizeFunc: func(text string) (string, error) { return text, nil },
|
||||||
|
fineGrainedTokenizeFunc: func(text string) (string, error) { return text + "_fg", nil },
|
||||||
|
numTokensFunc: func(text string) int { return len(text) },
|
||||||
|
getEmbeddingModelFunc: func(string, string) (*models.EmbeddingModel, error) {
|
||||||
|
driver := &stubEmbeddingDriver{
|
||||||
|
embeddings: []models.EmbeddingData{
|
||||||
|
{Embedding: []float64{1, 1}},
|
||||||
|
{Embedding: []float64{1, 1}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
modelName := "embed-1"
|
||||||
|
return models.NewEmbeddingModel(driver, &modelName, &models.APIConfig{}, 0), nil
|
||||||
|
},
|
||||||
|
incrementChunkStatsFunc: func(string, string, int64, int64, float64) error { return nil },
|
||||||
|
storeChunkImageFunc: func(bucket, chunkID string, imageBinary []byte) error {
|
||||||
|
storeCalls++
|
||||||
|
if bucket != datasetID || chunkID == "" || len(imageBinary) == 0 {
|
||||||
|
t.Fatalf("unexpected store args bucket=%s chunkID=%s len=%d", bucket, chunkID, len(imageBinary))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.AddChunk(&service.AddChunkRequest{
|
||||||
|
DatasetID: datasetID,
|
||||||
|
DocumentID: documentID,
|
||||||
|
Content: "chunk body",
|
||||||
|
ImageBase64: strPtr("not-base64"),
|
||||||
|
}, userID)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "Invalid `image_base64`") {
|
||||||
|
t.Fatalf("expected invalid image error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
validJPEG := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO2pRZ0AAAAASUVORK5CYII="
|
||||||
|
|
||||||
|
resp, err := svc.AddChunk(&service.AddChunkRequest{
|
||||||
|
DatasetID: datasetID,
|
||||||
|
DocumentID: documentID,
|
||||||
|
Content: "chunk body",
|
||||||
|
TagFeas: map[string]interface{}{"tag1": "bad"},
|
||||||
|
}, userID)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "`tag_feas` values must be finite numbers") || resp != nil {
|
||||||
|
t.Fatalf("expected tag_feas validation error, got resp=%#v err=%v", resp, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = svc.AddChunk(&service.AddChunkRequest{
|
||||||
|
DatasetID: datasetID,
|
||||||
|
DocumentID: documentID,
|
||||||
|
Content: "chunk body",
|
||||||
|
TagFeas: map[string]interface{}{"tag1": float64(1)},
|
||||||
|
ImageBase64: strPtr(validJPEG),
|
||||||
|
}, userID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AddChunk() with image error = %v", err)
|
||||||
|
}
|
||||||
|
if storeCalls != 1 {
|
||||||
|
t.Fatalf("store image calls = %d, want 1", storeCalls)
|
||||||
|
}
|
||||||
|
if _, ok := resp.Chunk["image_id"]; !ok {
|
||||||
|
t.Fatalf("expected image_id in response, got %#v", resp.Chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddChunkIncrementsStatsAfterInsert(t *testing.T) {
|
||||||
|
db := setupChunkTestDB(t)
|
||||||
|
pushChunkTestDB(t, db)
|
||||||
|
userID, datasetID, documentID := "user-1", "kb-1", "doc-1"
|
||||||
|
insertChunkTestKB(t, datasetID, userID)
|
||||||
|
insertChunkTestDoc(t, documentID, datasetID)
|
||||||
|
|
||||||
|
var incrementCalls int
|
||||||
|
engine := &addChunkTestEngine{}
|
||||||
|
svc := &ChunkService{
|
||||||
|
docEngine: engine,
|
||||||
|
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||||
|
documentDAO: dao.NewDocumentDAO(),
|
||||||
|
accessibleFunc: func(string, string) bool { return true },
|
||||||
|
getKnowledgebaseByIDFunc: func(id string) (*entity.Knowledgebase, error) {
|
||||||
|
return &entity.Knowledgebase{ID: id, TenantID: userID, EmbdID: "embed-1"}, nil
|
||||||
|
},
|
||||||
|
getEmbeddingModelFunc: func(string, string) (*models.EmbeddingModel, error) {
|
||||||
|
driver := &stubEmbeddingDriver{
|
||||||
|
embeddings: []models.EmbeddingData{
|
||||||
|
{Embedding: []float64{1, 2}},
|
||||||
|
{Embedding: []float64{3, 4}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
modelName := "embed-1"
|
||||||
|
return models.NewEmbeddingModel(driver, &modelName, &models.APIConfig{}, 0), nil
|
||||||
|
},
|
||||||
|
incrementChunkStatsFunc: func(string, string, int64, int64, float64) error {
|
||||||
|
incrementCalls++
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
tokenizeFunc: func(text string) (string, error) { return text, nil },
|
||||||
|
fineGrainedTokenizeFunc: func(text string) (string, error) { return text + "_fg", nil },
|
||||||
|
numTokensFunc: func(text string) int { return len(text) },
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.AddChunk(&service.AddChunkRequest{
|
||||||
|
DatasetID: datasetID,
|
||||||
|
DocumentID: documentID,
|
||||||
|
Content: "chunk body",
|
||||||
|
}, userID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AddChunk() error = %v", err)
|
||||||
|
}
|
||||||
|
if incrementCalls != 1 {
|
||||||
|
t.Fatalf("increment calls = %d, want 1", incrementCalls)
|
||||||
|
}
|
||||||
|
if len(engine.insertedChunks) != 1 {
|
||||||
|
t.Fatalf("inserted chunks = %d, want 1", len(engine.insertedChunks))
|
||||||
|
}
|
||||||
|
importantKwd, ok := engine.insertedChunks[0]["important_kwd"].([]string)
|
||||||
|
if !ok || len(importantKwd) != 0 {
|
||||||
|
t.Fatalf("important_kwd = %#v, want empty []string", engine.insertedChunks[0]["important_kwd"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreChunkImageMergesExistingImage(t *testing.T) {
|
||||||
|
oldImage := mustEncodePNG(t, image.Rect(0, 0, 2, 2))
|
||||||
|
newImage := mustEncodePNG(t, image.Rect(0, 0, 1, 1))
|
||||||
|
mockStorage := &chunkImageStorage{
|
||||||
|
exists: true,
|
||||||
|
oldBinary: oldImage,
|
||||||
|
}
|
||||||
|
|
||||||
|
factory := storage.GetStorageFactory()
|
||||||
|
originalStorage := factory.GetStorage()
|
||||||
|
factory.SetStorage(mockStorage)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
factory.SetStorage(originalStorage)
|
||||||
|
})
|
||||||
|
|
||||||
|
svc := &ChunkService{}
|
||||||
|
if err := svc.storeChunkImage("kb-1", "chunk-1", newImage); err != nil {
|
||||||
|
t.Fatalf("storeChunkImage() error = %v", err)
|
||||||
|
}
|
||||||
|
if mockStorage.putCalls != 1 {
|
||||||
|
t.Fatalf("put calls = %d, want 1", mockStorage.putCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func setupChunkTestDB(t *testing.T) *gorm.DB {
|
func setupChunkTestDB(t *testing.T) *gorm.DB {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@@ -620,6 +926,117 @@ func (e *parseTestDocEngine) GetChunkIDs([]map[string]interface{}) []string {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type addChunkTestEngine struct {
|
||||||
|
parseTestDocEngine
|
||||||
|
insertedChunks []map[string]interface{}
|
||||||
|
insertIndex string
|
||||||
|
insertDataset string
|
||||||
|
insertErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *addChunkTestEngine) InsertChunks(_ context.Context, chunks []map[string]interface{}, baseName string, datasetID string) ([]string, error) {
|
||||||
|
e.insertedChunks = chunks
|
||||||
|
e.insertIndex = baseName
|
||||||
|
e.insertDataset = datasetID
|
||||||
|
return nil, e.insertErr
|
||||||
|
}
|
||||||
|
|
||||||
|
type chunkImageStorage struct {
|
||||||
|
exists bool
|
||||||
|
oldBinary []byte
|
||||||
|
putCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *chunkImageStorage) Health() bool { return true }
|
||||||
|
func (s *chunkImageStorage) Put(bucket, fnm string, binary []byte, tenantID ...string) error {
|
||||||
|
s.putCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *chunkImageStorage) Get(bucket, fnm string, tenantID ...string) ([]byte, error) {
|
||||||
|
return s.oldBinary, nil
|
||||||
|
}
|
||||||
|
func (s *chunkImageStorage) Remove(bucket, fnm string, tenantID ...string) error { return nil }
|
||||||
|
func (s *chunkImageStorage) ObjExist(bucket, fnm string, tenantID ...string) bool { return s.exists }
|
||||||
|
func (s *chunkImageStorage) GetPresignedURL(bucket, fnm string, expires time.Duration, tenantID ...string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
func (s *chunkImageStorage) BucketExists(bucket string) bool { return true }
|
||||||
|
func (s *chunkImageStorage) RemoveBucket(bucket string) error { return nil }
|
||||||
|
func (s *chunkImageStorage) Copy(srcBucket, srcPath, destBucket, destPath string) bool { return false }
|
||||||
|
func (s *chunkImageStorage) Move(srcBucket, srcPath, destBucket, destPath string) bool { return false }
|
||||||
|
|
||||||
|
func mustEncodePNG(t *testing.T, rect image.Rectangle) []byte {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
img := image.NewRGBA(rect)
|
||||||
|
for y := rect.Min.Y; y < rect.Max.Y; y++ {
|
||||||
|
for x := rect.Min.X; x < rect.Max.X; x++ {
|
||||||
|
img.Set(x, y, color.White)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := png.Encode(&buf, img); err != nil {
|
||||||
|
t.Fatalf("encode png: %v", err)
|
||||||
|
}
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubEmbeddingDriver struct {
|
||||||
|
embeddings []models.EmbeddingData
|
||||||
|
embedErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *stubEmbeddingDriver) NewInstance(map[string]string) models.ModelDriver { return d }
|
||||||
|
func (d *stubEmbeddingDriver) Name() string { return "stub" }
|
||||||
|
func (d *stubEmbeddingDriver) ChatWithMessages(string, []models.Message, *models.APIConfig, *models.ChatConfig) (*models.ChatResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (d *stubEmbeddingDriver) ChatStreamlyWithSender(string, []models.Message, *models.APIConfig, *models.ChatConfig, func(*string, *string) error) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (d *stubEmbeddingDriver) Embed(*string, []string, *models.APIConfig, *models.EmbeddingConfig) ([]models.EmbeddingData, error) {
|
||||||
|
return d.embeddings, d.embedErr
|
||||||
|
}
|
||||||
|
func (d *stubEmbeddingDriver) Rerank(*string, string, []string, *models.APIConfig, *models.RerankConfig) (*models.RerankResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (d *stubEmbeddingDriver) TranscribeAudio(*string, *string, *models.APIConfig, *models.ASRConfig) (*models.ASRResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (d *stubEmbeddingDriver) TranscribeAudioWithSender(*string, *string, *models.APIConfig, *models.ASRConfig, func(*string, *string) error) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (d *stubEmbeddingDriver) AudioSpeech(*string, *string, *models.APIConfig, *models.TTSConfig) (*models.TTSResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (d *stubEmbeddingDriver) AudioSpeechWithSender(*string, *string, *models.APIConfig, *models.TTSConfig, func(*string, *string) error) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (d *stubEmbeddingDriver) OCRFile(*string, []byte, *string, *models.APIConfig, *models.OCRConfig) (*models.OCRFileResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (d *stubEmbeddingDriver) ParseFile(*string, []byte, *string, *models.APIConfig, *models.ParseFileConfig) (*models.ParseFileResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (d *stubEmbeddingDriver) ListModels(*models.APIConfig) ([]models.ListModelResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (d *stubEmbeddingDriver) Balance(*models.APIConfig) (map[string]interface{}, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (d *stubEmbeddingDriver) CheckConnection(*models.APIConfig) error { return nil }
|
||||||
|
func (d *stubEmbeddingDriver) ListTasks(*models.APIConfig) ([]models.ListTaskStatus, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (d *stubEmbeddingDriver) ShowTask(string, *models.APIConfig) (*models.TaskResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func strPtr(v string) *string {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|
||||||
func (e *parseTestDocEngine) KNNScores(context.Context, []map[string]interface{}, []float64, int) (map[string]interface{}, error) {
|
func (e *parseTestDocEngine) KNNScores(context.Context, []map[string]interface{}, []float64, int) (map[string]interface{}, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -95,6 +95,29 @@ type ParseFileRequest struct {
|
|||||||
DocumentIDs []string `json:"document_ids"`
|
DocumentIDs []string `json:"document_ids"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddChunkRequest request for adding a chunk
|
||||||
|
type AddChunkRequest struct {
|
||||||
|
DatasetID string `json:"dataset_id"`
|
||||||
|
DocumentID string `json:"document_id"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
ImportantKeywords []string `json:"important_keywords,omitempty"`
|
||||||
|
Questions []string `json:"questions,omitempty"`
|
||||||
|
TagKwd []string `json:"tag_kwd,omitempty"`
|
||||||
|
TagFeas interface{} `json:"tag_feas,omitempty"`
|
||||||
|
ImageBase64 *string `json:"image_base64,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddChunkResponse response for adding a chunk
|
||||||
|
type AddChunkResponse struct {
|
||||||
|
Chunk map[string]interface{} `json:"chunk"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorCoder exposes an application error code alongside an error string.
|
||||||
|
type ErrorCoder interface {
|
||||||
|
error
|
||||||
|
Code() common.ErrorCode
|
||||||
|
}
|
||||||
|
|
||||||
// Get retrieves a chunk by ID
|
// Get retrieves a chunk by ID
|
||||||
func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkResponse, error) {
|
func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkResponse, error) {
|
||||||
if s.docEngine == nil {
|
if s.docEngine == nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user