mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 00:05:43 +08:00
Go: refactor model API to accept model id (#15999)
### What problem does this PR solve? Not not only model_name@instance_name@provider_name is acceptable, but also model_id is acceptable. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@@ -18,6 +18,7 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"ragflow/internal/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -1301,14 +1302,20 @@ func (p *Parser) parseAdminSetDefault() (*Command, error) {
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
compositeModelName, err := p.parseQuotedString()
|
||||
modelNameOrID, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmd := NewCommand("set_default_model")
|
||||
cmd.Params["model_type"] = modelType
|
||||
cmd.Params["composite_model_name"] = compositeModelName
|
||||
if common.IsCompositeModelName(modelNameOrID) {
|
||||
cmd.Params["composite_model_name"] = modelNameOrID
|
||||
} else if common.IsUUID(modelNameOrID) {
|
||||
cmd.Params["model_id"] = modelNameOrID
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID)
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"ragflow/internal/common"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -464,13 +465,11 @@ func (c *CLI) SetDefaultModel(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
var providerName, instanceName, modelName string
|
||||
names := strings.Split(compositeModelName, "/")
|
||||
if len(names) != 3 {
|
||||
return nil, fmt.Errorf("model name must be in format 'provider/instance/model'")
|
||||
var err error
|
||||
modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providerName = names[0]
|
||||
instanceName = names[1]
|
||||
modelName = names[2]
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"model_type": modelType,
|
||||
@@ -480,7 +479,6 @@ func (c *CLI) SetDefaultModel(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
var resp *Response
|
||||
var err error
|
||||
switch c.Config.CLIMode {
|
||||
case AdminMode:
|
||||
resp, err = c.AdminServerClient.Request("PATCH", "/admin/models", "web", nil, payload)
|
||||
@@ -560,7 +558,7 @@ func (c *CLI) ListDefaultModels(cmd *Command) (ResponseIf, error) {
|
||||
case AdminMode:
|
||||
resp, err = c.AdminServerClient.Request("GET", "/admin/models", "web", nil, nil)
|
||||
case APIMode:
|
||||
resp, err = c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].Request("GET", "/models", "web", nil, nil)
|
||||
resp, err = c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].Request("GET", "/models/default", "web", nil, nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid server type")
|
||||
}
|
||||
@@ -573,7 +571,7 @@ func (c *CLI) ListDefaultModels(cmd *Command) (ResponseIf, error) {
|
||||
return nil, fmt.Errorf("failed to list default models: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result CommonResponse
|
||||
var result ModelsResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to list default models: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
@@ -374,6 +374,7 @@ type CurrentModel struct {
|
||||
Provider string
|
||||
Instance string
|
||||
Model string
|
||||
ModelID string
|
||||
}
|
||||
|
||||
// httpClientAdapter adapts HTTPClient to ce.HTTPClientInterface
|
||||
|
||||
@@ -57,6 +57,36 @@ func (r *CommonResponse) PrintOut() {
|
||||
}
|
||||
}
|
||||
|
||||
type ModelsResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data map[string][]map[string]interface{} `json:"data"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
OutputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *ModelsResponse) Type() string {
|
||||
return "models"
|
||||
}
|
||||
|
||||
func (r *ModelsResponse) TimeCost() float64 {
|
||||
return r.Duration
|
||||
}
|
||||
|
||||
func (r *ModelsResponse) SetOutputFormat(format OutputFormat) {
|
||||
r.OutputFormat = format
|
||||
}
|
||||
|
||||
func (r *ModelsResponse) PrintOut() {
|
||||
if r.Code == 0 {
|
||||
models := r.Data["models"]
|
||||
PrintTableSimpleByFormat(models, r.OutputFormat)
|
||||
} else {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
}
|
||||
}
|
||||
|
||||
type CommonDataResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/ingestion"
|
||||
"ragflow/internal/ingestion/parser"
|
||||
"ragflow/internal/utility"
|
||||
@@ -1632,23 +1633,34 @@ func (c *CLI) ChatToModel(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
var providerName, instanceName, modelName string
|
||||
var err error
|
||||
|
||||
// Check if composite_model_name is provided in command
|
||||
if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" {
|
||||
names := strings.Split(compositeModelName, "@")
|
||||
if len(names) != 3 {
|
||||
return nil, fmt.Errorf("model name must be in format 'model@instance@provider'")
|
||||
compositeModelName, ok := cmd.Params["composite_model_name"].(string)
|
||||
if ok {
|
||||
modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providerName = names[2]
|
||||
instanceName = names[1]
|
||||
modelName = names[0]
|
||||
} else if c.CurrentModel != nil {
|
||||
}
|
||||
|
||||
modelID, ok := cmd.Params["model_id"].(string)
|
||||
if !ok {
|
||||
modelID = ""
|
||||
}
|
||||
|
||||
if modelID == "" && compositeModelName == "" {
|
||||
if c.CurrentModel == nil {
|
||||
return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first")
|
||||
}
|
||||
|
||||
// Use current model if set
|
||||
providerName = c.CurrentModel.Provider
|
||||
instanceName = c.CurrentModel.Instance
|
||||
modelName = c.CurrentModel.Model
|
||||
} else {
|
||||
return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first")
|
||||
if c.CurrentModel.ModelID != "" {
|
||||
modelID = c.CurrentModel.ModelID
|
||||
} else {
|
||||
providerName = c.CurrentModel.Provider
|
||||
instanceName = c.CurrentModel.Instance
|
||||
modelName = c.CurrentModel.Model
|
||||
}
|
||||
}
|
||||
|
||||
formattedMessages := []map[string]interface{}{}
|
||||
@@ -1773,12 +1785,16 @@ func (c *CLI) ChatToModel(cmd *Command) (ResponseIf, error) {
|
||||
url := "/chat/completions"
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"provider_name": providerName,
|
||||
"instance_name": instanceName,
|
||||
"model_name": modelName,
|
||||
"messages": formattedMessages,
|
||||
"stream": stream,
|
||||
"thinking": thinking,
|
||||
"messages": formattedMessages,
|
||||
"stream": stream,
|
||||
"thinking": thinking,
|
||||
}
|
||||
if modelID == "" {
|
||||
payload["provider_name"] = providerName
|
||||
payload["instance_name"] = instanceName
|
||||
payload["model_name"] = modelName
|
||||
} else {
|
||||
payload["model_id"] = modelID
|
||||
}
|
||||
|
||||
if thinking {
|
||||
@@ -1864,12 +1880,12 @@ func (c *CLI) ChatToModel(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to list instance models: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
return nil, fmt.Errorf("failed to chat model: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result NonStreamResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to list instance models: invalid JSON (%w)", err)
|
||||
return nil, fmt.Errorf("failed to chat model: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
@@ -1889,23 +1905,35 @@ func (c *CLI) EmbedUserText(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
var providerName, instanceName, modelName string
|
||||
var err error
|
||||
|
||||
// Check if composite_model_name is provided in command
|
||||
if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" {
|
||||
names := strings.Split(compositeModelName, "@")
|
||||
if len(names) != 3 {
|
||||
return nil, fmt.Errorf("model name must be in format 'model@instance@provider'")
|
||||
compositeModelName, ok := cmd.Params["composite_model_name"].(string)
|
||||
if ok {
|
||||
modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providerName = names[2]
|
||||
instanceName = names[1]
|
||||
modelName = names[0]
|
||||
} else if c.CurrentModel != nil {
|
||||
}
|
||||
|
||||
modelID, ok := cmd.Params["model_id"].(string)
|
||||
if !ok {
|
||||
modelID = ""
|
||||
}
|
||||
|
||||
if modelID == "" && compositeModelName == "" {
|
||||
if c.CurrentModel == nil {
|
||||
return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first")
|
||||
}
|
||||
|
||||
// Use current model if set
|
||||
providerName = c.CurrentModel.Provider
|
||||
instanceName = c.CurrentModel.Instance
|
||||
modelName = c.CurrentModel.Model
|
||||
} else {
|
||||
return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first")
|
||||
if c.CurrentModel.ModelID != "" {
|
||||
modelID = c.CurrentModel.ModelID
|
||||
} else {
|
||||
providerName = c.CurrentModel.Provider
|
||||
instanceName = c.CurrentModel.Instance
|
||||
modelName = c.CurrentModel.Model
|
||||
}
|
||||
}
|
||||
|
||||
texts, ok := cmd.Params["texts"].([]string)
|
||||
@@ -1919,11 +1947,15 @@ func (c *CLI) EmbedUserText(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"provider_name": providerName,
|
||||
"instance_name": instanceName,
|
||||
"model_name": modelName,
|
||||
"texts": texts,
|
||||
"dimension": dimension,
|
||||
"texts": texts,
|
||||
"dimension": dimension,
|
||||
}
|
||||
if modelID == "" {
|
||||
payload["provider_name"] = providerName
|
||||
payload["instance_name"] = instanceName
|
||||
payload["model_name"] = modelName
|
||||
} else {
|
||||
payload["model_id"] = modelID
|
||||
}
|
||||
|
||||
url := "/embeddings"
|
||||
@@ -1956,23 +1988,35 @@ func (c *CLI) RerankUserDocument(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
var providerName, instanceName, modelName string
|
||||
var err error
|
||||
|
||||
// Check if composite_model_name is provided in command
|
||||
if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" {
|
||||
names := strings.Split(compositeModelName, "@")
|
||||
if len(names) != 3 {
|
||||
return nil, fmt.Errorf("model name must be in format 'model@instance@provider'")
|
||||
compositeModelName, ok := cmd.Params["composite_model_name"].(string)
|
||||
if ok {
|
||||
modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providerName = names[2]
|
||||
instanceName = names[1]
|
||||
modelName = names[0]
|
||||
} else if c.CurrentModel != nil {
|
||||
}
|
||||
|
||||
modelID, ok := cmd.Params["model_id"].(string)
|
||||
if !ok {
|
||||
modelID = ""
|
||||
}
|
||||
|
||||
if modelID == "" && compositeModelName == "" {
|
||||
if c.CurrentModel == nil {
|
||||
return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first")
|
||||
}
|
||||
|
||||
// Use current model if set
|
||||
providerName = c.CurrentModel.Provider
|
||||
instanceName = c.CurrentModel.Instance
|
||||
modelName = c.CurrentModel.Model
|
||||
} else {
|
||||
return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first")
|
||||
if c.CurrentModel.ModelID != "" {
|
||||
modelID = c.CurrentModel.ModelID
|
||||
} else {
|
||||
providerName = c.CurrentModel.Provider
|
||||
instanceName = c.CurrentModel.Instance
|
||||
modelName = c.CurrentModel.Model
|
||||
}
|
||||
}
|
||||
|
||||
query, ok := cmd.Params["query"].(string)
|
||||
@@ -1991,12 +2035,16 @@ func (c *CLI) RerankUserDocument(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"provider_name": providerName,
|
||||
"instance_name": instanceName,
|
||||
"model_name": modelName,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": topN,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": topN,
|
||||
}
|
||||
if modelID == "" {
|
||||
payload["provider_name"] = providerName
|
||||
payload["instance_name"] = instanceName
|
||||
payload["model_name"] = modelName
|
||||
} else {
|
||||
payload["model_id"] = modelID
|
||||
}
|
||||
|
||||
url := "/rerank"
|
||||
@@ -2029,23 +2077,35 @@ func (c *CLI) TTSUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
var providerName, instanceName, modelName string
|
||||
var err error
|
||||
|
||||
// Check if composite_model_name is provided in command
|
||||
if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" {
|
||||
names := strings.Split(compositeModelName, "@")
|
||||
if len(names) != 3 {
|
||||
return nil, fmt.Errorf("model name must be in format 'model@instance@provider'")
|
||||
compositeModelName, ok := cmd.Params["composite_model_name"].(string)
|
||||
if ok {
|
||||
modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providerName = names[2]
|
||||
instanceName = names[1]
|
||||
modelName = names[0]
|
||||
} else if c.CurrentModel != nil {
|
||||
}
|
||||
|
||||
modelID, ok := cmd.Params["model_id"].(string)
|
||||
if !ok {
|
||||
modelID = ""
|
||||
}
|
||||
|
||||
if modelID == "" && compositeModelName == "" {
|
||||
if c.CurrentModel == nil {
|
||||
return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first")
|
||||
}
|
||||
|
||||
// Use current model if set
|
||||
providerName = c.CurrentModel.Provider
|
||||
instanceName = c.CurrentModel.Instance
|
||||
modelName = c.CurrentModel.Model
|
||||
} else {
|
||||
return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first")
|
||||
if c.CurrentModel.ModelID != "" {
|
||||
modelID = c.CurrentModel.ModelID
|
||||
} else {
|
||||
providerName = c.CurrentModel.Provider
|
||||
instanceName = c.CurrentModel.Instance
|
||||
modelName = c.CurrentModel.Model
|
||||
}
|
||||
}
|
||||
|
||||
text, ok := cmd.Params["text"].(string)
|
||||
@@ -2059,10 +2119,14 @@ func (c *CLI) TTSUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
//}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"provider_name": providerName,
|
||||
"instance_name": instanceName,
|
||||
"model_name": modelName,
|
||||
"text": text,
|
||||
"text": text,
|
||||
}
|
||||
if modelID == "" {
|
||||
payload["provider_name"] = providerName
|
||||
payload["instance_name"] = instanceName
|
||||
payload["model_name"] = modelName
|
||||
} else {
|
||||
payload["model_id"] = modelID
|
||||
}
|
||||
|
||||
ttsConfigPayload := make(map[string]interface{})
|
||||
@@ -2221,23 +2285,35 @@ func (c *CLI) ASRUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
var providerName, instanceName, modelName string
|
||||
var err error
|
||||
|
||||
// Check if composite_model_name is provided in command
|
||||
if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" {
|
||||
names := strings.Split(compositeModelName, "@")
|
||||
if len(names) != 3 {
|
||||
return nil, fmt.Errorf("model name must be in format 'model@instance@provider'")
|
||||
compositeModelName, ok := cmd.Params["composite_model_name"].(string)
|
||||
if ok {
|
||||
modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providerName = names[2]
|
||||
instanceName = names[1]
|
||||
modelName = names[0]
|
||||
} else if c.CurrentModel != nil {
|
||||
}
|
||||
|
||||
modelID, ok := cmd.Params["model_id"].(string)
|
||||
if !ok {
|
||||
modelID = ""
|
||||
}
|
||||
|
||||
if modelID == "" && compositeModelName == "" {
|
||||
if c.CurrentModel == nil {
|
||||
return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first")
|
||||
}
|
||||
|
||||
// Use current model if set
|
||||
providerName = c.CurrentModel.Provider
|
||||
instanceName = c.CurrentModel.Instance
|
||||
modelName = c.CurrentModel.Model
|
||||
} else {
|
||||
return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first")
|
||||
if c.CurrentModel.ModelID != "" {
|
||||
modelID = c.CurrentModel.ModelID
|
||||
} else {
|
||||
providerName = c.CurrentModel.Provider
|
||||
instanceName = c.CurrentModel.Instance
|
||||
modelName = c.CurrentModel.Model
|
||||
}
|
||||
}
|
||||
|
||||
audioFile, ok := cmd.Params["audio_file"].(string)
|
||||
@@ -2246,10 +2322,15 @@ func (c *CLI) ASRUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"provider_name": providerName,
|
||||
"instance_name": instanceName,
|
||||
"model_name": modelName,
|
||||
"file": audioFile,
|
||||
"file": audioFile,
|
||||
}
|
||||
|
||||
if modelID == "" {
|
||||
payload["provider_name"] = providerName
|
||||
payload["instance_name"] = instanceName
|
||||
payload["model_name"] = modelName
|
||||
} else {
|
||||
payload["model_id"] = modelID
|
||||
}
|
||||
|
||||
asrConfigPayload := make(map[string]interface{})
|
||||
@@ -2308,28 +2389,38 @@ func (c *CLI) OCRUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
var providerName, instanceName, modelName string
|
||||
var err error
|
||||
|
||||
// Check if composite_model_name is provided in command
|
||||
if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" {
|
||||
names := strings.Split(compositeModelName, "@")
|
||||
if len(names) != 3 {
|
||||
return nil, fmt.Errorf("model name must be in format 'model@instance@provider'")
|
||||
compositeModelName, ok := cmd.Params["composite_model_name"].(string)
|
||||
if ok {
|
||||
modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providerName = names[2]
|
||||
instanceName = names[1]
|
||||
modelName = names[0]
|
||||
} else if c.CurrentModel != nil {
|
||||
// Use current model if set
|
||||
providerName = c.CurrentModel.Provider
|
||||
instanceName = c.CurrentModel.Instance
|
||||
modelName = c.CurrentModel.Model
|
||||
} else {
|
||||
return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first")
|
||||
}
|
||||
|
||||
modelID, ok := cmd.Params["model_id"].(string)
|
||||
if !ok {
|
||||
modelID = ""
|
||||
}
|
||||
|
||||
if modelID == "" && compositeModelName == "" {
|
||||
if c.CurrentModel == nil {
|
||||
return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first")
|
||||
}
|
||||
|
||||
// Use current model if set
|
||||
if c.CurrentModel.ModelID != "" {
|
||||
modelID = c.CurrentModel.ModelID
|
||||
} else {
|
||||
providerName = c.CurrentModel.Provider
|
||||
instanceName = c.CurrentModel.Instance
|
||||
modelName = c.CurrentModel.Model
|
||||
}
|
||||
}
|
||||
var filename string
|
||||
var fileURL string
|
||||
var ok bool
|
||||
var fileContent []byte
|
||||
|
||||
filename, ok = cmd.Params["file"].(string)
|
||||
@@ -2347,10 +2438,14 @@ func (c *CLI) OCRUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"provider_name": providerName,
|
||||
"instance_name": instanceName,
|
||||
"model_name": modelName,
|
||||
payload := map[string]interface{}{}
|
||||
|
||||
if modelID == "" {
|
||||
payload["provider_name"] = providerName
|
||||
payload["instance_name"] = instanceName
|
||||
payload["model_name"] = modelName
|
||||
} else {
|
||||
payload["model_id"] = modelID
|
||||
}
|
||||
|
||||
if fileContent != nil {
|
||||
@@ -2390,28 +2485,39 @@ func (c *CLI) ParseFileUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
var providerName, instanceName, modelName string
|
||||
var err error
|
||||
|
||||
// Check if composite_model_name is provided in command
|
||||
if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" {
|
||||
names := strings.Split(compositeModelName, "@")
|
||||
if len(names) != 3 {
|
||||
return nil, fmt.Errorf("model name must be in format 'model@instance@provider'")
|
||||
compositeModelName, ok := cmd.Params["composite_model_name"].(string)
|
||||
if ok {
|
||||
modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providerName = names[2]
|
||||
instanceName = names[1]
|
||||
modelName = names[0]
|
||||
} else if c.CurrentModel != nil {
|
||||
}
|
||||
|
||||
modelID, ok := cmd.Params["model_id"].(string)
|
||||
if !ok {
|
||||
modelID = ""
|
||||
}
|
||||
|
||||
if modelID == "" && compositeModelName == "" {
|
||||
if c.CurrentModel == nil {
|
||||
return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first")
|
||||
}
|
||||
|
||||
// Use current model if set
|
||||
providerName = c.CurrentModel.Provider
|
||||
instanceName = c.CurrentModel.Instance
|
||||
modelName = c.CurrentModel.Model
|
||||
} else {
|
||||
return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first")
|
||||
if c.CurrentModel.ModelID != "" {
|
||||
modelID = c.CurrentModel.ModelID
|
||||
} else {
|
||||
providerName = c.CurrentModel.Provider
|
||||
instanceName = c.CurrentModel.Instance
|
||||
modelName = c.CurrentModel.Model
|
||||
}
|
||||
}
|
||||
|
||||
var filename string
|
||||
var fileURL string
|
||||
var ok bool
|
||||
var fileContent []byte
|
||||
|
||||
filename, ok = cmd.Params["file"].(string)
|
||||
@@ -2434,10 +2540,14 @@ func (c *CLI) ParseFileUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"provider_name": providerName,
|
||||
"instance_name": instanceName,
|
||||
"model_name": modelName,
|
||||
payload := map[string]interface{}{}
|
||||
|
||||
if modelID == "" {
|
||||
payload["provider_name"] = providerName
|
||||
payload["instance_name"] = instanceName
|
||||
payload["model_name"] = modelName
|
||||
} else {
|
||||
payload["model_id"] = modelID
|
||||
}
|
||||
|
||||
if fileContent != nil {
|
||||
@@ -2664,20 +2774,30 @@ func (c *CLI) UseModel(cmd *Command) (ResponseIf, error) {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
}
|
||||
|
||||
var modelName, instanceName, providerName string
|
||||
var err error
|
||||
compositeModelName, ok := cmd.Params["composite_model_name"].(string)
|
||||
if !ok || compositeModelName == "" {
|
||||
return nil, fmt.Errorf("model identifier not provided")
|
||||
if ok {
|
||||
modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
names := strings.Split(compositeModelName, "@")
|
||||
if len(names) != 3 {
|
||||
return nil, fmt.Errorf("model identifier must be in format 'model@instance@provider'")
|
||||
modelID, ok := cmd.Params["model_id"].(string)
|
||||
if !ok {
|
||||
modelID = ""
|
||||
}
|
||||
|
||||
if modelID == "" && compositeModelName == "" {
|
||||
return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first")
|
||||
}
|
||||
|
||||
c.CurrentModel = &CurrentModel{
|
||||
Provider: names[2],
|
||||
Instance: names[1],
|
||||
Model: names[0],
|
||||
Provider: providerName,
|
||||
Instance: instanceName,
|
||||
Model: modelName,
|
||||
ModelID: modelID,
|
||||
}
|
||||
|
||||
var result SimpleResponse
|
||||
|
||||
@@ -2,6 +2,7 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"ragflow/internal/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
@@ -2285,7 +2286,7 @@ func (p *Parser) parseSetVariable() (*Command, error) {
|
||||
func (p *Parser) parseSetDefault() (*Command, error) {
|
||||
p.nextToken() // consume DEFAULT
|
||||
|
||||
var modelType, compositeModelName string
|
||||
var modelType, modelNameOrID string
|
||||
var err error
|
||||
|
||||
switch p.curToken.Type {
|
||||
@@ -2313,12 +2314,12 @@ func (p *Parser) parseSetDefault() (*Command, error) {
|
||||
}
|
||||
p.nextToken() // pass MODEL
|
||||
|
||||
// Format: 'provider/instance/model' or just 'message'
|
||||
// Format: 'model@instance@provider' or just 'message'
|
||||
if p.curToken.Type != TokenQuotedString {
|
||||
return nil, fmt.Errorf("expected quoted string with format provider/instance/model")
|
||||
return nil, fmt.Errorf("expected quoted string with format model@instance@provider")
|
||||
}
|
||||
|
||||
compositeModelName, err = p.parseQuotedString()
|
||||
modelNameOrID, err = p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2326,7 +2327,14 @@ func (p *Parser) parseSetDefault() (*Command, error) {
|
||||
|
||||
cmd := NewCommand("set_default_model")
|
||||
cmd.Params["model_type"] = modelType
|
||||
cmd.Params["composite_model_name"] = compositeModelName
|
||||
|
||||
if common.IsCompositeModelName(modelNameOrID) {
|
||||
cmd.Params["composite_model_name"] = modelNameOrID
|
||||
} else if common.IsUUID(modelNameOrID) {
|
||||
cmd.Params["model_id"] = modelNameOrID
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID)
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
@@ -3024,7 +3032,7 @@ func (p *Parser) parseChatCommand() (*Command, error) {
|
||||
p.nextToken() // consume CHAT
|
||||
|
||||
var err error
|
||||
var compositeModelName string = ""
|
||||
var modelNameOrID string = ""
|
||||
var messages []string
|
||||
var images []string
|
||||
var videos []string
|
||||
@@ -3038,11 +3046,11 @@ optionsLoop:
|
||||
switch p.curToken.Type {
|
||||
case TokenWith:
|
||||
p.nextToken()
|
||||
// 'model@instance@provider'
|
||||
if compositeModelName != "" {
|
||||
return nil, fmt.Errorf("model name is already set")
|
||||
// 'model@instance@provider' or model ID
|
||||
if modelNameOrID != "" {
|
||||
return nil, fmt.Errorf("model name or ID is already set")
|
||||
}
|
||||
compositeModelName, err = p.parseQuotedString()
|
||||
modelNameOrID, err = p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3182,7 +3190,13 @@ optionsLoop:
|
||||
}
|
||||
cmd := NewCommand("chat_to_model")
|
||||
|
||||
cmd.Params["composite_model_name"] = compositeModelName
|
||||
if common.IsCompositeModelName(modelNameOrID) {
|
||||
cmd.Params["composite_model_name"] = modelNameOrID
|
||||
} else if common.IsUUID(modelNameOrID) {
|
||||
cmd.Params["model_id"] = modelNameOrID
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID)
|
||||
}
|
||||
cmd.Params["messages"] = messages
|
||||
cmd.Params["images"] = images
|
||||
cmd.Params["videos"] = videos
|
||||
@@ -3276,7 +3290,7 @@ textLoop:
|
||||
}
|
||||
p.nextToken() // consume WITH
|
||||
|
||||
compositeModelName, err := p.parseQuotedString()
|
||||
modelNameOrID, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3306,7 +3320,13 @@ textLoop:
|
||||
}
|
||||
|
||||
cmd := NewCommand("embed_user_text")
|
||||
cmd.Params["composite_model_name"] = compositeModelName
|
||||
if common.IsCompositeModelName(modelNameOrID) {
|
||||
cmd.Params["composite_model_name"] = modelNameOrID
|
||||
} else if common.IsUUID(modelNameOrID) {
|
||||
cmd.Params["model_id"] = modelNameOrID
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID)
|
||||
}
|
||||
cmd.Params["texts"] = texts
|
||||
if dimension > 0 {
|
||||
cmd.Params["dimension"] = dimension
|
||||
@@ -3356,7 +3376,7 @@ documentLoop:
|
||||
}
|
||||
p.nextToken() // consume WITH
|
||||
|
||||
compositeModelName, err := p.parseQuotedString()
|
||||
modelNameOrID, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3374,7 +3394,13 @@ documentLoop:
|
||||
p.nextToken()
|
||||
|
||||
cmd := NewCommand("rarank_user_document")
|
||||
cmd.Params["composite_model_name"] = compositeModelName
|
||||
if common.IsCompositeModelName(modelNameOrID) {
|
||||
cmd.Params["composite_model_name"] = modelNameOrID
|
||||
} else if common.IsUUID(modelNameOrID) {
|
||||
cmd.Params["model_id"] = modelNameOrID
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID)
|
||||
}
|
||||
cmd.Params["query"] = query
|
||||
cmd.Params["documents"] = documents
|
||||
cmd.Params["top_n"] = topN
|
||||
@@ -3389,7 +3415,7 @@ func (p *Parser) parseASRCommand() (*Command, error) {
|
||||
}
|
||||
p.nextToken() // consume WITH
|
||||
|
||||
compositeModelName, err := p.parseQuotedString()
|
||||
modelNameOrID, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3407,7 +3433,13 @@ func (p *Parser) parseASRCommand() (*Command, error) {
|
||||
p.nextToken()
|
||||
|
||||
cmd := NewCommand("asr_user_command")
|
||||
cmd.Params["composite_model_name"] = compositeModelName
|
||||
if common.IsCompositeModelName(modelNameOrID) {
|
||||
cmd.Params["composite_model_name"] = modelNameOrID
|
||||
} else if common.IsUUID(modelNameOrID) {
|
||||
cmd.Params["model_id"] = modelNameOrID
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID)
|
||||
}
|
||||
cmd.Params["audio_file"] = audioFile
|
||||
|
||||
for p.curToken.Type != TokenEOF && p.curToken.Type != TokenSemicolon {
|
||||
@@ -3445,9 +3477,18 @@ func (p *Parser) parseTTSCommand() (*Command, error) {
|
||||
if p.curToken.Type != TokenQuotedString && p.curToken.Type != TokenIdentifier {
|
||||
return nil, fmt.Errorf("expect model name after 'with'")
|
||||
}
|
||||
cmd.Params["composite_model_name"] = strings.Trim(p.curToken.Value, "\"'")
|
||||
|
||||
modelNameOrID := strings.Trim(p.curToken.Value, "\"'")
|
||||
p.nextToken()
|
||||
|
||||
if common.IsCompositeModelName(modelNameOrID) {
|
||||
cmd.Params["composite_model_name"] = modelNameOrID
|
||||
} else if common.IsUUID(modelNameOrID) {
|
||||
cmd.Params["model_id"] = modelNameOrID
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID)
|
||||
}
|
||||
|
||||
if p.curToken.Type != TokenText {
|
||||
return nil, fmt.Errorf("expect 'text' parameter")
|
||||
}
|
||||
@@ -3509,7 +3550,7 @@ func (p *Parser) parseOCRCommand() (*Command, error) {
|
||||
}
|
||||
p.nextToken() // consume WITH
|
||||
|
||||
compositeModelName, err := p.parseQuotedString()
|
||||
modelNameOrID, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3540,7 +3581,18 @@ func (p *Parser) parseOCRCommand() (*Command, error) {
|
||||
return nil, fmt.Errorf("expected FILE or URL")
|
||||
}
|
||||
|
||||
cmd.Params["composite_model_name"] = compositeModelName
|
||||
if common.IsCompositeModelName(modelNameOrID) {
|
||||
cmd.Params["composite_model_name"] = modelNameOrID
|
||||
} else if common.IsUUID(modelNameOrID) {
|
||||
cmd.Params["model_id"] = modelNameOrID
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID)
|
||||
}
|
||||
|
||||
// Semicolon is optional
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
@@ -3548,7 +3600,7 @@ func (p *Parser) parseOCRCommand() (*Command, error) {
|
||||
func (p *Parser) parseModelParseCommand() (*Command, error) {
|
||||
p.nextToken() // consume WITH
|
||||
|
||||
compositeModelName, err := p.parseQuotedString()
|
||||
modelNameOrID, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3579,7 +3631,18 @@ func (p *Parser) parseModelParseCommand() (*Command, error) {
|
||||
return nil, fmt.Errorf("expected FILE or URL")
|
||||
}
|
||||
|
||||
cmd.Params["composite_model_name"] = compositeModelName
|
||||
if common.IsCompositeModelName(modelNameOrID) {
|
||||
cmd.Params["composite_model_name"] = modelNameOrID
|
||||
} else if common.IsUUID(modelNameOrID) {
|
||||
cmd.Params["model_id"] = modelNameOrID
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID)
|
||||
}
|
||||
|
||||
// Semicolon is optional
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
@@ -3711,7 +3774,7 @@ func (p *Parser) parseUseCommand() (*Command, error) {
|
||||
func (p *Parser) parseUseModel() (*Command, error) {
|
||||
p.nextToken() // consume MODEL
|
||||
|
||||
compositeModelName, err := p.parseQuotedString()
|
||||
modelNameOrID, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected model identifier in format 'model@instance@provider': %w", err)
|
||||
}
|
||||
@@ -3723,7 +3786,19 @@ func (p *Parser) parseUseModel() (*Command, error) {
|
||||
}
|
||||
|
||||
cmd := NewCommand("use_model")
|
||||
cmd.Params["composite_model_name"] = compositeModelName
|
||||
|
||||
if common.IsCompositeModelName(modelNameOrID) {
|
||||
cmd.Params["composite_model_name"] = modelNameOrID
|
||||
} else if common.IsUUID(modelNameOrID) {
|
||||
cmd.Params["model_id"] = modelNameOrID
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID)
|
||||
}
|
||||
|
||||
// Semicolon is optional
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,11 @@
|
||||
|
||||
package common
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PtrString formats a pointer value as a string for debug/log output.
|
||||
// Returns "<nil>" for nil pointers.
|
||||
@@ -26,3 +30,45 @@ func PtrString[T any](p *T) string {
|
||||
}
|
||||
return fmt.Sprintf("%v", *p)
|
||||
}
|
||||
|
||||
// composite model name format: model_name@instance_name@provider_name
|
||||
func IsCompositeModelName(modelName string) bool {
|
||||
parts := strings.Split(modelName, "@")
|
||||
if len(parts) != 3 {
|
||||
return false
|
||||
}
|
||||
for _, p := range parts {
|
||||
if p == "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func IsUUID(uuid string) bool {
|
||||
// only lower case letters and numbers, length is 32
|
||||
if len(uuid) != 32 {
|
||||
return false
|
||||
}
|
||||
uuidRegex := regexp.MustCompile(`^[a-z0-9]+$`)
|
||||
if uuidRegex.MatchString(uuid) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractCompositeName splits a composite model name into three parts.
|
||||
// Returns (modelName, instanceName, providerName, true) on success,
|
||||
// or ("", "", "", false) if the name is not a valid composite name.
|
||||
func ExtractCompositeName(modelName string) (string, string, string, error) {
|
||||
parts := strings.Split(modelName, "@")
|
||||
if len(parts) != 3 {
|
||||
return "", "", "", fmt.Errorf("invalid model name format")
|
||||
}
|
||||
for _, p := range parts {
|
||||
if p == "" {
|
||||
return "", "", "", fmt.Errorf("invalid model name format")
|
||||
}
|
||||
}
|
||||
return parts[0], parts[1], parts[2], nil
|
||||
}
|
||||
|
||||
@@ -897,6 +897,7 @@ type ChatToModelRequest struct {
|
||||
ProviderName *string `json:"provider_name"`
|
||||
InstanceName *string `json:"instance_name"`
|
||||
ModelName *string `json:"model_name"`
|
||||
ModelID *string `json:"model_id"`
|
||||
Messages []map[string]interface{} `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
Thinking bool `json:"thinking"`
|
||||
@@ -915,28 +916,38 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.ProviderName == nil || *req.ProviderName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.ModelID == nil {
|
||||
if req.ProviderName == nil || *req.ProviderName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.InstanceName == nil || *req.InstanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.InstanceName == nil || *req.InstanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelName == nil || *req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
if req.ModelName == nil || *req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if *req.ModelID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model ID is empty",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userID := c.GetString("user_id")
|
||||
@@ -1005,8 +1016,8 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
|
||||
messages[i] = models.Message{Role: role, Content: content}
|
||||
}
|
||||
|
||||
// Stream response using sender function (best performance, no channel)
|
||||
errorCode, err := h.modelProviderService.ChatToModelStreamWithSender(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, messages, &apiConfig, &chatConfig, sender)
|
||||
// Stream response using sender function (the best performance, no channel)
|
||||
errorCode, err := h.modelProviderService.ChatToModelStreamWithSender(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, messages, &apiConfig, &chatConfig, sender)
|
||||
|
||||
if errorCode != common.CodeSuccess {
|
||||
c.SSEvent("error", err.Error())
|
||||
@@ -1026,7 +1037,7 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
|
||||
content := msg["content"]
|
||||
messages[i] = models.Message{Role: role, Content: content}
|
||||
}
|
||||
response, errorCode, err = h.modelProviderService.ChatToModelWithMessages(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, messages, &apiConfig, &chatConfig)
|
||||
response, errorCode, err = h.modelProviderService.ChatToModelWithMessages(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, messages, &apiConfig, &chatConfig)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -1047,6 +1058,7 @@ type EmbedTextRequest struct {
|
||||
ProviderName *string `json:"provider_name"`
|
||||
InstanceName *string `json:"instance_name"`
|
||||
ModelName *string `json:"model_name"`
|
||||
ModelID *string `json:"model_id"`
|
||||
Texts []string `json:"texts"`
|
||||
Dimension int `json:"dimension"`
|
||||
}
|
||||
@@ -1062,28 +1074,38 @@ func (h *ProviderHandler) EmbedText(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.ProviderName == nil || *req.ProviderName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.ModelID == nil {
|
||||
if req.ProviderName == nil || *req.ProviderName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.InstanceName == nil || *req.InstanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.InstanceName == nil || *req.InstanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelName == nil || *req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
if req.ModelName == nil || *req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if *req.ModelID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model ID is empty",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userID := c.GetString("user_id")
|
||||
@@ -1102,8 +1124,7 @@ func (h *ProviderHandler) EmbedText(c *gin.Context) {
|
||||
var errorCode common.ErrorCode
|
||||
var err error
|
||||
|
||||
response, errorCode, err = h.modelProviderService.EmbedText(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Texts, &apiConfig, &embeddingConfig)
|
||||
|
||||
response, errorCode, err = h.modelProviderService.EmbedText(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.Texts, &apiConfig, &embeddingConfig)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": errorCode,
|
||||
@@ -1123,6 +1144,7 @@ type RerankDocumentRequest struct {
|
||||
ProviderName *string `json:"provider_name"`
|
||||
InstanceName *string `json:"instance_name"`
|
||||
ModelName *string `json:"model_name"`
|
||||
ModelID *string `json:"model_id"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN int `json:"top_n"`
|
||||
@@ -1139,28 +1161,38 @@ func (h *ProviderHandler) RerankDocument(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.ProviderName == nil || *req.ProviderName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.ModelID == nil {
|
||||
if req.ProviderName == nil || *req.ProviderName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.InstanceName == nil || *req.InstanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.InstanceName == nil || *req.InstanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelName == nil || *req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
if req.ModelName == nil || *req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if *req.ModelID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model ID is empty",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userID := c.GetString("user_id")
|
||||
@@ -1179,8 +1211,7 @@ func (h *ProviderHandler) RerankDocument(c *gin.Context) {
|
||||
var errorCode common.ErrorCode
|
||||
var err error
|
||||
|
||||
response, errorCode, err = h.modelProviderService.RerankDocument(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Query, req.Documents, &apiConfig, &rerankConfig)
|
||||
|
||||
response, errorCode, err = h.modelProviderService.RerankDocument(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.Query, req.Documents, &apiConfig, &rerankConfig)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": errorCode,
|
||||
@@ -1200,6 +1231,7 @@ type TranscribeAudioRequest struct {
|
||||
ProviderName *string `json:"provider_name"`
|
||||
InstanceName *string `json:"instance_name"`
|
||||
ModelName *string `json:"model_name"`
|
||||
ModelID *string `json:"model_id"`
|
||||
File *string `json:"file"`
|
||||
Language []string `json:"language"`
|
||||
Prompt int `json:"prompt"`
|
||||
@@ -1218,28 +1250,38 @@ func (h *ProviderHandler) TranscribeAudio(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.ProviderName == nil || *req.ProviderName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.ModelID == nil {
|
||||
if req.ProviderName == nil || *req.ProviderName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.InstanceName == nil || *req.InstanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.InstanceName == nil || *req.InstanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelName == nil || *req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
if req.ModelName == nil || *req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if *req.ModelID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model ID is empty",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userID := c.GetString("user_id")
|
||||
@@ -1287,9 +1329,8 @@ func (h *ProviderHandler) TranscribeAudio(c *gin.Context) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stream response using sender function (best performance, no channel)
|
||||
errorCode, err := h.modelProviderService.TranscribeAudioStream(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.File, &apiConfig, &asrConfig, sender)
|
||||
|
||||
// Stream response using sender function ( the best performance, no channel)
|
||||
errorCode, err := h.modelProviderService.TranscribeAudioStream(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.File, &apiConfig, &asrConfig, sender)
|
||||
if errorCode != common.CodeSuccess {
|
||||
c.SSEvent("error", err.Error())
|
||||
}
|
||||
@@ -1301,8 +1342,7 @@ func (h *ProviderHandler) TranscribeAudio(c *gin.Context) {
|
||||
var errorCode common.ErrorCode
|
||||
var err error
|
||||
|
||||
response, errorCode, err = h.modelProviderService.TranscribeAudio(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.File, &apiConfig, &asrConfig)
|
||||
|
||||
response, errorCode, err = h.modelProviderService.TranscribeAudio(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.File, &apiConfig, &asrConfig)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": errorCode,
|
||||
@@ -1322,6 +1362,7 @@ type AudioSpeechRequest struct {
|
||||
ProviderName *string `json:"provider_name"`
|
||||
InstanceName *string `json:"instance_name"`
|
||||
ModelName *string `json:"model_name"`
|
||||
ModelID *string `json:"model_id"`
|
||||
Text *string `json:"text"`
|
||||
Stream bool `json:"stream"`
|
||||
TTSConfig *models.TTSConfig `json:"tts_config"`
|
||||
@@ -1338,28 +1379,38 @@ func (h *ProviderHandler) AudioSpeech(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.ProviderName == nil || *req.ProviderName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.ModelID == nil {
|
||||
if req.ProviderName == nil || *req.ProviderName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.InstanceName == nil || *req.InstanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.InstanceName == nil || *req.InstanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelName == nil || *req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
if req.ModelName == nil || *req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if *req.ModelID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model ID is empty",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userID := c.GetString("user_id")
|
||||
@@ -1407,9 +1458,8 @@ func (h *ProviderHandler) AudioSpeech(c *gin.Context) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stream response using sender function (best performance, no channel)
|
||||
errorCode, err := h.modelProviderService.AudioSpeechStream(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Text, &apiConfig, &ttsConfig, sender)
|
||||
|
||||
// Stream response using sender function ( the best performance, no channel)
|
||||
errorCode, err := h.modelProviderService.AudioSpeechStream(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.Text, &apiConfig, &ttsConfig, sender)
|
||||
if errorCode != common.CodeSuccess {
|
||||
c.SSEvent("error", err.Error())
|
||||
}
|
||||
@@ -1421,8 +1471,7 @@ func (h *ProviderHandler) AudioSpeech(c *gin.Context) {
|
||||
var errorCode common.ErrorCode
|
||||
var err error
|
||||
|
||||
response, errorCode, err = h.modelProviderService.AudioSpeech(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Text, &apiConfig, &ttsConfig)
|
||||
|
||||
response, errorCode, err = h.modelProviderService.AudioSpeech(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.Text, &apiConfig, &ttsConfig)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": errorCode,
|
||||
@@ -1442,6 +1491,7 @@ type OCRFileRequest struct {
|
||||
ProviderName *string `json:"provider_name"`
|
||||
InstanceName *string `json:"instance_name"`
|
||||
ModelName *string `json:"model_name"`
|
||||
ModelID *string `json:"model_id"`
|
||||
Content []byte `json:"content"`
|
||||
URL *string `json:"url"`
|
||||
}
|
||||
@@ -1457,28 +1507,38 @@ func (h *ProviderHandler) OCRFile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.ProviderName == nil || *req.ProviderName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.ModelID == nil {
|
||||
if req.ProviderName == nil || *req.ProviderName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.InstanceName == nil || *req.InstanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.InstanceName == nil || *req.InstanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelName == nil || *req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
if req.ModelName == nil || *req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if *req.ModelID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model ID is empty",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userID := c.GetString("user_id")
|
||||
@@ -1495,8 +1555,7 @@ func (h *ProviderHandler) OCRFile(c *gin.Context) {
|
||||
var errorCode common.ErrorCode
|
||||
var err error
|
||||
|
||||
response, errorCode, err = h.modelProviderService.OCRFile(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Content, req.URL, &apiConfig, &OCRConfig)
|
||||
|
||||
response, errorCode, err = h.modelProviderService.OCRFile(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.Content, req.URL, &apiConfig, &OCRConfig)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": errorCode,
|
||||
@@ -1516,6 +1575,7 @@ type ParseFileRequest struct {
|
||||
ProviderName *string `json:"provider_name"`
|
||||
InstanceName *string `json:"instance_name"`
|
||||
ModelName *string `json:"model_name"`
|
||||
ModelID *string `json:"model_id"`
|
||||
Content []byte `json:"content"`
|
||||
URL *string `json:"url"`
|
||||
}
|
||||
@@ -1531,28 +1591,38 @@ func (h *ProviderHandler) ParseFile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.ProviderName == nil || *req.ProviderName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.ModelID == nil {
|
||||
if req.ProviderName == nil || *req.ProviderName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.InstanceName == nil || *req.InstanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.InstanceName == nil || *req.InstanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelName == nil || *req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
if req.ModelName == nil || *req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if *req.ModelID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model ID is empty",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userID := c.GetString("user_id")
|
||||
@@ -1569,8 +1639,7 @@ func (h *ProviderHandler) ParseFile(c *gin.Context) {
|
||||
var errorCode common.ErrorCode
|
||||
var err error
|
||||
|
||||
response, errorCode, err = h.modelProviderService.ParseFile(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Content, req.URL, &apiConfig, &parseFileConfig)
|
||||
|
||||
response, errorCode, err = h.modelProviderService.ParseFile(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.Content, req.URL, &apiConfig, &parseFileConfig)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": errorCode,
|
||||
|
||||
@@ -399,6 +399,9 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
// provider handler because that's where the
|
||||
// modelProviderService is wired.
|
||||
model.GET("/", r.providerHandler.ListTenantAddedModels)
|
||||
|
||||
// TODO: list default models?
|
||||
//model.GET("/", r.tenantHandler.GetModels)
|
||||
model.PATCH("/", r.tenantHandler.SetModels)
|
||||
// Tenant default-model selection (used by the agent
|
||||
// page's useFetchDefaultModels hook). Mirrors the
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user