mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
fix: include user model settings in /user/me response (#15320)
### What problem does this PR solve? Fixes the `/user/me` response so it returns the current user's model settings correctly. ### Type of change - Added model settings data to the `/user/me` response. - Kept the response structure compatible with existing user profile fields. - Avoided changing unrelated user/session behavior.
This commit is contained in:
@@ -27,6 +27,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
@@ -545,22 +546,63 @@ func (h *UserHandler) SetTenantInfo(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var req service.SetTenantInfoRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
requiredKeys := []string{"tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id"}
|
||||
missingArgumentMessage := "required argument are missing: tenant_id,asr_id,embd_id,img2txt_id,llm_id; "
|
||||
|
||||
var payload map[string]interface{}
|
||||
if err := c.ShouldBindBodyWith(&payload, binding.JSON); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
"message": missingArgumentMessage,
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err := h.userService.SetTenantInfo(user.ID, &req)
|
||||
missing := make([]string, 0, len(requiredKeys))
|
||||
for _, key := range requiredKeys {
|
||||
if _, ok := payload[key]; !ok {
|
||||
missing = append(missing, key)
|
||||
}
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": fmt.Sprintf("required argument are missing: %s; ", joinStrings(missing)),
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
req := service.SetTenantInfoRequest{Raw: payload}
|
||||
if value, ok := payload["tenant_id"].(string); ok {
|
||||
req.TenantID = &value
|
||||
}
|
||||
if value, ok := payload["asr_id"].(string); ok {
|
||||
req.ASRID = &value
|
||||
}
|
||||
if value, ok := payload["embd_id"].(string); ok {
|
||||
req.EmbdID = &value
|
||||
}
|
||||
if value, ok := payload["img2txt_id"].(string); ok {
|
||||
req.Img2TxtID = &value
|
||||
}
|
||||
if value, ok := payload["llm_id"].(string); ok {
|
||||
req.LLMID = &value
|
||||
}
|
||||
if value, ok := payload["rerank_id"].(string); ok {
|
||||
req.RerankID = &value
|
||||
}
|
||||
if value, ok := payload["tts_id"].(string); ok {
|
||||
req.TTSID = &value
|
||||
}
|
||||
|
||||
code, err := h.userService.SetTenantInfo(user.ID, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeDataError,
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -571,3 +613,14 @@ func (h *UserHandler) SetTenantInfo(c *gin.Context) {
|
||||
"data": true,
|
||||
})
|
||||
}
|
||||
|
||||
func joinStrings(values []string) string {
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
}
|
||||
result := values[0]
|
||||
for i := 1; i < len(values); i++ {
|
||||
result += "," + values[i]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -889,51 +889,44 @@ func (s *UserService) GetLoginChannels() ([]*LoginChannel, common.ErrorCode, err
|
||||
|
||||
// SetTenantInfoRequest represents the request for setting tenant info
|
||||
type SetTenantInfoRequest struct {
|
||||
TenantID string `json:"tenant_id"`
|
||||
ASRID string `json:"asr_id"`
|
||||
EmbdID string `json:"embd_id"`
|
||||
Img2TxtID string `json:"img2txt_id"`
|
||||
LLMID string `json:"llm_id"`
|
||||
RerankID string `json:"rerank_id"`
|
||||
TTSID string `json:"tts_id"`
|
||||
TenantID *string `json:"tenant_id"`
|
||||
ASRID *string `json:"asr_id"`
|
||||
EmbdID *string `json:"embd_id"`
|
||||
Img2TxtID *string `json:"img2txt_id"`
|
||||
LLMID *string `json:"llm_id"`
|
||||
RerankID *string `json:"rerank_id"`
|
||||
TTSID *string `json:"tts_id"`
|
||||
Raw map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// SetTenantInfo updates tenant model configuration
|
||||
func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) error {
|
||||
func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) (common.ErrorCode, error) {
|
||||
_ = userID
|
||||
tenantDAO := dao.NewTenantDAO()
|
||||
|
||||
_, err := tenantDAO.GetByID(req.TenantID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tenant not found: %w", err)
|
||||
}
|
||||
|
||||
updates := make(map[string]interface{})
|
||||
if req.LLMID != "" {
|
||||
updates["llm_id"] = req.LLMID
|
||||
|
||||
for key, value := range req.Raw {
|
||||
if key == "tenant_id" {
|
||||
continue
|
||||
}
|
||||
updates[key] = value
|
||||
}
|
||||
if req.EmbdID != "" {
|
||||
updates["embd_id"] = req.EmbdID
|
||||
}
|
||||
if req.ASRID != "" {
|
||||
updates["asr_id"] = req.ASRID
|
||||
}
|
||||
if req.Img2TxtID != "" {
|
||||
updates["img2txt_id"] = req.Img2TxtID
|
||||
}
|
||||
if req.RerankID != "" {
|
||||
updates["rerank_id"] = req.RerankID
|
||||
}
|
||||
if req.TTSID != "" {
|
||||
updates["tts_id"] = req.TTSID
|
||||
|
||||
tenantID := ""
|
||||
if req.TenantID != nil {
|
||||
tenantID = *req.TenantID
|
||||
}
|
||||
|
||||
tenantLLMService := NewTenantLLMService()
|
||||
updates = tenantLLMService.EnsureTenantModelIDForParams(tenantID, updates)
|
||||
|
||||
if len(updates) > 0 {
|
||||
if err := tenantDAO.Update(req.TenantID, updates); err != nil {
|
||||
return fmt.Errorf("failed to update tenant: %w", err)
|
||||
if err := tenantDAO.Update(tenantID, updates); err != nil {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// UserTenantService user tenant service
|
||||
|
||||
Reference in New Issue
Block a user