Files
ragflow/internal/entity/models/model.go
Jin Hai 65afaa1292 Model config: add tools (#16371)
### What problem does this PR solve?

```
{
      "name": "glm-4-flash",
      "max_tokens": 128000,
      "model_types": [
        "chat"
      ],
      "tools": {
        "support": true
      }
}
```

```
RAGFlow(admin)> list provider 'zhipu-ai' models;
+------------+---------------+------------+---------------+----------------+-----------+-----------+
| dimensions | max_dimension | max_tokens | model_type    | name           | thinking  | tools     |
+------------+---------------+------------+---------------+----------------+-----------+-----------+
|            |               | 204800     | [chat]        | glm-5          | supported | supported |
|            |               | 204800     | [chat]        | glm-5-turbo    | supported | supported |
```

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2026-06-26 11:37:51 +08:00

807 lines
23 KiB
Go

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package models
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
)
// ReasoningSimple represents simple reasoning capability
type ReasoningSimple struct {
Type string `json:"type"`
Enabled bool `json:"enabled"`
Default bool `json:"default"`
}
// ReasoningBudget represents budget-based reasoning capability
type ReasoningBudget struct {
Type string `json:"type"`
Enabled bool `json:"enabled"`
DefaultTokens int `json:"default_tokens"`
TokenRange struct {
Min int `json:"min"`
Max int `json:"max"`
} `json:"token_range"`
}
// ReasoningEffort represents effort-based reasoning capability
type ReasoningEffort struct {
Type string `json:"type"`
Enabled bool `json:"enabled"`
Default string `json:"default"`
Options []string `json:"options"`
}
// Reasoning represents the reasoning capability (can be one of three types)
type Reasoning struct {
Simple *ReasoningSimple `json:"-"`
Budget *ReasoningBudget `json:"-"`
Effort *ReasoningEffort `json:"-"`
RawType string `json:"type"`
}
// Reasoning represents the reasoning capability (can be one of three types)
type ClearReasoningContent struct {
DefaultValue bool `json:"default_value"`
SupportedModels []string `json:"supported_models"`
}
// Reasoning represents the reasoning capability (can be one of three types)
type Thinking struct {
DefaultValue bool `json:"default_value"`
SupportedModels []string `json:"supported_models"`
}
// UnmarshalJSON custom unmarshal for Reasoning
func (r *Reasoning) UnmarshalJSON(data []byte) error {
var temp map[string]interface{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
typeVal, ok := temp["type"].(string)
if !ok {
return fmt.Errorf("reasoning type is required")
}
r.RawType = typeVal
switch typeVal {
case "simple":
var simple ReasoningSimple
dataBytes, _ := json.Marshal(temp)
if err := json.Unmarshal(dataBytes, &simple); err != nil {
return err
}
r.Simple = &simple
case "budget":
var budget ReasoningBudget
dataBytes, _ := json.Marshal(temp)
if err := json.Unmarshal(dataBytes, &budget); err != nil {
return err
}
r.Budget = &budget
case "effort":
var effort ReasoningEffort
dataBytes, _ := json.Marshal(temp)
if err := json.Unmarshal(dataBytes, &effort); err != nil {
return err
}
r.Effort = &effort
default:
return fmt.Errorf("unknown reasoning type: %s", typeVal)
}
return nil
}
// MarshalJSON custom marshal for Reasoning
func (r *Reasoning) MarshalJSON() ([]byte, error) {
switch r.RawType {
case "simple":
if r.Simple != nil {
return json.Marshal(r.Simple)
}
case "budget":
if r.Budget != nil {
return json.Marshal(r.Budget)
}
case "effort":
if r.Effort != nil {
return json.Marshal(r.Effort)
}
}
return nil, fmt.Errorf("invalid reasoning state")
}
// Multimodal represents multimodal capability
type Multimodal struct {
Enabled bool `json:"enabled"`
InputModalities []string `json:"input_modalities,omitempty"`
OutputModalities []string `json:"output_modalities,omitempty"`
}
// Features represents all features of a model
type Features struct {
Multimodal *Multimodal `json:"multimodal,omitempty"`
Reasoning *Reasoning `json:"reasoning,omitempty"`
Thinking *Thinking `json:"thinking,omitempty"`
ClearThinking *ClearReasoningContent `json:"clear_thinking,omitempty"`
}
type ModelThinking struct {
DefaultValue bool `json:"default_value"`
ClearThinking bool `json:"clear_thinking"`
}
type ModelTools struct {
Support bool `json:"support"`
}
// Model represents a single LLM model
type Model struct {
Name string `json:"name"`
MaxTokens *int `json:"max_tokens"`
ModelTypes []string `json:"model_types"`
Thinking *ModelThinking `json:"thinking"`
Tools *ModelTools `json:"tools"`
Class *string `json:"class"`
MaxDimension *int `json:"max_dimension"` // used by embedding models
Dimensions []int `json:"dimensions"`
Alias []string `json:"alias"`
ModelTypeMap map[string]bool
}
// Provider represents an LLM provider
type Provider struct {
Name string `json:"name"`
URL map[string]string `json:"url"`
URLSuffix URLSuffix `json:"url_suffix"`
Models []*Model `json:"models"`
Features Features `json:"features"`
Class string `json:"class"`
ModelDriver ModelDriver
}
// ProviderManager manages provider and model operations
type ProviderManager struct {
Providers []Provider `json:"model_providers"`
AllModels []Model `json:"all_models"`
Alias2ModelIndex map[string]int `json:"alias2_model_index_map"`
}
// ModelResponse represents the standard response structure
type ModelResponse struct {
Code int `json:"code"`
Data []map[string]interface{} `json:"data"`
Message string `json:"message"`
}
func decodeProviderConfig(data []byte) (Provider, error) {
var provider Provider
if err := json.Unmarshal(data, &provider); err != nil {
return Provider{}, err
}
var rawProvider struct {
URLSuffix json.RawMessage `json:"url_suffix"`
}
if err := json.Unmarshal(data, &rawProvider); err != nil {
return Provider{}, err
}
if len(rawProvider.URLSuffix) == 0 {
return provider, nil
}
decoder := json.NewDecoder(bytes.NewReader(rawProvider.URLSuffix))
decoder.DisallowUnknownFields()
if err := decoder.Decode(&provider.URLSuffix); err != nil {
return Provider{}, err
}
return provider, nil
}
var providerManager *ProviderManager
func GetProviderManager() *ProviderManager {
return providerManager
}
// InitProviderManager creates a new ProviderManager by reading all JSON files from a directory
func InitProviderManager(dirPath string) error {
providers := []Provider{}
// Read all files in the directory
files, err := os.ReadDir(dirPath)
if err != nil {
return fmt.Errorf("error reading directory %s: %w", dirPath, err)
}
modelFactory := NewModelFactory()
// Iterate through all files
for _, file := range files {
// Skip directories
if file.IsDir() {
continue
}
// Only process JSON files
if !strings.HasSuffix(file.Name(), ".json") {
continue
}
// Build full file path
filePath := filepath.Join(dirPath, file.Name())
// Read the file
var data []byte
data, err = os.ReadFile(filePath)
if err != nil {
return fmt.Errorf("error reading file %s: %w", filePath, err)
}
// Parse JSON
var provider Provider
if provider, err = decodeProviderConfig(data); err != nil {
return fmt.Errorf("error parsing JSON from file %s: %w", filePath, err)
}
for _, model := range provider.Models {
// if the prefix of mode.Name is matched with keys of modelSupportThinking
if provider.Class == "" {
pos := strings.Index(model.Name, "-")
if pos >= 0 {
modelClass := model.Name[0:pos]
model.Class = &modelClass
}
} else {
model.Class = &provider.Name
}
model.ModelTypeMap = make(map[string]bool)
for _, modelType := range model.ModelTypes {
model.ModelTypeMap[modelType] = true
}
}
provider.ModelDriver, err = modelFactory.CreateModelDriver(provider.Name, provider.URL, provider.URLSuffix)
if err != nil {
return fmt.Errorf("error creating model driver for provider %s: %w", provider.Name, err)
}
// Add to providers list
providers = append(providers, provider)
}
if len(providers) == 0 {
return fmt.Errorf("no JSON files found in directory %s", dirPath)
}
// Read the file. Use a repo-root-relative path so that go test
// (which sets CWD to the package directory) can still find it.
var data []byte
data, err = os.ReadFile(filepath.Join(findRepoRoot(), "conf", "all_models.json"))
if err != nil {
return fmt.Errorf("error reading file 'conf/all_models.json': %w", err)
}
// Parse JSON
type AllModels struct {
Models []Model `json:"models"`
}
var allModels AllModels
if err = json.Unmarshal(data, &allModels); err != nil {
return fmt.Errorf("error parsing JSON from file 'conf/all_models.json': %w", err)
}
alias2ModelIndex := make(map[string]int)
for idx, model := range allModels.Models {
addModelAlias := func(alias string) error {
alias = strings.TrimSpace(alias)
if alias == "" {
return nil
}
lowerAlias := strings.ToLower(alias)
if existingIdx, ok := alias2ModelIndex[lowerAlias]; ok && existingIdx != idx {
return fmt.Errorf("duplicate alias %q for models %q and %q", alias, allModels.Models[existingIdx].Name, model.Name)
}
alias2ModelIndex[lowerAlias] = idx
return nil
}
if err = addModelAlias(model.Name); err != nil {
return err
}
for _, alias := range model.Alias {
if err = addModelAlias(alias); err != nil {
return err
}
}
}
providerManager = &ProviderManager{
Providers: providers,
AllModels: allModels.Models,
Alias2ModelIndex: alias2ModelIndex,
}
return nil
}
// 1. List all providers
func (pm *ProviderManager) ListProviders() ([]map[string]interface{}, error) {
var providers []map[string]interface{}
for _, provider := range pm.Providers {
modelTypeSet := make(map[string]struct{})
for _, model := range provider.Models {
for _, modelType := range model.ModelTypes {
modelTypeSet[modelType] = struct{}{}
}
}
// Always emit a non-nil slice. Python's provider_api_service.list_providers
// returns `[]` for empty model_types; clients (e.g. AvailableModels in
// web/src/pages/user-setting/setting-model/components/un-add-model.tsx)
// call .some() / .forEach() on this field, which crashes on null.
modelTypes := make([]string, 0, len(modelTypeSet))
for modelType := range modelTypeSet {
modelTypes = append(modelTypes, modelType)
}
providerData := map[string]interface{}{
"name": provider.Name,
"url": provider.URL,
"model_types": modelTypes,
"url_suffix": provider.URLSuffix,
}
providers = append(providers, providerData)
}
if len(providers) == 0 {
return nil, fmt.Errorf("no providers found")
}
return providers, nil
}
func (pm *ProviderManager) ListAllModels() ([]map[string]interface{}, error) {
var modelList []map[string]interface{}
for _, model := range pm.AllModels {
modelData := map[string]interface{}{
"name": model.Name,
"model_types": model.ModelTypes,
}
if model.Alias != nil {
modelData["alias"] = model.Alias
}
if model.Thinking != nil {
modelData["thinking"] = model.Thinking
}
if model.MaxTokens != nil {
modelData["max_tokens"] = *model.MaxTokens
}
if model.MaxDimension != nil {
modelData["max_dimension"] = *model.MaxDimension
}
if len(model.Dimensions) > 0 {
modelData["dimensions"] = model.Dimensions
}
if model.Thinking != nil {
modelData["thinking"] = "supported"
}
if model.Tools != nil {
modelData["tools"] = "supported"
}
modelList = append(modelList, modelData)
}
if len(modelList) == 0 {
return nil, fmt.Errorf("no models found")
}
return modelList, nil
}
func (pm *ProviderManager) GetModelByNameOrAlias(modelName string) *Model {
lowerModelName := strings.ToLower(modelName)
// Check if it is alias
modelIndex, ok := pm.Alias2ModelIndex[lowerModelName]
if ok {
return &pm.AllModels[modelIndex]
}
return nil
}
// 2. Show specific provider information (including base_url)
func (pm *ProviderManager) GetProviderByName(providerName string) (map[string]interface{}, error) {
provider := pm.FindProvider(providerName)
if provider == nil {
return nil, fmt.Errorf("provider '%s' not found", providerName)
}
providerInfo := map[string]interface{}{
"name": provider.Name,
"base_url": provider.URL,
"total_models": len(provider.Models),
}
return providerInfo, nil
}
// 3. List models under a specific provider
func (pm *ProviderManager) ListModels(providerName string) ([]map[string]interface{}, error) {
provider := pm.FindProvider(providerName)
if provider == nil {
return nil, fmt.Errorf("provider '%s' not found", providerName)
}
modelList := []map[string]interface{}{}
for _, model := range provider.Models {
// Field name "model_type" (singular) matches the IInstanceModel
// contract in web/src/interfaces/database/llm.ts:75 and Python's
// LLM.to_dict() (api/db/db_models.py). The Go port previously
// emitted the plural "model_types", which broke used-model.tsx
// — the view rendered empty because the "model_type" key was
// undefined. IAvailableProvider.model_types (plural) is the
// other interface used by un-add-model.tsx and continues to
// receive the plural form from the dedicated /api/v1/providers/
// handler.
//
// `max_dimension` and `dimensions` are always emitted (nil → null,
// empty slice → null in JSON) to match Python's LLM.to_dict() and
// keep the response shape stable for clients that destructure
// the object.
modelData := map[string]interface{}{
"name": model.Name,
"max_tokens": model.MaxTokens,
"model_type": model.ModelTypes,
"max_dimension": model.MaxDimension,
"dimensions": model.Dimensions,
}
if model.Thinking != nil {
modelData["thinking"] = "supported"
}
if model.Tools != nil {
modelData["tools"] = "supported"
}
modelList = append(modelList, modelData)
}
if len(modelList) == 0 {
return nil, fmt.Errorf("no models found")
}
return modelList, nil
}
func (pm *ProviderManager) GetModelByName(providerName, modelName string) (*Model, error) {
provider := pm.FindProvider(providerName)
if provider == nil {
return nil, fmt.Errorf("provider '%s' not found", providerName)
}
model := pm.findModel(provider, modelName)
if model == nil {
return nil, fmt.Errorf("model '%s' not found", modelName)
}
return model, nil
}
func (pm *ProviderManager) GetModelUrl(providerName, modelName, modelType string) (*string, *string, error) {
provider := pm.FindProvider(providerName)
if provider == nil {
return nil, nil, fmt.Errorf("provider '%s' not found", providerName)
}
model := pm.findModel(provider, modelName)
if model == nil {
return nil, nil, fmt.Errorf("model '%s' not found", modelName)
}
if !model.ModelTypeMap[modelType] {
return nil, nil, fmt.Errorf("model '%s' does not support model type '%s'", modelName, modelType)
}
switch modelType {
case "chat":
url := fmt.Sprintf("%s%s", provider.URL, provider.URLSuffix.Chat)
return &url, nil, nil
case "async_chat":
chatUrl := fmt.Sprintf("%s%s", provider.URL, provider.URLSuffix.AsyncChat)
resultUrl := fmt.Sprintf("%s%s", provider.URL, provider.URLSuffix.AsyncResult)
return &chatUrl, &resultUrl, nil
case "embedding":
url := fmt.Sprintf("%s%s", provider.URL, provider.URLSuffix.Embedding)
return &url, nil, nil
case "rerank":
url := fmt.Sprintf("%s%s", provider.URL, provider.URLSuffix.Rerank)
return &url, nil, nil
default:
return nil, nil, fmt.Errorf("model '%s' does not support model type '%s'", modelName, modelType)
}
}
// 4. Search specific model information with filtering by max_tokens or type
func (pm *ProviderManager) SearchModelInfo(providerName, modelName string, filterBy string, filterValue interface{}) ModelResponse {
resp := ModelResponse{
Code: 0,
Data: []map[string]interface{}{},
Message: "success",
}
provider := pm.FindProvider(providerName)
if provider == nil {
resp.Code = 404
resp.Message = fmt.Sprintf("Provider '%s' not found", providerName)
return resp
}
model := pm.findModel(provider, modelName)
if model == nil {
resp.Code = 404
resp.Message = fmt.Sprintf("Model '%s' not found in provider '%s'", modelName, providerName)
return resp
}
// Apply filters
matchFilter := true
if filterBy != "" && filterValue != nil {
switch filterBy {
case "max_tokens":
if maxVal, ok := filterValue.(int); ok {
if *model.MaxTokens < maxVal {
matchFilter = false
resp.Code = 400
resp.Message = fmt.Sprintf("Model does not meet filter criteria: max_tokens (%d) < %d",
model.MaxTokens, maxVal)
}
}
case "type":
if typeVal, ok := filterValue.(string); ok {
if !containsModelType(model.ModelTypes, typeVal) {
matchFilter = false
resp.Code = 400
resp.Message = fmt.Sprintf("Model does not meet filter criteria: type '%s' not found", typeVal)
}
}
}
}
if matchFilter {
modelData := map[string]interface{}{
"name": model.Name,
"max_tokens": model.MaxTokens,
"model_types": model.ModelTypes,
//"features": getFeaturesMap(model.Features),
}
if filterBy != "" && filterValue != nil {
modelData["filter_applied"] = map[string]interface{}{
"field": filterBy,
"value": filterValue,
}
}
resp.Data = append(resp.Data, modelData)
}
return resp
}
// 5. Display models with specific features
func (pm *ProviderManager) SearchByFeature(featureType string) ModelResponse {
resp := ModelResponse{
Code: 0,
Data: []map[string]interface{}{},
Message: "success",
}
//for _, provider := range pm.Providers {
// for _, model := range provider.Models {
// if modelHasFeature(model.Features, featureType) {
// modelData := map[string]interface{}{
// "provider": provider.Name,
// "name": model.Name,
// "max_tokens": model.MaxTokens,
// "model_types": model.ModelTypes,
// "features": getFeaturesMap(model.Features),
// }
// resp.Data = append(resp.Data, modelData)
// }
// }
//}
if len(resp.Data) == 0 {
resp.Code = 404
resp.Message = fmt.Sprintf("No models found with feature '%s'", featureType)
}
return resp
}
// 6. Display models with specific type
func (pm *ProviderManager) SearchByType(modelType string) ModelResponse {
resp := ModelResponse{
Code: 0,
Data: []map[string]interface{}{},
Message: "success",
}
for _, provider := range pm.Providers {
for _, model := range provider.Models {
if containsModelType(model.ModelTypes, modelType) {
modelData := map[string]interface{}{
"provider": provider.Name,
"name": model.Name,
"max_tokens": model.MaxTokens,
"model_types": model.ModelTypes,
//"features": getFeaturesMap(model.Features),
}
resp.Data = append(resp.Data, modelData)
}
}
}
if len(resp.Data) == 0 {
resp.Code = 404
resp.Message = fmt.Sprintf("No models found with type '%s'", modelType)
}
return resp
}
func GetFeatures(model *Model) []string {
var features []string
if model.Thinking != nil {
features = append(features, "thinking")
}
return features
}
func ConvertToFeaturesMap(model *Model) map[string]interface{} {
featuresMap := make(map[string]interface{})
if model.Thinking != nil {
thinkingMap := map[string]interface{}{
"default_value": model.Thinking.DefaultValue,
"clear_reasoning": model.Thinking.ClearThinking,
}
featuresMap["thinking"] = thinkingMap
}
return featuresMap
}
// Helper: Get features map for response
func getFeaturesMap(features Features) map[string]interface{} {
featuresMap := make(map[string]interface{})
if features.Multimodal != nil && features.Multimodal.Enabled {
multimodalMap := map[string]interface{}{
"enabled": features.Multimodal.Enabled,
"input_modalities": features.Multimodal.InputModalities,
"output_modalities": features.Multimodal.OutputModalities,
}
featuresMap["multimodal"] = multimodalMap
}
if features.Reasoning != nil {
reasoningMap := make(map[string]interface{})
switch features.Reasoning.RawType {
case "simple":
if features.Reasoning.Simple != nil {
reasoningMap["type"] = "simple"
reasoningMap["enabled"] = features.Reasoning.Simple.Enabled
reasoningMap["default"] = features.Reasoning.Simple.Default
}
case "budget":
if features.Reasoning.Budget != nil {
reasoningMap["type"] = "budget"
reasoningMap["enabled"] = features.Reasoning.Budget.Enabled
reasoningMap["default_tokens"] = features.Reasoning.Budget.DefaultTokens
reasoningMap["token_range"] = map[string]int{
"min": features.Reasoning.Budget.TokenRange.Min,
"max": features.Reasoning.Budget.TokenRange.Max,
}
}
case "effort":
if features.Reasoning.Effort != nil {
reasoningMap["type"] = "effort"
reasoningMap["enabled"] = features.Reasoning.Effort.Enabled
reasoningMap["default"] = features.Reasoning.Effort.Default
reasoningMap["options"] = features.Reasoning.Effort.Options
}
}
featuresMap["reasoning"] = reasoningMap
}
return featuresMap
}
// Helper: Check if model has a specific feature
func modelHasFeature(features Features, featureType string) bool {
switch strings.ToLower(featureType) {
case "multimodal":
return features.Multimodal != nil && features.Multimodal.Enabled
case "reasoning":
return features.Reasoning != nil
case "reasoning_simple":
return features.Reasoning != nil && features.Reasoning.RawType == "simple"
case "reasoning_budget":
return features.Reasoning != nil && features.Reasoning.RawType == "budget"
case "reasoning_effort":
return features.Reasoning != nil && features.Reasoning.RawType == "effort"
default:
return false
}
}
// findRepoRoot walks up from CWD until it finds the repo root (marked by
// conf/all_models.json). This makes tests work regardless of the Go test
// binary's CWD (which is set to the package directory by go test).
func findRepoRoot() string {
dir, err := os.Getwd()
if err != nil {
return "."
}
for dir != "/" && dir != "" {
if _, err := os.Stat(filepath.Join(dir, "conf", "all_models.json")); err == nil {
return dir
}
dir = filepath.Dir(dir)
}
return "."
}
// Helper: Find provider by name
func (pm *ProviderManager) FindProvider(name string) *Provider {
for i := range pm.Providers {
if strings.EqualFold(pm.Providers[i].Name, name) {
return &pm.Providers[i]
}
}
return nil
}
// Helper: Find model by name
func (pm *ProviderManager) findModel(provider *Provider, modelName string) *Model {
for i := range provider.Models {
if strings.EqualFold(provider.Models[i].Name, modelName) {
return provider.Models[i]
}
}
return nil
}
// Helper: Check if model types contains target
func containsModelType(types []string, target string) bool {
for _, t := range types {
if strings.EqualFold(t, target) {
return true
}
}
return false
}