Compare commits

...

4 Commits

Author SHA1 Message Date
Harsh Kashyap
45fc7feab4 fix(common/time_utils): correct None/empty timestamp fallback and ISO 8601 parsing (#16483)
Recovery PR for #16173 after the fork branch was accidentally reset
during rewrite-cleanup.

Cherry-picked onto current `main`:
- fix(common/time_utils): correct fallback timestamp and ISO-8601
normalization
- fix(common/time_utils): preserve zero timestamps and mark regression
tests
- test(common/time_utils): make fallback assertions deterministic

Supersedes closed #16173 — same branch
`Harsh23Kashyap/fix/time-utils-edgecases`, rebuilt per @yuzhichang
recovery steps in
https://github.com/infiniflow/ragflow/pull/16173#issuecomment-4829663835

---------

Co-authored-by: Harsh Kashyap <harshkashyap@Harshs-MacBook-Pro.local>
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-30 22:30:44 +08:00
Lynn
b53b693f22 Fix: CI (#16504)
### Summary

Fix race condition in parallel lefthook hooks causing ETXTBSY error
2026-06-30 22:14:11 +08:00
Jack
8e1dc4f308 revert: roll back tests.yml CI changes from PR #16391 (#16505)
## Summary

Two changes to make Go build \& run independent of native libraries
(office_oxide, pdfium, pdf_oxide).

## 1. Make native libraries optional (build.sh + Go source)

## 2. Roll back tests.yml CI changes from PR #16391
2026-06-30 21:50:37 +08:00
Yingfeng
5af361ed68 Add spacy based ner and relationship extractor for both python and Go version with equivalent outputs (#16456)
As title
2026-06-30 21:40:24 +08:00
30 changed files with 4182 additions and 55 deletions

View File

@@ -210,14 +210,8 @@ jobs:
-v "${PWD}/internal/cpp/resource:/usr/share/infinity/resource" \ -v "${PWD}/internal/cpp/resource:/usr/share/infinity/resource" \
infiniflow/infinity_builder:ubuntu22_clang20 infiniflow/infinity_builder:ubuntu22_clang20
sudo docker exec "${BUILDER_CONTAINER}" bash -c 'git config --global safe.directory "*" && cd /ragflow && ./build.sh --cpp' sudo docker exec "${BUILDER_CONTAINER}" bash -c 'git config --global safe.directory "*" && cd /ragflow && ./build.sh --cpp'
uv sync --python 3.13 --group test --frozen
./build.sh --go ./build.sh --go
- name: Prepare Python test environment
run: |
uv sync --python 3.13 --group test --frozen
uv pip install -e sdk/python
- name: Run Go unit tests - name: Run Go unit tests
# Runs after `./build.sh --go`, which guarantees the C++ static # Runs after `./build.sh --go`, which guarantees the C++ static
# library (librag_tokenizer_c_api.a) is present on disk. The Go # library (librag_tokenizer_c_api.a) is present on disk. The Go
@@ -240,7 +234,10 @@ jobs:
PKGS=$(go list ./... 2>/dev/null \ PKGS=$(go list ./... 2>/dev/null \
| grep -v '/internal/storage$' \ | grep -v '/internal/storage$' \
| grep -v '/internal/tokenizer$' \ | grep -v '/internal/tokenizer$' \
| grep -v '/internal/handler$' || true) | grep -v '/internal/handler$' \
| grep -v '/internal/deepdoc/parser/pdf/pdfium' \
| grep -v '/internal/deepdoc/parser/pdf/pdfoxide' \
| grep -v '/internal/deepdoc/parser/pdf' || true)
if [ -z "$PKGS" ]; then if [ -z "$PKGS" ]; then
./build.sh --test ./build.sh --test
else else
@@ -253,6 +250,11 @@ jobs:
sudo docker pull ubuntu:24.04 sudo docker pull ubuntu:24.04
sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -f Dockerfile -t ${RAGFLOW_IMAGE} . sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -f Dockerfile -t ${RAGFLOW_IMAGE} .
- name: Prepare Python test environment
run: |
uv sync --python 3.13 --group test --frozen
uv pip install -e sdk/python
- name: Prepare function test environment - name: Prepare function test environment
working-directory: docker working-directory: docker
run: | run: |
@@ -654,14 +656,8 @@ jobs:
-v "${PWD}/internal/cpp/resource:/usr/share/infinity/resource" \ -v "${PWD}/internal/cpp/resource:/usr/share/infinity/resource" \
infiniflow/infinity_builder:ubuntu22_clang20 infiniflow/infinity_builder:ubuntu22_clang20
sudo docker exec "${BUILDER_CONTAINER}" bash -c 'git config --global safe.directory "*" && cd /ragflow && ./build.sh --cpp' sudo docker exec "${BUILDER_CONTAINER}" bash -c 'git config --global safe.directory "*" && cd /ragflow && ./build.sh --cpp'
uv sync --python 3.13 --group test --frozen
./build.sh --go ./build.sh --go
- name: Prepare Python test environment
run: |
uv sync --python 3.13 --group test --frozen
uv pip install -e sdk/python
- name: Run Go unit tests - name: Run Go unit tests
# Runs after `./build.sh --go`, which guarantees the C++ static # Runs after `./build.sh --go`, which guarantees the C++ static
# library (librag_tokenizer_c_api.a) is present on disk. The Go # library (librag_tokenizer_c_api.a) is present on disk. The Go
@@ -684,7 +680,10 @@ jobs:
PKGS=$(go list ./... 2>/dev/null \ PKGS=$(go list ./... 2>/dev/null \
| grep -v '/internal/storage$' \ | grep -v '/internal/storage$' \
| grep -v '/internal/tokenizer$' \ | grep -v '/internal/tokenizer$' \
| grep -v '/internal/handler$' || true) | grep -v '/internal/handler$' \
| grep -v '/internal/deepdoc/parser/pdf/pdfium' \
| grep -v '/internal/deepdoc/parser/pdf/pdfoxide' \
| grep -v '/internal/deepdoc/parser/pdf' || true)
if [ -z "$PKGS" ]; then if [ -z "$PKGS" ]; then
./build.sh --test ./build.sh --test
else else
@@ -697,6 +696,11 @@ jobs:
sudo docker pull ubuntu:24.04 sudo docker pull ubuntu:24.04
sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -f Dockerfile -t ${RAGFLOW_IMAGE} . sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -f Dockerfile -t ${RAGFLOW_IMAGE} .
- name: Prepare Python test environment
run: |
uv sync --python 3.13 --group test --frozen
uv pip install -e sdk/python
- name: Prepare function test environment - name: Prepare function test environment
working-directory: docker working-directory: docker
run: | run: |

View File

@@ -143,7 +143,7 @@ check_office_oxide_deps() {
case "$(uname -s)" in case "$(uname -s)" in
Linux) lib_file="liboffice_oxide.so" ;; Linux) lib_file="liboffice_oxide.so" ;;
Darwin) lib_file="liboffice_oxide.dylib" ;; Darwin) lib_file="liboffice_oxide.dylib" ;;
*) echo -e "${RED}Unsupported OS for office_oxide${NC}"; exit 1 ;; *) echo -e "${RED}Unsupported OS for office_oxide${NC}"; return 1 ;;
esac esac
local lib_path="${OFFICE_OXIDE_PREFIX}/lib/${lib_file}" local lib_path="${OFFICE_OXIDE_PREFIX}/lib/${lib_file}"
@@ -164,14 +164,14 @@ check_office_oxide_deps() {
case "$(uname -m)" in case "$(uname -m)" in
x86_64) asset_name="native-linux-x86_64" ;; x86_64) asset_name="native-linux-x86_64" ;;
aarch64|arm64) asset_name="native-linux-aarch64" ;; aarch64|arm64) asset_name="native-linux-aarch64" ;;
*) echo -e "${RED}Unsupported arch: $(uname -m)${NC}"; exit 1 ;; *) echo -e "${RED}Unsupported arch: $(uname -m)${NC}"; return 1 ;;
esac esac
;; ;;
Darwin) Darwin)
case "$(uname -m)" in case "$(uname -m)" in
x86_64) asset_name="native-macos-x86_64" ;; x86_64) asset_name="native-macos-x86_64" ;;
aarch64|arm64) asset_name="native-macos-aarch64" ;; aarch64|arm64) asset_name="native-macos-aarch64" ;;
*) echo -e "${RED}Unsupported arch: $(uname -m)${NC}"; exit 1 ;; *) echo -e "${RED}Unsupported arch: $(uname -m)${NC}"; return 1 ;;
esac esac
;; ;;
esac esac
@@ -182,9 +182,9 @@ check_office_oxide_deps() {
_download_and_extract "$release_url" "${OFFICE_OXIDE_PREFIX}" _download_and_extract "$release_url" "${OFFICE_OXIDE_PREFIX}"
if [ ! -f "$lib_path" ]; then if [ ! -f "$lib_path" ]; then
echo -e "${RED}Error: Failed to install office_oxide native library (missing ${lib_path})${NC}" echo -e "${YELLOW}Warning: Failed to install office_oxide native library (missing ${lib_path})${NC}"
echo " Try: curl -fsSL ${release_url} | tar xzf - -C ${OFFICE_OXIDE_PREFIX}" echo " Try: curl -fsSL ${release_url} | tar xzf - -C ${OFFICE_OXIDE_PREFIX}"
exit 1 return 1
fi fi
echo -e "${GREEN}✓ office_oxide native library installed${NC}" echo -e "${GREEN}✓ office_oxide native library installed${NC}"
@@ -405,6 +405,7 @@ build_go() {
eval "$install_cmd" eval "$install_cmd"
fi fi
check_office_oxide_deps || true
setup_cgo_env setup_cgo_env
local strip_flags=() local strip_flags=()
@@ -445,35 +446,44 @@ build_go() {
echo -e "${GREEN}✓ Go ingestor built successfully: $INGESTOR_BINARY${NC}" echo -e "${GREEN}✓ Go ingestor built successfully: $INGESTOR_BINARY${NC}"
} }
# Configure CGO flags for native libraries (office_oxide, pdfium, pdf_oxide). # Configure CGO flags for native libraries.
# Call before any `go build` / `go test` step that links against these libraries. # setup_cgo_env — base: -I and -L paths only, no -l flags (those live in
# each package's own #cgo LDFLAGS pragma). Safe to call even when native
# libs are absent — just skips the paths that don't exist.
# setup_cgo_env_pdf — pdfium / pdf_oxide -L paths. Non-fatal when libs
# are missing. Only called by run_go_tests.
setup_cgo_env() { setup_cgo_env() {
# ── office_oxide ────────────────────────────────────────────────── # ── office_oxide (header + search path only, no -loffice_oxide) ───
check_office_oxide_deps if [ -f "${OFFICE_OXIDE_PREFIX}/include/office_oxide_c/office_oxide.h" ]; then
export CGO_CFLAGS="-I${OFFICE_OXIDE_PREFIX}/include/office_oxide_c${CGO_CFLAGS:+ $CGO_CFLAGS}" export CGO_CFLAGS="-I${OFFICE_OXIDE_PREFIX}/include/office_oxide_c${CGO_CFLAGS:+ $CGO_CFLAGS}"
export CGO_LDFLAGS="-L${OFFICE_OXIDE_PREFIX}/lib -loffice_oxide -Wl,-rpath,${OFFICE_OXIDE_PREFIX}/lib" fi
export LD_LIBRARY_PATH="${OFFICE_OXIDE_PREFIX}/lib${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}" if [ -f "${OFFICE_OXIDE_PREFIX}/lib/liboffice_oxide.so" ] || [ -f "${OFFICE_OXIDE_PREFIX}/lib/liboffice_oxide.dylib" ]; then
export CGO_LDFLAGS="-L${OFFICE_OXIDE_PREFIX}/lib${CGO_LDFLAGS:+ $CGO_LDFLAGS}"
export LD_LIBRARY_PATH="${OFFICE_OXIDE_PREFIX}/lib${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
fi
echo "CGO_CFLAGS: $CGO_CFLAGS"
echo "CGO_LDFLAGS: $CGO_LDFLAGS"
}
setup_cgo_env_pdf() {
# ── pdfium ──────────────────────────────────────────────────────── # ── pdfium ────────────────────────────────────────────────────────
check_pdfium_deps || return 1 check_pdfium_deps || true
if [ -f "${PDFIUM_PREFIX}/libpdfium.so" ]; then if [ -f "${PDFIUM_PREFIX}/libpdfium.so" ]; then
export CGO_LDFLAGS="$CGO_LDFLAGS -L${PDFIUM_PREFIX} -Wl,-rpath,${PDFIUM_PREFIX}" export CGO_LDFLAGS="$CGO_LDFLAGS -L${PDFIUM_PREFIX}"
export LD_LIBRARY_PATH="${PDFIUM_PREFIX}:${LD_LIBRARY_PATH}" export LD_LIBRARY_PATH="${PDFIUM_PREFIX}:${LD_LIBRARY_PATH}"
fi fi
# ── pdf_oxide ───────────────────────────────────────────────────── # ── pdf_oxide ─────────────────────────────────────────────────────
check_pdf_oxide_deps || return 1 check_pdf_oxide_deps || true
if [ -f "${PDF_OXIDE_PREFIX}/libpdf_oxide.so" ]; then if [ -f "${PDF_OXIDE_PREFIX}/libpdf_oxide.so" ]; then
export CGO_LDFLAGS="$CGO_LDFLAGS -L${PDF_OXIDE_PREFIX} -lpdf_oxide -Wl,-rpath,${PDF_OXIDE_PREFIX}" export CGO_LDFLAGS="$CGO_LDFLAGS -L${PDF_OXIDE_PREFIX}"
export LD_LIBRARY_PATH="${PDF_OXIDE_PREFIX}:${LD_LIBRARY_PATH}" export LD_LIBRARY_PATH="${PDF_OXIDE_PREFIX}:${LD_LIBRARY_PATH}"
elif [ -f "${PDF_OXIDE_PREFIX}/libpdf_oxide.a" ]; then elif [ -f "${PDF_OXIDE_PREFIX}/libpdf_oxide.a" ]; then
export CGO_LDFLAGS="$CGO_LDFLAGS ${PDF_OXIDE_PREFIX}/libpdf_oxide.a" export CGO_LDFLAGS="$CGO_LDFLAGS ${PDF_OXIDE_PREFIX}/libpdf_oxide.a"
fi fi
echo "CGO_CFLAGS: $CGO_CFLAGS" echo "CGO_LDFLAGS (with PDF): $CGO_LDFLAGS"
echo "Exporting CGO_CFLAGS: $CGO_CFLAGS"
echo "CGO_LDFLAGS: $CGO_LDFLAGS"
echo "Exporting CGO_LDFLAGS: $CGO_LDFLAGS"
} }
# Run Go unit tests with the same CGO env as `build_go`. Pass any extra args # Run Go unit tests with the same CGO env as `build_go`. Pass any extra args
@@ -482,7 +492,9 @@ run_go_tests() {
print_section "Running Go tests" print_section "Running Go tests"
cd "$PROJECT_ROOT" cd "$PROJECT_ROOT"
check_office_oxide_deps || true
setup_cgo_env setup_cgo_env
setup_cgo_env_pdf
if [ "$#" -eq 0 ]; then if [ "$#" -eq 0 ]; then
set -- ./... set -- ./...
@@ -522,6 +534,10 @@ run() {
cd "$PROJECT_ROOT" cd "$PROJECT_ROOT"
# Set LD_LIBRARY_PATH for native libraries that were linked at build time.
# Libraries are only in the search path when they were present during build.
setup_cgo_env
# admin_server must be running before ragflow_server, otherwise ragflow_server's # admin_server must be running before ragflow_server, otherwise ragflow_server's
# heartbeats to admin will error out (see internal/development.md). # heartbeats to admin will error out (see internal/development.md).
print_section "Starting admin server (background)" print_section "Starting admin server (background)"

View File

@@ -46,8 +46,9 @@ def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"):
>>> timestamp_to_date(1704067200000) >>> timestamp_to_date(1704067200000)
'2024-01-01 08:00:00' '2024-01-01 08:00:00'
""" """
if not timestamp: if timestamp is None or timestamp == "":
timestamp = time.time() timestamp = current_timestamp()
logging.debug("timestamp_to_date received empty timestamp; using current_timestamp() fallback")
timestamp = int(timestamp) / 1000 timestamp = int(timestamp) / 1000
time_array = time.localtime(timestamp) time_array = time.localtime(timestamp)
str_date = time.strftime(format_string, time_array) str_date = time.strftime(format_string, time_array)
@@ -144,11 +145,8 @@ def format_iso_8601_to_ymd_hms(time_str: str) -> str:
from dateutil import parser from dateutil import parser
try: try:
if parser.isoparse(time_str): dt = parser.isoparse(time_str)
dt = datetime.datetime.fromisoformat(time_str.replace("Z", "+00:00")) return dt.strftime("%Y-%m-%d %H:%M:%S")
return dt.strftime("%Y-%m-%d %H:%M:%S")
else:
return time_str
except Exception as e: except Exception as e:
logging.error(str(e)) logging.error(str(e))
return time_str return time_str

View File

@@ -133,6 +133,10 @@ set_target_properties(rag_tokenizer PROPERTIES
add_library(rag_tokenizer_c_api STATIC add_library(rag_tokenizer_c_api STATIC
rag_analyzer_c_api.cpp rag_analyzer_c_api.cpp
rag_analyzer_c_api.h rag_analyzer_c_api.h
thinc_ner.cpp
thinc_ner.h
thinc_parser.cpp
thinc_parser.h
rag_analyzer.cpp rag_analyzer.cpp
rag_analyzer.h rag_analyzer.h
dart_trie.h dart_trie.h

View File

@@ -99,6 +99,31 @@ char* RAGAnalyzer_GetTermTag(RAGAnalyzerHandle handle, const char* term);
// Returns: handle to the new analyzer instance, or NULL on failure // Returns: handle to the new analyzer instance, or NULL on failure
RAGAnalyzerHandle RAGAnalyzer_Copy(RAGAnalyzerHandle handle); RAGAnalyzerHandle RAGAnalyzer_Copy(RAGAnalyzerHandle handle);
// ---------------------------------------------------------------------------
// Named Entity Recognition (spaCy model inference)
// ---------------------------------------------------------------------------
// Create a ThincNER inference handle.
// model_ner_dir: path to the spaCy model's ner/ component directory
// model_vocab_dir: path to the spaCy model's vocab/ directory (optional, can be NULL)
// Returns: handle, or NULL on failure
RAGAnalyzerHandle ThincNER_Create(const char* model_ner_dir, const char* model_vocab_dir);
// Destroy a ThincNER handle.
void ThincNER_Destroy(RAGAnalyzerHandle handle);
// Run NER on pre-tokenized text.
// tokens_json: JSON array e.g. ["Apple","Inc.","was","founded","by","Steve","Jobs","."]
// Returns JSON array of entities, caller must free with ThincNER_FreeString.
char* ThincNER_Predict(RAGAnalyzerHandle handle, const char* tokens_json);
// Tokenize text using spaCy-compatible rules.
// Returns JSON array of token strings, caller must free with ThincNER_FreeString.
char* ThincNER_Tokenize(const char* text, const char* lang);
// Free a string returned by ThincNER_Predict or ThincNER_Tokenize.
void ThincNER_FreeString(char* ptr);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

610
internal/cpp/thinc_ner.cpp Normal file
View File

@@ -0,0 +1,610 @@
#pragma STDC FP_CONTRACT OFF
#include "thinc_ner.h"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
// =========================================================================
// JSON parser (minimal)
// =========================================================================
namespace {
std::string trim(const std::string& s) {
auto a = s.find_first_not_of(" \t\r\n");
return a == std::string::npos ? "" : s.substr(a, s.find_last_not_of(" \t\r\n")-a+1);
}
struct JVal {
enum Type {NUL,OBJ,ARR,STR,NUM,BOOL} type=NUL;
std::string str; std::vector<JVal> arr; std::unordered_map<std::string,JVal> obj; double num=0;
const JVal* get(const std::string& k) const { auto it=obj.find(k); return it!=obj.end()?&it->second:nullptr; }
int as_int() const { return (int)num; } int64_t as_i64() const { return (int64_t)num; }
};
struct JParser {
const char *p,*e; char pk() { while(p<e&&(*p==' '||*p=='\t'||*p=='\n'||*p=='\r'))++p; return p<e?*p:0; }
char nx() { while(p<e&&(*p==' '||*p=='\t'||*p=='\n'||*p=='\r'))++p; return p<e?*p++:0; }
JVal pv() { char c=pk(); if(c=='{')return po(); if(c=='[')return pa(); if(c=='"')return ps(); if(c=='t'||c=='f')return pb();
if(c=='n'){nx();nx();nx();nx();return JVal{};} return pn(); }
JVal po() { JVal v;v.type=JVal::OBJ; nx(); while(pk()!='}'){auto k=ps();nx();v.obj[k.str]=pv();if(pk()==',')nx();else break;}nx();return v; }
JVal pa() { JVal v;v.type=JVal::ARR; nx(); while(pk()!=']'){v.arr.push_back(pv());if(pk()==',')nx();else break;}nx();return v; }
JVal ps() { JVal v;v.type=JVal::STR; nx();while(p<e&&*p!='"'){if(*p=='\\'){++p;if(p<e)v.str+=*p++;}else v.str+=*p++;}if(p<e)++p;return v; }
JVal pn() { JVal v;v.type=JVal::NUM; auto s=p; if(p<e&&*p=='-')++p; while(p<e&&(*p>='0'&&*p<='9'))++p;
if(p<e&&*p=='.'){++p;while(p<e&&(*p>='0'&&*p<='9'))++p;}
if(p<e&&(*p=='e'||*p=='E')){++p;if(p<e&&(*p=='+'||*p=='-'))++p;while(p<e&&(*p>='0'&&*p<='9'))++p;}
if(s<p){try{v.num=std::stod(std::string(s,p-s));}catch(...){v.num=0;}} return v; }
JVal pb() { JVal v;v.type=JVal::BOOL; if(e-p>=4&&*p=='t'){v.str="true";p+=4;}else if(e-p>=5&&*p=='f'){v.str="false";p+=5;} return v; }
JVal parse(const std::string& j) { p=j.data(); e=p+j.size(); return pv(); }
};
// =========================================================================
// MurmurHash2 64-bit (vocab string→ID, seed=0 matching spaCy StringStore)
// =========================================================================
static uint64_t mh2_64a(const void* key, int len, uint64_t seed) {
const uint64_t m=0xc6a4a7935bd1e995ULL; const int r=47;
uint64_t h=seed^(uint64_t(len)*m); auto d=(const uint8_t*)key; int rm=len;
while(rm>=8){uint64_t k;memcpy(&k,d,8);k*=m;k^=k>>r;k*=m;h^=k;h*=m;d+=8;rm-=8;}
switch(rm){case 7:h^=uint64_t(d[6])<<48;case 6:h^=uint64_t(d[5])<<40;
case 5:h^=uint64_t(d[4])<<32;case 4:h^=uint64_t(d[3])<<24;
case 3:h^=uint64_t(d[2])<<16;case 2:h^=uint64_t(d[1])<<8;
case 1:h^=d[0];h*=m;break;}
h^=h>>r;h*=m;h^=h>>r; return h;
}
static uint64_t hash_feat(const std::string& s) { return s.empty()?0:mh2_64a(s.data(),(int)s.size(),0); }
// =========================================================================
// MurmurHash3_x64_128 (exact copy from mmh3 package, verified against thinc)
// =========================================================================
#define ROTL64(x,r) ((x << r) | (x >> (64 - r)))
static uint64_t getblock64(const uint64_t* p, size_t i) { uint64_t r; memcpy(&r, p+i, 8); return r; }
static uint64_t fmix64(uint64_t k) {
k ^= k >> 33; k *= 0xff51afd7ed558ccdULL;
k ^= k >> 33; k *= 0xc4ceb9fe1a85ec53ULL;
k ^= k >> 33; return k;
}
static void mmh3_x64_128(const void* key, int len, uint32_t seed, uint32_t out[4]) {
const uint8_t* data = (const uint8_t*)key;
int nblocks = len / 16;
uint64_t h1 = seed, h2 = seed;
const uint64_t c1 = 0x87c37b91114253d5ULL;
const uint64_t c2 = 0x4cf5ad432745937fULL;
const uint64_t* blocks = (const uint64_t*)(data);
for (int i = 0; i < nblocks; i++) {
uint64_t k1 = getblock64(blocks, i*2+0);
uint64_t k2 = getblock64(blocks, i*2+1);
k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1;
h1 = ROTL64(h1,27); h1 += h2; h1 = h1 * 5 + 0x52dce729;
k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2;
h2 = ROTL64(h2,31); h2 += h1; h2 = h2 * 5 + 0x38495ab5;
}
const uint8_t* tail = (const uint8_t*)(data + nblocks * 16);
uint64_t k1 = 0, k2 = 0;
switch (len & 15) {
case 15: k2 ^= ((uint64_t)tail[14]) << 48;
case 14: k2 ^= ((uint64_t)tail[13]) << 40;
case 13: k2 ^= ((uint64_t)tail[12]) << 32;
case 12: k2 ^= ((uint64_t)tail[11]) << 24;
case 11: k2 ^= ((uint64_t)tail[10]) << 16;
case 10: k2 ^= ((uint64_t)tail[9]) << 8;
case 9: k2 ^= ((uint64_t)tail[8]) << 0;
k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2;
case 8: k1 ^= ((uint64_t)tail[7]) << 56;
case 7: k1 ^= ((uint64_t)tail[6]) << 48;
case 6: k1 ^= ((uint64_t)tail[5]) << 40;
case 5: k1 ^= ((uint64_t)tail[4]) << 32;
case 4: k1 ^= ((uint64_t)tail[3]) << 24;
case 3: k1 ^= ((uint64_t)tail[2]) << 16;
case 2: k1 ^= ((uint64_t)tail[1]) << 8;
case 1: k1 ^= ((uint64_t)tail[0]) << 0;
k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1;
};
h1 ^= len; h2 ^= len;
h1 += h2; h2 += h1;
h1 = fmix64(h1); h2 = fmix64(h2);
h1 += h2; h2 += h1;
out[0] = (uint32_t)h1; out[1] = (uint32_t)(h1>>32);
out[2] = (uint32_t)h2; out[3] = (uint32_t)(h2>>32);
}
// =========================================================================
// HashEmbed
// =========================================================================
struct HashEmbed {
int n_rows=0,nO=0; uint32_t seed=0; std::vector<float> table;
bool load(int r, int o, const float* d) { n_rows=r;nO=o;table.assign(d,d+(size_t)r*o);return!table.empty(); }
void embed(uint64_t fid, float* out) const {
uint8_t in[8]; for(int i=0;i<8;i++)in[i]=(uint8_t)(fid>>(i*8));
uint32_t keys[4]; mmh3_x64_128(in,8,seed,keys);
for(int v=0;v<4;v++){int idx=(int)(keys[v]%(uint32_t)n_rows);for(int i=0;i<nO;i++)out[i]+=table[(size_t)idx*nO+i];}
}
};
// =========================================================================
// Features — dynamic based on n_embed
// en (6): NORM, PREFIX, SUFFIX, SHAPE, SPACY, IS_SPACE
// zh (5): NORM, PREFIX, SUFFIX, SHAPE, IS_SPACE
// =========================================================================
static uint64_t feat_norm(const std::string& t) {
std::string lo=t; std::transform(lo.begin(),lo.end(),lo.begin(),::tolower);
return hash_feat(lo);
}
// UTF-8 aware: get first Unicode codepoint as string
static std::string utf8_first(const std::string& s) {
if(s.empty()) return "";
unsigned char c=(unsigned char)s[0];
int l=1;
if((c&0xE0)==0xC0) l=2;
else if((c&0xF0)==0xE0) l=3;
else if((c&0xF8)==0xF0) l=4;
return s.substr(0,(size_t)l<=s.size()?l:1);
}
// Count Unicode codepoints in a string
static size_t utf8_len(const std::string& s) {
size_t n=0;
for(size_t i=0;i<s.size();n++){
unsigned char c=(unsigned char)s[i];
if((c&0x80)==0) i+=1;
else if((c&0xE0)==0xC0) i+=2;
else if((c&0xF0)==0xE0) i+=3;
else if((c&0xF8)==0xF0) i+=4;
else i+=1;
}
return n;
}
// Get suffix: last `count` Unicode codepoints
static std::string utf8_last(const std::string& s, size_t count) {
size_t ulen=utf8_len(s);
if(ulen<=count) return s;
// Find byte position of the (ulen-count)-th codepoint
size_t pos=0;
for(size_t i=0;i<ulen-count;i++){
unsigned char c=(unsigned char)s[pos];
if((c&0x80)==0) pos+=1;
else if((c&0xE0)==0xC0) pos+=2;
else if((c&0xF0)==0xE0) pos+=3;
else if((c&0xF8)==0xF0) pos+=4;
else pos+=1;
}
return s.substr(pos);
}
static uint64_t feat_prefix(const std::string& t) {
// spaCy: string[:1].lower() → hash the lowercased prefix
std::string p = t.empty() ? "" : utf8_first(t);
std::transform(p.begin(), p.end(), p.begin(), ::tolower);
return hash_feat(p);
}
static uint64_t feat_suffix(const std::string& t) {
// spaCy: string[-3:].lower() → hash the lowercased suffix
size_t ulen=utf8_len(t);
std::string s = ulen>=3 ? utf8_last(t,3) : t;
std::transform(s.begin(), s.end(), s.begin(), ::tolower);
return hash_feat(s);
}
static uint64_t feat_shape(const std::string& t) {
std::string sh;
for(unsigned char c:t){
if(c>0x7F)sh+='x'; // CJK → 'x' (matches spaCy zh shape)
else if(std::isupper(c))sh+='X';
else if(std::islower(c))sh+='x';
else if(std::isdigit(c))sh+='d';
else sh+=c;
}
return hash_feat(sh);
}
// Extract features based on n_embed count. Returns vector of hash values.
// NER model's tok2vec uses 4 features: NORM, PREFIX, SUFFIX, SHAPE
// (The pipeline's standalone tok2vec uses 6 features including SPACY and IS_SPACE.)
// Feature order matches the HashEmbed table order in the model.
static std::vector<uint64_t> extract_features(const std::string& t, int n_embed) {
std::vector<uint64_t> ids;
ids.push_back(feat_norm(t)); // #0: NORM (all models)
ids.push_back(feat_prefix(t)); // #1: PREFIX
ids.push_back(feat_suffix(t)); // #2: SUFFIX
ids.push_back(feat_shape(t)); // #3: SHAPE
if(n_embed==5) {
ids.push_back(0); // #4: IS_SPACE (zh/ja: 5-embed models, no SPACY)
} else if(n_embed>=6) {
ids.push_back(1); // #4: SPACY (en/de/fr/es/pt: 6-embed models)
ids.push_back(0); // #5: IS_SPACE
}
return ids;
}
// =========================================================================
// Layers
// =========================================================================
// Kahan compensated dot product: reduces floating-point accumulation error
// for long dot products (e.g. 576 terms in Maxout).
static float kahan_dot(const float* a, const float* b, int n) {
float sum = 0.0f;
float c = 0.0f;
for (int i = 0; i < n; i++) {
float y = a[i] * b[i] - c;
float t = sum + y;
c = (t - sum) - y;
sum = t;
}
return sum;
}
static void linear(float* out, const float* in, const float* W, const float* b, int nO, int nI) {
for(int i=0;i<nO;i++)out[i]=b[i]+kahan_dot(W+(size_t)i*nI,in,nI);
}
static void relu_inplace(float* x, int n) { for(int i=0;i<n;i++)x[i]=x[i]>0?x[i]:0; }
// Maxout: y[i] = max_p(b[i,p] + W[i,p,:] @ in)
static void maxout(float* out, const float* in, const float* W, const float* b, int nO, int nP, int nI) {
for(int i=0;i<nO;i++){
float best=-1e30f;
for(int p=0;p<nP;p++){
float s = b[(size_t)i*nP+p] + kahan_dot(W+(((size_t)i*nP+p)*nI), in, nI);
if(s>best)best=s;
}
out[i]=best;
}
}
// LayerNorm: y = G * (x-mean)/sqrt(var+eps) + b
static void layernorm(float* out, const float* in, int d, const float* G, const float* b, float eps) {
float mn=0,vr=0; for(int i=0;i<d;i++)mn+=in[i]; mn/=d;
for(int i=0;i<d;i++)vr+=(in[i]-mn)*(in[i]-mn); vr/=d;
float is=1.0f/sqrtf(vr+eps);
for(int i=0;i<d;i++)out[i]=G[i]*(in[i]-mn)*is+b[i];
}
// ExpandWindow: for token at index `idx` over all_tokens[n_tokens×dim], produce [t-1, t, t+1]
static void expand_win(float* out, const float* all, int n, int dim, int idx) {
int off = idx*dim;
if(idx>0)memcpy(out,all+(idx-1)*dim,dim*sizeof(float)); else memset(out,0,dim*sizeof(float));
memcpy(out+dim,all+off,dim*sizeof(float));
if(idx<n-1)memcpy(out+2*dim,all+(idx+1)*dim,dim*sizeof(float)); else memset(out+2*dim,0,dim*sizeof(float));
}
// BILUO decoder
struct Entity { std::string text,label; int start,end; float conf; };
static std::vector<Entity> decode_biluo(const std::vector<std::string>& tok, const std::vector<std::string>& lbl) {
std::vector<Entity> ents; int n=(int)tok.size(),st=-1; std::string et,ex;
for(int i=0;i<n;i++){
auto& l=lbl[i];
if(l.empty()||l=="O"){if(st>=0){ents.push_back({ex,et,st,i-1,0.85f});st=-1;et.clear();ex.clear();}continue;}
if(l.size()<3||l[1]!='-'){if(st>=0){ents.push_back({ex,et,st,i-1,0.85f});st=-1;}continue;}
char a=l[0]; std::string ty=l.substr(2);
if(a=='U'){if(st>=0){ents.push_back({ex,et,st,i-1,0.85f});st=-1;}ents.push_back({tok[i],ty,i,i,0.85f});}
else if(a=='B'){if(st>=0)ents.push_back({ex,et,st,i-1,0.85f});st=i;et=ty;ex=tok[i];}
else if(a=='I'){if(st>=0&&et==ty)ex+=" "+tok[i];else{if(st>=0)ents.push_back({ex,et,st,i-1,0.85f});st=i;et=ty;ex=tok[i];}}
else if(a=='L'){if(st>=0&&et==ty){ex+=" "+tok[i];ents.push_back({ex,et,st,i,0.85f});}else ents.push_back({tok[i],ty,i,i,0.85f});st=-1;et.clear();ex.clear();}
}
if(st>=0)ents.push_back({ex,et,st,n-1,0.85f});
return ents;
}
// Tokenizer
static std::vector<std::string> tokenize_en(const std::string& t) {
std::vector<std::string> r; std::string cur;
for(size_t i=0;i<t.size();i++){unsigned char c=(unsigned char)t[i];
if(std::isalpha(c)||std::isdigit(c)||c>127)cur+=c;
else if(c=='.'&&!cur.empty()&&i+1<t.size()&&std::isalpha((unsigned char)t[i+1]))cur+='.';
else{if(!cur.empty()){r.push_back(cur);cur.clear();}if(!std::isspace(c))r.push_back(std::string(1,(char)c));}}
if(!cur.empty())r.push_back(cur); return r;
}
static std::vector<std::string> tokenize_zh(const std::string& t) {
std::vector<std::string> r;
for(size_t i=0;i<t.size();i++){unsigned char c=(unsigned char)t[i];
if((c&0x80)==0){if(std::isalpha(c)||std::isdigit(c)){std::string w;while(i<t.size()&&(std::isalpha((unsigned char)t[i])||std::isdigit((unsigned char)t[i])))w+=t[i++];r.push_back(w);i--;}else if(!std::isspace(c))r.push_back(std::string(1,(char)c));}
else{int l=1;if((c&0xE0)==0xC0)l=2;else if((c&0xF0)==0xE0)l=3;else if((c&0xF8)==0xF0)l=4;r.push_back(t.substr(i,l));i+=l-1;}}
return r;
}
} // namespace
// =========================================================================
// State
// =========================================================================
struct State {
// HashEmbeds
std::vector<HashEmbed> embeds;
// Post-embed Maxout (576→96)
std::vector<float> poW,poB; int po_nO=96,po_nP=3,po_nI=576;
// Post-embed LayerNorm
std::vector<float> poG,poB2; bool has_poLN=false;
// Residual encoder (4 blocks)
struct ResBlk{bool has=false;std::vector<float>W,b,lnG,lnb;};
ResBlk res[4]; int n_res=0;
// NER hidden (96→64)
std::vector<float> hW,hB; int hO=64; bool has_hid=false;
// PrecomputableAffine: W_full[nP=3][nO=64][nI=2][nD=64], b_full[nO=64][nI=2]
// We use f=0 (first feature only): pre_out[p][o] = sum_d W[p][o][0][d] * hid[d] + b[o][0]
std::vector<float> pW_full; // flattened [3*64*2*64]
std::vector<float> pB_full; // flattened [64*2]
int p_nP=3, p_nO=64, p_nI=2, p_nD=64; bool has_pre=false;
// Classifier (64→n_actions)
std::vector<float> cW,cB; int nAct=0; bool has_cls=false;
std::vector<std::string> actLbl;
};
// =========================================================================
// Load model
// =========================================================================
static bool load(const std::string& dir, State* s) {
std::ifstream cf(dir+"/model.ckpt"); if(!cf){std::cerr<<"No model.ckpt\n";return false;}
std::stringstream cb;cb<<cf.rdbuf();
JVal ck=JParser().parse(cb.str()); if(ck.type!=JVal::OBJ)return false;
std::ifstream bf(dir+"/model.bin",std::ios::binary|std::ios::ate); if(!bf)return false;
size_t bz=bf.tellg();bf.seekg(0); std::vector<float> bin(bz/4); bf.read((char*)bin.data(),bz);
auto sl=[&](int64_t o, int64_t c)->std::vector<float>{
if(o+c>(int64_t)bin.size())return{}; return std::vector<float>(bin.begin()+o,bin.begin()+o+c);
};
auto ld=[&](const std::string& k, std::vector<float>* v, int* r0=nullptr,int* r1=nullptr,int* r2=nullptr)->bool{
auto* e=ck.get(k); if(!e)return false;
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
if(!sv||!ov||!cv)return false;
*v=sl(ov->as_i64(),cv->as_i64());
if(r0)*r0=sv->arr.size()>=1?sv->arr[0].as_int():1;
if(r1)*r1=sv->arr.size()>=2?sv->arr[1].as_int():1;
if(r2)*r2=sv->arr.size()>=3?sv->arr[2].as_int():1;
return!v->empty();
};
// HashEmbeds — dynamic count (6 for en, 5 for zh, etc.)
for(int ei=0;;ei++){
auto* e=ck.get("embed_"+std::to_string(ei)+"_E"); if(!e)break;
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
if(!sv||!ov||!cv)break;
int rs=sv->arr[0].as_int(),nO=sv->arr[1].as_int();
int64_t expected=(int64_t)rs*nO;
if(cv->as_i64()<expected)break; // count too short → malformed
auto d=sl(ov->as_i64(),cv->as_i64()); if(d.empty())break;
s->embeds.emplace_back(); s->embeds.back().load(rs,nO,d.data());
}
// Seeds
std::ifstream ff(dir+"/feature_config.json");
if(ff){std::stringstream fb;fb<<ff.rdbuf();auto cfg=JParser().parse(fb.str());auto* sa=cfg.get("embed_seeds");
if(sa&&sa->type==JVal::ARR)for(int i=0;i<(int)sa->arr.size()&&i<(int)s->embeds.size();i++)s->embeds[i].seed=(uint32_t)sa->arr[i].as_int();}
int r0=0,r1=0,r2=0;
// Post-embed
if(ld("poW",&s->poW,&r0,&r1,&r2)){s->po_nO=r0;s->po_nP=r1;s->po_nI=r2;ld("poB",&s->poB);}
if(ld("poG",&s->poG)){ld("poB2",&s->poB2);s->has_poLN=true;}
// Residual
for(int ri=0;ri<4;ri++){auto pk="res"+std::to_string(ri);auto& rb=s->res[ri];
if(ld(pk+"W",&rb.W,&r0,&r1,&r2)){ld(pk+"B",&rb.b);ld(pk+"lnG",&rb.lnG);ld(pk+"lnb",&rb.lnb);rb.has=true;s->n_res++;}}
// NER hidden
if(ld("hW",&s->hW,&r0,&r1)){s->hO=r0;ld("hB",&s->hB);s->has_hid=true;}
// PrecomputableAffine: load full 4D W and 2D b
// has_pre is only set when ALL of weight buffer, bias buffer,
// and hidden-dimension match (p_nD == hO) are satisfied.
{
auto* e=ck.get("pW_full"); if(e){
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
if(sv&&ov&&cv){
int nP=sv->arr.size()>=1?sv->arr[0].as_int():1;
int nO=sv->arr.size()>=2?sv->arr[1].as_int():1;
int nI=sv->arr.size()>=3?sv->arr[2].as_int():1;
int nD=sv->arr.size()>=4?sv->arr[3].as_int():1;
s->p_nP=nP; s->p_nO=nO; s->p_nI=nI; s->p_nD=nD;
size_t total = (size_t)nP * nO * nI * nD;
s->pW_full = sl(ov->as_i64(), cv->as_i64());
bool pw_ok = s->pW_full.size() >= total;
// Load bias inside pW_full block to access dimension info
bool pb_ok = false;
if(auto* pb_e=ck.get("pB_full")){
auto pb_ov=pb_e->get("offset"),pb_cv=pb_e->get("count");
if(pb_ov&&pb_cv){
s->pB_full = sl(pb_ov->as_i64(), pb_cv->as_i64());
pb_ok = s->pB_full.size() >= (size_t)nO * nI;
}
}
bool dim_ok = (nD == s->hO);
s->has_pre = pw_ok && pb_ok && dim_ok;
}
}
}
// Classifier
if(ld("cW",&s->cW,&r0,&r1)){s->nAct=r0;ld("cB",&s->cB);s->has_cls=true;}
return!s->embeds.empty();
}
static bool load_labels(const std::string& dir, State* s) {
std::ifstream f(dir+"/labels.json"); if(!f)return false;
std::stringstream b;b<<f.rdbuf();
auto d=JParser().parse(b.str()); auto* am=d.get("action_to_label");
if(!am||am->type!=JVal::OBJ)return false;
int mx=0; for(auto&[k,v]:am->obj){try{int a=std::stoi(k);if(a>mx)mx=a;}catch(...){}};
int n=s->nAct>0?s->nAct:mx+1; s->actLbl.resize(n,"O");
for(auto&[k,v]:am->obj){try{int a=std::stoi(k);if(a>=0&&a<n)s->actLbl[a]=v.str;}catch(...){}}
return!s->actLbl.empty();
}
// =========================================================================
// C API
// =========================================================================
ThincNERHandle ThincNER_Create(const char* d, const char*) {
auto* s=new State(); if(!load(d,s)){delete s;return nullptr;}
if(!load_labels(d,s))s->actLbl.resize(74,"O");
return s;
}
void ThincNER_Destroy(ThincNERHandle h) { delete (State*)h; }
char* ThincNER_Predict(ThincNERHandle h, const char* tj) {
auto* s=(State*)h; if(!s)return strdup("[]");
if(!tj)return strdup("[]");
// Parse tokens
std::vector<std::string> tok; std::string j(tj); size_t p=0;
while((p=j.find('"',p))!=std::string::npos){auto e=j.find('"',p+1);if(e==std::string::npos)break;std::string t=j.substr(p+1,e-p-1);if(!t.empty())tok.push_back(t);p=e+1;}
int n=(int)tok.size(); if(!n)return strdup("[]");
int NE=(int)s->embeds.size();
// Derive per-embed dimension from loaded tensors (all embed tables share the same nO)
int D = NE > 0 ? s->embeds[0].nO : 96;
int EC = NE * D;
// ---- Step 1: HashEmbed → concat (NER model: 4×96=384, pipe: 6×96=576) ----
std::vector<float> emb((size_t)n*EC,0);
for(int i=0;i<n;i++){
auto ids=extract_features(tok[i],NE);
size_t b=(size_t)i*EC;
for(int e=0;e<NE&&e<(int)s->embeds.size();e++)
s->embeds[e].embed(ids[e], emb.data()+b + (size_t)e*s->embeds[e].nO);
}
// ---- Step 2: Post-embed Maxout (576→96) ----
std::vector<float> pe((size_t)n*D);
for(int i=0;i<n;i++)maxout(pe.data()+(size_t)i*D,emb.data()+(size_t)i*EC,s->poW.data(),s->poB.data(),s->po_nO,s->po_nP,s->po_nI);
// ---- Step 3: Post-embed LayerNorm ----
std::vector<float> pln((size_t)n*D,0);
if(s->has_poLN){for(int i=0;i<n;i++)layernorm(pln.data()+(size_t)i*D,pe.data()+(size_t)i*D,D,s->poG.data(),s->poB2.data(),1e-6f);}
else pln=pe;
// ---- Step 4: Residual encoder blocks ----
std::vector<float> enc=pln;
for(int ri=0;ri<s->n_res;ri++){
auto& blk=s->res[ri]; if(!blk.has)continue;
int wd=D*3;
std::vector<float> exp((size_t)n*wd);
for(int i=0;i<n;i++)expand_win(exp.data()+(size_t)i*wd,enc.data(),n,D,i);
std::vector<float> mx((size_t)n*D);
for(int i=0;i<n;i++)maxout(mx.data()+(size_t)i*D,exp.data()+(size_t)i*wd,blk.W.data(),blk.b.data(),D,3,wd);
std::vector<float> ln((size_t)n*D);
if(!blk.lnG.empty()){for(int i=0;i<n;i++)layernorm(ln.data()+(size_t)i*D,mx.data()+(size_t)i*D,D,blk.lnG.data(),blk.lnb.data(),1e-6f);}
else ln=mx;
for(int i=0;i<n;i++){float* op=enc.data()+(size_t)i*D;for(int j=0;j<D;j++)op[j]+=ln[(size_t)i*D+j];}
}
// ---- Step 5: NER tok2vec linear (96→64, no ReLU) ----
// The NER model's layers[0] ends with a bare linear (no activation).
// This produces the 64-dim token vectors that feed into the PrecomputableAffine.
std::vector<float> hid((size_t)n*s->hO);
if(s->has_hid){for(int i=0;i<n;i++){linear(hid.data()+(size_t)i*s->hO,enc.data()+(size_t)i*D,s->hW.data(),s->hB.data(),s->hO,D);}}
else{for(int i=0;i<n;i++){int c=std::min(D,s->hO);memcpy(hid.data()+i*s->hO,enc.data()+i*D,c*4);}}
// ---- Step 6: PrecomputableAffine → Maxout → Classifier → constrained decoding ----
// Matches the spaCy ParserStepModel's predict_states formula:
// cached[t][f][o*nP+p] = W[f,o,p,:] @ hid[t] + b[o,p] (f=0..nF-1, W=[nF,nO,nP,nI])
// unmaxed[o,nP+p] = sum_f cached[t][f][o*nP+p] (sum over nF features)
// unmaxed += bias[o,p] (add bias once, not nF times)
// hid_vec[o] = max(unmaxed[o*nP+0], unmaxed[o*nP+1])
// scores[a] = cW[a][:] @ hid_vec + cB[a]
//
// Feature token indices match spaCy's BiluoPushDown transition system:
// f=0: B(0) = buffer front = current token index
// f=1: S(0) = stack top = entity_start if in entity, else back-off to B(0)
// f=2: S(1) = stack second = back-off to B(0) (stack has ≤1 item in simple case)
auto label_type = [](const std::string& lbl) -> char { return lbl.empty() ? 'O' : lbl[0]; };
auto label_etype = [](const std::string& lbl) -> std::string { return lbl.size()<3?"":lbl.substr(2); };
std::vector<std::string> tl(n, "O");
if(s->has_cls && s->has_pre){
int nF=s->p_nP, nO=s->p_nO, nP=s->p_nI, nD=s->p_nD; // W: [nF, nO, nP, nI]
std::vector<float> unmaxed((size_t)nO * nP, 0);
std::vector<float> hid_vec(nO, 0);
std::vector<float> scores(s->nAct, 0);
int entity_start = -1; // token index of current B-entity start, -1 = no entity
for(int i=0;i<n;i++){
// Determine feature token indices from transition state
// (matches spaCy BiluoPushDown B(0)/S(0)/S(1) mapping)
int ft[3];
ft[0] = i; // B(0) = current token
if(entity_start >= 0) {
ft[1] = entity_start; // S(0) = entity start
ft[2] = i; // S(1) = back-off to B(0)
} else {
ft[1] = i; // S(0) = back-off to B(0)
ft[2] = i; // S(1) = back-off to B(0)
}
// PrecomputableAffine: pre[f][o][p] = W[f][o][p][:] @ hid[ft[f]] + b[o][p]
memset(unmaxed.data(), 0, (size_t)nO * nP * sizeof(float));
for(int f=0;f<nF;f++){
const float* hf = hid.data() + (size_t)ft[f] * nO;
for(int o=0;o<nO;o++){
for(int p=0;p<nP;p++){
size_t base = (((size_t)f * nO + o) * nP + p) * nD;
float val = 0;
for(int d=0;d<nD;d++){
val += s->pW_full[base + d] * hf[d];
}
unmaxed[(size_t)o * nP + p] += val;
}
}
}
// Add bias ONCE (not nF times)
for(int o=0;o<nO;o++){
for(int p=0;p<nP;p++){
unmaxed[(size_t)o * nP + p] += s->pB_full[(size_t)o * nP + p];
}
}
// Maxout: hid_vec[o] = max_p unmaxed[o*nP + p]
for(int o=0;o<nO;o++){
float best = unmaxed[(size_t)o * nP];
for(int p=1;p<nP;p++){
float v = unmaxed[(size_t)o * nP + p];
if(v > best) best = v;
}
hid_vec[o] = best;
}
// Classifier: scores = cW @ hid_vec + cB
linear(scores.data(), hid_vec.data(), s->cW.data(), s->cB.data(), s->nAct, nO);
// Constrained greedy decoding
char prev_type = i>0 ? label_type(tl[i-1]) : 'O';
std::string prev_etype = i>0 ? label_etype(tl[i-1]) : "";
int bst=-1; float bv=-1e30f;
for(int a=0;a<s->nAct;a++){
const std::string& lbl = (a<(int)s->actLbl.size()) ? s->actLbl[a] : "O";
if(lbl.empty()) continue;
char ct = label_type(lbl);
std::string ce = label_etype(lbl);
bool valid=false;
if(prev_type=='O'||prev_type=='L'||prev_type=='U')
valid = (ct=='O'||ct=='B'||ct=='U');
else if(prev_type=='B'||prev_type=='I'){
if(ct=='O') valid=true;
else if((ct=='I'||ct=='L')&&ce==prev_etype) valid=true;
}
if(!valid) continue;
if(scores[a]>bv){bv=scores[a];bst=a;}
}
if(bst>=0) {
tl[i] = s->actLbl[bst];
// Update entity_start for next token (BiluoPushDown stack tracking)
char ct = label_type(tl[i]);
if(ct == 'B') entity_start = i;
else if(ct == 'I' || ct == 'L') { /* entity continues, keep entity_start */ }
else entity_start = -1; // O or U → no entity open
}
}
}
// ---- Step 8: BILUO decode ----
auto ents=decode_biluo(tok,tl);
std::string r="["; for(size_t i=0;i<ents.size();i++){if(i)r+=",";r+="{\"text\":\""+ents[i].text+"\",\"label\":\""+ents[i].label+"\",\"start\":"+std::to_string(ents[i].start)+",\"end\":"+std::to_string(ents[i].end)+",\"confidence\":"+std::to_string(ents[i].conf)+"}";} r+="]";
return strdup(r.c_str());
}
void ThincNER_FreeString(char* p) { free(p); }
char* ThincNER_Tokenize(const char* t, const char* l) {
if(!t)return strdup("[]");
std::string lang = l ? std::string(l) : "";
// de, fr, es, pt: European Latin -> en tokenizer; zh, ja: CJK -> zh tokenizer
bool is_cjk = (lang == "zh" || lang == "ja");
auto tok = is_cjk ? tokenize_zh(t) : tokenize_en(t);
std::string r="["; for(size_t i=0;i<tok.size();i++){if(i)r+=",";r+="\""+tok[i]+"\"";} r+="]";
return strdup(r.c_str());
}

45
internal/cpp/thinc_ner.h Normal file
View File

@@ -0,0 +1,45 @@
#ifndef THINC_NER_H
#define THINC_NER_H
#ifdef __cplusplus
extern "C" {
#endif
#include <stdint.h>
#include <stdbool.h>
// ---------------------------------------------------------------------------
// C API for spaCy model inference (en_core_web_sm / zh_core_web_sm)
//
// Loads model.ckpt + model.bin directly from a spaCy model directory.
// ---------------------------------------------------------------------------
typedef void* ThincNERHandle;
// Create / destroy an inference handle for a single spaCy model.
// model_dir: path to the model component directory, e.g.
// "models/en_core_web_sm-3.7.1/ner/"
// "models/zh_core_web_sm-3.7.1/ner/"
// Returns NULL on failure.
ThincNERHandle ThincNER_Create(const char* model_ner_dir, const char* model_vocab_dir);
void ThincNER_Destroy(ThincNERHandle handle);
// Run NER on pre-tokenized text.
// tokens_json: JSON array of token strings, e.g. ["Apple", "Inc.", "was", ...]
// Returns JSON array of entities:
// [{"text":"Apple Inc.","label":"ORG","start":0,"end":10,"confidence":0.85}, ...]
// Caller must free with ThincNER_FreeString.
char* ThincNER_Predict(ThincNERHandle handle, const char* tokens_json);
// Free a string returned by ThincNER_Predict.
void ThincNER_FreeString(char* ptr);
// Utility: tokenize text using spaCy-compatible rules.
// Returns JSON array of token strings.
char* ThincNER_Tokenize(const char* text, const char* lang);
#ifdef __cplusplus
}
#endif
#endif // THINC_NER_H

View File

@@ -0,0 +1,530 @@
#include "thinc_parser.h"
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
// =========================================================================
// Minimal JSON parser (replicated from thinc_ner.cpp)
// =========================================================================
namespace {
struct JVal {
enum Type{NUL,OBJ,ARR,STR,NUM,BOOL}type=NUL;
std::string str; std::vector<JVal> arr; std::unordered_map<std::string,JVal> obj; double num=0;
const JVal* get(const std::string& k)const{auto it=obj.find(k);return it!=obj.end()?&it->second:nullptr;}
int as_int()const{return(int)num;}int64_t as_i64()const{return(int64_t)num;}
};
struct JParser {
const char *p,*e;
char pk(){while(p<e&&(*p==' '||*p=='\t'||*p=='\n'||*p=='\r'))++p;return p<e?*p:0;}
char nx(){while(p<e&&(*p==' '||*p=='\t'||*p=='\n'||*p=='\r'))++p;return p<e?*p++:0;}
JVal pv(){char c=pk();if(c=='{')return po();if(c=='[')return pa();if(c=='"')return ps();if(c=='t'||c=='f')return pb();
if(c=='n'){nx();nx();nx();nx();return JVal{};}return pn();}
JVal po(){JVal v;v.type=JVal::OBJ;nx();while(p<e&&pk()!='}'){auto k=ps();nx();v.obj[k.str]=pv();if(p<e&&pk()==',')nx();else break;}if(p<e)nx();return v;}
JVal pa(){JVal v;v.type=JVal::ARR;nx();while(p<e&&pk()!=']'){v.arr.push_back(pv());if(p<e&&pk()==',')nx();else break;}if(p<e)nx();return v;}
JVal ps(){JVal v;v.type=JVal::STR;nx();while(p<e&&*p!='"'){if(*p=='\\'){++p;if(p<e){
switch(*p){case'"':case'\\':case'/':v.str+=*p++;break;case'n':v.str+='\n';++p;break;case't':v.str+='\t';++p;break;case'r':v.str+='\r';++p;break;case'b':v.str+='\b';++p;break;case'f':v.str+='\f';++p;break;case'u':{if(p+4<e){char tmp[5]={p[1],p[2],p[3],p[4],0};v.str+=(char)strtol(tmp,nullptr,16);p+=5;}else{++p;}}break;default:v.str+=*p++;break;}
}}else v.str+=*p++;}if(p<e)++p;return v;}
JVal pn(){JVal v;v.type=JVal::NUM;auto s=p;if(p<e&&*p=='-')++p;while(p<e&&(*p>='0'&&*p<='9'))++p;
if(p<e&&*p=='.'){++p;while(p<e&&(*p>='0'&&*p<='9'))++p;}
if(p<e&&(*p=='e'||*p=='E')){++p;if(p<e&&(*p=='+'||*p=='-'))++p;while(p<e&&(*p>='0'&&*p<='9'))++p;}
if(s<p){try{v.num=std::stod(std::string(s,p-s));}catch(...){v.num=0;}}return v;}
JVal pb(){JVal v;v.type=JVal::BOOL;if(e-p>=4&&*p=='t'){v.str="true";p+=4;}else if(e-p>=5&&*p=='f'){v.str="false";p+=5;}return v;}
JVal parse(const std::string& j){p=j.data();e=p+j.size();return pv();}
};
// =========================================================================
// HashEmbed + MurmurHash (copied from thinc_ner.cpp)
// =========================================================================
#define ROTL64(x,r) ((x << r) | (x >> (64 - r)))
static uint64_t getblock64(const uint64_t* p, size_t i) { uint64_t r; memcpy(&r, p+i, 8); return r; }
static uint64_t fmix64(uint64_t k) {
k ^= k >> 33; k *= 0xff51afd7ed558ccdULL; k ^= k >> 33; k *= 0xc4ceb9fe1a85ec53ULL; k ^= k >> 33; return k;
}
static void mmh3_x64_128(const void* key, int len, uint32_t seed, uint32_t out[4]) {
const uint8_t* data=(const uint8_t*)key; int nblocks=len/16;
uint64_t h1=seed,h2=seed,c1=0x87c37b91114253d5ULL,c2=0x4cf5ad432745937fULL;
const uint64_t* blocks=(const uint64_t*)data;
for(int i=0;i<nblocks;i++){
uint64_t k1=getblock64(blocks,i*2+0),k2=getblock64(blocks,i*2+1);
k1*=c1;k1=ROTL64(k1,31);k1*=c2;h1^=k1;h1=ROTL64(h1,27);h1+=h2;h1=h1*5+0x52dce729;
k2*=c2;k2=ROTL64(k2,33);k2*=c1;h2^=k2;h2=ROTL64(h2,31);h2+=h1;h2=h2*5+0x38495ab5;
}
const uint8_t* tail=(const uint8_t*)(data+nblocks*16);uint64_t k1=0,k2=0;
switch(len&15){
case 15:k2^=((uint64_t)tail[14])<<48;case 14:k2^=((uint64_t)tail[13])<<40;
case 13:k2^=((uint64_t)tail[12])<<32;case 12:k2^=((uint64_t)tail[11])<<24;
case 11:k2^=((uint64_t)tail[10])<<16;case 10:k2^=((uint64_t)tail[9])<<8;
case 9:k2^=((uint64_t)tail[8])<<0;k2*=c2;k2=ROTL64(k2,33);k2*=c1;h2^=k2;
case 8:k1^=((uint64_t)tail[7])<<56;case 7:k1^=((uint64_t)tail[6])<<48;
case 6:k1^=((uint64_t)tail[5])<<40;case 5:k1^=((uint64_t)tail[4])<<32;
case 4:k1^=((uint64_t)tail[3])<<24;case 3:k1^=((uint64_t)tail[2])<<16;
case 2:k1^=((uint64_t)tail[1])<<8;case 1:k1^=((uint64_t)tail[0])<<0;k1*=c1;k1=ROTL64(k1,31);k1*=c2;h1^=k1;
};h1^=len;h2^=len;h1+=h2;h2+=h1;h1=fmix64(h1);h2=fmix64(h2);h1+=h2;h2+=h1;
out[0]=(uint32_t)h1;out[1]=(uint32_t)(h1>>32);out[2]=(uint32_t)h2;out[3]=(uint32_t)(h2>>32);
}
static uint64_t mh2_64a(const void* key, int len, uint64_t seed) {
const uint64_t m=0xc6a4a7935bd1e995ULL;const int r=47;
uint64_t h=seed^(uint64_t(len)*m);auto d=(const uint8_t*)key;int rm=len;
while(rm>=8){uint64_t k;memcpy(&k,d,8);k*=m;k^=k>>r;k*=m;h^=k;h*=m;d+=8;rm-=8;}
switch(rm){case 7:h^=uint64_t(d[6])<<48;case 6:h^=uint64_t(d[5])<<40;case 5:h^=uint64_t(d[4])<<32;
case 4:h^=uint64_t(d[3])<<24;case 3:h^=uint64_t(d[2])<<16;case 2:h^=uint64_t(d[1])<<8;case 1:h^=d[0];h*=m;break;}
h^=h>>r;h*=m;h^=h>>r;return h;
}
static uint64_t hash_feat(const std::string& s){return s.empty()?0:mh2_64a(s.data(),(int)s.size(),0);}
struct HashEmbed{
int n_rows=0,nO=0;uint32_t seed=0;std::vector<float> table;
bool load(int r,int o,const float*d){n_rows=r;nO=o;table.assign(d,d+(size_t)r*o);return!table.empty();}
void embed(uint64_t fid,float* out)const{
uint8_t in[8];for(int i=0;i<8;i++)in[i]=(uint8_t)(fid>>(i*8));
uint32_t keys[4];mmh3_x64_128(in,8,seed,keys);
for(int v=0;v<4;v++){int idx=(int)(keys[v]%(uint32_t)n_rows);for(int i=0;i<nO;i++)out[i]+=table[(size_t)idx*nO+i];}
}
};
// =========================================================================
// Layer primitives
// =========================================================================
static void linear(float* out, const float* in, const float* W, const float* b, int nO, int nI) {
for(int i=0;i<nO;i++){float s=b[i];for(int j=0;j<nI;j++)s+=W[(size_t)i*nI+j]*in[j];out[i]=s;}
}
static void maxout(float* out, const float* in, const float* W, const float* b, int nO, int nP, int nI) {
for(int i=0;i<nO;i++){float best=-1e30f;for(int p=0;p<nP;p++){float s=b[(size_t)i*nP+p];for(int j=0;j<nI;j++)s+=W[((size_t)i*nP+p)*nI+j]*in[j];if(s>best)best=s;}out[i]=best;}
}
static void layernorm(float* out, const float* in, int d, const float* G, const float* b, float eps) {
float mn=0,vr=0;for(int i=0;i<d;i++)mn+=in[i];mn/=d;for(int i=0;i<d;i++)vr+=(in[i]-mn)*(in[i]-mn);vr/=d;float is=1.0f/sqrtf(vr+eps);
for(int i=0;i<d;i++)out[i]=G[i]*(in[i]-mn)*is+b[i];
}
static void expand_win(float* out, const float* all, int n, int dim, int idx) {
int off=idx*dim;if(idx>0)memcpy(out,all+(idx-1)*dim,dim*sizeof(float));else memset(out,0,dim*sizeof(float));
memcpy(out+dim,all+off,dim*sizeof(float));if(idx<n-1)memcpy(out+2*dim,all+(idx+1)*dim,dim*sizeof(float));else memset(out+2*dim,0,dim*sizeof(float));
}
// =========================================================================
// Feature extraction — UTF-8 aware, matching spaCy
// =========================================================================
static std::string u8_first(const std::string& s){
if(s.empty())return"";unsigned char c=(unsigned char)s[0];int l=1;
if((c&0xE0)==0xC0)l=2;else if((c&0xF0)==0xE0)l=3;else if((c&0xF8)==0xF0)l=4;
return s.substr(0,(size_t)l<=s.size()?l:1);
}
static size_t u8_len(const std::string& s){
size_t n=0;
for(size_t i=0;i<s.size();n++){unsigned char c=(unsigned char)s[i];
if((c&0x80)==0)i+=1;else if((c&0xE0)==0xC0)i+=2;else if((c&0xF0)==0xE0)i+=3;else if((c&0xF8)==0xF0)i+=4;else i+=1;}
return n;
}
static std::string u8_last(const std::string& s, size_t count){
size_t ul=u8_len(s);if(ul<=count)return s;size_t pos=0;
for(size_t i=0;i<ul-count;i++){unsigned char c=(unsigned char)s[pos];
if((c&0x80)==0)pos+=1;else if((c&0xE0)==0xC0)pos+=2;else if((c&0xF0)==0xE0)pos+=3;else if((c&0xF8)==0xF0)pos+=4;else pos+=1;}
return s.substr(pos);
}
static std::vector<uint64_t> extract_features(const std::string& t, int n_embed){
auto fn=[&](const std::string& s){return hash_feat(s);};
auto fp=[&](const std::string& s){std::string p=s.empty()?"":u8_first(s);std::transform(p.begin(),p.end(),p.begin(),::tolower);return hash_feat(p);};
auto fs=[&](const std::string& s){std::string su=u8_len(s)>=3?u8_last(s,3):s;std::transform(su.begin(),su.end(),su.begin(),::tolower);return hash_feat(su);};
auto fsh=[&](const std::string& t2){std::string sh;for(unsigned char c:t2){if(c>0x7F)sh+='x';else if(std::isupper(c))sh+='X';else if(std::islower(c))sh+='x';else if(std::isdigit(c))sh+='d';else sh+=c;}return hash_feat(sh);};
std::vector<uint64_t> ids;
std::string lo=t;std::transform(lo.begin(),lo.end(),lo.begin(),::tolower);
ids.push_back(fn(lo));ids.push_back(fp(t));ids.push_back(fs(t));ids.push_back(fsh(t));
if(n_embed==6){ids.push_back(1);ids.push_back(0);}else{ids.push_back(0);}
return ids;
}
// =========================================================================
// Tok2vec forward pass (shared with NER)
// =========================================================================
struct Tok2vecModel {
std::vector<HashEmbed> embeds;
std::vector<float> poW,poB,poG,poB2; bool has_poLN=false;
int po_nO=96,po_nP=3,po_nI=576;
struct ResBlk{bool has=false;std::vector<float>W,b,lnG,lnb;};
ResBlk res[4]; int n_res=0;
bool load(const std::string& dir) {
std::ifstream cf(dir+"/model.ckpt"); if(!cf)return false;
std::stringstream cb;cb<<cf.rdbuf();
JVal ck=JParser().parse(cb.str()); if(ck.type!=JVal::OBJ)return false;
std::ifstream bf(dir+"/model.bin",std::ios::binary|std::ios::ate); if(!bf)return false;
size_t bz=bf.tellg();bf.seekg(0); if(bz%4!=0||bz==0)return false;
std::vector<float> bin(bz/4); bf.read((char*)bin.data(),bz);
auto sl=[&](int64_t o,int64_t c)->std::vector<float>{
if(o+c>(int64_t)bin.size())return{}; return std::vector<float>(bin.begin()+o,bin.begin()+o+c);
};
auto ld=[&](const std::string& k, std::vector<float>* v,int* r0=nullptr,int* r1=nullptr,int* r2=nullptr)->bool{
auto* e=ck.get(k);if(!e)return false;
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
if(!sv||!ov||!cv)return false;
*v=sl(ov->as_i64(),cv->as_i64());
if(r0)*r0=sv->arr.size()>=1?sv->arr[0].as_int():1;
if(r1)*r1=sv->arr.size()>=2?sv->arr[1].as_int():1;
if(r2)*r2=sv->arr.size()>=3?sv->arr[2].as_int():1;
return!v->empty();
};
for(int ei=0;;ei++){
auto* e=ck.get("embed_"+std::to_string(ei)+"_E");if(!e)break;
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
if(!sv||!ov||!cv)break;int rs=sv->arr[0].as_int(),nO=sv->arr[1].as_int();
int64_t exp=(int64_t)rs*nO;if(cv->as_i64()<exp)break;
auto d=sl(ov->as_i64(),cv->as_i64());if(d.empty())break;
embeds.emplace_back();embeds.back().load(rs,nO,d.data());
}
std::ifstream ff(dir+"/feature_config.json");
if(ff){std::stringstream fb;fb<<ff.rdbuf();auto cfg=JParser().parse(fb.str());auto* sa=cfg.get("embed_seeds");
if(sa&&sa->type==JVal::ARR)for(int i=0;i<(int)sa->arr.size()&&i<(int)embeds.size();i++)embeds[i].seed=(uint32_t)sa->arr[i].as_int();}
int r0=0,r1=0,r2=0;
if(!ld("poW",&poW,&r0,&r1,&r2))return false;
po_nO=r0;po_nP=r1;po_nI=r2;
if(!ld("poB",&poB))return false;
if(ld("poG",&poG)){ld("poB2",&poB2);has_poLN=true;}
for(int ri=0;ri<4;ri++){auto pk="res"+std::to_string(ri);auto& rb=res[ri];
if(ld(pk+"W",&rb.W,&r0,&r1,&r2)){ld(pk+"B",&rb.b);ld(pk+"lnG",&rb.lnG);ld(pk+"lnb",&rb.lnb);rb.has=true;n_res++;}}
return!embeds.empty();
}
// Run tok2vec → (n_tokens, 96)
void forward(const std::vector<std::string>& tokens, float* out) {
int n=(int)tokens.size(),D=96,NE=(int)embeds.size(),EC=NE*D;
std::vector<float> emb((size_t)n*EC,0);
for(int i=0;i<n;i++){
auto ids=extract_features(tokens[i],NE);
size_t b=(size_t)i*EC;
for(int e=0;e<NE;e++)embeds[e].embed(ids[e],emb.data()+b+(size_t)e*D);
}
std::vector<float> pe((size_t)n*D);
for(int i=0;i<n;i++)maxout(pe.data()+(size_t)i*D,emb.data()+(size_t)i*EC,poW.data(),poB.data(),D,po_nP,EC);
std::vector<float> pln((size_t)n*D,0);
if(has_poLN)for(int i=0;i<n;i++)layernorm(pln.data()+(size_t)i*D,pe.data()+(size_t)i*D,D,poG.data(),poB2.data(),1e-6f);else pln=pe;
std::vector<float> enc=pln;
for(int ri=0;ri<n_res;ri++){if(!res[ri].has)continue;
int wd=D*3;std::vector<float> exp((size_t)n*wd);
for(int i=0;i<n;i++)expand_win(exp.data()+(size_t)i*wd,enc.data(),n,D,i);
std::vector<float> mx((size_t)n*D);for(int i=0;i<n;i++)maxout(mx.data()+(size_t)i*D,exp.data()+(size_t)i*wd,res[ri].W.data(),res[ri].b.data(),D,3,wd);
std::vector<float> ln((size_t)n*D);if(!res[ri].lnG.empty())for(int i=0;i<n;i++)layernorm(ln.data()+(size_t)i*D,mx.data()+(size_t)i*D,D,res[ri].lnG.data(),res[ri].lnb.data(),1e-6f);else ln=mx;
for(int i=0;i<n;i++){float* op=enc.data()+(size_t)i*D;for(int j=0;j<D;j++)op[j]+=ln[(size_t)i*D+j];}
}
memcpy(out,enc.data(),(size_t)n*D*sizeof(float));
}
};
// =========================================================================
// Arc-hybrid Parser
// =========================================================================
struct ParserModel {
int nO=64,nP=8,nI=2,n_actions=0;
std::vector<float> pW_hid,pb_hid; // 96→64
std::vector<float> pW_pre,pb_pre,pad_pre; // preaffine
std::vector<float> pW_cls,pb_cls; // classifier
std::vector<std::string> move_names;
bool load(const std::string& dir) {
std::ifstream cf(dir+"/model.ckpt"); if(!cf)return false;
std::stringstream cb;cb<<cf.rdbuf();
JVal ck=JParser().parse(cb.str()); if(ck.type!=JVal::OBJ)return false;
std::ifstream bf(dir+"/model.bin",std::ios::binary|std::ios::ate); if(!bf)return false;
size_t bz=bf.tellg();bf.seekg(0); if(bz%4!=0||bz==0)return false;
std::vector<float> bin(bz/4); bf.read((char*)bin.data(),bz);
auto sl=[&](int64_t o,int64_t c)->std::vector<float>{
if(o+c>(int64_t)bin.size())return{}; return std::vector<float>(bin.begin()+o,bin.begin()+o+c);
};
auto ld=[&](const std::string& k, std::vector<float>* v,int* r0=nullptr,int* r1=nullptr,int* r2=nullptr)->bool{
auto* e=ck.get(k);if(!e)return false;
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
if(!sv||!ov||!cv)return false;
*v=sl(ov->as_i64(),cv->as_i64());
if(r0)*r0=sv->arr.size()>=1?sv->arr[0].as_int():1;
if(r1)*r1=sv->arr.size()>=2?sv->arr[1].as_int():1;
if(r2)*r2=sv->arr.size()>=3?sv->arr[2].as_int():1;
return!v->empty();
};
int r0,r1,r2;
if(!ld("pW_hid",&pW_hid,&r0,&r1))return false;
nO=r0;ld("pb_hid",&pb_hid);
if(!ld("pW_pre",&pW_pre,&r0,&r1,&r2))return false;
nP=r0;nO=r1;nI=r2;
ld("pb_pre",&pb_pre);ld("pad_pre",&pad_pre);
if(!ld("pW_cls",&pW_cls,&r0,&r1))return false;
n_actions=r0;ld("pb_cls",&pb_cls);
std::ifstream mf(dir+"/meta.json");
if(mf){std::stringstream mb;mb<<mf.rdbuf();auto meta=JParser().parse(mb.str());auto* mn=meta.get("move_names");
if(mn&&mn->type==JVal::ARR)for(auto& v:mn->arr)move_names.push_back(v.str);}
return!pW_hid.empty() && !pW_pre.empty() && !pW_cls.empty();
}
// Run parser forward + state machine → (heads, labels)
void parse(const float* tokvecs, int n_tokens,
std::vector<int>& out_heads, std::vector<std::string>& out_labels) {
// 1. Hidden layer: 96→64
std::vector<float> hidden((size_t)n_tokens*nO,0);
for(int i=0;i<n_tokens;i++) linear(hidden.data()+(size_t)i*nO, tokvecs+(size_t)i*96,
pW_hid.data(), pb_hid.data(), nO, 96);
// 2. Pre-compute features
std::vector<float> precomp((size_t)(n_tokens+1)*nP*nO*nI,0);
// Pad token (index 0)
memcpy(precomp.data(), pad_pre.data(), (size_t)nP*nO*nI*sizeof(float));
// Real tokens
for(int i=0;i<n_tokens;i++){
size_t toff = (size_t)(i+1)*nP*nO*nI;
for(int p=0;p<nP;p++){
for(int w=0;w<nI;w++){
size_t base = toff + (size_t)p*nO*nI + (size_t)w;
float* out = precomp.data() + base;
// W[p][w][o][d] = [nP][nI][nO][nO]
for(int o=0;o<nO;o++){
float s = pb_pre[(size_t)w*nO + o];
for(int d=0;d<nO;d++) s += pW_pre[((size_t)p*nO*nI + (size_t)o*nI + w)*nO + d] * hidden[(size_t)i*nO + d];
out[(size_t)o*nI] = s;
}
}
}
}
// Helper: get feature value for a token at piece p, window w, output dim o
auto feat = [&](int idx, int p, int w, int o) -> float {
int ri = (idx < 0 || idx >= n_tokens) ? 0 : idx + 1;
return precomp[(size_t)ri*nP*nO*nI + (size_t)p*nO*nI + (size_t)w + (size_t)o*nI];
};
// 3. Arc-hybrid state machine
out_heads.assign(n_tokens, -1);
out_labels.assign(n_tokens, "");
std::vector<int> stack;
std::vector<int> buffer(n_tokens);
for(int i=0;i<n_tokens;i++) buffer[i]=i;
// Validate move_names covers all actions before indexing
if((int)move_names.size()!=n_actions){out_heads.clear();out_labels.clear();return;}
int act_S=-1, act_D=-1;
for(int i=0;i<(int)move_names.size();i++){
if(move_names[i]=="S") act_S=i;
if(move_names[i]=="D") act_D=i;
}
auto leftmost = [&](int idx)->int{
for(int i=0;i<n_tokens;i++) if(out_heads[i]==idx) return i;
return -1;
};
auto rightmost = [&](int idx)->int{
int r=-1; for(int i=0;i<n_tokens;i++) if(out_heads[i]==idx) r=i;
return r;
};
std::vector<float> scores(n_actions,0);
std::vector<float> feats(nP*nI*nO,0);
for(int step=0; step<n_tokens*4 && !(buffer.empty()&&stack.size()<=1); step++){
int s0=stack.empty()?-1:stack.back();
int s1=stack.size()<2?-1:stack[stack.size()-2];
int s2=stack.size()<3?-1:stack[stack.size()-3];
int b0=buffer.empty()?-1:buffer[0];
int b1=buffer.size()<2?-1:buffer[1];
// Build feature indices (same as verified Python implementation)
int idxs[16]={s0,s1, b0,s0, s0,leftmost(s0), s0,rightmost(s0),
s1,leftmost(s1), s1,rightmost(s1), s2,b1, b0,b1};
// Build feature vector: sum of precomputed features at each (idx, piece, window)
for(int o=0;o<nO;o++) feats[o]=0;
for(int p=0;p<nP;p++){
for(int w=0;w<nI;w++){
int ti=idxs[p*2+w];
for(int o=0;o<nO;o++){
feats[(size_t)p*nI*nO + (size_t)w*nO + o] = feat(ti,p,w,o);
}
}
}
// Classify
for(int a=0;a<n_actions;a++){
float s=pb_cls[a];
// Use HIDDEN state directly (64-dim) as classifier input
int cls_idx = b0 >= 0 ? b0 : (s0 >= 0 ? s0 : 0);
for(int j=0;j<nO;j++) s += pW_cls[(size_t)a*nO+j] * hidden[(size_t)cls_idx*nO+j];
scores[a]=s;
}
// Pick best VALID action
int best=-1; float best_sc=-1e30f;
for(int a=0;a<n_actions;a++){
bool valid=false;
const std::string& n=move_names[a];
if(n=="S") valid=!buffer.empty() && (int)stack.size()<n_tokens;
else if(n=="D") valid=!stack.empty();
else if(n.size()>=2 && (n[0]=='L'||n[0]=='R')) valid=stack.size()>=2;
if(valid && scores[a]>best_sc){best_sc=scores[a];best=a;}
}
if(best<0) break;
const std::string& act=move_names[best];
if(act=="S"){stack.push_back(buffer[0]);buffer.erase(buffer.begin());}
else if(act=="D"){stack.pop_back();}
else if(act.size()>=2){
std::string lbl=act.substr(2);
if(act[0]=='L'){out_heads[s0]=s1;out_labels[s0]=lbl;stack.erase(stack.end()-2);}
else{out_heads[s1]=s0;out_labels[s1]=lbl;stack.pop_back();}
}
}
}
};
// =========================================================================
// Combined state
// =========================================================================
struct ParserState {
Tok2vecModel tok2vec;
ParserModel parser;
bool loaded=false;
};
struct TaggerState {
Tok2vecModel tok2vec;
std::vector<float> tW,tb; // (n_tags, 96), (n_tags,)
std::vector<std::string> tags;
bool loaded=false;
};
} // namespace
// =========================================================================
// C API — Parser
// =========================================================================
ThincParserHandle ThincParser_Create(const char* ner_dir, const char* parser_dir) {
auto* s=new ParserState();
if(!ner_dir||!parser_dir){delete s;return nullptr;}
// Load PIPELINE tok2vec from <base>/tok2vec/ subdirectory (not NER's internal 4HE).
// ner_dir is typically <model_base>/ner/.
std::string base = std::string(ner_dir);
if(base.size()>=4 && base.substr(base.size()-4)=="/ner") base.resize(base.size()-4);
if(!s->tok2vec.load(base+"/tok2vec")){delete s;return nullptr;}
if(!s->parser.load(std::string(parser_dir))){delete s;return nullptr;}
s->loaded=true;
return s;
}
void ThincParser_Destroy(ThincParserHandle h) { delete (ParserState*)h; }
char* ThincParser_Predict(ThincParserHandle h, const char* tokens_json) {
auto* s=(ParserState*)h;
if(!s||!s->loaded||!tokens_json) return strdup("[]");
auto j=JParser().parse(std::string(tokens_json));
if(j.type!=JVal::ARR) return strdup("[]");
std::vector<std::string> tokens;
for(auto& v:j.arr) tokens.push_back(v.str);
int n=(int)tokens.size();
if(!n) return strdup("[]");
// Run tok2vec
std::vector<float> tokvecs((size_t)n*96,0);
s->tok2vec.forward(tokens, tokvecs.data());
// Run parser
std::vector<int> heads;
std::vector<std::string> labels;
s->parser.parse(tokvecs.data(), n, heads, labels);
// Build JSON output
std::string r="[";
for(int i=0;i<n;i++){
if(i)r+=",";
r+="{\"text\":\""+tokens[i]+"\",\"head\":"+std::to_string(heads[i])+
",\"dep\":\""+labels[i]+"\",\"index\":"+std::to_string(i)+"}";
}
r+="]";
return strdup(r.c_str());
}
void ThincParser_FreeString(char* p) { free(p); }
// =========================================================================
// C API — Tagger
// =========================================================================
ThincTaggerHandle ThincTagger_Create(const char* ner_dir, const char* tagger_dir) {
auto* s=new TaggerState();
if(!ner_dir||!tagger_dir){delete s;return nullptr;}
// Load PIPELINE tok2vec from <base>/tok2vec/ (6HE, not NER's internal 4HE)
std::string tbase = std::string(ner_dir);
if(tbase.size()>=4 && tbase.substr(tbase.size()-4)=="/ner") tbase.resize(tbase.size()-4);
if(!s->tok2vec.load(tbase+"/tok2vec")){delete s;return nullptr;}
std::ifstream cf(std::string(tagger_dir)+"/model.ckpt"); if(!cf){delete s;return nullptr;}
std::stringstream cb;cb<<cf.rdbuf();
JVal ck=JParser().parse(cb.str()); if(ck.type!=JVal::OBJ){delete s;return nullptr;}
std::ifstream bf(std::string(tagger_dir)+"/model.bin",std::ios::binary|std::ios::ate); if(!bf){delete s;return nullptr;}
size_t bz=bf.tellg();bf.seekg(0); std::vector<float> bin(bz/4); bf.read((char*)bin.data(),bz);
auto sl=[&](int64_t o,int64_t c)->std::vector<float>{
if(o+c>(int64_t)bin.size())return{}; return std::vector<float>(bin.begin()+o,bin.begin()+o+c);
};
auto ld=[&](const std::string& k, std::vector<float>* v,int* r0=nullptr)->bool{
auto* e=ck.get(k);if(!e)return false;
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
if(!sv||!ov||!cv)return false;
*v=sl(ov->as_i64(),cv->as_i64());
if(r0)*r0=sv->arr.size()>=1?sv->arr[0].as_int():1;
return!v->empty();
};
int r0=0; ld("tW",&s->tW,&r0); ld("tb",&s->tb);
std::ifstream mf(std::string(tagger_dir)+"/meta.json");
if(mf){std::stringstream mb;mb<<mf.rdbuf();auto meta=JParser().parse(mb.str());auto* tg=meta.get("tags");
if(tg&&tg->type==JVal::ARR)for(auto& v:tg->arr)s->tags.push_back(v.str);}
s->loaded=!s->tW.empty();
return s;
}
void ThincTagger_Destroy(ThincTaggerHandle h) { delete (TaggerState*)h; }
char* ThincTagger_Predict(ThincTaggerHandle h, const char* tokens_json) {
auto* s=(TaggerState*)h;
if(!s||!s->loaded||!tokens_json||s->tW.empty()) return strdup("[]");
auto j=JParser().parse(std::string(tokens_json));
if(j.type!=JVal::ARR)return strdup("[]");
std::vector<std::string> tokens;
for(auto& v:j.arr) tokens.push_back(v.str);
int n=(int)tokens.size(), n_tags=(int)s->tW.size()/96;
if(!n||!n_tags)return strdup("[]");
// Run tok2vec to get 96-dim embeddings, then softmax + argmax
std::vector<float> tokvecs((size_t)n*96,0);
s->tok2vec.forward(tokens, tokvecs.data());
std::vector<int> best_tags(n, 0);
for(int i=0;i<n;i++){
float best_sc=-1e30f;
for(int t=0;t<n_tags;t++){
float sc=s->tb[t];
for(int j=0;j<96;j++) sc += s->tW[(size_t)t*96+j] * tokvecs[(size_t)i*96+j];
if(sc>best_sc){best_sc=sc;best_tags[i]=t;}
}
}
// Strip morphologizer output to just POS (e.g. "Gender=Masc|Number=Sing|POS=NOUN" → "NOUN")
// For non-morphologizer models the tag string is used as-is.
auto pos_only = [](const std::string& t) -> std::string {
auto p = t.find("POS=");
if(p==std::string::npos) return t;
auto s = p+4;
auto e = t.find_first_of("|;", s);
if(e==std::string::npos) e = t.size();
return t.substr(s, e-s);
};
std::string r="[";
for(int i=0;i<n;i++){
if(i)r+=",";
std::string tag = best_tags[i] < (int)s->tags.size() ? s->tags[best_tags[i]] : "";
r+="{\"text\":\""+tokens[i]+"\",\"tag\":\""+pos_only(tag)+"\",\"index\":"+std::to_string(i)+"}";
}
r+="]";
return strdup(r.c_str());
}
void ThincTagger_FreeString(char* p) { free(p); }

View File

@@ -0,0 +1,49 @@
#ifndef THINC_PARSER_H
#define THINC_PARSER_H
#ifdef __cplusplus
extern "C" {
#endif
#include <stdint.h>
#include <stdbool.h>
// ---------------------------------------------------------------------------
// C API for spaCy dependency parser and POS tagger
// ---------------------------------------------------------------------------
typedef void* ThincParserHandle;
typedef void* ThincTaggerHandle;
// Parser: create/destroy
// model_ner_dir: path to NER model directory (shared tok2vec weights + vocab).
// model_parser_dir: path to parser model directory.
ThincParserHandle ThincParser_Create(const char* model_ner_dir, const char* model_parser_dir);
void ThincParser_Destroy(ThincParserHandle handle);
// Run dependency parser on pre-tokenized text.
// tokens_json: JSON array of token strings, e.g. ["Apple", "was", "founded", ...]
// Returns JSON array of token annotations:
// [{"text":"Apple","head":2,"dep":"nsubjpass","index":0}, ...]
// where head is the index of the head token (0-based), -1 for root.
// Caller must free with ThincParser_FreeString.
char* ThincParser_Predict(ThincParserHandle handle, const char* tokens_json);
void ThincParser_FreeString(char* ptr);
// Tagger: create/destroy
// model_ner_dir: path to NER model directory (for tok2vec weights)
// model_tagger_dir: path to tagger model directory
ThincTaggerHandle ThincTagger_Create(const char* model_ner_dir, const char* model_tagger_dir);
void ThincTagger_Destroy(ThincTaggerHandle handle);
// Run POS tagger on pre-tokenized text.
// tokens_json: JSON array of token strings.
// Returns JSON array: [{"text":"Apple","tag":"NNP","index":0}, ...]
char* ThincTagger_Predict(ThincTaggerHandle handle, const char* tokens_json);
void ThincTagger_FreeString(char* ptr);
#ifdef __cplusplus
}
#endif
#endif // THINC_PARSER_H

View File

@@ -1,3 +1,5 @@
//go:build cgo
// Package pdfium renders PDF pages using the system's libpdfium.so // Package pdfium renders PDF pages using the system's libpdfium.so
// (bundled with pypdfium2). It exists solely to replace pdf_oxide's // (bundled with pypdfium2). It exists solely to replace pdf_oxide's
// RenderPageRaw for use cases where image quality matters for downstream // RenderPageRaw for use cases where image quality matters for downstream

View File

@@ -0,0 +1,788 @@
// Go implementation of dependency-based relation extraction.
// Direct port of Python DepRelationExtractor — semantica-aligned.
// Operates on a dependency tree (heads + labels) independent of parser.
package extractor
import "strings"
// Verb lemmatization — multi-language
var verbLemma = map[string]string{
// English
"founded": "found", "founding": "found",
"works": "work", "working": "work",
"based": "base", "basing": "base",
"located": "locate", "locating": "locate",
"situated": "situate", "situating": "situate",
"acquired": "acquire", "acquiring": "acquire",
"employed": "employ", "employing": "employ",
"hired": "hire", "hiring": "hire",
"born": "bear",
"joined": "join", "joining": "join",
"merged": "merge", "merging": "merge",
"bought": "buy", "buying": "buy",
"created": "create", "creating": "create",
"established": "establish", "establishing": "establish",
"started": "start", "starting": "start",
"led": "lead", "leading": "lead",
"managed": "manage", "managing": "manage",
"headed": "head", "heading": "head",
"ran": "run", "running": "run",
"owned": "own", "owning": "own",
"developed": "develop", "developing": "develop",
"wrote": "write", "written": "write", "writing": "write",
"published": "publish", "publishing": "publish",
"invested": "invest", "investing": "invest",
"partnered": "partner", "partnering": "partner",
"collaborated": "collaborate", "collaborating": "collaborate",
"sets": "set",
// German
"gegründet": "gründen", "gründete": "gründen",
"arbeitet": "arbeiten", "arbeitete": "arbeiten",
"befindet": "befinden",
"liegt": "liegen", "lag": "liegen",
"geboren": "gebären",
"erworben": "erwerben", "erwarb": "erwerben",
"gekauft": "kaufen", "kaufte": "kaufen",
"übernommen": "übernehmen", "übernahm": "übernehmen",
// French
"fondé": "fonder", "fondée": "fonder",
"créé": "créer", "créée": "créer",
"travaille": "travailler",
"employé": "employer", "employée": "employer",
"situé": "situer", "située": "situer",
"né": "naître", "née": "naître",
"acquis": "acquérir",
// Spanish + Portuguese (shared forms)
"fundado": "fundar", "fundada": "fundar",
"creado": "crear", "creada": "crear",
"criado": "criar", "criada": "criar",
"trabaja": "trabajar", "trabalha": "trabalhar",
"ubicado": "ubicar", "ubicada": "ubicar",
"situado": "situar", "situada": "situar",
"localizado": "localizar", "localizada": "localizar",
"sediado": "sediar", "sediada": "sediar",
"nacido": "nacer", "nacida": "nacer",
"nascido": "nascer", "nascida": "nascer",
}
func lemma(w string) string {
if l, ok := verbLemma[w]; ok {
return l
}
return w
}
// Verb+prep → relation type (multi-language)
// Keys: verbLemma+prep (or verbLemma alone for direct-object relations).
// Matches Python _VERB_RELATIONS exactly.
var depVerbRelations = map[string]string{
// English
"found+by": "founded_by", "co-found+by": "founded_by",
"establish+by": "founded_by", "create+by": "founded_by",
"set+up": "founded_by", "start+by": "founded_by",
"work+for": "works_for", "employ+by": "works_for",
"hire+by": "works_for", "join": "works_for",
"lead+by": "works_for", "manage+by": "works_for",
"head+by": "works_for", "run+by": "works_for",
"own+by": "owns", "develop+by": "develops",
"write+by": "wrote", "publish+by": "published",
"invest+in": "invests_in", "partner+with": "partners_with",
"collaborate+with": "collaborates_with",
"merge+with": "merged_with", "subsidiar+y": "is_subsidiary_of",
"base+in": "located_in", "locate+in": "located_in",
"situate+in": "located_in", "headquarter+in": "located_in",
"bear+in": "born_in", "bear+on": "born_in",
"acquire+by": "acquired", "buy+by": "acquired",
// German (de)
"gründen+von": "founded_by", "errichten+von": "founded_by",
"arbeiten+für": "works_for", "beschäftigen+bei": "works_for",
"anstellen+bei": "works_for",
"sich+befinden": "located_in", "liegen+in": "located_in",
"sitzen+in": "located_in", "gebären+in": "born_in",
"gebären+am": "born_in",
"erwerben+durch": "acquired", "kaufen+durch": "acquired",
"übernehmen+durch": "acquired",
// French (fr)
"fonder+par": "founded_by", "créer+par": "founded_by",
"établir+par": "founded_by",
"travailler+pour": "works_for", "employer+par": "works_for",
"embaucher+par": "works_for",
"situer+à": "located_in", "baser+à": "located_in",
"implanter+à": "located_in",
"naître+à": "born_in",
"acquérir+par": "acquired", "racheter+par": "acquired",
// Spanish (es)
"fundar+por": "founded_by", "crear+por": "founded_by",
"establecer+por": "founded_by",
"trabajar+para": "works_for", "emplear+por": "works_for",
"contratar+por": "works_for",
"ubicar+en": "located_in", "situar+en": "located_in",
"tener+sede": "located_in",
"nacer+en": "born_in",
"adquirir+por": "acquired", "comprar+por": "acquired",
// Portuguese (pt)
"criar+por": "founded_by",
"estabelecer+por": "founded_by",
"trabalhar+para": "works_for", "empregar+por": "works_for",
"localizar+em": "located_in", "situar+em": "located_in",
"sediar+em": "located_in",
"nascer+em": "born_in",
// Chinese (zh)
"创立+由": "founded_by", "创建+由": "founded_by",
"成立+由": "founded_by", "创办+由": "founded_by",
"设立+由": "founded_by",
"任职+于": "works_for", "就职+于": "works_for",
"工作+在": "works_for", "位于+在": "located_in",
"坐落+在": "located_in", "总部设+在": "located_in",
"出生+在": "born_in", "出生+于": "born_in",
"收购+由": "acquired", "并购+由": "acquired",
// Japanese (ja)
"設立+によって": "founded_by", "創立+によって": "founded_by",
"勤務+で": "works_for", "在籍+で": "works_for",
"位置+に": "located_in", "所在+に": "located_in",
"本社+を": "located_in",
"出生+に": "born_in",
"買収+によって": "acquired",
}
// Copula title patterns — X is [title] of Y → typed relation
var depCopulaTitles = map[string][]string{
"ceo": {"ceo_of", "works_for"}, "cto": {"works_for"},
"cfo": {"works_for"}, "coo": {"works_for"},
"vp": {"works_for"}, "director": {"works_for"},
"manager": {"works_for"}, "engineer": {"works_for"},
"employee": {"works_for"},
"founder": {"founded_by"}, "co-founder": {"founded_by"},
}
// Multi-hop inference rules: A rel1 B + B rel2 C ⇒ A rel3 C
var multiHopRules = map[string]map[string]string{
"ceo_of": {"is_subsidiary_of": "works_for", "located_in": "works_for"},
"works_for": {"is_subsidiary_of": "works_for"},
"founded_by": {"is_subsidiary_of": "founded_by"},
}
// ---------------------------------------------------------------------------
// Language-specific dependency role mappings
// ---------------------------------------------------------------------------
// Matches Python _LANG_DEP_RULES exactly for all 7 languages.
type roleSpec struct {
dep string // main dependency label (e.g. "nsubj", "obl:agent")
childDep string // child dep for compound rules (e.g. "pobj" for "agent"→"pobj")
caseMarker string // optional case marker text (zh:"由", ja:"によって")
}
var langRolesMap = map[string]map[string]roleSpec{
"en": {
"pass_subj": {dep: "nsubjpass"},
"subj": {dep: "nsubj"},
"agent": {dep: "agent", childDep: "pobj"},
"dobj": {dep: "dobj"},
"prep_obj": {dep: "prep", childDep: "pobj"},
},
"de": {
"subj": {dep: "sb"},
"agent": {dep: "sbp", childDep: "nk"},
"prep_obj": {dep: "mo", childDep: "nk"},
// German ROOT is aux verb, real verb has dep "oc"
"root_verb_child": {dep: "oc"},
},
"fr": {
"pass_subj": {dep: "nsubj:pass"},
"subj": {dep: "nsubj"},
"agent": {dep: "obl:agent"},
"dobj": {dep: "obj"},
"prep_obj": {dep: "case", childDep: "obl"},
},
"es": {
"subj": {dep: "nsubj"},
"agent": {dep: "obj"},
"prep_obj": {dep: "case", childDep: "obl"},
},
"pt": {
"pass_subj": {dep: "nsubj:pass"},
"subj": {dep: "nsubj"},
"agent": {dep: "obl:agent"},
"dobj": {dep: "obj"},
"prep_obj": {dep: "case", childDep: "obl"},
},
"zh": {
"subj": {dep: "nsubj"},
"agent": {dep: "nmod:prep", caseMarker: "由"},
"dobj": {dep: "dobj"},
"prep_obj": {dep: "case", childDep: "nmod"},
},
"ja": {
"subj": {dep: "nsubj"},
"agent": {dep: "obl", caseMarker: "によって"},
"dobj": {dep: "dobj"},
"prep_obj": {dep: "case", childDep: "obl"},
},
}
// Copula dependency labels per language (attr_deps, prep_deps, obj_deps).
// Matches Python _extract_copula label sets.
var copulaDeps = map[string]struct {
attrDeps []string
prepDeps []string
objDeps []string
}{
"en": {attrDeps: []string{"attr"}, prepDeps: []string{"prep"}, objDeps: []string{"pobj"}},
"de": {attrDeps: []string{"pred"}, prepDeps: []string{"mo"}, objDeps: []string{"nk"}},
"fr": {attrDeps: []string{"attr"}, prepDeps: []string{"case"}, objDeps: []string{"obl"}},
"es": {attrDeps: []string{"attr"}, prepDeps: []string{"case"}, objDeps: []string{"obl"}},
"pt": {attrDeps: []string{"attr"}, prepDeps: []string{"case"}, objDeps: []string{"obl"}},
"zh": {attrDeps: []string{"attr"}, prepDeps: []string{"case"}, objDeps: []string{"nmod"}},
"ja": {attrDeps: []string{"attr"}, prepDeps: []string{"case"}, objDeps: []string{"obl"}},
}
// be-verb surface forms across all 7 languages (used for copula detection).
// Duplicate strings between languages are normalised into a single entry.
var beVerbs = map[string]bool{
// English
"is": true, "are": true, "was": true, "were": true, "be": true, "been": true, "being": true,
// German
"ist": true, "sind": true, "bin": true, "bist": true, "seid": true,
"war": true, "waren": true, "gewesen": true, "sein": true,
// French
"est": true, "suis": true, "sommes": true, "êtes": true, "sont": true,
"était": true, "étant": true, "être": true,
// Spanish + Portuguese (shared forms; french "es" omitted to avoid key collision)
"es": true, "é": true, "son": true, "são": true,
"está": true, "están": true, "estão": true,
"era": true, "eran": true, "eram": true,
"ser": true, "sido": true, "siendo": true, "sendo": true,
}
// DepToken holds token dependency info from the C++ parser.
type DepToken struct {
Text string `json:"text"`
Head int `json:"head"`
Dep string `json:"dep"`
Index int `json:"index"`
POS string `json:"pos,omitempty"`
}
// roleResult is the return type for getByRole.
type roleResult struct {
entity Entity
prep string // prep lemma for prep_obj role; empty otherwise
}
// ---------------------------------------------------------------------------
// Main entry point
// ---------------------------------------------------------------------------
// DepExtractRelations extracts typed relations from a dependency parse tree.
// lang is the language code (en/zh/de/fr/es/pt/ja).
// maxDistance is the max character distance for co-occurrence (0 = use default 100).
func DepExtractRelations(text string, tokens []DepToken, entities []Entity, lang string, maxDistance int) []Relation {
entityMap := buildEntityMapMulti(entities)
var relations []Relation
// Detect German-style "oc" handling
_, hasRootVerbChild := langRolesMap[lang]["root_verb_child"]
for _, tok := range tokens {
if tok.Head != tok.Index {
continue
}
// German: ROOT is aux verb; real verb is an "oc" child
if hasRootVerbChild {
for _, c := range childrenOf(tok.Index, tokens) {
if c.Dep == langRolesMap[lang]["root_verb_child"].dep {
if hasNegation(c.Index, tokens) {
continue
}
rels := extractFromRoot(text, c.Index, tokens, entityMap, lang)
relations = append(relations, rels...)
if isBeVerb(tok) {
rels := extractCopula(text, c.Index, tokens, entityMap, lang)
relations = append(relations, rels...)
}
}
}
continue
}
// Standard languages: ROOT = main verb
if hasNegation(tok.Index, tokens) {
continue
}
rels := extractFromRoot(text, tok.Index, tokens, entityMap, lang)
relations = append(relations, rels...)
if isBeVerb(tok) {
rels := extractCopula(text, tok.Index, tokens, entityMap, lang)
relations = append(relations, rels...)
}
}
// Co-occurrence (always, matching Python DepRelationExtractor.extract())
if maxDistance <= 0 {
maxDistance = 100
}
relations = append(relations, extractCooccurrence(text, entities, maxDistance)...)
relations = inferMultiHop(relations)
relations = dedupRelations(relations)
return relations
}
// ---------------------------------------------------------------------------
// Negation
// ---------------------------------------------------------------------------
func hasNegation(idx int, tokens []DepToken) bool {
for _, t := range tokens {
if t.Head == idx && (t.Dep == "neg" || t.Dep == "advmod:neg") {
return true
}
}
return false
}
func isBeVerb(tok DepToken) bool {
return beVerbs[strings.ToLower(tok.Text)]
}
// ---------------------------------------------------------------------------
// Multi-hop inference
// ---------------------------------------------------------------------------
func inferMultiHop(rels []Relation) []Relation {
bySubj := make(map[string][]Relation)
for _, r := range rels {
if r.Predicate == "related_to" {
continue
}
key := strings.ToLower(r.Subject.Text)
bySubj[key] = append(bySubj[key], r)
}
var inferred []Relation
for _, r := range rels {
if r.Predicate == "related_to" {
continue
}
objKey := strings.ToLower(r.Object.Text)
if chain, ok := bySubj[objKey]; ok {
for _, r2 := range chain {
if hopRules, ok := multiHopRules[r.Predicate]; ok {
if inferredPred, ok := hopRules[r2.Predicate]; ok {
conf := r.Confidence
if r2.Confidence < conf {
conf = r2.Confidence
}
conf *= 0.9
r := Relation{
Subject: r.Subject,
Predicate: inferredPred,
Object: r2.Object,
Confidence: conf,
Metadata: map[string]interface{}{
"method": "multi_hop",
"via": r.Predicate + "→" + r2.Predicate,
},
}
inferred = append(inferred, r)
}
}
}
}
}
return append(rels, inferred...)
}
// ---------------------------------------------------------------------------
// Entity map
// ---------------------------------------------------------------------------
func buildEntityMapMulti(entities []Entity) map[string][]Entity {
m := make(map[string][]Entity)
for _, e := range entities {
key := strings.ToLower(e.Text)
m[key] = append(m[key], e)
cleaned := strings.TrimRight(e.Text, ".,;:!?")
cleaned = strings.TrimSpace(cleaned)
if cleaned != e.Text {
ckey := strings.ToLower(cleaned)
m[ckey] = append(m[ckey], e)
}
}
return m
}
func findBestEntity(key string, entityMap map[string][]Entity) *Entity {
entries := entityMap[strings.ToLower(strings.TrimSpace(key))]
if len(entries) == 0 {
return nil
}
return &entries[0]
}
// ---------------------------------------------------------------------------
// Language-aware role lookup (replaces old getChildEntity/getAgentPobj/getPrepObjs)
// ---------------------------------------------------------------------------
// getByRole returns matching (entity, prep) pairs for a semantic role.
// Matches Python DepRelationExtractor._get_by_role.
func getByRole(lang string, role string, rootIdx int, tokens []DepToken, entityMap map[string][]Entity) []roleResult {
roles, ok := langRolesMap[lang]
if !ok {
roles = langRolesMap["en"]
}
spec, ok := roles[role]
if !ok {
return nil
}
var results []roleResult
for _, c := range childrenOf(rootIdx, tokens) {
if spec.caseMarker != "" {
// zh/ja agent: check for case marker in subtree
if c.Dep == spec.dep && hasCaseMarkerInSubtree(c.Index, tokens, spec.caseMarker) {
if ent := findEntityInSubtree(c.Index, tokens, entityMap); ent != nil {
results = append(results, roleResult{entity: *ent})
}
}
} else if spec.childDep != "" {
// Compound rule: parent dep + child dep
if c.Dep == spec.dep {
prepLemma := strings.ToLower(c.Text)
for _, gc := range childrenOf(c.Index, tokens) {
if gc.Dep == spec.childDep {
if ent := findEntityInSubtree(gc.Index, tokens, entityMap); ent != nil {
if role == "prep_obj" {
results = append(results, roleResult{entity: *ent, prep: prepLemma})
} else {
results = append(results, roleResult{entity: *ent})
}
}
}
}
}
} else {
// Simple rule: single dep label
if c.Dep == spec.dep {
if ent := findEntityInSubtree(c.Index, tokens, entityMap); ent != nil {
results = append(results, roleResult{entity: *ent})
}
}
}
}
return results
}
// hasCaseMarkerInSubtree checks if any token in the subtree of idx contains marker text.
func hasCaseMarkerInSubtree(idx int, tokens []DepToken, marker string) bool {
visited := map[int]bool{}
subtree := collectSubtree(idx, tokens, visited)
for _, t := range subtree {
if t.Text == marker {
return true
}
}
return false
}
func collectSubtree(idx int, tokens []DepToken, visited map[int]bool) []DepToken {
if visited[idx] {
return nil
}
visited[idx] = true
result := []DepToken{tokens[idx]}
for _, c := range childrenOf(idx, tokens) {
result = append(result, collectSubtree(c.Index, tokens, visited)...)
}
return result
}
// ---------------------------------------------------------------------------
// extractFromRoot — passive/active/preposition patterns (language-aware)
// ---------------------------------------------------------------------------
func extractFromRoot(text string, rootIdx int, tokens []DepToken, entityMap map[string][]Entity, lang string) []Relation {
var relations []Relation
root := tokens[rootIdx]
verbLemma := lemma(strings.ToLower(root.Text))
// Get roles using language-aware mapping
first := func(lst []roleResult) *Entity {
if len(lst) > 0 {
return &lst[0].entity
}
return nil
}
nsubj := first(getByRole(lang, "subj", rootIdx, tokens, entityMap))
nsubjpass := first(getByRole(lang, "pass_subj", rootIdx, tokens, entityMap))
dobj := first(getByRole(lang, "dobj", rootIdx, tokens, entityMap))
agentList := getByRole(lang, "agent", rootIdx, tokens, entityMap)
var agentEntity *Entity
if len(agentList) > 0 {
agentEntity = &agentList[0].entity
}
prepList := getByRole(lang, "prep_obj", rootIdx, tokens, entityMap)
hasExplicitAgent := agentEntity != nil
// Detect passive:
// - explicit pass_subj (en, fr, pt)
// - subj + agent (zh/ja with agent marker, es-style)
isPassiveCandidate := hasExplicitAgent
effectiveNsubjpass := nsubjpass
effectiveNsubj := nsubj
if isPassiveCandidate {
if nsubjpass != nil {
effectiveNsubjpass = nsubjpass
} else if nsubj != nil {
effectiveNsubjpass = nsubj
effectiveNsubj = nil
}
}
// Passive: X was founded/acquired by Y
if effectiveNsubjpass != nil && agentEntity != nil {
candidates := []string{"by", "von", "par", "por", "durch", "由", "によって"}
for _, candidate := range candidates {
if relType := lookupVerb(verbLemma, candidate); relType != "" {
var subj, obj Entity
if relType == "founded_by" || relType == "acquired" {
subj, obj = *effectiveNsubjpass, *agentEntity
} else {
subj, obj = *agentEntity, *effectiveNsubjpass
}
r := makeRelation(subj, relType, obj, 0.90)
r.Metadata = map[string]interface{}{"method": "passive", "verb": verbLemma}
relations = append(relations, r)
break
}
}
}
// Active: X VERB Y or X VERB prep Y
if effectiveNsubj != nil {
if dobj != nil {
if relType := lookupVerb(verbLemma, ""); relType != "" {
r := makeRelation(*effectiveNsubj, relType, *dobj, 0.85)
r.Metadata = map[string]interface{}{"method": "active", "verb": verbLemma}
relations = append(relations, r)
}
}
for _, pe := range prepList {
if relType := lookupVerb(verbLemma, pe.prep); relType != "" {
r := makeRelation(*effectiveNsubj, relType, pe.entity, 0.85)
r.Metadata = map[string]interface{}{"method": "active_prep", "verb": verbLemma, "prep": pe.prep}
relations = append(relations, r)
}
}
}
// Passive with prep ("is based in")
if effectiveNsubjpass != nil && len(prepList) > 0 && agentEntity == nil {
for _, pe := range prepList {
relType := lookupVerb(verbLemma, pe.prep)
if relType == "" {
relType = lookupVerb("be+"+verbLemma, pe.prep)
}
if relType != "" {
r := makeRelation(*effectiveNsubjpass, relType, pe.entity, 0.85)
r.Metadata = map[string]interface{}{"method": "passive_prep", "verb": verbLemma, "prep": pe.prep}
relations = append(relations, r)
}
}
}
return relations
}
// ---------------------------------------------------------------------------
// Copula extraction (language-aware)
// ---------------------------------------------------------------------------
func extractCopula(text string, rootIdx int, tokens []DepToken, entityMap map[string][]Entity, lang string) []Relation {
subjList := getByRole(lang, "subj", rootIdx, tokens, entityMap)
if len(subjList) == 0 {
return nil
}
subj := subjList[0].entity
cd, ok := copulaDeps[lang]
if !ok {
cd = copulaDeps["en"]
}
var titleLemma string
var prepObj *Entity
for _, c := range childrenOf(rootIdx, tokens) {
if !containsDep(c.Dep, cd.attrDeps) {
continue
}
for _, cc := range childrenOf(c.Index, tokens) {
if !containsDep(cc.Dep, cd.prepDeps) {
continue
}
for _, gc := range childrenOf(cc.Index, tokens) {
if containsDep(gc.Dep, cd.objDeps) {
if ent := findEntityInSubtree(gc.Index, tokens, entityMap); ent != nil {
prepObj = ent
titleLemma = strings.ToLower(c.Text)
}
break
}
}
}
}
if titleLemma == "" || prepObj == nil {
return nil
}
var relations []Relation
for keyword, relTypes := range depCopulaTitles {
if strings.Contains(titleLemma, keyword) {
for _, rt := range relTypes {
r := makeRelation(subj, rt, *prepObj, 0.88)
r.Metadata = map[string]interface{}{"method": "copula", "title": titleLemma}
relations = append(relations, r)
}
break
}
}
return relations
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
func makeRelation(subj Entity, pred string, obj Entity, conf float64) Relation {
return Relation{
Subject: subj,
Predicate: pred,
Object: obj,
Confidence: conf,
Metadata: map[string]interface{}{},
}
}
func lookupVerb(verb, prep string) string {
if prep != "" {
if v, ok := depVerbRelations[verb+"+"+prep]; ok {
return v
}
return ""
}
return depVerbRelations[verb]
}
func containsDep(dep string, deps []string) bool {
for _, d := range deps {
if dep == d {
return true
}
}
return false
}
func childrenOf(idx int, tokens []DepToken) []DepToken {
var kids []DepToken
for _, t := range tokens {
if t.Head == idx && t.Index != idx {
kids = append(kids, t)
}
}
return kids
}
func findEntityInSubtree(idx int, tokens []DepToken, emap map[string][]Entity) *Entity {
words := collectWords(idx, tokens, map[int]bool{})
if len(words) == 0 {
return nil
}
text := strings.Join(words, " ")
key := strings.ToLower(strings.TrimSpace(text))
if ent := findBestEntity(key, emap); ent != nil {
return ent
}
// For CJK (no spaces), also try joining without spaces
noSpace := strings.ToLower(strings.TrimSpace(strings.Join(words, "")))
if noSpace != key {
if ent := findBestEntity(noSpace, emap); ent != nil {
return ent
}
}
for _, sep := range []string{" and ", " or ", ", "} {
if strings.Contains(key, sep) {
candidate := strings.TrimSpace(strings.SplitN(key, sep, 2)[0])
if ent := findBestEntity(candidate, emap); ent != nil {
return ent
}
}
}
// Fuzzy: try substring match
for ek, ev := range emap {
if strings.Contains(ek, key) || strings.Contains(key, ek) {
e := ev[0]
return &e
}
}
return nil
}
func collectWords(idx int, tokens []DepToken, visited map[int]bool) []string {
if visited[idx] {
return nil
}
visited[idx] = true
tok := tokens[idx]
if tok.Dep == "prep" || tok.Dep == "punct" || tok.Dep == "det" ||
tok.Dep == "aux" || tok.Dep == "auxpass" || tok.Dep == "cc" ||
tok.Dep == "conj" || tok.Dep == "neg" {
return nil
}
var childWords []string
kids := childrenOf(idx, tokens)
sortByIndex(kids)
for _, c := range kids {
childWords = append(childWords, collectWords(c.Index, tokens, visited)...)
}
hasChildBefore := false
for _, k := range kids {
if k.Index < idx {
hasChildBefore = true
break
}
}
if hasChildBefore {
return append(childWords, tok.Text)
}
return append([]string{tok.Text}, childWords...)
}
func sortByIndex(ts []DepToken) {
for i := 0; i < len(ts); i++ {
for j := i + 1; j < len(ts); j++ {
if ts[j].Index < ts[i].Index {
ts[i], ts[j] = ts[j], ts[i]
}
}
}
}
func dedupRelations(rels []Relation) []Relation {
seen := make(map[string]bool)
var result []Relation
for _, r := range rels {
key := strings.ToLower(r.Subject.Text + "|" + r.Predicate + "|" + r.Object.Text)
rev := strings.ToLower(r.Object.Text + "|" + r.Predicate + "|" + r.Subject.Text)
if seen[key] || seen[rev] {
continue
}
seen[key] = true
result = append(result, r)
}
return result
}

View File

@@ -0,0 +1,414 @@
//
// Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Package extractor provides NER and relation extraction for the ingestion
// pipeline. It wraps the C++ ThincNER engine via cgo and supplements it with
// pure-Go regex-based relation extraction.
//
// The architecture mirrors the Python rag/graphrag/ner package so that both
// code paths produce identical output (verified by test).
package extractor
// #cgo CXXFLAGS: -std=c++20 -I${SRCDIR}/../../..
// #cgo linux LDFLAGS: ${SRCDIR}/../../../cpp/cmake-build-release/librag_tokenizer_c_api.a -lstdc++ -lm -lpthread -lpcre2-8
// #cgo darwin LDFLAGS: ${SRCDIR}/../../../cpp/cmake-build-release/librag_tokenizer_c_api.a -lstdc++ -lm -lpthread -lpcre2-8
//
// #include <stdlib.h>
// #include "../../../cpp/rag_analyzer_c_api.h"
// #include "../../../cpp/thinc_parser.h"
import "C"
import (
"encoding/json"
"fmt"
"os"
"strconv"
"strings"
"sync"
"unsafe"
)
// Entity represents an extracted named entity.
type Entity struct {
Text string `json:"text"`
Label string `json:"label"`
StartChar int `json:"start_char"`
EndChar int `json:"end_char"`
Confidence float64 `json:"confidence"`
AppType string `json:"app_type,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// Relation represents a typed relation between two entities.
type Relation struct {
Subject Entity `json:"subject"`
Predicate string `json:"predicate"`
Object Entity `json:"object"`
Confidence float64 `json:"confidence"`
Context string `json:"context,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// ExtractionResult holds the output of a full extraction pass.
type ExtractionResult struct {
Entities []Entity `json:"entities"`
Relations []Relation `json:"relations"`
Language string `json:"language,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// Extractor provides NER + relation extraction
type Extractor struct {
mu sync.Mutex
// Language code (en/zh/de/fr/es/pt/ja)
Lang string
// Minimum confidence to include an entity (default 0.0 = all)
ConfidenceThreshold float64
// Include token-level info (POS, dep) in ExtractionResult metadata
IncludeTokens bool
// Max character distance for co-occurrence relations (default 100)
MaxDistance int
}
// spaCy NER label → application entity type mapping
var spacyToAppType = map[string]string{
"PERSON": "person",
"ORG": "organization",
"GPE": "geo",
"LOC": "geo",
"FAC": "geo",
"EVENT": "event",
"PRODUCT": "category",
"DATE": "event",
"TIME": "event",
"MONEY": "category",
"QUANTITY": "category",
"PERCENT": "category",
"LAW": "category",
"NORP": "category",
"LANGUAGE": "category",
"WORK_OF_ART": "category",
}
var skipLabels = map[string]bool{
"ORDINAL": true,
"CARDINAL": true,
}
// ModelPredictor is a cached predict function for a model path.
// Closure captures the C handle to avoid unsafe.Pointer type issues.
type ModelPredictor func(tokensJSON string) (string, error)
var (
modelCacheMu sync.Mutex
modelCache = map[string]ModelPredictor{}
)
// langModel maps language codes to spaCy model names.
var langModel = map[string]string{
"en": "en_core_web_sm",
"zh": "zh_core_web_sm",
"de": "de_core_news_sm",
"fr": "fr_core_news_sm",
"es": "es_core_news_sm",
"pt": "pt_core_news_sm",
"ja": "ja_core_news_sm",
}
// langFallback maps languages without dedicated relation patterns to a fallback.
var langFallback = map[string]string{
"de": "en",
"fr": "en",
"es": "en",
"pt": "en",
"ja": "zh",
}
// NewExtractor creates a new extractor.
// Supported langs: en, zh, de, fr, es, pt, ja.
func NewExtractor(lang string) *Extractor {
if lang == "" {
lang = "en"
}
if _, ok := langModel[lang]; !ok {
lang = "en"
}
return &Extractor{
Lang: lang,
ConfidenceThreshold: 0.0, // include all by default
MaxDistance: 100,
}
}
// Extract runs NER and optionally relation extraction (dep-based via C++ parser, or regex fallback).
func (e *Extractor) Extract(text string, extractRelations bool) (*ExtractionResult, error) {
entities, err := e.ExtractEntities(text)
if err != nil {
return nil, err
}
// Collect token info if requested (before entity dedup changes offsets)
var tokensMeta []map[string]interface{}
if e.IncludeTokens {
tokensJSON := tokenizeText(text, e.Lang)
if tokensJSON != "" {
var rawTokens []string
if err := json.Unmarshal([]byte(tokensJSON), &rawTokens); err == nil {
for i, t := range rawTokens {
tokensMeta = append(tokensMeta, map[string]interface{}{
"text": t,
"index": i,
})
}
}
}
}
result := &ExtractionResult{
Entities: entities,
Language: e.Lang,
Metadata: map[string]interface{}{
"n_entities": len(entities),
"model": langModel[e.Lang],
},
}
if len(tokensMeta) > 0 {
result.Metadata["n_tokens"] = len(tokensMeta)
result.Metadata["tokens"] = tokensMeta
}
if extractRelations && len(entities) >= 2 {
relations := e.extractRelations(text, entities)
result.Relations = relations
nTyped := 0
for _, r := range relations {
if r.Predicate != "related_to" {
nTyped++
}
}
result.Metadata["n_relations"] = nTyped
}
return result, nil
}
// extractRelations attempts dep-based extraction via C++ parser; falls back to regex.
func (e *Extractor) extractRelations(text string, entities []Entity) []Relation {
relLang := e.Lang
if fb, ok := langFallback[e.Lang]; ok {
relLang = fb
}
// Try dep-based extraction via C++ parser — uses e.Lang (not relLang) so
// de/fr/es/pt/ja apply their language-specific DepExtractRelations rules.
tokensJSON := tokenizeText(text, e.Lang)
if tokensJSON == "" {
return extractRelationsWithOpts(text, entities, relLang, e.MaxDistance)
}
var tokens []string
if err := json.Unmarshal([]byte(tokensJSON), &tokens); err != nil || len(tokens) == 0 {
return extractRelationsWithOpts(text, entities, relLang, e.MaxDistance)
}
modelDir := e.findModelDir()
nerDir := modelDir + "/ner"
parserDir := modelDir + "/parser"
if deps, err := ParseTokensWithParser(nerDir, parserDir, tokens); err == nil && len(deps) > 0 {
depTokens := make([]DepToken, len(deps))
for i, d := range deps {
depTokens[i] = DepToken{Text: d.Text, Head: d.Head, Dep: d.Dep, Index: d.Index}
}
if rels := DepExtractRelations(text, depTokens, entities, e.Lang, e.MaxDistance); len(rels) > 0 {
return rels
}
}
// Fallback: regex-based extraction
return extractRelationsWithOpts(text, entities, relLang, e.MaxDistance)
}
func (e *Extractor) getPredictor(modelDir string) ModelPredictor {
modelCacheMu.Lock()
defer modelCacheMu.Unlock()
if p, ok := modelCache[modelDir]; ok {
return p
}
cModelDir := C.CString(modelDir + "/ner")
cModelVocab := C.CString(modelDir + "/vocab")
handle := C.ThincNER_Create(cModelDir, cModelVocab)
C.free(unsafe.Pointer(cModelDir))
C.free(unsafe.Pointer(cModelVocab))
// Don't cache a nil handle — return a one-shot error predictor instead.
if handle == nil {
fn := func(tokensJSON string) (string, error) {
return "", fmt.Errorf("ThincNER handle is nil for model dir: %s", modelDir)
}
return fn
}
p := func(tokensJSON string) (string, error) {
e.mu.Lock()
cTokensJSON := C.CString(tokensJSON)
cResult := C.ThincNER_Predict(handle, cTokensJSON)
e.mu.Unlock()
C.free(unsafe.Pointer(cTokensJSON))
if cResult == nil {
return "", fmt.Errorf("NER prediction failed")
}
defer C.ThincNER_FreeString(cResult)
return C.GoString(cResult), nil
}
modelCache[modelDir] = p
return p
}
// ExtractEntities extracts named entities from text using C++ ThincNER.
func (e *Extractor) ExtractEntities(text string) ([]Entity, error) {
tokensJSON := tokenizeText(text, e.Lang)
if tokensJSON == "" {
return nil, fmt.Errorf("tokenization failed")
}
modelDir := e.findModelDir()
predict := e.getPredictor(modelDir)
resultJSON, err := predict(tokensJSON)
if err != nil {
return nil, err
}
var rawEntities []struct {
Text string `json:"text"`
Label string `json:"label"`
Start int `json:"start"`
End int `json:"end"`
Confidence float64 `json:"confidence"`
}
if err := json.Unmarshal([]byte(resultJSON), &rawEntities); err != nil {
return nil, fmt.Errorf("failed to parse NER result: %w", err)
}
// Dedup by (text.lower(), start_char) — matching Python NERExtractor
// For CJK, strip spaces from entity text (BILUO decoder joins tokens with spaces)
isCJK := e.Lang == "zh" || e.Lang == "ja"
seen := make(map[string]bool)
entities := make([]Entity, 0, len(rawEntities))
for _, re := range rawEntities {
if skipLabels[re.Label] {
continue
}
if re.Confidence < e.ConfidenceThreshold {
continue
}
text := re.Text
if isCJK {
text = strings.ReplaceAll(text, " ", "")
}
key := strings.ToLower(text) + "|" + strconv.Itoa(re.Start)
if seen[key] {
continue
}
seen[key] = true
appType := spacyToAppType[re.Label]
if appType == "" {
appType = strings.ToLower(re.Label)
}
entities = append(entities, Entity{
Text: text,
Label: re.Label,
StartChar: re.Start,
EndChar: re.End,
Confidence: re.Confidence,
AppType: appType,
Metadata: map[string]interface{}{"source": "thincner"},
})
}
return entities, nil
}
// findModelDir locates the spaCy model directory under /usr/share/infinity/resource/spacy.
func (e *Extractor) findModelDir() string {
modelName := langModel[e.Lang]
if modelName == "" {
modelName = "en_core_web_sm"
}
base := "/usr/share/infinity/resource/spacy/" + modelName
if dirExists(base) {
return base
}
if p := getenv("SPACY_MODEL_DIR"); p != "" {
return p
}
return base
}
func dirExists(path string) bool {
info, err := os.Stat(path)
return err == nil && info.IsDir()
}
// tokenizeText tokenizes text via C++ tokenizer (all languages).
// Returns JSON array of token strings.
func tokenizeText(text, lang string) string {
cText := C.CString(text)
cLang := C.CString(lang)
defer C.free(unsafe.Pointer(cText))
defer C.free(unsafe.Pointer(cLang))
cTokens := C.ThincNER_Tokenize(cText, cLang)
if cTokens == nil {
return ""
}
defer C.ThincNER_FreeString(cTokens)
return C.GoString(cTokens)
}
func getenv(key string) string {
return os.Getenv(key)
}
// DetectLanguage detects text language based on Unicode ranges.
// Pure Go, zero dependencies.
func DetectLanguage(text string) string {
han, hira, kata, latin := 0, 0, 0, 0
for _, r := range text {
switch {
case isHan(r):
han++
case isHiragana(r):
hira++
case isKatakana(r):
kata++
case isLatin(r):
latin++
}
}
total := han + hira + kata + latin
if total == 0 {
return "en"
}
// CJK majority
if float64(han+hira+kata)/float64(total) > 0.3 {
if hira+kata > han {
return "ja" // Japanese-heavy
}
if han > 0 {
return "zh" // Han-heavy → Chinese
}
return "en"
}
// Latin majority — default to en (user specifies de/fr/es/pt explicitly)
return "en"
}
func isHan(r rune) bool { return r >= 0x4E00 && r <= 0x9FFF }
func isHiragana(r rune) bool { return r >= 0x3040 && r <= 0x309F }
func isKatakana(r rune) bool { return r >= 0x30A0 && r <= 0x30FF }
func isLatin(r rune) bool { return (r >= 0x0041 && r <= 0x005A) || (r >= 0x0061 && r <= 0x007A) }

View File

@@ -0,0 +1,30 @@
//
// Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Package extractor provides NER and relation extraction for the ingestion
// pipeline. It mirrors the Python rag/graphrag/ner package so that both
// code paths produce identical output.
//
// The C++ ThincNER engine (internal/cpp/) loads spaCy model.ckpt+model.bin
// directly for NER inference. Relation extraction is pure Go regex.
//
// Usage:
//
// ext := extractor.NewExtractor("en")
// result, err := ext.Extract("Apple Inc. was founded by Steve Jobs.", true)
// for _, e := range result.Entities { ... }
// for _, r := range result.Relations { ... }
package extractor

View File

@@ -0,0 +1,439 @@
//
// Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package extractor
import (
"regexp"
"strings"
)
// Multilingual relation patterns — matching Python MULTILANG_RELATION_PATTERNS.
// Entity groups [A-Z] are case-sensitive; relation keywords use (?i) inline.
// _entWord: uppercase-start word with optional trailing period (e.g. "Inc.", "Corp.")
// Periods between initials are also supported (e.g. "U.S.", "J.K.")
const _entWord = `[A-Za-z][\w']*(?:\.[A-Za-z][\w']*)*\.?`
const _relEntity = `(` + _entWord + `(?:\s+` + _entWord + `)*?)`
const _relEntity2 = `(` + _entWord + `(?:\s+` + _entWord + `){0,1})`
var relationPatterns = map[string][]relPatternEntry{
"en": {
{regexp.MustCompile(_relEntity + `\s+(?i:was)\s+(?i:founded)\s+(?i:by)\s+` + _relEntity2), "founded_by"},
{regexp.MustCompile(_relEntity + `\s+(?i:is)\s+(?i:an?\s+)?(?i:co-)?(?i:founder)\s+(?i:of)\s+` + _relEntity2), "founded_by"},
{regexp.MustCompile(_relEntity + `\s+(?i:works)\s+(?i:for)\s+` + _relEntity2), "works_for"},
{regexp.MustCompile(_relEntity + `\s+(?i:is)\s+(?i:an?\s+)?(?i:employee)\s+(?i:of)\s+` + _relEntity2), "works_for"},
{regexp.MustCompile(_relEntity + `\s+(?i:joined)\s+` + _relEntity2), "works_for"},
{regexp.MustCompile(_relEntity + `\s+(?i:is)\s+(?i:the\s+)?(?:CEO|CTO|CFO|VP|(?i:director|manager|engineer))\s+(?i:of|at)\s+` + _relEntity2), "works_for"},
{regexp.MustCompile(_relEntity + `\s+(?i:is)\s+(?i:located|based|headquartered|situated)\s+(?i:in)\s+` + _relEntity2), "located_in"},
{regexp.MustCompile(_relEntity + `\s+(?i:was)\s+(?i:born)\s+(?i:in|on)\s+` + _relEntity2), "born_in"},
{regexp.MustCompile(_relEntity + `\s+(?i:born)\s+(?i:in|on)\s+` + _relEntity2), "born_in"},
{regexp.MustCompile(_relEntity + `\s+(?i:was)\s+(?i:acquired)\s+(?i:by)\s+` + _relEntity2), "acquired"},
{regexp.MustCompile(_relEntity + `\s+(?i:acquired)\s+` + _relEntity2), "acquired"},
{regexp.MustCompile(_relEntity + `\s+(?i:is)\s+(?i:the\s+)?(?i:CEO)\s+(?i:of)\s+` + _relEntity2), "ceo_of"},
},
"zh": {
{regexp.MustCompile(`([\p{Han}\w]{2,6})\s*由\s*([\p{Han}\w]{2,4})\s*(?:创立|创建|成立|创办)`), "founded_by"},
{regexp.MustCompile(`([\p{Han}\w]{2,4})\s*(?:创立|创建|成立|创办)(?:\s*了\s*)?([\p{Han}\w]{2,10})`), "founded_by"},
{regexp.MustCompile(`([\p{Han}\w]{2,4})\s*(?:是\s*)?([\p{Han}\w]{2,10})\s*(?:创始人|联合创始人)`), "founded_by"},
{regexp.MustCompile(`([\p{Han}\w]{2,4})\s*(?:任职于|供职于|工作于|就职于)\s*([\p{Han}\w]{2,10})`), "works_for"},
{regexp.MustCompile(`([\p{Han}\w]{2,4})\s*(?:是\s*)?([\p{Han}\w]{2,10})\s*(?:的员工|的雇员)`), "works_for"},
{regexp.MustCompile(`([\p{Han}\w]{2,10})\s*(?:位于|坐落于|总部设在|总部位于)\s*([\p{Han}\w]{2,6})`), "located_in"},
{regexp.MustCompile(`([\p{Han}\w]{2,10})\s*在\s*([\p{Han}\w]{2,6})`), "located_in"},
{regexp.MustCompile(`([\p{Han}\w]{2,4})\s*(?:出生于|生于)\s*([\p{Han}\w]{2,6})`), "born_in"},
{regexp.MustCompile(`([\p{Han}\w]{2,10})\s*(?:收购|并购)\s*([\p{Han}\w]{2,10})`), "acquired"},
{regexp.MustCompile(`([\p{Han}\w]{2,10})\s*被\s*([\p{Han}\w]{2,10})\s*(?:收购|并购)`), "acquired"},
},
}
type relPatternEntry struct {
pattern *regexp.Regexp
predicate string
}
// ExtractRelations extracts typed relations between entities.
// Matches the Python RelationExtractor pattern-based approach,
// including cross-sentence filtering via sentence boundary checks.
func ExtractRelations(text string, entities []Entity, lang string) []Relation {
return extractRelationsWithOpts(text, entities, lang, 100)
}
// extractRelationsWithOpts is the internal version with configurable max distance.
func extractRelationsWithOpts(text string, entities []Entity, lang string, maxDistance int) []Relation {
patterns, ok := relationPatterns[lang]
if !ok {
patterns = relationPatterns["en"]
}
// Build multimap: entity text → all occurrences (handles duplicate entity names)
entityMultiMap := make(map[string][]Entity, len(entities))
for _, e := range entities {
key := strings.ToLower(e.Text)
entityMultiMap[key] = append(entityMultiMap[key], e)
// Also add punctuation-stripped version
cleaned := strings.TrimRight(e.Text, ".,;:!?")
cleaned = strings.TrimSpace(cleaned)
if cleaned != e.Text {
ckey := strings.ToLower(cleaned)
entityMultiMap[ckey] = append(entityMultiMap[ckey], e)
}
}
// Build sentence spans (matching Python's sentence splitting regex)
hasOffsets := false
for _, e := range entities {
if e.StartChar != 0 || e.EndChar != 0 {
hasOffsets = true
break
}
}
var sentenceSpans [][2]int
if hasOffsets {
sentenceSpans = splitSentences(text)
}
seen := make(map[string]bool)
var relations []Relation
// Phase 1: Pattern-based typed relations
// Process each sentence separately to prevent cross-sentence regex matches.
// When entities have no offsets, fall back to full-text matching.
if hasOffsets && len(sentenceSpans) > 0 {
for _, entry := range patterns {
for _, sp := range sentenceSpans {
sentText := text[sp[0]:sp[1]]
matches := entry.pattern.FindAllStringSubmatchIndex(sentText, -1)
for _, m := range matches {
if len(m) < 6 {
continue
}
subjStart, subjEnd := m[2], m[3]
objStart, objEnd := m[4], m[5]
if subjStart < 0 || objStart < 0 {
continue
}
// Adjust to absolute positions
absSubjStart := subjStart + sp[0]
absSubjEnd := subjEnd + sp[0]
absObjStart := objStart + sp[0]
absObjEnd := objEnd + sp[0]
subjText := strings.TrimSpace(text[absSubjStart:absSubjEnd])
objText := strings.TrimSpace(text[absObjStart:absObjEnd])
subj := findEntityByText(subjText, absSubjStart, entityMultiMap)
obj := findEntityByText(objText, absObjStart, entityMultiMap)
if subj.Text == "" || obj.Text == "" {
continue
}
key := subj.Text + "|" + entry.predicate + "|" + obj.Text
if seen[key] {
continue
}
seen[key] = true
var ctx string
absMatchStart := m[0] + sp[0]
absMatchEnd := m[1] + sp[0]
ctx = extractContext(text, text[absMatchStart:absMatchEnd])
relations = append(relations, Relation{
Subject: subj,
Predicate: entry.predicate,
Object: obj,
Confidence: 0.8,
Context: ctx,
})
}
}
}
} else {
// No offsets: process full text
for _, entry := range patterns {
matches := entry.pattern.FindAllStringSubmatchIndex(text, -1)
for _, m := range matches {
if len(m) < 6 {
continue
}
subjStart, subjEnd := m[2], m[3]
objStart, objEnd := m[4], m[5]
if subjStart < 0 || objStart < 0 {
continue
}
subjText := strings.TrimSpace(text[subjStart:subjEnd])
objText := strings.TrimSpace(text[objStart:objEnd])
subj := findEntityByText(subjText, subjStart, entityMultiMap)
obj := findEntityByText(objText, objStart, entityMultiMap)
if subj.Text == "" || obj.Text == "" {
continue
}
key := subj.Text + "|" + entry.predicate + "|" + obj.Text
if seen[key] {
continue
}
seen[key] = true
var ctx string
if len(m) >= 2 && m[0] >= 0 {
ctx = extractContext(text, text[m[0]:m[1]])
}
relations = append(relations, Relation{
Subject: subj,
Predicate: entry.predicate,
Object: obj,
Confidence: 0.8,
Context: ctx,
})
}
}
}
// Phase 2: Co-occurrence (standalone, with sentence boundary check)
for _, r := range extractCooccurrence(text, entities, maxDistance) {
key := r.Subject.Text + "|related_to|" + r.Object.Text
if !seen[key] {
seen[key] = true
relations = append(relations, r)
}
}
// Multi-hop inference + dedup (matching Python always applies these)
relations = inferMultiHop(relations)
relations = dedupRelations(relations)
return relations
}
// extractCooccurrence generates related_to relations for entity pairs
// within maxDistance characters in the same sentence.
func extractCooccurrence(text string, entities []Entity, maxDistance int) []Relation {
if len(entities) < 2 {
return nil
}
hasOffsets := false
for _, e := range entities {
if e.StartChar != 0 || e.EndChar != 0 {
hasOffsets = true
break
}
}
var sentenceSpans [][2]int
if hasOffsets {
sentenceSpans = splitSentences(text)
}
sameSentence := func(c1, c2 int) bool {
if !hasOffsets || len(sentenceSpans) == 0 {
return true
}
for _, sp := range sentenceSpans {
if sp[0] <= c1 && c1 < sp[1] && sp[0] <= c2 && c2 < sp[1] {
return true
}
}
return false
}
var relations []Relation
for i := 0; i < len(entities); i++ {
for j := i + 1; j < len(entities); j++ {
e1, e2 := entities[i], entities[j]
if !sameSentence(e1.StartChar, e2.StartChar) {
continue
}
dist := abs(e2.StartChar - e1.EndChar)
if dist > maxDistance {
continue
}
relations = append(relations, Relation{
Subject: e1,
Predicate: "related_to",
Object: e2,
Confidence: 0.4,
Context: extractContextSimple(text, e1, e2),
Metadata: map[string]interface{}{"method": "cooccurrence"},
})
}
}
return relations
}
// findEntityByText finds the entity occurrence closest to matchPos.
// Uses multimap to handle duplicate entity names at different positions.
func findEntityByText(raw string, matchPos int, entityMultiMap map[string][]Entity) Entity {
text := strings.TrimSpace(raw)
// Strip trailing punctuation
for len(text) > 0 && strings.ContainsAny(text[len(text)-1:], ".,;:!?") {
text = strings.TrimSpace(text[:len(text)-1])
}
ent := findClosest(text, matchPos, entityMultiMap)
if ent.Text != "" {
return ent
}
// Try stripping trailing " and ..." / " or ..." / ", ..."
key := strings.ToLower(text)
for _, sep := range []string{" and ", " or ", ", "} {
if idx := strings.Index(key, sep); idx > 0 {
if e := findClosest(key[:idx], matchPos, entityMultiMap); e.Text != "" {
return e
}
}
}
// Try progressively shorter word sequences (right-to-left word stripping)
// Handles cases like "Google in" → try "Google" or "Microsoft. Microsoft" → try "microsoft" (stripped)
words := strings.Fields(key)
for i := len(words) - 1; i > 0; i-- {
candidate := strings.Join(words[:i], " ")
// Strip trailing punctuation from candidate before lookup
candidate = strings.TrimRight(candidate, ".,;:!?")
candidate = strings.TrimSpace(candidate)
if e := findClosest(candidate, matchPos, entityMultiMap); e.Text != "" {
return e
}
}
return Entity{}
}
// findClosest returns the entity occurrence closest to matchPos from the multimap.
// Also tries stripping trailing punctuation from name if exact match fails.
func findClosest(name string, matchPos int, multiMap map[string][]Entity) Entity {
entries := multiMap[strings.ToLower(name)]
if len(entries) == 0 {
// Try with trailing punctuation stripped
cleaned := strings.TrimRight(name, ".,;:!?")
cleaned = strings.TrimSpace(cleaned)
if cleaned != name {
entries = multiMap[strings.ToLower(cleaned)]
}
}
if len(entries) == 0 {
return Entity{}
}
if len(entries) == 1 {
return entries[0]
}
// Multiple occurrences: pick the one whose span center is closest to matchPos
best := entries[0]
bestDist := abs(best.StartChar + best.EndChar - 2*matchPos)
for i := 1; i < len(entries); i++ {
d := abs(entries[i].StartChar + entries[i].EndChar - 2*matchPos)
if d < bestDist {
best = entries[i]
bestDist = d
}
}
return best
}
func extractContext(text string, matchStr string) string {
idx := strings.Index(text, matchStr)
if idx < 0 {
return ""
}
start := idx - 30
if start < 0 {
start = 0
}
end := idx + len(matchStr) + 30
if end > len(text) {
end = len(text)
}
return text[start:end]
}
func extractContextSimple(text string, e1, e2 Entity) string {
start := min(e1.StartChar, e2.StartChar) - 20
if start < 0 {
start = 0
}
end := max(e1.EndChar, e2.EndChar) + 20
if end > len(text) {
end = len(text)
}
return text[start:end]
}
func abs(x int) int {
if x < 0 {
return -x
}
return x
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
// splitSentences splits text into sentence spans [start, end).
// Matches Python's: re.finditer(r'[^.!?]+(?:[.!?](?=\s|$))+', text)
// Go RE2 lacks lookahead, so this manually identifies sentence boundaries:
// - Periods followed by uppercase letter or end-of-string are sentence ends.
// - Periods followed by lowercase letter are abbreviations (e.g., "Inc."), not sentence ends.
// - ! and ? are always sentence-ending.
func splitSentences(text string) [][2]int {
var spans [][2]int
start := 0
for i := 0; i < len(text); {
ch := text[i]
if ch == '!' || ch == '?' {
end := i + 1
spans = append(spans, [2]int{start, end})
start = end
i = end
continue
}
if ch == '.' {
// Check if this period is a sentence end or abbreviation
// Sentence end: period followed by space(s) + uppercase or end-of-string
// Abbreviation: period followed by space(s) + lowercase
end := i + 1
next := end
for next < len(text) && text[next] == ' ' {
next++
}
if next >= len(text) {
// Period at end of text = sentence end
spans = append(spans, [2]int{start, end})
start = end
i = end
continue
}
if text[next] >= 'A' && text[next] <= 'Z' {
// Period + space + uppercase = sentence end
spans = append(spans, [2]int{start, end})
start = end
i = end
continue
}
// Lowercase after period = abbreviation, not sentence end
i = end
continue
}
i++
}
// Remaining text after last sentence boundary
if start < len(text) {
spans = append(spans, [2]int{start, len(text)})
}
if len(spans) == 0 && len(text) > 0 {
spans = append(spans, [2]int{0, len(text)})
}
return spans
}

View File

@@ -0,0 +1,113 @@
//
// Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package extractor
import "testing"
// ---------------------------------------------------------------------------
// Test data — 21 English test cases (ground truth from Python+spaCy)
// ---------------------------------------------------------------------------
type EnTestSpec struct {
name string
text string
wantEntities [][2]string // (text, label) pairs that MUST be found
wantRels []relSpec // typed relations that MUST be found
}
type relSpec struct {
subj string
pred string
obj string
}
var enTests = []EnTestSpec{
{name: "founded_by_simple", text: "Apple Inc. was founded by Steve Jobs.",
wantEntities: [][2]string{{"Steve Jobs", "PERSON"}},
wantRels: []relSpec{{"Apple Inc.", "founded_by", "Steve Jobs"}}},
{name: "founded_by_multi", text: "Google was founded by Larry Page and Sergey Brin.",
wantEntities: [][2]string{{"Larry Page", "PERSON"}, {"Sergey Brin", "PERSON"}},
wantRels: []relSpec{{"Google", "founded_by", "Larry Page"}}},
{name: "cofounder_of", text: "Elon Musk is a co-founder of Tesla.",
wantEntities: [][2]string{{"Elon Musk", "PERSON"}, {"Tesla", "ORG"}},
wantRels: []relSpec{{"Elon Musk", "founded_by", "Tesla"}}},
{name: "works_for_simple", text: "John works for Microsoft.",
wantEntities: [][2]string{{"John", "PERSON"}, {"Microsoft", "ORG"}},
wantRels: []relSpec{{"John", "works_for", "Microsoft"}}},
{name: "employee_of", text: "Mary is an employee of Google.",
wantEntities: [][2]string{{"Mary", "PERSON"}, {"Google", "ORG"}},
wantRels: []relSpec{{"Mary", "works_for", "Google"}}},
{name: "joined_company", text: "Sundar Pichai joined Google in 2004.",
wantEntities: [][2]string{{"Sundar Pichai", "PERSON"}, {"Google", "ORG"}},
wantRels: []relSpec{{"Sundar Pichai", "works_for", "Google"}}},
{name: "headquartered_in", text: "The company is headquartered in San Francisco.",
wantEntities: [][2]string{{"San Francisco", "GPE"}},
wantRels: nil},
{name: "based_in", text: "Microsoft is based in Redmond.",
wantEntities: [][2]string{{"Microsoft", "ORG"}, {"Redmond", "GPE"}},
wantRels: []relSpec{{"Microsoft", "located_in", "Redmond"}}},
{name: "born_in", text: "Albert Einstein was born in Germany.",
wantEntities: [][2]string{{"Albert Einstein", "PERSON"}, {"Germany", "GPE"}},
wantRels: []relSpec{{"Albert Einstein", "born_in", "Germany"}}},
{name: "ceo_of", text: "Sundar Pichai is the CEO of Google.",
wantEntities: [][2]string{{"Sundar Pichai", "PERSON"}, {"Google", "ORG"}},
wantRels: []relSpec{{"Sundar Pichai", "works_for", "Google"}, {"Sundar Pichai", "ceo_of", "Google"}}},
{name: "acquired_by", text: "Instagram was acquired by Facebook.",
wantEntities: nil, // en_core_web_sm doesn't tag these
wantRels: nil},
{name: "acquired_active", text: "Facebook acquired Instagram.",
wantEntities: [][2]string{{"Instagram", "PERSON"}}, // en_core_web_sm: Instagram→PERSON
wantRels: nil},
{name: "multi_founded_ceo", text: "Google was founded by Larry Page. Sundar Pichai is the CEO of Google.",
wantEntities: [][2]string{{"Larry Page", "PERSON"}, {"Sundar Pichai", "PERSON"}},
wantRels: nil},
{name: "multi_works_located", text: "John works for Microsoft. Microsoft is based in Redmond.",
wantEntities: [][2]string{{"John", "PERSON"}, {"Microsoft", "ORG"}, {"Redmond", "GPE"}},
wantRels: []relSpec{{"Microsoft", "located_in", "Redmond"}}},
{name: "no_entities", text: "The cat sat on the mat.",
wantEntities: nil,
wantRels: nil},
{name: "org_with_inc", text: "Microsoft Corporation was founded by Bill Gates.",
wantEntities: [][2]string{{"Bill Gates", "PERSON"}},
wantRels: []relSpec{{"Microsoft Corporation", "founded_by", "Bill Gates"}}},
{name: "located_city", text: "The restaurant is located in Paris.",
wantEntities: [][2]string{{"Paris", "GPE"}},
wantRels: nil},
}
// ---------------------------------------------------------------------------
// Fast unit tests (pure Go, no Python dependency)
// ---------------------------------------------------------------------------
func TestDetectLanguage(t *testing.T) {
tests := []struct {
text string
want string
}{
{"Hello world", "en"},
{"你好世界", "zh"},
{"こんにちは世界", "ja"},
{"阿里巴巴由马云创立", "zh"},
{"アップルは", "ja"},
}
for _, tt := range tests {
got := DetectLanguage(tt.text)
if got != tt.want {
t.Errorf("DetectLanguage(%q) = %q, want %q", tt.text, got, tt.want)
}
}
}

View File

@@ -0,0 +1,87 @@
package extractor
/*
#include <stdlib.h>
#include "../../../cpp/thinc_parser.h"
*/
import "C"
import (
"encoding/json"
"fmt"
"unsafe"
)
// DepToken holds dependency parse info for one token (mirrors Go dep_relation.go)
type DepTokenC struct {
Text string `json:"text"`
Head int `json:"head"`
Dep string `json:"dep"`
Index int `json:"index"`
}
// RunParser runs the C++ dependency parser on tokenized text.
// modelBaseDir: path to model directory containing ner/ and parser/ subdirectories.
// tokensJSON: JSON array of token strings.
// Returns JSON array of DepTokenC.
func RunParser(nerDir, parserDir string, tokensJSON string) (string, error) {
cNer := C.CString(nerDir)
cParser := C.CString(parserDir)
cTokens := C.CString(tokensJSON)
defer C.free(unsafe.Pointer(cNer))
defer C.free(unsafe.Pointer(cParser))
defer C.free(unsafe.Pointer(cTokens))
handle := C.ThincParser_Create(cNer, cParser)
if handle == nil {
return "", fmt.Errorf("failed to create ThincParser handle")
}
defer C.ThincParser_Destroy(handle)
cResult := C.ThincParser_Predict(handle, cTokens)
if cResult == nil {
return "", fmt.Errorf("parser prediction failed")
}
defer C.ThincParser_FreeString(cResult)
return C.GoString(cResult), nil
}
// RunTagger runs the C++ POS tagger.
// nerDir: path to NER model directory (for tok2vec weights).
// taggerDir: path to tagger model directory.
func RunTagger(nerDir, taggerDir string, tokensJSON string) (string, error) {
cNer := C.CString(nerDir)
cTagger := C.CString(taggerDir)
cTokens := C.CString(tokensJSON)
defer C.free(unsafe.Pointer(cNer))
defer C.free(unsafe.Pointer(cTagger))
defer C.free(unsafe.Pointer(cTokens))
handle := C.ThincTagger_Create(cNer, cTagger)
if handle == nil {
return "", fmt.Errorf("failed to create ThincTagger handle")
}
defer C.ThincTagger_Destroy(handle)
cResult := C.ThincTagger_Predict(handle, cTokens)
if cResult == nil {
return "", fmt.Errorf("tagger prediction failed")
}
defer C.ThincTagger_FreeString(cResult)
return C.GoString(cResult), nil
}
// ParseTokensWithParser runs the C++ parser and returns parsed DepToken slice.
func ParseTokensWithParser(nerDir, parserDir string, tokens []string) ([]DepTokenC, error) {
tj, _ := json.Marshal(tokens)
resultJSON, err := RunParser(nerDir, parserDir, string(tj))
if err != nil {
return nil, err
}
var tokensC []DepTokenC
if err := json.Unmarshal([]byte(resultJSON), &tokensC); err != nil {
return nil, fmt.Errorf("parse result: %w", err)
}
return tokensC, nil
}

View File

@@ -1,4 +1,4 @@
//go:build cgo //go:build cgo && office
// //
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. // Copyright 2026 The InfiniFlow Authors. All Rights Reserved.

View File

@@ -1,4 +1,4 @@
//go:build cgo //go:build cgo && office
// //
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. // Copyright 2026 The InfiniFlow Authors. All Rights Reserved.

View File

@@ -1,4 +1,4 @@
//go:build !cgo //go:build !cgo || !office
package parser package parser

View File

@@ -1,4 +1,4 @@
//go:build cgo //go:build cgo && office
// //
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. // Copyright 2026 The InfiniFlow Authors. All Rights Reserved.

View File

@@ -1,4 +1,4 @@
//go:build cgo //go:build cgo && office
// //
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. // Copyright 2026 The InfiniFlow Authors. All Rights Reserved.

View File

@@ -1,4 +1,4 @@
//go:build cgo //go:build cgo && office
// //
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. // Copyright 2026 The InfiniFlow Authors. All Rights Reserved.

View File

@@ -1,4 +1,4 @@
//go:build cgo //go:build cgo && office
// //
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. // Copyright 2026 The InfiniFlow Authors. All Rights Reserved.

View File

@@ -39,9 +39,36 @@ pre-commit:
stage_fixed: true stage_fixed: true
- name: web-prettier - name: web-prettier
glob: "web/**/*.{css,less,json,js,jsx,ts,tsx}" glob: "web/**/*.{css,less,json,js,jsx,ts,tsx}"
run: npx --prefix web prettier --write --ignore-unknown {staged_files} run: |
# Ensure web/node_modules is populated. CI runners don't run `npm install`
# before lefthook, and `npx` without local node_modules fetches the
# latest packages from the registry — which breaks because the pinned
# prettier plugins (prettier-plugin-organize-imports,
# prettier-plugin-packagejson) aren't auto-resolved and ESLint 10
# requires eslint.config.js.
# Use a mkdir-based mutex to prevent parallel npm ci from multiple hooks
# racing (ETXTBSY error on esbuild). mkdir is atomic on both Linux and macOS.
LOCKDIR=/tmp/ragflow-web-npm-lock
while ! mkdir "$LOCKDIR" 2>/dev/null; do sleep 0.5; done
if [ ! -f web/node_modules/.package-lock.json ]; then
echo "==> web/node_modules missing or incomplete; running npm ci --prefix web"
rm -rf web/node_modules
npm ci --prefix web --no-audit --no-fund || { rm -rf "$LOCKDIR"; exit 1; }
fi
rm -rf "$LOCKDIR"
cd web && printf '%s\n' {staged_files} | sed 's|^web/||' | xargs npx prettier --write --ignore-unknown
stage_fixed: true stage_fixed: true
- name: web-eslint - name: web-eslint
glob: "web/**/*.{js,jsx,ts,tsx}" glob: "web/**/*.{js,jsx,ts,tsx}"
run: npx --prefix web eslint --fix {staged_files} run: |
# Same npm ci guard as web-prettier; mkdir mutex serialises concurrent installs.
LOCKDIR=/tmp/ragflow-web-npm-lock
while ! mkdir "$LOCKDIR" 2>/dev/null; do sleep 0.5; done
if [ ! -f web/node_modules/.package-lock.json ]; then
echo "==> web/node_modules missing or incomplete; running npm ci --prefix web"
rm -rf web/node_modules
npm ci --prefix web --no-audit --no-fund || { rm -rf "$LOCKDIR"; exit 1; }
fi
rm -rf "$LOCKDIR"
cd web && printf '%s\n' {staged_files} | sed 's|^web/||' | xargs npx eslint --fix
stage_fixed: true stage_fixed: true

View File

@@ -13,6 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from .graph_extractor import GraphExtractor from .ner_extractor import NERExtractor
from .dep_relation_extractor import DepRelationExtractor
from .types import Entity, ExtractionResult, Relation
__all__ = ["GraphExtractor"] __all__ = [
"NERExtractor",
"DepRelationExtractor",
"Entity",
"Relation",
"ExtractionResult",
]

View File

@@ -0,0 +1,558 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Dependency-based relation extractor — full semantica alignment.
Extracts typed relations using spaCy dependency parse with:
- Multi-hop inference (A→B→C transitivity)
- Negation filtering
- Dynamic confidence scoring
- Multi-occurrence entity matching
"""
from typing import Dict, List, Optional
from .types import Entity, Relation
# Language-specific dependency label mappings
# Keys: pass_subj, subj, agent, dobj, prep_obj — each maps to a dep label
# or a tuple (dep, child_dep) for compound patterns.
# None = no standard mapping (language uses different structure)
_LANG_DEP_RULES: Dict[str, Dict[str, object]] = {
"en": {"pass_subj": "nsubjpass", "subj": "nsubj",
"agent": ("agent", "pobj"),
"dobj": "dobj", "prep_obj": ("prep", "pobj")},
"de": {"subj": "sb",
"agent": ("sbp", "nk"),
"prep_obj": ("mo", "nk"),
"root_verb_child": "oc"}, # German ROOT is aux, real verb is "oc"
"fr": {"pass_subj": "nsubj:pass", "subj": "nsubj",
"agent": "obl:agent",
"dobj": "obj", "prep_obj": ("case", "obl")},
"es": {"subj": "nsubj",
"agent": "obj",
"prep_obj": ("case", "obl")},
"pt": {"pass_subj": "nsubj:pass", "subj": "nsubj",
"agent": "obl:agent",
"dobj": "obj", "prep_obj": ("case", "obl")},
"zh": {"subj": "nsubj",
"agent": ("nmod:prep", None, ""), # case "由" marks agent
"prep_obj": ("case", "nmod")},
"ja": {"subj": "nsubj",
"agent": ("obl", None, "によって"), # "によって" marks agent
"prep_obj": ("case", "obl")},
}
# Multi-hop inference rules: if A rel1 B and B rel2 C then A rel3 C
_MULTI_HOP: Dict[str, Dict[str, str]] = {
"ceo_of": {"is_subsidiary_of": "works_for", "located_in": "works_for"},
"works_for": {"is_subsidiary_of": "works_for"},
"founded_by": {"is_subsidiary_of": "founded_by"},
}
_VERB_RELATIONS: Dict[str, str] = {
# English
"found+by": "founded_by", "co-found+by": "founded_by",
"establish+by": "founded_by", "create+by": "founded_by",
"set+up": "founded_by", "start+by": "founded_by",
"work+for": "works_for", "employ+by": "works_for",
"hire+by": "works_for", "join": "works_for",
"lead+by": "works_for", "manage+by": "works_for",
"head+by": "works_for", "run+by": "works_for",
"own+by": "owns", "develop+by": "develops",
"write+by": "wrote", "publish+by": "published",
"invest+in": "invests_in", "partner+with": "partners_with",
"collaborate+with": "collaborates_with",
"merge+with": "merged_with", "subsidiar+y": "is_subsidiary_of",
"base+in": "located_in", "locate+in": "located_in",
"situate+in": "located_in", "headquarter+in": "located_in",
"bear+in": "born_in", "bear+on": "born_in",
"acquire+by": "acquired", "buy+by": "acquired",
# German (de): spaCy lemmas
"gründen+von": "founded_by", "errichten+von": "founded_by",
"arbeiten+für": "works_for", "beschäftigen+bei": "works_for",
"anstellen+bei": "works_for", "sich+befinden": "located_in",
"liegen+in": "located_in", "sitzen+in": "located_in",
"gebären+in": "born_in", "gebären+am": "born_in",
"erwerben+durch": "acquired", "kaufen+durch": "acquired",
"übernehmen+durch": "acquired",
# French (fr): spaCy lemmas
"fonder+par": "founded_by", "créer+par": "founded_by",
"établir+par": "founded_by",
"travailler+pour": "works_for", "employer+par": "works_for",
"embaucher+par": "works_for",
"situer+à": "located_in", "baser+à": "located_in",
"implanter+à": "located_in",
"naître+à": "born_in",
"acquérir+par": "acquired", "racheter+par": "acquired",
# Spanish + Portuguese (shared lemmas, no duplicate keys)
"fundar+por": "founded_by", "crear+por": "founded_by",
"criar+por": "founded_by",
"establecer+por": "founded_by", "estabelecer+por": "founded_by",
"trabajar+para": "works_for", "trabalhar+para": "works_for",
"emplear+por": "works_for", "empregar+por": "works_for",
"contratar+por": "works_for",
"ubicar+en": "located_in", "situar+en": "located_in",
"localizar+em": "located_in", "situar+em": "located_in",
"sediar+em": "located_in", "tener+sede": "located_in",
"nacer+en": "born_in", "nascer+em": "born_in",
"adquirir+por": "acquired", "comprar+por": "acquired",
# Chinese: verb + "由" (agent marker) or "被" (passive)
"创立+由": "founded_by", "创建+由": "founded_by",
"成立+由": "founded_by", "创办+由": "founded_by",
"设立+由": "founded_by",
"任职+于": "works_for", "就职+于": "works_for",
"工作+在": "works_for", "位于+在": "located_in",
"坐落+在": "located_in", "总部设+在": "located_in",
"出生+在": "born_in", "出生+于": "born_in",
"收购+由": "acquired", "并购+由": "acquired",
# Japanese: verb + "によって" (agent marker)
"設立+によって": "founded_by", "創立+によって": "founded_by",
"勤務+で": "works_for", "在籍+で": "works_for",
"位置+に": "located_in", "所在+に": "located_in",
"本社+を": "located_in",
"出生+に": "born_in",
"買収+によって": "acquired",
}
_COPULA_TITLE_MAP: Dict[str, List[str]] = {
"ceo": ["ceo_of", "works_for"], "cto": ["works_for"],
"cfo": ["works_for"], "coo": ["works_for"],
"vp": ["works_for"], "director": ["works_for"],
"manager": ["works_for"], "engineer": ["works_for"],
"employee": ["works_for"],
"founder": ["founded_by"], "co-founder": ["founded_by"],
}
class DepRelationExtractor:
"""Extract typed relations using dependency parse — semantica-aligned."""
def __init__(self, language: str = "en",
confidence_threshold: float = 0.3,
max_distance: int = 100):
self.language = language
self.confidence_threshold = confidence_threshold
self.max_distance = max_distance
def extract(self, text: str, entities: List[Entity],
doc=None, **options) -> List[Relation]:
semantica_rels = []
if doc is not None:
semantica_rels = self._extract_with_dep(text, doc, entities)
semantica_rels.extend(self._extract_cooccurrence(text, entities))
semantica_rels = self._infer_multi_hop(semantica_rels)
semantica_rels = self._deduplicate(semantica_rels)
return [r for r in semantica_rels if r.confidence >= self.confidence_threshold]
# ------------------------------------------------------------------
# Multi-hop inference (属性传递)
# ------------------------------------------------------------------
@staticmethod
def _infer_multi_hop(relations: List[Relation]) -> List[Relation]:
"""Infer transitive relations: A→B→C ⇒ A→C."""
by_subj: Dict[str, List[Relation]] = {}
for r in relations:
if r.predicate == "related_to":
continue
by_subj.setdefault(r.subject.text.lower(), []).append(r)
inferred = []
for r in relations:
if r.predicate == "related_to":
continue
obj_key = r.obj.text.lower()
if obj_key in by_subj:
for r2 in by_subj[obj_key]:
if r2.predicate in _MULTI_HOP.get(r.predicate, {}):
inferred_rel = _MULTI_HOP[r.predicate][r2.predicate]
if inferred_rel:
inferred.append(Relation(
subject=r.subject, predicate=inferred_rel,
obj=r2.obj, confidence=min(r.confidence, r2.confidence) * 0.9,
metadata={"method": "multi_hop",
"via": f"{r.predicate}{r2.predicate}"},
))
return relations + inferred
# ------------------------------------------------------------------
# Dependency extraction
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Language-aware role mapping
# ------------------------------------------------------------------
def _roles(self) -> Dict[str, str]:
"""Get role → dep label mapping for current language."""
return _LANG_DEP_RULES.get(self.language, _LANG_DEP_RULES["en"])
def _get_by_role(self, root, role: str, entity_map) -> list:
"""Get entities for a semantic role (language-aware). Returns [(Entity, prep?)]"""
rule = self._roles().get(role)
if rule is None:
return []
results = []
for c in root.children:
dep = c.dep_
if isinstance(rule, str):
if dep == rule:
ent = self._entity_from_subtree(c, entity_map)
if ent:
results.append((ent, None))
elif isinstance(rule, tuple):
parent_dep, child_dep = rule[0], rule[1]
# Check optional case marker (e.g., "由" for zh, "によって" for ja)
case_marker = rule[2] if len(rule) > 2 else None
if dep == parent_dep:
if case_marker:
# Check if any child has the expected case lemma
has_case = any(
gc.lemma_ == case_marker or gc.text == case_marker
for gc in c.subtree
)
if not has_case:
continue
if child_dep is None:
ent = self._entity_from_subtree(c, entity_map)
if ent:
results.append((ent, c.lemma_.lower() if role == "prep_obj" else None))
else:
for gc in c.children:
if gc.dep_ == child_dep:
ent = self._entity_from_subtree(gc, entity_map)
if ent:
prep = c.lemma_.lower() if role == "prep_obj" else None
results.append((ent, prep))
break
return results
def _extract_with_dep(self, text, doc, entities) -> List[Relation]:
relations = []
entity_map = self._build_entity_map_multi(entities)
is_de = self.language == "de"
for sent in doc.sents:
for token in sent:
# German: ROOT is aux verb, real verb is "oc" child
if is_de:
if token.dep_ != "ROOT":
continue
for c in token.children:
if c.dep_ == "oc":
# German: args attach to aux (ROOT), not main verb (oc)
# Pass both: root aux for args, oc for verb lemma
relations.extend(self._extract_from_root(text, c, entity_map, aux_root=token))
continue
if token.dep_ != "ROOT":
continue
relations.extend(self._extract_from_root(text, token, entity_map))
if token.lemma_ == "be":
relations.extend(self._extract_copula(text, token, entity_map))
return relations
def _extract_from_root(self, text, root, entity_map, aux_root=None) -> List[Relation]:
relations = []
# Fall back to text when lemma is empty (zh, ja don't have lemmatizers)
verb_lemma = (root.lemma_ or root.text).lower()
# For languages like German where args attach to aux verb
check = root if aux_root is None else aux_root
# Negation
if any(c.dep_ in ("neg", "advmod:neg") for c in check.children):
return relations
# Extract roles (check both the main verb and optional aux parent)
def first(lst):
return lst[0][0] if lst else None
def get_roles(token):
return (
first(self._get_by_role(token, "subj", entity_map)),
first(self._get_by_role(token, "pass_subj", entity_map)),
first(self._get_by_role(token, "dobj", entity_map)),
first(self._get_by_role(token, "agent", entity_map)),
self._get_by_role(token, "prep_obj", entity_map),
any(c.dep_ == "aux" for c in token.children),
)
s1, sp1, d1, a1, p1, h1 = get_roles(root)
s2, sp2, d2, a2, p2, h2 = (None, None, None, None, [], False)
if aux_root:
s2, sp2, d2, a2, p2, h2 = get_roles(aux_root)
# Merge: prefer found roles from aux if main verb lacks them
nsubj = s1 or s2
nsubjpass = sp1 or sp2
dobj = d1 or d2
agent_entity = a1 or a2
prep_list = p1 + p2
has_aux = h1 or h2 or aux_root is not None
has_explicit_agent = agent_entity is not None
# Detect passive:
# - explicit pass_subj (en, fr, pt)
# - subj + agent + aux (Spanish-style)
# - subj + agent for languages with agent marker (zh, ja)
is_passive_candidate = has_explicit_agent and (has_aux or self.language in ("zh", "ja"))
effective_nsubjpass = nsubjpass or (nsubj if is_passive_candidate else None)
effective_nsubj = nsubj if not is_passive_candidate else None
# Passive: X was founded/acquired by Y
if effective_nsubjpass and agent_entity:
prep = ""
# Try language-appropriate prepositions/case markers
candidates = ("by", "von", "par", "por", "durch", "", "によって")
for candidate in candidates:
if self._lookup(verb_lemma, candidate):
prep = candidate
break
rel_type = self._lookup(verb_lemma, prep) if prep else None
if rel_type:
if rel_type in ("founded_by", "acquired"):
subj, obj = effective_nsubjpass, agent_entity
else:
subj, obj = agent_entity, effective_nsubjpass
relations.append(self._make_rel(subj, rel_type, obj, 0.90, "passive", verb_lemma))
# Active: X VERB Y or X VERB prep Y
if effective_nsubj:
if dobj:
rt = self._lookup(verb_lemma, None)
if rt:
relations.append(self._make_rel(effective_nsubj, rt, dobj, 0.85, "active", verb_lemma))
for prep_entity, prep_l in prep_list:
rt = self._lookup(verb_lemma, prep_l)
if rt:
relations.append(self._make_rel(effective_nsubj, rt, prep_entity, 0.85,
"active_prep", verb_lemma, prep=prep_l))
# Passive with prep ("is based in")
if effective_nsubjpass and prep_list and not agent_entity:
for prep_entity, prep_l in prep_list:
rt = self._lookup(verb_lemma, prep_l)
if not rt:
rt = self._lookup("be+" + verb_lemma, prep_l)
if rt:
relations.append(self._make_rel(effective_nsubjpass, rt, prep_entity, 0.85,
"passive_prep", verb_lemma, prep=prep_l))
return relations
@staticmethod
def _make_rel(subj, pred, obj, conf, method, verb, prep=""):
m = {"method": method, "verb": verb}
if prep:
m["prep"] = prep
return Relation(subject=subj, predicate=pred, obj=obj,
confidence=conf, metadata=m)
@staticmethod
def _already_has(rels, subj, pred, obj) -> bool:
for r in rels:
if r.subject.text == subj.text and r.predicate == pred and r.obj.text == obj.text:
return True
return False
def _extract_copula(self, text, root, entity_map) -> List[Relation]:
relations = []
# Get subject using language-specific rules
subjs = self._get_by_role(root, "subj", entity_map)
subj = subjs[0][0] if subjs else None
if not subj:
return relations
title_lemma = None
prep_obj = None
deps_to_check = ["attr", "pred"] # attr=en, pred=de
for c in root.children:
if c.dep_ not in deps_to_check:
continue
for cc in c.children:
prep_deps = {"prep", "mo", "case"} # en=prep, de=mo, fr/case
if cc.dep_ not in prep_deps:
continue
for gc in cc.children:
pobj_deps = {"pobj", "nk", "obl"}
if gc.dep_ in pobj_deps or True: # accept any child as object
prep_obj = self._entity_from_subtree(gc, entity_map)
if prep_obj:
title_lemma = c.lemma_.lower()
break
if not title_lemma or not prep_obj:
return relations
for keyword, rel_types in _COPULA_TITLE_MAP.items():
if keyword in title_lemma:
for rt in rel_types:
relations.append(Relation(
subject=subj, predicate=rt, obj=prep_obj,
confidence=0.88, context=text,
metadata={"method": "copula", "title": title_lemma},
))
break
return relations
# ------------------------------------------------------------------
# Better entity map: multi-occurrence aware
# ------------------------------------------------------------------
@staticmethod
def _build_entity_map_multi(entities: List[Entity]) -> Dict[str, List[Entity]]:
"""Build entity map that keeps ALL occurrences per name."""
result: Dict[str, List[Entity]] = {}
for e in entities:
key = e.text.lower()
result.setdefault(key, []).append(e)
cleaned = e.text.rstrip(".,;:!?").strip().lower()
if cleaned != key:
result.setdefault(cleaned, []).append(e)
return result
@staticmethod
def _find_best_entity(key: str, entity_map: Dict[str, List[Entity]],
fallback_text: str = "") -> Optional[Entity]:
"""Find the best entity match. If multiple, prefer the one whose
text is an exact match for fallback_text, or the first one."""
entries = entity_map.get(key.lower(), [])
if not entries:
return None
if len(entries) == 1:
return entries[0]
# Prefer exact text match
for e in entries:
if e.text.lower() == fallback_text.lower():
return e
return entries[0]
# ------------------------------------------------------------------
# Argument extraction helpers
# ------------------------------------------------------------------
@staticmethod
def _get_child_entity(token, dep, entity_map):
for c in token.children:
if c.dep_ == dep:
return DepRelationExtractor._entity_from_subtree(c, entity_map)
return None
@staticmethod
def _get_agent_pobj(root, entity_map):
for c in root.children:
if c.dep_ == "agent":
for gc in c.children:
if gc.dep_ == "pobj":
return DepRelationExtractor._entity_from_subtree(gc, entity_map)
return None
@staticmethod
def _get_prep_objs(root, entity_map):
results = []
for c in root.children:
if c.dep_ == "prep":
prep_lemma = c.lemma_.lower()
for gc in c.children:
if gc.dep_ == "pobj":
ent = DepRelationExtractor._entity_from_subtree(gc, entity_map)
if ent:
results.append((prep_lemma, ent))
return results
@staticmethod
def _entity_from_subtree(token, entity_map) -> Optional[Entity]:
"""Match token's subtree against entity map. Uses character positions
for conjunction handling."""
min_char = token.idx
max_char = token.idx + len(token.text)
for t in token.subtree:
if t.dep_ not in ("prep", "punct", "det", "aux", "auxpass", "cc", "conj"):
if t.idx < min_char:
min_char = t.idx
end = t.idx + len(t.text)
if end > max_char:
max_char = end
text = token.doc.text[min_char:max_char].strip()
key = text.lower()
# Try multi-map lookup
entries = entity_map.get(key, [])
if not entries:
for sep in (" and ", " or ", ", "):
if sep in key:
entries = entity_map.get(key.split(sep)[0].strip(), [])
if entries:
break
if not entries:
for ek, ev in entity_map.items():
if ek in key or key in ek:
entries = ev
break
if entries:
return entries[0]
return None
@staticmethod
def _lookup(verb: str, prep: Optional[str] = None) -> Optional[str]:
if prep:
key = f"{verb}+{prep}"
return _VERB_RELATIONS.get(key)
return _VERB_RELATIONS.get(verb)
@staticmethod
def _deduplicate(relations: List[Relation]) -> List[Relation]:
seen = set()
result = []
for r in relations:
key = (r.subject.text.lower(), r.predicate, r.obj.text.lower())
rev = (r.obj.text.lower(), r.predicate, r.subject.text.lower())
if key in seen or rev in seen:
continue
seen.add(key)
result.append(r)
return result
# ------------------------------------------------------------------
# Co-occurrence
# ------------------------------------------------------------------
def _extract_cooccurrence(self, text, entities) -> List[Relation]:
if len(entities) < 2:
return []
import re as _re
spans = [(m.start(), m.end())
for m in _re.finditer(r'[^.!?]+(?:[.!?](?=\s|$))+', text)]
def same_sent(c1, c2):
return any(ss <= c1 < se and ss <= c2 < se for ss, se in spans)
rels = []
for i in range(len(entities)):
for j in range(i + 1, len(entities)):
e1, e2 = entities[i], entities[j]
if not same_sent(e1.start_char, e2.start_char):
continue
if abs(e2.start_char - e1.end_char) > self.max_distance:
continue
cs = max(0, min(e1.start_char, e2.start_char) - 20)
ce = min(len(text), max(e1.end_char, e2.end_char) + 20)
rels.append(Relation(
subject=e1, predicate="related_to", obj=e2,
confidence=0.4, context=text[cs:ce],
metadata={"method": "cooccurrence"},
))
return rels

View File

@@ -0,0 +1,243 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
NERExtractor — semantica-style full pipeline extraction.
Pipeline: tokenize → tag(POS) → parse(dep) → NER → typed relations
All components share a single spaCy `doc` object (one forward pass).
Output includes:
- Entities (from NER, enriched with POS/dep)
- Typed relations (from dependency patterns)
- Dependency tree (heads + labels per token)
- POS tags per token
Supports 7 languages: en, zh, de, fr, es, pt, ja
"""
import logging
from typing import Any, Dict, List, Optional
import spacy
from spacy import Language
from .dep_relation_extractor import DepRelationExtractor
from .types import Entity, ExtractionResult
# Language → spaCy model
_MODEL_MAP = {
"en": "en_core_web_sm", "zh": "zh_core_web_sm",
"de": "de_core_news_sm", "fr": "fr_core_news_sm",
"es": "es_core_news_sm", "pt": "pt_core_news_sm",
"ja": "ja_core_news_sm",
}
# SpaCy labels to skip from NER output
_SKIP_LABELS = {"ORDINAL", "CARDINAL"}
# Labels by confidence tier (for NER confidence scoring)
_HIGH_CONF = {"PERSON", "ORG", "GPE", "LOC", "DATE"}
_MED_CONF = {"PRODUCT", "EVENT", "WORK_OF_ART", "LAW", "LANGUAGE", "NORP",
"MONEY", "TIME", "PERCENT", "FAC", "QUANTITY"}
class NERExtractor:
"""
Full semantic extraction pipeline (NER + tagger + parser + relations).
Usage:
ext = NERExtractor(language="en")
result = ext.extract("Apple Inc. was founded by Steve Jobs.")
# result.entities → [Entity]
# result.relations → [Relation]
# result.tokens → [TokenInfo] (text, head, dep, tag, index)
"""
# Model cache: language → nlp (shared singleton per process)
_nlp_cache: Dict[str, Language] = {}
def __init__(
self,
language: str = "en",
spacy_model: Optional[str] = None,
confidence_threshold: float = 0.3,
):
if language not in _MODEL_MAP and spacy_model is None:
language = "en"
self.language = language
self.model_name = spacy_model or _MODEL_MAP.get(language, "en_core_web_sm")
self.confidence_threshold = confidence_threshold
self._nlp: Optional[Language] = None
# ------------------------------------------------------------------
# Model lifecycle
# ------------------------------------------------------------------
def _ensure_model(self):
"""Lazy-load shared spaCy model. Keeps ALL pipes needed for
dependency parsing (tagger, parser, ner, lemmatizer, attribute_ruler)."""
if self.model_name in self._nlp_cache:
self._nlp = self._nlp_cache[self.model_name]
return
try:
nlp = spacy.load(self.model_name)
self._nlp_cache[self.model_name] = nlp
self._nlp = nlp
except Exception as e:
logging.error("Failed to load spaCy model '%s': %s",
self.model_name, e)
raise
# ------------------------------------------------------------------
# Main extraction
# ------------------------------------------------------------------
def extract(
self,
text: str,
extract_relations: bool = True,
include_tokens: bool = True,
) -> ExtractionResult:
"""Run full pipeline on text."""
# 1. Single forward pass through spaCy
self._ensure_model()
doc = self._nlp(text)
# 2. Extract entities from NER
entities = self._extract_entities(doc)
# 3. Build token list (with POS, dep)
tokens = self._build_tokens(doc) if include_tokens else []
# 4. Extract typed relations using dependency parse
relations = []
if extract_relations and len(entities) >= 2:
dep_ext = DepRelationExtractor(
language=self.language,
confidence_threshold=self.confidence_threshold,
)
relations = dep_ext.extract(text, entities, doc=doc)
# 5. Build result
result = ExtractionResult(
entities=entities,
relations=relations,
language=self.language,
)
result.metadata = {
"model": self.model_name,
"n_tokens": len(doc),
"n_entities": len(entities),
"n_relations": len([r for r in relations if r.predicate != "related_to"]),
}
if include_tokens:
result.metadata["tokens"] = tokens
return result
def extract_batch(
self,
texts: List[str],
extract_relations: bool = True,
include_tokens: bool = False,
batch_size: int = 32,
) -> List[ExtractionResult]:
"""Batch extraction using spaCy's nlp.pipe() for efficiency."""
self._ensure_model()
results = []
for doc in self._nlp.pipe(texts, batch_size=batch_size):
entities = self._extract_entities(doc)
tokens = self._build_tokens(doc) if include_tokens else []
relations = []
if extract_relations and len(entities) >= 2:
dep_ext = DepRelationExtractor(
language=self.language,
confidence_threshold=self.confidence_threshold,
)
relations = dep_ext.extract(doc.text, entities, doc=doc)
result = ExtractionResult(
entities=entities,
relations=relations,
language=self.language,
)
if include_tokens:
result.metadata = {"tokens": tokens}
results.append(result)
return results
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _label_confidence(label: str) -> float:
if label in _HIGH_CONF:
return 0.85
if label in _MED_CONF:
return 0.65
return 0.50
def _extract_entities(self, doc) -> List[Entity]:
"""Extract NER entities from spaCy doc, enriched with POS."""
entities = []
seen = set()
for ent in doc.ents:
if ent.label_ in _SKIP_LABELS:
continue
confidence = self._label_confidence(ent.label_)
if confidence < self.confidence_threshold:
continue
key = (ent.text.lower(), ent.start_char)
if key in seen:
continue
seen.add(key)
entities.append(Entity(
text=ent.text,
label=ent.label_,
start_char=ent.start_char,
end_char=ent.end_char,
confidence=confidence,
metadata={"source": "spacy"},
))
return entities
@staticmethod
def _build_tokens(doc) -> List[Dict[str, Any]]:
"""Build token list with POS tags and dependency info."""
return [
{
"text": t.text,
"tag": t.tag_,
"dep": t.dep_,
"head": t.head.i,
"index": i,
"lemma": t.lemma_,
"pos": t.pos_,
}
for i, t in enumerate(doc)
]
@staticmethod
def clear_cache():
"""Clear the NLP model cache (e.g., for testing)."""
NERExtractor._nlp_cache.clear()
# Patch ExtractionResult to support metadata

75
rag/graphrag/ner/types.py Normal file
View File

@@ -0,0 +1,75 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Data types for entity and relation extraction.
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List
@dataclass
class Entity:
"""Extracted entity."""
text: str
label: str # spaCy NER label: PERSON, ORG, GPE, ...
start_char: int
end_char: int
confidence: float = 1.0
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class Relation:
"""Extracted relation between two entities."""
subject: Entity
predicate: str # relation type: "founded_by", "works_for", ...
obj: Entity
confidence: float = 1.0
context: str = "" # surrounding text
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ExtractionResult:
"""Result of a full extraction pass."""
entities: List[Entity] = field(default_factory=list)
relations: List[Relation] = field(default_factory=list)
language: str = "en"
metadata: Dict[str, Any] = field(default_factory=dict)
SPACY_TO_APP_ENTITY_TYPE: Dict[str, str] = {
"PERSON": "person",
"ORG": "organization",
"GPE": "geo",
"LOC": "geo",
"FAC": "geo",
"EVENT": "event",
"PRODUCT": "category",
"WORK_OF_ART": "category",
"LAW": "category",
"LANGUAGE": "category",
"NORP": "category",
"MONEY": "category",
"QUANTITY": "category",
"TIME": "event",
"DATE": "event",
"PERCENT": "category",
"CARDINAL": "category",
"ORDINAL": "category",
}
SKIP_SPACY_LABELS = {"ORDINAL", "CARDINAL"}

View File

@@ -17,7 +17,7 @@
import time import time
import datetime import datetime
import pytest import pytest
from common.time_utils import current_timestamp, timestamp_to_date, date_string_to_timestamp, datetime_format, delta_seconds from common.time_utils import current_timestamp, timestamp_to_date, date_string_to_timestamp, datetime_format, delta_seconds, format_iso_8601_to_ymd_hms
class TestCurrentTimestamp: class TestCurrentTimestamp:
@@ -650,4 +650,65 @@ class TestDeltaSeconds:
# If we're testing on the first day of month # If we're testing on the first day of month
date_string = "2024-01-31 12:00:00" # Use a known past date date_string = "2024-01-31 12:00:00" # Use a known past date
result = delta_seconds(date_string) result = delta_seconds(date_string)
assert result > 0 assert result > 0
@pytest.mark.p2
class TestTimestampToDateCurrentTimeFallback:
"""Regression tests for the None/empty fallback of timestamp_to_date.
The docstring promises "If None or empty, uses current time", but the
fallback assigned ``time.time()`` (seconds) and then divided by 1000 again,
producing a date around 1970-01-21 instead of now. The existing
``test_return_type_always_string`` only checks the return type, so it never
caught this. These tests pin the behaviour by value.
"""
def test_none_uses_current_time(self, monkeypatch):
"""None input must resolve to current_timestamp() fallback."""
fixed_ms = 1704067200123
monkeypatch.setattr("common.time_utils.current_timestamp", lambda: fixed_ms)
assert timestamp_to_date(None) == time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(fixed_ms / 1000)
)
def test_empty_string_uses_current_time(self, monkeypatch):
"""Empty-string input must resolve to current_timestamp() fallback."""
fixed_ms = 1704067200123
monkeypatch.setattr("common.time_utils.current_timestamp", lambda: fixed_ms)
assert timestamp_to_date("") == time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(fixed_ms / 1000)
)
def test_zero_timestamp_is_not_treated_as_empty(self):
"""Zero timestamp should map to Unix epoch, not fallback to current time."""
assert timestamp_to_date(0) == time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(0))
@pytest.mark.p2
class TestFormatIso8601ToYmdHms:
"""Test cases for format_iso_8601_to_ymd_hms function."""
def test_standard_utc_z(self):
"""Standard UTC timestamp with trailing Z."""
assert format_iso_8601_to_ymd_hms("2024-01-01T12:00:00Z") == "2024-01-01 12:00:00"
def test_explicit_utc_offset(self):
"""Timestamp with an explicit +00:00 offset."""
assert format_iso_8601_to_ymd_hms("2024-01-01T12:00:00+00:00") == "2024-01-01 12:00:00"
def test_ordinal_date_extended(self):
"""ISO 8601 ordinal date (day-of-year), extended form.
dateutil.isoparse accepts it but datetime.fromisoformat rejects it,
which previously made the function silently return the input unchanged.
"""
assert format_iso_8601_to_ymd_hms("2024-001T12:00:00Z") == "2024-01-01 12:00:00"
def test_ordinal_date_basic(self):
"""ISO 8601 ordinal date (day-of-year), basic form."""
assert format_iso_8601_to_ymd_hms("2024001T120000Z") == "2024-01-01 12:00:00"
def test_invalid_string_returns_original(self):
"""Unparseable input is returned unchanged."""
assert format_iso_8601_to_ymd_hms("not-a-date") == "not-a-date"

View File

@@ -1,7 +1,8 @@
// src/components/ui/modal.tsx
import React, { FC, ReactNode, useCallback, useEffect, useMemo } from 'react';
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import * as DialogPrimitive from '@radix-ui/react-dialog'; import * as DialogPrimitive from '@radix-ui/react-dialog';
import { AlertCircle, CheckCircle, Info, Loader, X } from 'lucide-react'; import { AlertCircle, CheckCircle, Info, Loader, X } from 'lucide-react';
import React, { FC, ReactNode, useCallback, useEffect, useMemo } from 'react';
import { createRoot } from 'react-dom/client'; import { createRoot } from 'react-dom/client';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { DialogDescription } from '../dialog'; import { DialogDescription } from '../dialog';