diff --git a/conf/models/huggingface.json b/conf/models/huggingface.json index c46ab4a46b..f1a7d942fb 100644 --- a/conf/models/huggingface.json +++ b/conf/models/huggingface.json @@ -1,7 +1,7 @@ { "name": "HuggingFace", "url": { - "default": "https://router.huggingface.co/v1/" + "default": "https://router.huggingface.co/v1" }, "url-suffix": { "chat": "chat/completions", diff --git a/internal/entity/models/huggingface.go b/internal/entity/models/huggingface.go index 1dad00a565..8684aedca1 100644 --- a/internal/entity/models/huggingface.go +++ b/internal/entity/models/huggingface.go @@ -26,12 +26,6 @@ func NewHuggingFaceModel(baseURL map[string]string, urlSuffix URLSuffix) *Huggin URLSuffix: urlSuffix, httpClient: &http.Client{ Timeout: 120 * time.Second, - Transport: &http.Transport{ - MaxIdleConns: 10, - MaxIdleConnsPerHost: 100, - IdleConnTimeout: 90 * time.Second, - DisableCompression: false, - }, }, } } @@ -41,12 +35,6 @@ func (h *HuggingFaceModel) NewInstance(baseURL map[string]string) ModelDriver { URLSuffix: h.URLSuffix, httpClient: &http.Client{ Timeout: 120 * time.Second, - Transport: &http.Transport{ - MaxIdleConns: 10, - MaxIdleConnsPerHost: 100, - IdleConnTimeout: 90 * time.Second, - DisableCompression: false, - }, }, } } @@ -204,7 +192,7 @@ func (h *HuggingFaceModel) ChatStreamlyWithSender(modelName string, messages []M region = *apiConfig.Region } - url := fmt.Sprintf("%s/chat/completions", h.BaseURL[region]) + url := fmt.Sprintf("%s/%s", h.BaseURL[region], h.URLSuffix.Chat) // Convert messages to API format apiMessages := make([]map[string]interface{}, len(messages)) @@ -356,6 +344,11 @@ func (h *HuggingFaceModel) Embed(modelName *string, texts []string, apiConfig *A return []EmbeddingData{}, nil } + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + if modelName == nil || *modelName == "" { return nil, fmt.Errorf("model name is required") } @@ -373,7 +366,7 @@ func (h *HuggingFaceModel) Embed(modelName *string, texts []string, apiConfig *A return nil, err } - url := fmt.Sprintf("https://router.huggingface.co/hf-inference/models/%s", *modelName) + url := fmt.Sprintf("%s/%s/%s", h.BaseURL[region], h.URLSuffix.Embedding, *modelName) req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil {