mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
### What problem does this PR solve? Part of #15853 (provider model-list refactor). Refactors **Ollama** `ListModels` onto the shared `ParseListModel` pattern and fixes two correctness issues: - **Endpoint:** switch the models suffix from `api/ps` (only currently-running models) to `api/tags` (all installed models) — the latter is what a model picker should show. - **Parsing:** Ollama returns `{"models":[{"name","model"}]}`, a non-OpenAI shape. Decode it into a typed struct, map the names into `ModelList`, then enrich through `ParseListModel`. This removes the previous unchecked type assertions (`result["models"].([]interface{})` / `.(map[string]interface{})` / `.(string)`) that **panicked** when the body was missing the `models` array or any field, and adds a fallback to the `model` field when `name` is blank. - Drops the no-op GET request body and a dead base-URL reassignment. #### Drive-by fix Shared gitee_test.go `DSModelList` -> `ModelList` compile fix (renamed in #15900) so the models test package builds; auto-resolves against the sibling #15853 PRs. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring
545 lines
16 KiB
Go
545 lines
16 KiB
Go
//
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
|
|
package models
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
)
|
|
|
|
// OllamaModel implements ModelDriver for Ollama AI
|
|
type OllamaModel struct {
|
|
baseModel BaseModel
|
|
}
|
|
|
|
// contentToText extracts a plain-text string from a Message.Content value.
|
|
// Content may be a raw string or an OpenAI-style multimodal array
|
|
// ([]interface{} where each element is {"type": "text", "text": "..."}).
|
|
// The first non-empty "text" value found is returned; empty string on no match.
|
|
func contentToText(content interface{}) string {
|
|
switch c := content.(type) {
|
|
case string:
|
|
return c
|
|
case []interface{}:
|
|
for _, item := range c {
|
|
if part, ok := item.(map[string]interface{}); ok {
|
|
if text, ok := part["text"].(string); ok && text != "" {
|
|
return text
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// NewOllamaModel creates a new Ollama AI model instance
|
|
func NewOllamaModel(baseURL map[string]string, urlSuffix URLSuffix) *OllamaModel {
|
|
return &OllamaModel{
|
|
baseModel: BaseModel{
|
|
BaseURL: baseURL,
|
|
URLSuffix: urlSuffix,
|
|
AllowEmptyAPIKey: true,
|
|
httpClient: NewDriverHTTPClient(),
|
|
},
|
|
}
|
|
}
|
|
|
|
func (o *OllamaModel) NewInstance(baseURL map[string]string) ModelDriver {
|
|
return NewOllamaModel(baseURL, o.baseModel.URLSuffix)
|
|
}
|
|
|
|
func (o *OllamaModel) Name() string {
|
|
return "ollama"
|
|
}
|
|
|
|
func (o *OllamaModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
|
if len(messages) == 0 {
|
|
return nil, fmt.Errorf("message is nil")
|
|
}
|
|
|
|
resolvedBaseURL, err := o.baseModel.GetBaseURL(apiConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
url := fmt.Sprintf("%s/%s", resolvedBaseURL, o.baseModel.URLSuffix.Chat)
|
|
|
|
// For qwen/glm models, use async chat endpoint
|
|
modelType := strings.Split(modelName, "_")[0]
|
|
if modelType == "qwen" || modelType == "glm" {
|
|
url = fmt.Sprintf("%s/%s", resolvedBaseURL, o.baseModel.URLSuffix.AsyncChat)
|
|
}
|
|
|
|
// Convert messages to API format
|
|
apiMessages := make([]map[string]interface{}, len(messages))
|
|
for i, msg := range messages {
|
|
apiMessages[i] = map[string]interface{}{
|
|
"role": msg.Role,
|
|
"content": contentToText(msg.Content),
|
|
}
|
|
}
|
|
|
|
// Build request body
|
|
reqBody := map[string]interface{}{
|
|
"model": modelName,
|
|
"messages": apiMessages,
|
|
"stream": false,
|
|
"temperature": 1,
|
|
}
|
|
|
|
if chatModelConfig != nil {
|
|
if chatModelConfig.Stream != nil {
|
|
reqBody["stream"] = *chatModelConfig.Stream
|
|
}
|
|
|
|
if chatModelConfig.MaxTokens != nil {
|
|
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
|
|
}
|
|
|
|
if chatModelConfig.Temperature != nil {
|
|
reqBody["temperature"] = *chatModelConfig.Temperature
|
|
}
|
|
|
|
if chatModelConfig.TopP != nil {
|
|
reqBody["top_p"] = *chatModelConfig.TopP
|
|
}
|
|
|
|
if chatModelConfig.Stop != nil {
|
|
reqBody["stop"] = *chatModelConfig.Stop
|
|
}
|
|
|
|
if chatModelConfig.Effort != nil && *chatModelConfig.Effort != "" {
|
|
if strings.HasPrefix(strings.ToLower(modelName), "gpt-oss") {
|
|
reqBody["think"] = *chatModelConfig.Effort
|
|
}
|
|
} else if chatModelConfig.Thinking != nil {
|
|
if *chatModelConfig.Thinking {
|
|
reqBody["think"] = true
|
|
}
|
|
}
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := o.baseModel.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read response body: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
// Parse response
|
|
var result map[string]interface{}
|
|
if err = json.Unmarshal(body, &result); err != nil {
|
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
|
}
|
|
|
|
message, ok := result["message"].(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to parse response: message not found")
|
|
}
|
|
|
|
content, ok := message["content"].(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to parse response: content not found")
|
|
}
|
|
|
|
reasonContent, _ := message["thinking"].(string)
|
|
|
|
chatResponse := &ChatResponse{
|
|
Answer: &content,
|
|
ReasonContent: &reasonContent,
|
|
}
|
|
|
|
return chatResponse, nil
|
|
}
|
|
|
|
func (o *OllamaModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error {
|
|
if len(messages) == 0 {
|
|
return fmt.Errorf("messages is empty")
|
|
}
|
|
|
|
resolvedBaseURL, err := o.baseModel.GetBaseURL(apiConfig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
url := fmt.Sprintf("%s/%s", resolvedBaseURL, o.baseModel.URLSuffix.Chat)
|
|
modelType := strings.Split(modelName, "-")[0]
|
|
if modelType == "qwen" || modelType == "glm" {
|
|
url = fmt.Sprintf("%s/%s", resolvedBaseURL, o.baseModel.URLSuffix.AsyncChat)
|
|
}
|
|
|
|
// Convert messages to API format (supporting multimodal content)
|
|
apiMessages := make([]map[string]interface{}, len(messages))
|
|
for i, msg := range messages {
|
|
apiMessages[i] = map[string]interface{}{
|
|
"role": msg.Role,
|
|
"content": contentToText(msg.Content),
|
|
}
|
|
}
|
|
|
|
// Build request body with streaming enabled
|
|
reqBody := map[string]interface{}{
|
|
"model": modelName,
|
|
"messages": apiMessages,
|
|
"stream": true,
|
|
}
|
|
|
|
if modelConfig.Stream != nil {
|
|
reqBody["stream"] = *modelConfig.Stream
|
|
}
|
|
|
|
if modelConfig.MaxTokens != nil {
|
|
reqBody["max_tokens"] = *modelConfig.MaxTokens
|
|
}
|
|
|
|
if modelConfig.Temperature != nil {
|
|
reqBody["temperature"] = *modelConfig.Temperature
|
|
}
|
|
|
|
if modelConfig.DoSample != nil {
|
|
reqBody["do_sample"] = *modelConfig.DoSample
|
|
}
|
|
|
|
if modelConfig.TopP != nil {
|
|
reqBody["top_p"] = *modelConfig.TopP
|
|
}
|
|
|
|
if modelConfig.Stop != nil {
|
|
reqBody["stop"] = *modelConfig.Stop
|
|
}
|
|
|
|
if modelConfig.Effort != nil && *modelConfig.Effort != "" {
|
|
if strings.HasPrefix(strings.ToLower(modelName), "gpt-oss") {
|
|
reqBody["think"] = *modelConfig.Effort
|
|
}
|
|
} else if modelConfig.Thinking != nil {
|
|
if *modelConfig.Thinking {
|
|
reqBody["think"] = true
|
|
}
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), streamCallTimeout)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := o.baseModel.httpClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
// SSE parsing: read line by line
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
|
|
// ignore the blank
|
|
if line == "" {
|
|
continue
|
|
}
|
|
|
|
// Parse the JSON event
|
|
var event map[string]interface{}
|
|
if err = json.Unmarshal([]byte(line), &event); err != nil {
|
|
continue
|
|
}
|
|
|
|
if messageMap, ok := event["message"].(map[string]interface{}); ok {
|
|
if reasoningContent, exists := messageMap["thinking"].(string); exists && reasoningContent != "" {
|
|
if err := sender(nil, &reasoningContent); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if content, exists := messageMap["content"].(string); exists && content != "" {
|
|
if err := sender(&content, nil); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
if done, ok := event["done"].(bool); ok && done {
|
|
break
|
|
}
|
|
}
|
|
|
|
// Send [DONE] marker for OpenAI compatibility with RAGFlow frontend
|
|
endOfStream := "[DONE]"
|
|
if err := sender(&endOfStream, nil); err != nil {
|
|
return err
|
|
}
|
|
|
|
return scanner.Err()
|
|
}
|
|
|
|
func (o *OllamaModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
|
|
if err := o.baseModel.APIConfigCheck(apiConfig); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(texts) == 0 {
|
|
return []EmbeddingData{}, nil
|
|
}
|
|
|
|
if modelName == nil || *modelName == "" {
|
|
return nil, fmt.Errorf("model name is required")
|
|
}
|
|
|
|
resolvedBaseURL, err := o.baseModel.GetBaseURL(apiConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
baseURL := resolvedBaseURL
|
|
if baseURL == "" {
|
|
baseURL = resolvedBaseURL
|
|
}
|
|
if baseURL == "" {
|
|
return nil, fmt.Errorf("missing base URL: please configure the local access address for Ollama (e.g., http://127.0.0.1:11434/v1)")
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), o.baseModel.URLSuffix.Embedding)
|
|
|
|
reqBody := map[string]interface{}{
|
|
"model": *modelName,
|
|
"input": texts,
|
|
}
|
|
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
|
|
reqBody["dimensions"] = embeddingConfig.Dimension
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if auth := BearerAuth(apiConfig); auth != "" {
|
|
req.Header.Set("Authorization", auth)
|
|
}
|
|
|
|
resp, err := o.baseModel.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("Ollama embeddings API error: %s, body: %s", resp.Status, string(body))
|
|
}
|
|
|
|
var embedResp struct {
|
|
Model string `json:"model"`
|
|
Embeddings [][]float64 `json:"embeddings"`
|
|
}
|
|
|
|
if err = json.Unmarshal(body, &embedResp); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
|
}
|
|
|
|
if len(embedResp.Embeddings) == 0 {
|
|
return nil, fmt.Errorf("no embeddings returned")
|
|
}
|
|
|
|
embeddings := make([]EmbeddingData, 0, len(embedResp.Embeddings))
|
|
|
|
for i, emb := range embedResp.Embeddings {
|
|
if len(emb) == 0 {
|
|
return nil, fmt.Errorf("empty embedding at index %d", i)
|
|
}
|
|
|
|
embeddings = append(embeddings, EmbeddingData{
|
|
Embedding: emb,
|
|
Index: i,
|
|
})
|
|
}
|
|
|
|
return embeddings, nil
|
|
}
|
|
|
|
func (o *OllamaModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
|
|
return nil, fmt.Errorf("no such method")
|
|
}
|
|
|
|
// TranscribeAudio transcribe audio
|
|
func (o *OllamaModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", o.Name())
|
|
}
|
|
|
|
func (o *OllamaModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error {
|
|
return fmt.Errorf("%s, no such method", o.Name())
|
|
}
|
|
|
|
// AudioSpeech convert text to audio
|
|
func (o *OllamaModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", o.Name())
|
|
}
|
|
|
|
func (o *OllamaModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {
|
|
return fmt.Errorf("%s, no such method", o.Name())
|
|
}
|
|
|
|
// OCRFile OCR file
|
|
func (o *OllamaModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", o.Name())
|
|
}
|
|
|
|
// ParseFile parse file
|
|
func (o *OllamaModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", o.Name())
|
|
}
|
|
|
|
func (o *OllamaModel) ListModels(apiConfig *APIConfig) ([]ListModelResponse, error) {
|
|
|
|
baseURL, err := o.baseModel.GetBaseURL(apiConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if baseURL == "" {
|
|
return nil, fmt.Errorf("missing base URL: please configure the local access address for Ollama (e.g., http://127.0.0.1:11434)")
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), o.baseModel.URLSuffix.Models)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := o.baseModel.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
// Ollama's /api/tags returns {"models":[{"name":...,"model":...}]}, a shape
|
|
// that differs from the OpenAI list. Decode it into a local struct, map the
|
|
// names into ModelList, then enrich through the shared ParseListModel helper
|
|
// (issue #15853). Using a typed struct also avoids the previous unchecked
|
|
// type assertions, which panicked when "models" was absent or malformed.
|
|
var result struct {
|
|
Models []struct {
|
|
Name string `json:"name"`
|
|
Model string `json:"model"`
|
|
} `json:"models"`
|
|
}
|
|
if err = json.Unmarshal(body, &result); err != nil {
|
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
|
}
|
|
|
|
modelList := ModelList{Object: "list"}
|
|
for _, m := range result.Models {
|
|
name := strings.TrimSpace(m.Name)
|
|
if name == "" {
|
|
name = strings.TrimSpace(m.Model)
|
|
}
|
|
if name == "" {
|
|
continue
|
|
}
|
|
modelList.Models = append(modelList.Models, DSModel{ID: name})
|
|
}
|
|
|
|
return ParseListModel(modelList), nil
|
|
}
|
|
|
|
func (o *OllamaModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
|
|
return nil, fmt.Errorf("no such method")
|
|
}
|
|
|
|
// CheckConnection verifies that the configured Ollama base URL is reachable
|
|
func (o *OllamaModel) CheckConnection(apiConfig *APIConfig) error {
|
|
_, err := o.ListModels(apiConfig)
|
|
return err
|
|
}
|
|
|
|
func (o *OllamaModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) {
|
|
return nil, fmt.Errorf("%s, no such method", o.Name())
|
|
}
|
|
|
|
func (o *OllamaModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", o.Name())
|
|
}
|