Go: implement Bedrock embeddings (#15543)

### What problem does this PR solve?

Fixes #15542.

AWS Bedrock support for the Go model provider layer was added in #15166,
but embedding support was intentionally left out of scope and
`BedrockModel.Embed(...)` still returned the `no such method` sentinel.
This PR implements Bedrock text embeddings under the umbrella provider
tracker #14736.

### What this PR includes

- `internal/entity/models/bedrock.go`: implement
`BedrockModel.Embed(...)` through Bedrock Runtime `InvokeModel` with
existing SigV4 auth, region resolution, and runtime URL helpers.
- Titan embeddings: supports `amazon.titan-embed-text-v1` and
`amazon.titan-embed-text-v2:0`; v2 forwards `EmbeddingConfig.Dimension`
as `dimensions` when provided, while v1 keeps the payload minimal.
- Cohere embeddings: supports `cohere.embed-english-v3`,
`cohere.embed-multilingual-v3`, and `cohere.embed-v4:0`; batches input
texts and maps returned vectors to RAGFlow `EmbeddingData` in input
order.
- `conf/models/bedrock.json`: adds the `embedding` URL suffix (`invoke`)
and Bedrock embedding model entries.
- `internal/entity/models/bedrock_test.go`: adds unit tests for Titan,
Cohere, typed Cohere responses, validation, empty input, unsupported
models, and HTTP error propagation.

Reference docs:

- Bedrock InvokeModel API:
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html
- Titan Text Embeddings:
https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html
- Cohere Embed models on Bedrock:
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

### How was this tested?

- [x] `jq empty conf/models/bedrock.json`
- [x] `git diff --check`
- [x] `go test ./internal/entity/models/... -run Bedrock -count=1`
- [x] `go test ./internal/entity/models/... -run '^$' -count=1`
- [x] `go test ./internal/entity/models/... -run Bedrock -race -count=1`

Note: `go test ./internal/entity/models/... -count=1` currently fails in
unrelated existing Astraflow coverage
(`TestAstraflowEmbedReturnsNoSuchMethod` panics in
`internal/entity/models/astraflow.go`). The Bedrock-specific tests and
compile-only package check pass.
This commit is contained in:
tmimmanuel
2026-06-04 19:26:32 -10:00
committed by GitHub
parent b8db200757
commit f78ef328bb
3 changed files with 395 additions and 8 deletions

View File

@@ -45,6 +45,7 @@ const (
defaultBedrockChatSuffix = "converse"
defaultBedrockStreamSuffix = "converse-stream"
defaultBedrockListModelsSuffix = "foundation-models"
defaultBedrockEmbeddingSuffix = "invoke"
bedrockStreamSuffixSuffix = "-stream"
)
@@ -309,6 +310,14 @@ func (b *BedrockModel) modelsSuffix() string {
return defaultBedrockListModelsSuffix
}
// embeddingSuffix returns the runtime InvokeModel operation path.
func (b *BedrockModel) embeddingSuffix() string {
if b.baseModel.URLSuffix.Embedding != "" {
return b.baseModel.URLSuffix.Embedding
}
return defaultBedrockEmbeddingSuffix
}
// bedrockRuntimeURL builds the per-region runtime endpoint URL for a
// given Bedrock operation. Bedrock paths are deployment-style:
// {host}/model/{modelId}/{op}. Any user-supplied override in BaseURL
@@ -836,11 +845,184 @@ func (b *BedrockModel) CheckConnection(apiConfig *APIConfig) error {
return err
}
// Embed is not exposed by Bedrock through the Converse API; the
// embeddings surface is per-model (Titan, Cohere) and ships in a
// follow-on PR alongside conf/models/bedrock.json embedding entries.
type bedrockTitanEmbeddingRequest struct {
InputText string `json:"inputText"`
Dimensions *int `json:"dimensions,omitempty"`
}
type bedrockTitanEmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
}
type bedrockCohereEmbeddingRequest struct {
Texts []string `json:"texts"`
InputType string `json:"input_type"`
OutputDimension *int `json:"output_dimension,omitempty"`
}
type bedrockCohereEmbeddingResponse struct {
Embeddings json.RawMessage `json:"embeddings"`
}
// Embed sends text embedding requests through Bedrock Runtime
// InvokeModel. Titan's embedding API accepts one inputText per call,
// while Cohere accepts a texts batch and returns vectors in input
// order.
func (b *BedrockModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
return nil, fmt.Errorf("%s, no such method", b.Name())
if len(texts) == 0 {
return []EmbeddingData{}, nil
}
if apiConfig == nil || apiConfig.ApiKey == nil {
return nil, fmt.Errorf("api key is required")
}
if modelName == nil || strings.TrimSpace(*modelName) == "" {
return nil, fmt.Errorf("model name is required")
}
modelID := strings.TrimSpace(*modelName)
key, err := parseBedrockKey(*apiConfig.ApiKey)
if err != nil {
return nil, err
}
region, err := resolveBedrockRegion(apiConfig, key)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
defer cancel()
creds, err := resolveBedrockCredentials(ctx, key, region)
if err != nil {
return nil, err
}
if strings.HasPrefix(modelID, "amazon.titan-embed-text-") {
return b.embedTitan(ctx, modelID, texts, region, creds, embeddingConfig)
}
if strings.HasPrefix(modelID, "cohere.embed-") {
return b.embedCohere(ctx, modelID, texts, region, creds, embeddingConfig)
}
return nil, fmt.Errorf("bedrock: unsupported embedding model %q", modelID)
}
func (b *BedrockModel) invokeEmbeddingModel(ctx context.Context, modelID string, body interface{}, region string, creds awssdk.Credentials) ([]byte, error) {
raw, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("bedrock: marshal embedding request: %w", err)
}
url := b.bedrockRuntimeURL(region, modelID, b.embeddingSuffix())
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(raw))
if err != nil {
return nil, fmt.Errorf("bedrock: build embedding request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if err := signBedrockRequest(ctx, req, raw, creds, bedrockRuntimeService, region); err != nil {
return nil, err
}
resp, err := b.baseModel.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("bedrock: send embedding request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("bedrock: read embedding response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("bedrock: embedding request failed with status %d: %s", resp.StatusCode, string(respBody))
}
return respBody, nil
}
func (b *BedrockModel) embedTitan(ctx context.Context, modelID string, texts []string, region string, creds awssdk.Credentials, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
embeddings := make([]EmbeddingData, 0, len(texts))
for i, text := range texts {
req := bedrockTitanEmbeddingRequest{
InputText: text,
}
if embeddingConfig != nil && embeddingConfig.Dimension > 0 && strings.HasPrefix(modelID, "amazon.titan-embed-text-v2") {
req.Dimensions = &embeddingConfig.Dimension
}
respBody, err := b.invokeEmbeddingModel(ctx, modelID, req, region, creds)
if err != nil {
return nil, err
}
var parsed bedrockTitanEmbeddingResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
return nil, fmt.Errorf("bedrock: parse Titan embedding response: %w", err)
}
if len(parsed.Embedding) == 0 {
return nil, fmt.Errorf("bedrock: Titan embedding response missing embedding for input index %d", i)
}
embeddings = append(embeddings, EmbeddingData{
Embedding: parsed.Embedding,
Index: i,
})
}
return embeddings, nil
}
func (b *BedrockModel) embedCohere(ctx context.Context, modelID string, texts []string, region string, creds awssdk.Credentials, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
req := bedrockCohereEmbeddingRequest{
Texts: texts,
InputType: "search_document",
}
if embeddingConfig != nil && embeddingConfig.Dimension > 0 && strings.HasPrefix(modelID, "cohere.embed-v4") {
req.OutputDimension = &embeddingConfig.Dimension
}
respBody, err := b.invokeEmbeddingModel(ctx, modelID, req, region, creds)
if err != nil {
return nil, err
}
var parsed bedrockCohereEmbeddingResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
return nil, fmt.Errorf("bedrock: parse Cohere embedding response: %w", err)
}
vectors, err := decodeCohereEmbeddingVectors(parsed.Embeddings)
if err != nil {
return nil, err
}
if len(vectors) != len(texts) {
return nil, fmt.Errorf("bedrock: Cohere returned %d embeddings for %d inputs", len(vectors), len(texts))
}
embeddings := make([]EmbeddingData, len(vectors))
for i, vector := range vectors {
if len(vector) == 0 {
return nil, fmt.Errorf("bedrock: Cohere embedding response missing embedding for input index %d", i)
}
embeddings[i] = EmbeddingData{
Embedding: vector,
Index: i,
}
}
return embeddings, nil
}
func decodeCohereEmbeddingVectors(raw json.RawMessage) ([][]float64, error) {
if len(raw) == 0 {
return nil, fmt.Errorf("bedrock: Cohere embedding response missing embeddings")
}
var vectors [][]float64
if err := json.Unmarshal(raw, &vectors); err == nil {
return vectors, nil
}
var byType map[string][][]float64
if err := json.Unmarshal(raw, &byType); err != nil {
return nil, fmt.Errorf("bedrock: parse Cohere embeddings: %w", err)
}
vectors, ok := byType["float"]
if !ok {
return nil, fmt.Errorf("bedrock: Cohere embedding response missing float embeddings")
}
return vectors, nil
}
// Rerank is not exposed by Bedrock.

