Files
ragflow/internal/entity/models/bedrock_test.go
Jack 2f99d52fb5 fix(ci): re-enable Go tests and fix compilation errors after ListModels signature change (#15862)
## Summary

This PR re-enables the Go test steps in CI that were previously
commented out, and fixes all compilation errors that have accumulated in
`internal/entity/models/` since the `ListModels` return type was changed
from `[]string` to `[]ListModelResponse`.

## Changes

### CI (`.github/workflows/tests.yml`)
- Re-enable **Prepare test resources** step (clones resource repo with
WordNet data)
- Re-enable **Test Go packages** step (runs `go test ./internal/...`)
- Fix resource path race condition by using
`/tmp/resource-${GITHUB_RUN_ID}` instead of `/tmp/resource`
- Exclude `/cli` package from Go tests (contains `main` redeclarations)

### Test fixes (16 model provider test files)
All errors were caused by the upstream change from `[]string` to
`[]ListModelResponse` in the `ListModels` interface:

- Add `joinModelNames` test helper to extract `.Name` from
`[]ListModelResponse` slices
- `strings.Join(models, ",")` → `joinModelNames(models, ",")` (11 files)
- `ids[i] != "..."` → `ids[i].Name != "..."` (cometapi, mistral)
- `got[i] != want[i]` → `got[i].Name != want[i]` (bedrock)
- `[]string` return types → `[]ListModelResponse` (google)

### Pre-existing bugs in model_test.go
Bugs introduced by the upstream `entity/` → `entity/models/` directory
rename:

- Add missing `pm := GetProviderManager()` calls in 3 test functions
- Fix `InitProviderManager` signature (`_, err :=` → `err :=`)
- Fix `MaxTokens` `*int` dereference (6 comparisons)
- Fix `readProviderConfig` relative path (3 levels up instead of 2)

### model.go
- Add `findRepoRoot()` to make `conf/all_models.json` resolution work
from any CWD, fixing `TestSiliconFlowProviderConfigLoadsLatestProModels`

### Test validation

```bash
go build ./internal/...      # 
go test ./internal/entity/models/... -count=1  #  all pass
```

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-09 21:12:15 +08:00

868 lines
30 KiB
Go

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package models
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
)
// validBedrockKey returns a JSON API-key blob using the access_key_secret
// auth mode. We use static credentials in tests so SigV4 has well-known
// material to sign with, and our httptest server simply accepts any
// signature rather than re-verifying it.
func validBedrockKey() string {
return `{"auth_mode":"access_key_secret","bedrock_region":"us-east-1","bedrock_ak":"AKIATEST","bedrock_sk":"secret-test"}`
}
// newBedrockForTest constructs a BedrockModel whose runtime and
// control endpoints are both overridden to point at the supplied
// httptest base URL. The override map keys ("us-east-1" and
// "control:us-east-1") match the lookups in bedrockRuntimeURL and
// bedrockControlURL respectively.
func newBedrockForTest(baseURL string) *BedrockModel {
return NewBedrockModel(
map[string]string{
"us-east-1": baseURL,
"control:us-east-1": baseURL,
},
URLSuffix{Chat: "converse", Models: "foundation-models"},
)
}
func TestBedrockName(t *testing.T) {
if got := newBedrockForTest("http://unused").Name(); got != "bedrock" {
t.Errorf("Name()=%q, want %q", got, "bedrock")
}
}
func TestParseBedrockKeyRejectsEmpty(t *testing.T) {
if _, err := parseBedrockKey(""); err == nil || !strings.Contains(err.Error(), "api key is required") {
t.Errorf("empty key: want api-key error, got %v", err)
}
if _, err := parseBedrockKey(" "); err == nil || !strings.Contains(err.Error(), "api key is required") {
t.Errorf("whitespace key: want api-key error, got %v", err)
}
}
func TestParseBedrockKeyRejectsNonJSON(t *testing.T) {
if _, err := parseBedrockKey("not-json"); err == nil || !strings.Contains(err.Error(), "JSON object") {
t.Errorf("non-JSON key: want JSON-object error, got %v", err)
}
}
func TestParseBedrockKeyRejectsMissingAuthMode(t *testing.T) {
if _, err := parseBedrockKey(`{"bedrock_region":"us-east-1"}`); err == nil || !strings.Contains(err.Error(), "auth_mode") {
t.Errorf("missing auth_mode: want explicit error, got %v", err)
}
}
func TestParseBedrockKeyRejectsUnknownAuthMode(t *testing.T) {
if _, err := parseBedrockKey(`{"auth_mode":"oauth"}`); err == nil || !strings.Contains(err.Error(), "unsupported auth_mode") {
t.Errorf("unknown auth_mode: want unsupported error, got %v", err)
}
}
func TestParseBedrockKeyAccessKeySecretValidates(t *testing.T) {
// Both AK and SK must be present; one without the other is rejected
// so a misconfigured tenant fails fast instead of producing an
// unsigned request the server then rejects opaquely.
missingSK := `{"auth_mode":"access_key_secret","bedrock_region":"us-east-1","bedrock_ak":"AKIATEST"}`
if _, err := parseBedrockKey(missingSK); err == nil || !strings.Contains(err.Error(), "bedrock_ak and bedrock_sk") {
t.Errorf("missing SK: want ak/sk error, got %v", err)
}
missingAK := `{"auth_mode":"access_key_secret","bedrock_region":"us-east-1","bedrock_sk":"secret"}`
if _, err := parseBedrockKey(missingAK); err == nil || !strings.Contains(err.Error(), "bedrock_ak and bedrock_sk") {
t.Errorf("missing AK: want ak/sk error, got %v", err)
}
}
func TestParseBedrockKeyIAMRoleRequiresARN(t *testing.T) {
if _, err := parseBedrockKey(`{"auth_mode":"iam_role","bedrock_region":"us-east-1"}`); err == nil || !strings.Contains(err.Error(), "aws_role_arn") {
t.Errorf("iam_role no ARN: want aws_role_arn error, got %v", err)
}
}
func TestParseBedrockKeyAssumeRoleAcceptsBareConfig(t *testing.T) {
// assume_role intentionally delegates to the default credential
// chain, so parseBedrockKey must accept a blob with no AK/SK/ARN.
if _, err := parseBedrockKey(`{"auth_mode":"assume_role","bedrock_region":"us-east-1"}`); err != nil {
t.Errorf("assume_role: want no error, got %v", err)
}
}
func TestResolveBedrockRegionPrefersAPIConfig(t *testing.T) {
key := &bedrockKey{Region: "us-east-1"}
override := "eu-west-1"
got, err := resolveBedrockRegion(&APIConfig{Region: &override}, key)
if err != nil || got != "eu-west-1" {
t.Errorf("got region=%q err=%v, want eu-west-1", got, err)
}
}
func TestResolveBedrockRegionFallsBackToKey(t *testing.T) {
key := &bedrockKey{Region: "us-east-1"}
got, err := resolveBedrockRegion(&APIConfig{}, key)
if err != nil || got != "us-east-1" {
t.Errorf("got region=%q err=%v, want us-east-1", got, err)
}
}
func TestResolveBedrockRegionRequiresOne(t *testing.T) {
key := &bedrockKey{}
if _, err := resolveBedrockRegion(&APIConfig{}, key); err == nil || !strings.Contains(err.Error(), "region is required") {
t.Errorf("no region: want region-required error, got %v", err)
}
}
func TestBuildConverseRequestExtractsSystem(t *testing.T) {
req, err := buildConverseRequest([]Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hi"},
}, nil)
if err != nil {
t.Fatalf("build: %v", err)
}
if len(req.System) != 1 || req.System[0].Text != "You are helpful." {
t.Errorf("system block wrong: %+v", req.System)
}
if len(req.Messages) != 1 || req.Messages[0].Role != "user" {
t.Errorf("messages wrong: %+v", req.Messages)
}
}
func TestBuildConverseRequestRejectsInterleavedSystem(t *testing.T) {
_, err := buildConverseRequest([]Message{
{Role: "user", Content: "Hi"},
{Role: "system", Content: "Mid-conversation system prompt"},
}, nil)
if err == nil || !strings.Contains(err.Error(), "system messages must come before") {
t.Errorf("interleaved system: want explicit error, got %v", err)
}
}
func TestBuildConverseRequestRejectsUnsupportedRole(t *testing.T) {
_, err := buildConverseRequest([]Message{{Role: "function", Content: "x"}}, nil)
if err == nil || !strings.Contains(err.Error(), "unsupported role") {
t.Errorf("unsupported role: want explicit error, got %v", err)
}
}
func TestBuildConverseRequestRejectsEmpty(t *testing.T) {
_, err := buildConverseRequest(nil, nil)
if err == nil || !strings.Contains(err.Error(), "messages is empty") {
t.Errorf("empty: want messages-empty error, got %v", err)
}
}
func TestBuildConverseRequestRejectsOnlySystem(t *testing.T) {
_, err := buildConverseRequest([]Message{{Role: "system", Content: "x"}}, nil)
if err == nil || !strings.Contains(err.Error(), "no user/assistant") {
t.Errorf("only system: want no-turns error, got %v", err)
}
}
func TestBuildConverseRequestRejectsMultimodalForNow(t *testing.T) {
// The text-only path is the only one this PR ships. A multimodal
// content array must fail loudly so the operator gets a clear
// migration path rather than a silently-truncated request.
_, err := buildConverseRequest([]Message{{Role: "user", Content: []interface{}{
map[string]interface{}{"type": "text", "text": "Hi"},
}}}, nil)
if err == nil || !strings.Contains(err.Error(), "only string Message.Content") {
t.Errorf("multimodal: want text-only error, got %v", err)
}
}
func TestMapChatConfigToInferenceForwardsAllFields(t *testing.T) {
mt := 4096
temp := 0.5
topP := 0.9
stop := []string{"END"}
inf := mapChatConfigToInference(&ChatConfig{
MaxTokens: &mt, Temperature: &temp, TopP: &topP, Stop: &stop,
})
if inf == nil {
t.Fatal("expected non-nil inferenceConfig")
}
if inf.MaxTokens == nil || *inf.MaxTokens != 4096 {
t.Errorf("maxTokens=%v", inf.MaxTokens)
}
if inf.Temperature == nil || *inf.Temperature != 0.5 {
t.Errorf("temperature=%v", inf.Temperature)
}
if inf.TopP == nil || *inf.TopP != 0.9 {
t.Errorf("topP=%v", inf.TopP)
}
if len(inf.StopSequences) != 1 || inf.StopSequences[0] != "END" {
t.Errorf("stopSequences=%v", inf.StopSequences)
}
}
func TestMapChatConfigToInferenceReturnsNilWhenUnset(t *testing.T) {
if got := mapChatConfigToInference(nil); got != nil {
t.Errorf("nil cfg: want nil, got %+v", got)
}
if got := mapChatConfigToInference(&ChatConfig{}); got != nil {
t.Errorf("empty cfg: want nil, got %+v", got)
}
}
func TestExtractAnswerConcatenatesContentBlocks(t *testing.T) {
resp := &bedrockConverseResponse{}
resp.Output.Message.Content = []bedrockContentBlock{
{Text: "Hello "},
{Text: "world"},
}
if got := extractAnswer(resp); got != "Hello world" {
t.Errorf("extractAnswer=%q", got)
}
}
func TestBedrockRuntimeURLUsesOverride(t *testing.T) {
b := NewBedrockModel(map[string]string{"us-east-1": "https://proxy.example.com/bedrock"}, URLSuffix{})
got := b.bedrockRuntimeURL("us-east-1", "anthropic.claude-3-haiku-20240307-v1:0", "converse")
want := "https://proxy.example.com/bedrock/model/anthropic.claude-3-haiku-20240307-v1:0/converse"
if got != want {
t.Errorf("override URL=%q, want %q", got, want)
}
}
func TestBedrockRuntimeURLFallsBackToAWS(t *testing.T) {
b := NewBedrockModel(nil, URLSuffix{})
got := b.bedrockRuntimeURL("eu-west-1", "amazon.nova-lite-v1:0", "converse")
want := "https://bedrock-runtime.eu-west-1.amazonaws.com/model/amazon.nova-lite-v1:0/converse"
if got != want {
t.Errorf("default URL=%q, want %q", got, want)
}
}
func TestBedrockControlURLFallsBackToAWS(t *testing.T) {
b := NewBedrockModel(nil, URLSuffix{})
got := b.bedrockControlURL("us-west-2", "foundation-models")
want := "https://bedrock.us-west-2.amazonaws.com/foundation-models"
if got != want {
t.Errorf("control URL=%q, want %q", got, want)
}
}
func TestJoinBedrockPathHandlesTrailingSlashes(t *testing.T) {
got := joinBedrockPath("https://x.example.com/", "model", "/m/", "converse")
want := "https://x.example.com/model/m/converse"
if got != want {
t.Errorf("joinBedrockPath=%q, want %q", got, want)
}
}
func TestExtractBedrockDeltaTextHappyPath(t *testing.T) {
got, err := extractBedrockDeltaText([]byte(`{"delta":{"text":"hi"},"contentBlockIndex":0}`))
if err != nil {
t.Fatalf("extract: %v", err)
}
if got != "hi" {
t.Errorf("text=%q", got)
}
}
func TestExtractBedrockDeltaTextSkipsEmpty(t *testing.T) {
// A toolUse-only delta has no text field; the helper must return
// "" with no error so the streaming loop simply skips the frame.
got, err := extractBedrockDeltaText([]byte(`{"delta":{"toolUse":{}},"contentBlockIndex":0}`))
if err != nil {
t.Fatalf("extract: %v", err)
}
if got != "" {
t.Errorf("expected empty text, got %q", got)
}
}
func TestExtractBedrockDeltaTextRejectsMalformed(t *testing.T) {
_, err := extractBedrockDeltaText([]byte(`{not-json}`))
if err == nil || !strings.Contains(err.Error(), "invalid contentBlockDelta") {
t.Errorf("malformed: want explicit error, got %v", err)
}
}
// newBedrockServer returns an httptest.Server that asserts the
// request method, path, and Authorization header before delegating
// to the supplied handler. We accept any "AWS4-HMAC-SHA256 ..."
// header rather than re-verify the signature: SigV4 correctness is
// the SDK's responsibility and is covered by its own tests.
func newBedrockServer(t *testing.T, wantMethod, wantPath string, handler http.HandlerFunc) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != wantMethod {
t.Errorf("method=%s, want %s", r.Method, wantMethod)
}
if r.URL.Path != wantPath {
t.Errorf("path=%s, want %s", r.URL.Path, wantPath)
}
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "AWS4-HMAC-SHA256 ") {
t.Errorf("Authorization=%q, want SigV4 prefix", auth)
}
handler(w, r)
}))
}
func TestBedrockChatHappyPath(t *testing.T) {
srv := newBedrockServer(t, http.MethodPost,
"/model/anthropic.claude-3-haiku-20240307-v1:0/converse",
func(w http.ResponseWriter, r *http.Request) {
raw, _ := io.ReadAll(r.Body)
var body bedrockConverseRequest
if err := json.Unmarshal(raw, &body); err != nil {
t.Errorf("unmarshal body: %v", err)
return
}
if len(body.Messages) != 1 || body.Messages[0].Role != "user" {
t.Errorf("messages wrong: %+v", body.Messages)
}
resp := bedrockConverseResponse{}
resp.Output.Message.Role = "assistant"
resp.Output.Message.Content = []bedrockContentBlock{{Text: "pong"}}
resp.StopReason = "end_turn"
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
resp, err := m.ChatWithMessages("anthropic.claude-3-haiku-20240307-v1:0",
[]Message{{Role: "user", Content: "ping"}},
&APIConfig{ApiKey: &key}, nil)
if err != nil {
t.Fatalf("Chat: %v", err)
}
if resp.Answer == nil || resp.ReasonContent == nil {
t.Fatalf("answer/reason must be non-nil pointers, got %v / %v", resp.Answer, resp.ReasonContent)
}
if *resp.Answer != "pong" {
t.Errorf("answer=%q want pong", *resp.Answer)
}
if *resp.ReasonContent != "" {
t.Errorf("reason=%q want empty", *resp.ReasonContent)
}
}
func TestBedrockChatRequiresAPIKey(t *testing.T) {
m := newBedrockForTest("http://unused")
_, err := m.ChatWithMessages("m", []Message{{Role: "user", Content: "x"}}, &APIConfig{}, nil)
if err == nil || !strings.Contains(err.Error(), "api key is required") {
t.Errorf("want api-key error, got %v", err)
}
}
func TestBedrockChatRequiresModelID(t *testing.T) {
m := newBedrockForTest("http://unused")
key := validBedrockKey()
_, err := m.ChatWithMessages("", []Message{{Role: "user", Content: "x"}},
&APIConfig{ApiKey: &key}, nil)
if err == nil || !strings.Contains(err.Error(), "model id is required") {
t.Errorf("want model-required error, got %v", err)
}
}
func TestBedrockChatPropagatesHTTPError(t *testing.T) {
srv := newBedrockServer(t, http.MethodPost,
"/model/m/converse",
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"message":"InvalidSignatureException"}`))
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
_, err := m.ChatWithMessages("m",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{ApiKey: &key}, nil)
if err == nil || !strings.Contains(err.Error(), "401") {
t.Errorf("want 401 propagated, got %v", err)
}
if err != nil && !strings.Contains(err.Error(), "InvalidSignatureException") {
t.Errorf("want body included in error, got %v", err)
}
}
func TestBedrockListModelsParsesCatalog(t *testing.T) {
srv := newBedrockServer(t, http.MethodGet,
"/foundation-models",
func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"modelSummaries": [
{"modelId":"anthropic.claude-3-haiku-20240307-v1:0"},
{"modelId":"amazon.nova-lite-v1:0"},
{"modelId":""}
]
}`))
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
got, err := m.ListModels(&APIConfig{ApiKey: &key})
if err != nil {
t.Fatalf("ListModels: %v", err)
}
// The empty-modelId summary is filtered so a malformed AWS
// response never leaks an empty string up to the UI dropdown.
want := []string{
"anthropic.claude-3-haiku-20240307-v1:0",
"amazon.nova-lite-v1:0",
}
if len(got) != len(want) {
t.Fatalf("len(got)=%d want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i].Name != want[i] {
t.Errorf("got[%d]=%s want %q", i, got[i].Name, want[i])
}
}
}
func TestBedrockCheckConnectionDelegates(t *testing.T) {
srv := newBedrockServer(t, http.MethodGet,
"/foundation-models",
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
if err := m.CheckConnection(&APIConfig{ApiKey: &key}); err == nil || !strings.Contains(err.Error(), "403") {
t.Errorf("want 403 surfaced via ListModels, got %v", err)
}
}
// encodeBedrockEventFrames builds an in-memory event-stream that
// looks like a real Bedrock Converse-Stream response: a messageStart
// lifecycle event, two contentBlockDelta payloads, and a messageStop
// terminator. We use the same encoder the SDK uses on the wire so
// the decode path is exercised end-to-end.
func encodeBedrockEventFrames(t *testing.T, events []struct {
eventType string
messageType string
payload []byte
}) []byte {
t.Helper()
var buf bytes.Buffer
enc := eventstream.NewEncoder()
for _, e := range events {
msg := eventstream.Message{
Headers: eventstream.Headers{
{Name: ":event-type", Value: eventstream.StringValue(e.eventType)},
},
Payload: e.payload,
}
if e.messageType != "" {
msg.Headers = append(msg.Headers, eventstream.Header{
Name: ":message-type",
Value: eventstream.StringValue(e.messageType),
})
} else {
msg.Headers = append(msg.Headers, eventstream.Header{
Name: ":message-type",
Value: eventstream.StringValue("event"),
})
}
if err := enc.Encode(&buf, msg); err != nil {
t.Fatalf("encode event-stream frame: %v", err)
}
}
return buf.Bytes()
}
func TestBedrockStreamDecodesContentDeltas(t *testing.T) {
frames := encodeBedrockEventFrames(t, []struct {
eventType string
messageType string
payload []byte
}{
{eventType: "messageStart", payload: []byte(`{"role":"assistant"}`)},
{eventType: "contentBlockDelta", payload: []byte(`{"delta":{"text":"Hello"},"contentBlockIndex":0}`)},
{eventType: "contentBlockDelta", payload: []byte(`{"delta":{"text":" world"},"contentBlockIndex":0}`)},
{eventType: "messageStop", payload: []byte(`{"stopReason":"end_turn"}`)},
})
srv := newBedrockServer(t, http.MethodPost,
"/model/m/converse-stream",
func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
_, _ = w.Write(frames)
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
var chunks []string
sawDone := false
err := m.ChatStreamlyWithSender("m",
[]Message{{Role: "user", Content: "hi"}},
&APIConfig{ApiKey: &key}, nil,
func(c *string, _ *string) error {
if c == nil {
return nil
}
if *c == "[DONE]" {
sawDone = true
return nil
}
chunks = append(chunks, *c)
return nil
})
if err != nil {
t.Fatalf("stream: %v", err)
}
if got := strings.Join(chunks, ""); got != "Hello world" {
t.Errorf("chunks=%q want %q", got, "Hello world")
}
if !sawDone {
t.Error("expected [DONE] sentinel")
}
}
func TestBedrockStreamSurfacesException(t *testing.T) {
frames := encodeBedrockEventFrames(t, []struct {
eventType string
messageType string
payload []byte
}{
{eventType: "contentBlockDelta", payload: []byte(`{"delta":{"text":"partial"},"contentBlockIndex":0}`)},
{
eventType: "throttlingException",
messageType: "exception",
payload: []byte(`{"message":"Too many requests"}`),
},
})
srv := newBedrockServer(t, http.MethodPost,
"/model/m/converse-stream",
func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
_, _ = w.Write(frames)
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
err := m.ChatStreamlyWithSender("m",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{ApiKey: &key}, nil,
func(*string, *string) error { return nil })
if err == nil || !strings.Contains(err.Error(), "exception") {
t.Errorf("want exception surfaced, got %v", err)
}
}
func TestBedrockStreamFailsWithoutTerminal(t *testing.T) {
// Connection closed cleanly after a delta but before messageStop.
// This used to be silently treated as success, masking truncated
// answers; the driver must now surface a "stream ended before
// messageStop" error so the caller can retry or alert.
frames := encodeBedrockEventFrames(t, []struct {
eventType string
messageType string
payload []byte
}{
{eventType: "contentBlockDelta", payload: []byte(`{"delta":{"text":"half"},"contentBlockIndex":0}`)},
})
srv := newBedrockServer(t, http.MethodPost,
"/model/m/converse-stream",
func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
_, _ = w.Write(frames)
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
err := m.ChatStreamlyWithSender("m",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{ApiKey: &key}, nil,
func(*string, *string) error { return nil })
if err == nil || !strings.Contains(err.Error(), "stream ended before") {
t.Errorf("want truncation error, got %v", err)
}
}
func TestBedrockStreamRejectsExplicitFalse(t *testing.T) {
m := newBedrockForTest("http://unused")
key := validBedrockKey()
stream := false
err := m.ChatStreamlyWithSender("m",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{ApiKey: &key},
&ChatConfig{Stream: &stream},
func(*string, *string) error { return nil })
if err == nil || !strings.Contains(err.Error(), "stream must be true") {
t.Errorf("want stream-true guard, got %v", err)
}
}
func TestBedrockStreamRequiresSender(t *testing.T) {
m := newBedrockForTest("http://unused")
key := validBedrockKey()
err := m.ChatStreamlyWithSender("m",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{ApiKey: &key}, nil, nil)
if err == nil || !strings.Contains(err.Error(), "sender is required") {
t.Errorf("want sender-required error, got %v", err)
}
}
func TestLookupBedrockEventHeader(t *testing.T) {
headers := eventstream.Headers{
{Name: ":event-type", Value: eventstream.StringValue("contentBlockDelta")},
{Name: ":message-type", Value: eventstream.StringValue("event")},
}
if got := lookupBedrockEventHeader(headers, ":event-type"); got != "contentBlockDelta" {
t.Errorf("event-type=%q", got)
}
if got := lookupBedrockEventHeader(headers, ":nonexistent"); got != "" {
t.Errorf("nonexistent header=%q want empty", got)
}
}
func TestBedrockTitanEmbedHappyPath(t *testing.T) {
var seenInputs []string
srv := newBedrockServer(t, http.MethodPost,
"/model/amazon.titan-embed-text-v2:0/invoke",
func(w http.ResponseWriter, r *http.Request) {
raw, _ := io.ReadAll(r.Body)
var body bedrockTitanEmbeddingRequest
if err := json.Unmarshal(raw, &body); err != nil {
t.Errorf("unmarshal body: %v", err)
return
}
seenInputs = append(seenInputs, body.InputText)
if body.Dimensions == nil || *body.Dimensions != 256 {
t.Errorf("dimensions=%v, want 256", body.Dimensions)
}
w.Header().Set("Content-Type", "application/json")
if body.InputText == "alpha" {
_, _ = w.Write([]byte(`{"embedding":[0.1,0.2]}`))
} else {
_, _ = w.Write([]byte(`{"embedding":[0.3,0.4]}`))
}
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
model := "amazon.titan-embed-text-v2:0"
got, err := m.Embed(&model, []string{"alpha", "beta"}, &APIConfig{ApiKey: &key}, &EmbeddingConfig{Dimension: 256})
if err != nil {
t.Fatalf("Embed: %v", err)
}
if len(seenInputs) != 2 || seenInputs[0] != "alpha" || seenInputs[1] != "beta" {
t.Fatalf("seen inputs=%v", seenInputs)
}
if len(got) != 2 {
t.Fatalf("len(got)=%d want 2", len(got))
}
if got[0].Index != 0 || got[0].Embedding[0] != 0.1 || got[1].Index != 1 || got[1].Embedding[0] != 0.3 {
t.Errorf("embeddings=%+v", got)
}
}
func TestBedrockTitanV1OmitsDimension(t *testing.T) {
srv := newBedrockServer(t, http.MethodPost,
"/model/amazon.titan-embed-text-v1/invoke",
func(w http.ResponseWriter, r *http.Request) {
raw, _ := io.ReadAll(r.Body)
if strings.Contains(string(raw), "dimensions") {
t.Errorf("Titan v1 body must not include dimensions: %s", string(raw))
}
_, _ = w.Write([]byte(`{"embedding":[0.1,0.2]}`))
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
model := "amazon.titan-embed-text-v1"
if _, err := m.Embed(&model, []string{"alpha"}, &APIConfig{ApiKey: &key}, &EmbeddingConfig{Dimension: 256}); err != nil {
t.Fatalf("Embed: %v", err)
}
}
func TestBedrockCohereEmbedHappyPath(t *testing.T) {
srv := newBedrockServer(t, http.MethodPost,
"/model/cohere.embed-english-v3/invoke",
func(w http.ResponseWriter, r *http.Request) {
raw, _ := io.ReadAll(r.Body)
var body bedrockCohereEmbeddingRequest
if err := json.Unmarshal(raw, &body); err != nil {
t.Errorf("unmarshal body: %v", err)
return
}
if len(body.Texts) != 2 || body.Texts[0] != "first" || body.Texts[1] != "second" {
t.Errorf("texts=%v", body.Texts)
}
if body.InputType != "search_document" {
t.Errorf("input_type=%q want search_document", body.InputType)
}
if body.OutputDimension != nil {
t.Errorf("v3 output_dimension=%v, want omitted", *body.OutputDimension)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"embeddings":[[1,2],[3,4]]}`))
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
model := "cohere.embed-english-v3"
got, err := m.Embed(&model, []string{"first", "second"}, &APIConfig{ApiKey: &key}, &EmbeddingConfig{Dimension: 128})
if err != nil {
t.Fatalf("Embed: %v", err)
}
if len(got) != 2 || got[0].Index != 0 || got[0].Embedding[0] != 1 || got[1].Index != 1 || got[1].Embedding[0] != 3 {
t.Errorf("embeddings=%+v", got)
}
}
func TestBedrockCohereV4ForwardsDimensionAndParsesTypedResponse(t *testing.T) {
srv := newBedrockServer(t, http.MethodPost,
"/model/cohere.embed-v4:0/invoke",
func(w http.ResponseWriter, r *http.Request) {
raw, _ := io.ReadAll(r.Body)
var body bedrockCohereEmbeddingRequest
if err := json.Unmarshal(raw, &body); err != nil {
t.Errorf("unmarshal body: %v", err)
return
}
if body.OutputDimension == nil || *body.OutputDimension != 512 {
t.Errorf("output_dimension=%v, want 512", body.OutputDimension)
}
_, _ = w.Write([]byte(`{"embeddings":{"float":[[0.5,0.6]]}}`))
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
model := "cohere.embed-v4:0"
got, err := m.Embed(&model, []string{"first"}, &APIConfig{ApiKey: &key}, &EmbeddingConfig{Dimension: 512})
if err != nil {
t.Fatalf("Embed: %v", err)
}
if len(got) != 1 || got[0].Index != 0 || got[0].Embedding[0] != 0.5 {
t.Errorf("embeddings=%+v", got)
}
}
func TestBedrockEmbedShortCircuitsEmptyInput(t *testing.T) {
m := newBedrockForTest("http://unused")
got, err := m.Embed(nil, nil, nil, nil)
if err != nil {
t.Fatalf("Embed empty: %v", err)
}
if len(got) != 0 {
t.Errorf("len(got)=%d want 0", len(got))
}
}
func TestBedrockEmbedRequiresAPIKeyAndModel(t *testing.T) {
m := newBedrockForTest("http://unused")
model := "x"
if _, err := m.Embed(&model, []string{"a"}, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "api key is required") {
t.Errorf("Embed: want api-key error, got %v", err)
}
key := validBedrockKey()
blank := " "
if _, err := m.Embed(&blank, []string{"a"}, &APIConfig{ApiKey: &key}, nil); err == nil || !strings.Contains(err.Error(), "model name is required") {
t.Errorf("Embed: want model-required error, got %v", err)
}
}
func TestBedrockEmbedRejectsUnsupportedModel(t *testing.T) {
m := newBedrockForTest("http://unused")
key := validBedrockKey()
model := "anthropic.claude-3-haiku-20240307-v1:0"
if _, err := m.Embed(&model, []string{"a"}, &APIConfig{ApiKey: &key}, nil); err == nil || !strings.Contains(err.Error(), "unsupported embedding model") {
t.Errorf("Embed: want unsupported-model error, got %v", err)
}
}
func TestBedrockEmbedPropagatesHTTPError(t *testing.T) {
srv := newBedrockServer(t, http.MethodPost,
"/model/amazon.titan-embed-text-v2:0/invoke",
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"message":"bad input"}`))
})
defer srv.Close()
m := newBedrockForTest(srv.URL)
key := validBedrockKey()
model := "amazon.titan-embed-text-v2:0"
if _, err := m.Embed(&model, []string{"a"}, &APIConfig{ApiKey: &key}, nil); err == nil || !strings.Contains(err.Error(), "400") || !strings.Contains(err.Error(), "bad input") {
t.Errorf("Embed: want HTTP error with body, got %v", err)
}
}
func TestBedrockRerankReturnsNoSuchMethod(t *testing.T) {
m := newBedrockForTest("http://unused")
model := "x"
if _, err := m.Rerank(&model, "q", []string{"a"}, &APIConfig{}, &RerankConfig{TopN: 1}); err == nil || !strings.Contains(err.Error(), "no such method") {
t.Errorf("Rerank: want no-such-method, got %v", err)
}
}
func TestBedrockBalanceReturnsNoSuchMethod(t *testing.T) {
m := newBedrockForTest("http://unused")
if _, err := m.Balance(&APIConfig{}); err == nil || !strings.Contains(err.Error(), "no such method") {
t.Errorf("Balance: want no-such-method, got %v", err)
}
}
func TestBedrockAudioOCRReturnNoSuchMethod(t *testing.T) {
m := newBedrockForTest("http://unused")
model := "x"
if _, err := m.TranscribeAudio(&model, &model, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
t.Errorf("TranscribeAudio: want no-such-method, got %v", err)
}
if _, err := m.AudioSpeech(&model, &model, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
t.Errorf("AudioSpeech: want no-such-method, got %v", err)
}
if _, err := m.OCRFile(&model, nil, &model, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
t.Errorf("OCRFile: want no-such-method, got %v", err)
}
if _, err := m.ParseFile(&model, nil, &model, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
t.Errorf("ParseFile: want no-such-method, got %v", err)
}
if _, err := m.ListTasks(&APIConfig{}); err == nil || !strings.Contains(err.Error(), "no such method") {
t.Errorf("ListTasks: want no-such-method, got %v", err)
}
if _, err := m.ShowTask("t", &APIConfig{}); err == nil || !strings.Contains(err.Error(), "no such method") {
t.Errorf("ShowTask: want no-such-method, got %v", err)
}
}
// Compile-time check that BedrockModel satisfies the ModelDriver
// contract. Any missing or mis-typed method shows up here as a build
// error instead of a confusing runtime nil-method-set panic.
var _ ModelDriver = (*BedrockModel)(nil)