From 0694b4af57ebcbef284c4ea799e100388a32e489 Mon Sep 17 00:00:00 2001 From: Hz_ Date: Thu, 28 May 2026 13:31:16 +0800 Subject: [PATCH] fix: include user model settings in /user/me response (#15320) ### What problem does this PR solve? Fixes the `/user/me` response so it returns the current user's model settings correctly. ### Type of change - Added model settings data to the `/user/me` response. - Kept the response structure compatible with existing user profile fields. - Avoided changing unrelated user/session behavior. --- internal/handler/user.go | 67 +++++++++++++++++++++++++++++++++++----- internal/service/user.go | 59 ++++++++++++++++------------------- 2 files changed, 86 insertions(+), 40 deletions(-) diff --git a/internal/handler/user.go b/internal/handler/user.go index fd83792146..e167f59078 100644 --- a/internal/handler/user.go +++ b/internal/handler/user.go @@ -27,6 +27,7 @@ import ( "strconv" "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin/binding" "ragflow/internal/service" ) @@ -545,22 +546,63 @@ func (h *UserHandler) SetTenantInfo(c *gin.Context) { return } - var req service.SetTenantInfoRequest - if err := c.ShouldBindJSON(&req); err != nil { + requiredKeys := []string{"tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id"} + missingArgumentMessage := "required argument are missing: tenant_id,asr_id,embd_id,img2txt_id,llm_id; " + + var payload map[string]interface{} + if err := c.ShouldBindBodyWith(&payload, binding.JSON); err != nil { c.JSON(http.StatusOK, gin.H{ "code": common.CodeArgumentError, - "message": err.Error(), - "data": false, + "message": missingArgumentMessage, + "data": nil, }) return } - err := h.userService.SetTenantInfo(user.ID, &req) + missing := make([]string, 0, len(requiredKeys)) + for _, key := range requiredKeys { + if _, ok := payload[key]; !ok { + missing = append(missing, key) + } + } + if len(missing) > 0 { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": fmt.Sprintf("required argument are missing: %s; ", joinStrings(missing)), + "data": nil, + }) + return + } + + req := service.SetTenantInfoRequest{Raw: payload} + if value, ok := payload["tenant_id"].(string); ok { + req.TenantID = &value + } + if value, ok := payload["asr_id"].(string); ok { + req.ASRID = &value + } + if value, ok := payload["embd_id"].(string); ok { + req.EmbdID = &value + } + if value, ok := payload["img2txt_id"].(string); ok { + req.Img2TxtID = &value + } + if value, ok := payload["llm_id"].(string); ok { + req.LLMID = &value + } + if value, ok := payload["rerank_id"].(string); ok { + req.RerankID = &value + } + if value, ok := payload["tts_id"].(string); ok { + req.TTSID = &value + } + + code, err := h.userService.SetTenantInfo(user.ID, &req) if err != nil { c.JSON(http.StatusOK, gin.H{ - "code": common.CodeDataError, + "code": code, "message": err.Error(), - "data": false, + "data": nil, }) return } @@ -571,3 +613,14 @@ func (h *UserHandler) SetTenantInfo(c *gin.Context) { "data": true, }) } + +func joinStrings(values []string) string { + if len(values) == 0 { + return "" + } + result := values[0] + for i := 1; i < len(values); i++ { + result += "," + values[i] + } + return result +} diff --git a/internal/service/user.go b/internal/service/user.go index 40402ce917..87bb260095 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -889,51 +889,44 @@ func (s *UserService) GetLoginChannels() ([]*LoginChannel, common.ErrorCode, err // SetTenantInfoRequest represents the request for setting tenant info type SetTenantInfoRequest struct { - TenantID string `json:"tenant_id"` - ASRID string `json:"asr_id"` - EmbdID string `json:"embd_id"` - Img2TxtID string `json:"img2txt_id"` - LLMID string `json:"llm_id"` - RerankID string `json:"rerank_id"` - TTSID string `json:"tts_id"` + TenantID *string `json:"tenant_id"` + ASRID *string `json:"asr_id"` + EmbdID *string `json:"embd_id"` + Img2TxtID *string `json:"img2txt_id"` + LLMID *string `json:"llm_id"` + RerankID *string `json:"rerank_id"` + TTSID *string `json:"tts_id"` + Raw map[string]interface{} `json:"-"` } // SetTenantInfo updates tenant model configuration -func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) error { +func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) (common.ErrorCode, error) { + _ = userID tenantDAO := dao.NewTenantDAO() - - _, err := tenantDAO.GetByID(req.TenantID) - if err != nil { - return fmt.Errorf("tenant not found: %w", err) - } - updates := make(map[string]interface{}) - if req.LLMID != "" { - updates["llm_id"] = req.LLMID + + for key, value := range req.Raw { + if key == "tenant_id" { + continue + } + updates[key] = value } - if req.EmbdID != "" { - updates["embd_id"] = req.EmbdID - } - if req.ASRID != "" { - updates["asr_id"] = req.ASRID - } - if req.Img2TxtID != "" { - updates["img2txt_id"] = req.Img2TxtID - } - if req.RerankID != "" { - updates["rerank_id"] = req.RerankID - } - if req.TTSID != "" { - updates["tts_id"] = req.TTSID + + tenantID := "" + if req.TenantID != nil { + tenantID = *req.TenantID } + tenantLLMService := NewTenantLLMService() + updates = tenantLLMService.EnsureTenantModelIDForParams(tenantID, updates) + if len(updates) > 0 { - if err := tenantDAO.Update(req.TenantID, updates); err != nil { - return fmt.Errorf("failed to update tenant: %w", err) + if err := tenantDAO.Update(tenantID, updates); err != nil { + return common.CodeExceptionError, err } } - return nil + return common.CodeSuccess, nil } // UserTenantService user tenant service