Files

133 lines
4.4 KiB
Go
Raw Permalink Normal View History

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package models
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// withTestTimeouts temporarily shrinks the shared timeout knobs so deadline
// behaviour can be exercised with millisecond-scale delays instead of the
// production minutes/seconds, and restores them when the test ends.
func withTestTimeouts(t *testing.T, nonStream, stream time.Duration) {
t.Helper()
origNon, origStream := nonStreamCallTimeout, streamCallTimeout
nonStreamCallTimeout = nonStream
streamCallTimeout = stream
t.Cleanup(func() {
nonStreamCallTimeout = origNon
streamCallTimeout = origStream
})
}
func newTimeoutTestGroq(baseURL string) *GroqModel {
return NewGroqModel(
map[string]string{"default": baseURL},
URLSuffix{Chat: "chat/completions", Models: "models"},
)
}
// TestStreamNotTruncatedByNonStreamTimeout is the regression guard for the
// original bug: a streaming chat response that keeps emitting data for longer
// than the short non-streaming deadline must not be severed mid-stream. The
// stream is bounded by the generous streamCallTimeout, so a server that
// dribbles SSE events over a span well past nonStreamCallTimeout still
// completes intact. If a provider regressed to wrapping its stream in
// nonStreamCallTimeout (the old 120s-wall footgun, just relocated), the
// context would cancel mid-stream and this test would fail.
func TestStreamNotTruncatedByNonStreamTimeout(t *testing.T) {
// Stream emits for ~240ms, far past the 60ms non-stream deadline but well
// inside the 10s stream deadline.
withTestTimeouts(t, 60*time.Millisecond, 10*time.Second)
chunks := []string{"Hello", " ", "streamed", " ", "world"}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
t.Errorf("ResponseWriter does not support flushing")
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
for _, c := range chunks {
time.Sleep(40 * time.Millisecond)
fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":%q}}]}\n", c)
flusher.Flush()
}
time.Sleep(40 * time.Millisecond)
_, _ = io.WriteString(w, "data: [DONE]\n")
flusher.Flush()
}))
defer srv.Close()
apiKey := "test-key"
var got strings.Builder
err := newTimeoutTestGroq(srv.URL).ChatStreamlyWithSender(
"llama-3.3-70b-versatile",
[]Message{{Role: "user", Content: "hi"}},
&APIConfig{ApiKey: &apiKey},
nil,
func(c *string, _ *string) error {
if c != nil && *c != "[DONE]" {
got.WriteString(*c)
}
return nil
},
)
if err != nil {
t.Fatalf("streaming returned error (stream was truncated): %v", err)
}
if want := "Hello streamed world"; got.String() != want {
t.Fatalf("streamed content=%q, want %q", got.String(), want)
}
}
// TestNonStreamHonorsShortDeadline ensures dropping the blanket
// http.Client.Timeout did not leave non-streaming calls unbounded: a slow
// server must still trip the per-call nonStreamCallTimeout promptly.
func TestNonStreamHonorsShortDeadline(t *testing.T) {
withTestTimeouts(t, 100*time.Millisecond, 10*time.Second)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(800 * time.Millisecond) // far beyond the 100ms non-stream deadline
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"late"}}]}`)
}))
defer srv.Close()
apiKey := "test-key"
start := time.Now()
_, err := newTimeoutTestGroq(srv.URL).ChatWithMessages(
"llama-3.3-70b-versatile",
[]Message{{Role: "user", Content: "hi"}},
&APIConfig{ApiKey: &apiKey},
nil,
)
elapsed := time.Since(start)
if err == nil {
t.Fatal("expected a deadline error from the slow server, got nil")
}
if elapsed > 400*time.Millisecond {
t.Fatalf("call took %v; expected it to abort near the 100ms deadline", elapsed)
}
}