Go: implement Embed (embeddings) in Replicate driver (#15073)

### What problem does this PR solve?

`ReplicateModel.Embed` in `internal/entity/models/replicate.go` was a
`"replicate, no such method"` stub. Tracking issue #14736 lists
Replicate's embedding surface as not implemented. This PR wires it up
against Replicate's documented embedding schema.

Until this PR, a tenant who selected a Replicate embedding model got the
sentinel error on every embed call.

Co-authored-by: sxxtony <sxxtony@users.noreply.github.com>
Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
sxxtony
2026-05-21 10:11:56 +03:00
committed by Jin Hai
parent 3e5b11a523
commit 7740ec6c95
2 changed files with 154 additions and 1 deletions

View File

@@ -22,6 +22,13 @@
"model_types": [
"chat"
]
},
{
"name": "replicate/all-mpnet-base-v2:b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305",
"max_tokens": 384,
"model_types": [
"embedding"
]
}
]
}

View File

@@ -566,8 +566,154 @@ func (r *ReplicateModel) CheckConnection(apiConfig *APIConfig) error {
return err
}
// replicateEmbedInput shapes the request body for Replicate's standard
// embedding models (e.g. replicate/all-mpnet-base-v2). Per the
// canonical Replicate embedding schema published in the model's
// openapi_schema, the two input fields are:
//
// text — single string to encode (used when len(texts) == 1)
// text_batch — JSON-formatted list of strings (used when len > 1)
//
// `text_batch` is `type: string` in the schema, so the JSON-encoded
// list itself is sent as a string value, NOT as a JSON array. Models
// that use different field names (e.g. nateraw/bge-large-en-v1.5's
// `texts`) are not currently supported by this driver; tenants on
// those should consult Replicate's OpenAPI schema and configure a
// compatible model in conf/models/replicate.json.
func replicateEmbedInput(texts []string) (map[string]interface{}, error) {
switch len(texts) {
case 0:
return nil, fmt.Errorf("replicate: texts is empty")
case 1:
return map[string]interface{}{"text": texts[0]}, nil
default:
encoded, err := json.Marshal(texts)
if err != nil {
return nil, fmt.Errorf("failed to encode text_batch: %w", err)
}
return map[string]interface{}{"text_batch": string(encoded)}, nil
}
}
// replicateEmbedOutputToVectors normalizes Replicate's two observed
// embedding-output shapes into []EmbeddingData aligned with the
// caller's input order:
//
// []{embedding: [floats]} — the documented Embedding schema used
// by replicate/all-mpnet-base-v2
// [][floats] — bare nested array used by some
// community models
//
// The driver rejects mismatched cardinality (output length != input
// length) and non-numeric vector entries rather than silently
// truncate or pad, matching the defensive posture the n1n / CometAPI
// drivers already use.
func replicateEmbedOutputToVectors(output interface{}, n int) ([]EmbeddingData, error) {
outputs, ok := output.([]interface{})
if !ok {
return nil, fmt.Errorf("replicate: expected output to be an array, got %T", output)
}
if len(outputs) != n {
return nil, fmt.Errorf("replicate: expected %d embeddings, got %d", n, len(outputs))
}
vectors := make([]EmbeddingData, n)
for i, item := range outputs {
vec, err := replicateExtractEmbeddingVector(item)
if err != nil {
return nil, fmt.Errorf("replicate: output[%d]: %w", i, err)
}
vectors[i] = EmbeddingData{Embedding: vec, Index: i}
}
return vectors, nil
}
func replicateExtractEmbeddingVector(item interface{}) ([]float64, error) {
switch v := item.(type) {
case []interface{}:
return replicateFloatsFromInterface(v)
case map[string]interface{}:
raw, ok := v["embedding"]
if !ok {
return nil, fmt.Errorf("missing 'embedding' field; got keys %v", replicateKeys(v))
}
arr, ok := raw.([]interface{})
if !ok {
return nil, fmt.Errorf("embedding field is %T, expected array", raw)
}
return replicateFloatsFromInterface(arr)
default:
return nil, fmt.Errorf("unsupported item type %T", item)
}
}
func replicateFloatsFromInterface(arr []interface{}) ([]float64, error) {
floats := make([]float64, len(arr))
for i, v := range arr {
f, ok := v.(float64)
if !ok {
return nil, fmt.Errorf("element %d is %T, expected number", i, v)
}
floats[i] = f
}
return floats, nil
}
func replicateKeys(m map[string]interface{}) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}
// Embed turns a list of texts into embedding vectors via Replicate's
// prediction API. Replicate's embedding surface is the same async
// /v1/predictions or /v1/models/{owner}/{name}/predictions endpoint
// the chat path already uses: create a prediction with `Prefer: wait`
// to skip the polling round-trip when possible, fall back to
// waitForPrediction for predictions that don't finish in the wait
// window. The driver targets the canonical Replicate embedding
// schema (input.text / input.text_batch, output is an array of
// {embedding: [floats]} objects); see replicateEmbedInput and
// replicateEmbedOutputToVectors for details.
func (r *ReplicateModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
return nil, fmt.Errorf("%s, no such method", r.Name())
if len(texts) == 0 {
return []EmbeddingData{}, nil
}
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return nil, fmt.Errorf("api key is required")
}
if modelName == nil || strings.TrimSpace(*modelName) == "" {
return nil, fmt.Errorf("model name is required")
}
url, version, err := r.predictionEndpoint(apiConfig, *modelName)
if err != nil {
return nil, err
}
input, err := replicateEmbedInput(texts)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
defer cancel()
prediction, err := r.createPrediction(ctx, url, version, input, false, *apiConfig.ApiKey, true)
if err != nil {
return nil, err
}
prediction, err = r.waitForPrediction(ctx, prediction, *apiConfig.ApiKey)
if err != nil {
return nil, err
}
if !replicatePredictionSucceeded(prediction.Status) {
return nil, fmt.Errorf("replicate: prediction ended with status %q", prediction.Status)
}
return replicateEmbedOutputToVectors(prediction.Output, len(texts))
}
func (r *ReplicateModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {