mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 00:05:43 +08:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
45fc7feab4 | ||
|
|
b53b693f22 | ||
|
|
8e1dc4f308 | ||
|
|
5af361ed68 |
32
.github/workflows/tests.yml
vendored
32
.github/workflows/tests.yml
vendored
@@ -210,14 +210,8 @@ jobs:
|
|||||||
-v "${PWD}/internal/cpp/resource:/usr/share/infinity/resource" \
|
-v "${PWD}/internal/cpp/resource:/usr/share/infinity/resource" \
|
||||||
infiniflow/infinity_builder:ubuntu22_clang20
|
infiniflow/infinity_builder:ubuntu22_clang20
|
||||||
sudo docker exec "${BUILDER_CONTAINER}" bash -c 'git config --global safe.directory "*" && cd /ragflow && ./build.sh --cpp'
|
sudo docker exec "${BUILDER_CONTAINER}" bash -c 'git config --global safe.directory "*" && cd /ragflow && ./build.sh --cpp'
|
||||||
uv sync --python 3.13 --group test --frozen
|
|
||||||
./build.sh --go
|
./build.sh --go
|
||||||
|
|
||||||
- name: Prepare Python test environment
|
|
||||||
run: |
|
|
||||||
uv sync --python 3.13 --group test --frozen
|
|
||||||
uv pip install -e sdk/python
|
|
||||||
|
|
||||||
- name: Run Go unit tests
|
- name: Run Go unit tests
|
||||||
# Runs after `./build.sh --go`, which guarantees the C++ static
|
# Runs after `./build.sh --go`, which guarantees the C++ static
|
||||||
# library (librag_tokenizer_c_api.a) is present on disk. The Go
|
# library (librag_tokenizer_c_api.a) is present on disk. The Go
|
||||||
@@ -240,7 +234,10 @@ jobs:
|
|||||||
PKGS=$(go list ./... 2>/dev/null \
|
PKGS=$(go list ./... 2>/dev/null \
|
||||||
| grep -v '/internal/storage$' \
|
| grep -v '/internal/storage$' \
|
||||||
| grep -v '/internal/tokenizer$' \
|
| grep -v '/internal/tokenizer$' \
|
||||||
| grep -v '/internal/handler$' || true)
|
| grep -v '/internal/handler$' \
|
||||||
|
| grep -v '/internal/deepdoc/parser/pdf/pdfium' \
|
||||||
|
| grep -v '/internal/deepdoc/parser/pdf/pdfoxide' \
|
||||||
|
| grep -v '/internal/deepdoc/parser/pdf' || true)
|
||||||
if [ -z "$PKGS" ]; then
|
if [ -z "$PKGS" ]; then
|
||||||
./build.sh --test
|
./build.sh --test
|
||||||
else
|
else
|
||||||
@@ -253,6 +250,11 @@ jobs:
|
|||||||
sudo docker pull ubuntu:24.04
|
sudo docker pull ubuntu:24.04
|
||||||
sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -f Dockerfile -t ${RAGFLOW_IMAGE} .
|
sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -f Dockerfile -t ${RAGFLOW_IMAGE} .
|
||||||
|
|
||||||
|
- name: Prepare Python test environment
|
||||||
|
run: |
|
||||||
|
uv sync --python 3.13 --group test --frozen
|
||||||
|
uv pip install -e sdk/python
|
||||||
|
|
||||||
- name: Prepare function test environment
|
- name: Prepare function test environment
|
||||||
working-directory: docker
|
working-directory: docker
|
||||||
run: |
|
run: |
|
||||||
@@ -654,14 +656,8 @@ jobs:
|
|||||||
-v "${PWD}/internal/cpp/resource:/usr/share/infinity/resource" \
|
-v "${PWD}/internal/cpp/resource:/usr/share/infinity/resource" \
|
||||||
infiniflow/infinity_builder:ubuntu22_clang20
|
infiniflow/infinity_builder:ubuntu22_clang20
|
||||||
sudo docker exec "${BUILDER_CONTAINER}" bash -c 'git config --global safe.directory "*" && cd /ragflow && ./build.sh --cpp'
|
sudo docker exec "${BUILDER_CONTAINER}" bash -c 'git config --global safe.directory "*" && cd /ragflow && ./build.sh --cpp'
|
||||||
uv sync --python 3.13 --group test --frozen
|
|
||||||
./build.sh --go
|
./build.sh --go
|
||||||
|
|
||||||
- name: Prepare Python test environment
|
|
||||||
run: |
|
|
||||||
uv sync --python 3.13 --group test --frozen
|
|
||||||
uv pip install -e sdk/python
|
|
||||||
|
|
||||||
- name: Run Go unit tests
|
- name: Run Go unit tests
|
||||||
# Runs after `./build.sh --go`, which guarantees the C++ static
|
# Runs after `./build.sh --go`, which guarantees the C++ static
|
||||||
# library (librag_tokenizer_c_api.a) is present on disk. The Go
|
# library (librag_tokenizer_c_api.a) is present on disk. The Go
|
||||||
@@ -684,7 +680,10 @@ jobs:
|
|||||||
PKGS=$(go list ./... 2>/dev/null \
|
PKGS=$(go list ./... 2>/dev/null \
|
||||||
| grep -v '/internal/storage$' \
|
| grep -v '/internal/storage$' \
|
||||||
| grep -v '/internal/tokenizer$' \
|
| grep -v '/internal/tokenizer$' \
|
||||||
| grep -v '/internal/handler$' || true)
|
| grep -v '/internal/handler$' \
|
||||||
|
| grep -v '/internal/deepdoc/parser/pdf/pdfium' \
|
||||||
|
| grep -v '/internal/deepdoc/parser/pdf/pdfoxide' \
|
||||||
|
| grep -v '/internal/deepdoc/parser/pdf' || true)
|
||||||
if [ -z "$PKGS" ]; then
|
if [ -z "$PKGS" ]; then
|
||||||
./build.sh --test
|
./build.sh --test
|
||||||
else
|
else
|
||||||
@@ -697,6 +696,11 @@ jobs:
|
|||||||
sudo docker pull ubuntu:24.04
|
sudo docker pull ubuntu:24.04
|
||||||
sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -f Dockerfile -t ${RAGFLOW_IMAGE} .
|
sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -f Dockerfile -t ${RAGFLOW_IMAGE} .
|
||||||
|
|
||||||
|
- name: Prepare Python test environment
|
||||||
|
run: |
|
||||||
|
uv sync --python 3.13 --group test --frozen
|
||||||
|
uv pip install -e sdk/python
|
||||||
|
|
||||||
- name: Prepare function test environment
|
- name: Prepare function test environment
|
||||||
working-directory: docker
|
working-directory: docker
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
56
build.sh
56
build.sh
@@ -143,7 +143,7 @@ check_office_oxide_deps() {
|
|||||||
case "$(uname -s)" in
|
case "$(uname -s)" in
|
||||||
Linux) lib_file="liboffice_oxide.so" ;;
|
Linux) lib_file="liboffice_oxide.so" ;;
|
||||||
Darwin) lib_file="liboffice_oxide.dylib" ;;
|
Darwin) lib_file="liboffice_oxide.dylib" ;;
|
||||||
*) echo -e "${RED}Unsupported OS for office_oxide${NC}"; exit 1 ;;
|
*) echo -e "${RED}Unsupported OS for office_oxide${NC}"; return 1 ;;
|
||||||
esac
|
esac
|
||||||
|
|
||||||
local lib_path="${OFFICE_OXIDE_PREFIX}/lib/${lib_file}"
|
local lib_path="${OFFICE_OXIDE_PREFIX}/lib/${lib_file}"
|
||||||
@@ -164,14 +164,14 @@ check_office_oxide_deps() {
|
|||||||
case "$(uname -m)" in
|
case "$(uname -m)" in
|
||||||
x86_64) asset_name="native-linux-x86_64" ;;
|
x86_64) asset_name="native-linux-x86_64" ;;
|
||||||
aarch64|arm64) asset_name="native-linux-aarch64" ;;
|
aarch64|arm64) asset_name="native-linux-aarch64" ;;
|
||||||
*) echo -e "${RED}Unsupported arch: $(uname -m)${NC}"; exit 1 ;;
|
*) echo -e "${RED}Unsupported arch: $(uname -m)${NC}"; return 1 ;;
|
||||||
esac
|
esac
|
||||||
;;
|
;;
|
||||||
Darwin)
|
Darwin)
|
||||||
case "$(uname -m)" in
|
case "$(uname -m)" in
|
||||||
x86_64) asset_name="native-macos-x86_64" ;;
|
x86_64) asset_name="native-macos-x86_64" ;;
|
||||||
aarch64|arm64) asset_name="native-macos-aarch64" ;;
|
aarch64|arm64) asset_name="native-macos-aarch64" ;;
|
||||||
*) echo -e "${RED}Unsupported arch: $(uname -m)${NC}"; exit 1 ;;
|
*) echo -e "${RED}Unsupported arch: $(uname -m)${NC}"; return 1 ;;
|
||||||
esac
|
esac
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
@@ -182,9 +182,9 @@ check_office_oxide_deps() {
|
|||||||
_download_and_extract "$release_url" "${OFFICE_OXIDE_PREFIX}"
|
_download_and_extract "$release_url" "${OFFICE_OXIDE_PREFIX}"
|
||||||
|
|
||||||
if [ ! -f "$lib_path" ]; then
|
if [ ! -f "$lib_path" ]; then
|
||||||
echo -e "${RED}Error: Failed to install office_oxide native library (missing ${lib_path})${NC}"
|
echo -e "${YELLOW}Warning: Failed to install office_oxide native library (missing ${lib_path})${NC}"
|
||||||
echo " Try: curl -fsSL ${release_url} | tar xzf - -C ${OFFICE_OXIDE_PREFIX}"
|
echo " Try: curl -fsSL ${release_url} | tar xzf - -C ${OFFICE_OXIDE_PREFIX}"
|
||||||
exit 1
|
return 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo -e "${GREEN}✓ office_oxide native library installed${NC}"
|
echo -e "${GREEN}✓ office_oxide native library installed${NC}"
|
||||||
@@ -405,6 +405,7 @@ build_go() {
|
|||||||
eval "$install_cmd"
|
eval "$install_cmd"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
check_office_oxide_deps || true
|
||||||
setup_cgo_env
|
setup_cgo_env
|
||||||
|
|
||||||
local strip_flags=()
|
local strip_flags=()
|
||||||
@@ -445,35 +446,44 @@ build_go() {
|
|||||||
echo -e "${GREEN}✓ Go ingestor built successfully: $INGESTOR_BINARY${NC}"
|
echo -e "${GREEN}✓ Go ingestor built successfully: $INGESTOR_BINARY${NC}"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Configure CGO flags for native libraries (office_oxide, pdfium, pdf_oxide).
|
# Configure CGO flags for native libraries.
|
||||||
# Call before any `go build` / `go test` step that links against these libraries.
|
# setup_cgo_env — base: -I and -L paths only, no -l flags (those live in
|
||||||
|
# each package's own #cgo LDFLAGS pragma). Safe to call even when native
|
||||||
|
# libs are absent — just skips the paths that don't exist.
|
||||||
|
# setup_cgo_env_pdf — pdfium / pdf_oxide -L paths. Non-fatal when libs
|
||||||
|
# are missing. Only called by run_go_tests.
|
||||||
setup_cgo_env() {
|
setup_cgo_env() {
|
||||||
# ── office_oxide ──────────────────────────────────────────────────
|
# ── office_oxide (header + search path only, no -loffice_oxide) ───
|
||||||
check_office_oxide_deps
|
if [ -f "${OFFICE_OXIDE_PREFIX}/include/office_oxide_c/office_oxide.h" ]; then
|
||||||
export CGO_CFLAGS="-I${OFFICE_OXIDE_PREFIX}/include/office_oxide_c${CGO_CFLAGS:+ $CGO_CFLAGS}"
|
export CGO_CFLAGS="-I${OFFICE_OXIDE_PREFIX}/include/office_oxide_c${CGO_CFLAGS:+ $CGO_CFLAGS}"
|
||||||
export CGO_LDFLAGS="-L${OFFICE_OXIDE_PREFIX}/lib -loffice_oxide -Wl,-rpath,${OFFICE_OXIDE_PREFIX}/lib"
|
fi
|
||||||
export LD_LIBRARY_PATH="${OFFICE_OXIDE_PREFIX}/lib${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
|
if [ -f "${OFFICE_OXIDE_PREFIX}/lib/liboffice_oxide.so" ] || [ -f "${OFFICE_OXIDE_PREFIX}/lib/liboffice_oxide.dylib" ]; then
|
||||||
|
export CGO_LDFLAGS="-L${OFFICE_OXIDE_PREFIX}/lib${CGO_LDFLAGS:+ $CGO_LDFLAGS}"
|
||||||
|
export LD_LIBRARY_PATH="${OFFICE_OXIDE_PREFIX}/lib${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "CGO_CFLAGS: $CGO_CFLAGS"
|
||||||
|
echo "CGO_LDFLAGS: $CGO_LDFLAGS"
|
||||||
|
}
|
||||||
|
|
||||||
|
setup_cgo_env_pdf() {
|
||||||
# ── pdfium ────────────────────────────────────────────────────────
|
# ── pdfium ────────────────────────────────────────────────────────
|
||||||
check_pdfium_deps || return 1
|
check_pdfium_deps || true
|
||||||
if [ -f "${PDFIUM_PREFIX}/libpdfium.so" ]; then
|
if [ -f "${PDFIUM_PREFIX}/libpdfium.so" ]; then
|
||||||
export CGO_LDFLAGS="$CGO_LDFLAGS -L${PDFIUM_PREFIX} -Wl,-rpath,${PDFIUM_PREFIX}"
|
export CGO_LDFLAGS="$CGO_LDFLAGS -L${PDFIUM_PREFIX}"
|
||||||
export LD_LIBRARY_PATH="${PDFIUM_PREFIX}:${LD_LIBRARY_PATH}"
|
export LD_LIBRARY_PATH="${PDFIUM_PREFIX}:${LD_LIBRARY_PATH}"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# ── pdf_oxide ─────────────────────────────────────────────────────
|
# ── pdf_oxide ─────────────────────────────────────────────────────
|
||||||
check_pdf_oxide_deps || return 1
|
check_pdf_oxide_deps || true
|
||||||
if [ -f "${PDF_OXIDE_PREFIX}/libpdf_oxide.so" ]; then
|
if [ -f "${PDF_OXIDE_PREFIX}/libpdf_oxide.so" ]; then
|
||||||
export CGO_LDFLAGS="$CGO_LDFLAGS -L${PDF_OXIDE_PREFIX} -lpdf_oxide -Wl,-rpath,${PDF_OXIDE_PREFIX}"
|
export CGO_LDFLAGS="$CGO_LDFLAGS -L${PDF_OXIDE_PREFIX}"
|
||||||
export LD_LIBRARY_PATH="${PDF_OXIDE_PREFIX}:${LD_LIBRARY_PATH}"
|
export LD_LIBRARY_PATH="${PDF_OXIDE_PREFIX}:${LD_LIBRARY_PATH}"
|
||||||
elif [ -f "${PDF_OXIDE_PREFIX}/libpdf_oxide.a" ]; then
|
elif [ -f "${PDF_OXIDE_PREFIX}/libpdf_oxide.a" ]; then
|
||||||
export CGO_LDFLAGS="$CGO_LDFLAGS ${PDF_OXIDE_PREFIX}/libpdf_oxide.a"
|
export CGO_LDFLAGS="$CGO_LDFLAGS ${PDF_OXIDE_PREFIX}/libpdf_oxide.a"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "CGO_CFLAGS: $CGO_CFLAGS"
|
echo "CGO_LDFLAGS (with PDF): $CGO_LDFLAGS"
|
||||||
echo "Exporting CGO_CFLAGS: $CGO_CFLAGS"
|
|
||||||
echo "CGO_LDFLAGS: $CGO_LDFLAGS"
|
|
||||||
echo "Exporting CGO_LDFLAGS: $CGO_LDFLAGS"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Run Go unit tests with the same CGO env as `build_go`. Pass any extra args
|
# Run Go unit tests with the same CGO env as `build_go`. Pass any extra args
|
||||||
@@ -482,7 +492,9 @@ run_go_tests() {
|
|||||||
print_section "Running Go tests"
|
print_section "Running Go tests"
|
||||||
|
|
||||||
cd "$PROJECT_ROOT"
|
cd "$PROJECT_ROOT"
|
||||||
|
check_office_oxide_deps || true
|
||||||
setup_cgo_env
|
setup_cgo_env
|
||||||
|
setup_cgo_env_pdf
|
||||||
|
|
||||||
if [ "$#" -eq 0 ]; then
|
if [ "$#" -eq 0 ]; then
|
||||||
set -- ./...
|
set -- ./...
|
||||||
@@ -522,6 +534,10 @@ run() {
|
|||||||
|
|
||||||
cd "$PROJECT_ROOT"
|
cd "$PROJECT_ROOT"
|
||||||
|
|
||||||
|
# Set LD_LIBRARY_PATH for native libraries that were linked at build time.
|
||||||
|
# Libraries are only in the search path when they were present during build.
|
||||||
|
setup_cgo_env
|
||||||
|
|
||||||
# admin_server must be running before ragflow_server, otherwise ragflow_server's
|
# admin_server must be running before ragflow_server, otherwise ragflow_server's
|
||||||
# heartbeats to admin will error out (see internal/development.md).
|
# heartbeats to admin will error out (see internal/development.md).
|
||||||
print_section "Starting admin server (background)"
|
print_section "Starting admin server (background)"
|
||||||
|
|||||||
@@ -46,8 +46,9 @@ def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"):
|
|||||||
>>> timestamp_to_date(1704067200000)
|
>>> timestamp_to_date(1704067200000)
|
||||||
'2024-01-01 08:00:00'
|
'2024-01-01 08:00:00'
|
||||||
"""
|
"""
|
||||||
if not timestamp:
|
if timestamp is None or timestamp == "":
|
||||||
timestamp = time.time()
|
timestamp = current_timestamp()
|
||||||
|
logging.debug("timestamp_to_date received empty timestamp; using current_timestamp() fallback")
|
||||||
timestamp = int(timestamp) / 1000
|
timestamp = int(timestamp) / 1000
|
||||||
time_array = time.localtime(timestamp)
|
time_array = time.localtime(timestamp)
|
||||||
str_date = time.strftime(format_string, time_array)
|
str_date = time.strftime(format_string, time_array)
|
||||||
@@ -144,11 +145,8 @@ def format_iso_8601_to_ymd_hms(time_str: str) -> str:
|
|||||||
from dateutil import parser
|
from dateutil import parser
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if parser.isoparse(time_str):
|
dt = parser.isoparse(time_str)
|
||||||
dt = datetime.datetime.fromisoformat(time_str.replace("Z", "+00:00"))
|
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
else:
|
|
||||||
return time_str
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(str(e))
|
logging.error(str(e))
|
||||||
return time_str
|
return time_str
|
||||||
|
|||||||
@@ -133,6 +133,10 @@ set_target_properties(rag_tokenizer PROPERTIES
|
|||||||
add_library(rag_tokenizer_c_api STATIC
|
add_library(rag_tokenizer_c_api STATIC
|
||||||
rag_analyzer_c_api.cpp
|
rag_analyzer_c_api.cpp
|
||||||
rag_analyzer_c_api.h
|
rag_analyzer_c_api.h
|
||||||
|
thinc_ner.cpp
|
||||||
|
thinc_ner.h
|
||||||
|
thinc_parser.cpp
|
||||||
|
thinc_parser.h
|
||||||
rag_analyzer.cpp
|
rag_analyzer.cpp
|
||||||
rag_analyzer.h
|
rag_analyzer.h
|
||||||
dart_trie.h
|
dart_trie.h
|
||||||
|
|||||||
@@ -99,6 +99,31 @@ char* RAGAnalyzer_GetTermTag(RAGAnalyzerHandle handle, const char* term);
|
|||||||
// Returns: handle to the new analyzer instance, or NULL on failure
|
// Returns: handle to the new analyzer instance, or NULL on failure
|
||||||
RAGAnalyzerHandle RAGAnalyzer_Copy(RAGAnalyzerHandle handle);
|
RAGAnalyzerHandle RAGAnalyzer_Copy(RAGAnalyzerHandle handle);
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Named Entity Recognition (spaCy model inference)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Create a ThincNER inference handle.
|
||||||
|
// model_ner_dir: path to the spaCy model's ner/ component directory
|
||||||
|
// model_vocab_dir: path to the spaCy model's vocab/ directory (optional, can be NULL)
|
||||||
|
// Returns: handle, or NULL on failure
|
||||||
|
RAGAnalyzerHandle ThincNER_Create(const char* model_ner_dir, const char* model_vocab_dir);
|
||||||
|
|
||||||
|
// Destroy a ThincNER handle.
|
||||||
|
void ThincNER_Destroy(RAGAnalyzerHandle handle);
|
||||||
|
|
||||||
|
// Run NER on pre-tokenized text.
|
||||||
|
// tokens_json: JSON array e.g. ["Apple","Inc.","was","founded","by","Steve","Jobs","."]
|
||||||
|
// Returns JSON array of entities, caller must free with ThincNER_FreeString.
|
||||||
|
char* ThincNER_Predict(RAGAnalyzerHandle handle, const char* tokens_json);
|
||||||
|
|
||||||
|
// Tokenize text using spaCy-compatible rules.
|
||||||
|
// Returns JSON array of token strings, caller must free with ThincNER_FreeString.
|
||||||
|
char* ThincNER_Tokenize(const char* text, const char* lang);
|
||||||
|
|
||||||
|
// Free a string returned by ThincNER_Predict or ThincNER_Tokenize.
|
||||||
|
void ThincNER_FreeString(char* ptr);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
610
internal/cpp/thinc_ner.cpp
Normal file
610
internal/cpp/thinc_ner.cpp
Normal file
@@ -0,0 +1,610 @@
|
|||||||
|
#pragma STDC FP_CONTRACT OFF
|
||||||
|
|
||||||
|
#include "thinc_ner.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstring>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// JSON parser (minimal)
|
||||||
|
// =========================================================================
|
||||||
|
namespace {
|
||||||
|
std::string trim(const std::string& s) {
|
||||||
|
auto a = s.find_first_not_of(" \t\r\n");
|
||||||
|
return a == std::string::npos ? "" : s.substr(a, s.find_last_not_of(" \t\r\n")-a+1);
|
||||||
|
}
|
||||||
|
struct JVal {
|
||||||
|
enum Type {NUL,OBJ,ARR,STR,NUM,BOOL} type=NUL;
|
||||||
|
std::string str; std::vector<JVal> arr; std::unordered_map<std::string,JVal> obj; double num=0;
|
||||||
|
const JVal* get(const std::string& k) const { auto it=obj.find(k); return it!=obj.end()?&it->second:nullptr; }
|
||||||
|
int as_int() const { return (int)num; } int64_t as_i64() const { return (int64_t)num; }
|
||||||
|
};
|
||||||
|
struct JParser {
|
||||||
|
const char *p,*e; char pk() { while(p<e&&(*p==' '||*p=='\t'||*p=='\n'||*p=='\r'))++p; return p<e?*p:0; }
|
||||||
|
char nx() { while(p<e&&(*p==' '||*p=='\t'||*p=='\n'||*p=='\r'))++p; return p<e?*p++:0; }
|
||||||
|
JVal pv() { char c=pk(); if(c=='{')return po(); if(c=='[')return pa(); if(c=='"')return ps(); if(c=='t'||c=='f')return pb();
|
||||||
|
if(c=='n'){nx();nx();nx();nx();return JVal{};} return pn(); }
|
||||||
|
JVal po() { JVal v;v.type=JVal::OBJ; nx(); while(pk()!='}'){auto k=ps();nx();v.obj[k.str]=pv();if(pk()==',')nx();else break;}nx();return v; }
|
||||||
|
JVal pa() { JVal v;v.type=JVal::ARR; nx(); while(pk()!=']'){v.arr.push_back(pv());if(pk()==',')nx();else break;}nx();return v; }
|
||||||
|
JVal ps() { JVal v;v.type=JVal::STR; nx();while(p<e&&*p!='"'){if(*p=='\\'){++p;if(p<e)v.str+=*p++;}else v.str+=*p++;}if(p<e)++p;return v; }
|
||||||
|
JVal pn() { JVal v;v.type=JVal::NUM; auto s=p; if(p<e&&*p=='-')++p; while(p<e&&(*p>='0'&&*p<='9'))++p;
|
||||||
|
if(p<e&&*p=='.'){++p;while(p<e&&(*p>='0'&&*p<='9'))++p;}
|
||||||
|
if(p<e&&(*p=='e'||*p=='E')){++p;if(p<e&&(*p=='+'||*p=='-'))++p;while(p<e&&(*p>='0'&&*p<='9'))++p;}
|
||||||
|
if(s<p){try{v.num=std::stod(std::string(s,p-s));}catch(...){v.num=0;}} return v; }
|
||||||
|
JVal pb() { JVal v;v.type=JVal::BOOL; if(e-p>=4&&*p=='t'){v.str="true";p+=4;}else if(e-p>=5&&*p=='f'){v.str="false";p+=5;} return v; }
|
||||||
|
JVal parse(const std::string& j) { p=j.data(); e=p+j.size(); return pv(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// MurmurHash2 64-bit (vocab string→ID, seed=0 matching spaCy StringStore)
|
||||||
|
// =========================================================================
|
||||||
|
static uint64_t mh2_64a(const void* key, int len, uint64_t seed) {
|
||||||
|
const uint64_t m=0xc6a4a7935bd1e995ULL; const int r=47;
|
||||||
|
uint64_t h=seed^(uint64_t(len)*m); auto d=(const uint8_t*)key; int rm=len;
|
||||||
|
while(rm>=8){uint64_t k;memcpy(&k,d,8);k*=m;k^=k>>r;k*=m;h^=k;h*=m;d+=8;rm-=8;}
|
||||||
|
switch(rm){case 7:h^=uint64_t(d[6])<<48;case 6:h^=uint64_t(d[5])<<40;
|
||||||
|
case 5:h^=uint64_t(d[4])<<32;case 4:h^=uint64_t(d[3])<<24;
|
||||||
|
case 3:h^=uint64_t(d[2])<<16;case 2:h^=uint64_t(d[1])<<8;
|
||||||
|
case 1:h^=d[0];h*=m;break;}
|
||||||
|
h^=h>>r;h*=m;h^=h>>r; return h;
|
||||||
|
}
|
||||||
|
static uint64_t hash_feat(const std::string& s) { return s.empty()?0:mh2_64a(s.data(),(int)s.size(),0); }
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// MurmurHash3_x64_128 (exact copy from mmh3 package, verified against thinc)
|
||||||
|
// =========================================================================
|
||||||
|
#define ROTL64(x,r) ((x << r) | (x >> (64 - r)))
|
||||||
|
static uint64_t getblock64(const uint64_t* p, size_t i) { uint64_t r; memcpy(&r, p+i, 8); return r; }
|
||||||
|
static uint64_t fmix64(uint64_t k) {
|
||||||
|
k ^= k >> 33; k *= 0xff51afd7ed558ccdULL;
|
||||||
|
k ^= k >> 33; k *= 0xc4ceb9fe1a85ec53ULL;
|
||||||
|
k ^= k >> 33; return k;
|
||||||
|
}
|
||||||
|
static void mmh3_x64_128(const void* key, int len, uint32_t seed, uint32_t out[4]) {
|
||||||
|
const uint8_t* data = (const uint8_t*)key;
|
||||||
|
int nblocks = len / 16;
|
||||||
|
uint64_t h1 = seed, h2 = seed;
|
||||||
|
const uint64_t c1 = 0x87c37b91114253d5ULL;
|
||||||
|
const uint64_t c2 = 0x4cf5ad432745937fULL;
|
||||||
|
const uint64_t* blocks = (const uint64_t*)(data);
|
||||||
|
for (int i = 0; i < nblocks; i++) {
|
||||||
|
uint64_t k1 = getblock64(blocks, i*2+0);
|
||||||
|
uint64_t k2 = getblock64(blocks, i*2+1);
|
||||||
|
k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1;
|
||||||
|
h1 = ROTL64(h1,27); h1 += h2; h1 = h1 * 5 + 0x52dce729;
|
||||||
|
k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2;
|
||||||
|
h2 = ROTL64(h2,31); h2 += h1; h2 = h2 * 5 + 0x38495ab5;
|
||||||
|
}
|
||||||
|
const uint8_t* tail = (const uint8_t*)(data + nblocks * 16);
|
||||||
|
uint64_t k1 = 0, k2 = 0;
|
||||||
|
switch (len & 15) {
|
||||||
|
case 15: k2 ^= ((uint64_t)tail[14]) << 48;
|
||||||
|
case 14: k2 ^= ((uint64_t)tail[13]) << 40;
|
||||||
|
case 13: k2 ^= ((uint64_t)tail[12]) << 32;
|
||||||
|
case 12: k2 ^= ((uint64_t)tail[11]) << 24;
|
||||||
|
case 11: k2 ^= ((uint64_t)tail[10]) << 16;
|
||||||
|
case 10: k2 ^= ((uint64_t)tail[9]) << 8;
|
||||||
|
case 9: k2 ^= ((uint64_t)tail[8]) << 0;
|
||||||
|
k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2;
|
||||||
|
case 8: k1 ^= ((uint64_t)tail[7]) << 56;
|
||||||
|
case 7: k1 ^= ((uint64_t)tail[6]) << 48;
|
||||||
|
case 6: k1 ^= ((uint64_t)tail[5]) << 40;
|
||||||
|
case 5: k1 ^= ((uint64_t)tail[4]) << 32;
|
||||||
|
case 4: k1 ^= ((uint64_t)tail[3]) << 24;
|
||||||
|
case 3: k1 ^= ((uint64_t)tail[2]) << 16;
|
||||||
|
case 2: k1 ^= ((uint64_t)tail[1]) << 8;
|
||||||
|
case 1: k1 ^= ((uint64_t)tail[0]) << 0;
|
||||||
|
k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1;
|
||||||
|
};
|
||||||
|
h1 ^= len; h2 ^= len;
|
||||||
|
h1 += h2; h2 += h1;
|
||||||
|
h1 = fmix64(h1); h2 = fmix64(h2);
|
||||||
|
h1 += h2; h2 += h1;
|
||||||
|
out[0] = (uint32_t)h1; out[1] = (uint32_t)(h1>>32);
|
||||||
|
out[2] = (uint32_t)h2; out[3] = (uint32_t)(h2>>32);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// HashEmbed
|
||||||
|
// =========================================================================
|
||||||
|
struct HashEmbed {
|
||||||
|
int n_rows=0,nO=0; uint32_t seed=0; std::vector<float> table;
|
||||||
|
bool load(int r, int o, const float* d) { n_rows=r;nO=o;table.assign(d,d+(size_t)r*o);return!table.empty(); }
|
||||||
|
void embed(uint64_t fid, float* out) const {
|
||||||
|
uint8_t in[8]; for(int i=0;i<8;i++)in[i]=(uint8_t)(fid>>(i*8));
|
||||||
|
uint32_t keys[4]; mmh3_x64_128(in,8,seed,keys);
|
||||||
|
for(int v=0;v<4;v++){int idx=(int)(keys[v]%(uint32_t)n_rows);for(int i=0;i<nO;i++)out[i]+=table[(size_t)idx*nO+i];}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Features — dynamic based on n_embed
|
||||||
|
// en (6): NORM, PREFIX, SUFFIX, SHAPE, SPACY, IS_SPACE
|
||||||
|
// zh (5): NORM, PREFIX, SUFFIX, SHAPE, IS_SPACE
|
||||||
|
// =========================================================================
|
||||||
|
static uint64_t feat_norm(const std::string& t) {
|
||||||
|
std::string lo=t; std::transform(lo.begin(),lo.end(),lo.begin(),::tolower);
|
||||||
|
return hash_feat(lo);
|
||||||
|
}
|
||||||
|
// UTF-8 aware: get first Unicode codepoint as string
|
||||||
|
static std::string utf8_first(const std::string& s) {
|
||||||
|
if(s.empty()) return "";
|
||||||
|
unsigned char c=(unsigned char)s[0];
|
||||||
|
int l=1;
|
||||||
|
if((c&0xE0)==0xC0) l=2;
|
||||||
|
else if((c&0xF0)==0xE0) l=3;
|
||||||
|
else if((c&0xF8)==0xF0) l=4;
|
||||||
|
return s.substr(0,(size_t)l<=s.size()?l:1);
|
||||||
|
}
|
||||||
|
// Count Unicode codepoints in a string
|
||||||
|
static size_t utf8_len(const std::string& s) {
|
||||||
|
size_t n=0;
|
||||||
|
for(size_t i=0;i<s.size();n++){
|
||||||
|
unsigned char c=(unsigned char)s[i];
|
||||||
|
if((c&0x80)==0) i+=1;
|
||||||
|
else if((c&0xE0)==0xC0) i+=2;
|
||||||
|
else if((c&0xF0)==0xE0) i+=3;
|
||||||
|
else if((c&0xF8)==0xF0) i+=4;
|
||||||
|
else i+=1;
|
||||||
|
}
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
// Get suffix: last `count` Unicode codepoints
|
||||||
|
static std::string utf8_last(const std::string& s, size_t count) {
|
||||||
|
size_t ulen=utf8_len(s);
|
||||||
|
if(ulen<=count) return s;
|
||||||
|
// Find byte position of the (ulen-count)-th codepoint
|
||||||
|
size_t pos=0;
|
||||||
|
for(size_t i=0;i<ulen-count;i++){
|
||||||
|
unsigned char c=(unsigned char)s[pos];
|
||||||
|
if((c&0x80)==0) pos+=1;
|
||||||
|
else if((c&0xE0)==0xC0) pos+=2;
|
||||||
|
else if((c&0xF0)==0xE0) pos+=3;
|
||||||
|
else if((c&0xF8)==0xF0) pos+=4;
|
||||||
|
else pos+=1;
|
||||||
|
}
|
||||||
|
return s.substr(pos);
|
||||||
|
}
|
||||||
|
static uint64_t feat_prefix(const std::string& t) {
|
||||||
|
// spaCy: string[:1].lower() → hash the lowercased prefix
|
||||||
|
std::string p = t.empty() ? "" : utf8_first(t);
|
||||||
|
std::transform(p.begin(), p.end(), p.begin(), ::tolower);
|
||||||
|
return hash_feat(p);
|
||||||
|
}
|
||||||
|
static uint64_t feat_suffix(const std::string& t) {
|
||||||
|
// spaCy: string[-3:].lower() → hash the lowercased suffix
|
||||||
|
size_t ulen=utf8_len(t);
|
||||||
|
std::string s = ulen>=3 ? utf8_last(t,3) : t;
|
||||||
|
std::transform(s.begin(), s.end(), s.begin(), ::tolower);
|
||||||
|
return hash_feat(s);
|
||||||
|
}
|
||||||
|
static uint64_t feat_shape(const std::string& t) {
|
||||||
|
std::string sh;
|
||||||
|
for(unsigned char c:t){
|
||||||
|
if(c>0x7F)sh+='x'; // CJK → 'x' (matches spaCy zh shape)
|
||||||
|
else if(std::isupper(c))sh+='X';
|
||||||
|
else if(std::islower(c))sh+='x';
|
||||||
|
else if(std::isdigit(c))sh+='d';
|
||||||
|
else sh+=c;
|
||||||
|
}
|
||||||
|
return hash_feat(sh);
|
||||||
|
}
|
||||||
|
// Extract features based on n_embed count. Returns vector of hash values.
|
||||||
|
// NER model's tok2vec uses 4 features: NORM, PREFIX, SUFFIX, SHAPE
|
||||||
|
// (The pipeline's standalone tok2vec uses 6 features including SPACY and IS_SPACE.)
|
||||||
|
// Feature order matches the HashEmbed table order in the model.
|
||||||
|
static std::vector<uint64_t> extract_features(const std::string& t, int n_embed) {
|
||||||
|
std::vector<uint64_t> ids;
|
||||||
|
ids.push_back(feat_norm(t)); // #0: NORM (all models)
|
||||||
|
ids.push_back(feat_prefix(t)); // #1: PREFIX
|
||||||
|
ids.push_back(feat_suffix(t)); // #2: SUFFIX
|
||||||
|
ids.push_back(feat_shape(t)); // #3: SHAPE
|
||||||
|
if(n_embed==5) {
|
||||||
|
ids.push_back(0); // #4: IS_SPACE (zh/ja: 5-embed models, no SPACY)
|
||||||
|
} else if(n_embed>=6) {
|
||||||
|
ids.push_back(1); // #4: SPACY (en/de/fr/es/pt: 6-embed models)
|
||||||
|
ids.push_back(0); // #5: IS_SPACE
|
||||||
|
}
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Layers
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
|
// Kahan compensated dot product: reduces floating-point accumulation error
|
||||||
|
// for long dot products (e.g. 576 terms in Maxout).
|
||||||
|
static float kahan_dot(const float* a, const float* b, int n) {
|
||||||
|
float sum = 0.0f;
|
||||||
|
float c = 0.0f;
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
float y = a[i] * b[i] - c;
|
||||||
|
float t = sum + y;
|
||||||
|
c = (t - sum) - y;
|
||||||
|
sum = t;
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void linear(float* out, const float* in, const float* W, const float* b, int nO, int nI) {
|
||||||
|
for(int i=0;i<nO;i++)out[i]=b[i]+kahan_dot(W+(size_t)i*nI,in,nI);
|
||||||
|
}
|
||||||
|
static void relu_inplace(float* x, int n) { for(int i=0;i<n;i++)x[i]=x[i]>0?x[i]:0; }
|
||||||
|
|
||||||
|
// Maxout: y[i] = max_p(b[i,p] + W[i,p,:] @ in)
|
||||||
|
static void maxout(float* out, const float* in, const float* W, const float* b, int nO, int nP, int nI) {
|
||||||
|
for(int i=0;i<nO;i++){
|
||||||
|
float best=-1e30f;
|
||||||
|
for(int p=0;p<nP;p++){
|
||||||
|
float s = b[(size_t)i*nP+p] + kahan_dot(W+(((size_t)i*nP+p)*nI), in, nI);
|
||||||
|
if(s>best)best=s;
|
||||||
|
}
|
||||||
|
out[i]=best;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LayerNorm: y = G * (x-mean)/sqrt(var+eps) + b
|
||||||
|
static void layernorm(float* out, const float* in, int d, const float* G, const float* b, float eps) {
|
||||||
|
float mn=0,vr=0; for(int i=0;i<d;i++)mn+=in[i]; mn/=d;
|
||||||
|
for(int i=0;i<d;i++)vr+=(in[i]-mn)*(in[i]-mn); vr/=d;
|
||||||
|
float is=1.0f/sqrtf(vr+eps);
|
||||||
|
for(int i=0;i<d;i++)out[i]=G[i]*(in[i]-mn)*is+b[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpandWindow: for token at index `idx` over all_tokens[n_tokens×dim], produce [t-1, t, t+1]
|
||||||
|
static void expand_win(float* out, const float* all, int n, int dim, int idx) {
|
||||||
|
int off = idx*dim;
|
||||||
|
if(idx>0)memcpy(out,all+(idx-1)*dim,dim*sizeof(float)); else memset(out,0,dim*sizeof(float));
|
||||||
|
memcpy(out+dim,all+off,dim*sizeof(float));
|
||||||
|
if(idx<n-1)memcpy(out+2*dim,all+(idx+1)*dim,dim*sizeof(float)); else memset(out+2*dim,0,dim*sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
|
// BILUO decoder
|
||||||
|
struct Entity { std::string text,label; int start,end; float conf; };
|
||||||
|
static std::vector<Entity> decode_biluo(const std::vector<std::string>& tok, const std::vector<std::string>& lbl) {
|
||||||
|
std::vector<Entity> ents; int n=(int)tok.size(),st=-1; std::string et,ex;
|
||||||
|
for(int i=0;i<n;i++){
|
||||||
|
auto& l=lbl[i];
|
||||||
|
if(l.empty()||l=="O"){if(st>=0){ents.push_back({ex,et,st,i-1,0.85f});st=-1;et.clear();ex.clear();}continue;}
|
||||||
|
if(l.size()<3||l[1]!='-'){if(st>=0){ents.push_back({ex,et,st,i-1,0.85f});st=-1;}continue;}
|
||||||
|
char a=l[0]; std::string ty=l.substr(2);
|
||||||
|
if(a=='U'){if(st>=0){ents.push_back({ex,et,st,i-1,0.85f});st=-1;}ents.push_back({tok[i],ty,i,i,0.85f});}
|
||||||
|
else if(a=='B'){if(st>=0)ents.push_back({ex,et,st,i-1,0.85f});st=i;et=ty;ex=tok[i];}
|
||||||
|
else if(a=='I'){if(st>=0&&et==ty)ex+=" "+tok[i];else{if(st>=0)ents.push_back({ex,et,st,i-1,0.85f});st=i;et=ty;ex=tok[i];}}
|
||||||
|
else if(a=='L'){if(st>=0&&et==ty){ex+=" "+tok[i];ents.push_back({ex,et,st,i,0.85f});}else ents.push_back({tok[i],ty,i,i,0.85f});st=-1;et.clear();ex.clear();}
|
||||||
|
}
|
||||||
|
if(st>=0)ents.push_back({ex,et,st,n-1,0.85f});
|
||||||
|
return ents;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tokenizer
|
||||||
|
static std::vector<std::string> tokenize_en(const std::string& t) {
|
||||||
|
std::vector<std::string> r; std::string cur;
|
||||||
|
for(size_t i=0;i<t.size();i++){unsigned char c=(unsigned char)t[i];
|
||||||
|
if(std::isalpha(c)||std::isdigit(c)||c>127)cur+=c;
|
||||||
|
else if(c=='.'&&!cur.empty()&&i+1<t.size()&&std::isalpha((unsigned char)t[i+1]))cur+='.';
|
||||||
|
else{if(!cur.empty()){r.push_back(cur);cur.clear();}if(!std::isspace(c))r.push_back(std::string(1,(char)c));}}
|
||||||
|
if(!cur.empty())r.push_back(cur); return r;
|
||||||
|
}
|
||||||
|
static std::vector<std::string> tokenize_zh(const std::string& t) {
|
||||||
|
std::vector<std::string> r;
|
||||||
|
for(size_t i=0;i<t.size();i++){unsigned char c=(unsigned char)t[i];
|
||||||
|
if((c&0x80)==0){if(std::isalpha(c)||std::isdigit(c)){std::string w;while(i<t.size()&&(std::isalpha((unsigned char)t[i])||std::isdigit((unsigned char)t[i])))w+=t[i++];r.push_back(w);i--;}else if(!std::isspace(c))r.push_back(std::string(1,(char)c));}
|
||||||
|
else{int l=1;if((c&0xE0)==0xC0)l=2;else if((c&0xF0)==0xE0)l=3;else if((c&0xF8)==0xF0)l=4;r.push_back(t.substr(i,l));i+=l-1;}}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// State
|
||||||
|
// =========================================================================
|
||||||
|
struct State {
|
||||||
|
// HashEmbeds
|
||||||
|
std::vector<HashEmbed> embeds;
|
||||||
|
// Post-embed Maxout (576→96)
|
||||||
|
std::vector<float> poW,poB; int po_nO=96,po_nP=3,po_nI=576;
|
||||||
|
// Post-embed LayerNorm
|
||||||
|
std::vector<float> poG,poB2; bool has_poLN=false;
|
||||||
|
// Residual encoder (4 blocks)
|
||||||
|
struct ResBlk{bool has=false;std::vector<float>W,b,lnG,lnb;};
|
||||||
|
ResBlk res[4]; int n_res=0;
|
||||||
|
// NER hidden (96→64)
|
||||||
|
std::vector<float> hW,hB; int hO=64; bool has_hid=false;
|
||||||
|
// PrecomputableAffine: W_full[nP=3][nO=64][nI=2][nD=64], b_full[nO=64][nI=2]
|
||||||
|
// We use f=0 (first feature only): pre_out[p][o] = sum_d W[p][o][0][d] * hid[d] + b[o][0]
|
||||||
|
std::vector<float> pW_full; // flattened [3*64*2*64]
|
||||||
|
std::vector<float> pB_full; // flattened [64*2]
|
||||||
|
int p_nP=3, p_nO=64, p_nI=2, p_nD=64; bool has_pre=false;
|
||||||
|
// Classifier (64→n_actions)
|
||||||
|
std::vector<float> cW,cB; int nAct=0; bool has_cls=false;
|
||||||
|
std::vector<std::string> actLbl;
|
||||||
|
};
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Load model
|
||||||
|
// =========================================================================
|
||||||
|
static bool load(const std::string& dir, State* s) {
|
||||||
|
std::ifstream cf(dir+"/model.ckpt"); if(!cf){std::cerr<<"No model.ckpt\n";return false;}
|
||||||
|
std::stringstream cb;cb<<cf.rdbuf();
|
||||||
|
JVal ck=JParser().parse(cb.str()); if(ck.type!=JVal::OBJ)return false;
|
||||||
|
|
||||||
|
std::ifstream bf(dir+"/model.bin",std::ios::binary|std::ios::ate); if(!bf)return false;
|
||||||
|
size_t bz=bf.tellg();bf.seekg(0); std::vector<float> bin(bz/4); bf.read((char*)bin.data(),bz);
|
||||||
|
|
||||||
|
auto sl=[&](int64_t o, int64_t c)->std::vector<float>{
|
||||||
|
if(o+c>(int64_t)bin.size())return{}; return std::vector<float>(bin.begin()+o,bin.begin()+o+c);
|
||||||
|
};
|
||||||
|
auto ld=[&](const std::string& k, std::vector<float>* v, int* r0=nullptr,int* r1=nullptr,int* r2=nullptr)->bool{
|
||||||
|
auto* e=ck.get(k); if(!e)return false;
|
||||||
|
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
|
||||||
|
if(!sv||!ov||!cv)return false;
|
||||||
|
*v=sl(ov->as_i64(),cv->as_i64());
|
||||||
|
if(r0)*r0=sv->arr.size()>=1?sv->arr[0].as_int():1;
|
||||||
|
if(r1)*r1=sv->arr.size()>=2?sv->arr[1].as_int():1;
|
||||||
|
if(r2)*r2=sv->arr.size()>=3?sv->arr[2].as_int():1;
|
||||||
|
return!v->empty();
|
||||||
|
};
|
||||||
|
|
||||||
|
// HashEmbeds — dynamic count (6 for en, 5 for zh, etc.)
|
||||||
|
for(int ei=0;;ei++){
|
||||||
|
auto* e=ck.get("embed_"+std::to_string(ei)+"_E"); if(!e)break;
|
||||||
|
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
|
||||||
|
if(!sv||!ov||!cv)break;
|
||||||
|
int rs=sv->arr[0].as_int(),nO=sv->arr[1].as_int();
|
||||||
|
int64_t expected=(int64_t)rs*nO;
|
||||||
|
if(cv->as_i64()<expected)break; // count too short → malformed
|
||||||
|
auto d=sl(ov->as_i64(),cv->as_i64()); if(d.empty())break;
|
||||||
|
s->embeds.emplace_back(); s->embeds.back().load(rs,nO,d.data());
|
||||||
|
}
|
||||||
|
// Seeds
|
||||||
|
std::ifstream ff(dir+"/feature_config.json");
|
||||||
|
if(ff){std::stringstream fb;fb<<ff.rdbuf();auto cfg=JParser().parse(fb.str());auto* sa=cfg.get("embed_seeds");
|
||||||
|
if(sa&&sa->type==JVal::ARR)for(int i=0;i<(int)sa->arr.size()&&i<(int)s->embeds.size();i++)s->embeds[i].seed=(uint32_t)sa->arr[i].as_int();}
|
||||||
|
|
||||||
|
int r0=0,r1=0,r2=0;
|
||||||
|
// Post-embed
|
||||||
|
if(ld("poW",&s->poW,&r0,&r1,&r2)){s->po_nO=r0;s->po_nP=r1;s->po_nI=r2;ld("poB",&s->poB);}
|
||||||
|
if(ld("poG",&s->poG)){ld("poB2",&s->poB2);s->has_poLN=true;}
|
||||||
|
// Residual
|
||||||
|
for(int ri=0;ri<4;ri++){auto pk="res"+std::to_string(ri);auto& rb=s->res[ri];
|
||||||
|
if(ld(pk+"W",&rb.W,&r0,&r1,&r2)){ld(pk+"B",&rb.b);ld(pk+"lnG",&rb.lnG);ld(pk+"lnb",&rb.lnb);rb.has=true;s->n_res++;}}
|
||||||
|
// NER hidden
|
||||||
|
if(ld("hW",&s->hW,&r0,&r1)){s->hO=r0;ld("hB",&s->hB);s->has_hid=true;}
|
||||||
|
// PrecomputableAffine: load full 4D W and 2D b
|
||||||
|
// has_pre is only set when ALL of weight buffer, bias buffer,
|
||||||
|
// and hidden-dimension match (p_nD == hO) are satisfied.
|
||||||
|
{
|
||||||
|
auto* e=ck.get("pW_full"); if(e){
|
||||||
|
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
|
||||||
|
if(sv&&ov&&cv){
|
||||||
|
int nP=sv->arr.size()>=1?sv->arr[0].as_int():1;
|
||||||
|
int nO=sv->arr.size()>=2?sv->arr[1].as_int():1;
|
||||||
|
int nI=sv->arr.size()>=3?sv->arr[2].as_int():1;
|
||||||
|
int nD=sv->arr.size()>=4?sv->arr[3].as_int():1;
|
||||||
|
s->p_nP=nP; s->p_nO=nO; s->p_nI=nI; s->p_nD=nD;
|
||||||
|
size_t total = (size_t)nP * nO * nI * nD;
|
||||||
|
s->pW_full = sl(ov->as_i64(), cv->as_i64());
|
||||||
|
bool pw_ok = s->pW_full.size() >= total;
|
||||||
|
|
||||||
|
// Load bias inside pW_full block to access dimension info
|
||||||
|
bool pb_ok = false;
|
||||||
|
if(auto* pb_e=ck.get("pB_full")){
|
||||||
|
auto pb_ov=pb_e->get("offset"),pb_cv=pb_e->get("count");
|
||||||
|
if(pb_ov&&pb_cv){
|
||||||
|
s->pB_full = sl(pb_ov->as_i64(), pb_cv->as_i64());
|
||||||
|
pb_ok = s->pB_full.size() >= (size_t)nO * nI;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool dim_ok = (nD == s->hO);
|
||||||
|
s->has_pre = pw_ok && pb_ok && dim_ok;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Classifier
|
||||||
|
if(ld("cW",&s->cW,&r0,&r1)){s->nAct=r0;ld("cB",&s->cB);s->has_cls=true;}
|
||||||
|
|
||||||
|
return!s->embeds.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool load_labels(const std::string& dir, State* s) {
|
||||||
|
std::ifstream f(dir+"/labels.json"); if(!f)return false;
|
||||||
|
std::stringstream b;b<<f.rdbuf();
|
||||||
|
auto d=JParser().parse(b.str()); auto* am=d.get("action_to_label");
|
||||||
|
if(!am||am->type!=JVal::OBJ)return false;
|
||||||
|
int mx=0; for(auto&[k,v]:am->obj){try{int a=std::stoi(k);if(a>mx)mx=a;}catch(...){}};
|
||||||
|
int n=s->nAct>0?s->nAct:mx+1; s->actLbl.resize(n,"O");
|
||||||
|
for(auto&[k,v]:am->obj){try{int a=std::stoi(k);if(a>=0&&a<n)s->actLbl[a]=v.str;}catch(...){}}
|
||||||
|
return!s->actLbl.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// C API
|
||||||
|
// =========================================================================
|
||||||
|
ThincNERHandle ThincNER_Create(const char* d, const char*) {
|
||||||
|
auto* s=new State(); if(!load(d,s)){delete s;return nullptr;}
|
||||||
|
if(!load_labels(d,s))s->actLbl.resize(74,"O");
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
void ThincNER_Destroy(ThincNERHandle h) { delete (State*)h; }
|
||||||
|
|
||||||
|
char* ThincNER_Predict(ThincNERHandle h, const char* tj) {
|
||||||
|
auto* s=(State*)h; if(!s)return strdup("[]");
|
||||||
|
if(!tj)return strdup("[]");
|
||||||
|
|
||||||
|
// Parse tokens
|
||||||
|
std::vector<std::string> tok; std::string j(tj); size_t p=0;
|
||||||
|
while((p=j.find('"',p))!=std::string::npos){auto e=j.find('"',p+1);if(e==std::string::npos)break;std::string t=j.substr(p+1,e-p-1);if(!t.empty())tok.push_back(t);p=e+1;}
|
||||||
|
int n=(int)tok.size(); if(!n)return strdup("[]");
|
||||||
|
int NE=(int)s->embeds.size();
|
||||||
|
// Derive per-embed dimension from loaded tensors (all embed tables share the same nO)
|
||||||
|
int D = NE > 0 ? s->embeds[0].nO : 96;
|
||||||
|
int EC = NE * D;
|
||||||
|
|
||||||
|
// ---- Step 1: HashEmbed → concat (NER model: 4×96=384, pipe: 6×96=576) ----
|
||||||
|
std::vector<float> emb((size_t)n*EC,0);
|
||||||
|
for(int i=0;i<n;i++){
|
||||||
|
auto ids=extract_features(tok[i],NE);
|
||||||
|
size_t b=(size_t)i*EC;
|
||||||
|
for(int e=0;e<NE&&e<(int)s->embeds.size();e++)
|
||||||
|
s->embeds[e].embed(ids[e], emb.data()+b + (size_t)e*s->embeds[e].nO);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ---- Step 2: Post-embed Maxout (576→96) ----
|
||||||
|
std::vector<float> pe((size_t)n*D);
|
||||||
|
for(int i=0;i<n;i++)maxout(pe.data()+(size_t)i*D,emb.data()+(size_t)i*EC,s->poW.data(),s->poB.data(),s->po_nO,s->po_nP,s->po_nI);
|
||||||
|
|
||||||
|
// ---- Step 3: Post-embed LayerNorm ----
|
||||||
|
std::vector<float> pln((size_t)n*D,0);
|
||||||
|
if(s->has_poLN){for(int i=0;i<n;i++)layernorm(pln.data()+(size_t)i*D,pe.data()+(size_t)i*D,D,s->poG.data(),s->poB2.data(),1e-6f);}
|
||||||
|
else pln=pe;
|
||||||
|
|
||||||
|
// ---- Step 4: Residual encoder blocks ----
|
||||||
|
std::vector<float> enc=pln;
|
||||||
|
for(int ri=0;ri<s->n_res;ri++){
|
||||||
|
auto& blk=s->res[ri]; if(!blk.has)continue;
|
||||||
|
int wd=D*3;
|
||||||
|
std::vector<float> exp((size_t)n*wd);
|
||||||
|
for(int i=0;i<n;i++)expand_win(exp.data()+(size_t)i*wd,enc.data(),n,D,i);
|
||||||
|
std::vector<float> mx((size_t)n*D);
|
||||||
|
for(int i=0;i<n;i++)maxout(mx.data()+(size_t)i*D,exp.data()+(size_t)i*wd,blk.W.data(),blk.b.data(),D,3,wd);
|
||||||
|
std::vector<float> ln((size_t)n*D);
|
||||||
|
if(!blk.lnG.empty()){for(int i=0;i<n;i++)layernorm(ln.data()+(size_t)i*D,mx.data()+(size_t)i*D,D,blk.lnG.data(),blk.lnb.data(),1e-6f);}
|
||||||
|
else ln=mx;
|
||||||
|
for(int i=0;i<n;i++){float* op=enc.data()+(size_t)i*D;for(int j=0;j<D;j++)op[j]+=ln[(size_t)i*D+j];}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Step 5: NER tok2vec linear (96→64, no ReLU) ----
|
||||||
|
// The NER model's layers[0] ends with a bare linear (no activation).
|
||||||
|
// This produces the 64-dim token vectors that feed into the PrecomputableAffine.
|
||||||
|
std::vector<float> hid((size_t)n*s->hO);
|
||||||
|
if(s->has_hid){for(int i=0;i<n;i++){linear(hid.data()+(size_t)i*s->hO,enc.data()+(size_t)i*D,s->hW.data(),s->hB.data(),s->hO,D);}}
|
||||||
|
else{for(int i=0;i<n;i++){int c=std::min(D,s->hO);memcpy(hid.data()+i*s->hO,enc.data()+i*D,c*4);}}
|
||||||
|
|
||||||
|
// ---- Step 6: PrecomputableAffine → Maxout → Classifier → constrained decoding ----
|
||||||
|
// Matches the spaCy ParserStepModel's predict_states formula:
|
||||||
|
// cached[t][f][o*nP+p] = W[f,o,p,:] @ hid[t] + b[o,p] (f=0..nF-1, W=[nF,nO,nP,nI])
|
||||||
|
// unmaxed[o,nP+p] = sum_f cached[t][f][o*nP+p] (sum over nF features)
|
||||||
|
// unmaxed += bias[o,p] (add bias once, not nF times)
|
||||||
|
// hid_vec[o] = max(unmaxed[o*nP+0], unmaxed[o*nP+1])
|
||||||
|
// scores[a] = cW[a][:] @ hid_vec + cB[a]
|
||||||
|
//
|
||||||
|
// Feature token indices match spaCy's BiluoPushDown transition system:
|
||||||
|
// f=0: B(0) = buffer front = current token index
|
||||||
|
// f=1: S(0) = stack top = entity_start if in entity, else back-off to B(0)
|
||||||
|
// f=2: S(1) = stack second = back-off to B(0) (stack has ≤1 item in simple case)
|
||||||
|
auto label_type = [](const std::string& lbl) -> char { return lbl.empty() ? 'O' : lbl[0]; };
|
||||||
|
auto label_etype = [](const std::string& lbl) -> std::string { return lbl.size()<3?"":lbl.substr(2); };
|
||||||
|
|
||||||
|
std::vector<std::string> tl(n, "O");
|
||||||
|
if(s->has_cls && s->has_pre){
|
||||||
|
int nF=s->p_nP, nO=s->p_nO, nP=s->p_nI, nD=s->p_nD; // W: [nF, nO, nP, nI]
|
||||||
|
std::vector<float> unmaxed((size_t)nO * nP, 0);
|
||||||
|
std::vector<float> hid_vec(nO, 0);
|
||||||
|
std::vector<float> scores(s->nAct, 0);
|
||||||
|
int entity_start = -1; // token index of current B-entity start, -1 = no entity
|
||||||
|
|
||||||
|
for(int i=0;i<n;i++){
|
||||||
|
// Determine feature token indices from transition state
|
||||||
|
// (matches spaCy BiluoPushDown B(0)/S(0)/S(1) mapping)
|
||||||
|
int ft[3];
|
||||||
|
ft[0] = i; // B(0) = current token
|
||||||
|
if(entity_start >= 0) {
|
||||||
|
ft[1] = entity_start; // S(0) = entity start
|
||||||
|
ft[2] = i; // S(1) = back-off to B(0)
|
||||||
|
} else {
|
||||||
|
ft[1] = i; // S(0) = back-off to B(0)
|
||||||
|
ft[2] = i; // S(1) = back-off to B(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrecomputableAffine: pre[f][o][p] = W[f][o][p][:] @ hid[ft[f]] + b[o][p]
|
||||||
|
memset(unmaxed.data(), 0, (size_t)nO * nP * sizeof(float));
|
||||||
|
for(int f=0;f<nF;f++){
|
||||||
|
const float* hf = hid.data() + (size_t)ft[f] * nO;
|
||||||
|
for(int o=0;o<nO;o++){
|
||||||
|
for(int p=0;p<nP;p++){
|
||||||
|
size_t base = (((size_t)f * nO + o) * nP + p) * nD;
|
||||||
|
float val = 0;
|
||||||
|
for(int d=0;d<nD;d++){
|
||||||
|
val += s->pW_full[base + d] * hf[d];
|
||||||
|
}
|
||||||
|
unmaxed[(size_t)o * nP + p] += val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add bias ONCE (not nF times)
|
||||||
|
for(int o=0;o<nO;o++){
|
||||||
|
for(int p=0;p<nP;p++){
|
||||||
|
unmaxed[(size_t)o * nP + p] += s->pB_full[(size_t)o * nP + p];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Maxout: hid_vec[o] = max_p unmaxed[o*nP + p]
|
||||||
|
for(int o=0;o<nO;o++){
|
||||||
|
float best = unmaxed[(size_t)o * nP];
|
||||||
|
for(int p=1;p<nP;p++){
|
||||||
|
float v = unmaxed[(size_t)o * nP + p];
|
||||||
|
if(v > best) best = v;
|
||||||
|
}
|
||||||
|
hid_vec[o] = best;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Classifier: scores = cW @ hid_vec + cB
|
||||||
|
linear(scores.data(), hid_vec.data(), s->cW.data(), s->cB.data(), s->nAct, nO);
|
||||||
|
|
||||||
|
// Constrained greedy decoding
|
||||||
|
char prev_type = i>0 ? label_type(tl[i-1]) : 'O';
|
||||||
|
std::string prev_etype = i>0 ? label_etype(tl[i-1]) : "";
|
||||||
|
int bst=-1; float bv=-1e30f;
|
||||||
|
for(int a=0;a<s->nAct;a++){
|
||||||
|
const std::string& lbl = (a<(int)s->actLbl.size()) ? s->actLbl[a] : "O";
|
||||||
|
if(lbl.empty()) continue;
|
||||||
|
char ct = label_type(lbl);
|
||||||
|
std::string ce = label_etype(lbl);
|
||||||
|
bool valid=false;
|
||||||
|
if(prev_type=='O'||prev_type=='L'||prev_type=='U')
|
||||||
|
valid = (ct=='O'||ct=='B'||ct=='U');
|
||||||
|
else if(prev_type=='B'||prev_type=='I'){
|
||||||
|
if(ct=='O') valid=true;
|
||||||
|
else if((ct=='I'||ct=='L')&&ce==prev_etype) valid=true;
|
||||||
|
}
|
||||||
|
if(!valid) continue;
|
||||||
|
if(scores[a]>bv){bv=scores[a];bst=a;}
|
||||||
|
}
|
||||||
|
if(bst>=0) {
|
||||||
|
tl[i] = s->actLbl[bst];
|
||||||
|
// Update entity_start for next token (BiluoPushDown stack tracking)
|
||||||
|
char ct = label_type(tl[i]);
|
||||||
|
if(ct == 'B') entity_start = i;
|
||||||
|
else if(ct == 'I' || ct == 'L') { /* entity continues, keep entity_start */ }
|
||||||
|
else entity_start = -1; // O or U → no entity open
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Step 8: BILUO decode ----
|
||||||
|
auto ents=decode_biluo(tok,tl);
|
||||||
|
std::string r="["; for(size_t i=0;i<ents.size();i++){if(i)r+=",";r+="{\"text\":\""+ents[i].text+"\",\"label\":\""+ents[i].label+"\",\"start\":"+std::to_string(ents[i].start)+",\"end\":"+std::to_string(ents[i].end)+",\"confidence\":"+std::to_string(ents[i].conf)+"}";} r+="]";
|
||||||
|
return strdup(r.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThincNER_FreeString(char* p) { free(p); }
|
||||||
|
|
||||||
|
char* ThincNER_Tokenize(const char* t, const char* l) {
|
||||||
|
if(!t)return strdup("[]");
|
||||||
|
std::string lang = l ? std::string(l) : "";
|
||||||
|
// de, fr, es, pt: European Latin -> en tokenizer; zh, ja: CJK -> zh tokenizer
|
||||||
|
bool is_cjk = (lang == "zh" || lang == "ja");
|
||||||
|
auto tok = is_cjk ? tokenize_zh(t) : tokenize_en(t);
|
||||||
|
std::string r="["; for(size_t i=0;i<tok.size();i++){if(i)r+=",";r+="\""+tok[i]+"\"";} r+="]";
|
||||||
|
return strdup(r.c_str());
|
||||||
|
}
|
||||||
45
internal/cpp/thinc_ner.h
Normal file
45
internal/cpp/thinc_ner.h
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
#ifndef THINC_NER_H
|
||||||
|
#define THINC_NER_H
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// C API for spaCy model inference (en_core_web_sm / zh_core_web_sm)
|
||||||
|
//
|
||||||
|
// Loads model.ckpt + model.bin directly from a spaCy model directory.
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
typedef void* ThincNERHandle;
|
||||||
|
|
||||||
|
// Create / destroy an inference handle for a single spaCy model.
|
||||||
|
// model_dir: path to the model component directory, e.g.
|
||||||
|
// "models/en_core_web_sm-3.7.1/ner/"
|
||||||
|
// "models/zh_core_web_sm-3.7.1/ner/"
|
||||||
|
// Returns NULL on failure.
|
||||||
|
ThincNERHandle ThincNER_Create(const char* model_ner_dir, const char* model_vocab_dir);
|
||||||
|
void ThincNER_Destroy(ThincNERHandle handle);
|
||||||
|
|
||||||
|
// Run NER on pre-tokenized text.
|
||||||
|
// tokens_json: JSON array of token strings, e.g. ["Apple", "Inc.", "was", ...]
|
||||||
|
// Returns JSON array of entities:
|
||||||
|
// [{"text":"Apple Inc.","label":"ORG","start":0,"end":10,"confidence":0.85}, ...]
|
||||||
|
// Caller must free with ThincNER_FreeString.
|
||||||
|
char* ThincNER_Predict(ThincNERHandle handle, const char* tokens_json);
|
||||||
|
|
||||||
|
// Free a string returned by ThincNER_Predict.
|
||||||
|
void ThincNER_FreeString(char* ptr);
|
||||||
|
|
||||||
|
// Utility: tokenize text using spaCy-compatible rules.
|
||||||
|
// Returns JSON array of token strings.
|
||||||
|
char* ThincNER_Tokenize(const char* text, const char* lang);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // THINC_NER_H
|
||||||
530
internal/cpp/thinc_parser.cpp
Normal file
530
internal/cpp/thinc_parser.cpp
Normal file
@@ -0,0 +1,530 @@
|
|||||||
|
#include "thinc_parser.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Minimal JSON parser (replicated from thinc_ner.cpp)
|
||||||
|
// =========================================================================
|
||||||
|
namespace {
|
||||||
|
struct JVal {
|
||||||
|
enum Type{NUL,OBJ,ARR,STR,NUM,BOOL}type=NUL;
|
||||||
|
std::string str; std::vector<JVal> arr; std::unordered_map<std::string,JVal> obj; double num=0;
|
||||||
|
const JVal* get(const std::string& k)const{auto it=obj.find(k);return it!=obj.end()?&it->second:nullptr;}
|
||||||
|
int as_int()const{return(int)num;}int64_t as_i64()const{return(int64_t)num;}
|
||||||
|
};
|
||||||
|
struct JParser {
|
||||||
|
const char *p,*e;
|
||||||
|
char pk(){while(p<e&&(*p==' '||*p=='\t'||*p=='\n'||*p=='\r'))++p;return p<e?*p:0;}
|
||||||
|
char nx(){while(p<e&&(*p==' '||*p=='\t'||*p=='\n'||*p=='\r'))++p;return p<e?*p++:0;}
|
||||||
|
JVal pv(){char c=pk();if(c=='{')return po();if(c=='[')return pa();if(c=='"')return ps();if(c=='t'||c=='f')return pb();
|
||||||
|
if(c=='n'){nx();nx();nx();nx();return JVal{};}return pn();}
|
||||||
|
JVal po(){JVal v;v.type=JVal::OBJ;nx();while(p<e&&pk()!='}'){auto k=ps();nx();v.obj[k.str]=pv();if(p<e&&pk()==',')nx();else break;}if(p<e)nx();return v;}
|
||||||
|
JVal pa(){JVal v;v.type=JVal::ARR;nx();while(p<e&&pk()!=']'){v.arr.push_back(pv());if(p<e&&pk()==',')nx();else break;}if(p<e)nx();return v;}
|
||||||
|
JVal ps(){JVal v;v.type=JVal::STR;nx();while(p<e&&*p!='"'){if(*p=='\\'){++p;if(p<e){
|
||||||
|
switch(*p){case'"':case'\\':case'/':v.str+=*p++;break;case'n':v.str+='\n';++p;break;case't':v.str+='\t';++p;break;case'r':v.str+='\r';++p;break;case'b':v.str+='\b';++p;break;case'f':v.str+='\f';++p;break;case'u':{if(p+4<e){char tmp[5]={p[1],p[2],p[3],p[4],0};v.str+=(char)strtol(tmp,nullptr,16);p+=5;}else{++p;}}break;default:v.str+=*p++;break;}
|
||||||
|
}}else v.str+=*p++;}if(p<e)++p;return v;}
|
||||||
|
JVal pn(){JVal v;v.type=JVal::NUM;auto s=p;if(p<e&&*p=='-')++p;while(p<e&&(*p>='0'&&*p<='9'))++p;
|
||||||
|
if(p<e&&*p=='.'){++p;while(p<e&&(*p>='0'&&*p<='9'))++p;}
|
||||||
|
if(p<e&&(*p=='e'||*p=='E')){++p;if(p<e&&(*p=='+'||*p=='-'))++p;while(p<e&&(*p>='0'&&*p<='9'))++p;}
|
||||||
|
if(s<p){try{v.num=std::stod(std::string(s,p-s));}catch(...){v.num=0;}}return v;}
|
||||||
|
JVal pb(){JVal v;v.type=JVal::BOOL;if(e-p>=4&&*p=='t'){v.str="true";p+=4;}else if(e-p>=5&&*p=='f'){v.str="false";p+=5;}return v;}
|
||||||
|
JVal parse(const std::string& j){p=j.data();e=p+j.size();return pv();}
|
||||||
|
};
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// HashEmbed + MurmurHash (copied from thinc_ner.cpp)
|
||||||
|
// =========================================================================
|
||||||
|
#define ROTL64(x,r) ((x << r) | (x >> (64 - r)))
|
||||||
|
static uint64_t getblock64(const uint64_t* p, size_t i) { uint64_t r; memcpy(&r, p+i, 8); return r; }
|
||||||
|
static uint64_t fmix64(uint64_t k) {
|
||||||
|
k ^= k >> 33; k *= 0xff51afd7ed558ccdULL; k ^= k >> 33; k *= 0xc4ceb9fe1a85ec53ULL; k ^= k >> 33; return k;
|
||||||
|
}
|
||||||
|
static void mmh3_x64_128(const void* key, int len, uint32_t seed, uint32_t out[4]) {
|
||||||
|
const uint8_t* data=(const uint8_t*)key; int nblocks=len/16;
|
||||||
|
uint64_t h1=seed,h2=seed,c1=0x87c37b91114253d5ULL,c2=0x4cf5ad432745937fULL;
|
||||||
|
const uint64_t* blocks=(const uint64_t*)data;
|
||||||
|
for(int i=0;i<nblocks;i++){
|
||||||
|
uint64_t k1=getblock64(blocks,i*2+0),k2=getblock64(blocks,i*2+1);
|
||||||
|
k1*=c1;k1=ROTL64(k1,31);k1*=c2;h1^=k1;h1=ROTL64(h1,27);h1+=h2;h1=h1*5+0x52dce729;
|
||||||
|
k2*=c2;k2=ROTL64(k2,33);k2*=c1;h2^=k2;h2=ROTL64(h2,31);h2+=h1;h2=h2*5+0x38495ab5;
|
||||||
|
}
|
||||||
|
const uint8_t* tail=(const uint8_t*)(data+nblocks*16);uint64_t k1=0,k2=0;
|
||||||
|
switch(len&15){
|
||||||
|
case 15:k2^=((uint64_t)tail[14])<<48;case 14:k2^=((uint64_t)tail[13])<<40;
|
||||||
|
case 13:k2^=((uint64_t)tail[12])<<32;case 12:k2^=((uint64_t)tail[11])<<24;
|
||||||
|
case 11:k2^=((uint64_t)tail[10])<<16;case 10:k2^=((uint64_t)tail[9])<<8;
|
||||||
|
case 9:k2^=((uint64_t)tail[8])<<0;k2*=c2;k2=ROTL64(k2,33);k2*=c1;h2^=k2;
|
||||||
|
case 8:k1^=((uint64_t)tail[7])<<56;case 7:k1^=((uint64_t)tail[6])<<48;
|
||||||
|
case 6:k1^=((uint64_t)tail[5])<<40;case 5:k1^=((uint64_t)tail[4])<<32;
|
||||||
|
case 4:k1^=((uint64_t)tail[3])<<24;case 3:k1^=((uint64_t)tail[2])<<16;
|
||||||
|
case 2:k1^=((uint64_t)tail[1])<<8;case 1:k1^=((uint64_t)tail[0])<<0;k1*=c1;k1=ROTL64(k1,31);k1*=c2;h1^=k1;
|
||||||
|
};h1^=len;h2^=len;h1+=h2;h2+=h1;h1=fmix64(h1);h2=fmix64(h2);h1+=h2;h2+=h1;
|
||||||
|
out[0]=(uint32_t)h1;out[1]=(uint32_t)(h1>>32);out[2]=(uint32_t)h2;out[3]=(uint32_t)(h2>>32);
|
||||||
|
}
|
||||||
|
|
||||||
|
static uint64_t mh2_64a(const void* key, int len, uint64_t seed) {
|
||||||
|
const uint64_t m=0xc6a4a7935bd1e995ULL;const int r=47;
|
||||||
|
uint64_t h=seed^(uint64_t(len)*m);auto d=(const uint8_t*)key;int rm=len;
|
||||||
|
while(rm>=8){uint64_t k;memcpy(&k,d,8);k*=m;k^=k>>r;k*=m;h^=k;h*=m;d+=8;rm-=8;}
|
||||||
|
switch(rm){case 7:h^=uint64_t(d[6])<<48;case 6:h^=uint64_t(d[5])<<40;case 5:h^=uint64_t(d[4])<<32;
|
||||||
|
case 4:h^=uint64_t(d[3])<<24;case 3:h^=uint64_t(d[2])<<16;case 2:h^=uint64_t(d[1])<<8;case 1:h^=d[0];h*=m;break;}
|
||||||
|
h^=h>>r;h*=m;h^=h>>r;return h;
|
||||||
|
}
|
||||||
|
static uint64_t hash_feat(const std::string& s){return s.empty()?0:mh2_64a(s.data(),(int)s.size(),0);}
|
||||||
|
|
||||||
|
struct HashEmbed{
|
||||||
|
int n_rows=0,nO=0;uint32_t seed=0;std::vector<float> table;
|
||||||
|
bool load(int r,int o,const float*d){n_rows=r;nO=o;table.assign(d,d+(size_t)r*o);return!table.empty();}
|
||||||
|
void embed(uint64_t fid,float* out)const{
|
||||||
|
uint8_t in[8];for(int i=0;i<8;i++)in[i]=(uint8_t)(fid>>(i*8));
|
||||||
|
uint32_t keys[4];mmh3_x64_128(in,8,seed,keys);
|
||||||
|
for(int v=0;v<4;v++){int idx=(int)(keys[v]%(uint32_t)n_rows);for(int i=0;i<nO;i++)out[i]+=table[(size_t)idx*nO+i];}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Layer primitives
|
||||||
|
// =========================================================================
|
||||||
|
static void linear(float* out, const float* in, const float* W, const float* b, int nO, int nI) {
|
||||||
|
for(int i=0;i<nO;i++){float s=b[i];for(int j=0;j<nI;j++)s+=W[(size_t)i*nI+j]*in[j];out[i]=s;}
|
||||||
|
}
|
||||||
|
static void maxout(float* out, const float* in, const float* W, const float* b, int nO, int nP, int nI) {
|
||||||
|
for(int i=0;i<nO;i++){float best=-1e30f;for(int p=0;p<nP;p++){float s=b[(size_t)i*nP+p];for(int j=0;j<nI;j++)s+=W[((size_t)i*nP+p)*nI+j]*in[j];if(s>best)best=s;}out[i]=best;}
|
||||||
|
}
|
||||||
|
static void layernorm(float* out, const float* in, int d, const float* G, const float* b, float eps) {
|
||||||
|
float mn=0,vr=0;for(int i=0;i<d;i++)mn+=in[i];mn/=d;for(int i=0;i<d;i++)vr+=(in[i]-mn)*(in[i]-mn);vr/=d;float is=1.0f/sqrtf(vr+eps);
|
||||||
|
for(int i=0;i<d;i++)out[i]=G[i]*(in[i]-mn)*is+b[i];
|
||||||
|
}
|
||||||
|
static void expand_win(float* out, const float* all, int n, int dim, int idx) {
|
||||||
|
int off=idx*dim;if(idx>0)memcpy(out,all+(idx-1)*dim,dim*sizeof(float));else memset(out,0,dim*sizeof(float));
|
||||||
|
memcpy(out+dim,all+off,dim*sizeof(float));if(idx<n-1)memcpy(out+2*dim,all+(idx+1)*dim,dim*sizeof(float));else memset(out+2*dim,0,dim*sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Feature extraction — UTF-8 aware, matching spaCy
|
||||||
|
// =========================================================================
|
||||||
|
static std::string u8_first(const std::string& s){
|
||||||
|
if(s.empty())return"";unsigned char c=(unsigned char)s[0];int l=1;
|
||||||
|
if((c&0xE0)==0xC0)l=2;else if((c&0xF0)==0xE0)l=3;else if((c&0xF8)==0xF0)l=4;
|
||||||
|
return s.substr(0,(size_t)l<=s.size()?l:1);
|
||||||
|
}
|
||||||
|
static size_t u8_len(const std::string& s){
|
||||||
|
size_t n=0;
|
||||||
|
for(size_t i=0;i<s.size();n++){unsigned char c=(unsigned char)s[i];
|
||||||
|
if((c&0x80)==0)i+=1;else if((c&0xE0)==0xC0)i+=2;else if((c&0xF0)==0xE0)i+=3;else if((c&0xF8)==0xF0)i+=4;else i+=1;}
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
static std::string u8_last(const std::string& s, size_t count){
|
||||||
|
size_t ul=u8_len(s);if(ul<=count)return s;size_t pos=0;
|
||||||
|
for(size_t i=0;i<ul-count;i++){unsigned char c=(unsigned char)s[pos];
|
||||||
|
if((c&0x80)==0)pos+=1;else if((c&0xE0)==0xC0)pos+=2;else if((c&0xF0)==0xE0)pos+=3;else if((c&0xF8)==0xF0)pos+=4;else pos+=1;}
|
||||||
|
return s.substr(pos);
|
||||||
|
}
|
||||||
|
static std::vector<uint64_t> extract_features(const std::string& t, int n_embed){
|
||||||
|
auto fn=[&](const std::string& s){return hash_feat(s);};
|
||||||
|
auto fp=[&](const std::string& s){std::string p=s.empty()?"":u8_first(s);std::transform(p.begin(),p.end(),p.begin(),::tolower);return hash_feat(p);};
|
||||||
|
auto fs=[&](const std::string& s){std::string su=u8_len(s)>=3?u8_last(s,3):s;std::transform(su.begin(),su.end(),su.begin(),::tolower);return hash_feat(su);};
|
||||||
|
auto fsh=[&](const std::string& t2){std::string sh;for(unsigned char c:t2){if(c>0x7F)sh+='x';else if(std::isupper(c))sh+='X';else if(std::islower(c))sh+='x';else if(std::isdigit(c))sh+='d';else sh+=c;}return hash_feat(sh);};
|
||||||
|
std::vector<uint64_t> ids;
|
||||||
|
std::string lo=t;std::transform(lo.begin(),lo.end(),lo.begin(),::tolower);
|
||||||
|
ids.push_back(fn(lo));ids.push_back(fp(t));ids.push_back(fs(t));ids.push_back(fsh(t));
|
||||||
|
if(n_embed==6){ids.push_back(1);ids.push_back(0);}else{ids.push_back(0);}
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Tok2vec forward pass (shared with NER)
|
||||||
|
// =========================================================================
|
||||||
|
struct Tok2vecModel {
|
||||||
|
std::vector<HashEmbed> embeds;
|
||||||
|
std::vector<float> poW,poB,poG,poB2; bool has_poLN=false;
|
||||||
|
int po_nO=96,po_nP=3,po_nI=576;
|
||||||
|
struct ResBlk{bool has=false;std::vector<float>W,b,lnG,lnb;};
|
||||||
|
ResBlk res[4]; int n_res=0;
|
||||||
|
|
||||||
|
bool load(const std::string& dir) {
|
||||||
|
std::ifstream cf(dir+"/model.ckpt"); if(!cf)return false;
|
||||||
|
std::stringstream cb;cb<<cf.rdbuf();
|
||||||
|
JVal ck=JParser().parse(cb.str()); if(ck.type!=JVal::OBJ)return false;
|
||||||
|
std::ifstream bf(dir+"/model.bin",std::ios::binary|std::ios::ate); if(!bf)return false;
|
||||||
|
size_t bz=bf.tellg();bf.seekg(0); if(bz%4!=0||bz==0)return false;
|
||||||
|
std::vector<float> bin(bz/4); bf.read((char*)bin.data(),bz);
|
||||||
|
auto sl=[&](int64_t o,int64_t c)->std::vector<float>{
|
||||||
|
if(o+c>(int64_t)bin.size())return{}; return std::vector<float>(bin.begin()+o,bin.begin()+o+c);
|
||||||
|
};
|
||||||
|
auto ld=[&](const std::string& k, std::vector<float>* v,int* r0=nullptr,int* r1=nullptr,int* r2=nullptr)->bool{
|
||||||
|
auto* e=ck.get(k);if(!e)return false;
|
||||||
|
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
|
||||||
|
if(!sv||!ov||!cv)return false;
|
||||||
|
*v=sl(ov->as_i64(),cv->as_i64());
|
||||||
|
if(r0)*r0=sv->arr.size()>=1?sv->arr[0].as_int():1;
|
||||||
|
if(r1)*r1=sv->arr.size()>=2?sv->arr[1].as_int():1;
|
||||||
|
if(r2)*r2=sv->arr.size()>=3?sv->arr[2].as_int():1;
|
||||||
|
return!v->empty();
|
||||||
|
};
|
||||||
|
for(int ei=0;;ei++){
|
||||||
|
auto* e=ck.get("embed_"+std::to_string(ei)+"_E");if(!e)break;
|
||||||
|
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
|
||||||
|
if(!sv||!ov||!cv)break;int rs=sv->arr[0].as_int(),nO=sv->arr[1].as_int();
|
||||||
|
int64_t exp=(int64_t)rs*nO;if(cv->as_i64()<exp)break;
|
||||||
|
auto d=sl(ov->as_i64(),cv->as_i64());if(d.empty())break;
|
||||||
|
embeds.emplace_back();embeds.back().load(rs,nO,d.data());
|
||||||
|
}
|
||||||
|
std::ifstream ff(dir+"/feature_config.json");
|
||||||
|
if(ff){std::stringstream fb;fb<<ff.rdbuf();auto cfg=JParser().parse(fb.str());auto* sa=cfg.get("embed_seeds");
|
||||||
|
if(sa&&sa->type==JVal::ARR)for(int i=0;i<(int)sa->arr.size()&&i<(int)embeds.size();i++)embeds[i].seed=(uint32_t)sa->arr[i].as_int();}
|
||||||
|
int r0=0,r1=0,r2=0;
|
||||||
|
if(!ld("poW",&poW,&r0,&r1,&r2))return false;
|
||||||
|
po_nO=r0;po_nP=r1;po_nI=r2;
|
||||||
|
if(!ld("poB",&poB))return false;
|
||||||
|
if(ld("poG",&poG)){ld("poB2",&poB2);has_poLN=true;}
|
||||||
|
for(int ri=0;ri<4;ri++){auto pk="res"+std::to_string(ri);auto& rb=res[ri];
|
||||||
|
if(ld(pk+"W",&rb.W,&r0,&r1,&r2)){ld(pk+"B",&rb.b);ld(pk+"lnG",&rb.lnG);ld(pk+"lnb",&rb.lnb);rb.has=true;n_res++;}}
|
||||||
|
return!embeds.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run tok2vec → (n_tokens, 96)
|
||||||
|
void forward(const std::vector<std::string>& tokens, float* out) {
|
||||||
|
int n=(int)tokens.size(),D=96,NE=(int)embeds.size(),EC=NE*D;
|
||||||
|
std::vector<float> emb((size_t)n*EC,0);
|
||||||
|
for(int i=0;i<n;i++){
|
||||||
|
auto ids=extract_features(tokens[i],NE);
|
||||||
|
size_t b=(size_t)i*EC;
|
||||||
|
for(int e=0;e<NE;e++)embeds[e].embed(ids[e],emb.data()+b+(size_t)e*D);
|
||||||
|
}
|
||||||
|
std::vector<float> pe((size_t)n*D);
|
||||||
|
for(int i=0;i<n;i++)maxout(pe.data()+(size_t)i*D,emb.data()+(size_t)i*EC,poW.data(),poB.data(),D,po_nP,EC);
|
||||||
|
std::vector<float> pln((size_t)n*D,0);
|
||||||
|
if(has_poLN)for(int i=0;i<n;i++)layernorm(pln.data()+(size_t)i*D,pe.data()+(size_t)i*D,D,poG.data(),poB2.data(),1e-6f);else pln=pe;
|
||||||
|
std::vector<float> enc=pln;
|
||||||
|
for(int ri=0;ri<n_res;ri++){if(!res[ri].has)continue;
|
||||||
|
int wd=D*3;std::vector<float> exp((size_t)n*wd);
|
||||||
|
for(int i=0;i<n;i++)expand_win(exp.data()+(size_t)i*wd,enc.data(),n,D,i);
|
||||||
|
std::vector<float> mx((size_t)n*D);for(int i=0;i<n;i++)maxout(mx.data()+(size_t)i*D,exp.data()+(size_t)i*wd,res[ri].W.data(),res[ri].b.data(),D,3,wd);
|
||||||
|
std::vector<float> ln((size_t)n*D);if(!res[ri].lnG.empty())for(int i=0;i<n;i++)layernorm(ln.data()+(size_t)i*D,mx.data()+(size_t)i*D,D,res[ri].lnG.data(),res[ri].lnb.data(),1e-6f);else ln=mx;
|
||||||
|
for(int i=0;i<n;i++){float* op=enc.data()+(size_t)i*D;for(int j=0;j<D;j++)op[j]+=ln[(size_t)i*D+j];}
|
||||||
|
}
|
||||||
|
memcpy(out,enc.data(),(size_t)n*D*sizeof(float));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Arc-hybrid Parser
|
||||||
|
// =========================================================================
|
||||||
|
struct ParserModel {
|
||||||
|
int nO=64,nP=8,nI=2,n_actions=0;
|
||||||
|
std::vector<float> pW_hid,pb_hid; // 96→64
|
||||||
|
std::vector<float> pW_pre,pb_pre,pad_pre; // preaffine
|
||||||
|
std::vector<float> pW_cls,pb_cls; // classifier
|
||||||
|
std::vector<std::string> move_names;
|
||||||
|
|
||||||
|
bool load(const std::string& dir) {
|
||||||
|
std::ifstream cf(dir+"/model.ckpt"); if(!cf)return false;
|
||||||
|
std::stringstream cb;cb<<cf.rdbuf();
|
||||||
|
JVal ck=JParser().parse(cb.str()); if(ck.type!=JVal::OBJ)return false;
|
||||||
|
std::ifstream bf(dir+"/model.bin",std::ios::binary|std::ios::ate); if(!bf)return false;
|
||||||
|
size_t bz=bf.tellg();bf.seekg(0); if(bz%4!=0||bz==0)return false;
|
||||||
|
std::vector<float> bin(bz/4); bf.read((char*)bin.data(),bz);
|
||||||
|
auto sl=[&](int64_t o,int64_t c)->std::vector<float>{
|
||||||
|
if(o+c>(int64_t)bin.size())return{}; return std::vector<float>(bin.begin()+o,bin.begin()+o+c);
|
||||||
|
};
|
||||||
|
auto ld=[&](const std::string& k, std::vector<float>* v,int* r0=nullptr,int* r1=nullptr,int* r2=nullptr)->bool{
|
||||||
|
auto* e=ck.get(k);if(!e)return false;
|
||||||
|
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
|
||||||
|
if(!sv||!ov||!cv)return false;
|
||||||
|
*v=sl(ov->as_i64(),cv->as_i64());
|
||||||
|
if(r0)*r0=sv->arr.size()>=1?sv->arr[0].as_int():1;
|
||||||
|
if(r1)*r1=sv->arr.size()>=2?sv->arr[1].as_int():1;
|
||||||
|
if(r2)*r2=sv->arr.size()>=3?sv->arr[2].as_int():1;
|
||||||
|
return!v->empty();
|
||||||
|
};
|
||||||
|
int r0,r1,r2;
|
||||||
|
if(!ld("pW_hid",&pW_hid,&r0,&r1))return false;
|
||||||
|
nO=r0;ld("pb_hid",&pb_hid);
|
||||||
|
if(!ld("pW_pre",&pW_pre,&r0,&r1,&r2))return false;
|
||||||
|
nP=r0;nO=r1;nI=r2;
|
||||||
|
ld("pb_pre",&pb_pre);ld("pad_pre",&pad_pre);
|
||||||
|
if(!ld("pW_cls",&pW_cls,&r0,&r1))return false;
|
||||||
|
n_actions=r0;ld("pb_cls",&pb_cls);
|
||||||
|
std::ifstream mf(dir+"/meta.json");
|
||||||
|
if(mf){std::stringstream mb;mb<<mf.rdbuf();auto meta=JParser().parse(mb.str());auto* mn=meta.get("move_names");
|
||||||
|
if(mn&&mn->type==JVal::ARR)for(auto& v:mn->arr)move_names.push_back(v.str);}
|
||||||
|
return!pW_hid.empty() && !pW_pre.empty() && !pW_cls.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run parser forward + state machine → (heads, labels)
|
||||||
|
void parse(const float* tokvecs, int n_tokens,
|
||||||
|
std::vector<int>& out_heads, std::vector<std::string>& out_labels) {
|
||||||
|
// 1. Hidden layer: 96→64
|
||||||
|
std::vector<float> hidden((size_t)n_tokens*nO,0);
|
||||||
|
for(int i=0;i<n_tokens;i++) linear(hidden.data()+(size_t)i*nO, tokvecs+(size_t)i*96,
|
||||||
|
pW_hid.data(), pb_hid.data(), nO, 96);
|
||||||
|
|
||||||
|
// 2. Pre-compute features
|
||||||
|
std::vector<float> precomp((size_t)(n_tokens+1)*nP*nO*nI,0);
|
||||||
|
// Pad token (index 0)
|
||||||
|
memcpy(precomp.data(), pad_pre.data(), (size_t)nP*nO*nI*sizeof(float));
|
||||||
|
// Real tokens
|
||||||
|
for(int i=0;i<n_tokens;i++){
|
||||||
|
size_t toff = (size_t)(i+1)*nP*nO*nI;
|
||||||
|
for(int p=0;p<nP;p++){
|
||||||
|
for(int w=0;w<nI;w++){
|
||||||
|
size_t base = toff + (size_t)p*nO*nI + (size_t)w;
|
||||||
|
float* out = precomp.data() + base;
|
||||||
|
// W[p][w][o][d] = [nP][nI][nO][nO]
|
||||||
|
for(int o=0;o<nO;o++){
|
||||||
|
float s = pb_pre[(size_t)w*nO + o];
|
||||||
|
for(int d=0;d<nO;d++) s += pW_pre[((size_t)p*nO*nI + (size_t)o*nI + w)*nO + d] * hidden[(size_t)i*nO + d];
|
||||||
|
out[(size_t)o*nI] = s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper: get feature value for a token at piece p, window w, output dim o
|
||||||
|
auto feat = [&](int idx, int p, int w, int o) -> float {
|
||||||
|
int ri = (idx < 0 || idx >= n_tokens) ? 0 : idx + 1;
|
||||||
|
return precomp[(size_t)ri*nP*nO*nI + (size_t)p*nO*nI + (size_t)w + (size_t)o*nI];
|
||||||
|
};
|
||||||
|
|
||||||
|
// 3. Arc-hybrid state machine
|
||||||
|
out_heads.assign(n_tokens, -1);
|
||||||
|
out_labels.assign(n_tokens, "");
|
||||||
|
std::vector<int> stack;
|
||||||
|
std::vector<int> buffer(n_tokens);
|
||||||
|
for(int i=0;i<n_tokens;i++) buffer[i]=i;
|
||||||
|
|
||||||
|
// Validate move_names covers all actions before indexing
|
||||||
|
if((int)move_names.size()!=n_actions){out_heads.clear();out_labels.clear();return;}
|
||||||
|
int act_S=-1, act_D=-1;
|
||||||
|
for(int i=0;i<(int)move_names.size();i++){
|
||||||
|
if(move_names[i]=="S") act_S=i;
|
||||||
|
if(move_names[i]=="D") act_D=i;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto leftmost = [&](int idx)->int{
|
||||||
|
for(int i=0;i<n_tokens;i++) if(out_heads[i]==idx) return i;
|
||||||
|
return -1;
|
||||||
|
};
|
||||||
|
auto rightmost = [&](int idx)->int{
|
||||||
|
int r=-1; for(int i=0;i<n_tokens;i++) if(out_heads[i]==idx) r=i;
|
||||||
|
return r;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<float> scores(n_actions,0);
|
||||||
|
std::vector<float> feats(nP*nI*nO,0);
|
||||||
|
|
||||||
|
for(int step=0; step<n_tokens*4 && !(buffer.empty()&&stack.size()<=1); step++){
|
||||||
|
int s0=stack.empty()?-1:stack.back();
|
||||||
|
int s1=stack.size()<2?-1:stack[stack.size()-2];
|
||||||
|
int s2=stack.size()<3?-1:stack[stack.size()-3];
|
||||||
|
int b0=buffer.empty()?-1:buffer[0];
|
||||||
|
int b1=buffer.size()<2?-1:buffer[1];
|
||||||
|
|
||||||
|
// Build feature indices (same as verified Python implementation)
|
||||||
|
int idxs[16]={s0,s1, b0,s0, s0,leftmost(s0), s0,rightmost(s0),
|
||||||
|
s1,leftmost(s1), s1,rightmost(s1), s2,b1, b0,b1};
|
||||||
|
|
||||||
|
// Build feature vector: sum of precomputed features at each (idx, piece, window)
|
||||||
|
for(int o=0;o<nO;o++) feats[o]=0;
|
||||||
|
for(int p=0;p<nP;p++){
|
||||||
|
for(int w=0;w<nI;w++){
|
||||||
|
int ti=idxs[p*2+w];
|
||||||
|
for(int o=0;o<nO;o++){
|
||||||
|
feats[(size_t)p*nI*nO + (size_t)w*nO + o] = feat(ti,p,w,o);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Classify
|
||||||
|
for(int a=0;a<n_actions;a++){
|
||||||
|
float s=pb_cls[a];
|
||||||
|
// Use HIDDEN state directly (64-dim) as classifier input
|
||||||
|
int cls_idx = b0 >= 0 ? b0 : (s0 >= 0 ? s0 : 0);
|
||||||
|
for(int j=0;j<nO;j++) s += pW_cls[(size_t)a*nO+j] * hidden[(size_t)cls_idx*nO+j];
|
||||||
|
scores[a]=s;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pick best VALID action
|
||||||
|
int best=-1; float best_sc=-1e30f;
|
||||||
|
for(int a=0;a<n_actions;a++){
|
||||||
|
bool valid=false;
|
||||||
|
const std::string& n=move_names[a];
|
||||||
|
if(n=="S") valid=!buffer.empty() && (int)stack.size()<n_tokens;
|
||||||
|
else if(n=="D") valid=!stack.empty();
|
||||||
|
else if(n.size()>=2 && (n[0]=='L'||n[0]=='R')) valid=stack.size()>=2;
|
||||||
|
if(valid && scores[a]>best_sc){best_sc=scores[a];best=a;}
|
||||||
|
}
|
||||||
|
if(best<0) break;
|
||||||
|
|
||||||
|
const std::string& act=move_names[best];
|
||||||
|
if(act=="S"){stack.push_back(buffer[0]);buffer.erase(buffer.begin());}
|
||||||
|
else if(act=="D"){stack.pop_back();}
|
||||||
|
else if(act.size()>=2){
|
||||||
|
std::string lbl=act.substr(2);
|
||||||
|
if(act[0]=='L'){out_heads[s0]=s1;out_labels[s0]=lbl;stack.erase(stack.end()-2);}
|
||||||
|
else{out_heads[s1]=s0;out_labels[s1]=lbl;stack.pop_back();}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Combined state
|
||||||
|
// =========================================================================
|
||||||
|
struct ParserState {
|
||||||
|
Tok2vecModel tok2vec;
|
||||||
|
ParserModel parser;
|
||||||
|
bool loaded=false;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TaggerState {
|
||||||
|
Tok2vecModel tok2vec;
|
||||||
|
std::vector<float> tW,tb; // (n_tags, 96), (n_tags,)
|
||||||
|
std::vector<std::string> tags;
|
||||||
|
bool loaded=false;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// C API — Parser
|
||||||
|
// =========================================================================
|
||||||
|
ThincParserHandle ThincParser_Create(const char* ner_dir, const char* parser_dir) {
|
||||||
|
auto* s=new ParserState();
|
||||||
|
if(!ner_dir||!parser_dir){delete s;return nullptr;}
|
||||||
|
// Load PIPELINE tok2vec from <base>/tok2vec/ subdirectory (not NER's internal 4HE).
|
||||||
|
// ner_dir is typically <model_base>/ner/.
|
||||||
|
std::string base = std::string(ner_dir);
|
||||||
|
if(base.size()>=4 && base.substr(base.size()-4)=="/ner") base.resize(base.size()-4);
|
||||||
|
if(!s->tok2vec.load(base+"/tok2vec")){delete s;return nullptr;}
|
||||||
|
if(!s->parser.load(std::string(parser_dir))){delete s;return nullptr;}
|
||||||
|
s->loaded=true;
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThincParser_Destroy(ThincParserHandle h) { delete (ParserState*)h; }
|
||||||
|
|
||||||
|
char* ThincParser_Predict(ThincParserHandle h, const char* tokens_json) {
|
||||||
|
auto* s=(ParserState*)h;
|
||||||
|
if(!s||!s->loaded||!tokens_json) return strdup("[]");
|
||||||
|
|
||||||
|
auto j=JParser().parse(std::string(tokens_json));
|
||||||
|
if(j.type!=JVal::ARR) return strdup("[]");
|
||||||
|
std::vector<std::string> tokens;
|
||||||
|
for(auto& v:j.arr) tokens.push_back(v.str);
|
||||||
|
int n=(int)tokens.size();
|
||||||
|
if(!n) return strdup("[]");
|
||||||
|
|
||||||
|
// Run tok2vec
|
||||||
|
std::vector<float> tokvecs((size_t)n*96,0);
|
||||||
|
s->tok2vec.forward(tokens, tokvecs.data());
|
||||||
|
|
||||||
|
// Run parser
|
||||||
|
std::vector<int> heads;
|
||||||
|
std::vector<std::string> labels;
|
||||||
|
s->parser.parse(tokvecs.data(), n, heads, labels);
|
||||||
|
|
||||||
|
// Build JSON output
|
||||||
|
std::string r="[";
|
||||||
|
for(int i=0;i<n;i++){
|
||||||
|
if(i)r+=",";
|
||||||
|
r+="{\"text\":\""+tokens[i]+"\",\"head\":"+std::to_string(heads[i])+
|
||||||
|
",\"dep\":\""+labels[i]+"\",\"index\":"+std::to_string(i)+"}";
|
||||||
|
}
|
||||||
|
r+="]";
|
||||||
|
return strdup(r.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThincParser_FreeString(char* p) { free(p); }
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// C API — Tagger
|
||||||
|
// =========================================================================
|
||||||
|
ThincTaggerHandle ThincTagger_Create(const char* ner_dir, const char* tagger_dir) {
|
||||||
|
auto* s=new TaggerState();
|
||||||
|
if(!ner_dir||!tagger_dir){delete s;return nullptr;}
|
||||||
|
// Load PIPELINE tok2vec from <base>/tok2vec/ (6HE, not NER's internal 4HE)
|
||||||
|
std::string tbase = std::string(ner_dir);
|
||||||
|
if(tbase.size()>=4 && tbase.substr(tbase.size()-4)=="/ner") tbase.resize(tbase.size()-4);
|
||||||
|
if(!s->tok2vec.load(tbase+"/tok2vec")){delete s;return nullptr;}
|
||||||
|
std::ifstream cf(std::string(tagger_dir)+"/model.ckpt"); if(!cf){delete s;return nullptr;}
|
||||||
|
std::stringstream cb;cb<<cf.rdbuf();
|
||||||
|
JVal ck=JParser().parse(cb.str()); if(ck.type!=JVal::OBJ){delete s;return nullptr;}
|
||||||
|
std::ifstream bf(std::string(tagger_dir)+"/model.bin",std::ios::binary|std::ios::ate); if(!bf){delete s;return nullptr;}
|
||||||
|
size_t bz=bf.tellg();bf.seekg(0); std::vector<float> bin(bz/4); bf.read((char*)bin.data(),bz);
|
||||||
|
auto sl=[&](int64_t o,int64_t c)->std::vector<float>{
|
||||||
|
if(o+c>(int64_t)bin.size())return{}; return std::vector<float>(bin.begin()+o,bin.begin()+o+c);
|
||||||
|
};
|
||||||
|
auto ld=[&](const std::string& k, std::vector<float>* v,int* r0=nullptr)->bool{
|
||||||
|
auto* e=ck.get(k);if(!e)return false;
|
||||||
|
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
|
||||||
|
if(!sv||!ov||!cv)return false;
|
||||||
|
*v=sl(ov->as_i64(),cv->as_i64());
|
||||||
|
if(r0)*r0=sv->arr.size()>=1?sv->arr[0].as_int():1;
|
||||||
|
return!v->empty();
|
||||||
|
};
|
||||||
|
int r0=0; ld("tW",&s->tW,&r0); ld("tb",&s->tb);
|
||||||
|
std::ifstream mf(std::string(tagger_dir)+"/meta.json");
|
||||||
|
if(mf){std::stringstream mb;mb<<mf.rdbuf();auto meta=JParser().parse(mb.str());auto* tg=meta.get("tags");
|
||||||
|
if(tg&&tg->type==JVal::ARR)for(auto& v:tg->arr)s->tags.push_back(v.str);}
|
||||||
|
s->loaded=!s->tW.empty();
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThincTagger_Destroy(ThincTaggerHandle h) { delete (TaggerState*)h; }
|
||||||
|
|
||||||
|
char* ThincTagger_Predict(ThincTaggerHandle h, const char* tokens_json) {
|
||||||
|
auto* s=(TaggerState*)h;
|
||||||
|
if(!s||!s->loaded||!tokens_json||s->tW.empty()) return strdup("[]");
|
||||||
|
auto j=JParser().parse(std::string(tokens_json));
|
||||||
|
if(j.type!=JVal::ARR)return strdup("[]");
|
||||||
|
std::vector<std::string> tokens;
|
||||||
|
for(auto& v:j.arr) tokens.push_back(v.str);
|
||||||
|
int n=(int)tokens.size(), n_tags=(int)s->tW.size()/96;
|
||||||
|
if(!n||!n_tags)return strdup("[]");
|
||||||
|
|
||||||
|
// Run tok2vec to get 96-dim embeddings, then softmax + argmax
|
||||||
|
std::vector<float> tokvecs((size_t)n*96,0);
|
||||||
|
s->tok2vec.forward(tokens, tokvecs.data());
|
||||||
|
|
||||||
|
std::vector<int> best_tags(n, 0);
|
||||||
|
for(int i=0;i<n;i++){
|
||||||
|
float best_sc=-1e30f;
|
||||||
|
for(int t=0;t<n_tags;t++){
|
||||||
|
float sc=s->tb[t];
|
||||||
|
for(int j=0;j<96;j++) sc += s->tW[(size_t)t*96+j] * tokvecs[(size_t)i*96+j];
|
||||||
|
if(sc>best_sc){best_sc=sc;best_tags[i]=t;}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip morphologizer output to just POS (e.g. "Gender=Masc|Number=Sing|POS=NOUN" → "NOUN")
|
||||||
|
// For non-morphologizer models the tag string is used as-is.
|
||||||
|
auto pos_only = [](const std::string& t) -> std::string {
|
||||||
|
auto p = t.find("POS=");
|
||||||
|
if(p==std::string::npos) return t;
|
||||||
|
auto s = p+4;
|
||||||
|
auto e = t.find_first_of("|;", s);
|
||||||
|
if(e==std::string::npos) e = t.size();
|
||||||
|
return t.substr(s, e-s);
|
||||||
|
};
|
||||||
|
std::string r="[";
|
||||||
|
for(int i=0;i<n;i++){
|
||||||
|
if(i)r+=",";
|
||||||
|
std::string tag = best_tags[i] < (int)s->tags.size() ? s->tags[best_tags[i]] : "";
|
||||||
|
r+="{\"text\":\""+tokens[i]+"\",\"tag\":\""+pos_only(tag)+"\",\"index\":"+std::to_string(i)+"}";
|
||||||
|
}
|
||||||
|
r+="]";
|
||||||
|
return strdup(r.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThincTagger_FreeString(char* p) { free(p); }
|
||||||
49
internal/cpp/thinc_parser.h
Normal file
49
internal/cpp/thinc_parser.h
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
#ifndef THINC_PARSER_H
|
||||||
|
#define THINC_PARSER_H
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// C API for spaCy dependency parser and POS tagger
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
typedef void* ThincParserHandle;
|
||||||
|
typedef void* ThincTaggerHandle;
|
||||||
|
|
||||||
|
// Parser: create/destroy
|
||||||
|
// model_ner_dir: path to NER model directory (shared tok2vec weights + vocab).
|
||||||
|
// model_parser_dir: path to parser model directory.
|
||||||
|
ThincParserHandle ThincParser_Create(const char* model_ner_dir, const char* model_parser_dir);
|
||||||
|
void ThincParser_Destroy(ThincParserHandle handle);
|
||||||
|
|
||||||
|
// Run dependency parser on pre-tokenized text.
|
||||||
|
// tokens_json: JSON array of token strings, e.g. ["Apple", "was", "founded", ...]
|
||||||
|
// Returns JSON array of token annotations:
|
||||||
|
// [{"text":"Apple","head":2,"dep":"nsubjpass","index":0}, ...]
|
||||||
|
// where head is the index of the head token (0-based), -1 for root.
|
||||||
|
// Caller must free with ThincParser_FreeString.
|
||||||
|
char* ThincParser_Predict(ThincParserHandle handle, const char* tokens_json);
|
||||||
|
void ThincParser_FreeString(char* ptr);
|
||||||
|
|
||||||
|
// Tagger: create/destroy
|
||||||
|
// model_ner_dir: path to NER model directory (for tok2vec weights)
|
||||||
|
// model_tagger_dir: path to tagger model directory
|
||||||
|
ThincTaggerHandle ThincTagger_Create(const char* model_ner_dir, const char* model_tagger_dir);
|
||||||
|
void ThincTagger_Destroy(ThincTaggerHandle handle);
|
||||||
|
|
||||||
|
// Run POS tagger on pre-tokenized text.
|
||||||
|
// tokens_json: JSON array of token strings.
|
||||||
|
// Returns JSON array: [{"text":"Apple","tag":"NNP","index":0}, ...]
|
||||||
|
char* ThincTagger_Predict(ThincTaggerHandle handle, const char* tokens_json);
|
||||||
|
void ThincTagger_FreeString(char* ptr);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // THINC_PARSER_H
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build cgo
|
||||||
|
|
||||||
// Package pdfium renders PDF pages using the system's libpdfium.so
|
// Package pdfium renders PDF pages using the system's libpdfium.so
|
||||||
// (bundled with pypdfium2). It exists solely to replace pdf_oxide's
|
// (bundled with pypdfium2). It exists solely to replace pdf_oxide's
|
||||||
// RenderPageRaw for use cases where image quality matters for downstream
|
// RenderPageRaw for use cases where image quality matters for downstream
|
||||||
|
|||||||
788
internal/ingestion/compilation/extractor/dep_relation.go
Normal file
788
internal/ingestion/compilation/extractor/dep_relation.go
Normal file
@@ -0,0 +1,788 @@
|
|||||||
|
// Go implementation of dependency-based relation extraction.
|
||||||
|
// Direct port of Python DepRelationExtractor — semantica-aligned.
|
||||||
|
// Operates on a dependency tree (heads + labels) independent of parser.
|
||||||
|
|
||||||
|
package extractor
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// Verb lemmatization — multi-language
|
||||||
|
var verbLemma = map[string]string{
|
||||||
|
// English
|
||||||
|
"founded": "found", "founding": "found",
|
||||||
|
"works": "work", "working": "work",
|
||||||
|
"based": "base", "basing": "base",
|
||||||
|
"located": "locate", "locating": "locate",
|
||||||
|
"situated": "situate", "situating": "situate",
|
||||||
|
"acquired": "acquire", "acquiring": "acquire",
|
||||||
|
"employed": "employ", "employing": "employ",
|
||||||
|
"hired": "hire", "hiring": "hire",
|
||||||
|
"born": "bear",
|
||||||
|
"joined": "join", "joining": "join",
|
||||||
|
"merged": "merge", "merging": "merge",
|
||||||
|
"bought": "buy", "buying": "buy",
|
||||||
|
"created": "create", "creating": "create",
|
||||||
|
"established": "establish", "establishing": "establish",
|
||||||
|
"started": "start", "starting": "start",
|
||||||
|
"led": "lead", "leading": "lead",
|
||||||
|
"managed": "manage", "managing": "manage",
|
||||||
|
"headed": "head", "heading": "head",
|
||||||
|
"ran": "run", "running": "run",
|
||||||
|
"owned": "own", "owning": "own",
|
||||||
|
"developed": "develop", "developing": "develop",
|
||||||
|
"wrote": "write", "written": "write", "writing": "write",
|
||||||
|
"published": "publish", "publishing": "publish",
|
||||||
|
"invested": "invest", "investing": "invest",
|
||||||
|
"partnered": "partner", "partnering": "partner",
|
||||||
|
"collaborated": "collaborate", "collaborating": "collaborate",
|
||||||
|
"sets": "set",
|
||||||
|
// German
|
||||||
|
"gegründet": "gründen", "gründete": "gründen",
|
||||||
|
"arbeitet": "arbeiten", "arbeitete": "arbeiten",
|
||||||
|
"befindet": "befinden",
|
||||||
|
"liegt": "liegen", "lag": "liegen",
|
||||||
|
"geboren": "gebären",
|
||||||
|
"erworben": "erwerben", "erwarb": "erwerben",
|
||||||
|
"gekauft": "kaufen", "kaufte": "kaufen",
|
||||||
|
"übernommen": "übernehmen", "übernahm": "übernehmen",
|
||||||
|
// French
|
||||||
|
"fondé": "fonder", "fondée": "fonder",
|
||||||
|
"créé": "créer", "créée": "créer",
|
||||||
|
"travaille": "travailler",
|
||||||
|
"employé": "employer", "employée": "employer",
|
||||||
|
"situé": "situer", "située": "situer",
|
||||||
|
"né": "naître", "née": "naître",
|
||||||
|
"acquis": "acquérir",
|
||||||
|
// Spanish + Portuguese (shared forms)
|
||||||
|
"fundado": "fundar", "fundada": "fundar",
|
||||||
|
"creado": "crear", "creada": "crear",
|
||||||
|
"criado": "criar", "criada": "criar",
|
||||||
|
"trabaja": "trabajar", "trabalha": "trabalhar",
|
||||||
|
"ubicado": "ubicar", "ubicada": "ubicar",
|
||||||
|
"situado": "situar", "situada": "situar",
|
||||||
|
"localizado": "localizar", "localizada": "localizar",
|
||||||
|
"sediado": "sediar", "sediada": "sediar",
|
||||||
|
"nacido": "nacer", "nacida": "nacer",
|
||||||
|
"nascido": "nascer", "nascida": "nascer",
|
||||||
|
}
|
||||||
|
|
||||||
|
func lemma(w string) string {
|
||||||
|
if l, ok := verbLemma[w]; ok {
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verb+prep → relation type (multi-language)
|
||||||
|
// Keys: verbLemma+prep (or verbLemma alone for direct-object relations).
|
||||||
|
// Matches Python _VERB_RELATIONS exactly.
|
||||||
|
var depVerbRelations = map[string]string{
|
||||||
|
// English
|
||||||
|
"found+by": "founded_by", "co-found+by": "founded_by",
|
||||||
|
"establish+by": "founded_by", "create+by": "founded_by",
|
||||||
|
"set+up": "founded_by", "start+by": "founded_by",
|
||||||
|
"work+for": "works_for", "employ+by": "works_for",
|
||||||
|
"hire+by": "works_for", "join": "works_for",
|
||||||
|
"lead+by": "works_for", "manage+by": "works_for",
|
||||||
|
"head+by": "works_for", "run+by": "works_for",
|
||||||
|
"own+by": "owns", "develop+by": "develops",
|
||||||
|
"write+by": "wrote", "publish+by": "published",
|
||||||
|
"invest+in": "invests_in", "partner+with": "partners_with",
|
||||||
|
"collaborate+with": "collaborates_with",
|
||||||
|
"merge+with": "merged_with", "subsidiar+y": "is_subsidiary_of",
|
||||||
|
"base+in": "located_in", "locate+in": "located_in",
|
||||||
|
"situate+in": "located_in", "headquarter+in": "located_in",
|
||||||
|
"bear+in": "born_in", "bear+on": "born_in",
|
||||||
|
"acquire+by": "acquired", "buy+by": "acquired",
|
||||||
|
// German (de)
|
||||||
|
"gründen+von": "founded_by", "errichten+von": "founded_by",
|
||||||
|
"arbeiten+für": "works_for", "beschäftigen+bei": "works_for",
|
||||||
|
"anstellen+bei": "works_for",
|
||||||
|
"sich+befinden": "located_in", "liegen+in": "located_in",
|
||||||
|
"sitzen+in": "located_in", "gebären+in": "born_in",
|
||||||
|
"gebären+am": "born_in",
|
||||||
|
"erwerben+durch": "acquired", "kaufen+durch": "acquired",
|
||||||
|
"übernehmen+durch": "acquired",
|
||||||
|
// French (fr)
|
||||||
|
"fonder+par": "founded_by", "créer+par": "founded_by",
|
||||||
|
"établir+par": "founded_by",
|
||||||
|
"travailler+pour": "works_for", "employer+par": "works_for",
|
||||||
|
"embaucher+par": "works_for",
|
||||||
|
"situer+à": "located_in", "baser+à": "located_in",
|
||||||
|
"implanter+à": "located_in",
|
||||||
|
"naître+à": "born_in",
|
||||||
|
"acquérir+par": "acquired", "racheter+par": "acquired",
|
||||||
|
// Spanish (es)
|
||||||
|
"fundar+por": "founded_by", "crear+por": "founded_by",
|
||||||
|
"establecer+por": "founded_by",
|
||||||
|
"trabajar+para": "works_for", "emplear+por": "works_for",
|
||||||
|
"contratar+por": "works_for",
|
||||||
|
"ubicar+en": "located_in", "situar+en": "located_in",
|
||||||
|
"tener+sede": "located_in",
|
||||||
|
"nacer+en": "born_in",
|
||||||
|
"adquirir+por": "acquired", "comprar+por": "acquired",
|
||||||
|
// Portuguese (pt)
|
||||||
|
"criar+por": "founded_by",
|
||||||
|
"estabelecer+por": "founded_by",
|
||||||
|
"trabalhar+para": "works_for", "empregar+por": "works_for",
|
||||||
|
"localizar+em": "located_in", "situar+em": "located_in",
|
||||||
|
"sediar+em": "located_in",
|
||||||
|
"nascer+em": "born_in",
|
||||||
|
// Chinese (zh)
|
||||||
|
"创立+由": "founded_by", "创建+由": "founded_by",
|
||||||
|
"成立+由": "founded_by", "创办+由": "founded_by",
|
||||||
|
"设立+由": "founded_by",
|
||||||
|
"任职+于": "works_for", "就职+于": "works_for",
|
||||||
|
"工作+在": "works_for", "位于+在": "located_in",
|
||||||
|
"坐落+在": "located_in", "总部设+在": "located_in",
|
||||||
|
"出生+在": "born_in", "出生+于": "born_in",
|
||||||
|
"收购+由": "acquired", "并购+由": "acquired",
|
||||||
|
// Japanese (ja)
|
||||||
|
"設立+によって": "founded_by", "創立+によって": "founded_by",
|
||||||
|
"勤務+で": "works_for", "在籍+で": "works_for",
|
||||||
|
"位置+に": "located_in", "所在+に": "located_in",
|
||||||
|
"本社+を": "located_in",
|
||||||
|
"出生+に": "born_in",
|
||||||
|
"買収+によって": "acquired",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copula title patterns — X is [title] of Y → typed relation
|
||||||
|
var depCopulaTitles = map[string][]string{
|
||||||
|
"ceo": {"ceo_of", "works_for"}, "cto": {"works_for"},
|
||||||
|
"cfo": {"works_for"}, "coo": {"works_for"},
|
||||||
|
"vp": {"works_for"}, "director": {"works_for"},
|
||||||
|
"manager": {"works_for"}, "engineer": {"works_for"},
|
||||||
|
"employee": {"works_for"},
|
||||||
|
"founder": {"founded_by"}, "co-founder": {"founded_by"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multi-hop inference rules: A rel1 B + B rel2 C ⇒ A rel3 C
|
||||||
|
var multiHopRules = map[string]map[string]string{
|
||||||
|
"ceo_of": {"is_subsidiary_of": "works_for", "located_in": "works_for"},
|
||||||
|
"works_for": {"is_subsidiary_of": "works_for"},
|
||||||
|
"founded_by": {"is_subsidiary_of": "founded_by"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Language-specific dependency role mappings
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Matches Python _LANG_DEP_RULES exactly for all 7 languages.
|
||||||
|
|
||||||
|
type roleSpec struct {
|
||||||
|
dep string // main dependency label (e.g. "nsubj", "obl:agent")
|
||||||
|
childDep string // child dep for compound rules (e.g. "pobj" for "agent"→"pobj")
|
||||||
|
caseMarker string // optional case marker text (zh:"由", ja:"によって")
|
||||||
|
}
|
||||||
|
|
||||||
|
var langRolesMap = map[string]map[string]roleSpec{
|
||||||
|
"en": {
|
||||||
|
"pass_subj": {dep: "nsubjpass"},
|
||||||
|
"subj": {dep: "nsubj"},
|
||||||
|
"agent": {dep: "agent", childDep: "pobj"},
|
||||||
|
"dobj": {dep: "dobj"},
|
||||||
|
"prep_obj": {dep: "prep", childDep: "pobj"},
|
||||||
|
},
|
||||||
|
"de": {
|
||||||
|
"subj": {dep: "sb"},
|
||||||
|
"agent": {dep: "sbp", childDep: "nk"},
|
||||||
|
"prep_obj": {dep: "mo", childDep: "nk"},
|
||||||
|
// German ROOT is aux verb, real verb has dep "oc"
|
||||||
|
"root_verb_child": {dep: "oc"},
|
||||||
|
},
|
||||||
|
"fr": {
|
||||||
|
"pass_subj": {dep: "nsubj:pass"},
|
||||||
|
"subj": {dep: "nsubj"},
|
||||||
|
"agent": {dep: "obl:agent"},
|
||||||
|
"dobj": {dep: "obj"},
|
||||||
|
"prep_obj": {dep: "case", childDep: "obl"},
|
||||||
|
},
|
||||||
|
"es": {
|
||||||
|
"subj": {dep: "nsubj"},
|
||||||
|
"agent": {dep: "obj"},
|
||||||
|
"prep_obj": {dep: "case", childDep: "obl"},
|
||||||
|
},
|
||||||
|
"pt": {
|
||||||
|
"pass_subj": {dep: "nsubj:pass"},
|
||||||
|
"subj": {dep: "nsubj"},
|
||||||
|
"agent": {dep: "obl:agent"},
|
||||||
|
"dobj": {dep: "obj"},
|
||||||
|
"prep_obj": {dep: "case", childDep: "obl"},
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"subj": {dep: "nsubj"},
|
||||||
|
"agent": {dep: "nmod:prep", caseMarker: "由"},
|
||||||
|
"dobj": {dep: "dobj"},
|
||||||
|
"prep_obj": {dep: "case", childDep: "nmod"},
|
||||||
|
},
|
||||||
|
"ja": {
|
||||||
|
"subj": {dep: "nsubj"},
|
||||||
|
"agent": {dep: "obl", caseMarker: "によって"},
|
||||||
|
"dobj": {dep: "dobj"},
|
||||||
|
"prep_obj": {dep: "case", childDep: "obl"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copula dependency labels per language (attr_deps, prep_deps, obj_deps).
|
||||||
|
// Matches Python _extract_copula label sets.
|
||||||
|
var copulaDeps = map[string]struct {
|
||||||
|
attrDeps []string
|
||||||
|
prepDeps []string
|
||||||
|
objDeps []string
|
||||||
|
}{
|
||||||
|
"en": {attrDeps: []string{"attr"}, prepDeps: []string{"prep"}, objDeps: []string{"pobj"}},
|
||||||
|
"de": {attrDeps: []string{"pred"}, prepDeps: []string{"mo"}, objDeps: []string{"nk"}},
|
||||||
|
"fr": {attrDeps: []string{"attr"}, prepDeps: []string{"case"}, objDeps: []string{"obl"}},
|
||||||
|
"es": {attrDeps: []string{"attr"}, prepDeps: []string{"case"}, objDeps: []string{"obl"}},
|
||||||
|
"pt": {attrDeps: []string{"attr"}, prepDeps: []string{"case"}, objDeps: []string{"obl"}},
|
||||||
|
"zh": {attrDeps: []string{"attr"}, prepDeps: []string{"case"}, objDeps: []string{"nmod"}},
|
||||||
|
"ja": {attrDeps: []string{"attr"}, prepDeps: []string{"case"}, objDeps: []string{"obl"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
// be-verb surface forms across all 7 languages (used for copula detection).
|
||||||
|
// Duplicate strings between languages are normalised into a single entry.
|
||||||
|
var beVerbs = map[string]bool{
|
||||||
|
// English
|
||||||
|
"is": true, "are": true, "was": true, "were": true, "be": true, "been": true, "being": true,
|
||||||
|
// German
|
||||||
|
"ist": true, "sind": true, "bin": true, "bist": true, "seid": true,
|
||||||
|
"war": true, "waren": true, "gewesen": true, "sein": true,
|
||||||
|
// French
|
||||||
|
"est": true, "suis": true, "sommes": true, "êtes": true, "sont": true,
|
||||||
|
"était": true, "étant": true, "être": true,
|
||||||
|
// Spanish + Portuguese (shared forms; french "es" omitted to avoid key collision)
|
||||||
|
"es": true, "é": true, "son": true, "são": true,
|
||||||
|
"está": true, "están": true, "estão": true,
|
||||||
|
"era": true, "eran": true, "eram": true,
|
||||||
|
"ser": true, "sido": true, "siendo": true, "sendo": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// DepToken holds token dependency info from the C++ parser.
|
||||||
|
type DepToken struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Head int `json:"head"`
|
||||||
|
Dep string `json:"dep"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
POS string `json:"pos,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// roleResult is the return type for getByRole.
|
||||||
|
type roleResult struct {
|
||||||
|
entity Entity
|
||||||
|
prep string // prep lemma for prep_obj role; empty otherwise
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Main entry point
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// DepExtractRelations extracts typed relations from a dependency parse tree.
|
||||||
|
// lang is the language code (en/zh/de/fr/es/pt/ja).
|
||||||
|
// maxDistance is the max character distance for co-occurrence (0 = use default 100).
|
||||||
|
func DepExtractRelations(text string, tokens []DepToken, entities []Entity, lang string, maxDistance int) []Relation {
|
||||||
|
entityMap := buildEntityMapMulti(entities)
|
||||||
|
var relations []Relation
|
||||||
|
|
||||||
|
// Detect German-style "oc" handling
|
||||||
|
_, hasRootVerbChild := langRolesMap[lang]["root_verb_child"]
|
||||||
|
|
||||||
|
for _, tok := range tokens {
|
||||||
|
if tok.Head != tok.Index {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// German: ROOT is aux verb; real verb is an "oc" child
|
||||||
|
if hasRootVerbChild {
|
||||||
|
for _, c := range childrenOf(tok.Index, tokens) {
|
||||||
|
if c.Dep == langRolesMap[lang]["root_verb_child"].dep {
|
||||||
|
if hasNegation(c.Index, tokens) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rels := extractFromRoot(text, c.Index, tokens, entityMap, lang)
|
||||||
|
relations = append(relations, rels...)
|
||||||
|
if isBeVerb(tok) {
|
||||||
|
rels := extractCopula(text, c.Index, tokens, entityMap, lang)
|
||||||
|
relations = append(relations, rels...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Standard languages: ROOT = main verb
|
||||||
|
if hasNegation(tok.Index, tokens) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rels := extractFromRoot(text, tok.Index, tokens, entityMap, lang)
|
||||||
|
relations = append(relations, rels...)
|
||||||
|
if isBeVerb(tok) {
|
||||||
|
rels := extractCopula(text, tok.Index, tokens, entityMap, lang)
|
||||||
|
relations = append(relations, rels...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Co-occurrence (always, matching Python DepRelationExtractor.extract())
|
||||||
|
if maxDistance <= 0 {
|
||||||
|
maxDistance = 100
|
||||||
|
}
|
||||||
|
relations = append(relations, extractCooccurrence(text, entities, maxDistance)...)
|
||||||
|
|
||||||
|
relations = inferMultiHop(relations)
|
||||||
|
relations = dedupRelations(relations)
|
||||||
|
return relations
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Negation
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func hasNegation(idx int, tokens []DepToken) bool {
|
||||||
|
for _, t := range tokens {
|
||||||
|
if t.Head == idx && (t.Dep == "neg" || t.Dep == "advmod:neg") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBeVerb(tok DepToken) bool {
|
||||||
|
return beVerbs[strings.ToLower(tok.Text)]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Multi-hop inference
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func inferMultiHop(rels []Relation) []Relation {
|
||||||
|
bySubj := make(map[string][]Relation)
|
||||||
|
for _, r := range rels {
|
||||||
|
if r.Predicate == "related_to" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := strings.ToLower(r.Subject.Text)
|
||||||
|
bySubj[key] = append(bySubj[key], r)
|
||||||
|
}
|
||||||
|
|
||||||
|
var inferred []Relation
|
||||||
|
for _, r := range rels {
|
||||||
|
if r.Predicate == "related_to" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
objKey := strings.ToLower(r.Object.Text)
|
||||||
|
if chain, ok := bySubj[objKey]; ok {
|
||||||
|
for _, r2 := range chain {
|
||||||
|
if hopRules, ok := multiHopRules[r.Predicate]; ok {
|
||||||
|
if inferredPred, ok := hopRules[r2.Predicate]; ok {
|
||||||
|
conf := r.Confidence
|
||||||
|
if r2.Confidence < conf {
|
||||||
|
conf = r2.Confidence
|
||||||
|
}
|
||||||
|
conf *= 0.9
|
||||||
|
r := Relation{
|
||||||
|
Subject: r.Subject,
|
||||||
|
Predicate: inferredPred,
|
||||||
|
Object: r2.Object,
|
||||||
|
Confidence: conf,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"method": "multi_hop",
|
||||||
|
"via": r.Predicate + "→" + r2.Predicate,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
inferred = append(inferred, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return append(rels, inferred...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Entity map
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func buildEntityMapMulti(entities []Entity) map[string][]Entity {
|
||||||
|
m := make(map[string][]Entity)
|
||||||
|
for _, e := range entities {
|
||||||
|
key := strings.ToLower(e.Text)
|
||||||
|
m[key] = append(m[key], e)
|
||||||
|
cleaned := strings.TrimRight(e.Text, ".,;:!?")
|
||||||
|
cleaned = strings.TrimSpace(cleaned)
|
||||||
|
if cleaned != e.Text {
|
||||||
|
ckey := strings.ToLower(cleaned)
|
||||||
|
m[ckey] = append(m[ckey], e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func findBestEntity(key string, entityMap map[string][]Entity) *Entity {
|
||||||
|
entries := entityMap[strings.ToLower(strings.TrimSpace(key))]
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &entries[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Language-aware role lookup (replaces old getChildEntity/getAgentPobj/getPrepObjs)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// getByRole returns matching (entity, prep) pairs for a semantic role.
|
||||||
|
// Matches Python DepRelationExtractor._get_by_role.
|
||||||
|
func getByRole(lang string, role string, rootIdx int, tokens []DepToken, entityMap map[string][]Entity) []roleResult {
|
||||||
|
roles, ok := langRolesMap[lang]
|
||||||
|
if !ok {
|
||||||
|
roles = langRolesMap["en"]
|
||||||
|
}
|
||||||
|
spec, ok := roles[role]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []roleResult
|
||||||
|
for _, c := range childrenOf(rootIdx, tokens) {
|
||||||
|
if spec.caseMarker != "" {
|
||||||
|
// zh/ja agent: check for case marker in subtree
|
||||||
|
if c.Dep == spec.dep && hasCaseMarkerInSubtree(c.Index, tokens, spec.caseMarker) {
|
||||||
|
if ent := findEntityInSubtree(c.Index, tokens, entityMap); ent != nil {
|
||||||
|
results = append(results, roleResult{entity: *ent})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if spec.childDep != "" {
|
||||||
|
// Compound rule: parent dep + child dep
|
||||||
|
if c.Dep == spec.dep {
|
||||||
|
prepLemma := strings.ToLower(c.Text)
|
||||||
|
for _, gc := range childrenOf(c.Index, tokens) {
|
||||||
|
if gc.Dep == spec.childDep {
|
||||||
|
if ent := findEntityInSubtree(gc.Index, tokens, entityMap); ent != nil {
|
||||||
|
if role == "prep_obj" {
|
||||||
|
results = append(results, roleResult{entity: *ent, prep: prepLemma})
|
||||||
|
} else {
|
||||||
|
results = append(results, roleResult{entity: *ent})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Simple rule: single dep label
|
||||||
|
if c.Dep == spec.dep {
|
||||||
|
if ent := findEntityInSubtree(c.Index, tokens, entityMap); ent != nil {
|
||||||
|
results = append(results, roleResult{entity: *ent})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasCaseMarkerInSubtree checks if any token in the subtree of idx contains marker text.
|
||||||
|
func hasCaseMarkerInSubtree(idx int, tokens []DepToken, marker string) bool {
|
||||||
|
visited := map[int]bool{}
|
||||||
|
subtree := collectSubtree(idx, tokens, visited)
|
||||||
|
for _, t := range subtree {
|
||||||
|
if t.Text == marker {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectSubtree(idx int, tokens []DepToken, visited map[int]bool) []DepToken {
|
||||||
|
if visited[idx] {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
visited[idx] = true
|
||||||
|
result := []DepToken{tokens[idx]}
|
||||||
|
for _, c := range childrenOf(idx, tokens) {
|
||||||
|
result = append(result, collectSubtree(c.Index, tokens, visited)...)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// extractFromRoot — passive/active/preposition patterns (language-aware)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func extractFromRoot(text string, rootIdx int, tokens []DepToken, entityMap map[string][]Entity, lang string) []Relation {
|
||||||
|
var relations []Relation
|
||||||
|
root := tokens[rootIdx]
|
||||||
|
verbLemma := lemma(strings.ToLower(root.Text))
|
||||||
|
|
||||||
|
// Get roles using language-aware mapping
|
||||||
|
first := func(lst []roleResult) *Entity {
|
||||||
|
if len(lst) > 0 {
|
||||||
|
return &lst[0].entity
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
nsubj := first(getByRole(lang, "subj", rootIdx, tokens, entityMap))
|
||||||
|
nsubjpass := first(getByRole(lang, "pass_subj", rootIdx, tokens, entityMap))
|
||||||
|
dobj := first(getByRole(lang, "dobj", rootIdx, tokens, entityMap))
|
||||||
|
agentList := getByRole(lang, "agent", rootIdx, tokens, entityMap)
|
||||||
|
var agentEntity *Entity
|
||||||
|
if len(agentList) > 0 {
|
||||||
|
agentEntity = &agentList[0].entity
|
||||||
|
}
|
||||||
|
prepList := getByRole(lang, "prep_obj", rootIdx, tokens, entityMap)
|
||||||
|
hasExplicitAgent := agentEntity != nil
|
||||||
|
|
||||||
|
// Detect passive:
|
||||||
|
// - explicit pass_subj (en, fr, pt)
|
||||||
|
// - subj + agent (zh/ja with agent marker, es-style)
|
||||||
|
isPassiveCandidate := hasExplicitAgent
|
||||||
|
|
||||||
|
effectiveNsubjpass := nsubjpass
|
||||||
|
effectiveNsubj := nsubj
|
||||||
|
if isPassiveCandidate {
|
||||||
|
if nsubjpass != nil {
|
||||||
|
effectiveNsubjpass = nsubjpass
|
||||||
|
} else if nsubj != nil {
|
||||||
|
effectiveNsubjpass = nsubj
|
||||||
|
effectiveNsubj = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Passive: X was founded/acquired by Y
|
||||||
|
if effectiveNsubjpass != nil && agentEntity != nil {
|
||||||
|
candidates := []string{"by", "von", "par", "por", "durch", "由", "によって"}
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
if relType := lookupVerb(verbLemma, candidate); relType != "" {
|
||||||
|
var subj, obj Entity
|
||||||
|
if relType == "founded_by" || relType == "acquired" {
|
||||||
|
subj, obj = *effectiveNsubjpass, *agentEntity
|
||||||
|
} else {
|
||||||
|
subj, obj = *agentEntity, *effectiveNsubjpass
|
||||||
|
}
|
||||||
|
r := makeRelation(subj, relType, obj, 0.90)
|
||||||
|
r.Metadata = map[string]interface{}{"method": "passive", "verb": verbLemma}
|
||||||
|
relations = append(relations, r)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Active: X VERB Y or X VERB prep Y
|
||||||
|
if effectiveNsubj != nil {
|
||||||
|
if dobj != nil {
|
||||||
|
if relType := lookupVerb(verbLemma, ""); relType != "" {
|
||||||
|
r := makeRelation(*effectiveNsubj, relType, *dobj, 0.85)
|
||||||
|
r.Metadata = map[string]interface{}{"method": "active", "verb": verbLemma}
|
||||||
|
relations = append(relations, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, pe := range prepList {
|
||||||
|
if relType := lookupVerb(verbLemma, pe.prep); relType != "" {
|
||||||
|
r := makeRelation(*effectiveNsubj, relType, pe.entity, 0.85)
|
||||||
|
r.Metadata = map[string]interface{}{"method": "active_prep", "verb": verbLemma, "prep": pe.prep}
|
||||||
|
relations = append(relations, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Passive with prep ("is based in")
|
||||||
|
if effectiveNsubjpass != nil && len(prepList) > 0 && agentEntity == nil {
|
||||||
|
for _, pe := range prepList {
|
||||||
|
relType := lookupVerb(verbLemma, pe.prep)
|
||||||
|
if relType == "" {
|
||||||
|
relType = lookupVerb("be+"+verbLemma, pe.prep)
|
||||||
|
}
|
||||||
|
if relType != "" {
|
||||||
|
r := makeRelation(*effectiveNsubjpass, relType, pe.entity, 0.85)
|
||||||
|
r.Metadata = map[string]interface{}{"method": "passive_prep", "verb": verbLemma, "prep": pe.prep}
|
||||||
|
relations = append(relations, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return relations
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Copula extraction (language-aware)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func extractCopula(text string, rootIdx int, tokens []DepToken, entityMap map[string][]Entity, lang string) []Relation {
|
||||||
|
subjList := getByRole(lang, "subj", rootIdx, tokens, entityMap)
|
||||||
|
if len(subjList) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
subj := subjList[0].entity
|
||||||
|
|
||||||
|
cd, ok := copulaDeps[lang]
|
||||||
|
if !ok {
|
||||||
|
cd = copulaDeps["en"]
|
||||||
|
}
|
||||||
|
|
||||||
|
var titleLemma string
|
||||||
|
var prepObj *Entity
|
||||||
|
|
||||||
|
for _, c := range childrenOf(rootIdx, tokens) {
|
||||||
|
if !containsDep(c.Dep, cd.attrDeps) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, cc := range childrenOf(c.Index, tokens) {
|
||||||
|
if !containsDep(cc.Dep, cd.prepDeps) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, gc := range childrenOf(cc.Index, tokens) {
|
||||||
|
if containsDep(gc.Dep, cd.objDeps) {
|
||||||
|
if ent := findEntityInSubtree(gc.Index, tokens, entityMap); ent != nil {
|
||||||
|
prepObj = ent
|
||||||
|
titleLemma = strings.ToLower(c.Text)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if titleLemma == "" || prepObj == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var relations []Relation
|
||||||
|
for keyword, relTypes := range depCopulaTitles {
|
||||||
|
if strings.Contains(titleLemma, keyword) {
|
||||||
|
for _, rt := range relTypes {
|
||||||
|
r := makeRelation(subj, rt, *prepObj, 0.88)
|
||||||
|
r.Metadata = map[string]interface{}{"method": "copula", "title": titleLemma}
|
||||||
|
relations = append(relations, r)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return relations
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func makeRelation(subj Entity, pred string, obj Entity, conf float64) Relation {
|
||||||
|
return Relation{
|
||||||
|
Subject: subj,
|
||||||
|
Predicate: pred,
|
||||||
|
Object: obj,
|
||||||
|
Confidence: conf,
|
||||||
|
Metadata: map[string]interface{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func lookupVerb(verb, prep string) string {
|
||||||
|
if prep != "" {
|
||||||
|
if v, ok := depVerbRelations[verb+"+"+prep]; ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return depVerbRelations[verb]
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsDep(dep string, deps []string) bool {
|
||||||
|
for _, d := range deps {
|
||||||
|
if dep == d {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func childrenOf(idx int, tokens []DepToken) []DepToken {
|
||||||
|
var kids []DepToken
|
||||||
|
for _, t := range tokens {
|
||||||
|
if t.Head == idx && t.Index != idx {
|
||||||
|
kids = append(kids, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kids
|
||||||
|
}
|
||||||
|
|
||||||
|
func findEntityInSubtree(idx int, tokens []DepToken, emap map[string][]Entity) *Entity {
|
||||||
|
words := collectWords(idx, tokens, map[int]bool{})
|
||||||
|
if len(words) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
text := strings.Join(words, " ")
|
||||||
|
key := strings.ToLower(strings.TrimSpace(text))
|
||||||
|
if ent := findBestEntity(key, emap); ent != nil {
|
||||||
|
return ent
|
||||||
|
}
|
||||||
|
// For CJK (no spaces), also try joining without spaces
|
||||||
|
noSpace := strings.ToLower(strings.TrimSpace(strings.Join(words, "")))
|
||||||
|
if noSpace != key {
|
||||||
|
if ent := findBestEntity(noSpace, emap); ent != nil {
|
||||||
|
return ent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, sep := range []string{" and ", " or ", ", "} {
|
||||||
|
if strings.Contains(key, sep) {
|
||||||
|
candidate := strings.TrimSpace(strings.SplitN(key, sep, 2)[0])
|
||||||
|
if ent := findBestEntity(candidate, emap); ent != nil {
|
||||||
|
return ent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fuzzy: try substring match
|
||||||
|
for ek, ev := range emap {
|
||||||
|
if strings.Contains(ek, key) || strings.Contains(key, ek) {
|
||||||
|
e := ev[0]
|
||||||
|
return &e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectWords(idx int, tokens []DepToken, visited map[int]bool) []string {
|
||||||
|
if visited[idx] {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
visited[idx] = true
|
||||||
|
tok := tokens[idx]
|
||||||
|
if tok.Dep == "prep" || tok.Dep == "punct" || tok.Dep == "det" ||
|
||||||
|
tok.Dep == "aux" || tok.Dep == "auxpass" || tok.Dep == "cc" ||
|
||||||
|
tok.Dep == "conj" || tok.Dep == "neg" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var childWords []string
|
||||||
|
kids := childrenOf(idx, tokens)
|
||||||
|
sortByIndex(kids)
|
||||||
|
for _, c := range kids {
|
||||||
|
childWords = append(childWords, collectWords(c.Index, tokens, visited)...)
|
||||||
|
}
|
||||||
|
hasChildBefore := false
|
||||||
|
for _, k := range kids {
|
||||||
|
if k.Index < idx {
|
||||||
|
hasChildBefore = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasChildBefore {
|
||||||
|
return append(childWords, tok.Text)
|
||||||
|
}
|
||||||
|
return append([]string{tok.Text}, childWords...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortByIndex(ts []DepToken) {
|
||||||
|
for i := 0; i < len(ts); i++ {
|
||||||
|
for j := i + 1; j < len(ts); j++ {
|
||||||
|
if ts[j].Index < ts[i].Index {
|
||||||
|
ts[i], ts[j] = ts[j], ts[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func dedupRelations(rels []Relation) []Relation {
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
var result []Relation
|
||||||
|
for _, r := range rels {
|
||||||
|
key := strings.ToLower(r.Subject.Text + "|" + r.Predicate + "|" + r.Object.Text)
|
||||||
|
rev := strings.ToLower(r.Object.Text + "|" + r.Predicate + "|" + r.Subject.Text)
|
||||||
|
if seen[key] || seen[rev] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = true
|
||||||
|
result = append(result, r)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
414
internal/ingestion/compilation/extractor/ner.go
Normal file
414
internal/ingestion/compilation/extractor/ner.go
Normal file
@@ -0,0 +1,414 @@
|
|||||||
|
//
|
||||||
|
// Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
|
||||||
|
// Package extractor provides NER and relation extraction for the ingestion
|
||||||
|
// pipeline. It wraps the C++ ThincNER engine via cgo and supplements it with
|
||||||
|
// pure-Go regex-based relation extraction.
|
||||||
|
//
|
||||||
|
// The architecture mirrors the Python rag/graphrag/ner package so that both
|
||||||
|
// code paths produce identical output (verified by test).
|
||||||
|
package extractor
|
||||||
|
|
||||||
|
// #cgo CXXFLAGS: -std=c++20 -I${SRCDIR}/../../..
|
||||||
|
// #cgo linux LDFLAGS: ${SRCDIR}/../../../cpp/cmake-build-release/librag_tokenizer_c_api.a -lstdc++ -lm -lpthread -lpcre2-8
|
||||||
|
// #cgo darwin LDFLAGS: ${SRCDIR}/../../../cpp/cmake-build-release/librag_tokenizer_c_api.a -lstdc++ -lm -lpthread -lpcre2-8
|
||||||
|
//
|
||||||
|
// #include <stdlib.h>
|
||||||
|
// #include "../../../cpp/rag_analyzer_c_api.h"
|
||||||
|
// #include "../../../cpp/thinc_parser.h"
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Entity represents an extracted named entity.
|
||||||
|
type Entity struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
StartChar int `json:"start_char"`
|
||||||
|
EndChar int `json:"end_char"`
|
||||||
|
Confidence float64 `json:"confidence"`
|
||||||
|
AppType string `json:"app_type,omitempty"`
|
||||||
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Relation represents a typed relation between two entities.
|
||||||
|
type Relation struct {
|
||||||
|
Subject Entity `json:"subject"`
|
||||||
|
Predicate string `json:"predicate"`
|
||||||
|
Object Entity `json:"object"`
|
||||||
|
Confidence float64 `json:"confidence"`
|
||||||
|
Context string `json:"context,omitempty"`
|
||||||
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractionResult holds the output of a full extraction pass.
|
||||||
|
type ExtractionResult struct {
|
||||||
|
Entities []Entity `json:"entities"`
|
||||||
|
Relations []Relation `json:"relations"`
|
||||||
|
Language string `json:"language,omitempty"`
|
||||||
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extractor provides NER + relation extraction
|
||||||
|
type Extractor struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
// Language code (en/zh/de/fr/es/pt/ja)
|
||||||
|
Lang string
|
||||||
|
// Minimum confidence to include an entity (default 0.0 = all)
|
||||||
|
ConfidenceThreshold float64
|
||||||
|
// Include token-level info (POS, dep) in ExtractionResult metadata
|
||||||
|
IncludeTokens bool
|
||||||
|
// Max character distance for co-occurrence relations (default 100)
|
||||||
|
MaxDistance int
|
||||||
|
}
|
||||||
|
|
||||||
|
// spaCy NER label → application entity type mapping
|
||||||
|
var spacyToAppType = map[string]string{
|
||||||
|
"PERSON": "person",
|
||||||
|
"ORG": "organization",
|
||||||
|
"GPE": "geo",
|
||||||
|
"LOC": "geo",
|
||||||
|
"FAC": "geo",
|
||||||
|
"EVENT": "event",
|
||||||
|
"PRODUCT": "category",
|
||||||
|
"DATE": "event",
|
||||||
|
"TIME": "event",
|
||||||
|
"MONEY": "category",
|
||||||
|
"QUANTITY": "category",
|
||||||
|
"PERCENT": "category",
|
||||||
|
"LAW": "category",
|
||||||
|
"NORP": "category",
|
||||||
|
"LANGUAGE": "category",
|
||||||
|
"WORK_OF_ART": "category",
|
||||||
|
}
|
||||||
|
|
||||||
|
var skipLabels = map[string]bool{
|
||||||
|
"ORDINAL": true,
|
||||||
|
"CARDINAL": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelPredictor is a cached predict function for a model path.
|
||||||
|
// Closure captures the C handle to avoid unsafe.Pointer type issues.
|
||||||
|
type ModelPredictor func(tokensJSON string) (string, error)
|
||||||
|
|
||||||
|
var (
|
||||||
|
modelCacheMu sync.Mutex
|
||||||
|
modelCache = map[string]ModelPredictor{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// langModel maps language codes to spaCy model names.
|
||||||
|
var langModel = map[string]string{
|
||||||
|
"en": "en_core_web_sm",
|
||||||
|
"zh": "zh_core_web_sm",
|
||||||
|
"de": "de_core_news_sm",
|
||||||
|
"fr": "fr_core_news_sm",
|
||||||
|
"es": "es_core_news_sm",
|
||||||
|
"pt": "pt_core_news_sm",
|
||||||
|
"ja": "ja_core_news_sm",
|
||||||
|
}
|
||||||
|
|
||||||
|
// langFallback maps languages without dedicated relation patterns to a fallback.
|
||||||
|
var langFallback = map[string]string{
|
||||||
|
"de": "en",
|
||||||
|
"fr": "en",
|
||||||
|
"es": "en",
|
||||||
|
"pt": "en",
|
||||||
|
"ja": "zh",
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewExtractor creates a new extractor.
|
||||||
|
// Supported langs: en, zh, de, fr, es, pt, ja.
|
||||||
|
func NewExtractor(lang string) *Extractor {
|
||||||
|
if lang == "" {
|
||||||
|
lang = "en"
|
||||||
|
}
|
||||||
|
if _, ok := langModel[lang]; !ok {
|
||||||
|
lang = "en"
|
||||||
|
}
|
||||||
|
return &Extractor{
|
||||||
|
Lang: lang,
|
||||||
|
ConfidenceThreshold: 0.0, // include all by default
|
||||||
|
MaxDistance: 100,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract runs NER and optionally relation extraction (dep-based via C++ parser, or regex fallback).
|
||||||
|
func (e *Extractor) Extract(text string, extractRelations bool) (*ExtractionResult, error) {
|
||||||
|
entities, err := e.ExtractEntities(text)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect token info if requested (before entity dedup changes offsets)
|
||||||
|
var tokensMeta []map[string]interface{}
|
||||||
|
if e.IncludeTokens {
|
||||||
|
tokensJSON := tokenizeText(text, e.Lang)
|
||||||
|
if tokensJSON != "" {
|
||||||
|
var rawTokens []string
|
||||||
|
if err := json.Unmarshal([]byte(tokensJSON), &rawTokens); err == nil {
|
||||||
|
for i, t := range rawTokens {
|
||||||
|
tokensMeta = append(tokensMeta, map[string]interface{}{
|
||||||
|
"text": t,
|
||||||
|
"index": i,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &ExtractionResult{
|
||||||
|
Entities: entities,
|
||||||
|
Language: e.Lang,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"n_entities": len(entities),
|
||||||
|
"model": langModel[e.Lang],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if len(tokensMeta) > 0 {
|
||||||
|
result.Metadata["n_tokens"] = len(tokensMeta)
|
||||||
|
result.Metadata["tokens"] = tokensMeta
|
||||||
|
}
|
||||||
|
|
||||||
|
if extractRelations && len(entities) >= 2 {
|
||||||
|
relations := e.extractRelations(text, entities)
|
||||||
|
result.Relations = relations
|
||||||
|
nTyped := 0
|
||||||
|
for _, r := range relations {
|
||||||
|
if r.Predicate != "related_to" {
|
||||||
|
nTyped++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.Metadata["n_relations"] = nTyped
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractRelations attempts dep-based extraction via C++ parser; falls back to regex.
|
||||||
|
func (e *Extractor) extractRelations(text string, entities []Entity) []Relation {
|
||||||
|
relLang := e.Lang
|
||||||
|
if fb, ok := langFallback[e.Lang]; ok {
|
||||||
|
relLang = fb
|
||||||
|
}
|
||||||
|
// Try dep-based extraction via C++ parser — uses e.Lang (not relLang) so
|
||||||
|
// de/fr/es/pt/ja apply their language-specific DepExtractRelations rules.
|
||||||
|
tokensJSON := tokenizeText(text, e.Lang)
|
||||||
|
if tokensJSON == "" {
|
||||||
|
return extractRelationsWithOpts(text, entities, relLang, e.MaxDistance)
|
||||||
|
}
|
||||||
|
var tokens []string
|
||||||
|
if err := json.Unmarshal([]byte(tokensJSON), &tokens); err != nil || len(tokens) == 0 {
|
||||||
|
return extractRelationsWithOpts(text, entities, relLang, e.MaxDistance)
|
||||||
|
}
|
||||||
|
modelDir := e.findModelDir()
|
||||||
|
nerDir := modelDir + "/ner"
|
||||||
|
parserDir := modelDir + "/parser"
|
||||||
|
if deps, err := ParseTokensWithParser(nerDir, parserDir, tokens); err == nil && len(deps) > 0 {
|
||||||
|
depTokens := make([]DepToken, len(deps))
|
||||||
|
for i, d := range deps {
|
||||||
|
depTokens[i] = DepToken{Text: d.Text, Head: d.Head, Dep: d.Dep, Index: d.Index}
|
||||||
|
}
|
||||||
|
if rels := DepExtractRelations(text, depTokens, entities, e.Lang, e.MaxDistance); len(rels) > 0 {
|
||||||
|
return rels
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback: regex-based extraction
|
||||||
|
return extractRelationsWithOpts(text, entities, relLang, e.MaxDistance)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Extractor) getPredictor(modelDir string) ModelPredictor {
|
||||||
|
modelCacheMu.Lock()
|
||||||
|
defer modelCacheMu.Unlock()
|
||||||
|
if p, ok := modelCache[modelDir]; ok {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
cModelDir := C.CString(modelDir + "/ner")
|
||||||
|
cModelVocab := C.CString(modelDir + "/vocab")
|
||||||
|
handle := C.ThincNER_Create(cModelDir, cModelVocab)
|
||||||
|
C.free(unsafe.Pointer(cModelDir))
|
||||||
|
C.free(unsafe.Pointer(cModelVocab))
|
||||||
|
// Don't cache a nil handle — return a one-shot error predictor instead.
|
||||||
|
if handle == nil {
|
||||||
|
fn := func(tokensJSON string) (string, error) {
|
||||||
|
return "", fmt.Errorf("ThincNER handle is nil for model dir: %s", modelDir)
|
||||||
|
}
|
||||||
|
return fn
|
||||||
|
}
|
||||||
|
p := func(tokensJSON string) (string, error) {
|
||||||
|
e.mu.Lock()
|
||||||
|
cTokensJSON := C.CString(tokensJSON)
|
||||||
|
cResult := C.ThincNER_Predict(handle, cTokensJSON)
|
||||||
|
e.mu.Unlock()
|
||||||
|
C.free(unsafe.Pointer(cTokensJSON))
|
||||||
|
if cResult == nil {
|
||||||
|
return "", fmt.Errorf("NER prediction failed")
|
||||||
|
}
|
||||||
|
defer C.ThincNER_FreeString(cResult)
|
||||||
|
return C.GoString(cResult), nil
|
||||||
|
}
|
||||||
|
modelCache[modelDir] = p
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractEntities extracts named entities from text using C++ ThincNER.
|
||||||
|
func (e *Extractor) ExtractEntities(text string) ([]Entity, error) {
|
||||||
|
tokensJSON := tokenizeText(text, e.Lang)
|
||||||
|
if tokensJSON == "" {
|
||||||
|
return nil, fmt.Errorf("tokenization failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
modelDir := e.findModelDir()
|
||||||
|
predict := e.getPredictor(modelDir)
|
||||||
|
|
||||||
|
resultJSON, err := predict(tokensJSON)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var rawEntities []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
Start int `json:"start"`
|
||||||
|
End int `json:"end"`
|
||||||
|
Confidence float64 `json:"confidence"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(resultJSON), &rawEntities); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse NER result: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dedup by (text.lower(), start_char) — matching Python NERExtractor
|
||||||
|
// For CJK, strip spaces from entity text (BILUO decoder joins tokens with spaces)
|
||||||
|
isCJK := e.Lang == "zh" || e.Lang == "ja"
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
entities := make([]Entity, 0, len(rawEntities))
|
||||||
|
for _, re := range rawEntities {
|
||||||
|
if skipLabels[re.Label] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if re.Confidence < e.ConfidenceThreshold {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
text := re.Text
|
||||||
|
if isCJK {
|
||||||
|
text = strings.ReplaceAll(text, " ", "")
|
||||||
|
}
|
||||||
|
key := strings.ToLower(text) + "|" + strconv.Itoa(re.Start)
|
||||||
|
if seen[key] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = true
|
||||||
|
appType := spacyToAppType[re.Label]
|
||||||
|
if appType == "" {
|
||||||
|
appType = strings.ToLower(re.Label)
|
||||||
|
}
|
||||||
|
entities = append(entities, Entity{
|
||||||
|
Text: text,
|
||||||
|
Label: re.Label,
|
||||||
|
StartChar: re.Start,
|
||||||
|
EndChar: re.End,
|
||||||
|
Confidence: re.Confidence,
|
||||||
|
AppType: appType,
|
||||||
|
Metadata: map[string]interface{}{"source": "thincner"},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return entities, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findModelDir locates the spaCy model directory under /usr/share/infinity/resource/spacy.
|
||||||
|
func (e *Extractor) findModelDir() string {
|
||||||
|
modelName := langModel[e.Lang]
|
||||||
|
if modelName == "" {
|
||||||
|
modelName = "en_core_web_sm"
|
||||||
|
}
|
||||||
|
base := "/usr/share/infinity/resource/spacy/" + modelName
|
||||||
|
if dirExists(base) {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
if p := getenv("SPACY_MODEL_DIR"); p != "" {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
func dirExists(path string) bool {
|
||||||
|
info, err := os.Stat(path)
|
||||||
|
return err == nil && info.IsDir()
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenizeText tokenizes text via C++ tokenizer (all languages).
|
||||||
|
// Returns JSON array of token strings.
|
||||||
|
func tokenizeText(text, lang string) string {
|
||||||
|
cText := C.CString(text)
|
||||||
|
cLang := C.CString(lang)
|
||||||
|
defer C.free(unsafe.Pointer(cText))
|
||||||
|
defer C.free(unsafe.Pointer(cLang))
|
||||||
|
|
||||||
|
cTokens := C.ThincNER_Tokenize(cText, cLang)
|
||||||
|
if cTokens == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer C.ThincNER_FreeString(cTokens)
|
||||||
|
return C.GoString(cTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getenv(key string) string {
|
||||||
|
return os.Getenv(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectLanguage detects text language based on Unicode ranges.
|
||||||
|
// Pure Go, zero dependencies.
|
||||||
|
func DetectLanguage(text string) string {
|
||||||
|
han, hira, kata, latin := 0, 0, 0, 0
|
||||||
|
for _, r := range text {
|
||||||
|
switch {
|
||||||
|
case isHan(r):
|
||||||
|
han++
|
||||||
|
case isHiragana(r):
|
||||||
|
hira++
|
||||||
|
case isKatakana(r):
|
||||||
|
kata++
|
||||||
|
case isLatin(r):
|
||||||
|
latin++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
total := han + hira + kata + latin
|
||||||
|
if total == 0 {
|
||||||
|
return "en"
|
||||||
|
}
|
||||||
|
// CJK majority
|
||||||
|
if float64(han+hira+kata)/float64(total) > 0.3 {
|
||||||
|
if hira+kata > han {
|
||||||
|
return "ja" // Japanese-heavy
|
||||||
|
}
|
||||||
|
if han > 0 {
|
||||||
|
return "zh" // Han-heavy → Chinese
|
||||||
|
}
|
||||||
|
return "en"
|
||||||
|
}
|
||||||
|
// Latin majority — default to en (user specifies de/fr/es/pt explicitly)
|
||||||
|
return "en"
|
||||||
|
}
|
||||||
|
|
||||||
|
func isHan(r rune) bool { return r >= 0x4E00 && r <= 0x9FFF }
|
||||||
|
func isHiragana(r rune) bool { return r >= 0x3040 && r <= 0x309F }
|
||||||
|
func isKatakana(r rune) bool { return r >= 0x30A0 && r <= 0x30FF }
|
||||||
|
func isLatin(r rune) bool { return (r >= 0x0041 && r <= 0x005A) || (r >= 0x0061 && r <= 0x007A) }
|
||||||
30
internal/ingestion/compilation/extractor/ner_extractor.go
Normal file
30
internal/ingestion/compilation/extractor/ner_extractor.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
//
|
||||||
|
// Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
|
||||||
|
// Package extractor provides NER and relation extraction for the ingestion
|
||||||
|
// pipeline. It mirrors the Python rag/graphrag/ner package so that both
|
||||||
|
// code paths produce identical output.
|
||||||
|
//
|
||||||
|
// The C++ ThincNER engine (internal/cpp/) loads spaCy model.ckpt+model.bin
|
||||||
|
// directly for NER inference. Relation extraction is pure Go regex.
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
//
|
||||||
|
// ext := extractor.NewExtractor("en")
|
||||||
|
// result, err := ext.Extract("Apple Inc. was founded by Steve Jobs.", true)
|
||||||
|
// for _, e := range result.Entities { ... }
|
||||||
|
// for _, r := range result.Relations { ... }
|
||||||
|
package extractor
|
||||||
439
internal/ingestion/compilation/extractor/ner_relation.go
Normal file
439
internal/ingestion/compilation/extractor/ner_relation.go
Normal file
@@ -0,0 +1,439 @@
|
|||||||
|
//
|
||||||
|
// Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
|
||||||
|
package extractor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Multilingual relation patterns — matching Python MULTILANG_RELATION_PATTERNS.
|
||||||
|
// Entity groups [A-Z] are case-sensitive; relation keywords use (?i) inline.
|
||||||
|
// _entWord: uppercase-start word with optional trailing period (e.g. "Inc.", "Corp.")
|
||||||
|
// Periods between initials are also supported (e.g. "U.S.", "J.K.")
|
||||||
|
const _entWord = `[A-Za-z][\w']*(?:\.[A-Za-z][\w']*)*\.?`
|
||||||
|
const _relEntity = `(` + _entWord + `(?:\s+` + _entWord + `)*?)`
|
||||||
|
const _relEntity2 = `(` + _entWord + `(?:\s+` + _entWord + `){0,1})`
|
||||||
|
|
||||||
|
var relationPatterns = map[string][]relPatternEntry{
|
||||||
|
"en": {
|
||||||
|
{regexp.MustCompile(_relEntity + `\s+(?i:was)\s+(?i:founded)\s+(?i:by)\s+` + _relEntity2), "founded_by"},
|
||||||
|
{regexp.MustCompile(_relEntity + `\s+(?i:is)\s+(?i:an?\s+)?(?i:co-)?(?i:founder)\s+(?i:of)\s+` + _relEntity2), "founded_by"},
|
||||||
|
{regexp.MustCompile(_relEntity + `\s+(?i:works)\s+(?i:for)\s+` + _relEntity2), "works_for"},
|
||||||
|
{regexp.MustCompile(_relEntity + `\s+(?i:is)\s+(?i:an?\s+)?(?i:employee)\s+(?i:of)\s+` + _relEntity2), "works_for"},
|
||||||
|
{regexp.MustCompile(_relEntity + `\s+(?i:joined)\s+` + _relEntity2), "works_for"},
|
||||||
|
{regexp.MustCompile(_relEntity + `\s+(?i:is)\s+(?i:the\s+)?(?:CEO|CTO|CFO|VP|(?i:director|manager|engineer))\s+(?i:of|at)\s+` + _relEntity2), "works_for"},
|
||||||
|
{regexp.MustCompile(_relEntity + `\s+(?i:is)\s+(?i:located|based|headquartered|situated)\s+(?i:in)\s+` + _relEntity2), "located_in"},
|
||||||
|
{regexp.MustCompile(_relEntity + `\s+(?i:was)\s+(?i:born)\s+(?i:in|on)\s+` + _relEntity2), "born_in"},
|
||||||
|
{regexp.MustCompile(_relEntity + `\s+(?i:born)\s+(?i:in|on)\s+` + _relEntity2), "born_in"},
|
||||||
|
{regexp.MustCompile(_relEntity + `\s+(?i:was)\s+(?i:acquired)\s+(?i:by)\s+` + _relEntity2), "acquired"},
|
||||||
|
{regexp.MustCompile(_relEntity + `\s+(?i:acquired)\s+` + _relEntity2), "acquired"},
|
||||||
|
{regexp.MustCompile(_relEntity + `\s+(?i:is)\s+(?i:the\s+)?(?i:CEO)\s+(?i:of)\s+` + _relEntity2), "ceo_of"},
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
{regexp.MustCompile(`([\p{Han}\w]{2,6})\s*由\s*([\p{Han}\w]{2,4})\s*(?:创立|创建|成立|创办)`), "founded_by"},
|
||||||
|
{regexp.MustCompile(`([\p{Han}\w]{2,4})\s*(?:创立|创建|成立|创办)(?:\s*了\s*)?([\p{Han}\w]{2,10})`), "founded_by"},
|
||||||
|
{regexp.MustCompile(`([\p{Han}\w]{2,4})\s*(?:是\s*)?([\p{Han}\w]{2,10})\s*(?:创始人|联合创始人)`), "founded_by"},
|
||||||
|
{regexp.MustCompile(`([\p{Han}\w]{2,4})\s*(?:任职于|供职于|工作于|就职于)\s*([\p{Han}\w]{2,10})`), "works_for"},
|
||||||
|
{regexp.MustCompile(`([\p{Han}\w]{2,4})\s*(?:是\s*)?([\p{Han}\w]{2,10})\s*(?:的员工|的雇员)`), "works_for"},
|
||||||
|
{regexp.MustCompile(`([\p{Han}\w]{2,10})\s*(?:位于|坐落于|总部设在|总部位于)\s*([\p{Han}\w]{2,6})`), "located_in"},
|
||||||
|
{regexp.MustCompile(`([\p{Han}\w]{2,10})\s*在\s*([\p{Han}\w]{2,6})`), "located_in"},
|
||||||
|
{regexp.MustCompile(`([\p{Han}\w]{2,4})\s*(?:出生于|生于)\s*([\p{Han}\w]{2,6})`), "born_in"},
|
||||||
|
{regexp.MustCompile(`([\p{Han}\w]{2,10})\s*(?:收购|并购)\s*([\p{Han}\w]{2,10})`), "acquired"},
|
||||||
|
{regexp.MustCompile(`([\p{Han}\w]{2,10})\s*被\s*([\p{Han}\w]{2,10})\s*(?:收购|并购)`), "acquired"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type relPatternEntry struct {
|
||||||
|
pattern *regexp.Regexp
|
||||||
|
predicate string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractRelations extracts typed relations between entities.
|
||||||
|
// Matches the Python RelationExtractor pattern-based approach,
|
||||||
|
// including cross-sentence filtering via sentence boundary checks.
|
||||||
|
func ExtractRelations(text string, entities []Entity, lang string) []Relation {
|
||||||
|
return extractRelationsWithOpts(text, entities, lang, 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractRelationsWithOpts is the internal version with configurable max distance.
|
||||||
|
func extractRelationsWithOpts(text string, entities []Entity, lang string, maxDistance int) []Relation {
|
||||||
|
patterns, ok := relationPatterns[lang]
|
||||||
|
if !ok {
|
||||||
|
patterns = relationPatterns["en"]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build multimap: entity text → all occurrences (handles duplicate entity names)
|
||||||
|
entityMultiMap := make(map[string][]Entity, len(entities))
|
||||||
|
for _, e := range entities {
|
||||||
|
key := strings.ToLower(e.Text)
|
||||||
|
entityMultiMap[key] = append(entityMultiMap[key], e)
|
||||||
|
// Also add punctuation-stripped version
|
||||||
|
cleaned := strings.TrimRight(e.Text, ".,;:!?")
|
||||||
|
cleaned = strings.TrimSpace(cleaned)
|
||||||
|
if cleaned != e.Text {
|
||||||
|
ckey := strings.ToLower(cleaned)
|
||||||
|
entityMultiMap[ckey] = append(entityMultiMap[ckey], e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build sentence spans (matching Python's sentence splitting regex)
|
||||||
|
hasOffsets := false
|
||||||
|
for _, e := range entities {
|
||||||
|
if e.StartChar != 0 || e.EndChar != 0 {
|
||||||
|
hasOffsets = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var sentenceSpans [][2]int
|
||||||
|
if hasOffsets {
|
||||||
|
sentenceSpans = splitSentences(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
var relations []Relation
|
||||||
|
|
||||||
|
// Phase 1: Pattern-based typed relations
|
||||||
|
// Process each sentence separately to prevent cross-sentence regex matches.
|
||||||
|
// When entities have no offsets, fall back to full-text matching.
|
||||||
|
if hasOffsets && len(sentenceSpans) > 0 {
|
||||||
|
for _, entry := range patterns {
|
||||||
|
for _, sp := range sentenceSpans {
|
||||||
|
sentText := text[sp[0]:sp[1]]
|
||||||
|
matches := entry.pattern.FindAllStringSubmatchIndex(sentText, -1)
|
||||||
|
for _, m := range matches {
|
||||||
|
if len(m) < 6 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
subjStart, subjEnd := m[2], m[3]
|
||||||
|
objStart, objEnd := m[4], m[5]
|
||||||
|
if subjStart < 0 || objStart < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Adjust to absolute positions
|
||||||
|
absSubjStart := subjStart + sp[0]
|
||||||
|
absSubjEnd := subjEnd + sp[0]
|
||||||
|
absObjStart := objStart + sp[0]
|
||||||
|
absObjEnd := objEnd + sp[0]
|
||||||
|
subjText := strings.TrimSpace(text[absSubjStart:absSubjEnd])
|
||||||
|
objText := strings.TrimSpace(text[absObjStart:absObjEnd])
|
||||||
|
subj := findEntityByText(subjText, absSubjStart, entityMultiMap)
|
||||||
|
obj := findEntityByText(objText, absObjStart, entityMultiMap)
|
||||||
|
if subj.Text == "" || obj.Text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := subj.Text + "|" + entry.predicate + "|" + obj.Text
|
||||||
|
if seen[key] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = true
|
||||||
|
|
||||||
|
var ctx string
|
||||||
|
absMatchStart := m[0] + sp[0]
|
||||||
|
absMatchEnd := m[1] + sp[0]
|
||||||
|
ctx = extractContext(text, text[absMatchStart:absMatchEnd])
|
||||||
|
relations = append(relations, Relation{
|
||||||
|
Subject: subj,
|
||||||
|
Predicate: entry.predicate,
|
||||||
|
Object: obj,
|
||||||
|
Confidence: 0.8,
|
||||||
|
Context: ctx,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No offsets: process full text
|
||||||
|
for _, entry := range patterns {
|
||||||
|
matches := entry.pattern.FindAllStringSubmatchIndex(text, -1)
|
||||||
|
for _, m := range matches {
|
||||||
|
if len(m) < 6 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
subjStart, subjEnd := m[2], m[3]
|
||||||
|
objStart, objEnd := m[4], m[5]
|
||||||
|
if subjStart < 0 || objStart < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
subjText := strings.TrimSpace(text[subjStart:subjEnd])
|
||||||
|
objText := strings.TrimSpace(text[objStart:objEnd])
|
||||||
|
subj := findEntityByText(subjText, subjStart, entityMultiMap)
|
||||||
|
obj := findEntityByText(objText, objStart, entityMultiMap)
|
||||||
|
if subj.Text == "" || obj.Text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := subj.Text + "|" + entry.predicate + "|" + obj.Text
|
||||||
|
if seen[key] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = true
|
||||||
|
|
||||||
|
var ctx string
|
||||||
|
if len(m) >= 2 && m[0] >= 0 {
|
||||||
|
ctx = extractContext(text, text[m[0]:m[1]])
|
||||||
|
}
|
||||||
|
relations = append(relations, Relation{
|
||||||
|
Subject: subj,
|
||||||
|
Predicate: entry.predicate,
|
||||||
|
Object: obj,
|
||||||
|
Confidence: 0.8,
|
||||||
|
Context: ctx,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 2: Co-occurrence (standalone, with sentence boundary check)
|
||||||
|
for _, r := range extractCooccurrence(text, entities, maxDistance) {
|
||||||
|
key := r.Subject.Text + "|related_to|" + r.Object.Text
|
||||||
|
if !seen[key] {
|
||||||
|
seen[key] = true
|
||||||
|
relations = append(relations, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multi-hop inference + dedup (matching Python always applies these)
|
||||||
|
relations = inferMultiHop(relations)
|
||||||
|
relations = dedupRelations(relations)
|
||||||
|
|
||||||
|
return relations
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractCooccurrence generates related_to relations for entity pairs
|
||||||
|
// within maxDistance characters in the same sentence.
|
||||||
|
func extractCooccurrence(text string, entities []Entity, maxDistance int) []Relation {
|
||||||
|
if len(entities) < 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
hasOffsets := false
|
||||||
|
for _, e := range entities {
|
||||||
|
if e.StartChar != 0 || e.EndChar != 0 {
|
||||||
|
hasOffsets = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var sentenceSpans [][2]int
|
||||||
|
if hasOffsets {
|
||||||
|
sentenceSpans = splitSentences(text)
|
||||||
|
}
|
||||||
|
sameSentence := func(c1, c2 int) bool {
|
||||||
|
if !hasOffsets || len(sentenceSpans) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, sp := range sentenceSpans {
|
||||||
|
if sp[0] <= c1 && c1 < sp[1] && sp[0] <= c2 && c2 < sp[1] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var relations []Relation
|
||||||
|
for i := 0; i < len(entities); i++ {
|
||||||
|
for j := i + 1; j < len(entities); j++ {
|
||||||
|
e1, e2 := entities[i], entities[j]
|
||||||
|
if !sameSentence(e1.StartChar, e2.StartChar) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dist := abs(e2.StartChar - e1.EndChar)
|
||||||
|
if dist > maxDistance {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
relations = append(relations, Relation{
|
||||||
|
Subject: e1,
|
||||||
|
Predicate: "related_to",
|
||||||
|
Object: e2,
|
||||||
|
Confidence: 0.4,
|
||||||
|
Context: extractContextSimple(text, e1, e2),
|
||||||
|
Metadata: map[string]interface{}{"method": "cooccurrence"},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return relations
|
||||||
|
}
|
||||||
|
|
||||||
|
// findEntityByText finds the entity occurrence closest to matchPos.
|
||||||
|
// Uses multimap to handle duplicate entity names at different positions.
|
||||||
|
func findEntityByText(raw string, matchPos int, entityMultiMap map[string][]Entity) Entity {
|
||||||
|
text := strings.TrimSpace(raw)
|
||||||
|
// Strip trailing punctuation
|
||||||
|
for len(text) > 0 && strings.ContainsAny(text[len(text)-1:], ".,;:!?") {
|
||||||
|
text = strings.TrimSpace(text[:len(text)-1])
|
||||||
|
}
|
||||||
|
ent := findClosest(text, matchPos, entityMultiMap)
|
||||||
|
if ent.Text != "" {
|
||||||
|
return ent
|
||||||
|
}
|
||||||
|
// Try stripping trailing " and ..." / " or ..." / ", ..."
|
||||||
|
key := strings.ToLower(text)
|
||||||
|
for _, sep := range []string{" and ", " or ", ", "} {
|
||||||
|
if idx := strings.Index(key, sep); idx > 0 {
|
||||||
|
if e := findClosest(key[:idx], matchPos, entityMultiMap); e.Text != "" {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Try progressively shorter word sequences (right-to-left word stripping)
|
||||||
|
// Handles cases like "Google in" → try "Google" or "Microsoft. Microsoft" → try "microsoft" (stripped)
|
||||||
|
words := strings.Fields(key)
|
||||||
|
for i := len(words) - 1; i > 0; i-- {
|
||||||
|
candidate := strings.Join(words[:i], " ")
|
||||||
|
// Strip trailing punctuation from candidate before lookup
|
||||||
|
candidate = strings.TrimRight(candidate, ".,;:!?")
|
||||||
|
candidate = strings.TrimSpace(candidate)
|
||||||
|
if e := findClosest(candidate, matchPos, entityMultiMap); e.Text != "" {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Entity{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findClosest returns the entity occurrence closest to matchPos from the multimap.
|
||||||
|
// Also tries stripping trailing punctuation from name if exact match fails.
|
||||||
|
func findClosest(name string, matchPos int, multiMap map[string][]Entity) Entity {
|
||||||
|
entries := multiMap[strings.ToLower(name)]
|
||||||
|
if len(entries) == 0 {
|
||||||
|
// Try with trailing punctuation stripped
|
||||||
|
cleaned := strings.TrimRight(name, ".,;:!?")
|
||||||
|
cleaned = strings.TrimSpace(cleaned)
|
||||||
|
if cleaned != name {
|
||||||
|
entries = multiMap[strings.ToLower(cleaned)]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return Entity{}
|
||||||
|
}
|
||||||
|
if len(entries) == 1 {
|
||||||
|
return entries[0]
|
||||||
|
}
|
||||||
|
// Multiple occurrences: pick the one whose span center is closest to matchPos
|
||||||
|
best := entries[0]
|
||||||
|
bestDist := abs(best.StartChar + best.EndChar - 2*matchPos)
|
||||||
|
for i := 1; i < len(entries); i++ {
|
||||||
|
d := abs(entries[i].StartChar + entries[i].EndChar - 2*matchPos)
|
||||||
|
if d < bestDist {
|
||||||
|
best = entries[i]
|
||||||
|
bestDist = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractContext(text string, matchStr string) string {
|
||||||
|
idx := strings.Index(text, matchStr)
|
||||||
|
if idx < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
start := idx - 30
|
||||||
|
if start < 0 {
|
||||||
|
start = 0
|
||||||
|
}
|
||||||
|
end := idx + len(matchStr) + 30
|
||||||
|
if end > len(text) {
|
||||||
|
end = len(text)
|
||||||
|
}
|
||||||
|
return text[start:end]
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractContextSimple(text string, e1, e2 Entity) string {
|
||||||
|
start := min(e1.StartChar, e2.StartChar) - 20
|
||||||
|
if start < 0 {
|
||||||
|
start = 0
|
||||||
|
}
|
||||||
|
end := max(e1.EndChar, e2.EndChar) + 20
|
||||||
|
if end > len(text) {
|
||||||
|
end = len(text)
|
||||||
|
}
|
||||||
|
return text[start:end]
|
||||||
|
}
|
||||||
|
|
||||||
|
func abs(x int) int {
|
||||||
|
if x < 0 {
|
||||||
|
return -x
|
||||||
|
}
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func max(a, b int) int {
|
||||||
|
if a > b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitSentences splits text into sentence spans [start, end).
|
||||||
|
// Matches Python's: re.finditer(r'[^.!?]+(?:[.!?](?=\s|$))+', text)
|
||||||
|
// Go RE2 lacks lookahead, so this manually identifies sentence boundaries:
|
||||||
|
// - Periods followed by uppercase letter or end-of-string are sentence ends.
|
||||||
|
// - Periods followed by lowercase letter are abbreviations (e.g., "Inc."), not sentence ends.
|
||||||
|
// - ! and ? are always sentence-ending.
|
||||||
|
func splitSentences(text string) [][2]int {
|
||||||
|
var spans [][2]int
|
||||||
|
start := 0
|
||||||
|
for i := 0; i < len(text); {
|
||||||
|
ch := text[i]
|
||||||
|
if ch == '!' || ch == '?' {
|
||||||
|
end := i + 1
|
||||||
|
spans = append(spans, [2]int{start, end})
|
||||||
|
start = end
|
||||||
|
i = end
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '.' {
|
||||||
|
// Check if this period is a sentence end or abbreviation
|
||||||
|
// Sentence end: period followed by space(s) + uppercase or end-of-string
|
||||||
|
// Abbreviation: period followed by space(s) + lowercase
|
||||||
|
end := i + 1
|
||||||
|
next := end
|
||||||
|
for next < len(text) && text[next] == ' ' {
|
||||||
|
next++
|
||||||
|
}
|
||||||
|
if next >= len(text) {
|
||||||
|
// Period at end of text = sentence end
|
||||||
|
spans = append(spans, [2]int{start, end})
|
||||||
|
start = end
|
||||||
|
i = end
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if text[next] >= 'A' && text[next] <= 'Z' {
|
||||||
|
// Period + space + uppercase = sentence end
|
||||||
|
spans = append(spans, [2]int{start, end})
|
||||||
|
start = end
|
||||||
|
i = end
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Lowercase after period = abbreviation, not sentence end
|
||||||
|
i = end
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
// Remaining text after last sentence boundary
|
||||||
|
if start < len(text) {
|
||||||
|
spans = append(spans, [2]int{start, len(text)})
|
||||||
|
}
|
||||||
|
if len(spans) == 0 && len(text) > 0 {
|
||||||
|
spans = append(spans, [2]int{0, len(text)})
|
||||||
|
}
|
||||||
|
return spans
|
||||||
|
}
|
||||||
113
internal/ingestion/compilation/extractor/ner_test.go
Normal file
113
internal/ingestion/compilation/extractor/ner_test.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
//
|
||||||
|
// Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
|
||||||
|
package extractor
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Test data — 21 English test cases (ground truth from Python+spaCy)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type EnTestSpec struct {
|
||||||
|
name string
|
||||||
|
text string
|
||||||
|
wantEntities [][2]string // (text, label) pairs that MUST be found
|
||||||
|
wantRels []relSpec // typed relations that MUST be found
|
||||||
|
}
|
||||||
|
|
||||||
|
type relSpec struct {
|
||||||
|
subj string
|
||||||
|
pred string
|
||||||
|
obj string
|
||||||
|
}
|
||||||
|
|
||||||
|
var enTests = []EnTestSpec{
|
||||||
|
{name: "founded_by_simple", text: "Apple Inc. was founded by Steve Jobs.",
|
||||||
|
wantEntities: [][2]string{{"Steve Jobs", "PERSON"}},
|
||||||
|
wantRels: []relSpec{{"Apple Inc.", "founded_by", "Steve Jobs"}}},
|
||||||
|
{name: "founded_by_multi", text: "Google was founded by Larry Page and Sergey Brin.",
|
||||||
|
wantEntities: [][2]string{{"Larry Page", "PERSON"}, {"Sergey Brin", "PERSON"}},
|
||||||
|
wantRels: []relSpec{{"Google", "founded_by", "Larry Page"}}},
|
||||||
|
{name: "cofounder_of", text: "Elon Musk is a co-founder of Tesla.",
|
||||||
|
wantEntities: [][2]string{{"Elon Musk", "PERSON"}, {"Tesla", "ORG"}},
|
||||||
|
wantRels: []relSpec{{"Elon Musk", "founded_by", "Tesla"}}},
|
||||||
|
{name: "works_for_simple", text: "John works for Microsoft.",
|
||||||
|
wantEntities: [][2]string{{"John", "PERSON"}, {"Microsoft", "ORG"}},
|
||||||
|
wantRels: []relSpec{{"John", "works_for", "Microsoft"}}},
|
||||||
|
{name: "employee_of", text: "Mary is an employee of Google.",
|
||||||
|
wantEntities: [][2]string{{"Mary", "PERSON"}, {"Google", "ORG"}},
|
||||||
|
wantRels: []relSpec{{"Mary", "works_for", "Google"}}},
|
||||||
|
{name: "joined_company", text: "Sundar Pichai joined Google in 2004.",
|
||||||
|
wantEntities: [][2]string{{"Sundar Pichai", "PERSON"}, {"Google", "ORG"}},
|
||||||
|
wantRels: []relSpec{{"Sundar Pichai", "works_for", "Google"}}},
|
||||||
|
{name: "headquartered_in", text: "The company is headquartered in San Francisco.",
|
||||||
|
wantEntities: [][2]string{{"San Francisco", "GPE"}},
|
||||||
|
wantRels: nil},
|
||||||
|
{name: "based_in", text: "Microsoft is based in Redmond.",
|
||||||
|
wantEntities: [][2]string{{"Microsoft", "ORG"}, {"Redmond", "GPE"}},
|
||||||
|
wantRels: []relSpec{{"Microsoft", "located_in", "Redmond"}}},
|
||||||
|
{name: "born_in", text: "Albert Einstein was born in Germany.",
|
||||||
|
wantEntities: [][2]string{{"Albert Einstein", "PERSON"}, {"Germany", "GPE"}},
|
||||||
|
wantRels: []relSpec{{"Albert Einstein", "born_in", "Germany"}}},
|
||||||
|
{name: "ceo_of", text: "Sundar Pichai is the CEO of Google.",
|
||||||
|
wantEntities: [][2]string{{"Sundar Pichai", "PERSON"}, {"Google", "ORG"}},
|
||||||
|
wantRels: []relSpec{{"Sundar Pichai", "works_for", "Google"}, {"Sundar Pichai", "ceo_of", "Google"}}},
|
||||||
|
{name: "acquired_by", text: "Instagram was acquired by Facebook.",
|
||||||
|
wantEntities: nil, // en_core_web_sm doesn't tag these
|
||||||
|
wantRels: nil},
|
||||||
|
{name: "acquired_active", text: "Facebook acquired Instagram.",
|
||||||
|
wantEntities: [][2]string{{"Instagram", "PERSON"}}, // en_core_web_sm: Instagram→PERSON
|
||||||
|
wantRels: nil},
|
||||||
|
{name: "multi_founded_ceo", text: "Google was founded by Larry Page. Sundar Pichai is the CEO of Google.",
|
||||||
|
wantEntities: [][2]string{{"Larry Page", "PERSON"}, {"Sundar Pichai", "PERSON"}},
|
||||||
|
wantRels: nil},
|
||||||
|
{name: "multi_works_located", text: "John works for Microsoft. Microsoft is based in Redmond.",
|
||||||
|
wantEntities: [][2]string{{"John", "PERSON"}, {"Microsoft", "ORG"}, {"Redmond", "GPE"}},
|
||||||
|
wantRels: []relSpec{{"Microsoft", "located_in", "Redmond"}}},
|
||||||
|
{name: "no_entities", text: "The cat sat on the mat.",
|
||||||
|
wantEntities: nil,
|
||||||
|
wantRels: nil},
|
||||||
|
{name: "org_with_inc", text: "Microsoft Corporation was founded by Bill Gates.",
|
||||||
|
wantEntities: [][2]string{{"Bill Gates", "PERSON"}},
|
||||||
|
wantRels: []relSpec{{"Microsoft Corporation", "founded_by", "Bill Gates"}}},
|
||||||
|
{name: "located_city", text: "The restaurant is located in Paris.",
|
||||||
|
wantEntities: [][2]string{{"Paris", "GPE"}},
|
||||||
|
wantRels: nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Fast unit tests (pure Go, no Python dependency)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestDetectLanguage(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
text string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"Hello world", "en"},
|
||||||
|
{"你好世界", "zh"},
|
||||||
|
{"こんにちは世界", "ja"},
|
||||||
|
{"阿里巴巴由马云创立", "zh"},
|
||||||
|
{"アップルは", "ja"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := DetectLanguage(tt.text)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("DetectLanguage(%q) = %q, want %q", tt.text, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
87
internal/ingestion/compilation/extractor/parser_go.go
Normal file
87
internal/ingestion/compilation/extractor/parser_go.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package extractor
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include "../../../cpp/thinc_parser.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DepToken holds dependency parse info for one token (mirrors Go dep_relation.go)
|
||||||
|
type DepTokenC struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Head int `json:"head"`
|
||||||
|
Dep string `json:"dep"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunParser runs the C++ dependency parser on tokenized text.
|
||||||
|
// modelBaseDir: path to model directory containing ner/ and parser/ subdirectories.
|
||||||
|
// tokensJSON: JSON array of token strings.
|
||||||
|
// Returns JSON array of DepTokenC.
|
||||||
|
func RunParser(nerDir, parserDir string, tokensJSON string) (string, error) {
|
||||||
|
cNer := C.CString(nerDir)
|
||||||
|
cParser := C.CString(parserDir)
|
||||||
|
cTokens := C.CString(tokensJSON)
|
||||||
|
defer C.free(unsafe.Pointer(cNer))
|
||||||
|
defer C.free(unsafe.Pointer(cParser))
|
||||||
|
defer C.free(unsafe.Pointer(cTokens))
|
||||||
|
|
||||||
|
handle := C.ThincParser_Create(cNer, cParser)
|
||||||
|
if handle == nil {
|
||||||
|
return "", fmt.Errorf("failed to create ThincParser handle")
|
||||||
|
}
|
||||||
|
defer C.ThincParser_Destroy(handle)
|
||||||
|
|
||||||
|
cResult := C.ThincParser_Predict(handle, cTokens)
|
||||||
|
if cResult == nil {
|
||||||
|
return "", fmt.Errorf("parser prediction failed")
|
||||||
|
}
|
||||||
|
defer C.ThincParser_FreeString(cResult)
|
||||||
|
|
||||||
|
return C.GoString(cResult), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunTagger runs the C++ POS tagger.
|
||||||
|
// nerDir: path to NER model directory (for tok2vec weights).
|
||||||
|
// taggerDir: path to tagger model directory.
|
||||||
|
func RunTagger(nerDir, taggerDir string, tokensJSON string) (string, error) {
|
||||||
|
cNer := C.CString(nerDir)
|
||||||
|
cTagger := C.CString(taggerDir)
|
||||||
|
cTokens := C.CString(tokensJSON)
|
||||||
|
defer C.free(unsafe.Pointer(cNer))
|
||||||
|
defer C.free(unsafe.Pointer(cTagger))
|
||||||
|
defer C.free(unsafe.Pointer(cTokens))
|
||||||
|
|
||||||
|
handle := C.ThincTagger_Create(cNer, cTagger)
|
||||||
|
if handle == nil {
|
||||||
|
return "", fmt.Errorf("failed to create ThincTagger handle")
|
||||||
|
}
|
||||||
|
defer C.ThincTagger_Destroy(handle)
|
||||||
|
|
||||||
|
cResult := C.ThincTagger_Predict(handle, cTokens)
|
||||||
|
if cResult == nil {
|
||||||
|
return "", fmt.Errorf("tagger prediction failed")
|
||||||
|
}
|
||||||
|
defer C.ThincTagger_FreeString(cResult)
|
||||||
|
|
||||||
|
return C.GoString(cResult), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseTokensWithParser runs the C++ parser and returns parsed DepToken slice.
|
||||||
|
func ParseTokensWithParser(nerDir, parserDir string, tokens []string) ([]DepTokenC, error) {
|
||||||
|
tj, _ := json.Marshal(tokens)
|
||||||
|
resultJSON, err := RunParser(nerDir, parserDir, string(tj))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var tokensC []DepTokenC
|
||||||
|
if err := json.Unmarshal([]byte(resultJSON), &tokensC); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse result: %w", err)
|
||||||
|
}
|
||||||
|
return tokensC, nil
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build cgo
|
//go:build cgo && office
|
||||||
|
|
||||||
//
|
//
|
||||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build cgo
|
//go:build cgo && office
|
||||||
|
|
||||||
//
|
//
|
||||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !cgo
|
//go:build !cgo || !office
|
||||||
|
|
||||||
package parser
|
package parser
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build cgo
|
//go:build cgo && office
|
||||||
|
|
||||||
//
|
//
|
||||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build cgo
|
//go:build cgo && office
|
||||||
|
|
||||||
//
|
//
|
||||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build cgo
|
//go:build cgo && office
|
||||||
|
|
||||||
//
|
//
|
||||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build cgo
|
//go:build cgo && office
|
||||||
|
|
||||||
//
|
//
|
||||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
|||||||
31
lefthook.yml
31
lefthook.yml
@@ -39,9 +39,36 @@ pre-commit:
|
|||||||
stage_fixed: true
|
stage_fixed: true
|
||||||
- name: web-prettier
|
- name: web-prettier
|
||||||
glob: "web/**/*.{css,less,json,js,jsx,ts,tsx}"
|
glob: "web/**/*.{css,less,json,js,jsx,ts,tsx}"
|
||||||
run: npx --prefix web prettier --write --ignore-unknown {staged_files}
|
run: |
|
||||||
|
# Ensure web/node_modules is populated. CI runners don't run `npm install`
|
||||||
|
# before lefthook, and `npx` without local node_modules fetches the
|
||||||
|
# latest packages from the registry — which breaks because the pinned
|
||||||
|
# prettier plugins (prettier-plugin-organize-imports,
|
||||||
|
# prettier-plugin-packagejson) aren't auto-resolved and ESLint 10
|
||||||
|
# requires eslint.config.js.
|
||||||
|
# Use a mkdir-based mutex to prevent parallel npm ci from multiple hooks
|
||||||
|
# racing (ETXTBSY error on esbuild). mkdir is atomic on both Linux and macOS.
|
||||||
|
LOCKDIR=/tmp/ragflow-web-npm-lock
|
||||||
|
while ! mkdir "$LOCKDIR" 2>/dev/null; do sleep 0.5; done
|
||||||
|
if [ ! -f web/node_modules/.package-lock.json ]; then
|
||||||
|
echo "==> web/node_modules missing or incomplete; running npm ci --prefix web"
|
||||||
|
rm -rf web/node_modules
|
||||||
|
npm ci --prefix web --no-audit --no-fund || { rm -rf "$LOCKDIR"; exit 1; }
|
||||||
|
fi
|
||||||
|
rm -rf "$LOCKDIR"
|
||||||
|
cd web && printf '%s\n' {staged_files} | sed 's|^web/||' | xargs npx prettier --write --ignore-unknown
|
||||||
stage_fixed: true
|
stage_fixed: true
|
||||||
- name: web-eslint
|
- name: web-eslint
|
||||||
glob: "web/**/*.{js,jsx,ts,tsx}"
|
glob: "web/**/*.{js,jsx,ts,tsx}"
|
||||||
run: npx --prefix web eslint --fix {staged_files}
|
run: |
|
||||||
|
# Same npm ci guard as web-prettier; mkdir mutex serialises concurrent installs.
|
||||||
|
LOCKDIR=/tmp/ragflow-web-npm-lock
|
||||||
|
while ! mkdir "$LOCKDIR" 2>/dev/null; do sleep 0.5; done
|
||||||
|
if [ ! -f web/node_modules/.package-lock.json ]; then
|
||||||
|
echo "==> web/node_modules missing or incomplete; running npm ci --prefix web"
|
||||||
|
rm -rf web/node_modules
|
||||||
|
npm ci --prefix web --no-audit --no-fund || { rm -rf "$LOCKDIR"; exit 1; }
|
||||||
|
fi
|
||||||
|
rm -rf "$LOCKDIR"
|
||||||
|
cd web && printf '%s\n' {staged_files} | sed 's|^web/||' | xargs npx eslint --fix
|
||||||
stage_fixed: true
|
stage_fixed: true
|
||||||
|
|||||||
@@ -13,6 +13,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from .graph_extractor import GraphExtractor
|
from .ner_extractor import NERExtractor
|
||||||
|
from .dep_relation_extractor import DepRelationExtractor
|
||||||
|
from .types import Entity, ExtractionResult, Relation
|
||||||
|
|
||||||
__all__ = ["GraphExtractor"]
|
__all__ = [
|
||||||
|
"NERExtractor",
|
||||||
|
"DepRelationExtractor",
|
||||||
|
"Entity",
|
||||||
|
"Relation",
|
||||||
|
"ExtractionResult",
|
||||||
|
]
|
||||||
|
|||||||
558
rag/graphrag/ner/dep_relation_extractor.py
Normal file
558
rag/graphrag/ner/dep_relation_extractor.py
Normal file
@@ -0,0 +1,558 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""
|
||||||
|
Dependency-based relation extractor — full semantica alignment.
|
||||||
|
|
||||||
|
Extracts typed relations using spaCy dependency parse with:
|
||||||
|
- Multi-hop inference (A→B→C transitivity)
|
||||||
|
- Negation filtering
|
||||||
|
- Dynamic confidence scoring
|
||||||
|
- Multi-occurrence entity matching
|
||||||
|
"""
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from .types import Entity, Relation
|
||||||
|
|
||||||
|
# Language-specific dependency label mappings
|
||||||
|
# Keys: pass_subj, subj, agent, dobj, prep_obj — each maps to a dep label
|
||||||
|
# or a tuple (dep, child_dep) for compound patterns.
|
||||||
|
# None = no standard mapping (language uses different structure)
|
||||||
|
_LANG_DEP_RULES: Dict[str, Dict[str, object]] = {
|
||||||
|
"en": {"pass_subj": "nsubjpass", "subj": "nsubj",
|
||||||
|
"agent": ("agent", "pobj"),
|
||||||
|
"dobj": "dobj", "prep_obj": ("prep", "pobj")},
|
||||||
|
"de": {"subj": "sb",
|
||||||
|
"agent": ("sbp", "nk"),
|
||||||
|
"prep_obj": ("mo", "nk"),
|
||||||
|
"root_verb_child": "oc"}, # German ROOT is aux, real verb is "oc"
|
||||||
|
"fr": {"pass_subj": "nsubj:pass", "subj": "nsubj",
|
||||||
|
"agent": "obl:agent",
|
||||||
|
"dobj": "obj", "prep_obj": ("case", "obl")},
|
||||||
|
"es": {"subj": "nsubj",
|
||||||
|
"agent": "obj",
|
||||||
|
"prep_obj": ("case", "obl")},
|
||||||
|
"pt": {"pass_subj": "nsubj:pass", "subj": "nsubj",
|
||||||
|
"agent": "obl:agent",
|
||||||
|
"dobj": "obj", "prep_obj": ("case", "obl")},
|
||||||
|
"zh": {"subj": "nsubj",
|
||||||
|
"agent": ("nmod:prep", None, "由"), # case "由" marks agent
|
||||||
|
"prep_obj": ("case", "nmod")},
|
||||||
|
"ja": {"subj": "nsubj",
|
||||||
|
"agent": ("obl", None, "によって"), # "によって" marks agent
|
||||||
|
"prep_obj": ("case", "obl")},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Multi-hop inference rules: if A rel1 B and B rel2 C then A rel3 C
|
||||||
|
_MULTI_HOP: Dict[str, Dict[str, str]] = {
|
||||||
|
"ceo_of": {"is_subsidiary_of": "works_for", "located_in": "works_for"},
|
||||||
|
"works_for": {"is_subsidiary_of": "works_for"},
|
||||||
|
"founded_by": {"is_subsidiary_of": "founded_by"},
|
||||||
|
}
|
||||||
|
|
||||||
|
_VERB_RELATIONS: Dict[str, str] = {
|
||||||
|
# English
|
||||||
|
"found+by": "founded_by", "co-found+by": "founded_by",
|
||||||
|
"establish+by": "founded_by", "create+by": "founded_by",
|
||||||
|
"set+up": "founded_by", "start+by": "founded_by",
|
||||||
|
"work+for": "works_for", "employ+by": "works_for",
|
||||||
|
"hire+by": "works_for", "join": "works_for",
|
||||||
|
"lead+by": "works_for", "manage+by": "works_for",
|
||||||
|
"head+by": "works_for", "run+by": "works_for",
|
||||||
|
"own+by": "owns", "develop+by": "develops",
|
||||||
|
"write+by": "wrote", "publish+by": "published",
|
||||||
|
"invest+in": "invests_in", "partner+with": "partners_with",
|
||||||
|
"collaborate+with": "collaborates_with",
|
||||||
|
"merge+with": "merged_with", "subsidiar+y": "is_subsidiary_of",
|
||||||
|
"base+in": "located_in", "locate+in": "located_in",
|
||||||
|
"situate+in": "located_in", "headquarter+in": "located_in",
|
||||||
|
"bear+in": "born_in", "bear+on": "born_in",
|
||||||
|
"acquire+by": "acquired", "buy+by": "acquired",
|
||||||
|
# German (de): spaCy lemmas
|
||||||
|
"gründen+von": "founded_by", "errichten+von": "founded_by",
|
||||||
|
"arbeiten+für": "works_for", "beschäftigen+bei": "works_for",
|
||||||
|
"anstellen+bei": "works_for", "sich+befinden": "located_in",
|
||||||
|
"liegen+in": "located_in", "sitzen+in": "located_in",
|
||||||
|
"gebären+in": "born_in", "gebären+am": "born_in",
|
||||||
|
"erwerben+durch": "acquired", "kaufen+durch": "acquired",
|
||||||
|
"übernehmen+durch": "acquired",
|
||||||
|
# French (fr): spaCy lemmas
|
||||||
|
"fonder+par": "founded_by", "créer+par": "founded_by",
|
||||||
|
"établir+par": "founded_by",
|
||||||
|
"travailler+pour": "works_for", "employer+par": "works_for",
|
||||||
|
"embaucher+par": "works_for",
|
||||||
|
"situer+à": "located_in", "baser+à": "located_in",
|
||||||
|
"implanter+à": "located_in",
|
||||||
|
"naître+à": "born_in",
|
||||||
|
"acquérir+par": "acquired", "racheter+par": "acquired",
|
||||||
|
# Spanish + Portuguese (shared lemmas, no duplicate keys)
|
||||||
|
"fundar+por": "founded_by", "crear+por": "founded_by",
|
||||||
|
"criar+por": "founded_by",
|
||||||
|
"establecer+por": "founded_by", "estabelecer+por": "founded_by",
|
||||||
|
"trabajar+para": "works_for", "trabalhar+para": "works_for",
|
||||||
|
"emplear+por": "works_for", "empregar+por": "works_for",
|
||||||
|
"contratar+por": "works_for",
|
||||||
|
"ubicar+en": "located_in", "situar+en": "located_in",
|
||||||
|
"localizar+em": "located_in", "situar+em": "located_in",
|
||||||
|
"sediar+em": "located_in", "tener+sede": "located_in",
|
||||||
|
"nacer+en": "born_in", "nascer+em": "born_in",
|
||||||
|
"adquirir+por": "acquired", "comprar+por": "acquired",
|
||||||
|
# Chinese: verb + "由" (agent marker) or "被" (passive)
|
||||||
|
"创立+由": "founded_by", "创建+由": "founded_by",
|
||||||
|
"成立+由": "founded_by", "创办+由": "founded_by",
|
||||||
|
"设立+由": "founded_by",
|
||||||
|
"任职+于": "works_for", "就职+于": "works_for",
|
||||||
|
"工作+在": "works_for", "位于+在": "located_in",
|
||||||
|
"坐落+在": "located_in", "总部设+在": "located_in",
|
||||||
|
"出生+在": "born_in", "出生+于": "born_in",
|
||||||
|
"收购+由": "acquired", "并购+由": "acquired",
|
||||||
|
# Japanese: verb + "によって" (agent marker)
|
||||||
|
"設立+によって": "founded_by", "創立+によって": "founded_by",
|
||||||
|
"勤務+で": "works_for", "在籍+で": "works_for",
|
||||||
|
"位置+に": "located_in", "所在+に": "located_in",
|
||||||
|
"本社+を": "located_in",
|
||||||
|
"出生+に": "born_in",
|
||||||
|
"買収+によって": "acquired",
|
||||||
|
}
|
||||||
|
|
||||||
|
_COPULA_TITLE_MAP: Dict[str, List[str]] = {
|
||||||
|
"ceo": ["ceo_of", "works_for"], "cto": ["works_for"],
|
||||||
|
"cfo": ["works_for"], "coo": ["works_for"],
|
||||||
|
"vp": ["works_for"], "director": ["works_for"],
|
||||||
|
"manager": ["works_for"], "engineer": ["works_for"],
|
||||||
|
"employee": ["works_for"],
|
||||||
|
"founder": ["founded_by"], "co-founder": ["founded_by"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DepRelationExtractor:
|
||||||
|
"""Extract typed relations using dependency parse — semantica-aligned."""
|
||||||
|
|
||||||
|
def __init__(self, language: str = "en",
|
||||||
|
confidence_threshold: float = 0.3,
|
||||||
|
max_distance: int = 100):
|
||||||
|
self.language = language
|
||||||
|
self.confidence_threshold = confidence_threshold
|
||||||
|
self.max_distance = max_distance
|
||||||
|
|
||||||
|
def extract(self, text: str, entities: List[Entity],
|
||||||
|
doc=None, **options) -> List[Relation]:
|
||||||
|
semantica_rels = []
|
||||||
|
if doc is not None:
|
||||||
|
semantica_rels = self._extract_with_dep(text, doc, entities)
|
||||||
|
semantica_rels.extend(self._extract_cooccurrence(text, entities))
|
||||||
|
semantica_rels = self._infer_multi_hop(semantica_rels)
|
||||||
|
semantica_rels = self._deduplicate(semantica_rels)
|
||||||
|
return [r for r in semantica_rels if r.confidence >= self.confidence_threshold]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Multi-hop inference (属性传递)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _infer_multi_hop(relations: List[Relation]) -> List[Relation]:
|
||||||
|
"""Infer transitive relations: A→B→C ⇒ A→C."""
|
||||||
|
by_subj: Dict[str, List[Relation]] = {}
|
||||||
|
for r in relations:
|
||||||
|
if r.predicate == "related_to":
|
||||||
|
continue
|
||||||
|
by_subj.setdefault(r.subject.text.lower(), []).append(r)
|
||||||
|
|
||||||
|
inferred = []
|
||||||
|
for r in relations:
|
||||||
|
if r.predicate == "related_to":
|
||||||
|
continue
|
||||||
|
obj_key = r.obj.text.lower()
|
||||||
|
if obj_key in by_subj:
|
||||||
|
for r2 in by_subj[obj_key]:
|
||||||
|
if r2.predicate in _MULTI_HOP.get(r.predicate, {}):
|
||||||
|
inferred_rel = _MULTI_HOP[r.predicate][r2.predicate]
|
||||||
|
if inferred_rel:
|
||||||
|
inferred.append(Relation(
|
||||||
|
subject=r.subject, predicate=inferred_rel,
|
||||||
|
obj=r2.obj, confidence=min(r.confidence, r2.confidence) * 0.9,
|
||||||
|
metadata={"method": "multi_hop",
|
||||||
|
"via": f"{r.predicate}→{r2.predicate}"},
|
||||||
|
))
|
||||||
|
return relations + inferred
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Dependency extraction
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Language-aware role mapping
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _roles(self) -> Dict[str, str]:
|
||||||
|
"""Get role → dep label mapping for current language."""
|
||||||
|
return _LANG_DEP_RULES.get(self.language, _LANG_DEP_RULES["en"])
|
||||||
|
|
||||||
|
def _get_by_role(self, root, role: str, entity_map) -> list:
|
||||||
|
"""Get entities for a semantic role (language-aware). Returns [(Entity, prep?)]"""
|
||||||
|
rule = self._roles().get(role)
|
||||||
|
if rule is None:
|
||||||
|
return []
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for c in root.children:
|
||||||
|
dep = c.dep_
|
||||||
|
if isinstance(rule, str):
|
||||||
|
if dep == rule:
|
||||||
|
ent = self._entity_from_subtree(c, entity_map)
|
||||||
|
if ent:
|
||||||
|
results.append((ent, None))
|
||||||
|
elif isinstance(rule, tuple):
|
||||||
|
parent_dep, child_dep = rule[0], rule[1]
|
||||||
|
# Check optional case marker (e.g., "由" for zh, "によって" for ja)
|
||||||
|
case_marker = rule[2] if len(rule) > 2 else None
|
||||||
|
if dep == parent_dep:
|
||||||
|
if case_marker:
|
||||||
|
# Check if any child has the expected case lemma
|
||||||
|
has_case = any(
|
||||||
|
gc.lemma_ == case_marker or gc.text == case_marker
|
||||||
|
for gc in c.subtree
|
||||||
|
)
|
||||||
|
if not has_case:
|
||||||
|
continue
|
||||||
|
if child_dep is None:
|
||||||
|
ent = self._entity_from_subtree(c, entity_map)
|
||||||
|
if ent:
|
||||||
|
results.append((ent, c.lemma_.lower() if role == "prep_obj" else None))
|
||||||
|
else:
|
||||||
|
for gc in c.children:
|
||||||
|
if gc.dep_ == child_dep:
|
||||||
|
ent = self._entity_from_subtree(gc, entity_map)
|
||||||
|
if ent:
|
||||||
|
prep = c.lemma_.lower() if role == "prep_obj" else None
|
||||||
|
results.append((ent, prep))
|
||||||
|
break
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _extract_with_dep(self, text, doc, entities) -> List[Relation]:
|
||||||
|
relations = []
|
||||||
|
entity_map = self._build_entity_map_multi(entities)
|
||||||
|
is_de = self.language == "de"
|
||||||
|
|
||||||
|
for sent in doc.sents:
|
||||||
|
for token in sent:
|
||||||
|
# German: ROOT is aux verb, real verb is "oc" child
|
||||||
|
if is_de:
|
||||||
|
if token.dep_ != "ROOT":
|
||||||
|
continue
|
||||||
|
for c in token.children:
|
||||||
|
if c.dep_ == "oc":
|
||||||
|
# German: args attach to aux (ROOT), not main verb (oc)
|
||||||
|
# Pass both: root aux for args, oc for verb lemma
|
||||||
|
relations.extend(self._extract_from_root(text, c, entity_map, aux_root=token))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if token.dep_ != "ROOT":
|
||||||
|
continue
|
||||||
|
relations.extend(self._extract_from_root(text, token, entity_map))
|
||||||
|
if token.lemma_ == "be":
|
||||||
|
relations.extend(self._extract_copula(text, token, entity_map))
|
||||||
|
|
||||||
|
return relations
|
||||||
|
|
||||||
|
def _extract_from_root(self, text, root, entity_map, aux_root=None) -> List[Relation]:
|
||||||
|
relations = []
|
||||||
|
# Fall back to text when lemma is empty (zh, ja don't have lemmatizers)
|
||||||
|
verb_lemma = (root.lemma_ or root.text).lower()
|
||||||
|
# For languages like German where args attach to aux verb
|
||||||
|
check = root if aux_root is None else aux_root
|
||||||
|
|
||||||
|
# Negation
|
||||||
|
if any(c.dep_ in ("neg", "advmod:neg") for c in check.children):
|
||||||
|
return relations
|
||||||
|
|
||||||
|
# Extract roles (check both the main verb and optional aux parent)
|
||||||
|
def first(lst):
|
||||||
|
return lst[0][0] if lst else None
|
||||||
|
def get_roles(token):
|
||||||
|
return (
|
||||||
|
first(self._get_by_role(token, "subj", entity_map)),
|
||||||
|
first(self._get_by_role(token, "pass_subj", entity_map)),
|
||||||
|
first(self._get_by_role(token, "dobj", entity_map)),
|
||||||
|
first(self._get_by_role(token, "agent", entity_map)),
|
||||||
|
self._get_by_role(token, "prep_obj", entity_map),
|
||||||
|
any(c.dep_ == "aux" for c in token.children),
|
||||||
|
)
|
||||||
|
|
||||||
|
s1, sp1, d1, a1, p1, h1 = get_roles(root)
|
||||||
|
s2, sp2, d2, a2, p2, h2 = (None, None, None, None, [], False)
|
||||||
|
if aux_root:
|
||||||
|
s2, sp2, d2, a2, p2, h2 = get_roles(aux_root)
|
||||||
|
|
||||||
|
# Merge: prefer found roles from aux if main verb lacks them
|
||||||
|
nsubj = s1 or s2
|
||||||
|
nsubjpass = sp1 or sp2
|
||||||
|
dobj = d1 or d2
|
||||||
|
agent_entity = a1 or a2
|
||||||
|
prep_list = p1 + p2
|
||||||
|
has_aux = h1 or h2 or aux_root is not None
|
||||||
|
has_explicit_agent = agent_entity is not None
|
||||||
|
|
||||||
|
# Detect passive:
|
||||||
|
# - explicit pass_subj (en, fr, pt)
|
||||||
|
# - subj + agent + aux (Spanish-style)
|
||||||
|
# - subj + agent for languages with agent marker (zh, ja)
|
||||||
|
is_passive_candidate = has_explicit_agent and (has_aux or self.language in ("zh", "ja"))
|
||||||
|
|
||||||
|
effective_nsubjpass = nsubjpass or (nsubj if is_passive_candidate else None)
|
||||||
|
effective_nsubj = nsubj if not is_passive_candidate else None
|
||||||
|
|
||||||
|
# Passive: X was founded/acquired by Y
|
||||||
|
if effective_nsubjpass and agent_entity:
|
||||||
|
prep = ""
|
||||||
|
# Try language-appropriate prepositions/case markers
|
||||||
|
candidates = ("by", "von", "par", "por", "durch", "由", "によって")
|
||||||
|
for candidate in candidates:
|
||||||
|
if self._lookup(verb_lemma, candidate):
|
||||||
|
prep = candidate
|
||||||
|
break
|
||||||
|
rel_type = self._lookup(verb_lemma, prep) if prep else None
|
||||||
|
if rel_type:
|
||||||
|
if rel_type in ("founded_by", "acquired"):
|
||||||
|
subj, obj = effective_nsubjpass, agent_entity
|
||||||
|
else:
|
||||||
|
subj, obj = agent_entity, effective_nsubjpass
|
||||||
|
relations.append(self._make_rel(subj, rel_type, obj, 0.90, "passive", verb_lemma))
|
||||||
|
|
||||||
|
# Active: X VERB Y or X VERB prep Y
|
||||||
|
if effective_nsubj:
|
||||||
|
if dobj:
|
||||||
|
rt = self._lookup(verb_lemma, None)
|
||||||
|
if rt:
|
||||||
|
relations.append(self._make_rel(effective_nsubj, rt, dobj, 0.85, "active", verb_lemma))
|
||||||
|
for prep_entity, prep_l in prep_list:
|
||||||
|
rt = self._lookup(verb_lemma, prep_l)
|
||||||
|
if rt:
|
||||||
|
relations.append(self._make_rel(effective_nsubj, rt, prep_entity, 0.85,
|
||||||
|
"active_prep", verb_lemma, prep=prep_l))
|
||||||
|
|
||||||
|
# Passive with prep ("is based in")
|
||||||
|
if effective_nsubjpass and prep_list and not agent_entity:
|
||||||
|
for prep_entity, prep_l in prep_list:
|
||||||
|
rt = self._lookup(verb_lemma, prep_l)
|
||||||
|
if not rt:
|
||||||
|
rt = self._lookup("be+" + verb_lemma, prep_l)
|
||||||
|
if rt:
|
||||||
|
relations.append(self._make_rel(effective_nsubjpass, rt, prep_entity, 0.85,
|
||||||
|
"passive_prep", verb_lemma, prep=prep_l))
|
||||||
|
|
||||||
|
return relations
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_rel(subj, pred, obj, conf, method, verb, prep=""):
|
||||||
|
m = {"method": method, "verb": verb}
|
||||||
|
if prep:
|
||||||
|
m["prep"] = prep
|
||||||
|
return Relation(subject=subj, predicate=pred, obj=obj,
|
||||||
|
confidence=conf, metadata=m)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _already_has(rels, subj, pred, obj) -> bool:
|
||||||
|
for r in rels:
|
||||||
|
if r.subject.text == subj.text and r.predicate == pred and r.obj.text == obj.text:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _extract_copula(self, text, root, entity_map) -> List[Relation]:
|
||||||
|
relations = []
|
||||||
|
# Get subject using language-specific rules
|
||||||
|
subjs = self._get_by_role(root, "subj", entity_map)
|
||||||
|
subj = subjs[0][0] if subjs else None
|
||||||
|
if not subj:
|
||||||
|
return relations
|
||||||
|
|
||||||
|
title_lemma = None
|
||||||
|
prep_obj = None
|
||||||
|
deps_to_check = ["attr", "pred"] # attr=en, pred=de
|
||||||
|
for c in root.children:
|
||||||
|
if c.dep_ not in deps_to_check:
|
||||||
|
continue
|
||||||
|
for cc in c.children:
|
||||||
|
prep_deps = {"prep", "mo", "case"} # en=prep, de=mo, fr/case
|
||||||
|
if cc.dep_ not in prep_deps:
|
||||||
|
continue
|
||||||
|
for gc in cc.children:
|
||||||
|
pobj_deps = {"pobj", "nk", "obl"}
|
||||||
|
if gc.dep_ in pobj_deps or True: # accept any child as object
|
||||||
|
prep_obj = self._entity_from_subtree(gc, entity_map)
|
||||||
|
if prep_obj:
|
||||||
|
title_lemma = c.lemma_.lower()
|
||||||
|
break
|
||||||
|
|
||||||
|
if not title_lemma or not prep_obj:
|
||||||
|
return relations
|
||||||
|
for keyword, rel_types in _COPULA_TITLE_MAP.items():
|
||||||
|
if keyword in title_lemma:
|
||||||
|
for rt in rel_types:
|
||||||
|
relations.append(Relation(
|
||||||
|
subject=subj, predicate=rt, obj=prep_obj,
|
||||||
|
confidence=0.88, context=text,
|
||||||
|
metadata={"method": "copula", "title": title_lemma},
|
||||||
|
))
|
||||||
|
break
|
||||||
|
return relations
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Better entity map: multi-occurrence aware
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_entity_map_multi(entities: List[Entity]) -> Dict[str, List[Entity]]:
|
||||||
|
"""Build entity map that keeps ALL occurrences per name."""
|
||||||
|
result: Dict[str, List[Entity]] = {}
|
||||||
|
for e in entities:
|
||||||
|
key = e.text.lower()
|
||||||
|
result.setdefault(key, []).append(e)
|
||||||
|
cleaned = e.text.rstrip(".,;:!?").strip().lower()
|
||||||
|
if cleaned != key:
|
||||||
|
result.setdefault(cleaned, []).append(e)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _find_best_entity(key: str, entity_map: Dict[str, List[Entity]],
|
||||||
|
fallback_text: str = "") -> Optional[Entity]:
|
||||||
|
"""Find the best entity match. If multiple, prefer the one whose
|
||||||
|
text is an exact match for fallback_text, or the first one."""
|
||||||
|
entries = entity_map.get(key.lower(), [])
|
||||||
|
if not entries:
|
||||||
|
return None
|
||||||
|
if len(entries) == 1:
|
||||||
|
return entries[0]
|
||||||
|
# Prefer exact text match
|
||||||
|
for e in entries:
|
||||||
|
if e.text.lower() == fallback_text.lower():
|
||||||
|
return e
|
||||||
|
return entries[0]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Argument extraction helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_child_entity(token, dep, entity_map):
|
||||||
|
for c in token.children:
|
||||||
|
if c.dep_ == dep:
|
||||||
|
return DepRelationExtractor._entity_from_subtree(c, entity_map)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_agent_pobj(root, entity_map):
|
||||||
|
for c in root.children:
|
||||||
|
if c.dep_ == "agent":
|
||||||
|
for gc in c.children:
|
||||||
|
if gc.dep_ == "pobj":
|
||||||
|
return DepRelationExtractor._entity_from_subtree(gc, entity_map)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_prep_objs(root, entity_map):
|
||||||
|
results = []
|
||||||
|
for c in root.children:
|
||||||
|
if c.dep_ == "prep":
|
||||||
|
prep_lemma = c.lemma_.lower()
|
||||||
|
for gc in c.children:
|
||||||
|
if gc.dep_ == "pobj":
|
||||||
|
ent = DepRelationExtractor._entity_from_subtree(gc, entity_map)
|
||||||
|
if ent:
|
||||||
|
results.append((prep_lemma, ent))
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _entity_from_subtree(token, entity_map) -> Optional[Entity]:
|
||||||
|
"""Match token's subtree against entity map. Uses character positions
|
||||||
|
for conjunction handling."""
|
||||||
|
min_char = token.idx
|
||||||
|
max_char = token.idx + len(token.text)
|
||||||
|
for t in token.subtree:
|
||||||
|
if t.dep_ not in ("prep", "punct", "det", "aux", "auxpass", "cc", "conj"):
|
||||||
|
if t.idx < min_char:
|
||||||
|
min_char = t.idx
|
||||||
|
end = t.idx + len(t.text)
|
||||||
|
if end > max_char:
|
||||||
|
max_char = end
|
||||||
|
text = token.doc.text[min_char:max_char].strip()
|
||||||
|
key = text.lower()
|
||||||
|
# Try multi-map lookup
|
||||||
|
entries = entity_map.get(key, [])
|
||||||
|
if not entries:
|
||||||
|
for sep in (" and ", " or ", ", "):
|
||||||
|
if sep in key:
|
||||||
|
entries = entity_map.get(key.split(sep)[0].strip(), [])
|
||||||
|
if entries:
|
||||||
|
break
|
||||||
|
if not entries:
|
||||||
|
for ek, ev in entity_map.items():
|
||||||
|
if ek in key or key in ek:
|
||||||
|
entries = ev
|
||||||
|
break
|
||||||
|
if entries:
|
||||||
|
return entries[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _lookup(verb: str, prep: Optional[str] = None) -> Optional[str]:
|
||||||
|
if prep:
|
||||||
|
key = f"{verb}+{prep}"
|
||||||
|
return _VERB_RELATIONS.get(key)
|
||||||
|
return _VERB_RELATIONS.get(verb)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _deduplicate(relations: List[Relation]) -> List[Relation]:
|
||||||
|
seen = set()
|
||||||
|
result = []
|
||||||
|
for r in relations:
|
||||||
|
key = (r.subject.text.lower(), r.predicate, r.obj.text.lower())
|
||||||
|
rev = (r.obj.text.lower(), r.predicate, r.subject.text.lower())
|
||||||
|
if key in seen or rev in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
result.append(r)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Co-occurrence
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _extract_cooccurrence(self, text, entities) -> List[Relation]:
|
||||||
|
if len(entities) < 2:
|
||||||
|
return []
|
||||||
|
import re as _re
|
||||||
|
spans = [(m.start(), m.end())
|
||||||
|
for m in _re.finditer(r'[^.!?]+(?:[.!?](?=\s|$))+', text)]
|
||||||
|
|
||||||
|
def same_sent(c1, c2):
|
||||||
|
return any(ss <= c1 < se and ss <= c2 < se for ss, se in spans)
|
||||||
|
|
||||||
|
rels = []
|
||||||
|
for i in range(len(entities)):
|
||||||
|
for j in range(i + 1, len(entities)):
|
||||||
|
e1, e2 = entities[i], entities[j]
|
||||||
|
if not same_sent(e1.start_char, e2.start_char):
|
||||||
|
continue
|
||||||
|
if abs(e2.start_char - e1.end_char) > self.max_distance:
|
||||||
|
continue
|
||||||
|
cs = max(0, min(e1.start_char, e2.start_char) - 20)
|
||||||
|
ce = min(len(text), max(e1.end_char, e2.end_char) + 20)
|
||||||
|
rels.append(Relation(
|
||||||
|
subject=e1, predicate="related_to", obj=e2,
|
||||||
|
confidence=0.4, context=text[cs:ce],
|
||||||
|
metadata={"method": "cooccurrence"},
|
||||||
|
))
|
||||||
|
return rels
|
||||||
243
rag/graphrag/ner/ner_extractor.py
Normal file
243
rag/graphrag/ner/ner_extractor.py
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""
|
||||||
|
NERExtractor — semantica-style full pipeline extraction.
|
||||||
|
|
||||||
|
Pipeline: tokenize → tag(POS) → parse(dep) → NER → typed relations
|
||||||
|
|
||||||
|
All components share a single spaCy `doc` object (one forward pass).
|
||||||
|
|
||||||
|
Output includes:
|
||||||
|
- Entities (from NER, enriched with POS/dep)
|
||||||
|
- Typed relations (from dependency patterns)
|
||||||
|
- Dependency tree (heads + labels per token)
|
||||||
|
- POS tags per token
|
||||||
|
|
||||||
|
Supports 7 languages: en, zh, de, fr, es, pt, ja
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import spacy
|
||||||
|
from spacy import Language
|
||||||
|
|
||||||
|
from .dep_relation_extractor import DepRelationExtractor
|
||||||
|
from .types import Entity, ExtractionResult
|
||||||
|
|
||||||
|
# Language → spaCy model
|
||||||
|
_MODEL_MAP = {
|
||||||
|
"en": "en_core_web_sm", "zh": "zh_core_web_sm",
|
||||||
|
"de": "de_core_news_sm", "fr": "fr_core_news_sm",
|
||||||
|
"es": "es_core_news_sm", "pt": "pt_core_news_sm",
|
||||||
|
"ja": "ja_core_news_sm",
|
||||||
|
}
|
||||||
|
|
||||||
|
# SpaCy labels to skip from NER output
|
||||||
|
_SKIP_LABELS = {"ORDINAL", "CARDINAL"}
|
||||||
|
|
||||||
|
# Labels by confidence tier (for NER confidence scoring)
|
||||||
|
_HIGH_CONF = {"PERSON", "ORG", "GPE", "LOC", "DATE"}
|
||||||
|
_MED_CONF = {"PRODUCT", "EVENT", "WORK_OF_ART", "LAW", "LANGUAGE", "NORP",
|
||||||
|
"MONEY", "TIME", "PERCENT", "FAC", "QUANTITY"}
|
||||||
|
|
||||||
|
|
||||||
|
class NERExtractor:
|
||||||
|
"""
|
||||||
|
Full semantic extraction pipeline (NER + tagger + parser + relations).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ext = NERExtractor(language="en")
|
||||||
|
result = ext.extract("Apple Inc. was founded by Steve Jobs.")
|
||||||
|
|
||||||
|
# result.entities → [Entity]
|
||||||
|
# result.relations → [Relation]
|
||||||
|
# result.tokens → [TokenInfo] (text, head, dep, tag, index)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Model cache: language → nlp (shared singleton per process)
|
||||||
|
_nlp_cache: Dict[str, Language] = {}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
language: str = "en",
|
||||||
|
spacy_model: Optional[str] = None,
|
||||||
|
confidence_threshold: float = 0.3,
|
||||||
|
):
|
||||||
|
if language not in _MODEL_MAP and spacy_model is None:
|
||||||
|
language = "en"
|
||||||
|
self.language = language
|
||||||
|
self.model_name = spacy_model or _MODEL_MAP.get(language, "en_core_web_sm")
|
||||||
|
self.confidence_threshold = confidence_threshold
|
||||||
|
self._nlp: Optional[Language] = None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Model lifecycle
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _ensure_model(self):
|
||||||
|
"""Lazy-load shared spaCy model. Keeps ALL pipes needed for
|
||||||
|
dependency parsing (tagger, parser, ner, lemmatizer, attribute_ruler)."""
|
||||||
|
if self.model_name in self._nlp_cache:
|
||||||
|
self._nlp = self._nlp_cache[self.model_name]
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
nlp = spacy.load(self.model_name)
|
||||||
|
self._nlp_cache[self.model_name] = nlp
|
||||||
|
self._nlp = nlp
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Failed to load spaCy model '%s': %s",
|
||||||
|
self.model_name, e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Main extraction
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def extract(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
extract_relations: bool = True,
|
||||||
|
include_tokens: bool = True,
|
||||||
|
) -> ExtractionResult:
|
||||||
|
"""Run full pipeline on text."""
|
||||||
|
|
||||||
|
# 1. Single forward pass through spaCy
|
||||||
|
self._ensure_model()
|
||||||
|
doc = self._nlp(text)
|
||||||
|
|
||||||
|
# 2. Extract entities from NER
|
||||||
|
entities = self._extract_entities(doc)
|
||||||
|
|
||||||
|
# 3. Build token list (with POS, dep)
|
||||||
|
tokens = self._build_tokens(doc) if include_tokens else []
|
||||||
|
|
||||||
|
# 4. Extract typed relations using dependency parse
|
||||||
|
relations = []
|
||||||
|
if extract_relations and len(entities) >= 2:
|
||||||
|
dep_ext = DepRelationExtractor(
|
||||||
|
language=self.language,
|
||||||
|
confidence_threshold=self.confidence_threshold,
|
||||||
|
)
|
||||||
|
relations = dep_ext.extract(text, entities, doc=doc)
|
||||||
|
|
||||||
|
# 5. Build result
|
||||||
|
result = ExtractionResult(
|
||||||
|
entities=entities,
|
||||||
|
relations=relations,
|
||||||
|
language=self.language,
|
||||||
|
)
|
||||||
|
result.metadata = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"n_tokens": len(doc),
|
||||||
|
"n_entities": len(entities),
|
||||||
|
"n_relations": len([r for r in relations if r.predicate != "related_to"]),
|
||||||
|
}
|
||||||
|
if include_tokens:
|
||||||
|
result.metadata["tokens"] = tokens
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def extract_batch(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
extract_relations: bool = True,
|
||||||
|
include_tokens: bool = False,
|
||||||
|
batch_size: int = 32,
|
||||||
|
) -> List[ExtractionResult]:
|
||||||
|
"""Batch extraction using spaCy's nlp.pipe() for efficiency."""
|
||||||
|
self._ensure_model()
|
||||||
|
results = []
|
||||||
|
for doc in self._nlp.pipe(texts, batch_size=batch_size):
|
||||||
|
entities = self._extract_entities(doc)
|
||||||
|
tokens = self._build_tokens(doc) if include_tokens else []
|
||||||
|
relations = []
|
||||||
|
if extract_relations and len(entities) >= 2:
|
||||||
|
dep_ext = DepRelationExtractor(
|
||||||
|
language=self.language,
|
||||||
|
confidence_threshold=self.confidence_threshold,
|
||||||
|
)
|
||||||
|
relations = dep_ext.extract(doc.text, entities, doc=doc)
|
||||||
|
result = ExtractionResult(
|
||||||
|
entities=entities,
|
||||||
|
relations=relations,
|
||||||
|
language=self.language,
|
||||||
|
)
|
||||||
|
if include_tokens:
|
||||||
|
result.metadata = {"tokens": tokens}
|
||||||
|
results.append(result)
|
||||||
|
return results
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _label_confidence(label: str) -> float:
|
||||||
|
if label in _HIGH_CONF:
|
||||||
|
return 0.85
|
||||||
|
if label in _MED_CONF:
|
||||||
|
return 0.65
|
||||||
|
return 0.50
|
||||||
|
|
||||||
|
def _extract_entities(self, doc) -> List[Entity]:
|
||||||
|
"""Extract NER entities from spaCy doc, enriched with POS."""
|
||||||
|
entities = []
|
||||||
|
seen = set()
|
||||||
|
for ent in doc.ents:
|
||||||
|
if ent.label_ in _SKIP_LABELS:
|
||||||
|
continue
|
||||||
|
confidence = self._label_confidence(ent.label_)
|
||||||
|
if confidence < self.confidence_threshold:
|
||||||
|
continue
|
||||||
|
key = (ent.text.lower(), ent.start_char)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
entities.append(Entity(
|
||||||
|
text=ent.text,
|
||||||
|
label=ent.label_,
|
||||||
|
start_char=ent.start_char,
|
||||||
|
end_char=ent.end_char,
|
||||||
|
confidence=confidence,
|
||||||
|
metadata={"source": "spacy"},
|
||||||
|
))
|
||||||
|
return entities
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_tokens(doc) -> List[Dict[str, Any]]:
|
||||||
|
"""Build token list with POS tags and dependency info."""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"text": t.text,
|
||||||
|
"tag": t.tag_,
|
||||||
|
"dep": t.dep_,
|
||||||
|
"head": t.head.i,
|
||||||
|
"index": i,
|
||||||
|
"lemma": t.lemma_,
|
||||||
|
"pos": t.pos_,
|
||||||
|
}
|
||||||
|
for i, t in enumerate(doc)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clear_cache():
|
||||||
|
"""Clear the NLP model cache (e.g., for testing)."""
|
||||||
|
NERExtractor._nlp_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# Patch ExtractionResult to support metadata
|
||||||
|
|
||||||
75
rag/graphrag/ner/types.py
Normal file
75
rag/graphrag/ner/types.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""
|
||||||
|
Data types for entity and relation extraction.
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Entity:
|
||||||
|
"""Extracted entity."""
|
||||||
|
text: str
|
||||||
|
label: str # spaCy NER label: PERSON, ORG, GPE, ...
|
||||||
|
start_char: int
|
||||||
|
end_char: int
|
||||||
|
confidence: float = 1.0
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Relation:
|
||||||
|
"""Extracted relation between two entities."""
|
||||||
|
subject: Entity
|
||||||
|
predicate: str # relation type: "founded_by", "works_for", ...
|
||||||
|
obj: Entity
|
||||||
|
confidence: float = 1.0
|
||||||
|
context: str = "" # surrounding text
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExtractionResult:
|
||||||
|
"""Result of a full extraction pass."""
|
||||||
|
entities: List[Entity] = field(default_factory=list)
|
||||||
|
relations: List[Relation] = field(default_factory=list)
|
||||||
|
language: str = "en"
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
SPACY_TO_APP_ENTITY_TYPE: Dict[str, str] = {
|
||||||
|
"PERSON": "person",
|
||||||
|
"ORG": "organization",
|
||||||
|
"GPE": "geo",
|
||||||
|
"LOC": "geo",
|
||||||
|
"FAC": "geo",
|
||||||
|
"EVENT": "event",
|
||||||
|
"PRODUCT": "category",
|
||||||
|
"WORK_OF_ART": "category",
|
||||||
|
"LAW": "category",
|
||||||
|
"LANGUAGE": "category",
|
||||||
|
"NORP": "category",
|
||||||
|
"MONEY": "category",
|
||||||
|
"QUANTITY": "category",
|
||||||
|
"TIME": "event",
|
||||||
|
"DATE": "event",
|
||||||
|
"PERCENT": "category",
|
||||||
|
"CARDINAL": "category",
|
||||||
|
"ORDINAL": "category",
|
||||||
|
}
|
||||||
|
|
||||||
|
SKIP_SPACY_LABELS = {"ORDINAL", "CARDINAL"}
|
||||||
@@ -17,7 +17,7 @@
|
|||||||
import time
|
import time
|
||||||
import datetime
|
import datetime
|
||||||
import pytest
|
import pytest
|
||||||
from common.time_utils import current_timestamp, timestamp_to_date, date_string_to_timestamp, datetime_format, delta_seconds
|
from common.time_utils import current_timestamp, timestamp_to_date, date_string_to_timestamp, datetime_format, delta_seconds, format_iso_8601_to_ymd_hms
|
||||||
|
|
||||||
|
|
||||||
class TestCurrentTimestamp:
|
class TestCurrentTimestamp:
|
||||||
@@ -650,4 +650,65 @@ class TestDeltaSeconds:
|
|||||||
# If we're testing on the first day of month
|
# If we're testing on the first day of month
|
||||||
date_string = "2024-01-31 12:00:00" # Use a known past date
|
date_string = "2024-01-31 12:00:00" # Use a known past date
|
||||||
result = delta_seconds(date_string)
|
result = delta_seconds(date_string)
|
||||||
assert result > 0
|
assert result > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
class TestTimestampToDateCurrentTimeFallback:
|
||||||
|
"""Regression tests for the None/empty fallback of timestamp_to_date.
|
||||||
|
|
||||||
|
The docstring promises "If None or empty, uses current time", but the
|
||||||
|
fallback assigned ``time.time()`` (seconds) and then divided by 1000 again,
|
||||||
|
producing a date around 1970-01-21 instead of now. The existing
|
||||||
|
``test_return_type_always_string`` only checks the return type, so it never
|
||||||
|
caught this. These tests pin the behaviour by value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_none_uses_current_time(self, monkeypatch):
|
||||||
|
"""None input must resolve to current_timestamp() fallback."""
|
||||||
|
fixed_ms = 1704067200123
|
||||||
|
monkeypatch.setattr("common.time_utils.current_timestamp", lambda: fixed_ms)
|
||||||
|
assert timestamp_to_date(None) == time.strftime(
|
||||||
|
"%Y-%m-%d %H:%M:%S", time.localtime(fixed_ms / 1000)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_empty_string_uses_current_time(self, monkeypatch):
|
||||||
|
"""Empty-string input must resolve to current_timestamp() fallback."""
|
||||||
|
fixed_ms = 1704067200123
|
||||||
|
monkeypatch.setattr("common.time_utils.current_timestamp", lambda: fixed_ms)
|
||||||
|
assert timestamp_to_date("") == time.strftime(
|
||||||
|
"%Y-%m-%d %H:%M:%S", time.localtime(fixed_ms / 1000)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_zero_timestamp_is_not_treated_as_empty(self):
|
||||||
|
"""Zero timestamp should map to Unix epoch, not fallback to current time."""
|
||||||
|
assert timestamp_to_date(0) == time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(0))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
class TestFormatIso8601ToYmdHms:
|
||||||
|
"""Test cases for format_iso_8601_to_ymd_hms function."""
|
||||||
|
|
||||||
|
def test_standard_utc_z(self):
|
||||||
|
"""Standard UTC timestamp with trailing Z."""
|
||||||
|
assert format_iso_8601_to_ymd_hms("2024-01-01T12:00:00Z") == "2024-01-01 12:00:00"
|
||||||
|
|
||||||
|
def test_explicit_utc_offset(self):
|
||||||
|
"""Timestamp with an explicit +00:00 offset."""
|
||||||
|
assert format_iso_8601_to_ymd_hms("2024-01-01T12:00:00+00:00") == "2024-01-01 12:00:00"
|
||||||
|
|
||||||
|
def test_ordinal_date_extended(self):
|
||||||
|
"""ISO 8601 ordinal date (day-of-year), extended form.
|
||||||
|
|
||||||
|
dateutil.isoparse accepts it but datetime.fromisoformat rejects it,
|
||||||
|
which previously made the function silently return the input unchanged.
|
||||||
|
"""
|
||||||
|
assert format_iso_8601_to_ymd_hms("2024-001T12:00:00Z") == "2024-01-01 12:00:00"
|
||||||
|
|
||||||
|
def test_ordinal_date_basic(self):
|
||||||
|
"""ISO 8601 ordinal date (day-of-year), basic form."""
|
||||||
|
assert format_iso_8601_to_ymd_hms("2024001T120000Z") == "2024-01-01 12:00:00"
|
||||||
|
|
||||||
|
def test_invalid_string_returns_original(self):
|
||||||
|
"""Unparseable input is returned unchanged."""
|
||||||
|
assert format_iso_8601_to_ymd_hms("not-a-date") == "not-a-date"
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
|
// src/components/ui/modal.tsx
|
||||||
|
import React, { FC, ReactNode, useCallback, useEffect, useMemo } from 'react';
|
||||||
import { cn } from '@/lib/utils';
|
import { cn } from '@/lib/utils';
|
||||||
import * as DialogPrimitive from '@radix-ui/react-dialog';
|
import * as DialogPrimitive from '@radix-ui/react-dialog';
|
||||||
import { AlertCircle, CheckCircle, Info, Loader, X } from 'lucide-react';
|
import { AlertCircle, CheckCircle, Info, Loader, X } from 'lucide-react';
|
||||||
import React, { FC, ReactNode, useCallback, useEffect, useMemo } from 'react';
|
|
||||||
import { createRoot } from 'react-dom/client';
|
import { createRoot } from 'react-dom/client';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { DialogDescription } from '../dialog';
|
import { DialogDescription } from '../dialog';
|
||||||
|
|||||||
Reference in New Issue
Block a user