From 7740ec6c95f725cb653ca370ebd0ee6942419f2d Mon Sep 17 00:00:00 2001 From: sxxtony <166789813+sxxtony@users.noreply.github.com> Date: Thu, 21 May 2026 10:11:56 +0300 Subject: [PATCH] Go: implement Embed (embeddings) in Replicate driver (#15073) ### What problem does this PR solve? `ReplicateModel.Embed` in `internal/entity/models/replicate.go` was a `"replicate, no such method"` stub. Tracking issue #14736 lists Replicate's embedding surface as not implemented. This PR wires it up against Replicate's documented embedding schema. Until this PR, a tenant who selected a Replicate embedding model got the sentinel error on every embed call. Co-authored-by: sxxtony Co-authored-by: Jin Hai --- conf/models/replicate.json | 7 ++ internal/entity/models/replicate.go | 148 +++++++++++++++++++++++++++- 2 files changed, 154 insertions(+), 1 deletion(-) diff --git a/conf/models/replicate.json b/conf/models/replicate.json index 91111351ad..a57ab2db2e 100644 --- a/conf/models/replicate.json +++ b/conf/models/replicate.json @@ -22,6 +22,13 @@ "model_types": [ "chat" ] + }, + { + "name": "replicate/all-mpnet-base-v2:b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", + "max_tokens": 384, + "model_types": [ + "embedding" + ] } ] } diff --git a/internal/entity/models/replicate.go b/internal/entity/models/replicate.go index 0757b83250..c1c72ace54 100644 --- a/internal/entity/models/replicate.go +++ b/internal/entity/models/replicate.go @@ -566,8 +566,154 @@ func (r *ReplicateModel) CheckConnection(apiConfig *APIConfig) error { return err } +// replicateEmbedInput shapes the request body for Replicate's standard +// embedding models (e.g. replicate/all-mpnet-base-v2). Per the +// canonical Replicate embedding schema published in the model's +// openapi_schema, the two input fields are: +// +// text — single string to encode (used when len(texts) == 1) +// text_batch — JSON-formatted list of strings (used when len > 1) +// +// `text_batch` is `type: string` in the schema, so the JSON-encoded +// list itself is sent as a string value, NOT as a JSON array. Models +// that use different field names (e.g. nateraw/bge-large-en-v1.5's +// `texts`) are not currently supported by this driver; tenants on +// those should consult Replicate's OpenAPI schema and configure a +// compatible model in conf/models/replicate.json. +func replicateEmbedInput(texts []string) (map[string]interface{}, error) { + switch len(texts) { + case 0: + return nil, fmt.Errorf("replicate: texts is empty") + case 1: + return map[string]interface{}{"text": texts[0]}, nil + default: + encoded, err := json.Marshal(texts) + if err != nil { + return nil, fmt.Errorf("failed to encode text_batch: %w", err) + } + return map[string]interface{}{"text_batch": string(encoded)}, nil + } +} + +// replicateEmbedOutputToVectors normalizes Replicate's two observed +// embedding-output shapes into []EmbeddingData aligned with the +// caller's input order: +// +// []{embedding: [floats]} — the documented Embedding schema used +// by replicate/all-mpnet-base-v2 +// [][floats] — bare nested array used by some +// community models +// +// The driver rejects mismatched cardinality (output length != input +// length) and non-numeric vector entries rather than silently +// truncate or pad, matching the defensive posture the n1n / CometAPI +// drivers already use. +func replicateEmbedOutputToVectors(output interface{}, n int) ([]EmbeddingData, error) { + outputs, ok := output.([]interface{}) + if !ok { + return nil, fmt.Errorf("replicate: expected output to be an array, got %T", output) + } + if len(outputs) != n { + return nil, fmt.Errorf("replicate: expected %d embeddings, got %d", n, len(outputs)) + } + + vectors := make([]EmbeddingData, n) + for i, item := range outputs { + vec, err := replicateExtractEmbeddingVector(item) + if err != nil { + return nil, fmt.Errorf("replicate: output[%d]: %w", i, err) + } + vectors[i] = EmbeddingData{Embedding: vec, Index: i} + } + return vectors, nil +} + +func replicateExtractEmbeddingVector(item interface{}) ([]float64, error) { + switch v := item.(type) { + case []interface{}: + return replicateFloatsFromInterface(v) + case map[string]interface{}: + raw, ok := v["embedding"] + if !ok { + return nil, fmt.Errorf("missing 'embedding' field; got keys %v", replicateKeys(v)) + } + arr, ok := raw.([]interface{}) + if !ok { + return nil, fmt.Errorf("embedding field is %T, expected array", raw) + } + return replicateFloatsFromInterface(arr) + default: + return nil, fmt.Errorf("unsupported item type %T", item) + } +} + +func replicateFloatsFromInterface(arr []interface{}) ([]float64, error) { + floats := make([]float64, len(arr)) + for i, v := range arr { + f, ok := v.(float64) + if !ok { + return nil, fmt.Errorf("element %d is %T, expected number", i, v) + } + floats[i] = f + } + return floats, nil +} + +func replicateKeys(m map[string]interface{}) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + +// Embed turns a list of texts into embedding vectors via Replicate's +// prediction API. Replicate's embedding surface is the same async +// /v1/predictions or /v1/models/{owner}/{name}/predictions endpoint +// the chat path already uses: create a prediction with `Prefer: wait` +// to skip the polling round-trip when possible, fall back to +// waitForPrediction for predictions that don't finish in the wait +// window. The driver targets the canonical Replicate embedding +// schema (input.text / input.text_batch, output is an array of +// {embedding: [floats]} objects); see replicateEmbedInput and +// replicateEmbedOutputToVectors for details. func (r *ReplicateModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { - return nil, fmt.Errorf("%s, no such method", r.Name()) + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if modelName == nil || strings.TrimSpace(*modelName) == "" { + return nil, fmt.Errorf("model name is required") + } + + url, version, err := r.predictionEndpoint(apiConfig, *modelName) + if err != nil { + return nil, err + } + + input, err := replicateEmbedInput(texts) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + prediction, err := r.createPrediction(ctx, url, version, input, false, *apiConfig.ApiKey, true) + if err != nil { + return nil, err + } + prediction, err = r.waitForPrediction(ctx, prediction, *apiConfig.ApiKey) + if err != nil { + return nil, err + } + if !replicatePredictionSucceeded(prediction.Status) { + return nil, fmt.Errorf("replicate: prediction ended with status %q", prediction.Status) + } + + return replicateEmbedOutputToVectors(prediction.Output, len(texts)) } func (r *ReplicateModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {