diff --git a/conf/models/replicate.json b/conf/models/replicate.json new file mode 100644 index 0000000000..91111351ad --- /dev/null +++ b/conf/models/replicate.json @@ -0,0 +1,27 @@ +{ + "name": "Replicate", + "url": { + "default": "https://api.replicate.com" + }, + "url_suffix": { + "chat": "v1/predictions", + "models": "v1/models" + }, + "class": "replicate", + "models": [ + { + "name": "meta/meta-llama-3-70b-instruct", + "max_tokens": 8192, + "model_types": [ + "chat" + ] + }, + { + "name": "meta/meta-llama-3-8b-instruct", + "max_tokens": 8192, + "model_types": [ + "chat" + ] + } + ] +} diff --git a/internal/entity/models/factory.go b/internal/entity/models/factory.go index e3719d3474..d8ccc9f766 100644 --- a/internal/entity/models/factory.go +++ b/internal/entity/models/factory.go @@ -93,6 +93,8 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string return NewLongCatModel(baseURL, urlSuffix), nil case "novita": return NewNovitaModel(baseURL, urlSuffix), nil + case "replicate": + return NewReplicateModel(baseURL, urlSuffix), nil case "voyage": return NewVoyageModel(baseURL, urlSuffix), nil case "paddleocr": diff --git a/internal/entity/models/replicate.go b/internal/entity/models/replicate.go new file mode 100644 index 0000000000..0757b83250 --- /dev/null +++ b/internal/entity/models/replicate.go @@ -0,0 +1,611 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +const replicatePollInterval = time.Second + +type ReplicateModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func NewReplicateModel(baseURL map[string]string, urlSuffix URLSuffix) *ReplicateModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &ReplicateModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (r *ReplicateModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewReplicateModel(baseURL, r.URLSuffix) +} + +func (r *ReplicateModel) Name() string { + return "replicate" +} + +type replicatePredictionURLs struct { + Get string `json:"get"` + Stream string `json:"stream"` +} + +type replicatePrediction struct { + ID string `json:"id"` + Status string `json:"status"` + Output interface{} `json:"output"` + Error interface{} `json:"error"` + URLs replicatePredictionURLs `json:"urls"` +} + +type replicateModelsResponse struct { + Results []struct { + Owner string `json:"owner"` + Name string `json:"name"` + } `json:"results"` +} + +type replicateSSEEvent struct { + event string + data string +} + +func (r *ReplicateModel) baseURLForRegion(region string) (string, error) { + base, ok := r.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("replicate: no base URL configured for region %q", region) + } + return strings.TrimSuffix(base, "/"), nil +} + +func (r *ReplicateModel) endpoint(apiConfig *APIConfig, suffix string) (string, error) { + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := r.baseURLForRegion(region) + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s", baseURL, suffix), nil +} + +func replicateUsesVersionEndpoint(modelName string) bool { + name := strings.TrimSpace(modelName) + return !strings.Contains(name, "/") || strings.Contains(name, ":") +} + +func (r *ReplicateModel) predictionEndpoint(apiConfig *APIConfig, modelName string) (string, string, error) { + if replicateUsesVersionEndpoint(modelName) { + endpoint, err := r.endpoint(apiConfig, r.URLSuffix.Chat) + return endpoint, modelName, err + } + + parts := strings.Split(modelName, "/") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return "", "", fmt.Errorf("replicate: official model name must be owner/name") + } + + modelsPrefix := strings.TrimSuffix(r.URLSuffix.Models, "models") + if modelsPrefix == "" { + modelsPrefix = "v1/" + } + officialSuffix := fmt.Sprintf("%smodels/%s/%s/predictions", + modelsPrefix, + url.PathEscape(parts[0]), + url.PathEscape(parts[1]), + ) + endpoint, err := r.endpoint(apiConfig, officialSuffix) + return endpoint, "", err +} + +func replicateMessageContent(content interface{}) string { + switch v := content.(type) { + case string: + return v + default: + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprint(v) + } + return string(b) + } +} + +func replicatePromptFromMessages(messages []Message) (string, string) { + var systemParts []string + var promptParts []string + nonSystemCount := 0 + for _, msg := range messages { + content := replicateMessageContent(msg.Content) + if msg.Role == "system" { + systemParts = append(systemParts, content) + continue + } + nonSystemCount++ + if nonSystemCount == 1 && msg.Role == "user" && len(messages) == len(systemParts)+1 { + promptParts = append(promptParts, content) + continue + } + promptParts = append(promptParts, fmt.Sprintf("%s: %s", msg.Role, content)) + } + return strings.Join(promptParts, "\n"), strings.Join(systemParts, "\n\n") +} + +func replicateInputFromMessages(messages []Message, chatModelConfig *ChatConfig) map[string]interface{} { + prompt, systemPrompt := replicatePromptFromMessages(messages) + input := map[string]interface{}{ + "prompt": prompt, + } + if systemPrompt != "" { + input["system_prompt"] = systemPrompt + } + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + input["max_new_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + input["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + input["top_p"] = *chatModelConfig.TopP + } + // Replicate model inputs are model-specific. Forward only the + // common prompt-model fields above; Stop is intentionally + // omitted because upstream behavior is undefined for many + // hosted models. + } + return input +} + +func replicateOutputToString(output interface{}) (string, error) { + switch v := output.(type) { + case nil: + return "", nil + case string: + return v, nil + case []interface{}: + var b strings.Builder + for _, item := range v { + text, err := replicateOutputToString(item) + if err != nil { + return "", err + } + b.WriteString(text) + } + return b.String(), nil + case map[string]interface{}: + raw, err := json.Marshal(v) + if err != nil { + return "", err + } + return string(raw), nil + default: + return fmt.Sprint(v), nil + } +} + +func (r *ReplicateModel) createPrediction(ctx context.Context, url string, version string, input map[string]interface{}, stream bool, apiKey string, preferWait bool) (*replicatePrediction, error) { + body := map[string]interface{}{ + "input": input, + "stream": stream, + } + if version != "" { + body["version"] = version + } + + jsonData, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + if preferWait { + req.Header.Set("Prefer", "wait=60") + } + + resp, err := r.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var prediction replicatePrediction + if err = json.Unmarshal(bodyBytes, &prediction); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + if prediction.Error != nil { + return nil, fmt.Errorf("replicate: upstream error: %v", prediction.Error) + } + return &prediction, nil +} + +func replicatePredictionDone(status string) bool { + return replicatePredictionSucceeded(status) || status == "failed" || status == "canceled" +} + +func replicatePredictionSucceeded(status string) bool { + return status == "successful" +} + +func (r *ReplicateModel) getPrediction(ctx context.Context, url string, apiKey string) (*replicatePrediction, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + + resp, err := r.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var prediction replicatePrediction + if err = json.Unmarshal(body, &prediction); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + if prediction.Error != nil { + return nil, fmt.Errorf("replicate: upstream error: %v", prediction.Error) + } + return &prediction, nil +} + +func (r *ReplicateModel) waitForPrediction(ctx context.Context, prediction *replicatePrediction, apiKey string) (*replicatePrediction, error) { + if prediction == nil { + return nil, fmt.Errorf("replicate: empty prediction response") + } + if replicatePredictionDone(prediction.Status) { + return prediction, nil + } + if prediction.URLs.Get == "" { + return nil, fmt.Errorf("replicate: prediction is %q and no polling URL was returned", prediction.Status) + } + + ticker := time.NewTicker(replicatePollInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("replicate: prediction did not finish before timeout: %w", ctx.Err()) + case <-ticker.C: + next, err := r.getPrediction(ctx, prediction.URLs.Get, apiKey) + if err != nil { + return nil, err + } + if replicatePredictionDone(next.Status) { + return next, nil + } + } + } +} + +func (r *ReplicateModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if strings.TrimSpace(modelName) == "" { + return nil, fmt.Errorf("model name is required") + } + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + url, version, err := r.predictionEndpoint(apiConfig, modelName) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + prediction, err := r.createPrediction(ctx, url, version, replicateInputFromMessages(messages, chatModelConfig), false, *apiConfig.ApiKey, true) + if err != nil { + return nil, err + } + prediction, err = r.waitForPrediction(ctx, prediction, *apiConfig.ApiKey) + if err != nil { + return nil, err + } + if !replicatePredictionSucceeded(prediction.Status) { + return nil, fmt.Errorf("replicate: prediction ended with status %q", prediction.Status) + } + + answer, err := replicateOutputToString(prediction.Output) + if err != nil { + return nil, fmt.Errorf("failed to parse prediction output: %w", err) + } + reasonContent := "" + return &ChatResponse{Answer: &answer, ReasonContent: &reasonContent}, nil +} + +func (r *ReplicateModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("api key is required") + } + if strings.TrimSpace(modelName) == "" { + return fmt.Errorf("model name is required") + } + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + if chatModelConfig != nil && chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + + url, version, err := r.predictionEndpoint(apiConfig, modelName) + if err != nil { + return err + } + + prediction, err := r.createPrediction(context.Background(), url, version, replicateInputFromMessages(messages, chatModelConfig), true, *apiConfig.ApiKey, false) + if err != nil { + return err + } + if prediction.URLs.Stream == "" { + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + prediction, err = r.waitForPrediction(ctx, prediction, *apiConfig.ApiKey) + if err != nil { + return err + } + answer, err := replicateOutputToString(prediction.Output) + if err != nil { + return fmt.Errorf("failed to parse prediction output: %w", err) + } + if answer != "" { + if err := sender(&answer, nil); err != nil { + return err + } + } + endOfStream := "[DONE]" + return sender(&endOfStream, nil) + } + + return r.readPredictionStream(prediction.URLs.Stream, *apiConfig.ApiKey, sender) +} + +func (r *ReplicateModel) readPredictionStream(url string, apiKey string, sender func(*string, *string) error) error { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + req.Header.Set("Accept", "text/event-stream") + + resp, err := r.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + current := replicateSSEEvent{} + sawDone := false + for scanner.Scan() { + line := scanner.Text() + if line == "" { + done, err := dispatchReplicateSSEEvent(current, sender) + if err != nil { + return err + } + if done { + sawDone = true + break + } + current = replicateSSEEvent{} + continue + } + if strings.HasPrefix(line, "event:") { + current.event = strings.TrimSpace(line[6:]) + } + if strings.HasPrefix(line, "data:") { + if current.data != "" { + current.data += "\n" + } + data := line[5:] + if strings.HasPrefix(data, " ") { + data = data[1:] + } + current.data += data + } + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawDone && (current.event != "" || current.data != "") { + done, err := dispatchReplicateSSEEvent(current, sender) + if err != nil { + return err + } + sawDone = done + } + if !sawDone { + return fmt.Errorf("replicate: stream ended before done event") + } + + endOfStream := "[DONE]" + return sender(&endOfStream, nil) +} + +func dispatchReplicateSSEEvent(event replicateSSEEvent, sender func(*string, *string) error) (bool, error) { + switch event.event { + case "output", "": + if event.data == "" { + return false, nil + } + return false, sender(&event.data, nil) + case "error": + return false, fmt.Errorf("replicate: upstream stream error: %s", event.data) + case "done": + return true, nil + default: + return false, nil + } +} + +func (r *ReplicateModel) ListModels(apiConfig *APIConfig) ([]string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + url, err := r.endpoint(apiConfig, r.URLSuffix.Models) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := r.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result replicateModelsResponse + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + models := make([]string, 0, len(result.Results)) + for _, model := range result.Results { + if model.Owner != "" && model.Name != "" { + models = append(models, fmt.Sprintf("%s/%s", model.Owner, model.Name)) + } + } + return models, nil +} + +func (r *ReplicateModel) CheckConnection(apiConfig *APIConfig) error { + _, err := r.ListModels(apiConfig) + return err +} + +func (r *ReplicateModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} diff --git a/internal/entity/models/replicate_test.go b/internal/entity/models/replicate_test.go new file mode 100644 index 0000000000..d9eb1efd6b --- /dev/null +++ b/internal/entity/models/replicate_test.go @@ -0,0 +1,321 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newReplicateForTest(baseURL string) *ReplicateModel { + return NewReplicateModel( + map[string]string{"default": baseURL}, + URLSuffix{Chat: "v1/predictions", Models: "v1/models"}, + ) +} + +func TestReplicateName(t *testing.T) { + if got := newReplicateForTest("http://unused").Name(); got != "replicate" { + t.Errorf("Name()=%q", got) + } +} + +func TestReplicateFactory(t *testing.T) { + driver, err := NewModelFactory().CreateModelDriver("Replicate", map[string]string{"default": "http://unused"}, URLSuffix{}) + if err != nil { + t.Fatalf("CreateModelDriver: %v", err) + } + if _, ok := driver.(*ReplicateModel); !ok { + t.Fatalf("driver type=%T, want *ReplicateModel", driver) + } +} + +func TestReplicatePromptFromMessages(t *testing.T) { + prompt, system := replicatePromptFromMessages([]Message{ + {Role: "system", Content: "be terse"}, + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi"}, + {Role: "user", Content: map[string]interface{}{"text": "again"}}, + }) + if system != "be terse" { + t.Errorf("system=%q", system) + } + want := "user: hello\nassistant: hi\nuser: {\"text\":\"again\"}" + if prompt != want { + t.Errorf("prompt=%q want %q", prompt, want) + } +} + +func TestReplicatePredictionEndpoint(t *testing.T) { + m := newReplicateForTest("https://api.example.test") + + endpoint, version, err := m.predictionEndpoint(&APIConfig{}, "meta/meta-llama-3-70b-instruct") + if err != nil { + t.Fatalf("official endpoint: %v", err) + } + if endpoint != "https://api.example.test/v1/models/meta/meta-llama-3-70b-instruct/predictions" { + t.Errorf("official endpoint=%q", endpoint) + } + if version != "" { + t.Errorf("official version=%q want empty", version) + } + + endpoint, version, err = m.predictionEndpoint(&APIConfig{}, "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa") + if err != nil { + t.Fatalf("version endpoint: %v", err) + } + if endpoint != "https://api.example.test/v1/predictions" { + t.Errorf("version endpoint=%q", endpoint) + } + if version != "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" { + t.Errorf("version=%q", version) + } +} + +func TestReplicateOfficialChatHappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models/meta/meta-llama-3-70b-instruct/predictions" { + t.Errorf("path=%s", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("Authorization=%q", got) + } + if got := r.Header.Get("Prefer"); got != "wait=60" { + t.Errorf("Prefer=%q", got) + } + raw, _ := io.ReadAll(r.Body) + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("body: %v", err) + return + } + if body["version"] != nil { + t.Errorf("official model requests must not send version=%v", body["version"]) + } + if body["stream"] != false { + t.Errorf("stream=%v", body["stream"]) + } + input := body["input"].(map[string]interface{}) + if input["prompt"] != "hello" { + t.Errorf("prompt=%v", input["prompt"]) + } + if input["system_prompt"] != "be helpful" { + t.Errorf("system_prompt=%v", input["system_prompt"]) + } + if input["max_new_tokens"] != float64(128) { + t.Errorf("max_new_tokens=%v", input["max_new_tokens"]) + } + // Stop is deliberately filtered out because Replicate model + // inputs are model-specific and upstream support is undefined. + if input["stop"] != nil { + t.Errorf("unexpected stop=%v", input["stop"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "successful", + "output": []string{"hel", "lo"}, + }) + })) + defer srv.Close() + + apiKey := "test-key" + maxTokens := 128 + stop := []string{"END"} + resp, err := newReplicateForTest(srv.URL).ChatWithMessages( + "meta/meta-llama-3-70b-instruct", + []Message{{Role: "system", Content: "be helpful"}, {Role: "user", Content: "hello"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{MaxTokens: &maxTokens, Stop: &stop}, + ) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if *resp.Answer != "hello" { + t.Errorf("Answer=%q", *resp.Answer) + } + if *resp.ReasonContent != "" { + t.Errorf("ReasonContent=%q", *resp.ReasonContent) + } +} + +func TestReplicateCommunityChatUsesVersionEndpoint(t *testing.T) { + const version = "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/predictions" { + t.Errorf("path=%s", r.URL.Path) + } + raw, _ := io.ReadAll(r.Body) + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("body: %v", err) + return + } + if body["version"] != version { + t.Errorf("version=%v", body["version"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "successful", + "output": "ok", + }) + })) + defer srv.Close() + + apiKey := "test-key" + resp, err := newReplicateForTest(srv.URL).ChatWithMessages( + version, + []Message{{Role: "user", Content: "hello"}}, + &APIConfig{ApiKey: &apiKey}, nil, + ) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if *resp.Answer != "ok" { + t.Errorf("Answer=%q", *resp.Answer) + } +} + +func TestReplicateChatPollsUntilSucceeded(t *testing.T) { + var getCount int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("Authorization=%q", got) + } + switch r.URL.Path { + case "/v1/models/meta/meta-llama-3-70b-instruct/predictions": + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "processing", + "urls": map[string]string{ + "get": "http://" + r.Host + "/v1/predictions/p1", + }, + }) + case "/v1/predictions/p1": + getCount++ + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "successful", + "output": "done", + }) + default: + t.Errorf("unexpected path=%s", r.URL.Path) + } + })) + defer srv.Close() + + apiKey := "test-key" + resp, err := newReplicateForTest(srv.URL).ChatWithMessages( + "meta/meta-llama-3-70b-instruct", + []Message{{Role: "user", Content: "hello"}}, + &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if getCount != 1 { + t.Errorf("getCount=%d", getCount) + } + if *resp.Answer != "done" { + t.Errorf("Answer=%q", *resp.Answer) + } +} + +func TestReplicateStreamHappyPath(t *testing.T) { + var streamURL string + streamSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Accept"); got != "text/event-stream" { + t.Errorf("Accept=%q", got) + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, "event: output\n") + _, _ = io.WriteString(w, "data: Hello\n\n") + _, _ = io.WriteString(w, "event: output\n") + _, _ = io.WriteString(w, "data: world\n\n") + _, _ = io.WriteString(w, "event: done\n") + _, _ = io.WriteString(w, "data: {}\n\n") + })) + defer streamSrv.Close() + streamURL = streamSrv.URL + + apiSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models/meta/meta-llama-3-70b-instruct/predictions" { + t.Errorf("path=%s", r.URL.Path) + } + raw, _ := io.ReadAll(r.Body) + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("body: %v", err) + return + } + if body["stream"] != true { + t.Errorf("stream=%v", body["stream"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "starting", + "urls": map[string]string{ + "stream": streamURL, + }, + }) + })) + defer apiSrv.Close() + + apiKey := "test-key" + var chunks []string + err := newReplicateForTest(apiSrv.URL).ChatStreamlyWithSender( + "meta/meta-llama-3-70b-instruct", + []Message{{Role: "user", Content: "hello"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(c *string, _ *string) error { + if c != nil { + chunks = append(chunks, *c) + } + return nil + }) + if err != nil { + t.Fatalf("ChatStreamlyWithSender: %v", err) + } + if strings.Join(chunks, "") != "Hello world[DONE]" { + t.Errorf("chunks=%q", strings.Join(chunks, "")) + } +} + +func TestReplicateListModelsAndCheckConnection(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models" { + t.Errorf("path=%s", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("Authorization=%q", got) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "results": []map[string]string{ + {"owner": "meta", "name": "meta-llama-3-70b-instruct"}, + {"owner": "replicate", "name": "hello-world"}, + }, + }) + })) + defer srv.Close() + + apiKey := "test-key" + model := newReplicateForTest(srv.URL) + models, err := model.ListModels(&APIConfig{ApiKey: &apiKey}) + if err != nil { + t.Fatalf("ListModels: %v", err) + } + if strings.Join(models, ",") != "meta/meta-llama-3-70b-instruct,replicate/hello-world" { + t.Errorf("models=%v", models) + } + if err := model.CheckConnection(&APIConfig{ApiKey: &apiKey}); err != nil { + t.Fatalf("CheckConnection: %v", err) + } +} + +func TestReplicateUnsupportedMethods(t *testing.T) { + m := newReplicateForTest("http://unused") + if _, err := m.Embed(nil, nil, nil, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Embed error=%v", err) + } + if _, err := m.Rerank(nil, "", nil, nil, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Rerank error=%v", err) + } + if _, err := m.Balance(nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Balance error=%v", err) + } +}