From a237a89b90932c41b17a39a7ad109edc626e7590 Mon Sep 17 00:00:00 2001 From: Jack Date: Fri, 5 Jun 2026 10:11:14 +0800 Subject: [PATCH] feat: add QueryRewrite prompt builder and response parser (#15669) QueryRewrite prompt builder and response parser. Zero external dependencies. ### Functions - `BuildQueryRewritePrompt`: Renders `minirag_query2kwd` prompt with query and type pool - `ParseQueryRewriteResponse`: Parses LLM JSON response with fallback for markdown and extra text ### Testing ``` === RUN TestBuildQueryRewritePrompt --- PASS === RUN TestParseQueryRewriteResponse_ValidJSON --- PASS === RUN TestParseQueryRewriteResponse_MarkdownBlock --- PASS === RUN TestParseQueryRewriteResponse_ExtraText --- PASS === RUN TestParseQueryRewriteResponse_Invalid --- PASS === RUN TestParseQueryRewriteResponse_EmptyEntities --- PASS ``` Co-authored-by: Claude Opus 4.8 --- internal/common/kg_query_rewrite.go | 172 +++++++++++++++++++++++ internal/common/kg_query_rewrite_test.go | 112 +++++++++++++++ 2 files changed, 284 insertions(+) create mode 100644 internal/common/kg_query_rewrite.go create mode 100644 internal/common/kg_query_rewrite_test.go diff --git a/internal/common/kg_query_rewrite.go b/internal/common/kg_query_rewrite.go new file mode 100644 index 0000000000..8bee01e596 --- /dev/null +++ b/internal/common/kg_query_rewrite.go @@ -0,0 +1,172 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package common + +import ( + "encoding/json" + "strings" +) + +// QueryRewriteResult holds the parsed result of a query rewrite. +type QueryRewriteResult struct { + TypeKeywords []string `json:"answer_type_keywords"` + Entities []string `json:"entities_from_query"` +} + +// queryRewritePromptTmpl is the system prompt template for query rewriting. +// Matches Python: rag/graphrag/query_analyze_prompt.py::PROMPTS["minirag_query2kwd"] +const queryRewritePromptTmpl = `---Role--- + +You are a helpful assistant tasked with identifying both answer-type and low-level keywords in the user's query. + +---Goal--- + +Given the query, list both answer-type and low-level keywords. +answer_type_keywords focus on the type of the answer to the certain query, while low-level keywords focus on specific entities, details, or concrete terms. +The answer_type_keywords must be selected from Answer type pool. +This pool is in the form of a dictionary, where the key represents the Type you should choose from and the value represents the example samples. + +---Instructions--- + +- Output the keywords in JSON format. +- The JSON should have three keys: + - "answer_type_keywords" for the types of the answer. In this list, the types with the highest likelihood should be placed at the forefront. No more than 3. + - "entities_from_query" for specific entities or details. It must be extracted from the query. +###################### +-Examples- +###################### +Example 1: + +Query: "How does international trade influence global economic stability?" +Answer type pool: { + 'PERSONAL LIFE': ['FAMILY TIME', 'HOME MAINTENANCE'], + 'STRATEGY': ['MARKETING PLAN', 'BUSINESS EXPANSION'], + 'SERVICE FACILITATION': ['ONLINE SUPPORT', 'CUSTOMER SERVICE TRAINING'], + 'PERSON': ['JANE DOE', 'JOHN SMITH'], + 'FOOD': ['PASTA', 'SUSHI'], + 'EMOTION': ['HAPPINESS', 'ANGER'], + 'PERSONAL EXPERIENCE': ['TRAVEL ABROAD', 'STUDYING ABROAD'], + 'INTERACTION': ['TEAM MEETING', 'NETWORKING EVENT'], + 'BEVERAGE': ['COFFEE', 'TEA'], + 'PLAN': ['ANNUAL BUDGET', 'PROJECT TIMELINE'], + 'GEO': ['NEW YORK CITY', 'SOUTH AFRICA'], + 'GEAR': ['CAMPING TENT', 'CYCLING HELMET'], + 'EMOJI': ['🎉', '🚀'], + 'BEHAVIOR': ['POSITIVE FEEDBACK', 'NEGATIVE CRITICISM'], + 'TONE': ['FORMAL', 'INFORMAL'], + 'LOCATION': ['DOWNTOWN', 'SUBURBS'] +}} +################ +Output: +{ + "answer_type_keywords": ["STRATEGY","PERSONAL LIFE"], + "entities_from_query": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"] +} +############################# +Example 2: + +Query: "Where is the capital of the United States?" +Answer type pool: { + 'ORGANIZATION': ['GREENPEACE', 'RED CROSS'], + 'PERSONAL LIFE': ['DAILY WORKOUT', 'HOME COOKING'], + 'STRATEGY': ['FINANCIAL INVESTMENT', 'BUSINESS EXPANSION'], + 'SERVICE FACILITATION': ['ONLINE SUPPORT', 'CUSTOMER SERVICE TRAINING'], + 'PERSON': ['ALBERTA SMITH', 'BENJAMIN JONES'], + 'FOOD': ['PASTA CARBONARA', 'SUSHI PLATTER'], + 'EMOTION': ['HAPPINESS', 'SADNESS'], + 'PERSONAL EXPERIENCE': ['TRAVEL ADVENTURE', 'BOOK CLUB'], + 'INTERACTION': ['TEAM BUILDING', 'NETWORKING MEETUP'], + 'BEVERAGE': ['LATTE', 'GREEN TEA'], + 'PLAN': ['WEIGHT LOSS', 'CAREER DEVELOPMENT'], + 'GEO': ['PARIS', 'NEW YORK'], + 'GEAR': ['CAMERA', 'HEADPHONES'], + 'EMOJI': ['🏢', '🌍'], + 'BEHAVIOR': ['POSITIVE THINKING', 'STRESS MANAGEMENT'], + 'TONE': ['FRIENDLY', 'PROFESSIONAL'], + 'LOCATION': ['DOWNTOWN', 'SUBURBS'] +}} +################ +Output: +{ + "answer_type_keywords": ["LOCATION"], + "entities_from_query": ["capital of the United States", "Washington", "New York"] +} +############################# + +-Real Data- +###################### +Query: {query} +Answer type pool:{TYPE_POOL} +###################### +Output: +` + +// BuildQueryRewritePrompt builds the system prompt for query rewrite. +func BuildQueryRewritePrompt(question string, ty2entsJSON string) string { + r := strings.NewReplacer( + "{query}", question, + "{TYPE_POOL}", ty2entsJSON, + ) + return r.Replace(queryRewritePromptTmpl) +} + +// ParseQueryRewriteResponse parses the LLM response and returns structured keywords. +// Handles JSON parsing with fallback logic matching Python's json_repair behavior. +func ParseQueryRewriteResponse(response string) (*QueryRewriteResult, error) { + // Try direct JSON parsing first + result, err := tryParseJSON(response) + if err == nil { + return result, nil + } + + // Fallback: try to extract JSON from markdown code blocks + cleaned := strings.TrimSpace(response) + if idx := strings.Index(cleaned, "```"); idx >= 0 { + rest := cleaned[idx+3:] + if end := strings.Index(rest, "```"); end >= 0 { + code := strings.TrimSpace(rest[:end]) + code = strings.TrimPrefix(code, "json") + code = strings.TrimSpace(code) + result, err := tryParseJSON(code) + if err == nil { + return result, nil + } + } + } + + // Fallback: extract first JSON object + start := strings.Index(cleaned, "{") + end := strings.LastIndex(cleaned, "}") + if start >= 0 && end > start { + candidate := cleaned[start : end+1] + result, err := tryParseJSON(candidate) + if err == nil { + return result, nil + } + } + + return nil, err // return the original error +} + +// tryParseJSON attempts to parse a JSON string into QueryRewriteResult. +func tryParseJSON(data string) (*QueryRewriteResult, error) { + var result QueryRewriteResult + if err := json.Unmarshal([]byte(data), &result); err != nil { + return nil, err + } + return &result, nil +} diff --git a/internal/common/kg_query_rewrite_test.go b/internal/common/kg_query_rewrite_test.go new file mode 100644 index 0000000000..813828a413 --- /dev/null +++ b/internal/common/kg_query_rewrite_test.go @@ -0,0 +1,112 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package common + +import ( + "testing" +) + +func TestBuildQueryRewritePrompt(t *testing.T) { + question := "What is the capital of France?" + ty2ents := `{"LOCATION": ["Paris", "London"]}` + prompt := BuildQueryRewritePrompt(question, ty2ents) + if !contains(prompt, question) { + t.Error("expected question in prompt") + } + if !contains(prompt, ty2ents) { + t.Error("expected type pool in prompt") + } + if contains(prompt, "{query}") { + t.Error("placeholder {query} should have been replaced") + } + if contains(prompt, "{TYPE_POOL}") { + t.Error("placeholder {TYPE_POOL} should have been replaced") + } +} + +func TestParseQueryRewriteResponse_ValidJSON(t *testing.T) { + resp := `{ + "answer_type_keywords": ["LOCATION", "ORGANIZATION"], + "entities_from_query": ["France", "Paris", "Capital"] + }` + result, err := ParseQueryRewriteResponse(resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result.TypeKeywords) != 2 || result.TypeKeywords[0] != "LOCATION" { + t.Errorf("expected [LOCATION ORGANIZATION], got %v", result.TypeKeywords) + } + if len(result.Entities) != 3 || result.Entities[0] != "France" { + t.Errorf("expected [France Paris Capital], got %v", result.Entities) + } +} + +func TestParseQueryRewriteResponse_MarkdownBlock(t *testing.T) { + resp := "Here is the result:\n```json\n{\n\t\"answer_type_keywords\": [\"DATE\"],\n\t\"entities_from_query\": [\"SpaceX\", \"launch\"]\n}\n```" + result, err := ParseQueryRewriteResponse(resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result.TypeKeywords) != 1 || result.TypeKeywords[0] != "DATE" { + t.Errorf("expected [DATE], got %v", result.TypeKeywords) + } +} + +func TestParseQueryRewriteResponse_ExtraText(t *testing.T) { + resp := `Some text before +{ + "answer_type_keywords": ["PERSON"], + "entities_from_query": ["Einstein"] +} +Some text after` + result, err := ParseQueryRewriteResponse(resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result.Entities) != 1 || result.Entities[0] != "Einstein" { + t.Errorf("expected [Einstein], got %v", result.Entities) + } +} + +func TestParseQueryRewriteResponse_Invalid(t *testing.T) { + resp := "This is not valid JSON" + _, err := ParseQueryRewriteResponse(resp) + if err == nil { + t.Error("expected error for invalid JSON") + } +} + +func TestParseQueryRewriteResponse_EmptyEntities(t *testing.T) { + resp := `{"answer_type_keywords": ["LOCATION"], "entities_from_query": []}` + result, err := ParseQueryRewriteResponse(resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result.Entities) != 0 { + t.Errorf("expected empty entities, got %v", result.Entities) + } +} + +// contains checks if a string contains a substring. +func contains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +}