mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Go: implement Embed in GPUStack driver (#15182)
### What problem does this PR solve? The Go GPUStack driver returned a stub error for `Embed()` even though GPUStack exposes OpenAI-compatible embeddings on the **v1-openai** route (not `v1/embeddings`). ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -2,7 +2,8 @@
|
||||
"name": "GPUStack",
|
||||
"url_suffix": {
|
||||
"chat": "v1/chat/completions",
|
||||
"models": "v1/models"
|
||||
"models": "v1/models",
|
||||
"embedding": "v1-openai/embeddings"
|
||||
},
|
||||
"class": "local"
|
||||
}
|
||||
|
||||
@@ -36,9 +36,8 @@ import (
|
||||
// Chat is served at <base>/v1/chat/completions with the standard
|
||||
// OpenAI wire shape and Bearer auth (the GPUStack server always
|
||||
// requires an API key; see rag/llm/chat_model.py GPUStack route).
|
||||
// /v1 also aliases /v1-openai for chat and embeddings; this driver
|
||||
// uses /v1 to match the Python side and the closed bug report
|
||||
// #13236 ("v1 suffix required in base url").
|
||||
// Chat uses /v1 (see #13236). Embeddings use the v1-openai route per GPUStack
|
||||
// API docs and maintainer review (v1-openai/embeddings, not v1/embeddings).
|
||||
type GPUStackModel struct {
|
||||
BaseURL map[string]string
|
||||
URLSuffix URLSuffix
|
||||
@@ -424,8 +423,117 @@ func (g *GPUStackModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *GPUStackModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", g.Name())
|
||||
// gpustackEmbeddingData is one element in a GPUStack embeddings response.
|
||||
type gpustackEmbeddingData struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index *int `json:"index"`
|
||||
}
|
||||
|
||||
// gpustackEmbeddingResponse is the JSON body returned by GPUStack embeddings API.
|
||||
type gpustackEmbeddingResponse struct {
|
||||
Data []gpustackEmbeddingData `json:"data"`
|
||||
}
|
||||
|
||||
// Embed requests embedding vectors via GPUStack's v1-openai/embeddings endpoint.
|
||||
func (g *GPUStackModel) Embed(
|
||||
modelName *string,
|
||||
texts []string,
|
||||
apiConfig *APIConfig,
|
||||
embeddingConfig *EmbeddingConfig,
|
||||
) ([]EmbeddingData, error) {
|
||||
if len(texts) == 0 {
|
||||
return []EmbeddingData{}, nil
|
||||
}
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("api key is required")
|
||||
}
|
||||
if modelName == nil || strings.TrimSpace(*modelName) == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
baseURL, err := g.baseURLForRegion(region)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if g.URLSuffix.Embedding == "" {
|
||||
return nil, fmt.Errorf("gpustack: embedding URL suffix is not configured")
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", baseURL, strings.TrimPrefix(g.URLSuffix.Embedding, "/"))
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"input": texts,
|
||||
}
|
||||
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
|
||||
reqBody["dimensions"] = embeddingConfig.Dimension
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := g.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("gpustack embeddings API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var parsed gpustackEmbeddingResponse
|
||||
if err = json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
embeddings := make([]EmbeddingData, len(texts))
|
||||
filled := make([]bool, len(texts))
|
||||
for _, item := range parsed.Data {
|
||||
if item.Index == nil {
|
||||
return nil, fmt.Errorf("gpustack: missing embedding index in response item")
|
||||
}
|
||||
idx := *item.Index
|
||||
if idx < 0 || idx >= len(texts) {
|
||||
return nil, fmt.Errorf("gpustack: embedding response index %d out of range for %d inputs", idx, len(texts))
|
||||
}
|
||||
if filled[idx] {
|
||||
return nil, fmt.Errorf("gpustack: duplicate embedding index %d in response", idx)
|
||||
}
|
||||
if len(item.Embedding) == 0 {
|
||||
return nil, fmt.Errorf("gpustack: empty embedding vector for input index %d", idx)
|
||||
}
|
||||
embeddings[idx] = EmbeddingData{
|
||||
Embedding: item.Embedding,
|
||||
Index: idx,
|
||||
}
|
||||
filled[idx] = true
|
||||
}
|
||||
for i, ok := range filled {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("gpustack: missing embedding for input index %d", i)
|
||||
}
|
||||
}
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (g *GPUStackModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
|
||||
|
||||
@@ -73,10 +73,12 @@ func newGPUStackSSEServer(t *testing.T, expectedPath, ssePayload string) *httpte
|
||||
func newGPUStackForTest(baseURL string) *GPUStackModel {
|
||||
return NewGPUStackModel(
|
||||
map[string]string{"default": baseURL},
|
||||
URLSuffix{Chat: "v1/chat/completions", Models: "v1/models"},
|
||||
URLSuffix{Chat: "v1/chat/completions", Models: "v1/models", Embedding: "v1-openai/embeddings"},
|
||||
)
|
||||
}
|
||||
|
||||
const gpustackEmbeddingsPath = "/v1-openai/embeddings"
|
||||
|
||||
func TestGPUStackName(t *testing.T) {
|
||||
if got := newGPUStackForTest("http://unused").Name(); got != "gpustack" {
|
||||
t.Errorf("Name()=%q, want %q", got, "gpustack")
|
||||
@@ -451,12 +453,197 @@ func TestGPUStackListModelsRequiresAPIKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestGPUStackEmbedHappyPath verifies request shape and dimensions on v1-openai/embeddings.
|
||||
func TestGPUStackEmbedHappyPath(t *testing.T) {
|
||||
srv := newGPUStackServer(t, gpustackEmbeddingsPath, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
if body["model"] != "bge-m3" {
|
||||
t.Errorf("model=%v", body["model"])
|
||||
}
|
||||
if body["dimensions"] != float64(512) {
|
||||
t.Errorf("dimensions=%v, want 512", body["dimensions"])
|
||||
}
|
||||
inputs, ok := body["input"].([]interface{})
|
||||
if !ok || len(inputs) != 2 {
|
||||
t.Errorf("input=%v, want 2-element array", body["input"])
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{"embedding": []float64{0.2, 0.2}, "index": 1},
|
||||
{"embedding": []float64{0.1, 0.2}, "index": 0},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "bge-m3"
|
||||
vecs, err := newGPUStackForTest(srv.URL).Embed(
|
||||
&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, &EmbeddingConfig{Dimension: 512})
|
||||
if err != nil {
|
||||
t.Fatalf("Embed: %v", err)
|
||||
}
|
||||
if len(vecs) != 2 {
|
||||
t.Fatalf("len(vecs)=%d, want 2", len(vecs))
|
||||
}
|
||||
if vecs[0].Index != 0 || vecs[0].Embedding[0] != 0.1 || vecs[1].Index != 1 || vecs[1].Embedding[0] != 0.2 {
|
||||
t.Errorf("vecs=%+v", vecs)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGPUStackEmbedReordersByIndex verifies out-of-order response indices are mapped correctly.
|
||||
func TestGPUStackEmbedReordersByIndex(t *testing.T) {
|
||||
srv := newGPUStackServer(t, gpustackEmbeddingsPath, func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{"embedding": []float64{2}, "index": 2},
|
||||
{"embedding": []float64{0}, "index": 0},
|
||||
{"embedding": []float64{1}, "index": 1},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "bge-m3"
|
||||
vecs, err := newGPUStackForTest(srv.URL).Embed(
|
||||
&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Embed: %v", err)
|
||||
}
|
||||
for i, v := range vecs {
|
||||
if v.Index != i || v.Embedding[0] != float64(i) {
|
||||
t.Errorf("slot %d = %+v, want Embedding=[%d] Index=%d", i, v, i, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestGPUStackEmbedEmptyInputShortCircuits avoids HTTP when texts is empty.
|
||||
func TestGPUStackEmbedEmptyInputShortCircuits(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
t.Error("Embed([]) made an unexpected HTTP call")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "bge-m3"
|
||||
vecs, err := newGPUStackForTest(srv.URL).Embed(&model, []string{}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Embed([]): %v", err)
|
||||
}
|
||||
if len(vecs) != 0 {
|
||||
t.Errorf("len(vecs)=%d, want 0", len(vecs))
|
||||
}
|
||||
}
|
||||
|
||||
// TestGPUStackEmbedRequiresAPIKey rejects requests without an API key.
|
||||
func TestGPUStackEmbedRequiresAPIKey(t *testing.T) {
|
||||
model := "bge-m3"
|
||||
_, err := newGPUStackForTest("http://unused").Embed(&model, []string{"a"}, &APIConfig{}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "api key is required") {
|
||||
t.Errorf("expected api-key error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGPUStackEmbedRejectsDuplicateIndex errors on duplicate response indices.
|
||||
func TestGPUStackEmbedRejectsDuplicateIndex(t *testing.T) {
|
||||
srv := newGPUStackServer(t, gpustackEmbeddingsPath, func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{"embedding": []float64{0.1}, "index": 0},
|
||||
{"embedding": []float64{0.2}, "index": 0},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "bge-m3"
|
||||
_, err := newGPUStackForTest(srv.URL).Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "duplicate") {
|
||||
t.Errorf("expected duplicate-index error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGPUStackEmbedRejectsOutOfRangeIndex errors when index exceeds input length.
|
||||
func TestGPUStackEmbedRejectsOutOfRangeIndex(t *testing.T) {
|
||||
srv := newGPUStackServer(t, gpustackEmbeddingsPath, func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{"embedding": []float64{0.1}, "index": 2},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "bge-m3"
|
||||
_, err := newGPUStackForTest(srv.URL).Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "out of range") {
|
||||
t.Errorf("expected out-of-range error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGPUStackEmbedRejectsMissingIndex errors when index is omitted from response.
|
||||
func TestGPUStackEmbedRejectsMissingIndex(t *testing.T) {
|
||||
srv := newGPUStackServer(t, gpustackEmbeddingsPath, func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{"embedding": []float64{0.1}},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "bge-m3"
|
||||
_, err := newGPUStackForTest(srv.URL).Embed(&model, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "missing embedding index") {
|
||||
t.Errorf("expected missing-index error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGPUStackEmbedRejectsEmptyVector errors when the API returns a zero-length vector.
|
||||
func TestGPUStackEmbedRejectsEmptyVector(t *testing.T) {
|
||||
srv := newGPUStackServer(t, gpustackEmbeddingsPath, func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{"embedding": []float64{}, "index": 0},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "bge-m3"
|
||||
_, err := newGPUStackForTest(srv.URL).Embed(&model, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "empty embedding vector") {
|
||||
t.Errorf("expected empty-vector error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGPUStackEmbedRejectsMissingSlot errors when a response index is never returned.
|
||||
func TestGPUStackEmbedRejectsMissingSlot(t *testing.T) {
|
||||
srv := newGPUStackServer(t, gpustackEmbeddingsPath, func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{"embedding": []float64{0.1}, "index": 0},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "bge-m3"
|
||||
_, err := newGPUStackForTest(srv.URL).Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "missing embedding for input index") {
|
||||
t.Errorf("expected missing-slot error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGPUStackUnsupportedMethods(t *testing.T) {
|
||||
m := newGPUStackForTest("http://unused")
|
||||
model := "x"
|
||||
if _, err := m.Embed(&model, []string{"a"}, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
t.Errorf("Embed: %v", err)
|
||||
}
|
||||
if _, err := m.Rerank(&model, "q", []string{"a"}, &APIConfig{}, &RerankConfig{TopN: 1}); err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
t.Errorf("Rerank: %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user