Files
ragflow/internal/common/timer_test.go
qinling0210 563d855780 Implement OpenAI chat completions in GO (#16177)
### What problem does this PR solve?

Implement OpenAI chat completions in GO

POST /api/v1/openai/<chat_id>/chat/completions

OpenAI chat cli: internal/development.md

### Type of change

- [x] Refactoring
2026-06-18 18:07:27 +08:00

225 lines
6.2 KiB
Go

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package common
import (
"encoding/json"
"regexp"
"strings"
"sync"
"testing"
"time"
)
func TestTimer_BasicSequentialPhases(t *testing.T) {
tm := NewTimer()
tm.Start()
tm.Enter(PhaseCheckLLM)
time.Sleep(5 * time.Millisecond)
tm.Exit(PhaseCheckLLM)
tm.Enter(PhaseBindModels)
time.Sleep(3 * time.Millisecond)
tm.Exit(PhaseBindModels)
got := tm.Phase(PhaseCheckLLM)
if got < 4*time.Millisecond || got > 50*time.Millisecond {
t.Errorf("PhaseCheckLLM = %v, want ~5ms", got)
}
got = tm.Phase(PhaseBindModels)
if got < 2*time.Millisecond || got > 50*time.Millisecond {
t.Errorf("PhaseBindModels = %v, want ~3ms", got)
}
// Untouched phase should be 0.
if d := tm.Phase(PhaseRetrieval); d != 0 {
t.Errorf("PhaseRetrieval = %v, want 0", d)
}
total := tm.Total()
if total < 7*time.Millisecond {
t.Errorf("Total = %v, want >= 7ms", total)
}
}
func TestTimer_NestedPhasesAddUp(t *testing.T) {
tm := NewTimer()
tm.Start()
tm.Enter(PhaseQueryRefinement) // outer
time.Sleep(2 * time.Millisecond)
tm.Enter(PhaseGenerateAnswer) // inner (LLM call inside pre-retrieval)
time.Sleep(3 * time.Millisecond)
tm.Exit(PhaseGenerateAnswer)
time.Sleep(1 * time.Millisecond)
tm.Exit(PhaseQueryRefinement)
// Generate answer records the inner 3ms.
got := tm.Phase(PhaseGenerateAnswer)
if got < 2*time.Millisecond || got > 50*time.Millisecond {
t.Errorf("PhaseGenerateAnswer = %v, want ~3ms", got)
}
// Pre-retrieval processing records the WHOLE outer span (2 + 3 + 1 ≈ 6ms).
got = tm.Phase(PhaseQueryRefinement)
if got < 5*time.Millisecond || got > 50*time.Millisecond {
t.Errorf("PhaseQueryRefinement = %v, want ~6ms (outer span)", got)
}
}
func TestTimer_ExitWithoutEnterIsNoop(t *testing.T) {
tm := NewTimer()
tm.Start()
// Should not panic, should not record anything.
tm.Exit(PhaseRetrieval)
if d := tm.Phase(PhaseRetrieval); d != 0 {
t.Errorf("PhaseRetrieval = %v, want 0", d)
}
}
func TestTimer_StartResetsState(t *testing.T) {
tm := NewTimer()
tm.Start()
tm.Enter(PhaseCheckLLM)
time.Sleep(2 * time.Millisecond)
tm.Exit(PhaseCheckLLM)
if tm.Phase(PhaseCheckLLM) == 0 {
t.Fatal("precondition: phase must be non-zero before reset")
}
tm.Start()
if d := tm.Phase(PhaseCheckLLM); d != 0 {
t.Errorf("after Start, PhaseCheckLLM = %v, want 0", d)
}
if total := tm.Total(); total > 50*time.Millisecond {
t.Errorf("after Start, Total = %v, want tiny", total)
}
}
func TestTimer_ConcurrentAccess(t *testing.T) {
tm := NewTimer()
tm.Start()
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
tm.Enter(PhaseRetrieval)
time.Sleep(time.Millisecond)
tm.Exit(PhaseRetrieval)
}()
}
wg.Wait()
got := tm.Phase(PhaseRetrieval)
if got < 9*time.Millisecond {
t.Errorf("PhaseRetrieval = %v, want ~10ms (10 parallel spans)", got)
}
}
func TestTimer_Report(t *testing.T) {
tm := NewTimer()
tm.Start()
tm.Enter(PhaseCheckLLM)
time.Sleep(2 * time.Millisecond)
tm.Exit(PhaseCheckLLM)
tm.Enter(PhaseBindModels)
time.Sleep(1 * time.Millisecond)
tm.Exit(PhaseBindModels)
r := tm.Report()
// Required fields
if _, ok := r.PhasesMs[string(PhaseCheckLLM)]; !ok {
t.Errorf("Report missing PhaseCheckLLM: %+v", r.PhasesMs)
}
if _, ok := r.PhasesMs[string(PhaseBindModels)]; !ok {
t.Errorf("Report missing PhaseBindModels: %+v", r.PhasesMs)
}
if _, ok := r.PhasesMs[string(PhaseGenerateAnswer)]; !ok {
t.Errorf("Report missing PhaseGenerateAnswer: %+v", r.PhasesMs)
}
if r.PhasesMs[string(PhaseCheckLLM)] < 1.0 {
t.Errorf("Report PhaseCheckLLM_ms = %v, want >= 1.0", r.PhasesMs[string(PhaseCheckLLM)])
}
if r.TotalMs < 2.0 {
t.Errorf("Report TotalMs = %v, want >= 2.0", r.TotalMs)
}
// JSON round-trip
b, err := json.Marshal(r)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if !strings.Contains(string(b), `"phases_ms"`) || !strings.Contains(string(b), `"total_ms"`) {
t.Errorf("JSON missing expected keys: %s", b)
}
// Direct Marshal of the Timer
b2, err := json.Marshal(tm)
if err != nil {
t.Fatalf("Marshal(Timer) failed: %v", err)
}
if !strings.Contains(string(b2), `"phases_ms"`) {
t.Errorf("Timer JSON missing phases_ms: %s", b2)
}
}
func TestTimer_Markdown(t *testing.T) {
tm := NewTimer()
tm.Start()
tm.Enter(PhaseCheckLLM)
time.Sleep(2 * time.Millisecond)
tm.Exit(PhaseCheckLLM)
tm.Enter(PhaseRetrieval)
time.Sleep(5 * time.Millisecond)
tm.Exit(PhaseRetrieval)
tm.Enter(PhaseGenerateAnswer)
time.Sleep(50 * time.Millisecond)
tm.Exit(PhaseGenerateAnswer)
md := tm.Markdown()
// Should start with newline + "## Time elapsed:" header
if !strings.HasPrefix(md, "\n## Time elapsed:") {
t.Errorf("Markdown missing header: %q", md)
}
// Should contain all 6 phase labels
for _, label := range []string{"Check LLM", "Check Langfuse tracer", "Bind models", "Query refinement(LLM)", "Retrieval", "Generate answer", "Total"} {
if !strings.Contains(md, label+":") {
t.Errorf("Markdown missing label %q: %q", label, md)
}
}
// Phase durations should be numeric with "ms" suffix.
mdRE := regexp.MustCompile(`(?m)^\s*-\s+([A-Za-z ()\.]+):\s+([0-9.]+)ms$`)
matches := mdRE.FindAllStringSubmatch(md, -1)
if len(matches) < 7 {
t.Errorf("expected 7 phase lines, found %d in:\n%s", len(matches), md)
}
// Total should be the sum-ish of the three measured phases.
totalRE := regexp.MustCompile(`Total:\s+([0-9.]+)ms`)
totalMatch := totalRE.FindStringSubmatch(md)
if len(totalMatch) < 2 {
t.Fatalf("Markdown missing Total line: %q", md)
}
}
func TestTimer_TotalBeforeStart(t *testing.T) {
tm := NewTimer()
// No Start() called.
if total := tm.Total(); total != 0 {
t.Errorf("Total before Start = %v, want 0", total)
}
}