mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Go: implement provider: Ollama (#14580)
### What problem does this PR solve? implement `Ollama` provider ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring --------- Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
8
conf/models/ollama.json
Normal file
8
conf/models/ollama.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"name": "ollama",
|
||||
"url_suffix": {
|
||||
"chat": "chat/completions",
|
||||
"models": "models"
|
||||
},
|
||||
"class": "local"
|
||||
}
|
||||
@@ -86,7 +86,7 @@ func (k *MoonshotModel) ChatWithMessages(modelName string, messages []Message, a
|
||||
"model": modelName,
|
||||
"messages": apiMessages,
|
||||
"stream": false,
|
||||
"temperature": 1,
|
||||
"temperature": 0.6,
|
||||
}
|
||||
|
||||
if chatModelConfig != nil {
|
||||
|
||||
423
internal/entity/models/ollama.go
Normal file
423
internal/entity/models/ollama.go
Normal file
@@ -0,0 +1,423 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OllamaModel implements ModelDriver for Ollama AI
|
||||
type OllamaModel struct {
|
||||
BaseURL map[string]string
|
||||
URLSuffix URLSuffix
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewOllamaModel creates a new Ollama AI model instance
|
||||
func NewOllamaModel(baseURL map[string]string, urlSuffix URLSuffix) *OllamaModel {
|
||||
return &OllamaModel{
|
||||
BaseURL: baseURL,
|
||||
URLSuffix: urlSuffix,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DisableCompression: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (o OllamaModel) NewInstance(baseURL map[string]string) ModelDriver {
|
||||
return &OllamaModel{
|
||||
BaseURL: baseURL,
|
||||
URLSuffix: o.URLSuffix,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DisableCompression: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (o OllamaModel) Name() string {
|
||||
return "ollama"
|
||||
}
|
||||
|
||||
func (o OllamaModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("message is nil")
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.Chat)
|
||||
|
||||
// For qwen/glm models, use async chat endpoint
|
||||
modelType := strings.Split(modelName, "_")[0]
|
||||
if modelType == "qwen" || modelType == "glm" {
|
||||
url = fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.AsyncChat)
|
||||
}
|
||||
|
||||
// Convert messages to API format
|
||||
apiMessages := make([]map[string]interface{}, len(messages))
|
||||
for i, msg := range messages {
|
||||
apiMessages[i] = map[string]interface{}{
|
||||
"role": msg.Role,
|
||||
"content": msg.Content,
|
||||
}
|
||||
}
|
||||
|
||||
// Build request body
|
||||
reqBody := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": apiMessages,
|
||||
"stream": false,
|
||||
"temperature": 1,
|
||||
}
|
||||
|
||||
if chatModelConfig != nil {
|
||||
if chatModelConfig.Stream != nil {
|
||||
reqBody["stream"] = *chatModelConfig.Stream
|
||||
}
|
||||
|
||||
if chatModelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
|
||||
}
|
||||
|
||||
if chatModelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *chatModelConfig.Temperature
|
||||
}
|
||||
|
||||
if chatModelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *chatModelConfig.TopP
|
||||
}
|
||||
|
||||
if chatModelConfig.Stop != nil {
|
||||
reqBody["stop"] = *chatModelConfig.Stop
|
||||
}
|
||||
|
||||
if chatModelConfig.Thinking != nil {
|
||||
if *chatModelConfig.Thinking {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "enabled",
|
||||
}
|
||||
} else {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "disabled",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var result map[string]interface{}
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
choices, ok := result["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
return nil, fmt.Errorf("no choices in response")
|
||||
}
|
||||
|
||||
firstChoice, ok := choices[0].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid choice format")
|
||||
}
|
||||
|
||||
messageMap, ok := firstChoice["message"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid message format")
|
||||
}
|
||||
|
||||
content, ok := messageMap["content"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid content format")
|
||||
}
|
||||
|
||||
thinking, answer := GetThinkingAndAnswer(chatModelConfig.ModelClass, &content)
|
||||
|
||||
chatResponse := &ChatResponse{
|
||||
Answer: answer,
|
||||
ReasonContent: thinking,
|
||||
}
|
||||
|
||||
return chatResponse, nil
|
||||
}
|
||||
|
||||
func (o OllamaModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error {
|
||||
if len(messages) == 0 {
|
||||
return fmt.Errorf("messages is empty")
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.Chat)
|
||||
modelType := strings.Split(modelName, "-")[0]
|
||||
if modelType == "qwen" || modelType == "glm" {
|
||||
url = fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.AsyncChat)
|
||||
}
|
||||
|
||||
// Convert messages to API format (supporting multimodal content)
|
||||
apiMessages := make([]map[string]interface{}, len(messages))
|
||||
for i, msg := range messages {
|
||||
apiMessages[i] = map[string]interface{}{
|
||||
"role": msg.Role,
|
||||
"content": msg.Content,
|
||||
}
|
||||
}
|
||||
|
||||
// Build request body with streaming enabled
|
||||
reqBody := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": apiMessages,
|
||||
"stream": true,
|
||||
}
|
||||
|
||||
if modelConfig.Stream != nil {
|
||||
reqBody["stream"] = *modelConfig.Stream
|
||||
}
|
||||
|
||||
if modelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *modelConfig.MaxTokens
|
||||
}
|
||||
|
||||
if modelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *modelConfig.Temperature
|
||||
}
|
||||
|
||||
if modelConfig.DoSample != nil {
|
||||
reqBody["do_sample"] = *modelConfig.DoSample
|
||||
}
|
||||
|
||||
if modelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *modelConfig.TopP
|
||||
}
|
||||
|
||||
if modelConfig.Stop != nil {
|
||||
reqBody["stop"] = *modelConfig.Stop
|
||||
}
|
||||
|
||||
if modelConfig.Thinking != nil {
|
||||
if *modelConfig.Thinking {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "enabled",
|
||||
}
|
||||
} else {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "disabled",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// SSE parsing: read line by line
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
common.Info(line)
|
||||
|
||||
// SSE data line starts with "data:"
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract JSON after "data:"
|
||||
data := strings.TrimSpace(line[5:])
|
||||
|
||||
// [DONE] marks the end of stream
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
// Parse the JSON event
|
||||
var event map[string]interface{}
|
||||
if err = json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
choices, ok := event["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
firstChoice, ok := choices[0].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
delta, ok := firstChoice["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
reasoningContent, ok := delta["reasoning_content"].(string)
|
||||
if ok && reasoningContent != "" {
|
||||
if err := sender(nil, &reasoningContent); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
content, ok := delta["content"].(string)
|
||||
if ok && content != "" {
|
||||
if err := sender(&content, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
finishReason, ok := firstChoice["finish_reason"].(string)
|
||||
if ok && finishReason != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Send [DONE] marker for OpenAI compatibility
|
||||
endOfStream := "[DONE]"
|
||||
if err = sender(&endOfStream, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func (o OllamaModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("no such method")
|
||||
}
|
||||
|
||||
func (o OllamaModel) Rerank(modelName *string, query string, texts []string, apiConfig *APIConfig) ([]float64, error) {
|
||||
return nil, fmt.Errorf("no such method")
|
||||
}
|
||||
|
||||
func (o OllamaModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
var region = "default"
|
||||
|
||||
if apiConfig.Region != nil {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.Models)
|
||||
|
||||
reqBody := map[string]interface{}{}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var result map[string]interface{}
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
// convert result["data"] to []map[string]interface{}
|
||||
models := make([]string, 0)
|
||||
for _, model := range result["data"].([]interface{}) {
|
||||
modelMap := model.(map[string]interface{})
|
||||
modelName := modelMap["id"].(string)
|
||||
models = append(models, modelName)
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (o OllamaModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
|
||||
return nil, fmt.Errorf("no such method")
|
||||
}
|
||||
|
||||
func (o OllamaModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
return fmt.Errorf("no such method")
|
||||
}
|
||||
@@ -72,132 +72,6 @@ func (z *VllmModel) Name() string {
|
||||
return "vllm"
|
||||
}
|
||||
|
||||
// Chat sends a message and returns response
|
||||
func (z *VllmModel) Chat(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
||||
if message == nil {
|
||||
return nil, fmt.Errorf("message is nil")
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Chat)
|
||||
|
||||
// I need to get the model type, such as qwen3 is the prefix, the model type will be qwen. glm is the prefix, the model type will be glm. such as the model name: qwen3-0.6b, the model type will be qwen3
|
||||
// the model name is glm-4.7, the model type will be glm
|
||||
modelType := strings.Split(*modelName, "-")[0]
|
||||
if modelType == "qwen" || modelType == "glm" {
|
||||
url = fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.AsyncChat)
|
||||
}
|
||||
|
||||
// Build request body
|
||||
reqBody := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": *message},
|
||||
},
|
||||
"stream": false,
|
||||
"temperature": 1,
|
||||
}
|
||||
|
||||
if chatModelConfig.Stream != nil {
|
||||
reqBody["stream"] = *chatModelConfig.Stream
|
||||
}
|
||||
|
||||
if chatModelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
|
||||
}
|
||||
|
||||
if chatModelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *chatModelConfig.Temperature
|
||||
}
|
||||
|
||||
if chatModelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *chatModelConfig.TopP
|
||||
}
|
||||
|
||||
if chatModelConfig.Stop != nil {
|
||||
reqBody["stop"] = *chatModelConfig.Stop
|
||||
}
|
||||
|
||||
if chatModelConfig.Thinking != nil {
|
||||
if *chatModelConfig.Thinking {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "enabled",
|
||||
}
|
||||
} else {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "disabled",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var result map[string]interface{}
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
choices, ok := result["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
return nil, fmt.Errorf("no choices in response")
|
||||
}
|
||||
|
||||
firstChoice, ok := choices[0].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid choice format")
|
||||
}
|
||||
|
||||
messageMap, ok := firstChoice["message"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid message format")
|
||||
}
|
||||
|
||||
content, ok := messageMap["content"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid content format")
|
||||
}
|
||||
|
||||
thinking, answer := GetThinkingAndAnswer(chatModelConfig.ModelClass, &content)
|
||||
|
||||
chatResponse := &ChatResponse{
|
||||
Answer: answer,
|
||||
ReasonContent: thinking,
|
||||
}
|
||||
|
||||
return chatResponse, nil
|
||||
}
|
||||
|
||||
// ChatWithMessages sends multiple messages with roles and returns response
|
||||
func (z *VllmModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
||||
if len(messages) == 0 {
|
||||
|
||||
Reference in New Issue
Block a user