View File

@@ -646,11 +646,180 @@ func TestLookupBedrockEventHeader(t *testing.T) {
}
}
func TestBedrockEmbedReturnsNoSuchMethod(t *testing.T) {
func TestBedrockTitanEmbedHappyPath(t *testing.T) {
var seenInputs []string
srv := newBedrockServer(t, http.MethodPost,
"/model/amazon.titan-embed-text-v2:0/invoke",
func(w http.ResponseWriter, r *http.Request) {
raw, _ := io.ReadAll(r.Body)
var body bedrockTitanEmbeddingRequest
if err := json.Unmarshal(raw, &body); err != nil {
t.Errorf("unmarshal body: %v", err)
return
}
seenInputs = append(seenInputs, body.InputText)
if body.Dimensions == nil || *body.Dimensions != 256 {
t.Errorf("dimensions=%v, want 256", body.Dimensions)
}
w.Header().Set("Content-Type", "application/json")
if body.InputText == "alpha" {
_, _ = w.Write([]byte(`{"embedding":[0.1,0.2]}`))
} else {
_, _ = w.Write([]byte(`{"embedding":[0.3,0.4]}`))
}
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
model := "amazon.titan-embed-text-v2:0"
got, err := m.Embed(&model, []string{"alpha", "beta"}, &APIConfig{ApiKey: &key}, &EmbeddingConfig{Dimension: 256})
if err != nil {
t.Fatalf("Embed: %v", err)
}
if len(seenInputs) != 2 || seenInputs[0] != "alpha" || seenInputs[1] != "beta" {
t.Fatalf("seen inputs=%v", seenInputs)
}
if len(got) != 2 {
t.Fatalf("len(got)=%d want 2", len(got))
}
if got[0].Index != 0 || got[0].Embedding[0] != 0.1 || got[1].Index != 1 || got[1].Embedding[0] != 0.3 {
t.Errorf("embeddings=%+v", got)
}
}
func TestBedrockTitanV1OmitsDimension(t *testing.T) {
srv := newBedrockServer(t, http.MethodPost,
"/model/amazon.titan-embed-text-v1/invoke",
func(w http.ResponseWriter, r *http.Request) {
raw, _ := io.ReadAll(r.Body)
if strings.Contains(string(raw), "dimensions") {
t.Errorf("Titan v1 body must not include dimensions: %s", string(raw))
}
_, _ = w.Write([]byte(`{"embedding":[0.1,0.2]}`))
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
model := "amazon.titan-embed-text-v1"
if _, err := m.Embed(&model, []string{"alpha"}, &APIConfig{ApiKey: &key}, &EmbeddingConfig{Dimension: 256}); err != nil {
t.Fatalf("Embed: %v", err)
}
}
func TestBedrockCohereEmbedHappyPath(t *testing.T) {
srv := newBedrockServer(t, http.MethodPost,
"/model/cohere.embed-english-v3/invoke",
func(w http.ResponseWriter, r *http.Request) {
raw, _ := io.ReadAll(r.Body)
var body bedrockCohereEmbeddingRequest
if err := json.Unmarshal(raw, &body); err != nil {
t.Errorf("unmarshal body: %v", err)
return
}
if len(body.Texts) != 2 || body.Texts[0] != "first" || body.Texts[1] != "second" {
t.Errorf("texts=%v", body.Texts)
}
if body.InputType != "search_document" {
t.Errorf("input_type=%q want search_document", body.InputType)
}
if body.OutputDimension != nil {
t.Errorf("v3 output_dimension=%v, want omitted", *body.OutputDimension)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"embeddings":[[1,2],[3,4]]}`))
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
model := "cohere.embed-english-v3"
got, err := m.Embed(&model, []string{"first", "second"}, &APIConfig{ApiKey: &key}, &EmbeddingConfig{Dimension: 128})
if err != nil {
t.Fatalf("Embed: %v", err)
}
if len(got) != 2 || got[0].Index != 0 || got[0].Embedding[0] != 1 || got[1].Index != 1 || got[1].Embedding[0] != 3 {
t.Errorf("embeddings=%+v", got)
}
}
func TestBedrockCohereV4ForwardsDimensionAndParsesTypedResponse(t *testing.T) {
srv := newBedrockServer(t, http.MethodPost,
"/model/cohere.embed-v4:0/invoke",
func(w http.ResponseWriter, r *http.Request) {
raw, _ := io.ReadAll(r.Body)
var body bedrockCohereEmbeddingRequest
if err := json.Unmarshal(raw, &body); err != nil {
t.Errorf("unmarshal body: %v", err)
return
}
if body.OutputDimension == nil || *body.OutputDimension != 512 {
t.Errorf("output_dimension=%v, want 512", body.OutputDimension)
}
_, _ = w.Write([]byte(`{"embeddings":{"float":[[0.5,0.6]]}}`))
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
model := "cohere.embed-v4:0"
got, err := m.Embed(&model, []string{"first"}, &APIConfig{ApiKey: &key}, &EmbeddingConfig{Dimension: 512})
if err != nil {
t.Fatalf("Embed: %v", err)
}
if len(got) != 1 || got[0].Index != 0 || got[0].Embedding[0] != 0.5 {
t.Errorf("embeddings=%+v", got)
}
}
func TestBedrockEmbedShortCircuitsEmptyInput(t *testing.T) {
m := newBedrockForTest("http://unused")
got, err := m.Embed(nil, nil, nil, nil)
if err != nil {
t.Fatalf("Embed empty: %v", err)
}
if len(got) != 0 {
t.Errorf("len(got)=%d want 0", len(got))
}
}
func TestBedrockEmbedRequiresAPIKeyAndModel(t *testing.T) {
m := newBedrockForTest("http://unused")
model := "x"
if _, err := m.Embed(&model, []string{"a"}, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
t.Errorf("Embed: want no-such-method, got %v", err)
if _, err := m.Embed(&model, []string{"a"}, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "api key is required") {
t.Errorf("Embed: want api-key error, got %v", err)
}
key := validBedrockKey()
blank := " "
if _, err := m.Embed(&blank, []string{"a"}, &APIConfig{ApiKey: &key}, nil); err == nil || !strings.Contains(err.Error(), "model name is required") {
t.Errorf("Embed: want model-required error, got %v", err)
}
}
func TestBedrockEmbedRejectsUnsupportedModel(t *testing.T) {
m := newBedrockForTest("http://unused")
key := validBedrockKey()
model := "anthropic.claude-3-haiku-20240307-v1:0"
if _, err := m.Embed(&model, []string{"a"}, &APIConfig{ApiKey: &key}, nil); err == nil || !strings.Contains(err.Error(), "unsupported embedding model") {
t.Errorf("Embed: want unsupported-model error, got %v", err)
}
}
func TestBedrockEmbedPropagatesHTTPError(t *testing.T) {
srv := newBedrockServer(t, http.MethodPost,
"/model/amazon.titan-embed-text-v2:0/invoke",
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"message":"bad input"}`))
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
model := "amazon.titan-embed-text-v2:0"
if _, err := m.Embed(&model, []string{"a"}, &APIConfig{ApiKey: &key}, nil); err == nil || !strings.Contains(err.Error(), "400") || !strings.Contains(err.Error(), "bad input") {
t.Errorf("Embed: want HTTP error with body, got %v", err)
}
}