mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
fix(go): implement ListModels and CheckConnection in NVIDIA driver (#14636)
### What problem does this PR solve? The NVIDIA Go driver added in #14623 has a real chat path, but \`ListModels\` and \`CheckConnection\` are stubs that always return \`no such method\`. So: - The model picker cannot auto-populate available NVIDIA NIM model ids. Users have to type the full id by hand (e.g. \`abacusai/dracarys-llama-3.1-70b-instruct\`). - The "Check connection" button always fails for NVIDIA, even when the base URL is reachable and the API key is accepted. NVIDIA NIM is OpenAI-compatible. \`/v1/models\` works with the same Bearer token used for chat. The \`conf/models/nvidia.json\` file already wires the \`models\` url_suffix, so no config change is needed. ### What this PR includes - \`internal/entity/models/nvidia.go\`: - \`ListModels\` now calls \`GET ${BaseURL}/${URLSuffix.Models}\`, parses \`response.data[*].id\`, and returns the list. Same shape as the moonshot, xai, and openai drivers. - \`CheckConnection\` now calls \`ListModels\` and returns its error. Same pattern xai, moonshot, deepseek, aliyun, and gitee already use. \`Balance\`, \`Encode\`, and \`Rerank\` are still stubs in this PR and can be added in follow-ups. No JSON change. No factory change. No interface change. ### How the implementation works - Region resolution falls back to \`default\` when the supplied region is unknown, so a stray region value does not break a valid request. - The Authorization header is only set when \`apiConfig\` and \`ApiKey\` are non-nil and non-empty. This avoids a nil-pointer dereference and lets self-hosted NIM deployments without a key still work. - Non-200 responses propagate the upstream status line and body so the user sees a real error message. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) ### How was this tested? - \`go build ./internal/entity/models/...\` in a clean go 1.25 image (the go.mod minimum) returns exit 0. - The full method set on \`NvidiaModel\` still matches the \`ModelDriver\` interface. - Pattern parity with the existing xai, moonshot, deepseek, aliyun, gitee, and openai drivers. Closes #14635
This commit is contained in:
@@ -337,14 +337,85 @@ func (n NvidiaModel) Rerank(modelName *string, query string, texts []string, api
|
||||
return nil, fmt.Errorf("no such method")
|
||||
}
|
||||
|
||||
// ListModels calls /v1/models on the configured NVIDIA NIM base URL
|
||||
// and returns the list of available model ids. The endpoint is
|
||||
// OpenAI-compatible, so the parsing follows the same shape used by
|
||||
// the moonshot, xai, and openai drivers.
|
||||
func (n NvidiaModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
return nil, fmt.Errorf("no such method")
|
||||
var region = "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL := n.BaseURL[region]
|
||||
if baseURL == "" {
|
||||
baseURL = n.BaseURL["default"]
|
||||
}
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("nvidia: no base URL configured for region %q", region)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", baseURL, n.URLSuffix.Models)
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
}
|
||||
|
||||
resp, err := n.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Nvidia models API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
data, ok := result["data"].([]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid models list format")
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(data))
|
||||
for _, item := range data {
|
||||
m, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
id, ok := m["id"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
models = append(models, id)
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (n NvidiaModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
|
||||
return nil, fmt.Errorf("no such method")
|
||||
}
|
||||
|
||||
// CheckConnection verifies that the configured NVIDIA NIM base URL
|
||||
// is reachable and that the API key is accepted, by issuing a
|
||||
// lightweight ListModels call. Mirrors the pattern used by the xai,
|
||||
// moonshot, deepseek, aliyun, and gitee drivers.
|
||||
func (n NvidiaModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
return fmt.Errorf("no such method")
|
||||
_, err := n.ListModels(apiConfig)
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user