Go: implement Embed in GPUStack driver (#15182)

### What problem does this PR solve?

The Go GPUStack driver returned a stub error for `Embed()` even though
GPUStack exposes OpenAI-compatible embeddings on the **v1-openai** route
(not `v1/embeddings`).

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
glorydavid03023
2026-05-31 22:22:43 -05:00
committed by GitHub
parent 2d7044b57e
commit 3774916060
3 changed files with 306 additions and 10 deletions

View File

@@ -2,7 +2,8 @@
"name": "GPUStack",
"url_suffix": {
"chat": "v1/chat/completions",
"models": "v1/models"
"models": "v1/models",
"embedding": "v1-openai/embeddings"
},
"class": "local"
}

View File

@@ -36,9 +36,8 @@ import (
// Chat is served at <base>/v1/chat/completions with the standard
// OpenAI wire shape and Bearer auth (the GPUStack server always
// requires an API key; see rag/llm/chat_model.py GPUStack route).
// /v1 also aliases /v1-openai for chat and embeddings; this driver
// uses /v1 to match the Python side and the closed bug report
// #13236 ("v1 suffix required in base url").
// Chat uses /v1 (see #13236). Embeddings use the v1-openai route per GPUStack
// API docs and maintainer review (v1-openai/embeddings, not v1/embeddings).
type GPUStackModel struct {
BaseURL map[string]string
URLSuffix URLSuffix
@@ -424,8 +423,117 @@ func (g *GPUStackModel) CheckConnection(apiConfig *APIConfig) error {
return err
}
func (g *GPUStackModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
return nil, fmt.Errorf("%s, no such method", g.Name())
// gpustackEmbeddingData is one element in a GPUStack embeddings response.
type gpustackEmbeddingData struct {
Embedding []float64 `json:"embedding"`
Index *int `json:"index"`
}
// gpustackEmbeddingResponse is the JSON body returned by GPUStack embeddings API.
type gpustackEmbeddingResponse struct {
Data []gpustackEmbeddingData `json:"data"`
}
// Embed requests embedding vectors via GPUStack's v1-openai/embeddings endpoint.
func (g *GPUStackModel) Embed(
modelName *string,
texts []string,
apiConfig *APIConfig,
embeddingConfig *EmbeddingConfig,
) ([]EmbeddingData, error) {
if len(texts) == 0 {
return []EmbeddingData{}, nil
}
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return nil, fmt.Errorf("api key is required")
}
if modelName == nil || strings.TrimSpace(*modelName) == "" {
return nil, fmt.Errorf("model name is required")
}
region := "default"
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL, err := g.baseURLForRegion(region)
if err != nil {
return nil, err
}
if g.URLSuffix.Embedding == "" {
return nil, fmt.Errorf("gpustack: embedding URL suffix is not configured")
}
url := fmt.Sprintf("%s/%s", baseURL, strings.TrimPrefix(g.URLSuffix.Embedding, "/"))
reqBody := map[string]interface{}{
"model": *modelName,
"input": texts,
}
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
reqBody["dimensions"] = embeddingConfig.Dimension
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := g.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("gpustack embeddings API error: %s, body: %s", resp.Status, string(body))
}
var parsed gpustackEmbeddingResponse
if err = json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
embeddings := make([]EmbeddingData, len(texts))
filled := make([]bool, len(texts))
for _, item := range parsed.Data {
if item.Index == nil {
return nil, fmt.Errorf("gpustack: missing embedding index in response item")
}
idx := *item.Index
if idx < 0 || idx >= len(texts) {
return nil, fmt.Errorf("gpustack: embedding response index %d out of range for %d inputs", idx, len(texts))
}
if filled[idx] {
return nil, fmt.Errorf("gpustack: duplicate embedding index %d in response", idx)
}
if len(item.Embedding) == 0 {
return nil, fmt.Errorf("gpustack: empty embedding vector for input index %d", idx)
}
embeddings[idx] = EmbeddingData{
Embedding: item.Embedding,
Index: idx,
}
filled[idx] = true
}
for i, ok := range filled {
if !ok {
return nil, fmt.Errorf("gpustack: missing embedding for input index %d", i)
}
}
return embeddings, nil
}
func (g *GPUStackModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {

View File

@@ -73,10 +73,12 @@ func newGPUStackSSEServer(t *testing.T, expectedPath, ssePayload string) *httpte
func newGPUStackForTest(baseURL string) *GPUStackModel {
return NewGPUStackModel(
map[string]string{"default": baseURL},
URLSuffix{Chat: "v1/chat/completions", Models: "v1/models"},
URLSuffix{Chat: "v1/chat/completions", Models: "v1/models", Embedding: "v1-openai/embeddings"},
)
}
const gpustackEmbeddingsPath = "/v1-openai/embeddings"
func TestGPUStackName(t *testing.T) {
if got := newGPUStackForTest("http://unused").Name(); got != "gpustack" {
t.Errorf("Name()=%q, want %q", got, "gpustack")
@@ -451,12 +453,197 @@ func TestGPUStackListModelsRequiresAPIKey(t *testing.T) {
}
}
// TestGPUStackEmbedHappyPath verifies request shape and dimensions on v1-openai/embeddings.
func TestGPUStackEmbedHappyPath(t *testing.T) {
srv := newGPUStackServer(t, gpustackEmbeddingsPath, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
if body["model"] != "bge-m3" {
t.Errorf("model=%v", body["model"])
}
if body["dimensions"] != float64(512) {
t.Errorf("dimensions=%v, want 512", body["dimensions"])
}
inputs, ok := body["input"].([]interface{})
if !ok || len(inputs) != 2 {
t.Errorf("input=%v, want 2-element array", body["input"])
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"data": []map[string]interface{}{
{"embedding": []float64{0.2, 0.2}, "index": 1},
{"embedding": []float64{0.1, 0.2}, "index": 0},
},
})
})
defer srv.Close()
apiKey := "test-key"
model := "bge-m3"
vecs, err := newGPUStackForTest(srv.URL).Embed(
&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, &EmbeddingConfig{Dimension: 512})
if err != nil {
t.Fatalf("Embed: %v", err)
}
if len(vecs) != 2 {
t.Fatalf("len(vecs)=%d, want 2", len(vecs))
}
if vecs[0].Index != 0 || vecs[0].Embedding[0] != 0.1 || vecs[1].Index != 1 || vecs[1].Embedding[0] != 0.2 {
t.Errorf("vecs=%+v", vecs)
}
}
// TestGPUStackEmbedReordersByIndex verifies out-of-order response indices are mapped correctly.
func TestGPUStackEmbedReordersByIndex(t *testing.T) {
srv := newGPUStackServer(t, gpustackEmbeddingsPath, func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"data": []map[string]interface{}{
{"embedding": []float64{2}, "index": 2},
{"embedding": []float64{0}, "index": 0},
{"embedding": []float64{1}, "index": 1},
},
})
})
defer srv.Close()
apiKey := "test-key"
model := "bge-m3"
vecs, err := newGPUStackForTest(srv.URL).Embed(
&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil)
if err != nil {
t.Fatalf("Embed: %v", err)
}
for i, v := range vecs {
if v.Index != i || v.Embedding[0] != float64(i) {
t.Errorf("slot %d = %+v, want Embedding=[%d] Index=%d", i, v, i, i)
}
}
}
// TestGPUStackEmbedEmptyInputShortCircuits avoids HTTP when texts is empty.
func TestGPUStackEmbedEmptyInputShortCircuits(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
t.Error("Embed([]) made an unexpected HTTP call")
w.WriteHeader(http.StatusInternalServerError)
}))
defer srv.Close()
apiKey := "test-key"
model := "bge-m3"
vecs, err := newGPUStackForTest(srv.URL).Embed(&model, []string{}, &APIConfig{ApiKey: &apiKey}, nil)
if err != nil {
t.Fatalf("Embed([]): %v", err)
}
if len(vecs) != 0 {
t.Errorf("len(vecs)=%d, want 0", len(vecs))
}
}
// TestGPUStackEmbedRequiresAPIKey rejects requests without an API key.
func TestGPUStackEmbedRequiresAPIKey(t *testing.T) {
model := "bge-m3"
_, err := newGPUStackForTest("http://unused").Embed(&model, []string{"a"}, &APIConfig{}, nil)
if err == nil || !strings.Contains(err.Error(), "api key is required") {
t.Errorf("expected api-key error, got %v", err)
}
}
// TestGPUStackEmbedRejectsDuplicateIndex errors on duplicate response indices.
func TestGPUStackEmbedRejectsDuplicateIndex(t *testing.T) {
srv := newGPUStackServer(t, gpustackEmbeddingsPath, func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"data": []map[string]interface{}{
{"embedding": []float64{0.1}, "index": 0},
{"embedding": []float64{0.2}, "index": 0},
},
})
})
defer srv.Close()
apiKey := "test-key"
model := "bge-m3"
_, err := newGPUStackForTest(srv.URL).Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
if err == nil || !strings.Contains(err.Error(), "duplicate") {
t.Errorf("expected duplicate-index error, got %v", err)
}
}
// TestGPUStackEmbedRejectsOutOfRangeIndex errors when index exceeds input length.
func TestGPUStackEmbedRejectsOutOfRangeIndex(t *testing.T) {
srv := newGPUStackServer(t, gpustackEmbeddingsPath, func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"data": []map[string]interface{}{
{"embedding": []float64{0.1}, "index": 2},
},
})
})
defer srv.Close()
apiKey := "test-key"
model := "bge-m3"
_, err := newGPUStackForTest(srv.URL).Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
if err == nil || !strings.Contains(err.Error(), "out of range") {
t.Errorf("expected out-of-range error, got %v", err)
}
}
// TestGPUStackEmbedRejectsMissingIndex errors when index is omitted from response.
func TestGPUStackEmbedRejectsMissingIndex(t *testing.T) {
srv := newGPUStackServer(t, gpustackEmbeddingsPath, func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"data": []map[string]interface{}{
{"embedding": []float64{0.1}},
},
})
})
defer srv.Close()
apiKey := "test-key"
model := "bge-m3"
_, err := newGPUStackForTest(srv.URL).Embed(&model, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil)
if err == nil || !strings.Contains(err.Error(), "missing embedding index") {
t.Errorf("expected missing-index error, got %v", err)
}
}
// TestGPUStackEmbedRejectsEmptyVector errors when the API returns a zero-length vector.
func TestGPUStackEmbedRejectsEmptyVector(t *testing.T) {
srv := newGPUStackServer(t, gpustackEmbeddingsPath, func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"data": []map[string]interface{}{
{"embedding": []float64{}, "index": 0},
},
})
})
defer srv.Close()
apiKey := "test-key"
model := "bge-m3"
_, err := newGPUStackForTest(srv.URL).Embed(&model, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil)
if err == nil || !strings.Contains(err.Error(), "empty embedding vector") {
t.Errorf("expected empty-vector error, got %v", err)
}
}
// TestGPUStackEmbedRejectsMissingSlot errors when a response index is never returned.
func TestGPUStackEmbedRejectsMissingSlot(t *testing.T) {
srv := newGPUStackServer(t, gpustackEmbeddingsPath, func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"data": []map[string]interface{}{
{"embedding": []float64{0.1}, "index": 0},
},
})
})
defer srv.Close()
apiKey := "test-key"
model := "bge-m3"
_, err := newGPUStackForTest(srv.URL).Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
if err == nil || !strings.Contains(err.Error(), "missing embedding for input index") {
t.Errorf("expected missing-slot error, got %v", err)
}
}
func TestGPUStackUnsupportedMethods(t *testing.T) {
m := newGPUStackForTest("http://unused")
model := "x"
if _, err := m.Embed(&model, []string{"a"}, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
t.Errorf("Embed: %v", err)
}
if _, err := m.Rerank(&model, "q", []string{"a"}, &APIConfig{}, &RerankConfig{TopN: 1}); err == nil || !strings.Contains(err.Error(), "no such method") {
t.Errorf("Rerank: %v", err)
}