mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
### What problem does this PR solve? Implement OpenAI chat completions in GO POST /api/v1/openai/<chat_id>/chat/completions OpenAI chat cli: internal/development.md ### Type of change - [x] Refactoring
331 lines
9.1 KiB
Go
331 lines
9.1 KiB
Go
//
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
|
|
package infinity
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"ragflow/internal/common"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
const (
|
|
psqlTimeout = 10 * time.Second
|
|
defaultPsqlPath = "/usr/bin/psql"
|
|
defaultPsqlHost = "infinity"
|
|
defaultPsqlPort = "5432"
|
|
)
|
|
|
|
var whitespaceRe = regexp.MustCompile("[ `]+")
|
|
|
|
var rowCountFooterRe = regexp.MustCompile(`^\(\d+ rows?`)
|
|
|
|
// fieldMappingEntry is one entry in infinity_mapping.json.
|
|
type fieldMappingEntry struct {
|
|
Type string `json:"type"`
|
|
Comment string `json:"comment"`
|
|
}
|
|
|
|
// loadFieldMapping reads infinity_mapping.json and returns alias→actual
|
|
// and actual→firstAlias maps. Silently returns empty maps on missing file.
|
|
func loadFieldMapping(mappingFileName string) (aliasToActual map[string]string, actualToFirstAlias map[string]string, err error) {
|
|
if mappingFileName == "" {
|
|
mappingFileName = "infinity_mapping.json"
|
|
}
|
|
confPath := filepath.Join(projectBaseDir(), "conf", mappingFileName)
|
|
data, err := os.ReadFile(confPath)
|
|
if err != nil {
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
return map[string]string{}, map[string]string{}, nil
|
|
}
|
|
return nil, nil, fmt.Errorf("load field mapping %q: %w", confPath, err)
|
|
}
|
|
|
|
fields := map[string]fieldMappingEntry{}
|
|
if err := json.Unmarshal(data, &fields); err != nil {
|
|
return nil, nil, fmt.Errorf("parse field mapping %q: %w", confPath, err)
|
|
}
|
|
|
|
aliasToActual = make(map[string]string, len(fields)*2)
|
|
actualToFirstAlias = make(map[string]string, len(fields))
|
|
for actual, info := range fields {
|
|
if info.Comment == "" {
|
|
continue
|
|
}
|
|
var firstAlias string
|
|
for _, raw := range strings.Split(info.Comment, ",") {
|
|
alias := strings.TrimSpace(raw)
|
|
if alias == "" {
|
|
continue
|
|
}
|
|
aliasToActual[alias] = actual
|
|
if firstAlias == "" {
|
|
firstAlias = alias
|
|
}
|
|
}
|
|
if firstAlias != "" {
|
|
actualToFirstAlias[actual] = firstAlias
|
|
}
|
|
}
|
|
return aliasToActual, actualToFirstAlias, nil
|
|
}
|
|
|
|
// projectBaseDir returns the project root. Honors RAG_PROJECT_BASE and
|
|
// RAG_DEPLOY_BASE env vars; falls back to working directory.
|
|
func projectBaseDir() string {
|
|
if v := os.Getenv("RAG_PROJECT_BASE"); v != "" {
|
|
return v
|
|
}
|
|
if v := os.Getenv("RAG_DEPLOY_BASE"); v != "" {
|
|
return v
|
|
}
|
|
// Fall back to the repository root. The Go engine package lives at
|
|
// internal/engine/infinity/; the repo root is three levels up.
|
|
wd, err := os.Getwd()
|
|
if err != nil {
|
|
return "."
|
|
}
|
|
return wd
|
|
}
|
|
|
|
// preprocessSQL collapses spaces/backticks and strips '%'.
|
|
func preprocessSQL(sql string) string {
|
|
sql = whitespaceRe.ReplaceAllString(sql, " ")
|
|
sql = strings.ReplaceAll(sql, "%", "")
|
|
return sql
|
|
}
|
|
|
|
// rewriteFieldAliases rewrites alias field names to actual stored names
|
|
// in SELECT, WHERE, ORDER BY, GROUP BY, and HAVING clauses.
|
|
func rewriteFieldAliases(sql string, aliasToActual map[string]string) string {
|
|
if len(aliasToActual) == 0 {
|
|
return sql
|
|
}
|
|
selectRe := regexp.MustCompile(`(?si)(select\s+)(.+?)(\s+from\b)`)
|
|
sql = selectRe.ReplaceAllStringFunc(sql, func(m string) string {
|
|
parts := selectRe.FindStringSubmatch(m)
|
|
prefix, cols, suffix := parts[1], parts[2], parts[3]
|
|
for alias, actual := range aliasToActual {
|
|
pat := regexp.MustCompile(`(^|[,\s])` + regexp.QuoteMeta(alias) + `($|[,\s])`)
|
|
cols = pat.ReplaceAllString(cols, "${1}"+actual+"${2}")
|
|
}
|
|
return prefix + cols + suffix
|
|
})
|
|
|
|
clauseAliases := func(sql, keyword string) string {
|
|
return rewriteFirstAliasAfterKeyword(sql, keyword, aliasToActual)
|
|
}
|
|
sql = clauseAliases(sql, "where")
|
|
sql = clauseAliases(sql, "order by")
|
|
sql = clauseAliases(sql, "group by")
|
|
sql = clauseAliases(sql, "having")
|
|
return sql
|
|
}
|
|
|
|
func rewriteFirstAliasAfterKeyword(sql, keyword string, aliasToActual map[string]string) string {
|
|
for alias, actual := range aliasToActual {
|
|
aliasPat := regexp.MustCompile(`\b` + regexp.QuoteMeta(alias) + `\b`)
|
|
kwIdx := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(keyword) + `\b`).FindStringIndex(sql)
|
|
if kwIdx == nil {
|
|
continue
|
|
}
|
|
tail := sql[kwIdx[1]:]
|
|
aliasIdx := aliasPat.FindStringIndex(tail)
|
|
if aliasIdx == nil {
|
|
continue
|
|
}
|
|
absStart := kwIdx[1] + aliasIdx[0]
|
|
absEnd := kwIdx[1] + aliasIdx[1]
|
|
sql = sql[:absStart] + actual + sql[absEnd:]
|
|
}
|
|
return sql
|
|
}
|
|
|
|
// psqlResult is the structured parse of a psql table-format output.
|
|
type psqlResult struct {
|
|
Columns []string
|
|
Rows [][]string
|
|
}
|
|
|
|
// runPsql shells out to psql and parses the table-format output.
|
|
func runPsql(ctx context.Context, host, port, sql string) (*psqlResult, error) {
|
|
psqlPath, err := findPsqlBinary()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, psqlTimeout)
|
|
defer cancel()
|
|
|
|
cmd := exec.CommandContext(ctx, psqlPath, "-h", host, "-p", port, "-c", sql)
|
|
common.Debug("executing psql",
|
|
zap.String("path", psqlPath),
|
|
zap.String("host", host),
|
|
zap.String("port", port),
|
|
)
|
|
var stdout, stderr bytes.Buffer
|
|
cmd.Stdout = &stdout
|
|
cmd.Stderr = &stderr
|
|
|
|
if err := cmd.Run(); err != nil {
|
|
if ctx.Err() != nil {
|
|
return nil, fmt.Errorf("SQL timeout\n\nSQL: %s", sql)
|
|
}
|
|
return nil, fmt.Errorf("psql command failed: %s\nSQL: %s", strings.TrimSpace(stderr.String()), sql)
|
|
}
|
|
return parsePsqlTable(stdout.String()), nil
|
|
}
|
|
|
|
// findPsqlBinary checks PATH first, then falls back to defaultPsqlPath.
|
|
func findPsqlBinary() (string, error) {
|
|
if path, err := exec.LookPath("psql"); err == nil {
|
|
return path, nil
|
|
}
|
|
if _, err := os.Stat(defaultPsqlPath); err == nil {
|
|
return defaultPsqlPath, nil
|
|
}
|
|
return "", fmt.Errorf("psql not found on PATH and not at %q", defaultPsqlPath)
|
|
}
|
|
|
|
// parsePsqlTable parses psql's pipe-delimited output:
|
|
//
|
|
// col1 | col2
|
|
// -----+-----
|
|
// val1 | val2
|
|
func parsePsqlTable(output string) *psqlResult {
|
|
res := &psqlResult{}
|
|
out := strings.TrimSpace(output)
|
|
if out == "" {
|
|
return res
|
|
}
|
|
lines := strings.Split(out, "\n")
|
|
if len(lines) == 0 {
|
|
return res
|
|
}
|
|
|
|
for _, raw := range strings.Split(lines[0], "|") {
|
|
if col := strings.TrimSpace(raw); col != "" {
|
|
res.Columns = append(res.Columns, col)
|
|
}
|
|
}
|
|
|
|
dataStart := 1
|
|
if len(lines) >= 2 && strings.Contains(lines[1], "-") {
|
|
dataStart = 2
|
|
}
|
|
for i := dataStart; i < len(lines); i++ {
|
|
line := strings.TrimSpace(lines[i])
|
|
if line == "" || rowCountFooterRe.MatchString(line) {
|
|
continue
|
|
}
|
|
cells := strings.Split(line, "|")
|
|
for j := range cells {
|
|
cells[j] = strings.TrimSpace(cells[j])
|
|
}
|
|
switch {
|
|
case len(cells) == len(res.Columns):
|
|
res.Rows = append(res.Rows, cells)
|
|
case len(cells) > len(res.Columns):
|
|
res.Rows = append(res.Rows, cells[:len(res.Columns)])
|
|
default:
|
|
padded := make([]string, len(res.Columns))
|
|
copy(padded, cells)
|
|
for k := len(cells); k < len(res.Columns); k++ {
|
|
padded[k] = ""
|
|
}
|
|
res.Rows = append(res.Rows, padded)
|
|
}
|
|
}
|
|
return res
|
|
}
|
|
|
|
// toRowMaps converts psqlResult to a slice of column-keyed maps.
|
|
func toRowMaps(res *psqlResult) []map[string]interface{} {
|
|
if res == nil || len(res.Rows) == 0 {
|
|
return nil
|
|
}
|
|
out := make([]map[string]interface{}, 0, len(res.Rows))
|
|
for _, row := range res.Rows {
|
|
m := make(map[string]interface{}, len(res.Columns))
|
|
for j, col := range res.Columns {
|
|
if j < len(row) {
|
|
m[col] = row[j]
|
|
}
|
|
}
|
|
out = append(out, m)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func resolvePsqlHostPort(hostURI string, postgresPort int) (host, port string) {
|
|
host = defaultPsqlHost
|
|
port = defaultPsqlPort
|
|
if postgresPort > 0 {
|
|
port = strconv.Itoa(postgresPort)
|
|
}
|
|
if hostURI != "" {
|
|
if h, _, ok := strings.Cut(hostURI, ":"); ok && h != "" {
|
|
host = h
|
|
}
|
|
}
|
|
return host, port
|
|
}
|
|
|
|
// RunSQL implements the SQL retrieval path: preprocess, rewrite aliases,
|
|
// run psql subprocess, parse output.
|
|
func (e *infinityEngine) RunSQL(ctx context.Context, tableName string, sqlText string, kbIDs []string, _ string) ([]map[string]interface{}, error) {
|
|
if e == nil || e.client == nil {
|
|
return nil, fmt.Errorf("infinity RunSQL: client not initialized")
|
|
}
|
|
sqlText = strings.TrimSpace(sqlText)
|
|
if sqlText == "" {
|
|
return nil, fmt.Errorf("infinity RunSQL: empty SQL")
|
|
}
|
|
|
|
common.Debug("InfinityConnection.sql get sql", zap.String("sql", sqlText))
|
|
|
|
sqlText = preprocessSQL(sqlText)
|
|
|
|
aliasMap, _, err := loadFieldMapping(e.client.mappingFileName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("infinity RunSQL: %w", err)
|
|
}
|
|
sqlText = rewriteFieldAliases(sqlText, aliasMap)
|
|
|
|
common.Debug("InfinityConnection.sql to execute", zap.String("sql", sqlText))
|
|
|
|
host, port := resolvePsqlHostPort(e.client.hostURI, e.client.postgresPort)
|
|
res, err := runPsql(ctx, host, port, sqlText)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return toRowMaps(res), nil
|
|
}
|