mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Add Replicate chat provider (#14958)
## What - Add Replicate as a chat provider backed by the documented predictions API - Register Replicate in the Go model factory and provider config - Support non-streaming chat through sync predictions, polling fallback, streaming through `urls.stream`, model listing, and connection checks ## Notes - Uses `POST /v1/predictions` with Replicate model identifiers in `version`, which supports official and community model identifiers - Maps RAGFlow messages into Replicate prompt-shaped inputs (`prompt`, optional `system_prompt`) and forwards common documented LLM inputs: `max_new_tokens`, `temperature`, `top_p` - Preserves whitespace in SSE output chunks and emits RAGFlow `[DONE]` at stream completion ## Tests - `go test -vet=off -run TestReplicate -count=1 ./internal/entity/models` - `go test -vet=off -count=1 ./internal/entity/models` Refs #14736
This commit is contained in:
27
conf/models/replicate.json
Normal file
27
conf/models/replicate.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"name": "Replicate",
|
||||
"url": {
|
||||
"default": "https://api.replicate.com"
|
||||
},
|
||||
"url_suffix": {
|
||||
"chat": "v1/predictions",
|
||||
"models": "v1/models"
|
||||
},
|
||||
"class": "replicate",
|
||||
"models": [
|
||||
{
|
||||
"name": "meta/meta-llama-3-70b-instruct",
|
||||
"max_tokens": 8192,
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "meta/meta-llama-3-8b-instruct",
|
||||
"max_tokens": 8192,
|
||||
"model_types": [
|
||||
"chat"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -93,6 +93,8 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string
|
||||
return NewLongCatModel(baseURL, urlSuffix), nil
|
||||
case "novita":
|
||||
return NewNovitaModel(baseURL, urlSuffix), nil
|
||||
case "replicate":
|
||||
return NewReplicateModel(baseURL, urlSuffix), nil
|
||||
case "voyage":
|
||||
return NewVoyageModel(baseURL, urlSuffix), nil
|
||||
case "paddleocr":
|
||||
|
||||
611
internal/entity/models/replicate.go
Normal file
611
internal/entity/models/replicate.go
Normal file
@@ -0,0 +1,611 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const replicatePollInterval = time.Second
|
||||
|
||||
type ReplicateModel struct {
|
||||
BaseURL map[string]string
|
||||
URLSuffix URLSuffix
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewReplicateModel(baseURL map[string]string, urlSuffix URLSuffix) *ReplicateModel {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.MaxIdleConns = 100
|
||||
transport.MaxIdleConnsPerHost = 10
|
||||
transport.IdleConnTimeout = 90 * time.Second
|
||||
transport.DisableCompression = false
|
||||
transport.ResponseHeaderTimeout = 60 * time.Second
|
||||
|
||||
return &ReplicateModel{
|
||||
BaseURL: baseURL,
|
||||
URLSuffix: urlSuffix,
|
||||
httpClient: &http.Client{
|
||||
Transport: transport,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) NewInstance(baseURL map[string]string) ModelDriver {
|
||||
return NewReplicateModel(baseURL, r.URLSuffix)
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) Name() string {
|
||||
return "replicate"
|
||||
}
|
||||
|
||||
type replicatePredictionURLs struct {
|
||||
Get string `json:"get"`
|
||||
Stream string `json:"stream"`
|
||||
}
|
||||
|
||||
type replicatePrediction struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Output interface{} `json:"output"`
|
||||
Error interface{} `json:"error"`
|
||||
URLs replicatePredictionURLs `json:"urls"`
|
||||
}
|
||||
|
||||
type replicateModelsResponse struct {
|
||||
Results []struct {
|
||||
Owner string `json:"owner"`
|
||||
Name string `json:"name"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
type replicateSSEEvent struct {
|
||||
event string
|
||||
data string
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) baseURLForRegion(region string) (string, error) {
|
||||
base, ok := r.BaseURL[region]
|
||||
if !ok || base == "" {
|
||||
return "", fmt.Errorf("replicate: no base URL configured for region %q", region)
|
||||
}
|
||||
return strings.TrimSuffix(base, "/"), nil
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) endpoint(apiConfig *APIConfig, suffix string) (string, error) {
|
||||
region := "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL, err := r.baseURLForRegion(region)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("%s/%s", baseURL, suffix), nil
|
||||
}
|
||||
|
||||
func replicateUsesVersionEndpoint(modelName string) bool {
|
||||
name := strings.TrimSpace(modelName)
|
||||
return !strings.Contains(name, "/") || strings.Contains(name, ":")
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) predictionEndpoint(apiConfig *APIConfig, modelName string) (string, string, error) {
|
||||
if replicateUsesVersionEndpoint(modelName) {
|
||||
endpoint, err := r.endpoint(apiConfig, r.URLSuffix.Chat)
|
||||
return endpoint, modelName, err
|
||||
}
|
||||
|
||||
parts := strings.Split(modelName, "/")
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return "", "", fmt.Errorf("replicate: official model name must be owner/name")
|
||||
}
|
||||
|
||||
modelsPrefix := strings.TrimSuffix(r.URLSuffix.Models, "models")
|
||||
if modelsPrefix == "" {
|
||||
modelsPrefix = "v1/"
|
||||
}
|
||||
officialSuffix := fmt.Sprintf("%smodels/%s/%s/predictions",
|
||||
modelsPrefix,
|
||||
url.PathEscape(parts[0]),
|
||||
url.PathEscape(parts[1]),
|
||||
)
|
||||
endpoint, err := r.endpoint(apiConfig, officialSuffix)
|
||||
return endpoint, "", err
|
||||
}
|
||||
|
||||
func replicateMessageContent(content interface{}) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
default:
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return fmt.Sprint(v)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
|
||||
func replicatePromptFromMessages(messages []Message) (string, string) {
|
||||
var systemParts []string
|
||||
var promptParts []string
|
||||
nonSystemCount := 0
|
||||
for _, msg := range messages {
|
||||
content := replicateMessageContent(msg.Content)
|
||||
if msg.Role == "system" {
|
||||
systemParts = append(systemParts, content)
|
||||
continue
|
||||
}
|
||||
nonSystemCount++
|
||||
if nonSystemCount == 1 && msg.Role == "user" && len(messages) == len(systemParts)+1 {
|
||||
promptParts = append(promptParts, content)
|
||||
continue
|
||||
}
|
||||
promptParts = append(promptParts, fmt.Sprintf("%s: %s", msg.Role, content))
|
||||
}
|
||||
return strings.Join(promptParts, "\n"), strings.Join(systemParts, "\n\n")
|
||||
}
|
||||
|
||||
func replicateInputFromMessages(messages []Message, chatModelConfig *ChatConfig) map[string]interface{} {
|
||||
prompt, systemPrompt := replicatePromptFromMessages(messages)
|
||||
input := map[string]interface{}{
|
||||
"prompt": prompt,
|
||||
}
|
||||
if systemPrompt != "" {
|
||||
input["system_prompt"] = systemPrompt
|
||||
}
|
||||
if chatModelConfig != nil {
|
||||
if chatModelConfig.MaxTokens != nil {
|
||||
input["max_new_tokens"] = *chatModelConfig.MaxTokens
|
||||
}
|
||||
if chatModelConfig.Temperature != nil {
|
||||
input["temperature"] = *chatModelConfig.Temperature
|
||||
}
|
||||
if chatModelConfig.TopP != nil {
|
||||
input["top_p"] = *chatModelConfig.TopP
|
||||
}
|
||||
// Replicate model inputs are model-specific. Forward only the
|
||||
// common prompt-model fields above; Stop is intentionally
|
||||
// omitted because upstream behavior is undefined for many
|
||||
// hosted models.
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
func replicateOutputToString(output interface{}) (string, error) {
|
||||
switch v := output.(type) {
|
||||
case nil:
|
||||
return "", nil
|
||||
case string:
|
||||
return v, nil
|
||||
case []interface{}:
|
||||
var b strings.Builder
|
||||
for _, item := range v {
|
||||
text, err := replicateOutputToString(item)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b.WriteString(text)
|
||||
}
|
||||
return b.String(), nil
|
||||
case map[string]interface{}:
|
||||
raw, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(raw), nil
|
||||
default:
|
||||
return fmt.Sprint(v), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) createPrediction(ctx context.Context, url string, version string, input map[string]interface{}, stream bool, apiKey string, preferWait bool) (*replicatePrediction, error) {
|
||||
body := map[string]interface{}{
|
||||
"input": input,
|
||||
"stream": stream,
|
||||
}
|
||||
if version != "" {
|
||||
body["version"] = version
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
||||
if preferWait {
|
||||
req.Header.Set("Prefer", "wait=60")
|
||||
}
|
||||
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var prediction replicatePrediction
|
||||
if err = json.Unmarshal(bodyBytes, &prediction); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
if prediction.Error != nil {
|
||||
return nil, fmt.Errorf("replicate: upstream error: %v", prediction.Error)
|
||||
}
|
||||
return &prediction, nil
|
||||
}
|
||||
|
||||
func replicatePredictionDone(status string) bool {
|
||||
return replicatePredictionSucceeded(status) || status == "failed" || status == "canceled"
|
||||
}
|
||||
|
||||
func replicatePredictionSucceeded(status string) bool {
|
||||
return status == "successful"
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) getPrediction(ctx context.Context, url string, apiKey string) (*replicatePrediction, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
||||
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var prediction replicatePrediction
|
||||
if err = json.Unmarshal(body, &prediction); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
if prediction.Error != nil {
|
||||
return nil, fmt.Errorf("replicate: upstream error: %v", prediction.Error)
|
||||
}
|
||||
return &prediction, nil
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) waitForPrediction(ctx context.Context, prediction *replicatePrediction, apiKey string) (*replicatePrediction, error) {
|
||||
if prediction == nil {
|
||||
return nil, fmt.Errorf("replicate: empty prediction response")
|
||||
}
|
||||
if replicatePredictionDone(prediction.Status) {
|
||||
return prediction, nil
|
||||
}
|
||||
if prediction.URLs.Get == "" {
|
||||
return nil, fmt.Errorf("replicate: prediction is %q and no polling URL was returned", prediction.Status)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(replicatePollInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("replicate: prediction did not finish before timeout: %w", ctx.Err())
|
||||
case <-ticker.C:
|
||||
next, err := r.getPrediction(ctx, prediction.URLs.Get, apiKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if replicatePredictionDone(next.Status) {
|
||||
return next, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("api key is required")
|
||||
}
|
||||
if strings.TrimSpace(modelName) == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("messages is empty")
|
||||
}
|
||||
|
||||
url, version, err := r.predictionEndpoint(apiConfig, modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
prediction, err := r.createPrediction(ctx, url, version, replicateInputFromMessages(messages, chatModelConfig), false, *apiConfig.ApiKey, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prediction, err = r.waitForPrediction(ctx, prediction, *apiConfig.ApiKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !replicatePredictionSucceeded(prediction.Status) {
|
||||
return nil, fmt.Errorf("replicate: prediction ended with status %q", prediction.Status)
|
||||
}
|
||||
|
||||
answer, err := replicateOutputToString(prediction.Output)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse prediction output: %w", err)
|
||||
}
|
||||
reasonContent := ""
|
||||
return &ChatResponse{Answer: &answer, ReasonContent: &reasonContent}, nil
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error {
|
||||
if sender == nil {
|
||||
return fmt.Errorf("sender is required")
|
||||
}
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return fmt.Errorf("api key is required")
|
||||
}
|
||||
if strings.TrimSpace(modelName) == "" {
|
||||
return fmt.Errorf("model name is required")
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return fmt.Errorf("messages is empty")
|
||||
}
|
||||
if chatModelConfig != nil && chatModelConfig.Stream != nil && !*chatModelConfig.Stream {
|
||||
return fmt.Errorf("stream must be true in ChatStreamlyWithSender")
|
||||
}
|
||||
|
||||
url, version, err := r.predictionEndpoint(apiConfig, modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prediction, err := r.createPrediction(context.Background(), url, version, replicateInputFromMessages(messages, chatModelConfig), true, *apiConfig.ApiKey, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if prediction.URLs.Stream == "" {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
prediction, err = r.waitForPrediction(ctx, prediction, *apiConfig.ApiKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
answer, err := replicateOutputToString(prediction.Output)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse prediction output: %w", err)
|
||||
}
|
||||
if answer != "" {
|
||||
if err := sender(&answer, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
endOfStream := "[DONE]"
|
||||
return sender(&endOfStream, nil)
|
||||
}
|
||||
|
||||
return r.readPredictionStream(prediction.URLs.Stream, *apiConfig.ApiKey, sender)
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) readPredictionStream(url string, apiKey string, sender func(*string, *string) error) error {
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||||
current := replicateSSEEvent{}
|
||||
sawDone := false
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
done, err := dispatchReplicateSSEEvent(current, sender)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
sawDone = true
|
||||
break
|
||||
}
|
||||
current = replicateSSEEvent{}
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "event:") {
|
||||
current.event = strings.TrimSpace(line[6:])
|
||||
}
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
if current.data != "" {
|
||||
current.data += "\n"
|
||||
}
|
||||
data := line[5:]
|
||||
if strings.HasPrefix(data, " ") {
|
||||
data = data[1:]
|
||||
}
|
||||
current.data += data
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("failed to scan response body: %w", err)
|
||||
}
|
||||
if !sawDone && (current.event != "" || current.data != "") {
|
||||
done, err := dispatchReplicateSSEEvent(current, sender)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sawDone = done
|
||||
}
|
||||
if !sawDone {
|
||||
return fmt.Errorf("replicate: stream ended before done event")
|
||||
}
|
||||
|
||||
endOfStream := "[DONE]"
|
||||
return sender(&endOfStream, nil)
|
||||
}
|
||||
|
||||
func dispatchReplicateSSEEvent(event replicateSSEEvent, sender func(*string, *string) error) (bool, error) {
|
||||
switch event.event {
|
||||
case "output", "":
|
||||
if event.data == "" {
|
||||
return false, nil
|
||||
}
|
||||
return false, sender(&event.data, nil)
|
||||
case "error":
|
||||
return false, fmt.Errorf("replicate: upstream stream error: %s", event.data)
|
||||
case "done":
|
||||
return true, nil
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("api key is required")
|
||||
}
|
||||
|
||||
url, err := r.endpoint(apiConfig, r.URLSuffix.Models)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result replicateModelsResponse
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(result.Results))
|
||||
for _, model := range result.Results {
|
||||
if model.Owner != "" && model.Name != "" {
|
||||
models = append(models, fmt.Sprintf("%s/%s", model.Owner, model.Name))
|
||||
}
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
_, err := r.ListModels(apiConfig)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", r.Name())
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", r.Name())
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", r.Name())
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", r.Name())
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("%s, no such method", r.Name())
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", r.Name())
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("%s, no such method", r.Name())
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", r.Name())
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", r.Name())
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", r.Name())
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", r.Name())
|
||||
}
|
||||
321
internal/entity/models/replicate_test.go
Normal file
321
internal/entity/models/replicate_test.go
Normal file
@@ -0,0 +1,321 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newReplicateForTest(baseURL string) *ReplicateModel {
|
||||
return NewReplicateModel(
|
||||
map[string]string{"default": baseURL},
|
||||
URLSuffix{Chat: "v1/predictions", Models: "v1/models"},
|
||||
)
|
||||
}
|
||||
|
||||
func TestReplicateName(t *testing.T) {
|
||||
if got := newReplicateForTest("http://unused").Name(); got != "replicate" {
|
||||
t.Errorf("Name()=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateFactory(t *testing.T) {
|
||||
driver, err := NewModelFactory().CreateModelDriver("Replicate", map[string]string{"default": "http://unused"}, URLSuffix{})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateModelDriver: %v", err)
|
||||
}
|
||||
if _, ok := driver.(*ReplicateModel); !ok {
|
||||
t.Fatalf("driver type=%T, want *ReplicateModel", driver)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicatePromptFromMessages(t *testing.T) {
|
||||
prompt, system := replicatePromptFromMessages([]Message{
|
||||
{Role: "system", Content: "be terse"},
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi"},
|
||||
{Role: "user", Content: map[string]interface{}{"text": "again"}},
|
||||
})
|
||||
if system != "be terse" {
|
||||
t.Errorf("system=%q", system)
|
||||
}
|
||||
want := "user: hello\nassistant: hi\nuser: {\"text\":\"again\"}"
|
||||
if prompt != want {
|
||||
t.Errorf("prompt=%q want %q", prompt, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicatePredictionEndpoint(t *testing.T) {
|
||||
m := newReplicateForTest("https://api.example.test")
|
||||
|
||||
endpoint, version, err := m.predictionEndpoint(&APIConfig{}, "meta/meta-llama-3-70b-instruct")
|
||||
if err != nil {
|
||||
t.Fatalf("official endpoint: %v", err)
|
||||
}
|
||||
if endpoint != "https://api.example.test/v1/models/meta/meta-llama-3-70b-instruct/predictions" {
|
||||
t.Errorf("official endpoint=%q", endpoint)
|
||||
}
|
||||
if version != "" {
|
||||
t.Errorf("official version=%q want empty", version)
|
||||
}
|
||||
|
||||
endpoint, version, err = m.predictionEndpoint(&APIConfig{}, "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
|
||||
if err != nil {
|
||||
t.Fatalf("version endpoint: %v", err)
|
||||
}
|
||||
if endpoint != "https://api.example.test/v1/predictions" {
|
||||
t.Errorf("version endpoint=%q", endpoint)
|
||||
}
|
||||
if version != "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" {
|
||||
t.Errorf("version=%q", version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateOfficialChatHappyPath(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/models/meta/meta-llama-3-70b-instruct/predictions" {
|
||||
t.Errorf("path=%s", r.URL.Path)
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
|
||||
t.Errorf("Authorization=%q", got)
|
||||
}
|
||||
if got := r.Header.Get("Prefer"); got != "wait=60" {
|
||||
t.Errorf("Prefer=%q", got)
|
||||
}
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &body); err != nil {
|
||||
t.Errorf("body: %v", err)
|
||||
return
|
||||
}
|
||||
if body["version"] != nil {
|
||||
t.Errorf("official model requests must not send version=%v", body["version"])
|
||||
}
|
||||
if body["stream"] != false {
|
||||
t.Errorf("stream=%v", body["stream"])
|
||||
}
|
||||
input := body["input"].(map[string]interface{})
|
||||
if input["prompt"] != "hello" {
|
||||
t.Errorf("prompt=%v", input["prompt"])
|
||||
}
|
||||
if input["system_prompt"] != "be helpful" {
|
||||
t.Errorf("system_prompt=%v", input["system_prompt"])
|
||||
}
|
||||
if input["max_new_tokens"] != float64(128) {
|
||||
t.Errorf("max_new_tokens=%v", input["max_new_tokens"])
|
||||
}
|
||||
// Stop is deliberately filtered out because Replicate model
|
||||
// inputs are model-specific and upstream support is undefined.
|
||||
if input["stop"] != nil {
|
||||
t.Errorf("unexpected stop=%v", input["stop"])
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"status": "successful",
|
||||
"output": []string{"hel", "lo"},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
maxTokens := 128
|
||||
stop := []string{"END"}
|
||||
resp, err := newReplicateForTest(srv.URL).ChatWithMessages(
|
||||
"meta/meta-llama-3-70b-instruct",
|
||||
[]Message{{Role: "system", Content: "be helpful"}, {Role: "user", Content: "hello"}},
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
&ChatConfig{MaxTokens: &maxTokens, Stop: &stop},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("ChatWithMessages: %v", err)
|
||||
}
|
||||
if *resp.Answer != "hello" {
|
||||
t.Errorf("Answer=%q", *resp.Answer)
|
||||
}
|
||||
if *resp.ReasonContent != "" {
|
||||
t.Errorf("ReasonContent=%q", *resp.ReasonContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateCommunityChatUsesVersionEndpoint(t *testing.T) {
|
||||
const version = "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/predictions" {
|
||||
t.Errorf("path=%s", r.URL.Path)
|
||||
}
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &body); err != nil {
|
||||
t.Errorf("body: %v", err)
|
||||
return
|
||||
}
|
||||
if body["version"] != version {
|
||||
t.Errorf("version=%v", body["version"])
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"status": "successful",
|
||||
"output": "ok",
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
resp, err := newReplicateForTest(srv.URL).ChatWithMessages(
|
||||
version,
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
&APIConfig{ApiKey: &apiKey}, nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("ChatWithMessages: %v", err)
|
||||
}
|
||||
if *resp.Answer != "ok" {
|
||||
t.Errorf("Answer=%q", *resp.Answer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateChatPollsUntilSucceeded(t *testing.T) {
|
||||
var getCount int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
|
||||
t.Errorf("Authorization=%q", got)
|
||||
}
|
||||
switch r.URL.Path {
|
||||
case "/v1/models/meta/meta-llama-3-70b-instruct/predictions":
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"status": "processing",
|
||||
"urls": map[string]string{
|
||||
"get": "http://" + r.Host + "/v1/predictions/p1",
|
||||
},
|
||||
})
|
||||
case "/v1/predictions/p1":
|
||||
getCount++
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"status": "successful",
|
||||
"output": "done",
|
||||
})
|
||||
default:
|
||||
t.Errorf("unexpected path=%s", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
resp, err := newReplicateForTest(srv.URL).ChatWithMessages(
|
||||
"meta/meta-llama-3-70b-instruct",
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
&APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ChatWithMessages: %v", err)
|
||||
}
|
||||
if getCount != 1 {
|
||||
t.Errorf("getCount=%d", getCount)
|
||||
}
|
||||
if *resp.Answer != "done" {
|
||||
t.Errorf("Answer=%q", *resp.Answer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateStreamHappyPath(t *testing.T) {
|
||||
var streamURL string
|
||||
streamSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Accept"); got != "text/event-stream" {
|
||||
t.Errorf("Accept=%q", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = io.WriteString(w, "event: output\n")
|
||||
_, _ = io.WriteString(w, "data: Hello\n\n")
|
||||
_, _ = io.WriteString(w, "event: output\n")
|
||||
_, _ = io.WriteString(w, "data: world\n\n")
|
||||
_, _ = io.WriteString(w, "event: done\n")
|
||||
_, _ = io.WriteString(w, "data: {}\n\n")
|
||||
}))
|
||||
defer streamSrv.Close()
|
||||
streamURL = streamSrv.URL
|
||||
|
||||
apiSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/models/meta/meta-llama-3-70b-instruct/predictions" {
|
||||
t.Errorf("path=%s", r.URL.Path)
|
||||
}
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &body); err != nil {
|
||||
t.Errorf("body: %v", err)
|
||||
return
|
||||
}
|
||||
if body["stream"] != true {
|
||||
t.Errorf("stream=%v", body["stream"])
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"status": "starting",
|
||||
"urls": map[string]string{
|
||||
"stream": streamURL,
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer apiSrv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
var chunks []string
|
||||
err := newReplicateForTest(apiSrv.URL).ChatStreamlyWithSender(
|
||||
"meta/meta-llama-3-70b-instruct",
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
&APIConfig{ApiKey: &apiKey}, nil,
|
||||
func(c *string, _ *string) error {
|
||||
if c != nil {
|
||||
chunks = append(chunks, *c)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ChatStreamlyWithSender: %v", err)
|
||||
}
|
||||
if strings.Join(chunks, "") != "Hello world[DONE]" {
|
||||
t.Errorf("chunks=%q", strings.Join(chunks, ""))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateListModelsAndCheckConnection(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/models" {
|
||||
t.Errorf("path=%s", r.URL.Path)
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
|
||||
t.Errorf("Authorization=%q", got)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"results": []map[string]string{
|
||||
{"owner": "meta", "name": "meta-llama-3-70b-instruct"},
|
||||
{"owner": "replicate", "name": "hello-world"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := newReplicateForTest(srv.URL)
|
||||
models, err := model.ListModels(&APIConfig{ApiKey: &apiKey})
|
||||
if err != nil {
|
||||
t.Fatalf("ListModels: %v", err)
|
||||
}
|
||||
if strings.Join(models, ",") != "meta/meta-llama-3-70b-instruct,replicate/hello-world" {
|
||||
t.Errorf("models=%v", models)
|
||||
}
|
||||
if err := model.CheckConnection(&APIConfig{ApiKey: &apiKey}); err != nil {
|
||||
t.Fatalf("CheckConnection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateUnsupportedMethods(t *testing.T) {
|
||||
m := newReplicateForTest("http://unused")
|
||||
if _, err := m.Embed(nil, nil, nil, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
t.Errorf("Embed error=%v", err)
|
||||
}
|
||||
if _, err := m.Rerank(nil, "", nil, nil, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
t.Errorf("Rerank error=%v", err)
|
||||
}
|
||||
if _, err := m.Balance(nil); err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
t.Errorf("Balance error=%v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user