mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Go: implement Embed (embeddings) in Replicate driver (#15073)
### What problem does this PR solve? `ReplicateModel.Embed` in `internal/entity/models/replicate.go` was a `"replicate, no such method"` stub. Tracking issue #14736 lists Replicate's embedding surface as not implemented. This PR wires it up against Replicate's documented embedding schema. Until this PR, a tenant who selected a Replicate embedding model got the sentinel error on every embed call. Co-authored-by: sxxtony <sxxtony@users.noreply.github.com> Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@@ -22,6 +22,13 @@
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "replicate/all-mpnet-base-v2:b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305",
|
||||
"max_tokens": 384,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -566,8 +566,154 @@ func (r *ReplicateModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// replicateEmbedInput shapes the request body for Replicate's standard
|
||||
// embedding models (e.g. replicate/all-mpnet-base-v2). Per the
|
||||
// canonical Replicate embedding schema published in the model's
|
||||
// openapi_schema, the two input fields are:
|
||||
//
|
||||
// text — single string to encode (used when len(texts) == 1)
|
||||
// text_batch — JSON-formatted list of strings (used when len > 1)
|
||||
//
|
||||
// `text_batch` is `type: string` in the schema, so the JSON-encoded
|
||||
// list itself is sent as a string value, NOT as a JSON array. Models
|
||||
// that use different field names (e.g. nateraw/bge-large-en-v1.5's
|
||||
// `texts`) are not currently supported by this driver; tenants on
|
||||
// those should consult Replicate's OpenAPI schema and configure a
|
||||
// compatible model in conf/models/replicate.json.
|
||||
func replicateEmbedInput(texts []string) (map[string]interface{}, error) {
|
||||
switch len(texts) {
|
||||
case 0:
|
||||
return nil, fmt.Errorf("replicate: texts is empty")
|
||||
case 1:
|
||||
return map[string]interface{}{"text": texts[0]}, nil
|
||||
default:
|
||||
encoded, err := json.Marshal(texts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode text_batch: %w", err)
|
||||
}
|
||||
return map[string]interface{}{"text_batch": string(encoded)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// replicateEmbedOutputToVectors normalizes Replicate's two observed
|
||||
// embedding-output shapes into []EmbeddingData aligned with the
|
||||
// caller's input order:
|
||||
//
|
||||
// []{embedding: [floats]} — the documented Embedding schema used
|
||||
// by replicate/all-mpnet-base-v2
|
||||
// [][floats] — bare nested array used by some
|
||||
// community models
|
||||
//
|
||||
// The driver rejects mismatched cardinality (output length != input
|
||||
// length) and non-numeric vector entries rather than silently
|
||||
// truncate or pad, matching the defensive posture the n1n / CometAPI
|
||||
// drivers already use.
|
||||
func replicateEmbedOutputToVectors(output interface{}, n int) ([]EmbeddingData, error) {
|
||||
outputs, ok := output.([]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("replicate: expected output to be an array, got %T", output)
|
||||
}
|
||||
if len(outputs) != n {
|
||||
return nil, fmt.Errorf("replicate: expected %d embeddings, got %d", n, len(outputs))
|
||||
}
|
||||
|
||||
vectors := make([]EmbeddingData, n)
|
||||
for i, item := range outputs {
|
||||
vec, err := replicateExtractEmbeddingVector(item)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("replicate: output[%d]: %w", i, err)
|
||||
}
|
||||
vectors[i] = EmbeddingData{Embedding: vec, Index: i}
|
||||
}
|
||||
return vectors, nil
|
||||
}
|
||||
|
||||
func replicateExtractEmbeddingVector(item interface{}) ([]float64, error) {
|
||||
switch v := item.(type) {
|
||||
case []interface{}:
|
||||
return replicateFloatsFromInterface(v)
|
||||
case map[string]interface{}:
|
||||
raw, ok := v["embedding"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'embedding' field; got keys %v", replicateKeys(v))
|
||||
}
|
||||
arr, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("embedding field is %T, expected array", raw)
|
||||
}
|
||||
return replicateFloatsFromInterface(arr)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported item type %T", item)
|
||||
}
|
||||
}
|
||||
|
||||
func replicateFloatsFromInterface(arr []interface{}) ([]float64, error) {
|
||||
floats := make([]float64, len(arr))
|
||||
for i, v := range arr {
|
||||
f, ok := v.(float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("element %d is %T, expected number", i, v)
|
||||
}
|
||||
floats[i] = f
|
||||
}
|
||||
return floats, nil
|
||||
}
|
||||
|
||||
func replicateKeys(m map[string]interface{}) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// Embed turns a list of texts into embedding vectors via Replicate's
|
||||
// prediction API. Replicate's embedding surface is the same async
|
||||
// /v1/predictions or /v1/models/{owner}/{name}/predictions endpoint
|
||||
// the chat path already uses: create a prediction with `Prefer: wait`
|
||||
// to skip the polling round-trip when possible, fall back to
|
||||
// waitForPrediction for predictions that don't finish in the wait
|
||||
// window. The driver targets the canonical Replicate embedding
|
||||
// schema (input.text / input.text_batch, output is an array of
|
||||
// {embedding: [floats]} objects); see replicateEmbedInput and
|
||||
// replicateEmbedOutputToVectors for details.
|
||||
func (r *ReplicateModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", r.Name())
|
||||
if len(texts) == 0 {
|
||||
return []EmbeddingData{}, nil
|
||||
}
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("api key is required")
|
||||
}
|
||||
if modelName == nil || strings.TrimSpace(*modelName) == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
url, version, err := r.predictionEndpoint(apiConfig, *modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
input, err := replicateEmbedInput(texts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
prediction, err := r.createPrediction(ctx, url, version, input, false, *apiConfig.ApiKey, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prediction, err = r.waitForPrediction(ctx, prediction, *apiConfig.ApiKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !replicatePredictionSucceeded(prediction.Status) {
|
||||
return nil, fmt.Errorf("replicate: prediction ended with status %q", prediction.Status)
|
||||
}
|
||||
|
||||
return replicateEmbedOutputToVectors(prediction.Output, len(texts))
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
|
||||
|
||||
Reference in New Issue
Block a user