Files
ragflow/internal/engine/infinity/sql.go
qinling0210 563d855780 Implement OpenAI chat completions in GO (#16177)
### What problem does this PR solve?

Implement OpenAI chat completions in GO

POST /api/v1/openai/<chat_id>/chat/completions

OpenAI chat cli: internal/development.md

### Type of change

- [x] Refactoring
2026-06-18 18:07:27 +08:00

331 lines
9.1 KiB
Go

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package infinity
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"ragflow/internal/common"
"go.uber.org/zap"
)
const (
psqlTimeout = 10 * time.Second
defaultPsqlPath = "/usr/bin/psql"
defaultPsqlHost = "infinity"
defaultPsqlPort = "5432"
)
var whitespaceRe = regexp.MustCompile("[ `]+")
var rowCountFooterRe = regexp.MustCompile(`^\(\d+ rows?`)
// fieldMappingEntry is one entry in infinity_mapping.json.
type fieldMappingEntry struct {
Type string `json:"type"`
Comment string `json:"comment"`
}
// loadFieldMapping reads infinity_mapping.json and returns alias→actual
// and actual→firstAlias maps. Silently returns empty maps on missing file.
func loadFieldMapping(mappingFileName string) (aliasToActual map[string]string, actualToFirstAlias map[string]string, err error) {
if mappingFileName == "" {
mappingFileName = "infinity_mapping.json"
}
confPath := filepath.Join(projectBaseDir(), "conf", mappingFileName)
data, err := os.ReadFile(confPath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return map[string]string{}, map[string]string{}, nil
}
return nil, nil, fmt.Errorf("load field mapping %q: %w", confPath, err)
}
fields := map[string]fieldMappingEntry{}
if err := json.Unmarshal(data, &fields); err != nil {
return nil, nil, fmt.Errorf("parse field mapping %q: %w", confPath, err)
}
aliasToActual = make(map[string]string, len(fields)*2)
actualToFirstAlias = make(map[string]string, len(fields))
for actual, info := range fields {
if info.Comment == "" {
continue
}
var firstAlias string
for _, raw := range strings.Split(info.Comment, ",") {
alias := strings.TrimSpace(raw)
if alias == "" {
continue
}
aliasToActual[alias] = actual
if firstAlias == "" {
firstAlias = alias
}
}
if firstAlias != "" {
actualToFirstAlias[actual] = firstAlias
}
}
return aliasToActual, actualToFirstAlias, nil
}
// projectBaseDir returns the project root. Honors RAG_PROJECT_BASE and
// RAG_DEPLOY_BASE env vars; falls back to working directory.
func projectBaseDir() string {
if v := os.Getenv("RAG_PROJECT_BASE"); v != "" {
return v
}
if v := os.Getenv("RAG_DEPLOY_BASE"); v != "" {
return v
}
// Fall back to the repository root. The Go engine package lives at
// internal/engine/infinity/; the repo root is three levels up.
wd, err := os.Getwd()
if err != nil {
return "."
}
return wd
}
// preprocessSQL collapses spaces/backticks and strips '%'.
func preprocessSQL(sql string) string {
sql = whitespaceRe.ReplaceAllString(sql, " ")
sql = strings.ReplaceAll(sql, "%", "")
return sql
}
// rewriteFieldAliases rewrites alias field names to actual stored names
// in SELECT, WHERE, ORDER BY, GROUP BY, and HAVING clauses.
func rewriteFieldAliases(sql string, aliasToActual map[string]string) string {
if len(aliasToActual) == 0 {
return sql
}
selectRe := regexp.MustCompile(`(?si)(select\s+)(.+?)(\s+from\b)`)
sql = selectRe.ReplaceAllStringFunc(sql, func(m string) string {
parts := selectRe.FindStringSubmatch(m)
prefix, cols, suffix := parts[1], parts[2], parts[3]
for alias, actual := range aliasToActual {
pat := regexp.MustCompile(`(^|[,\s])` + regexp.QuoteMeta(alias) + `($|[,\s])`)
cols = pat.ReplaceAllString(cols, "${1}"+actual+"${2}")
}
return prefix + cols + suffix
})
clauseAliases := func(sql, keyword string) string {
return rewriteFirstAliasAfterKeyword(sql, keyword, aliasToActual)
}
sql = clauseAliases(sql, "where")
sql = clauseAliases(sql, "order by")
sql = clauseAliases(sql, "group by")
sql = clauseAliases(sql, "having")
return sql
}
func rewriteFirstAliasAfterKeyword(sql, keyword string, aliasToActual map[string]string) string {
for alias, actual := range aliasToActual {
aliasPat := regexp.MustCompile(`\b` + regexp.QuoteMeta(alias) + `\b`)
kwIdx := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(keyword) + `\b`).FindStringIndex(sql)
if kwIdx == nil {
continue
}
tail := sql[kwIdx[1]:]
aliasIdx := aliasPat.FindStringIndex(tail)
if aliasIdx == nil {
continue
}
absStart := kwIdx[1] + aliasIdx[0]
absEnd := kwIdx[1] + aliasIdx[1]
sql = sql[:absStart] + actual + sql[absEnd:]
}
return sql
}
// psqlResult is the structured parse of a psql table-format output.
type psqlResult struct {
Columns []string
Rows [][]string
}
// runPsql shells out to psql and parses the table-format output.
func runPsql(ctx context.Context, host, port, sql string) (*psqlResult, error) {
psqlPath, err := findPsqlBinary()
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(ctx, psqlTimeout)
defer cancel()
cmd := exec.CommandContext(ctx, psqlPath, "-h", host, "-p", port, "-c", sql)
common.Debug("executing psql",
zap.String("path", psqlPath),
zap.String("host", host),
zap.String("port", port),
)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if ctx.Err() != nil {
return nil, fmt.Errorf("SQL timeout\n\nSQL: %s", sql)
}
return nil, fmt.Errorf("psql command failed: %s\nSQL: %s", strings.TrimSpace(stderr.String()), sql)
}
return parsePsqlTable(stdout.String()), nil
}
// findPsqlBinary checks PATH first, then falls back to defaultPsqlPath.
func findPsqlBinary() (string, error) {
if path, err := exec.LookPath("psql"); err == nil {
return path, nil
}
if _, err := os.Stat(defaultPsqlPath); err == nil {
return defaultPsqlPath, nil
}
return "", fmt.Errorf("psql not found on PATH and not at %q", defaultPsqlPath)
}
// parsePsqlTable parses psql's pipe-delimited output:
//
// col1 | col2
// -----+-----
// val1 | val2
func parsePsqlTable(output string) *psqlResult {
res := &psqlResult{}
out := strings.TrimSpace(output)
if out == "" {
return res
}
lines := strings.Split(out, "\n")
if len(lines) == 0 {
return res
}
for _, raw := range strings.Split(lines[0], "|") {
if col := strings.TrimSpace(raw); col != "" {
res.Columns = append(res.Columns, col)
}
}
dataStart := 1
if len(lines) >= 2 && strings.Contains(lines[1], "-") {
dataStart = 2
}
for i := dataStart; i < len(lines); i++ {
line := strings.TrimSpace(lines[i])
if line == "" || rowCountFooterRe.MatchString(line) {
continue
}
cells := strings.Split(line, "|")
for j := range cells {
cells[j] = strings.TrimSpace(cells[j])
}
switch {
case len(cells) == len(res.Columns):
res.Rows = append(res.Rows, cells)
case len(cells) > len(res.Columns):
res.Rows = append(res.Rows, cells[:len(res.Columns)])
default:
padded := make([]string, len(res.Columns))
copy(padded, cells)
for k := len(cells); k < len(res.Columns); k++ {
padded[k] = ""
}
res.Rows = append(res.Rows, padded)
}
}
return res
}
// toRowMaps converts psqlResult to a slice of column-keyed maps.
func toRowMaps(res *psqlResult) []map[string]interface{} {
if res == nil || len(res.Rows) == 0 {
return nil
}
out := make([]map[string]interface{}, 0, len(res.Rows))
for _, row := range res.Rows {
m := make(map[string]interface{}, len(res.Columns))
for j, col := range res.Columns {
if j < len(row) {
m[col] = row[j]
}
}
out = append(out, m)
}
return out
}
func resolvePsqlHostPort(hostURI string, postgresPort int) (host, port string) {
host = defaultPsqlHost
port = defaultPsqlPort
if postgresPort > 0 {
port = strconv.Itoa(postgresPort)
}
if hostURI != "" {
if h, _, ok := strings.Cut(hostURI, ":"); ok && h != "" {
host = h
}
}
return host, port
}
// RunSQL implements the SQL retrieval path: preprocess, rewrite aliases,
// run psql subprocess, parse output.
func (e *infinityEngine) RunSQL(ctx context.Context, tableName string, sqlText string, kbIDs []string, _ string) ([]map[string]interface{}, error) {
if e == nil || e.client == nil {
return nil, fmt.Errorf("infinity RunSQL: client not initialized")
}
sqlText = strings.TrimSpace(sqlText)
if sqlText == "" {
return nil, fmt.Errorf("infinity RunSQL: empty SQL")
}
common.Debug("InfinityConnection.sql get sql", zap.String("sql", sqlText))
sqlText = preprocessSQL(sqlText)
aliasMap, _, err := loadFieldMapping(e.client.mappingFileName)
if err != nil {
return nil, fmt.Errorf("infinity RunSQL: %w", err)
}
sqlText = rewriteFieldAliases(sqlText, aliasMap)
common.Debug("InfinityConnection.sql to execute", zap.String("sql", sqlText))
host, port := resolvePsqlHostPort(e.client.hostURI, e.client.postgresPort)
res, err := runPsql(ctx, host, port, sqlText)
if err != nil {
return nil, err
}
return toRowMaps(res), nil
}