Go: implement provider: Ollama (#14580)

### What problem does this PR solve?

implement `Ollama` provider

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring

---------

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Haruko386
2026-05-06 12:03:58 +08:00
committed by GitHub
parent 28388993a4
commit cd54c08e84
4 changed files with 432 additions and 127 deletions

8
conf/models/ollama.json Normal file
View File

@@ -0,0 +1,8 @@
{
"name": "ollama",
"url_suffix": {
"chat": "chat/completions",
"models": "models"
},
"class": "local"
}

View File

@@ -86,7 +86,7 @@ func (k *MoonshotModel) ChatWithMessages(modelName string, messages []Message, a
"model": modelName,
"messages": apiMessages,
"stream": false,
"temperature": 1,
"temperature": 0.6,
}
if chatModelConfig != nil {

View File

@@ -0,0 +1,423 @@
package models
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"ragflow/internal/common"
"strings"
"time"
)
// OllamaModel implements ModelDriver for Ollama AI
type OllamaModel struct {
BaseURL map[string]string
URLSuffix URLSuffix
httpClient *http.Client
}
// NewOllamaModel creates a new Ollama AI model instance
func NewOllamaModel(baseURL map[string]string, urlSuffix URLSuffix) *OllamaModel {
return &OllamaModel{
BaseURL: baseURL,
URLSuffix: urlSuffix,
httpClient: &http.Client{
Timeout: 120 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
DisableCompression: false,
},
},
}
}
func (o OllamaModel) NewInstance(baseURL map[string]string) ModelDriver {
return &OllamaModel{
BaseURL: baseURL,
URLSuffix: o.URLSuffix,
httpClient: &http.Client{
Timeout: 120 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
DisableCompression: false,
},
},
}
}
func (o OllamaModel) Name() string {
return "ollama"
}
func (o OllamaModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("message is nil")
}
var region = "default"
if apiConfig.Region != nil {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.Chat)
// For qwen/glm models, use async chat endpoint
modelType := strings.Split(modelName, "_")[0]
if modelType == "qwen" || modelType == "glm" {
url = fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.AsyncChat)
}
// Convert messages to API format
apiMessages := make([]map[string]interface{}, len(messages))
for i, msg := range messages {
apiMessages[i] = map[string]interface{}{
"role": msg.Role,
"content": msg.Content,
}
}
// Build request body
reqBody := map[string]interface{}{
"model": modelName,
"messages": apiMessages,
"stream": false,
"temperature": 1,
}
if chatModelConfig != nil {
if chatModelConfig.Stream != nil {
reqBody["stream"] = *chatModelConfig.Stream
}
if chatModelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
}
if chatModelConfig.Temperature != nil {
reqBody["temperature"] = *chatModelConfig.Temperature
}
if chatModelConfig.TopP != nil {
reqBody["top_p"] = *chatModelConfig.TopP
}
if chatModelConfig.Stop != nil {
reqBody["stop"] = *chatModelConfig.Stop
}
if chatModelConfig.Thinking != nil {
if *chatModelConfig.Thinking {
reqBody["thinking"] = map[string]interface{}{
"type": "enabled",
}
} else {
reqBody["thinking"] = map[string]interface{}{
"type": "disabled",
}
}
}
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var result map[string]interface{}
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
choices, ok := result["choices"].([]interface{})
if !ok || len(choices) == 0 {
return nil, fmt.Errorf("no choices in response")
}
firstChoice, ok := choices[0].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid choice format")
}
messageMap, ok := firstChoice["message"].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid message format")
}
content, ok := messageMap["content"].(string)
if !ok {
return nil, fmt.Errorf("invalid content format")
}
thinking, answer := GetThinkingAndAnswer(chatModelConfig.ModelClass, &content)
chatResponse := &ChatResponse{
Answer: answer,
ReasonContent: thinking,
}
return chatResponse, nil
}
func (o OllamaModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error {
if len(messages) == 0 {
return fmt.Errorf("messages is empty")
}
var region = "default"
if apiConfig.Region != nil {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.Chat)
modelType := strings.Split(modelName, "-")[0]
if modelType == "qwen" || modelType == "glm" {
url = fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.AsyncChat)
}
// Convert messages to API format (supporting multimodal content)
apiMessages := make([]map[string]interface{}, len(messages))
for i, msg := range messages {
apiMessages[i] = map[string]interface{}{
"role": msg.Role,
"content": msg.Content,
}
}
// Build request body with streaming enabled
reqBody := map[string]interface{}{
"model": modelName,
"messages": apiMessages,
"stream": true,
}
if modelConfig.Stream != nil {
reqBody["stream"] = *modelConfig.Stream
}
if modelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *modelConfig.MaxTokens
}
if modelConfig.Temperature != nil {
reqBody["temperature"] = *modelConfig.Temperature
}
if modelConfig.DoSample != nil {
reqBody["do_sample"] = *modelConfig.DoSample
}
if modelConfig.TopP != nil {
reqBody["top_p"] = *modelConfig.TopP
}
if modelConfig.Stop != nil {
reqBody["stop"] = *modelConfig.Stop
}
if modelConfig.Thinking != nil {
if *modelConfig.Thinking {
reqBody["thinking"] = map[string]interface{}{
"type": "enabled",
}
} else {
reqBody["thinking"] = map[string]interface{}{
"type": "disabled",
}
}
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := o.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// SSE parsing: read line by line
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
common.Info(line)
// SSE data line starts with "data:"
if !strings.HasPrefix(line, "data:") {
continue
}
// Extract JSON after "data:"
data := strings.TrimSpace(line[5:])
// [DONE] marks the end of stream
if data == "[DONE]" {
break
}
// Parse the JSON event
var event map[string]interface{}
if err = json.Unmarshal([]byte(data), &event); err != nil {
continue
}
choices, ok := event["choices"].([]interface{})
if !ok || len(choices) == 0 {
continue
}
firstChoice, ok := choices[0].(map[string]interface{})
if !ok {
continue
}
delta, ok := firstChoice["delta"].(map[string]interface{})
if !ok {
continue
}
reasoningContent, ok := delta["reasoning_content"].(string)
if ok && reasoningContent != "" {
if err := sender(nil, &reasoningContent); err != nil {
return err
}
}
content, ok := delta["content"].(string)
if ok && content != "" {
if err := sender(&content, nil); err != nil {
return err
}
}
finishReason, ok := firstChoice["finish_reason"].(string)
if ok && finishReason != "" {
break
}
}
// Send [DONE] marker for OpenAI compatibility
endOfStream := "[DONE]"
if err = sender(&endOfStream, nil); err != nil {
return err
}
return scanner.Err()
}
func (o OllamaModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("no such method")
}
func (o OllamaModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
return nil, fmt.Errorf("no such method")
}
func (o OllamaModel) ListModels(apiConfig *APIConfig) ([]string, error) {
var region = "default"
if apiConfig.Region != nil {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.Models)
reqBody := map[string]interface{}{}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var result map[string]interface{}
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
// convert result["data"] to []map[string]interface{}
models := make([]string, 0)
for _, model := range result["data"].([]interface{}) {
modelMap := model.(map[string]interface{})
modelName := modelMap["id"].(string)
models = append(models, modelName)
}
return models, nil
}
func (o OllamaModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
return nil, fmt.Errorf("no such method")
}
func (o OllamaModel) CheckConnection(apiConfig *APIConfig) error {
return fmt.Errorf("no such method")
}

View File

@@ -72,132 +72,6 @@ func (z *VllmModel) Name() string {
return "vllm"
}
// Chat sends a message and returns response
func (z *VllmModel) Chat(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
if message == nil {
return nil, fmt.Errorf("message is nil")
}
var region = "default"
if apiConfig.Region != nil {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Chat)
// I need to get the model type, such as qwen3 is the prefix, the model type will be qwen. glm is the prefix, the model type will be glm. such as the model name: qwen3-0.6b, the model type will be qwen3
// the model name is glm-4.7, the model type will be glm
modelType := strings.Split(*modelName, "-")[0]
if modelType == "qwen" || modelType == "glm" {
url = fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.AsyncChat)
}
// Build request body
reqBody := map[string]interface{}{
"model": modelName,
"messages": []map[string]string{
{"role": "user", "content": *message},
},
"stream": false,
"temperature": 1,
}
if chatModelConfig.Stream != nil {
reqBody["stream"] = *chatModelConfig.Stream
}
if chatModelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
}
if chatModelConfig.Temperature != nil {
reqBody["temperature"] = *chatModelConfig.Temperature
}
if chatModelConfig.TopP != nil {
reqBody["top_p"] = *chatModelConfig.TopP
}
if chatModelConfig.Stop != nil {
reqBody["stop"] = *chatModelConfig.Stop
}
if chatModelConfig.Thinking != nil {
if *chatModelConfig.Thinking {
reqBody["thinking"] = map[string]interface{}{
"type": "enabled",
}
} else {
reqBody["thinking"] = map[string]interface{}{
"type": "disabled",
}
}
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := z.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var result map[string]interface{}
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
choices, ok := result["choices"].([]interface{})
if !ok || len(choices) == 0 {
return nil, fmt.Errorf("no choices in response")
}
firstChoice, ok := choices[0].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid choice format")
}
messageMap, ok := firstChoice["message"].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid message format")
}
content, ok := messageMap["content"].(string)
if !ok {
return nil, fmt.Errorf("invalid content format")
}
thinking, answer := GetThinkingAndAnswer(chatModelConfig.ModelClass, &content)
chatResponse := &ChatResponse{
Answer: answer,
ReasonContent: thinking,
}
return chatResponse, nil
}
// ChatWithMessages sends multiple messages with roles and returns response
func (z *VllmModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
if len(messages) == 0 {