mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
fix(go-models): harden Hunyuan embedding validation (#15249)
## Summary - Validate Hunyuan embedding model name and API key before building requests. - Reuse region-aware base URL validation for embedding requests. - Replace the stale unsupported Embed test with happy-path and validation coverage. ## What changed - Added early Hunyuan Embed validation for missing model names and API keys. - Routed Embed through the same base URL region guard used by the other Hunyuan methods. - Updated Hunyuan tests to configure the embedding suffix and cover Embed success plus invalid inputs. ## Why Hunyuan Embed is implemented, but the existing test still expected it to be unsupported and could panic before returning a normal validation error. This keeps the implemented embedding path aligned with the current driver behavior and prevents nil input panics. Closes #15087 Refs #14736
This commit is contained in:
@@ -440,13 +440,23 @@ func (a *HunyuanModel) Embed(modelName *string, texts []string, apiConfig *APICo
|
||||
if len(texts) == 0 {
|
||||
return []EmbeddingData{}, nil
|
||||
}
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("api key is required")
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", a.BaseURL[region], a.URLSuffix.Embedding)
|
||||
baseURL, err := a.baseURLForRegion(region)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", baseURL, a.URLSuffix.Embedding)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
@@ -458,7 +468,10 @@ func (a *HunyuanModel) Embed(modelName *string, texts []string, apiConfig *APICo
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
@@ -496,7 +509,7 @@ func (a *HunyuanModel) Embed(modelName *string, texts []string, apiConfig *APICo
|
||||
return nil, fmt.Errorf("hunyuan embedding response contains no data: %s", string(body))
|
||||
}
|
||||
|
||||
var embeddings []EmbeddingData
|
||||
embeddings := make([]EmbeddingData, 0, len(parsedResponse.Data))
|
||||
for _, dataElem := range parsedResponse.Data {
|
||||
embeddings = append(embeddings, EmbeddingData{
|
||||
Embedding: dataElem.Embedding,
|
||||
|
||||
@@ -75,7 +75,7 @@ func newHunyuanSSEServer(t *testing.T, expectedPath, ssePayload string) *httptes
|
||||
func newHunyuanForTest(baseURL string) *HunyuanModel {
|
||||
return NewHunyuanModel(
|
||||
map[string]string{"default": baseURL},
|
||||
URLSuffix{Chat: "chat/completions", Models: "models"},
|
||||
URLSuffix{Chat: "chat/completions", Embedding: "embeddings", Models: "models"},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -428,11 +428,60 @@ func TestHunyuanBaseURLForRegionUnknown(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHunyuanEmbedReturnsNoSuchMethod(t *testing.T) {
|
||||
model := "x"
|
||||
_, err := newHunyuanForTest("http://unused").Embed(&model, []string{"a"}, &APIConfig{}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
t.Errorf("Embed: want 'no such method', got %v", err)
|
||||
func TestHunyuanEmbedHappyPath(t *testing.T) {
|
||||
srv := newHunyuanServer(t, http.MethodPost, "/embeddings", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
if body["model"] != "hunyuan-embedding" {
|
||||
t.Errorf("model=%v", body["model"])
|
||||
}
|
||||
inputs, ok := body["input"].([]interface{})
|
||||
if !ok || len(inputs) != 2 {
|
||||
t.Errorf("input=%#v", body["input"])
|
||||
http.Error(w, "bad input", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{"embedding": []float64{0.1, 0.2}, "index": 0},
|
||||
{"embedding": []float64{0.3, 0.4}, "index": 1},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "hunyuan-embedding"
|
||||
embeddings, err := newHunyuanForTest(srv.URL).Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Embed: %v", err)
|
||||
}
|
||||
if len(embeddings) != 2 || embeddings[1].Index != 1 || embeddings[1].Embedding[0] != 0.3 {
|
||||
t.Errorf("embeddings=%+v", embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHunyuanEmbedValidatesInputs(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
model := "hunyuan-embedding"
|
||||
|
||||
if embeddings, err := newHunyuanForTest("http://unused").Embed(nil, nil, nil, nil); err != nil || len(embeddings) != 0 {
|
||||
t.Errorf("empty input: embeddings=%+v err=%v", embeddings, err)
|
||||
}
|
||||
if _, err := newHunyuanForTest("http://unused").Embed(nil, []string{"x"}, &APIConfig{ApiKey: &apiKey}, nil); err == nil || !strings.Contains(err.Error(), "model name is required") {
|
||||
t.Errorf("nil model: %v", err)
|
||||
}
|
||||
emptyModel := ""
|
||||
if _, err := newHunyuanForTest("http://unused").Embed(&emptyModel, []string{"x"}, &APIConfig{ApiKey: &apiKey}, nil); err == nil || !strings.Contains(err.Error(), "model name is required") {
|
||||
t.Errorf("empty model: %v", err)
|
||||
}
|
||||
if _, err := newHunyuanForTest("http://unused").Embed(&model, []string{"x"}, nil, nil); err == nil || !strings.Contains(err.Error(), "api key is required") {
|
||||
t.Errorf("nil api config: %v", err)
|
||||
}
|
||||
if _, err := newHunyuanForTest("http://unused").Embed(&model, []string{"x"}, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "api key is required") {
|
||||
t.Errorf("missing api key: %v", err)
|
||||
}
|
||||
emptyKey := ""
|
||||
if _, err := newHunyuanForTest("http://unused").Embed(&model, []string{"x"}, &APIConfig{ApiKey: &emptyKey}, nil); err == nil || !strings.Contains(err.Error(), "api key is required") {
|
||||
t.Errorf("empty api key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user