mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-04 01:29:35 +08:00
Add spacy based ner and relationship extractor for both python and Go version with equivalent outputs (#16456)
As title
This commit is contained in:
@@ -133,6 +133,10 @@ set_target_properties(rag_tokenizer PROPERTIES
|
||||
add_library(rag_tokenizer_c_api STATIC
|
||||
rag_analyzer_c_api.cpp
|
||||
rag_analyzer_c_api.h
|
||||
thinc_ner.cpp
|
||||
thinc_ner.h
|
||||
thinc_parser.cpp
|
||||
thinc_parser.h
|
||||
rag_analyzer.cpp
|
||||
rag_analyzer.h
|
||||
dart_trie.h
|
||||
|
||||
@@ -99,6 +99,31 @@ char* RAGAnalyzer_GetTermTag(RAGAnalyzerHandle handle, const char* term);
|
||||
// Returns: handle to the new analyzer instance, or NULL on failure
|
||||
RAGAnalyzerHandle RAGAnalyzer_Copy(RAGAnalyzerHandle handle);
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Named Entity Recognition (spaCy model inference)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Create a ThincNER inference handle.
|
||||
// model_ner_dir: path to the spaCy model's ner/ component directory
|
||||
// model_vocab_dir: path to the spaCy model's vocab/ directory (optional, can be NULL)
|
||||
// Returns: handle, or NULL on failure
|
||||
RAGAnalyzerHandle ThincNER_Create(const char* model_ner_dir, const char* model_vocab_dir);
|
||||
|
||||
// Destroy a ThincNER handle.
|
||||
void ThincNER_Destroy(RAGAnalyzerHandle handle);
|
||||
|
||||
// Run NER on pre-tokenized text.
|
||||
// tokens_json: JSON array e.g. ["Apple","Inc.","was","founded","by","Steve","Jobs","."]
|
||||
// Returns JSON array of entities, caller must free with ThincNER_FreeString.
|
||||
char* ThincNER_Predict(RAGAnalyzerHandle handle, const char* tokens_json);
|
||||
|
||||
// Tokenize text using spaCy-compatible rules.
|
||||
// Returns JSON array of token strings, caller must free with ThincNER_FreeString.
|
||||
char* ThincNER_Tokenize(const char* text, const char* lang);
|
||||
|
||||
// Free a string returned by ThincNER_Predict or ThincNER_Tokenize.
|
||||
void ThincNER_FreeString(char* ptr);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
610
internal/cpp/thinc_ner.cpp
Normal file
610
internal/cpp/thinc_ner.cpp
Normal file
@@ -0,0 +1,610 @@
|
||||
#pragma STDC FP_CONTRACT OFF
|
||||
|
||||
#include "thinc_ner.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
// =========================================================================
|
||||
// JSON parser (minimal)
|
||||
// =========================================================================
|
||||
namespace {
|
||||
std::string trim(const std::string& s) {
|
||||
auto a = s.find_first_not_of(" \t\r\n");
|
||||
return a == std::string::npos ? "" : s.substr(a, s.find_last_not_of(" \t\r\n")-a+1);
|
||||
}
|
||||
struct JVal {
|
||||
enum Type {NUL,OBJ,ARR,STR,NUM,BOOL} type=NUL;
|
||||
std::string str; std::vector<JVal> arr; std::unordered_map<std::string,JVal> obj; double num=0;
|
||||
const JVal* get(const std::string& k) const { auto it=obj.find(k); return it!=obj.end()?&it->second:nullptr; }
|
||||
int as_int() const { return (int)num; } int64_t as_i64() const { return (int64_t)num; }
|
||||
};
|
||||
struct JParser {
|
||||
const char *p,*e; char pk() { while(p<e&&(*p==' '||*p=='\t'||*p=='\n'||*p=='\r'))++p; return p<e?*p:0; }
|
||||
char nx() { while(p<e&&(*p==' '||*p=='\t'||*p=='\n'||*p=='\r'))++p; return p<e?*p++:0; }
|
||||
JVal pv() { char c=pk(); if(c=='{')return po(); if(c=='[')return pa(); if(c=='"')return ps(); if(c=='t'||c=='f')return pb();
|
||||
if(c=='n'){nx();nx();nx();nx();return JVal{};} return pn(); }
|
||||
JVal po() { JVal v;v.type=JVal::OBJ; nx(); while(pk()!='}'){auto k=ps();nx();v.obj[k.str]=pv();if(pk()==',')nx();else break;}nx();return v; }
|
||||
JVal pa() { JVal v;v.type=JVal::ARR; nx(); while(pk()!=']'){v.arr.push_back(pv());if(pk()==',')nx();else break;}nx();return v; }
|
||||
JVal ps() { JVal v;v.type=JVal::STR; nx();while(p<e&&*p!='"'){if(*p=='\\'){++p;if(p<e)v.str+=*p++;}else v.str+=*p++;}if(p<e)++p;return v; }
|
||||
JVal pn() { JVal v;v.type=JVal::NUM; auto s=p; if(p<e&&*p=='-')++p; while(p<e&&(*p>='0'&&*p<='9'))++p;
|
||||
if(p<e&&*p=='.'){++p;while(p<e&&(*p>='0'&&*p<='9'))++p;}
|
||||
if(p<e&&(*p=='e'||*p=='E')){++p;if(p<e&&(*p=='+'||*p=='-'))++p;while(p<e&&(*p>='0'&&*p<='9'))++p;}
|
||||
if(s<p){try{v.num=std::stod(std::string(s,p-s));}catch(...){v.num=0;}} return v; }
|
||||
JVal pb() { JVal v;v.type=JVal::BOOL; if(e-p>=4&&*p=='t'){v.str="true";p+=4;}else if(e-p>=5&&*p=='f'){v.str="false";p+=5;} return v; }
|
||||
JVal parse(const std::string& j) { p=j.data(); e=p+j.size(); return pv(); }
|
||||
};
|
||||
|
||||
// =========================================================================
|
||||
// MurmurHash2 64-bit (vocab string→ID, seed=0 matching spaCy StringStore)
|
||||
// =========================================================================
|
||||
static uint64_t mh2_64a(const void* key, int len, uint64_t seed) {
|
||||
const uint64_t m=0xc6a4a7935bd1e995ULL; const int r=47;
|
||||
uint64_t h=seed^(uint64_t(len)*m); auto d=(const uint8_t*)key; int rm=len;
|
||||
while(rm>=8){uint64_t k;memcpy(&k,d,8);k*=m;k^=k>>r;k*=m;h^=k;h*=m;d+=8;rm-=8;}
|
||||
switch(rm){case 7:h^=uint64_t(d[6])<<48;case 6:h^=uint64_t(d[5])<<40;
|
||||
case 5:h^=uint64_t(d[4])<<32;case 4:h^=uint64_t(d[3])<<24;
|
||||
case 3:h^=uint64_t(d[2])<<16;case 2:h^=uint64_t(d[1])<<8;
|
||||
case 1:h^=d[0];h*=m;break;}
|
||||
h^=h>>r;h*=m;h^=h>>r; return h;
|
||||
}
|
||||
static uint64_t hash_feat(const std::string& s) { return s.empty()?0:mh2_64a(s.data(),(int)s.size(),0); }
|
||||
|
||||
// =========================================================================
|
||||
// MurmurHash3_x64_128 (exact copy from mmh3 package, verified against thinc)
|
||||
// =========================================================================
|
||||
#define ROTL64(x,r) ((x << r) | (x >> (64 - r)))
|
||||
static uint64_t getblock64(const uint64_t* p, size_t i) { uint64_t r; memcpy(&r, p+i, 8); return r; }
|
||||
static uint64_t fmix64(uint64_t k) {
|
||||
k ^= k >> 33; k *= 0xff51afd7ed558ccdULL;
|
||||
k ^= k >> 33; k *= 0xc4ceb9fe1a85ec53ULL;
|
||||
k ^= k >> 33; return k;
|
||||
}
|
||||
static void mmh3_x64_128(const void* key, int len, uint32_t seed, uint32_t out[4]) {
|
||||
const uint8_t* data = (const uint8_t*)key;
|
||||
int nblocks = len / 16;
|
||||
uint64_t h1 = seed, h2 = seed;
|
||||
const uint64_t c1 = 0x87c37b91114253d5ULL;
|
||||
const uint64_t c2 = 0x4cf5ad432745937fULL;
|
||||
const uint64_t* blocks = (const uint64_t*)(data);
|
||||
for (int i = 0; i < nblocks; i++) {
|
||||
uint64_t k1 = getblock64(blocks, i*2+0);
|
||||
uint64_t k2 = getblock64(blocks, i*2+1);
|
||||
k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1;
|
||||
h1 = ROTL64(h1,27); h1 += h2; h1 = h1 * 5 + 0x52dce729;
|
||||
k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2;
|
||||
h2 = ROTL64(h2,31); h2 += h1; h2 = h2 * 5 + 0x38495ab5;
|
||||
}
|
||||
const uint8_t* tail = (const uint8_t*)(data + nblocks * 16);
|
||||
uint64_t k1 = 0, k2 = 0;
|
||||
switch (len & 15) {
|
||||
case 15: k2 ^= ((uint64_t)tail[14]) << 48;
|
||||
case 14: k2 ^= ((uint64_t)tail[13]) << 40;
|
||||
case 13: k2 ^= ((uint64_t)tail[12]) << 32;
|
||||
case 12: k2 ^= ((uint64_t)tail[11]) << 24;
|
||||
case 11: k2 ^= ((uint64_t)tail[10]) << 16;
|
||||
case 10: k2 ^= ((uint64_t)tail[9]) << 8;
|
||||
case 9: k2 ^= ((uint64_t)tail[8]) << 0;
|
||||
k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2;
|
||||
case 8: k1 ^= ((uint64_t)tail[7]) << 56;
|
||||
case 7: k1 ^= ((uint64_t)tail[6]) << 48;
|
||||
case 6: k1 ^= ((uint64_t)tail[5]) << 40;
|
||||
case 5: k1 ^= ((uint64_t)tail[4]) << 32;
|
||||
case 4: k1 ^= ((uint64_t)tail[3]) << 24;
|
||||
case 3: k1 ^= ((uint64_t)tail[2]) << 16;
|
||||
case 2: k1 ^= ((uint64_t)tail[1]) << 8;
|
||||
case 1: k1 ^= ((uint64_t)tail[0]) << 0;
|
||||
k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1;
|
||||
};
|
||||
h1 ^= len; h2 ^= len;
|
||||
h1 += h2; h2 += h1;
|
||||
h1 = fmix64(h1); h2 = fmix64(h2);
|
||||
h1 += h2; h2 += h1;
|
||||
out[0] = (uint32_t)h1; out[1] = (uint32_t)(h1>>32);
|
||||
out[2] = (uint32_t)h2; out[3] = (uint32_t)(h2>>32);
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// HashEmbed
|
||||
// =========================================================================
|
||||
struct HashEmbed {
|
||||
int n_rows=0,nO=0; uint32_t seed=0; std::vector<float> table;
|
||||
bool load(int r, int o, const float* d) { n_rows=r;nO=o;table.assign(d,d+(size_t)r*o);return!table.empty(); }
|
||||
void embed(uint64_t fid, float* out) const {
|
||||
uint8_t in[8]; for(int i=0;i<8;i++)in[i]=(uint8_t)(fid>>(i*8));
|
||||
uint32_t keys[4]; mmh3_x64_128(in,8,seed,keys);
|
||||
for(int v=0;v<4;v++){int idx=(int)(keys[v]%(uint32_t)n_rows);for(int i=0;i<nO;i++)out[i]+=table[(size_t)idx*nO+i];}
|
||||
}
|
||||
};
|
||||
|
||||
// =========================================================================
|
||||
// Features — dynamic based on n_embed
|
||||
// en (6): NORM, PREFIX, SUFFIX, SHAPE, SPACY, IS_SPACE
|
||||
// zh (5): NORM, PREFIX, SUFFIX, SHAPE, IS_SPACE
|
||||
// =========================================================================
|
||||
static uint64_t feat_norm(const std::string& t) {
|
||||
std::string lo=t; std::transform(lo.begin(),lo.end(),lo.begin(),::tolower);
|
||||
return hash_feat(lo);
|
||||
}
|
||||
// UTF-8 aware: get first Unicode codepoint as string
|
||||
static std::string utf8_first(const std::string& s) {
|
||||
if(s.empty()) return "";
|
||||
unsigned char c=(unsigned char)s[0];
|
||||
int l=1;
|
||||
if((c&0xE0)==0xC0) l=2;
|
||||
else if((c&0xF0)==0xE0) l=3;
|
||||
else if((c&0xF8)==0xF0) l=4;
|
||||
return s.substr(0,(size_t)l<=s.size()?l:1);
|
||||
}
|
||||
// Count Unicode codepoints in a string
|
||||
static size_t utf8_len(const std::string& s) {
|
||||
size_t n=0;
|
||||
for(size_t i=0;i<s.size();n++){
|
||||
unsigned char c=(unsigned char)s[i];
|
||||
if((c&0x80)==0) i+=1;
|
||||
else if((c&0xE0)==0xC0) i+=2;
|
||||
else if((c&0xF0)==0xE0) i+=3;
|
||||
else if((c&0xF8)==0xF0) i+=4;
|
||||
else i+=1;
|
||||
}
|
||||
return n;
|
||||
}
|
||||
// Get suffix: last `count` Unicode codepoints
|
||||
static std::string utf8_last(const std::string& s, size_t count) {
|
||||
size_t ulen=utf8_len(s);
|
||||
if(ulen<=count) return s;
|
||||
// Find byte position of the (ulen-count)-th codepoint
|
||||
size_t pos=0;
|
||||
for(size_t i=0;i<ulen-count;i++){
|
||||
unsigned char c=(unsigned char)s[pos];
|
||||
if((c&0x80)==0) pos+=1;
|
||||
else if((c&0xE0)==0xC0) pos+=2;
|
||||
else if((c&0xF0)==0xE0) pos+=3;
|
||||
else if((c&0xF8)==0xF0) pos+=4;
|
||||
else pos+=1;
|
||||
}
|
||||
return s.substr(pos);
|
||||
}
|
||||
static uint64_t feat_prefix(const std::string& t) {
|
||||
// spaCy: string[:1].lower() → hash the lowercased prefix
|
||||
std::string p = t.empty() ? "" : utf8_first(t);
|
||||
std::transform(p.begin(), p.end(), p.begin(), ::tolower);
|
||||
return hash_feat(p);
|
||||
}
|
||||
static uint64_t feat_suffix(const std::string& t) {
|
||||
// spaCy: string[-3:].lower() → hash the lowercased suffix
|
||||
size_t ulen=utf8_len(t);
|
||||
std::string s = ulen>=3 ? utf8_last(t,3) : t;
|
||||
std::transform(s.begin(), s.end(), s.begin(), ::tolower);
|
||||
return hash_feat(s);
|
||||
}
|
||||
static uint64_t feat_shape(const std::string& t) {
|
||||
std::string sh;
|
||||
for(unsigned char c:t){
|
||||
if(c>0x7F)sh+='x'; // CJK → 'x' (matches spaCy zh shape)
|
||||
else if(std::isupper(c))sh+='X';
|
||||
else if(std::islower(c))sh+='x';
|
||||
else if(std::isdigit(c))sh+='d';
|
||||
else sh+=c;
|
||||
}
|
||||
return hash_feat(sh);
|
||||
}
|
||||
// Extract features based on n_embed count. Returns vector of hash values.
|
||||
// NER model's tok2vec uses 4 features: NORM, PREFIX, SUFFIX, SHAPE
|
||||
// (The pipeline's standalone tok2vec uses 6 features including SPACY and IS_SPACE.)
|
||||
// Feature order matches the HashEmbed table order in the model.
|
||||
static std::vector<uint64_t> extract_features(const std::string& t, int n_embed) {
|
||||
std::vector<uint64_t> ids;
|
||||
ids.push_back(feat_norm(t)); // #0: NORM (all models)
|
||||
ids.push_back(feat_prefix(t)); // #1: PREFIX
|
||||
ids.push_back(feat_suffix(t)); // #2: SUFFIX
|
||||
ids.push_back(feat_shape(t)); // #3: SHAPE
|
||||
if(n_embed==5) {
|
||||
ids.push_back(0); // #4: IS_SPACE (zh/ja: 5-embed models, no SPACY)
|
||||
} else if(n_embed>=6) {
|
||||
ids.push_back(1); // #4: SPACY (en/de/fr/es/pt: 6-embed models)
|
||||
ids.push_back(0); // #5: IS_SPACE
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Layers
|
||||
// =========================================================================
|
||||
|
||||
// Kahan compensated dot product: reduces floating-point accumulation error
|
||||
// for long dot products (e.g. 576 terms in Maxout).
|
||||
static float kahan_dot(const float* a, const float* b, int n) {
|
||||
float sum = 0.0f;
|
||||
float c = 0.0f;
|
||||
for (int i = 0; i < n; i++) {
|
||||
float y = a[i] * b[i] - c;
|
||||
float t = sum + y;
|
||||
c = (t - sum) - y;
|
||||
sum = t;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
static void linear(float* out, const float* in, const float* W, const float* b, int nO, int nI) {
|
||||
for(int i=0;i<nO;i++)out[i]=b[i]+kahan_dot(W+(size_t)i*nI,in,nI);
|
||||
}
|
||||
static void relu_inplace(float* x, int n) { for(int i=0;i<n;i++)x[i]=x[i]>0?x[i]:0; }
|
||||
|
||||
// Maxout: y[i] = max_p(b[i,p] + W[i,p,:] @ in)
|
||||
static void maxout(float* out, const float* in, const float* W, const float* b, int nO, int nP, int nI) {
|
||||
for(int i=0;i<nO;i++){
|
||||
float best=-1e30f;
|
||||
for(int p=0;p<nP;p++){
|
||||
float s = b[(size_t)i*nP+p] + kahan_dot(W+(((size_t)i*nP+p)*nI), in, nI);
|
||||
if(s>best)best=s;
|
||||
}
|
||||
out[i]=best;
|
||||
}
|
||||
}
|
||||
|
||||
// LayerNorm: y = G * (x-mean)/sqrt(var+eps) + b
|
||||
static void layernorm(float* out, const float* in, int d, const float* G, const float* b, float eps) {
|
||||
float mn=0,vr=0; for(int i=0;i<d;i++)mn+=in[i]; mn/=d;
|
||||
for(int i=0;i<d;i++)vr+=(in[i]-mn)*(in[i]-mn); vr/=d;
|
||||
float is=1.0f/sqrtf(vr+eps);
|
||||
for(int i=0;i<d;i++)out[i]=G[i]*(in[i]-mn)*is+b[i];
|
||||
}
|
||||
|
||||
// ExpandWindow: for token at index `idx` over all_tokens[n_tokens×dim], produce [t-1, t, t+1]
|
||||
static void expand_win(float* out, const float* all, int n, int dim, int idx) {
|
||||
int off = idx*dim;
|
||||
if(idx>0)memcpy(out,all+(idx-1)*dim,dim*sizeof(float)); else memset(out,0,dim*sizeof(float));
|
||||
memcpy(out+dim,all+off,dim*sizeof(float));
|
||||
if(idx<n-1)memcpy(out+2*dim,all+(idx+1)*dim,dim*sizeof(float)); else memset(out+2*dim,0,dim*sizeof(float));
|
||||
}
|
||||
|
||||
// BILUO decoder
|
||||
struct Entity { std::string text,label; int start,end; float conf; };
|
||||
static std::vector<Entity> decode_biluo(const std::vector<std::string>& tok, const std::vector<std::string>& lbl) {
|
||||
std::vector<Entity> ents; int n=(int)tok.size(),st=-1; std::string et,ex;
|
||||
for(int i=0;i<n;i++){
|
||||
auto& l=lbl[i];
|
||||
if(l.empty()||l=="O"){if(st>=0){ents.push_back({ex,et,st,i-1,0.85f});st=-1;et.clear();ex.clear();}continue;}
|
||||
if(l.size()<3||l[1]!='-'){if(st>=0){ents.push_back({ex,et,st,i-1,0.85f});st=-1;}continue;}
|
||||
char a=l[0]; std::string ty=l.substr(2);
|
||||
if(a=='U'){if(st>=0){ents.push_back({ex,et,st,i-1,0.85f});st=-1;}ents.push_back({tok[i],ty,i,i,0.85f});}
|
||||
else if(a=='B'){if(st>=0)ents.push_back({ex,et,st,i-1,0.85f});st=i;et=ty;ex=tok[i];}
|
||||
else if(a=='I'){if(st>=0&&et==ty)ex+=" "+tok[i];else{if(st>=0)ents.push_back({ex,et,st,i-1,0.85f});st=i;et=ty;ex=tok[i];}}
|
||||
else if(a=='L'){if(st>=0&&et==ty){ex+=" "+tok[i];ents.push_back({ex,et,st,i,0.85f});}else ents.push_back({tok[i],ty,i,i,0.85f});st=-1;et.clear();ex.clear();}
|
||||
}
|
||||
if(st>=0)ents.push_back({ex,et,st,n-1,0.85f});
|
||||
return ents;
|
||||
}
|
||||
|
||||
// Tokenizer
|
||||
static std::vector<std::string> tokenize_en(const std::string& t) {
|
||||
std::vector<std::string> r; std::string cur;
|
||||
for(size_t i=0;i<t.size();i++){unsigned char c=(unsigned char)t[i];
|
||||
if(std::isalpha(c)||std::isdigit(c)||c>127)cur+=c;
|
||||
else if(c=='.'&&!cur.empty()&&i+1<t.size()&&std::isalpha((unsigned char)t[i+1]))cur+='.';
|
||||
else{if(!cur.empty()){r.push_back(cur);cur.clear();}if(!std::isspace(c))r.push_back(std::string(1,(char)c));}}
|
||||
if(!cur.empty())r.push_back(cur); return r;
|
||||
}
|
||||
static std::vector<std::string> tokenize_zh(const std::string& t) {
|
||||
std::vector<std::string> r;
|
||||
for(size_t i=0;i<t.size();i++){unsigned char c=(unsigned char)t[i];
|
||||
if((c&0x80)==0){if(std::isalpha(c)||std::isdigit(c)){std::string w;while(i<t.size()&&(std::isalpha((unsigned char)t[i])||std::isdigit((unsigned char)t[i])))w+=t[i++];r.push_back(w);i--;}else if(!std::isspace(c))r.push_back(std::string(1,(char)c));}
|
||||
else{int l=1;if((c&0xE0)==0xC0)l=2;else if((c&0xF0)==0xE0)l=3;else if((c&0xF8)==0xF0)l=4;r.push_back(t.substr(i,l));i+=l-1;}}
|
||||
return r;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// =========================================================================
|
||||
// State
|
||||
// =========================================================================
|
||||
struct State {
|
||||
// HashEmbeds
|
||||
std::vector<HashEmbed> embeds;
|
||||
// Post-embed Maxout (576→96)
|
||||
std::vector<float> poW,poB; int po_nO=96,po_nP=3,po_nI=576;
|
||||
// Post-embed LayerNorm
|
||||
std::vector<float> poG,poB2; bool has_poLN=false;
|
||||
// Residual encoder (4 blocks)
|
||||
struct ResBlk{bool has=false;std::vector<float>W,b,lnG,lnb;};
|
||||
ResBlk res[4]; int n_res=0;
|
||||
// NER hidden (96→64)
|
||||
std::vector<float> hW,hB; int hO=64; bool has_hid=false;
|
||||
// PrecomputableAffine: W_full[nP=3][nO=64][nI=2][nD=64], b_full[nO=64][nI=2]
|
||||
// We use f=0 (first feature only): pre_out[p][o] = sum_d W[p][o][0][d] * hid[d] + b[o][0]
|
||||
std::vector<float> pW_full; // flattened [3*64*2*64]
|
||||
std::vector<float> pB_full; // flattened [64*2]
|
||||
int p_nP=3, p_nO=64, p_nI=2, p_nD=64; bool has_pre=false;
|
||||
// Classifier (64→n_actions)
|
||||
std::vector<float> cW,cB; int nAct=0; bool has_cls=false;
|
||||
std::vector<std::string> actLbl;
|
||||
};
|
||||
|
||||
// =========================================================================
|
||||
// Load model
|
||||
// =========================================================================
|
||||
static bool load(const std::string& dir, State* s) {
|
||||
std::ifstream cf(dir+"/model.ckpt"); if(!cf){std::cerr<<"No model.ckpt\n";return false;}
|
||||
std::stringstream cb;cb<<cf.rdbuf();
|
||||
JVal ck=JParser().parse(cb.str()); if(ck.type!=JVal::OBJ)return false;
|
||||
|
||||
std::ifstream bf(dir+"/model.bin",std::ios::binary|std::ios::ate); if(!bf)return false;
|
||||
size_t bz=bf.tellg();bf.seekg(0); std::vector<float> bin(bz/4); bf.read((char*)bin.data(),bz);
|
||||
|
||||
auto sl=[&](int64_t o, int64_t c)->std::vector<float>{
|
||||
if(o+c>(int64_t)bin.size())return{}; return std::vector<float>(bin.begin()+o,bin.begin()+o+c);
|
||||
};
|
||||
auto ld=[&](const std::string& k, std::vector<float>* v, int* r0=nullptr,int* r1=nullptr,int* r2=nullptr)->bool{
|
||||
auto* e=ck.get(k); if(!e)return false;
|
||||
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
|
||||
if(!sv||!ov||!cv)return false;
|
||||
*v=sl(ov->as_i64(),cv->as_i64());
|
||||
if(r0)*r0=sv->arr.size()>=1?sv->arr[0].as_int():1;
|
||||
if(r1)*r1=sv->arr.size()>=2?sv->arr[1].as_int():1;
|
||||
if(r2)*r2=sv->arr.size()>=3?sv->arr[2].as_int():1;
|
||||
return!v->empty();
|
||||
};
|
||||
|
||||
// HashEmbeds — dynamic count (6 for en, 5 for zh, etc.)
|
||||
for(int ei=0;;ei++){
|
||||
auto* e=ck.get("embed_"+std::to_string(ei)+"_E"); if(!e)break;
|
||||
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
|
||||
if(!sv||!ov||!cv)break;
|
||||
int rs=sv->arr[0].as_int(),nO=sv->arr[1].as_int();
|
||||
int64_t expected=(int64_t)rs*nO;
|
||||
if(cv->as_i64()<expected)break; // count too short → malformed
|
||||
auto d=sl(ov->as_i64(),cv->as_i64()); if(d.empty())break;
|
||||
s->embeds.emplace_back(); s->embeds.back().load(rs,nO,d.data());
|
||||
}
|
||||
// Seeds
|
||||
std::ifstream ff(dir+"/feature_config.json");
|
||||
if(ff){std::stringstream fb;fb<<ff.rdbuf();auto cfg=JParser().parse(fb.str());auto* sa=cfg.get("embed_seeds");
|
||||
if(sa&&sa->type==JVal::ARR)for(int i=0;i<(int)sa->arr.size()&&i<(int)s->embeds.size();i++)s->embeds[i].seed=(uint32_t)sa->arr[i].as_int();}
|
||||
|
||||
int r0=0,r1=0,r2=0;
|
||||
// Post-embed
|
||||
if(ld("poW",&s->poW,&r0,&r1,&r2)){s->po_nO=r0;s->po_nP=r1;s->po_nI=r2;ld("poB",&s->poB);}
|
||||
if(ld("poG",&s->poG)){ld("poB2",&s->poB2);s->has_poLN=true;}
|
||||
// Residual
|
||||
for(int ri=0;ri<4;ri++){auto pk="res"+std::to_string(ri);auto& rb=s->res[ri];
|
||||
if(ld(pk+"W",&rb.W,&r0,&r1,&r2)){ld(pk+"B",&rb.b);ld(pk+"lnG",&rb.lnG);ld(pk+"lnb",&rb.lnb);rb.has=true;s->n_res++;}}
|
||||
// NER hidden
|
||||
if(ld("hW",&s->hW,&r0,&r1)){s->hO=r0;ld("hB",&s->hB);s->has_hid=true;}
|
||||
// PrecomputableAffine: load full 4D W and 2D b
|
||||
// has_pre is only set when ALL of weight buffer, bias buffer,
|
||||
// and hidden-dimension match (p_nD == hO) are satisfied.
|
||||
{
|
||||
auto* e=ck.get("pW_full"); if(e){
|
||||
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
|
||||
if(sv&&ov&&cv){
|
||||
int nP=sv->arr.size()>=1?sv->arr[0].as_int():1;
|
||||
int nO=sv->arr.size()>=2?sv->arr[1].as_int():1;
|
||||
int nI=sv->arr.size()>=3?sv->arr[2].as_int():1;
|
||||
int nD=sv->arr.size()>=4?sv->arr[3].as_int():1;
|
||||
s->p_nP=nP; s->p_nO=nO; s->p_nI=nI; s->p_nD=nD;
|
||||
size_t total = (size_t)nP * nO * nI * nD;
|
||||
s->pW_full = sl(ov->as_i64(), cv->as_i64());
|
||||
bool pw_ok = s->pW_full.size() >= total;
|
||||
|
||||
// Load bias inside pW_full block to access dimension info
|
||||
bool pb_ok = false;
|
||||
if(auto* pb_e=ck.get("pB_full")){
|
||||
auto pb_ov=pb_e->get("offset"),pb_cv=pb_e->get("count");
|
||||
if(pb_ov&&pb_cv){
|
||||
s->pB_full = sl(pb_ov->as_i64(), pb_cv->as_i64());
|
||||
pb_ok = s->pB_full.size() >= (size_t)nO * nI;
|
||||
}
|
||||
}
|
||||
|
||||
bool dim_ok = (nD == s->hO);
|
||||
s->has_pre = pw_ok && pb_ok && dim_ok;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Classifier
|
||||
if(ld("cW",&s->cW,&r0,&r1)){s->nAct=r0;ld("cB",&s->cB);s->has_cls=true;}
|
||||
|
||||
return!s->embeds.empty();
|
||||
}
|
||||
|
||||
static bool load_labels(const std::string& dir, State* s) {
|
||||
std::ifstream f(dir+"/labels.json"); if(!f)return false;
|
||||
std::stringstream b;b<<f.rdbuf();
|
||||
auto d=JParser().parse(b.str()); auto* am=d.get("action_to_label");
|
||||
if(!am||am->type!=JVal::OBJ)return false;
|
||||
int mx=0; for(auto&[k,v]:am->obj){try{int a=std::stoi(k);if(a>mx)mx=a;}catch(...){}};
|
||||
int n=s->nAct>0?s->nAct:mx+1; s->actLbl.resize(n,"O");
|
||||
for(auto&[k,v]:am->obj){try{int a=std::stoi(k);if(a>=0&&a<n)s->actLbl[a]=v.str;}catch(...){}}
|
||||
return!s->actLbl.empty();
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// C API
|
||||
// =========================================================================
|
||||
ThincNERHandle ThincNER_Create(const char* d, const char*) {
|
||||
auto* s=new State(); if(!load(d,s)){delete s;return nullptr;}
|
||||
if(!load_labels(d,s))s->actLbl.resize(74,"O");
|
||||
return s;
|
||||
}
|
||||
void ThincNER_Destroy(ThincNERHandle h) { delete (State*)h; }
|
||||
|
||||
char* ThincNER_Predict(ThincNERHandle h, const char* tj) {
|
||||
auto* s=(State*)h; if(!s)return strdup("[]");
|
||||
if(!tj)return strdup("[]");
|
||||
|
||||
// Parse tokens
|
||||
std::vector<std::string> tok; std::string j(tj); size_t p=0;
|
||||
while((p=j.find('"',p))!=std::string::npos){auto e=j.find('"',p+1);if(e==std::string::npos)break;std::string t=j.substr(p+1,e-p-1);if(!t.empty())tok.push_back(t);p=e+1;}
|
||||
int n=(int)tok.size(); if(!n)return strdup("[]");
|
||||
int NE=(int)s->embeds.size();
|
||||
// Derive per-embed dimension from loaded tensors (all embed tables share the same nO)
|
||||
int D = NE > 0 ? s->embeds[0].nO : 96;
|
||||
int EC = NE * D;
|
||||
|
||||
// ---- Step 1: HashEmbed → concat (NER model: 4×96=384, pipe: 6×96=576) ----
|
||||
std::vector<float> emb((size_t)n*EC,0);
|
||||
for(int i=0;i<n;i++){
|
||||
auto ids=extract_features(tok[i],NE);
|
||||
size_t b=(size_t)i*EC;
|
||||
for(int e=0;e<NE&&e<(int)s->embeds.size();e++)
|
||||
s->embeds[e].embed(ids[e], emb.data()+b + (size_t)e*s->embeds[e].nO);
|
||||
}
|
||||
|
||||
|
||||
// ---- Step 2: Post-embed Maxout (576→96) ----
|
||||
std::vector<float> pe((size_t)n*D);
|
||||
for(int i=0;i<n;i++)maxout(pe.data()+(size_t)i*D,emb.data()+(size_t)i*EC,s->poW.data(),s->poB.data(),s->po_nO,s->po_nP,s->po_nI);
|
||||
|
||||
// ---- Step 3: Post-embed LayerNorm ----
|
||||
std::vector<float> pln((size_t)n*D,0);
|
||||
if(s->has_poLN){for(int i=0;i<n;i++)layernorm(pln.data()+(size_t)i*D,pe.data()+(size_t)i*D,D,s->poG.data(),s->poB2.data(),1e-6f);}
|
||||
else pln=pe;
|
||||
|
||||
// ---- Step 4: Residual encoder blocks ----
|
||||
std::vector<float> enc=pln;
|
||||
for(int ri=0;ri<s->n_res;ri++){
|
||||
auto& blk=s->res[ri]; if(!blk.has)continue;
|
||||
int wd=D*3;
|
||||
std::vector<float> exp((size_t)n*wd);
|
||||
for(int i=0;i<n;i++)expand_win(exp.data()+(size_t)i*wd,enc.data(),n,D,i);
|
||||
std::vector<float> mx((size_t)n*D);
|
||||
for(int i=0;i<n;i++)maxout(mx.data()+(size_t)i*D,exp.data()+(size_t)i*wd,blk.W.data(),blk.b.data(),D,3,wd);
|
||||
std::vector<float> ln((size_t)n*D);
|
||||
if(!blk.lnG.empty()){for(int i=0;i<n;i++)layernorm(ln.data()+(size_t)i*D,mx.data()+(size_t)i*D,D,blk.lnG.data(),blk.lnb.data(),1e-6f);}
|
||||
else ln=mx;
|
||||
for(int i=0;i<n;i++){float* op=enc.data()+(size_t)i*D;for(int j=0;j<D;j++)op[j]+=ln[(size_t)i*D+j];}
|
||||
}
|
||||
|
||||
// ---- Step 5: NER tok2vec linear (96→64, no ReLU) ----
|
||||
// The NER model's layers[0] ends with a bare linear (no activation).
|
||||
// This produces the 64-dim token vectors that feed into the PrecomputableAffine.
|
||||
std::vector<float> hid((size_t)n*s->hO);
|
||||
if(s->has_hid){for(int i=0;i<n;i++){linear(hid.data()+(size_t)i*s->hO,enc.data()+(size_t)i*D,s->hW.data(),s->hB.data(),s->hO,D);}}
|
||||
else{for(int i=0;i<n;i++){int c=std::min(D,s->hO);memcpy(hid.data()+i*s->hO,enc.data()+i*D,c*4);}}
|
||||
|
||||
// ---- Step 6: PrecomputableAffine → Maxout → Classifier → constrained decoding ----
|
||||
// Matches the spaCy ParserStepModel's predict_states formula:
|
||||
// cached[t][f][o*nP+p] = W[f,o,p,:] @ hid[t] + b[o,p] (f=0..nF-1, W=[nF,nO,nP,nI])
|
||||
// unmaxed[o,nP+p] = sum_f cached[t][f][o*nP+p] (sum over nF features)
|
||||
// unmaxed += bias[o,p] (add bias once, not nF times)
|
||||
// hid_vec[o] = max(unmaxed[o*nP+0], unmaxed[o*nP+1])
|
||||
// scores[a] = cW[a][:] @ hid_vec + cB[a]
|
||||
//
|
||||
// Feature token indices match spaCy's BiluoPushDown transition system:
|
||||
// f=0: B(0) = buffer front = current token index
|
||||
// f=1: S(0) = stack top = entity_start if in entity, else back-off to B(0)
|
||||
// f=2: S(1) = stack second = back-off to B(0) (stack has ≤1 item in simple case)
|
||||
auto label_type = [](const std::string& lbl) -> char { return lbl.empty() ? 'O' : lbl[0]; };
|
||||
auto label_etype = [](const std::string& lbl) -> std::string { return lbl.size()<3?"":lbl.substr(2); };
|
||||
|
||||
std::vector<std::string> tl(n, "O");
|
||||
if(s->has_cls && s->has_pre){
|
||||
int nF=s->p_nP, nO=s->p_nO, nP=s->p_nI, nD=s->p_nD; // W: [nF, nO, nP, nI]
|
||||
std::vector<float> unmaxed((size_t)nO * nP, 0);
|
||||
std::vector<float> hid_vec(nO, 0);
|
||||
std::vector<float> scores(s->nAct, 0);
|
||||
int entity_start = -1; // token index of current B-entity start, -1 = no entity
|
||||
|
||||
for(int i=0;i<n;i++){
|
||||
// Determine feature token indices from transition state
|
||||
// (matches spaCy BiluoPushDown B(0)/S(0)/S(1) mapping)
|
||||
int ft[3];
|
||||
ft[0] = i; // B(0) = current token
|
||||
if(entity_start >= 0) {
|
||||
ft[1] = entity_start; // S(0) = entity start
|
||||
ft[2] = i; // S(1) = back-off to B(0)
|
||||
} else {
|
||||
ft[1] = i; // S(0) = back-off to B(0)
|
||||
ft[2] = i; // S(1) = back-off to B(0)
|
||||
}
|
||||
|
||||
// PrecomputableAffine: pre[f][o][p] = W[f][o][p][:] @ hid[ft[f]] + b[o][p]
|
||||
memset(unmaxed.data(), 0, (size_t)nO * nP * sizeof(float));
|
||||
for(int f=0;f<nF;f++){
|
||||
const float* hf = hid.data() + (size_t)ft[f] * nO;
|
||||
for(int o=0;o<nO;o++){
|
||||
for(int p=0;p<nP;p++){
|
||||
size_t base = (((size_t)f * nO + o) * nP + p) * nD;
|
||||
float val = 0;
|
||||
for(int d=0;d<nD;d++){
|
||||
val += s->pW_full[base + d] * hf[d];
|
||||
}
|
||||
unmaxed[(size_t)o * nP + p] += val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add bias ONCE (not nF times)
|
||||
for(int o=0;o<nO;o++){
|
||||
for(int p=0;p<nP;p++){
|
||||
unmaxed[(size_t)o * nP + p] += s->pB_full[(size_t)o * nP + p];
|
||||
}
|
||||
}
|
||||
|
||||
// Maxout: hid_vec[o] = max_p unmaxed[o*nP + p]
|
||||
for(int o=0;o<nO;o++){
|
||||
float best = unmaxed[(size_t)o * nP];
|
||||
for(int p=1;p<nP;p++){
|
||||
float v = unmaxed[(size_t)o * nP + p];
|
||||
if(v > best) best = v;
|
||||
}
|
||||
hid_vec[o] = best;
|
||||
}
|
||||
|
||||
// Classifier: scores = cW @ hid_vec + cB
|
||||
linear(scores.data(), hid_vec.data(), s->cW.data(), s->cB.data(), s->nAct, nO);
|
||||
|
||||
// Constrained greedy decoding
|
||||
char prev_type = i>0 ? label_type(tl[i-1]) : 'O';
|
||||
std::string prev_etype = i>0 ? label_etype(tl[i-1]) : "";
|
||||
int bst=-1; float bv=-1e30f;
|
||||
for(int a=0;a<s->nAct;a++){
|
||||
const std::string& lbl = (a<(int)s->actLbl.size()) ? s->actLbl[a] : "O";
|
||||
if(lbl.empty()) continue;
|
||||
char ct = label_type(lbl);
|
||||
std::string ce = label_etype(lbl);
|
||||
bool valid=false;
|
||||
if(prev_type=='O'||prev_type=='L'||prev_type=='U')
|
||||
valid = (ct=='O'||ct=='B'||ct=='U');
|
||||
else if(prev_type=='B'||prev_type=='I'){
|
||||
if(ct=='O') valid=true;
|
||||
else if((ct=='I'||ct=='L')&&ce==prev_etype) valid=true;
|
||||
}
|
||||
if(!valid) continue;
|
||||
if(scores[a]>bv){bv=scores[a];bst=a;}
|
||||
}
|
||||
if(bst>=0) {
|
||||
tl[i] = s->actLbl[bst];
|
||||
// Update entity_start for next token (BiluoPushDown stack tracking)
|
||||
char ct = label_type(tl[i]);
|
||||
if(ct == 'B') entity_start = i;
|
||||
else if(ct == 'I' || ct == 'L') { /* entity continues, keep entity_start */ }
|
||||
else entity_start = -1; // O or U → no entity open
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Step 8: BILUO decode ----
|
||||
auto ents=decode_biluo(tok,tl);
|
||||
std::string r="["; for(size_t i=0;i<ents.size();i++){if(i)r+=",";r+="{\"text\":\""+ents[i].text+"\",\"label\":\""+ents[i].label+"\",\"start\":"+std::to_string(ents[i].start)+",\"end\":"+std::to_string(ents[i].end)+",\"confidence\":"+std::to_string(ents[i].conf)+"}";} r+="]";
|
||||
return strdup(r.c_str());
|
||||
}
|
||||
|
||||
void ThincNER_FreeString(char* p) { free(p); }
|
||||
|
||||
char* ThincNER_Tokenize(const char* t, const char* l) {
|
||||
if(!t)return strdup("[]");
|
||||
std::string lang = l ? std::string(l) : "";
|
||||
// de, fr, es, pt: European Latin -> en tokenizer; zh, ja: CJK -> zh tokenizer
|
||||
bool is_cjk = (lang == "zh" || lang == "ja");
|
||||
auto tok = is_cjk ? tokenize_zh(t) : tokenize_en(t);
|
||||
std::string r="["; for(size_t i=0;i<tok.size();i++){if(i)r+=",";r+="\""+tok[i]+"\"";} r+="]";
|
||||
return strdup(r.c_str());
|
||||
}
|
||||
45
internal/cpp/thinc_ner.h
Normal file
45
internal/cpp/thinc_ner.h
Normal file
@@ -0,0 +1,45 @@
|
||||
#ifndef THINC_NER_H
|
||||
#define THINC_NER_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// C API for spaCy model inference (en_core_web_sm / zh_core_web_sm)
|
||||
//
|
||||
// Loads model.ckpt + model.bin directly from a spaCy model directory.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
typedef void* ThincNERHandle;
|
||||
|
||||
// Create / destroy an inference handle for a single spaCy model.
|
||||
// model_dir: path to the model component directory, e.g.
|
||||
// "models/en_core_web_sm-3.7.1/ner/"
|
||||
// "models/zh_core_web_sm-3.7.1/ner/"
|
||||
// Returns NULL on failure.
|
||||
ThincNERHandle ThincNER_Create(const char* model_ner_dir, const char* model_vocab_dir);
|
||||
void ThincNER_Destroy(ThincNERHandle handle);
|
||||
|
||||
// Run NER on pre-tokenized text.
|
||||
// tokens_json: JSON array of token strings, e.g. ["Apple", "Inc.", "was", ...]
|
||||
// Returns JSON array of entities:
|
||||
// [{"text":"Apple Inc.","label":"ORG","start":0,"end":10,"confidence":0.85}, ...]
|
||||
// Caller must free with ThincNER_FreeString.
|
||||
char* ThincNER_Predict(ThincNERHandle handle, const char* tokens_json);
|
||||
|
||||
// Free a string returned by ThincNER_Predict.
|
||||
void ThincNER_FreeString(char* ptr);
|
||||
|
||||
// Utility: tokenize text using spaCy-compatible rules.
|
||||
// Returns JSON array of token strings.
|
||||
char* ThincNER_Tokenize(const char* text, const char* lang);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // THINC_NER_H
|
||||
530
internal/cpp/thinc_parser.cpp
Normal file
530
internal/cpp/thinc_parser.cpp
Normal file
@@ -0,0 +1,530 @@
|
||||
#include "thinc_parser.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
// =========================================================================
|
||||
// Minimal JSON parser (replicated from thinc_ner.cpp)
|
||||
// =========================================================================
|
||||
namespace {
|
||||
struct JVal {
|
||||
enum Type{NUL,OBJ,ARR,STR,NUM,BOOL}type=NUL;
|
||||
std::string str; std::vector<JVal> arr; std::unordered_map<std::string,JVal> obj; double num=0;
|
||||
const JVal* get(const std::string& k)const{auto it=obj.find(k);return it!=obj.end()?&it->second:nullptr;}
|
||||
int as_int()const{return(int)num;}int64_t as_i64()const{return(int64_t)num;}
|
||||
};
|
||||
struct JParser {
|
||||
const char *p,*e;
|
||||
char pk(){while(p<e&&(*p==' '||*p=='\t'||*p=='\n'||*p=='\r'))++p;return p<e?*p:0;}
|
||||
char nx(){while(p<e&&(*p==' '||*p=='\t'||*p=='\n'||*p=='\r'))++p;return p<e?*p++:0;}
|
||||
JVal pv(){char c=pk();if(c=='{')return po();if(c=='[')return pa();if(c=='"')return ps();if(c=='t'||c=='f')return pb();
|
||||
if(c=='n'){nx();nx();nx();nx();return JVal{};}return pn();}
|
||||
JVal po(){JVal v;v.type=JVal::OBJ;nx();while(p<e&&pk()!='}'){auto k=ps();nx();v.obj[k.str]=pv();if(p<e&&pk()==',')nx();else break;}if(p<e)nx();return v;}
|
||||
JVal pa(){JVal v;v.type=JVal::ARR;nx();while(p<e&&pk()!=']'){v.arr.push_back(pv());if(p<e&&pk()==',')nx();else break;}if(p<e)nx();return v;}
|
||||
JVal ps(){JVal v;v.type=JVal::STR;nx();while(p<e&&*p!='"'){if(*p=='\\'){++p;if(p<e){
|
||||
switch(*p){case'"':case'\\':case'/':v.str+=*p++;break;case'n':v.str+='\n';++p;break;case't':v.str+='\t';++p;break;case'r':v.str+='\r';++p;break;case'b':v.str+='\b';++p;break;case'f':v.str+='\f';++p;break;case'u':{if(p+4<e){char tmp[5]={p[1],p[2],p[3],p[4],0};v.str+=(char)strtol(tmp,nullptr,16);p+=5;}else{++p;}}break;default:v.str+=*p++;break;}
|
||||
}}else v.str+=*p++;}if(p<e)++p;return v;}
|
||||
JVal pn(){JVal v;v.type=JVal::NUM;auto s=p;if(p<e&&*p=='-')++p;while(p<e&&(*p>='0'&&*p<='9'))++p;
|
||||
if(p<e&&*p=='.'){++p;while(p<e&&(*p>='0'&&*p<='9'))++p;}
|
||||
if(p<e&&(*p=='e'||*p=='E')){++p;if(p<e&&(*p=='+'||*p=='-'))++p;while(p<e&&(*p>='0'&&*p<='9'))++p;}
|
||||
if(s<p){try{v.num=std::stod(std::string(s,p-s));}catch(...){v.num=0;}}return v;}
|
||||
JVal pb(){JVal v;v.type=JVal::BOOL;if(e-p>=4&&*p=='t'){v.str="true";p+=4;}else if(e-p>=5&&*p=='f'){v.str="false";p+=5;}return v;}
|
||||
JVal parse(const std::string& j){p=j.data();e=p+j.size();return pv();}
|
||||
};
|
||||
|
||||
// =========================================================================
|
||||
// HashEmbed + MurmurHash (copied from thinc_ner.cpp)
|
||||
// =========================================================================
|
||||
#define ROTL64(x,r) ((x << r) | (x >> (64 - r)))
|
||||
static uint64_t getblock64(const uint64_t* p, size_t i) { uint64_t r; memcpy(&r, p+i, 8); return r; }
|
||||
static uint64_t fmix64(uint64_t k) {
|
||||
k ^= k >> 33; k *= 0xff51afd7ed558ccdULL; k ^= k >> 33; k *= 0xc4ceb9fe1a85ec53ULL; k ^= k >> 33; return k;
|
||||
}
|
||||
static void mmh3_x64_128(const void* key, int len, uint32_t seed, uint32_t out[4]) {
|
||||
const uint8_t* data=(const uint8_t*)key; int nblocks=len/16;
|
||||
uint64_t h1=seed,h2=seed,c1=0x87c37b91114253d5ULL,c2=0x4cf5ad432745937fULL;
|
||||
const uint64_t* blocks=(const uint64_t*)data;
|
||||
for(int i=0;i<nblocks;i++){
|
||||
uint64_t k1=getblock64(blocks,i*2+0),k2=getblock64(blocks,i*2+1);
|
||||
k1*=c1;k1=ROTL64(k1,31);k1*=c2;h1^=k1;h1=ROTL64(h1,27);h1+=h2;h1=h1*5+0x52dce729;
|
||||
k2*=c2;k2=ROTL64(k2,33);k2*=c1;h2^=k2;h2=ROTL64(h2,31);h2+=h1;h2=h2*5+0x38495ab5;
|
||||
}
|
||||
const uint8_t* tail=(const uint8_t*)(data+nblocks*16);uint64_t k1=0,k2=0;
|
||||
switch(len&15){
|
||||
case 15:k2^=((uint64_t)tail[14])<<48;case 14:k2^=((uint64_t)tail[13])<<40;
|
||||
case 13:k2^=((uint64_t)tail[12])<<32;case 12:k2^=((uint64_t)tail[11])<<24;
|
||||
case 11:k2^=((uint64_t)tail[10])<<16;case 10:k2^=((uint64_t)tail[9])<<8;
|
||||
case 9:k2^=((uint64_t)tail[8])<<0;k2*=c2;k2=ROTL64(k2,33);k2*=c1;h2^=k2;
|
||||
case 8:k1^=((uint64_t)tail[7])<<56;case 7:k1^=((uint64_t)tail[6])<<48;
|
||||
case 6:k1^=((uint64_t)tail[5])<<40;case 5:k1^=((uint64_t)tail[4])<<32;
|
||||
case 4:k1^=((uint64_t)tail[3])<<24;case 3:k1^=((uint64_t)tail[2])<<16;
|
||||
case 2:k1^=((uint64_t)tail[1])<<8;case 1:k1^=((uint64_t)tail[0])<<0;k1*=c1;k1=ROTL64(k1,31);k1*=c2;h1^=k1;
|
||||
};h1^=len;h2^=len;h1+=h2;h2+=h1;h1=fmix64(h1);h2=fmix64(h2);h1+=h2;h2+=h1;
|
||||
out[0]=(uint32_t)h1;out[1]=(uint32_t)(h1>>32);out[2]=(uint32_t)h2;out[3]=(uint32_t)(h2>>32);
|
||||
}
|
||||
|
||||
static uint64_t mh2_64a(const void* key, int len, uint64_t seed) {
|
||||
const uint64_t m=0xc6a4a7935bd1e995ULL;const int r=47;
|
||||
uint64_t h=seed^(uint64_t(len)*m);auto d=(const uint8_t*)key;int rm=len;
|
||||
while(rm>=8){uint64_t k;memcpy(&k,d,8);k*=m;k^=k>>r;k*=m;h^=k;h*=m;d+=8;rm-=8;}
|
||||
switch(rm){case 7:h^=uint64_t(d[6])<<48;case 6:h^=uint64_t(d[5])<<40;case 5:h^=uint64_t(d[4])<<32;
|
||||
case 4:h^=uint64_t(d[3])<<24;case 3:h^=uint64_t(d[2])<<16;case 2:h^=uint64_t(d[1])<<8;case 1:h^=d[0];h*=m;break;}
|
||||
h^=h>>r;h*=m;h^=h>>r;return h;
|
||||
}
|
||||
static uint64_t hash_feat(const std::string& s){return s.empty()?0:mh2_64a(s.data(),(int)s.size(),0);}
|
||||
|
||||
struct HashEmbed{
|
||||
int n_rows=0,nO=0;uint32_t seed=0;std::vector<float> table;
|
||||
bool load(int r,int o,const float*d){n_rows=r;nO=o;table.assign(d,d+(size_t)r*o);return!table.empty();}
|
||||
void embed(uint64_t fid,float* out)const{
|
||||
uint8_t in[8];for(int i=0;i<8;i++)in[i]=(uint8_t)(fid>>(i*8));
|
||||
uint32_t keys[4];mmh3_x64_128(in,8,seed,keys);
|
||||
for(int v=0;v<4;v++){int idx=(int)(keys[v]%(uint32_t)n_rows);for(int i=0;i<nO;i++)out[i]+=table[(size_t)idx*nO+i];}
|
||||
}
|
||||
};
|
||||
|
||||
// =========================================================================
|
||||
// Layer primitives
|
||||
// =========================================================================
|
||||
static void linear(float* out, const float* in, const float* W, const float* b, int nO, int nI) {
|
||||
for(int i=0;i<nO;i++){float s=b[i];for(int j=0;j<nI;j++)s+=W[(size_t)i*nI+j]*in[j];out[i]=s;}
|
||||
}
|
||||
static void maxout(float* out, const float* in, const float* W, const float* b, int nO, int nP, int nI) {
|
||||
for(int i=0;i<nO;i++){float best=-1e30f;for(int p=0;p<nP;p++){float s=b[(size_t)i*nP+p];for(int j=0;j<nI;j++)s+=W[((size_t)i*nP+p)*nI+j]*in[j];if(s>best)best=s;}out[i]=best;}
|
||||
}
|
||||
static void layernorm(float* out, const float* in, int d, const float* G, const float* b, float eps) {
|
||||
float mn=0,vr=0;for(int i=0;i<d;i++)mn+=in[i];mn/=d;for(int i=0;i<d;i++)vr+=(in[i]-mn)*(in[i]-mn);vr/=d;float is=1.0f/sqrtf(vr+eps);
|
||||
for(int i=0;i<d;i++)out[i]=G[i]*(in[i]-mn)*is+b[i];
|
||||
}
|
||||
static void expand_win(float* out, const float* all, int n, int dim, int idx) {
|
||||
int off=idx*dim;if(idx>0)memcpy(out,all+(idx-1)*dim,dim*sizeof(float));else memset(out,0,dim*sizeof(float));
|
||||
memcpy(out+dim,all+off,dim*sizeof(float));if(idx<n-1)memcpy(out+2*dim,all+(idx+1)*dim,dim*sizeof(float));else memset(out+2*dim,0,dim*sizeof(float));
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Feature extraction — UTF-8 aware, matching spaCy
|
||||
// =========================================================================
|
||||
static std::string u8_first(const std::string& s){
|
||||
if(s.empty())return"";unsigned char c=(unsigned char)s[0];int l=1;
|
||||
if((c&0xE0)==0xC0)l=2;else if((c&0xF0)==0xE0)l=3;else if((c&0xF8)==0xF0)l=4;
|
||||
return s.substr(0,(size_t)l<=s.size()?l:1);
|
||||
}
|
||||
static size_t u8_len(const std::string& s){
|
||||
size_t n=0;
|
||||
for(size_t i=0;i<s.size();n++){unsigned char c=(unsigned char)s[i];
|
||||
if((c&0x80)==0)i+=1;else if((c&0xE0)==0xC0)i+=2;else if((c&0xF0)==0xE0)i+=3;else if((c&0xF8)==0xF0)i+=4;else i+=1;}
|
||||
return n;
|
||||
}
|
||||
static std::string u8_last(const std::string& s, size_t count){
|
||||
size_t ul=u8_len(s);if(ul<=count)return s;size_t pos=0;
|
||||
for(size_t i=0;i<ul-count;i++){unsigned char c=(unsigned char)s[pos];
|
||||
if((c&0x80)==0)pos+=1;else if((c&0xE0)==0xC0)pos+=2;else if((c&0xF0)==0xE0)pos+=3;else if((c&0xF8)==0xF0)pos+=4;else pos+=1;}
|
||||
return s.substr(pos);
|
||||
}
|
||||
static std::vector<uint64_t> extract_features(const std::string& t, int n_embed){
|
||||
auto fn=[&](const std::string& s){return hash_feat(s);};
|
||||
auto fp=[&](const std::string& s){std::string p=s.empty()?"":u8_first(s);std::transform(p.begin(),p.end(),p.begin(),::tolower);return hash_feat(p);};
|
||||
auto fs=[&](const std::string& s){std::string su=u8_len(s)>=3?u8_last(s,3):s;std::transform(su.begin(),su.end(),su.begin(),::tolower);return hash_feat(su);};
|
||||
auto fsh=[&](const std::string& t2){std::string sh;for(unsigned char c:t2){if(c>0x7F)sh+='x';else if(std::isupper(c))sh+='X';else if(std::islower(c))sh+='x';else if(std::isdigit(c))sh+='d';else sh+=c;}return hash_feat(sh);};
|
||||
std::vector<uint64_t> ids;
|
||||
std::string lo=t;std::transform(lo.begin(),lo.end(),lo.begin(),::tolower);
|
||||
ids.push_back(fn(lo));ids.push_back(fp(t));ids.push_back(fs(t));ids.push_back(fsh(t));
|
||||
if(n_embed==6){ids.push_back(1);ids.push_back(0);}else{ids.push_back(0);}
|
||||
return ids;
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Tok2vec forward pass (shared with NER)
|
||||
// =========================================================================
|
||||
struct Tok2vecModel {
|
||||
std::vector<HashEmbed> embeds;
|
||||
std::vector<float> poW,poB,poG,poB2; bool has_poLN=false;
|
||||
int po_nO=96,po_nP=3,po_nI=576;
|
||||
struct ResBlk{bool has=false;std::vector<float>W,b,lnG,lnb;};
|
||||
ResBlk res[4]; int n_res=0;
|
||||
|
||||
bool load(const std::string& dir) {
|
||||
std::ifstream cf(dir+"/model.ckpt"); if(!cf)return false;
|
||||
std::stringstream cb;cb<<cf.rdbuf();
|
||||
JVal ck=JParser().parse(cb.str()); if(ck.type!=JVal::OBJ)return false;
|
||||
std::ifstream bf(dir+"/model.bin",std::ios::binary|std::ios::ate); if(!bf)return false;
|
||||
size_t bz=bf.tellg();bf.seekg(0); if(bz%4!=0||bz==0)return false;
|
||||
std::vector<float> bin(bz/4); bf.read((char*)bin.data(),bz);
|
||||
auto sl=[&](int64_t o,int64_t c)->std::vector<float>{
|
||||
if(o+c>(int64_t)bin.size())return{}; return std::vector<float>(bin.begin()+o,bin.begin()+o+c);
|
||||
};
|
||||
auto ld=[&](const std::string& k, std::vector<float>* v,int* r0=nullptr,int* r1=nullptr,int* r2=nullptr)->bool{
|
||||
auto* e=ck.get(k);if(!e)return false;
|
||||
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
|
||||
if(!sv||!ov||!cv)return false;
|
||||
*v=sl(ov->as_i64(),cv->as_i64());
|
||||
if(r0)*r0=sv->arr.size()>=1?sv->arr[0].as_int():1;
|
||||
if(r1)*r1=sv->arr.size()>=2?sv->arr[1].as_int():1;
|
||||
if(r2)*r2=sv->arr.size()>=3?sv->arr[2].as_int():1;
|
||||
return!v->empty();
|
||||
};
|
||||
for(int ei=0;;ei++){
|
||||
auto* e=ck.get("embed_"+std::to_string(ei)+"_E");if(!e)break;
|
||||
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
|
||||
if(!sv||!ov||!cv)break;int rs=sv->arr[0].as_int(),nO=sv->arr[1].as_int();
|
||||
int64_t exp=(int64_t)rs*nO;if(cv->as_i64()<exp)break;
|
||||
auto d=sl(ov->as_i64(),cv->as_i64());if(d.empty())break;
|
||||
embeds.emplace_back();embeds.back().load(rs,nO,d.data());
|
||||
}
|
||||
std::ifstream ff(dir+"/feature_config.json");
|
||||
if(ff){std::stringstream fb;fb<<ff.rdbuf();auto cfg=JParser().parse(fb.str());auto* sa=cfg.get("embed_seeds");
|
||||
if(sa&&sa->type==JVal::ARR)for(int i=0;i<(int)sa->arr.size()&&i<(int)embeds.size();i++)embeds[i].seed=(uint32_t)sa->arr[i].as_int();}
|
||||
int r0=0,r1=0,r2=0;
|
||||
if(!ld("poW",&poW,&r0,&r1,&r2))return false;
|
||||
po_nO=r0;po_nP=r1;po_nI=r2;
|
||||
if(!ld("poB",&poB))return false;
|
||||
if(ld("poG",&poG)){ld("poB2",&poB2);has_poLN=true;}
|
||||
for(int ri=0;ri<4;ri++){auto pk="res"+std::to_string(ri);auto& rb=res[ri];
|
||||
if(ld(pk+"W",&rb.W,&r0,&r1,&r2)){ld(pk+"B",&rb.b);ld(pk+"lnG",&rb.lnG);ld(pk+"lnb",&rb.lnb);rb.has=true;n_res++;}}
|
||||
return!embeds.empty();
|
||||
}
|
||||
|
||||
// Run tok2vec → (n_tokens, 96)
|
||||
void forward(const std::vector<std::string>& tokens, float* out) {
|
||||
int n=(int)tokens.size(),D=96,NE=(int)embeds.size(),EC=NE*D;
|
||||
std::vector<float> emb((size_t)n*EC,0);
|
||||
for(int i=0;i<n;i++){
|
||||
auto ids=extract_features(tokens[i],NE);
|
||||
size_t b=(size_t)i*EC;
|
||||
for(int e=0;e<NE;e++)embeds[e].embed(ids[e],emb.data()+b+(size_t)e*D);
|
||||
}
|
||||
std::vector<float> pe((size_t)n*D);
|
||||
for(int i=0;i<n;i++)maxout(pe.data()+(size_t)i*D,emb.data()+(size_t)i*EC,poW.data(),poB.data(),D,po_nP,EC);
|
||||
std::vector<float> pln((size_t)n*D,0);
|
||||
if(has_poLN)for(int i=0;i<n;i++)layernorm(pln.data()+(size_t)i*D,pe.data()+(size_t)i*D,D,poG.data(),poB2.data(),1e-6f);else pln=pe;
|
||||
std::vector<float> enc=pln;
|
||||
for(int ri=0;ri<n_res;ri++){if(!res[ri].has)continue;
|
||||
int wd=D*3;std::vector<float> exp((size_t)n*wd);
|
||||
for(int i=0;i<n;i++)expand_win(exp.data()+(size_t)i*wd,enc.data(),n,D,i);
|
||||
std::vector<float> mx((size_t)n*D);for(int i=0;i<n;i++)maxout(mx.data()+(size_t)i*D,exp.data()+(size_t)i*wd,res[ri].W.data(),res[ri].b.data(),D,3,wd);
|
||||
std::vector<float> ln((size_t)n*D);if(!res[ri].lnG.empty())for(int i=0;i<n;i++)layernorm(ln.data()+(size_t)i*D,mx.data()+(size_t)i*D,D,res[ri].lnG.data(),res[ri].lnb.data(),1e-6f);else ln=mx;
|
||||
for(int i=0;i<n;i++){float* op=enc.data()+(size_t)i*D;for(int j=0;j<D;j++)op[j]+=ln[(size_t)i*D+j];}
|
||||
}
|
||||
memcpy(out,enc.data(),(size_t)n*D*sizeof(float));
|
||||
}
|
||||
};
|
||||
|
||||
// =========================================================================
|
||||
// Arc-hybrid Parser
|
||||
// =========================================================================
|
||||
struct ParserModel {
|
||||
int nO=64,nP=8,nI=2,n_actions=0;
|
||||
std::vector<float> pW_hid,pb_hid; // 96→64
|
||||
std::vector<float> pW_pre,pb_pre,pad_pre; // preaffine
|
||||
std::vector<float> pW_cls,pb_cls; // classifier
|
||||
std::vector<std::string> move_names;
|
||||
|
||||
bool load(const std::string& dir) {
|
||||
std::ifstream cf(dir+"/model.ckpt"); if(!cf)return false;
|
||||
std::stringstream cb;cb<<cf.rdbuf();
|
||||
JVal ck=JParser().parse(cb.str()); if(ck.type!=JVal::OBJ)return false;
|
||||
std::ifstream bf(dir+"/model.bin",std::ios::binary|std::ios::ate); if(!bf)return false;
|
||||
size_t bz=bf.tellg();bf.seekg(0); if(bz%4!=0||bz==0)return false;
|
||||
std::vector<float> bin(bz/4); bf.read((char*)bin.data(),bz);
|
||||
auto sl=[&](int64_t o,int64_t c)->std::vector<float>{
|
||||
if(o+c>(int64_t)bin.size())return{}; return std::vector<float>(bin.begin()+o,bin.begin()+o+c);
|
||||
};
|
||||
auto ld=[&](const std::string& k, std::vector<float>* v,int* r0=nullptr,int* r1=nullptr,int* r2=nullptr)->bool{
|
||||
auto* e=ck.get(k);if(!e)return false;
|
||||
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
|
||||
if(!sv||!ov||!cv)return false;
|
||||
*v=sl(ov->as_i64(),cv->as_i64());
|
||||
if(r0)*r0=sv->arr.size()>=1?sv->arr[0].as_int():1;
|
||||
if(r1)*r1=sv->arr.size()>=2?sv->arr[1].as_int():1;
|
||||
if(r2)*r2=sv->arr.size()>=3?sv->arr[2].as_int():1;
|
||||
return!v->empty();
|
||||
};
|
||||
int r0,r1,r2;
|
||||
if(!ld("pW_hid",&pW_hid,&r0,&r1))return false;
|
||||
nO=r0;ld("pb_hid",&pb_hid);
|
||||
if(!ld("pW_pre",&pW_pre,&r0,&r1,&r2))return false;
|
||||
nP=r0;nO=r1;nI=r2;
|
||||
ld("pb_pre",&pb_pre);ld("pad_pre",&pad_pre);
|
||||
if(!ld("pW_cls",&pW_cls,&r0,&r1))return false;
|
||||
n_actions=r0;ld("pb_cls",&pb_cls);
|
||||
std::ifstream mf(dir+"/meta.json");
|
||||
if(mf){std::stringstream mb;mb<<mf.rdbuf();auto meta=JParser().parse(mb.str());auto* mn=meta.get("move_names");
|
||||
if(mn&&mn->type==JVal::ARR)for(auto& v:mn->arr)move_names.push_back(v.str);}
|
||||
return!pW_hid.empty() && !pW_pre.empty() && !pW_cls.empty();
|
||||
}
|
||||
|
||||
// Run parser forward + state machine → (heads, labels)
|
||||
void parse(const float* tokvecs, int n_tokens,
|
||||
std::vector<int>& out_heads, std::vector<std::string>& out_labels) {
|
||||
// 1. Hidden layer: 96→64
|
||||
std::vector<float> hidden((size_t)n_tokens*nO,0);
|
||||
for(int i=0;i<n_tokens;i++) linear(hidden.data()+(size_t)i*nO, tokvecs+(size_t)i*96,
|
||||
pW_hid.data(), pb_hid.data(), nO, 96);
|
||||
|
||||
// 2. Pre-compute features
|
||||
std::vector<float> precomp((size_t)(n_tokens+1)*nP*nO*nI,0);
|
||||
// Pad token (index 0)
|
||||
memcpy(precomp.data(), pad_pre.data(), (size_t)nP*nO*nI*sizeof(float));
|
||||
// Real tokens
|
||||
for(int i=0;i<n_tokens;i++){
|
||||
size_t toff = (size_t)(i+1)*nP*nO*nI;
|
||||
for(int p=0;p<nP;p++){
|
||||
for(int w=0;w<nI;w++){
|
||||
size_t base = toff + (size_t)p*nO*nI + (size_t)w;
|
||||
float* out = precomp.data() + base;
|
||||
// W[p][w][o][d] = [nP][nI][nO][nO]
|
||||
for(int o=0;o<nO;o++){
|
||||
float s = pb_pre[(size_t)w*nO + o];
|
||||
for(int d=0;d<nO;d++) s += pW_pre[((size_t)p*nO*nI + (size_t)o*nI + w)*nO + d] * hidden[(size_t)i*nO + d];
|
||||
out[(size_t)o*nI] = s;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper: get feature value for a token at piece p, window w, output dim o
|
||||
auto feat = [&](int idx, int p, int w, int o) -> float {
|
||||
int ri = (idx < 0 || idx >= n_tokens) ? 0 : idx + 1;
|
||||
return precomp[(size_t)ri*nP*nO*nI + (size_t)p*nO*nI + (size_t)w + (size_t)o*nI];
|
||||
};
|
||||
|
||||
// 3. Arc-hybrid state machine
|
||||
out_heads.assign(n_tokens, -1);
|
||||
out_labels.assign(n_tokens, "");
|
||||
std::vector<int> stack;
|
||||
std::vector<int> buffer(n_tokens);
|
||||
for(int i=0;i<n_tokens;i++) buffer[i]=i;
|
||||
|
||||
// Validate move_names covers all actions before indexing
|
||||
if((int)move_names.size()!=n_actions){out_heads.clear();out_labels.clear();return;}
|
||||
int act_S=-1, act_D=-1;
|
||||
for(int i=0;i<(int)move_names.size();i++){
|
||||
if(move_names[i]=="S") act_S=i;
|
||||
if(move_names[i]=="D") act_D=i;
|
||||
}
|
||||
|
||||
auto leftmost = [&](int idx)->int{
|
||||
for(int i=0;i<n_tokens;i++) if(out_heads[i]==idx) return i;
|
||||
return -1;
|
||||
};
|
||||
auto rightmost = [&](int idx)->int{
|
||||
int r=-1; for(int i=0;i<n_tokens;i++) if(out_heads[i]==idx) r=i;
|
||||
return r;
|
||||
};
|
||||
|
||||
std::vector<float> scores(n_actions,0);
|
||||
std::vector<float> feats(nP*nI*nO,0);
|
||||
|
||||
for(int step=0; step<n_tokens*4 && !(buffer.empty()&&stack.size()<=1); step++){
|
||||
int s0=stack.empty()?-1:stack.back();
|
||||
int s1=stack.size()<2?-1:stack[stack.size()-2];
|
||||
int s2=stack.size()<3?-1:stack[stack.size()-3];
|
||||
int b0=buffer.empty()?-1:buffer[0];
|
||||
int b1=buffer.size()<2?-1:buffer[1];
|
||||
|
||||
// Build feature indices (same as verified Python implementation)
|
||||
int idxs[16]={s0,s1, b0,s0, s0,leftmost(s0), s0,rightmost(s0),
|
||||
s1,leftmost(s1), s1,rightmost(s1), s2,b1, b0,b1};
|
||||
|
||||
// Build feature vector: sum of precomputed features at each (idx, piece, window)
|
||||
for(int o=0;o<nO;o++) feats[o]=0;
|
||||
for(int p=0;p<nP;p++){
|
||||
for(int w=0;w<nI;w++){
|
||||
int ti=idxs[p*2+w];
|
||||
for(int o=0;o<nO;o++){
|
||||
feats[(size_t)p*nI*nO + (size_t)w*nO + o] = feat(ti,p,w,o);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Classify
|
||||
for(int a=0;a<n_actions;a++){
|
||||
float s=pb_cls[a];
|
||||
// Use HIDDEN state directly (64-dim) as classifier input
|
||||
int cls_idx = b0 >= 0 ? b0 : (s0 >= 0 ? s0 : 0);
|
||||
for(int j=0;j<nO;j++) s += pW_cls[(size_t)a*nO+j] * hidden[(size_t)cls_idx*nO+j];
|
||||
scores[a]=s;
|
||||
}
|
||||
|
||||
// Pick best VALID action
|
||||
int best=-1; float best_sc=-1e30f;
|
||||
for(int a=0;a<n_actions;a++){
|
||||
bool valid=false;
|
||||
const std::string& n=move_names[a];
|
||||
if(n=="S") valid=!buffer.empty() && (int)stack.size()<n_tokens;
|
||||
else if(n=="D") valid=!stack.empty();
|
||||
else if(n.size()>=2 && (n[0]=='L'||n[0]=='R')) valid=stack.size()>=2;
|
||||
if(valid && scores[a]>best_sc){best_sc=scores[a];best=a;}
|
||||
}
|
||||
if(best<0) break;
|
||||
|
||||
const std::string& act=move_names[best];
|
||||
if(act=="S"){stack.push_back(buffer[0]);buffer.erase(buffer.begin());}
|
||||
else if(act=="D"){stack.pop_back();}
|
||||
else if(act.size()>=2){
|
||||
std::string lbl=act.substr(2);
|
||||
if(act[0]=='L'){out_heads[s0]=s1;out_labels[s0]=lbl;stack.erase(stack.end()-2);}
|
||||
else{out_heads[s1]=s0;out_labels[s1]=lbl;stack.pop_back();}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// =========================================================================
|
||||
// Combined state
|
||||
// =========================================================================
|
||||
struct ParserState {
|
||||
Tok2vecModel tok2vec;
|
||||
ParserModel parser;
|
||||
bool loaded=false;
|
||||
};
|
||||
|
||||
struct TaggerState {
|
||||
Tok2vecModel tok2vec;
|
||||
std::vector<float> tW,tb; // (n_tags, 96), (n_tags,)
|
||||
std::vector<std::string> tags;
|
||||
bool loaded=false;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// =========================================================================
|
||||
// C API — Parser
|
||||
// =========================================================================
|
||||
ThincParserHandle ThincParser_Create(const char* ner_dir, const char* parser_dir) {
|
||||
auto* s=new ParserState();
|
||||
if(!ner_dir||!parser_dir){delete s;return nullptr;}
|
||||
// Load PIPELINE tok2vec from <base>/tok2vec/ subdirectory (not NER's internal 4HE).
|
||||
// ner_dir is typically <model_base>/ner/.
|
||||
std::string base = std::string(ner_dir);
|
||||
if(base.size()>=4 && base.substr(base.size()-4)=="/ner") base.resize(base.size()-4);
|
||||
if(!s->tok2vec.load(base+"/tok2vec")){delete s;return nullptr;}
|
||||
if(!s->parser.load(std::string(parser_dir))){delete s;return nullptr;}
|
||||
s->loaded=true;
|
||||
return s;
|
||||
}
|
||||
|
||||
void ThincParser_Destroy(ThincParserHandle h) { delete (ParserState*)h; }
|
||||
|
||||
char* ThincParser_Predict(ThincParserHandle h, const char* tokens_json) {
|
||||
auto* s=(ParserState*)h;
|
||||
if(!s||!s->loaded||!tokens_json) return strdup("[]");
|
||||
|
||||
auto j=JParser().parse(std::string(tokens_json));
|
||||
if(j.type!=JVal::ARR) return strdup("[]");
|
||||
std::vector<std::string> tokens;
|
||||
for(auto& v:j.arr) tokens.push_back(v.str);
|
||||
int n=(int)tokens.size();
|
||||
if(!n) return strdup("[]");
|
||||
|
||||
// Run tok2vec
|
||||
std::vector<float> tokvecs((size_t)n*96,0);
|
||||
s->tok2vec.forward(tokens, tokvecs.data());
|
||||
|
||||
// Run parser
|
||||
std::vector<int> heads;
|
||||
std::vector<std::string> labels;
|
||||
s->parser.parse(tokvecs.data(), n, heads, labels);
|
||||
|
||||
// Build JSON output
|
||||
std::string r="[";
|
||||
for(int i=0;i<n;i++){
|
||||
if(i)r+=",";
|
||||
r+="{\"text\":\""+tokens[i]+"\",\"head\":"+std::to_string(heads[i])+
|
||||
",\"dep\":\""+labels[i]+"\",\"index\":"+std::to_string(i)+"}";
|
||||
}
|
||||
r+="]";
|
||||
return strdup(r.c_str());
|
||||
}
|
||||
|
||||
void ThincParser_FreeString(char* p) { free(p); }
|
||||
|
||||
// =========================================================================
|
||||
// C API — Tagger
|
||||
// =========================================================================
|
||||
ThincTaggerHandle ThincTagger_Create(const char* ner_dir, const char* tagger_dir) {
|
||||
auto* s=new TaggerState();
|
||||
if(!ner_dir||!tagger_dir){delete s;return nullptr;}
|
||||
// Load PIPELINE tok2vec from <base>/tok2vec/ (6HE, not NER's internal 4HE)
|
||||
std::string tbase = std::string(ner_dir);
|
||||
if(tbase.size()>=4 && tbase.substr(tbase.size()-4)=="/ner") tbase.resize(tbase.size()-4);
|
||||
if(!s->tok2vec.load(tbase+"/tok2vec")){delete s;return nullptr;}
|
||||
std::ifstream cf(std::string(tagger_dir)+"/model.ckpt"); if(!cf){delete s;return nullptr;}
|
||||
std::stringstream cb;cb<<cf.rdbuf();
|
||||
JVal ck=JParser().parse(cb.str()); if(ck.type!=JVal::OBJ){delete s;return nullptr;}
|
||||
std::ifstream bf(std::string(tagger_dir)+"/model.bin",std::ios::binary|std::ios::ate); if(!bf){delete s;return nullptr;}
|
||||
size_t bz=bf.tellg();bf.seekg(0); std::vector<float> bin(bz/4); bf.read((char*)bin.data(),bz);
|
||||
auto sl=[&](int64_t o,int64_t c)->std::vector<float>{
|
||||
if(o+c>(int64_t)bin.size())return{}; return std::vector<float>(bin.begin()+o,bin.begin()+o+c);
|
||||
};
|
||||
auto ld=[&](const std::string& k, std::vector<float>* v,int* r0=nullptr)->bool{
|
||||
auto* e=ck.get(k);if(!e)return false;
|
||||
auto sv=e->get("shape"),ov=e->get("offset"),cv=e->get("count");
|
||||
if(!sv||!ov||!cv)return false;
|
||||
*v=sl(ov->as_i64(),cv->as_i64());
|
||||
if(r0)*r0=sv->arr.size()>=1?sv->arr[0].as_int():1;
|
||||
return!v->empty();
|
||||
};
|
||||
int r0=0; ld("tW",&s->tW,&r0); ld("tb",&s->tb);
|
||||
std::ifstream mf(std::string(tagger_dir)+"/meta.json");
|
||||
if(mf){std::stringstream mb;mb<<mf.rdbuf();auto meta=JParser().parse(mb.str());auto* tg=meta.get("tags");
|
||||
if(tg&&tg->type==JVal::ARR)for(auto& v:tg->arr)s->tags.push_back(v.str);}
|
||||
s->loaded=!s->tW.empty();
|
||||
return s;
|
||||
}
|
||||
|
||||
void ThincTagger_Destroy(ThincTaggerHandle h) { delete (TaggerState*)h; }
|
||||
|
||||
char* ThincTagger_Predict(ThincTaggerHandle h, const char* tokens_json) {
|
||||
auto* s=(TaggerState*)h;
|
||||
if(!s||!s->loaded||!tokens_json||s->tW.empty()) return strdup("[]");
|
||||
auto j=JParser().parse(std::string(tokens_json));
|
||||
if(j.type!=JVal::ARR)return strdup("[]");
|
||||
std::vector<std::string> tokens;
|
||||
for(auto& v:j.arr) tokens.push_back(v.str);
|
||||
int n=(int)tokens.size(), n_tags=(int)s->tW.size()/96;
|
||||
if(!n||!n_tags)return strdup("[]");
|
||||
|
||||
// Run tok2vec to get 96-dim embeddings, then softmax + argmax
|
||||
std::vector<float> tokvecs((size_t)n*96,0);
|
||||
s->tok2vec.forward(tokens, tokvecs.data());
|
||||
|
||||
std::vector<int> best_tags(n, 0);
|
||||
for(int i=0;i<n;i++){
|
||||
float best_sc=-1e30f;
|
||||
for(int t=0;t<n_tags;t++){
|
||||
float sc=s->tb[t];
|
||||
for(int j=0;j<96;j++) sc += s->tW[(size_t)t*96+j] * tokvecs[(size_t)i*96+j];
|
||||
if(sc>best_sc){best_sc=sc;best_tags[i]=t;}
|
||||
}
|
||||
}
|
||||
|
||||
// Strip morphologizer output to just POS (e.g. "Gender=Masc|Number=Sing|POS=NOUN" → "NOUN")
|
||||
// For non-morphologizer models the tag string is used as-is.
|
||||
auto pos_only = [](const std::string& t) -> std::string {
|
||||
auto p = t.find("POS=");
|
||||
if(p==std::string::npos) return t;
|
||||
auto s = p+4;
|
||||
auto e = t.find_first_of("|;", s);
|
||||
if(e==std::string::npos) e = t.size();
|
||||
return t.substr(s, e-s);
|
||||
};
|
||||
std::string r="[";
|
||||
for(int i=0;i<n;i++){
|
||||
if(i)r+=",";
|
||||
std::string tag = best_tags[i] < (int)s->tags.size() ? s->tags[best_tags[i]] : "";
|
||||
r+="{\"text\":\""+tokens[i]+"\",\"tag\":\""+pos_only(tag)+"\",\"index\":"+std::to_string(i)+"}";
|
||||
}
|
||||
r+="]";
|
||||
return strdup(r.c_str());
|
||||
}
|
||||
|
||||
void ThincTagger_FreeString(char* p) { free(p); }
|
||||
49
internal/cpp/thinc_parser.h
Normal file
49
internal/cpp/thinc_parser.h
Normal file
@@ -0,0 +1,49 @@
|
||||
#ifndef THINC_PARSER_H
|
||||
#define THINC_PARSER_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// C API for spaCy dependency parser and POS tagger
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
typedef void* ThincParserHandle;
|
||||
typedef void* ThincTaggerHandle;
|
||||
|
||||
// Parser: create/destroy
|
||||
// model_ner_dir: path to NER model directory (shared tok2vec weights + vocab).
|
||||
// model_parser_dir: path to parser model directory.
|
||||
ThincParserHandle ThincParser_Create(const char* model_ner_dir, const char* model_parser_dir);
|
||||
void ThincParser_Destroy(ThincParserHandle handle);
|
||||
|
||||
// Run dependency parser on pre-tokenized text.
|
||||
// tokens_json: JSON array of token strings, e.g. ["Apple", "was", "founded", ...]
|
||||
// Returns JSON array of token annotations:
|
||||
// [{"text":"Apple","head":2,"dep":"nsubjpass","index":0}, ...]
|
||||
// where head is the index of the head token (0-based), -1 for root.
|
||||
// Caller must free with ThincParser_FreeString.
|
||||
char* ThincParser_Predict(ThincParserHandle handle, const char* tokens_json);
|
||||
void ThincParser_FreeString(char* ptr);
|
||||
|
||||
// Tagger: create/destroy
|
||||
// model_ner_dir: path to NER model directory (for tok2vec weights)
|
||||
// model_tagger_dir: path to tagger model directory
|
||||
ThincTaggerHandle ThincTagger_Create(const char* model_ner_dir, const char* model_tagger_dir);
|
||||
void ThincTagger_Destroy(ThincTaggerHandle handle);
|
||||
|
||||
// Run POS tagger on pre-tokenized text.
|
||||
// tokens_json: JSON array of token strings.
|
||||
// Returns JSON array: [{"text":"Apple","tag":"NNP","index":0}, ...]
|
||||
char* ThincTagger_Predict(ThincTaggerHandle handle, const char* tokens_json);
|
||||
void ThincTagger_FreeString(char* ptr);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // THINC_PARSER_H
|
||||
788
internal/ingestion/compilation/extractor/dep_relation.go
Normal file
788
internal/ingestion/compilation/extractor/dep_relation.go
Normal file
@@ -0,0 +1,788 @@
|
||||
// Go implementation of dependency-based relation extraction.
|
||||
// Direct port of Python DepRelationExtractor — semantica-aligned.
|
||||
// Operates on a dependency tree (heads + labels) independent of parser.
|
||||
|
||||
package extractor
|
||||
|
||||
import "strings"
|
||||
|
||||
// Verb lemmatization — multi-language
|
||||
var verbLemma = map[string]string{
|
||||
// English
|
||||
"founded": "found", "founding": "found",
|
||||
"works": "work", "working": "work",
|
||||
"based": "base", "basing": "base",
|
||||
"located": "locate", "locating": "locate",
|
||||
"situated": "situate", "situating": "situate",
|
||||
"acquired": "acquire", "acquiring": "acquire",
|
||||
"employed": "employ", "employing": "employ",
|
||||
"hired": "hire", "hiring": "hire",
|
||||
"born": "bear",
|
||||
"joined": "join", "joining": "join",
|
||||
"merged": "merge", "merging": "merge",
|
||||
"bought": "buy", "buying": "buy",
|
||||
"created": "create", "creating": "create",
|
||||
"established": "establish", "establishing": "establish",
|
||||
"started": "start", "starting": "start",
|
||||
"led": "lead", "leading": "lead",
|
||||
"managed": "manage", "managing": "manage",
|
||||
"headed": "head", "heading": "head",
|
||||
"ran": "run", "running": "run",
|
||||
"owned": "own", "owning": "own",
|
||||
"developed": "develop", "developing": "develop",
|
||||
"wrote": "write", "written": "write", "writing": "write",
|
||||
"published": "publish", "publishing": "publish",
|
||||
"invested": "invest", "investing": "invest",
|
||||
"partnered": "partner", "partnering": "partner",
|
||||
"collaborated": "collaborate", "collaborating": "collaborate",
|
||||
"sets": "set",
|
||||
// German
|
||||
"gegründet": "gründen", "gründete": "gründen",
|
||||
"arbeitet": "arbeiten", "arbeitete": "arbeiten",
|
||||
"befindet": "befinden",
|
||||
"liegt": "liegen", "lag": "liegen",
|
||||
"geboren": "gebären",
|
||||
"erworben": "erwerben", "erwarb": "erwerben",
|
||||
"gekauft": "kaufen", "kaufte": "kaufen",
|
||||
"übernommen": "übernehmen", "übernahm": "übernehmen",
|
||||
// French
|
||||
"fondé": "fonder", "fondée": "fonder",
|
||||
"créé": "créer", "créée": "créer",
|
||||
"travaille": "travailler",
|
||||
"employé": "employer", "employée": "employer",
|
||||
"situé": "situer", "située": "situer",
|
||||
"né": "naître", "née": "naître",
|
||||
"acquis": "acquérir",
|
||||
// Spanish + Portuguese (shared forms)
|
||||
"fundado": "fundar", "fundada": "fundar",
|
||||
"creado": "crear", "creada": "crear",
|
||||
"criado": "criar", "criada": "criar",
|
||||
"trabaja": "trabajar", "trabalha": "trabalhar",
|
||||
"ubicado": "ubicar", "ubicada": "ubicar",
|
||||
"situado": "situar", "situada": "situar",
|
||||
"localizado": "localizar", "localizada": "localizar",
|
||||
"sediado": "sediar", "sediada": "sediar",
|
||||
"nacido": "nacer", "nacida": "nacer",
|
||||
"nascido": "nascer", "nascida": "nascer",
|
||||
}
|
||||
|
||||
func lemma(w string) string {
|
||||
if l, ok := verbLemma[w]; ok {
|
||||
return l
|
||||
}
|
||||
return w
|
||||
}
|
||||
|
||||
// Verb+prep → relation type (multi-language)
|
||||
// Keys: verbLemma+prep (or verbLemma alone for direct-object relations).
|
||||
// Matches Python _VERB_RELATIONS exactly.
|
||||
var depVerbRelations = map[string]string{
|
||||
// English
|
||||
"found+by": "founded_by", "co-found+by": "founded_by",
|
||||
"establish+by": "founded_by", "create+by": "founded_by",
|
||||
"set+up": "founded_by", "start+by": "founded_by",
|
||||
"work+for": "works_for", "employ+by": "works_for",
|
||||
"hire+by": "works_for", "join": "works_for",
|
||||
"lead+by": "works_for", "manage+by": "works_for",
|
||||
"head+by": "works_for", "run+by": "works_for",
|
||||
"own+by": "owns", "develop+by": "develops",
|
||||
"write+by": "wrote", "publish+by": "published",
|
||||
"invest+in": "invests_in", "partner+with": "partners_with",
|
||||
"collaborate+with": "collaborates_with",
|
||||
"merge+with": "merged_with", "subsidiar+y": "is_subsidiary_of",
|
||||
"base+in": "located_in", "locate+in": "located_in",
|
||||
"situate+in": "located_in", "headquarter+in": "located_in",
|
||||
"bear+in": "born_in", "bear+on": "born_in",
|
||||
"acquire+by": "acquired", "buy+by": "acquired",
|
||||
// German (de)
|
||||
"gründen+von": "founded_by", "errichten+von": "founded_by",
|
||||
"arbeiten+für": "works_for", "beschäftigen+bei": "works_for",
|
||||
"anstellen+bei": "works_for",
|
||||
"sich+befinden": "located_in", "liegen+in": "located_in",
|
||||
"sitzen+in": "located_in", "gebären+in": "born_in",
|
||||
"gebären+am": "born_in",
|
||||
"erwerben+durch": "acquired", "kaufen+durch": "acquired",
|
||||
"übernehmen+durch": "acquired",
|
||||
// French (fr)
|
||||
"fonder+par": "founded_by", "créer+par": "founded_by",
|
||||
"établir+par": "founded_by",
|
||||
"travailler+pour": "works_for", "employer+par": "works_for",
|
||||
"embaucher+par": "works_for",
|
||||
"situer+à": "located_in", "baser+à": "located_in",
|
||||
"implanter+à": "located_in",
|
||||
"naître+à": "born_in",
|
||||
"acquérir+par": "acquired", "racheter+par": "acquired",
|
||||
// Spanish (es)
|
||||
"fundar+por": "founded_by", "crear+por": "founded_by",
|
||||
"establecer+por": "founded_by",
|
||||
"trabajar+para": "works_for", "emplear+por": "works_for",
|
||||
"contratar+por": "works_for",
|
||||
"ubicar+en": "located_in", "situar+en": "located_in",
|
||||
"tener+sede": "located_in",
|
||||
"nacer+en": "born_in",
|
||||
"adquirir+por": "acquired", "comprar+por": "acquired",
|
||||
// Portuguese (pt)
|
||||
"criar+por": "founded_by",
|
||||
"estabelecer+por": "founded_by",
|
||||
"trabalhar+para": "works_for", "empregar+por": "works_for",
|
||||
"localizar+em": "located_in", "situar+em": "located_in",
|
||||
"sediar+em": "located_in",
|
||||
"nascer+em": "born_in",
|
||||
// Chinese (zh)
|
||||
"创立+由": "founded_by", "创建+由": "founded_by",
|
||||
"成立+由": "founded_by", "创办+由": "founded_by",
|
||||
"设立+由": "founded_by",
|
||||
"任职+于": "works_for", "就职+于": "works_for",
|
||||
"工作+在": "works_for", "位于+在": "located_in",
|
||||
"坐落+在": "located_in", "总部设+在": "located_in",
|
||||
"出生+在": "born_in", "出生+于": "born_in",
|
||||
"收购+由": "acquired", "并购+由": "acquired",
|
||||
// Japanese (ja)
|
||||
"設立+によって": "founded_by", "創立+によって": "founded_by",
|
||||
"勤務+で": "works_for", "在籍+で": "works_for",
|
||||
"位置+に": "located_in", "所在+に": "located_in",
|
||||
"本社+を": "located_in",
|
||||
"出生+に": "born_in",
|
||||
"買収+によって": "acquired",
|
||||
}
|
||||
|
||||
// Copula title patterns — X is [title] of Y → typed relation
|
||||
var depCopulaTitles = map[string][]string{
|
||||
"ceo": {"ceo_of", "works_for"}, "cto": {"works_for"},
|
||||
"cfo": {"works_for"}, "coo": {"works_for"},
|
||||
"vp": {"works_for"}, "director": {"works_for"},
|
||||
"manager": {"works_for"}, "engineer": {"works_for"},
|
||||
"employee": {"works_for"},
|
||||
"founder": {"founded_by"}, "co-founder": {"founded_by"},
|
||||
}
|
||||
|
||||
// Multi-hop inference rules: A rel1 B + B rel2 C ⇒ A rel3 C
|
||||
var multiHopRules = map[string]map[string]string{
|
||||
"ceo_of": {"is_subsidiary_of": "works_for", "located_in": "works_for"},
|
||||
"works_for": {"is_subsidiary_of": "works_for"},
|
||||
"founded_by": {"is_subsidiary_of": "founded_by"},
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Language-specific dependency role mappings
|
||||
// ---------------------------------------------------------------------------
|
||||
// Matches Python _LANG_DEP_RULES exactly for all 7 languages.
|
||||
|
||||
type roleSpec struct {
|
||||
dep string // main dependency label (e.g. "nsubj", "obl:agent")
|
||||
childDep string // child dep for compound rules (e.g. "pobj" for "agent"→"pobj")
|
||||
caseMarker string // optional case marker text (zh:"由", ja:"によって")
|
||||
}
|
||||
|
||||
var langRolesMap = map[string]map[string]roleSpec{
|
||||
"en": {
|
||||
"pass_subj": {dep: "nsubjpass"},
|
||||
"subj": {dep: "nsubj"},
|
||||
"agent": {dep: "agent", childDep: "pobj"},
|
||||
"dobj": {dep: "dobj"},
|
||||
"prep_obj": {dep: "prep", childDep: "pobj"},
|
||||
},
|
||||
"de": {
|
||||
"subj": {dep: "sb"},
|
||||
"agent": {dep: "sbp", childDep: "nk"},
|
||||
"prep_obj": {dep: "mo", childDep: "nk"},
|
||||
// German ROOT is aux verb, real verb has dep "oc"
|
||||
"root_verb_child": {dep: "oc"},
|
||||
},
|
||||
"fr": {
|
||||
"pass_subj": {dep: "nsubj:pass"},
|
||||
"subj": {dep: "nsubj"},
|
||||
"agent": {dep: "obl:agent"},
|
||||
"dobj": {dep: "obj"},
|
||||
"prep_obj": {dep: "case", childDep: "obl"},
|
||||
},
|
||||
"es": {
|
||||
"subj": {dep: "nsubj"},
|
||||
"agent": {dep: "obj"},
|
||||
"prep_obj": {dep: "case", childDep: "obl"},
|
||||
},
|
||||
"pt": {
|
||||
"pass_subj": {dep: "nsubj:pass"},
|
||||
"subj": {dep: "nsubj"},
|
||||
"agent": {dep: "obl:agent"},
|
||||
"dobj": {dep: "obj"},
|
||||
"prep_obj": {dep: "case", childDep: "obl"},
|
||||
},
|
||||
"zh": {
|
||||
"subj": {dep: "nsubj"},
|
||||
"agent": {dep: "nmod:prep", caseMarker: "由"},
|
||||
"dobj": {dep: "dobj"},
|
||||
"prep_obj": {dep: "case", childDep: "nmod"},
|
||||
},
|
||||
"ja": {
|
||||
"subj": {dep: "nsubj"},
|
||||
"agent": {dep: "obl", caseMarker: "によって"},
|
||||
"dobj": {dep: "dobj"},
|
||||
"prep_obj": {dep: "case", childDep: "obl"},
|
||||
},
|
||||
}
|
||||
|
||||
// Copula dependency labels per language (attr_deps, prep_deps, obj_deps).
|
||||
// Matches Python _extract_copula label sets.
|
||||
var copulaDeps = map[string]struct {
|
||||
attrDeps []string
|
||||
prepDeps []string
|
||||
objDeps []string
|
||||
}{
|
||||
"en": {attrDeps: []string{"attr"}, prepDeps: []string{"prep"}, objDeps: []string{"pobj"}},
|
||||
"de": {attrDeps: []string{"pred"}, prepDeps: []string{"mo"}, objDeps: []string{"nk"}},
|
||||
"fr": {attrDeps: []string{"attr"}, prepDeps: []string{"case"}, objDeps: []string{"obl"}},
|
||||
"es": {attrDeps: []string{"attr"}, prepDeps: []string{"case"}, objDeps: []string{"obl"}},
|
||||
"pt": {attrDeps: []string{"attr"}, prepDeps: []string{"case"}, objDeps: []string{"obl"}},
|
||||
"zh": {attrDeps: []string{"attr"}, prepDeps: []string{"case"}, objDeps: []string{"nmod"}},
|
||||
"ja": {attrDeps: []string{"attr"}, prepDeps: []string{"case"}, objDeps: []string{"obl"}},
|
||||
}
|
||||
|
||||
// be-verb surface forms across all 7 languages (used for copula detection).
|
||||
// Duplicate strings between languages are normalised into a single entry.
|
||||
var beVerbs = map[string]bool{
|
||||
// English
|
||||
"is": true, "are": true, "was": true, "were": true, "be": true, "been": true, "being": true,
|
||||
// German
|
||||
"ist": true, "sind": true, "bin": true, "bist": true, "seid": true,
|
||||
"war": true, "waren": true, "gewesen": true, "sein": true,
|
||||
// French
|
||||
"est": true, "suis": true, "sommes": true, "êtes": true, "sont": true,
|
||||
"était": true, "étant": true, "être": true,
|
||||
// Spanish + Portuguese (shared forms; french "es" omitted to avoid key collision)
|
||||
"es": true, "é": true, "son": true, "são": true,
|
||||
"está": true, "están": true, "estão": true,
|
||||
"era": true, "eran": true, "eram": true,
|
||||
"ser": true, "sido": true, "siendo": true, "sendo": true,
|
||||
}
|
||||
|
||||
// DepToken holds token dependency info from the C++ parser.
|
||||
type DepToken struct {
|
||||
Text string `json:"text"`
|
||||
Head int `json:"head"`
|
||||
Dep string `json:"dep"`
|
||||
Index int `json:"index"`
|
||||
POS string `json:"pos,omitempty"`
|
||||
}
|
||||
|
||||
// roleResult is the return type for getByRole.
|
||||
type roleResult struct {
|
||||
entity Entity
|
||||
prep string // prep lemma for prep_obj role; empty otherwise
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main entry point
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// DepExtractRelations extracts typed relations from a dependency parse tree.
|
||||
// lang is the language code (en/zh/de/fr/es/pt/ja).
|
||||
// maxDistance is the max character distance for co-occurrence (0 = use default 100).
|
||||
func DepExtractRelations(text string, tokens []DepToken, entities []Entity, lang string, maxDistance int) []Relation {
|
||||
entityMap := buildEntityMapMulti(entities)
|
||||
var relations []Relation
|
||||
|
||||
// Detect German-style "oc" handling
|
||||
_, hasRootVerbChild := langRolesMap[lang]["root_verb_child"]
|
||||
|
||||
for _, tok := range tokens {
|
||||
if tok.Head != tok.Index {
|
||||
continue
|
||||
}
|
||||
// German: ROOT is aux verb; real verb is an "oc" child
|
||||
if hasRootVerbChild {
|
||||
for _, c := range childrenOf(tok.Index, tokens) {
|
||||
if c.Dep == langRolesMap[lang]["root_verb_child"].dep {
|
||||
if hasNegation(c.Index, tokens) {
|
||||
continue
|
||||
}
|
||||
rels := extractFromRoot(text, c.Index, tokens, entityMap, lang)
|
||||
relations = append(relations, rels...)
|
||||
if isBeVerb(tok) {
|
||||
rels := extractCopula(text, c.Index, tokens, entityMap, lang)
|
||||
relations = append(relations, rels...)
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Standard languages: ROOT = main verb
|
||||
if hasNegation(tok.Index, tokens) {
|
||||
continue
|
||||
}
|
||||
rels := extractFromRoot(text, tok.Index, tokens, entityMap, lang)
|
||||
relations = append(relations, rels...)
|
||||
if isBeVerb(tok) {
|
||||
rels := extractCopula(text, tok.Index, tokens, entityMap, lang)
|
||||
relations = append(relations, rels...)
|
||||
}
|
||||
}
|
||||
|
||||
// Co-occurrence (always, matching Python DepRelationExtractor.extract())
|
||||
if maxDistance <= 0 {
|
||||
maxDistance = 100
|
||||
}
|
||||
relations = append(relations, extractCooccurrence(text, entities, maxDistance)...)
|
||||
|
||||
relations = inferMultiHop(relations)
|
||||
relations = dedupRelations(relations)
|
||||
return relations
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Negation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func hasNegation(idx int, tokens []DepToken) bool {
|
||||
for _, t := range tokens {
|
||||
if t.Head == idx && (t.Dep == "neg" || t.Dep == "advmod:neg") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isBeVerb(tok DepToken) bool {
|
||||
return beVerbs[strings.ToLower(tok.Text)]
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Multi-hop inference
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func inferMultiHop(rels []Relation) []Relation {
|
||||
bySubj := make(map[string][]Relation)
|
||||
for _, r := range rels {
|
||||
if r.Predicate == "related_to" {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(r.Subject.Text)
|
||||
bySubj[key] = append(bySubj[key], r)
|
||||
}
|
||||
|
||||
var inferred []Relation
|
||||
for _, r := range rels {
|
||||
if r.Predicate == "related_to" {
|
||||
continue
|
||||
}
|
||||
objKey := strings.ToLower(r.Object.Text)
|
||||
if chain, ok := bySubj[objKey]; ok {
|
||||
for _, r2 := range chain {
|
||||
if hopRules, ok := multiHopRules[r.Predicate]; ok {
|
||||
if inferredPred, ok := hopRules[r2.Predicate]; ok {
|
||||
conf := r.Confidence
|
||||
if r2.Confidence < conf {
|
||||
conf = r2.Confidence
|
||||
}
|
||||
conf *= 0.9
|
||||
r := Relation{
|
||||
Subject: r.Subject,
|
||||
Predicate: inferredPred,
|
||||
Object: r2.Object,
|
||||
Confidence: conf,
|
||||
Metadata: map[string]interface{}{
|
||||
"method": "multi_hop",
|
||||
"via": r.Predicate + "→" + r2.Predicate,
|
||||
},
|
||||
}
|
||||
inferred = append(inferred, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return append(rels, inferred...)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Entity map
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func buildEntityMapMulti(entities []Entity) map[string][]Entity {
|
||||
m := make(map[string][]Entity)
|
||||
for _, e := range entities {
|
||||
key := strings.ToLower(e.Text)
|
||||
m[key] = append(m[key], e)
|
||||
cleaned := strings.TrimRight(e.Text, ".,;:!?")
|
||||
cleaned = strings.TrimSpace(cleaned)
|
||||
if cleaned != e.Text {
|
||||
ckey := strings.ToLower(cleaned)
|
||||
m[ckey] = append(m[ckey], e)
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func findBestEntity(key string, entityMap map[string][]Entity) *Entity {
|
||||
entries := entityMap[strings.ToLower(strings.TrimSpace(key))]
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &entries[0]
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Language-aware role lookup (replaces old getChildEntity/getAgentPobj/getPrepObjs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// getByRole returns matching (entity, prep) pairs for a semantic role.
|
||||
// Matches Python DepRelationExtractor._get_by_role.
|
||||
func getByRole(lang string, role string, rootIdx int, tokens []DepToken, entityMap map[string][]Entity) []roleResult {
|
||||
roles, ok := langRolesMap[lang]
|
||||
if !ok {
|
||||
roles = langRolesMap["en"]
|
||||
}
|
||||
spec, ok := roles[role]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var results []roleResult
|
||||
for _, c := range childrenOf(rootIdx, tokens) {
|
||||
if spec.caseMarker != "" {
|
||||
// zh/ja agent: check for case marker in subtree
|
||||
if c.Dep == spec.dep && hasCaseMarkerInSubtree(c.Index, tokens, spec.caseMarker) {
|
||||
if ent := findEntityInSubtree(c.Index, tokens, entityMap); ent != nil {
|
||||
results = append(results, roleResult{entity: *ent})
|
||||
}
|
||||
}
|
||||
} else if spec.childDep != "" {
|
||||
// Compound rule: parent dep + child dep
|
||||
if c.Dep == spec.dep {
|
||||
prepLemma := strings.ToLower(c.Text)
|
||||
for _, gc := range childrenOf(c.Index, tokens) {
|
||||
if gc.Dep == spec.childDep {
|
||||
if ent := findEntityInSubtree(gc.Index, tokens, entityMap); ent != nil {
|
||||
if role == "prep_obj" {
|
||||
results = append(results, roleResult{entity: *ent, prep: prepLemma})
|
||||
} else {
|
||||
results = append(results, roleResult{entity: *ent})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Simple rule: single dep label
|
||||
if c.Dep == spec.dep {
|
||||
if ent := findEntityInSubtree(c.Index, tokens, entityMap); ent != nil {
|
||||
results = append(results, roleResult{entity: *ent})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
// hasCaseMarkerInSubtree checks if any token in the subtree of idx contains marker text.
|
||||
func hasCaseMarkerInSubtree(idx int, tokens []DepToken, marker string) bool {
|
||||
visited := map[int]bool{}
|
||||
subtree := collectSubtree(idx, tokens, visited)
|
||||
for _, t := range subtree {
|
||||
if t.Text == marker {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func collectSubtree(idx int, tokens []DepToken, visited map[int]bool) []DepToken {
|
||||
if visited[idx] {
|
||||
return nil
|
||||
}
|
||||
visited[idx] = true
|
||||
result := []DepToken{tokens[idx]}
|
||||
for _, c := range childrenOf(idx, tokens) {
|
||||
result = append(result, collectSubtree(c.Index, tokens, visited)...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// extractFromRoot — passive/active/preposition patterns (language-aware)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func extractFromRoot(text string, rootIdx int, tokens []DepToken, entityMap map[string][]Entity, lang string) []Relation {
|
||||
var relations []Relation
|
||||
root := tokens[rootIdx]
|
||||
verbLemma := lemma(strings.ToLower(root.Text))
|
||||
|
||||
// Get roles using language-aware mapping
|
||||
first := func(lst []roleResult) *Entity {
|
||||
if len(lst) > 0 {
|
||||
return &lst[0].entity
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
nsubj := first(getByRole(lang, "subj", rootIdx, tokens, entityMap))
|
||||
nsubjpass := first(getByRole(lang, "pass_subj", rootIdx, tokens, entityMap))
|
||||
dobj := first(getByRole(lang, "dobj", rootIdx, tokens, entityMap))
|
||||
agentList := getByRole(lang, "agent", rootIdx, tokens, entityMap)
|
||||
var agentEntity *Entity
|
||||
if len(agentList) > 0 {
|
||||
agentEntity = &agentList[0].entity
|
||||
}
|
||||
prepList := getByRole(lang, "prep_obj", rootIdx, tokens, entityMap)
|
||||
hasExplicitAgent := agentEntity != nil
|
||||
|
||||
// Detect passive:
|
||||
// - explicit pass_subj (en, fr, pt)
|
||||
// - subj + agent (zh/ja with agent marker, es-style)
|
||||
isPassiveCandidate := hasExplicitAgent
|
||||
|
||||
effectiveNsubjpass := nsubjpass
|
||||
effectiveNsubj := nsubj
|
||||
if isPassiveCandidate {
|
||||
if nsubjpass != nil {
|
||||
effectiveNsubjpass = nsubjpass
|
||||
} else if nsubj != nil {
|
||||
effectiveNsubjpass = nsubj
|
||||
effectiveNsubj = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Passive: X was founded/acquired by Y
|
||||
if effectiveNsubjpass != nil && agentEntity != nil {
|
||||
candidates := []string{"by", "von", "par", "por", "durch", "由", "によって"}
|
||||
for _, candidate := range candidates {
|
||||
if relType := lookupVerb(verbLemma, candidate); relType != "" {
|
||||
var subj, obj Entity
|
||||
if relType == "founded_by" || relType == "acquired" {
|
||||
subj, obj = *effectiveNsubjpass, *agentEntity
|
||||
} else {
|
||||
subj, obj = *agentEntity, *effectiveNsubjpass
|
||||
}
|
||||
r := makeRelation(subj, relType, obj, 0.90)
|
||||
r.Metadata = map[string]interface{}{"method": "passive", "verb": verbLemma}
|
||||
relations = append(relations, r)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Active: X VERB Y or X VERB prep Y
|
||||
if effectiveNsubj != nil {
|
||||
if dobj != nil {
|
||||
if relType := lookupVerb(verbLemma, ""); relType != "" {
|
||||
r := makeRelation(*effectiveNsubj, relType, *dobj, 0.85)
|
||||
r.Metadata = map[string]interface{}{"method": "active", "verb": verbLemma}
|
||||
relations = append(relations, r)
|
||||
}
|
||||
}
|
||||
for _, pe := range prepList {
|
||||
if relType := lookupVerb(verbLemma, pe.prep); relType != "" {
|
||||
r := makeRelation(*effectiveNsubj, relType, pe.entity, 0.85)
|
||||
r.Metadata = map[string]interface{}{"method": "active_prep", "verb": verbLemma, "prep": pe.prep}
|
||||
relations = append(relations, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Passive with prep ("is based in")
|
||||
if effectiveNsubjpass != nil && len(prepList) > 0 && agentEntity == nil {
|
||||
for _, pe := range prepList {
|
||||
relType := lookupVerb(verbLemma, pe.prep)
|
||||
if relType == "" {
|
||||
relType = lookupVerb("be+"+verbLemma, pe.prep)
|
||||
}
|
||||
if relType != "" {
|
||||
r := makeRelation(*effectiveNsubjpass, relType, pe.entity, 0.85)
|
||||
r.Metadata = map[string]interface{}{"method": "passive_prep", "verb": verbLemma, "prep": pe.prep}
|
||||
relations = append(relations, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return relations
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Copula extraction (language-aware)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func extractCopula(text string, rootIdx int, tokens []DepToken, entityMap map[string][]Entity, lang string) []Relation {
|
||||
subjList := getByRole(lang, "subj", rootIdx, tokens, entityMap)
|
||||
if len(subjList) == 0 {
|
||||
return nil
|
||||
}
|
||||
subj := subjList[0].entity
|
||||
|
||||
cd, ok := copulaDeps[lang]
|
||||
if !ok {
|
||||
cd = copulaDeps["en"]
|
||||
}
|
||||
|
||||
var titleLemma string
|
||||
var prepObj *Entity
|
||||
|
||||
for _, c := range childrenOf(rootIdx, tokens) {
|
||||
if !containsDep(c.Dep, cd.attrDeps) {
|
||||
continue
|
||||
}
|
||||
for _, cc := range childrenOf(c.Index, tokens) {
|
||||
if !containsDep(cc.Dep, cd.prepDeps) {
|
||||
continue
|
||||
}
|
||||
for _, gc := range childrenOf(cc.Index, tokens) {
|
||||
if containsDep(gc.Dep, cd.objDeps) {
|
||||
if ent := findEntityInSubtree(gc.Index, tokens, entityMap); ent != nil {
|
||||
prepObj = ent
|
||||
titleLemma = strings.ToLower(c.Text)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if titleLemma == "" || prepObj == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var relations []Relation
|
||||
for keyword, relTypes := range depCopulaTitles {
|
||||
if strings.Contains(titleLemma, keyword) {
|
||||
for _, rt := range relTypes {
|
||||
r := makeRelation(subj, rt, *prepObj, 0.88)
|
||||
r.Metadata = map[string]interface{}{"method": "copula", "title": titleLemma}
|
||||
relations = append(relations, r)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return relations
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func makeRelation(subj Entity, pred string, obj Entity, conf float64) Relation {
|
||||
return Relation{
|
||||
Subject: subj,
|
||||
Predicate: pred,
|
||||
Object: obj,
|
||||
Confidence: conf,
|
||||
Metadata: map[string]interface{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func lookupVerb(verb, prep string) string {
|
||||
if prep != "" {
|
||||
if v, ok := depVerbRelations[verb+"+"+prep]; ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
return depVerbRelations[verb]
|
||||
}
|
||||
|
||||
func containsDep(dep string, deps []string) bool {
|
||||
for _, d := range deps {
|
||||
if dep == d {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func childrenOf(idx int, tokens []DepToken) []DepToken {
|
||||
var kids []DepToken
|
||||
for _, t := range tokens {
|
||||
if t.Head == idx && t.Index != idx {
|
||||
kids = append(kids, t)
|
||||
}
|
||||
}
|
||||
return kids
|
||||
}
|
||||
|
||||
func findEntityInSubtree(idx int, tokens []DepToken, emap map[string][]Entity) *Entity {
|
||||
words := collectWords(idx, tokens, map[int]bool{})
|
||||
if len(words) == 0 {
|
||||
return nil
|
||||
}
|
||||
text := strings.Join(words, " ")
|
||||
key := strings.ToLower(strings.TrimSpace(text))
|
||||
if ent := findBestEntity(key, emap); ent != nil {
|
||||
return ent
|
||||
}
|
||||
// For CJK (no spaces), also try joining without spaces
|
||||
noSpace := strings.ToLower(strings.TrimSpace(strings.Join(words, "")))
|
||||
if noSpace != key {
|
||||
if ent := findBestEntity(noSpace, emap); ent != nil {
|
||||
return ent
|
||||
}
|
||||
}
|
||||
for _, sep := range []string{" and ", " or ", ", "} {
|
||||
if strings.Contains(key, sep) {
|
||||
candidate := strings.TrimSpace(strings.SplitN(key, sep, 2)[0])
|
||||
if ent := findBestEntity(candidate, emap); ent != nil {
|
||||
return ent
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fuzzy: try substring match
|
||||
for ek, ev := range emap {
|
||||
if strings.Contains(ek, key) || strings.Contains(key, ek) {
|
||||
e := ev[0]
|
||||
return &e
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func collectWords(idx int, tokens []DepToken, visited map[int]bool) []string {
|
||||
if visited[idx] {
|
||||
return nil
|
||||
}
|
||||
visited[idx] = true
|
||||
tok := tokens[idx]
|
||||
if tok.Dep == "prep" || tok.Dep == "punct" || tok.Dep == "det" ||
|
||||
tok.Dep == "aux" || tok.Dep == "auxpass" || tok.Dep == "cc" ||
|
||||
tok.Dep == "conj" || tok.Dep == "neg" {
|
||||
return nil
|
||||
}
|
||||
var childWords []string
|
||||
kids := childrenOf(idx, tokens)
|
||||
sortByIndex(kids)
|
||||
for _, c := range kids {
|
||||
childWords = append(childWords, collectWords(c.Index, tokens, visited)...)
|
||||
}
|
||||
hasChildBefore := false
|
||||
for _, k := range kids {
|
||||
if k.Index < idx {
|
||||
hasChildBefore = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasChildBefore {
|
||||
return append(childWords, tok.Text)
|
||||
}
|
||||
return append([]string{tok.Text}, childWords...)
|
||||
}
|
||||
|
||||
func sortByIndex(ts []DepToken) {
|
||||
for i := 0; i < len(ts); i++ {
|
||||
for j := i + 1; j < len(ts); j++ {
|
||||
if ts[j].Index < ts[i].Index {
|
||||
ts[i], ts[j] = ts[j], ts[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func dedupRelations(rels []Relation) []Relation {
|
||||
seen := make(map[string]bool)
|
||||
var result []Relation
|
||||
for _, r := range rels {
|
||||
key := strings.ToLower(r.Subject.Text + "|" + r.Predicate + "|" + r.Object.Text)
|
||||
rev := strings.ToLower(r.Object.Text + "|" + r.Predicate + "|" + r.Subject.Text)
|
||||
if seen[key] || seen[rev] {
|
||||
continue
|
||||
}
|
||||
seen[key] = true
|
||||
result = append(result, r)
|
||||
}
|
||||
return result
|
||||
}
|
||||
414
internal/ingestion/compilation/extractor/ner.go
Normal file
414
internal/ingestion/compilation/extractor/ner.go
Normal file
@@ -0,0 +1,414 @@
|
||||
//
|
||||
// Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// Package extractor provides NER and relation extraction for the ingestion
|
||||
// pipeline. It wraps the C++ ThincNER engine via cgo and supplements it with
|
||||
// pure-Go regex-based relation extraction.
|
||||
//
|
||||
// The architecture mirrors the Python rag/graphrag/ner package so that both
|
||||
// code paths produce identical output (verified by test).
|
||||
package extractor
|
||||
|
||||
// #cgo CXXFLAGS: -std=c++20 -I${SRCDIR}/../../..
|
||||
// #cgo linux LDFLAGS: ${SRCDIR}/../../../cpp/cmake-build-release/librag_tokenizer_c_api.a -lstdc++ -lm -lpthread -lpcre2-8
|
||||
// #cgo darwin LDFLAGS: ${SRCDIR}/../../../cpp/cmake-build-release/librag_tokenizer_c_api.a -lstdc++ -lm -lpthread -lpcre2-8
|
||||
//
|
||||
// #include <stdlib.h>
|
||||
// #include "../../../cpp/rag_analyzer_c_api.h"
|
||||
// #include "../../../cpp/thinc_parser.h"
|
||||
import "C"
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Entity represents an extracted named entity.
|
||||
type Entity struct {
|
||||
Text string `json:"text"`
|
||||
Label string `json:"label"`
|
||||
StartChar int `json:"start_char"`
|
||||
EndChar int `json:"end_char"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
AppType string `json:"app_type,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// Relation represents a typed relation between two entities.
|
||||
type Relation struct {
|
||||
Subject Entity `json:"subject"`
|
||||
Predicate string `json:"predicate"`
|
||||
Object Entity `json:"object"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
Context string `json:"context,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ExtractionResult holds the output of a full extraction pass.
|
||||
type ExtractionResult struct {
|
||||
Entities []Entity `json:"entities"`
|
||||
Relations []Relation `json:"relations"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// Extractor provides NER + relation extraction
|
||||
type Extractor struct {
|
||||
mu sync.Mutex
|
||||
// Language code (en/zh/de/fr/es/pt/ja)
|
||||
Lang string
|
||||
// Minimum confidence to include an entity (default 0.0 = all)
|
||||
ConfidenceThreshold float64
|
||||
// Include token-level info (POS, dep) in ExtractionResult metadata
|
||||
IncludeTokens bool
|
||||
// Max character distance for co-occurrence relations (default 100)
|
||||
MaxDistance int
|
||||
}
|
||||
|
||||
// spaCy NER label → application entity type mapping
|
||||
var spacyToAppType = map[string]string{
|
||||
"PERSON": "person",
|
||||
"ORG": "organization",
|
||||
"GPE": "geo",
|
||||
"LOC": "geo",
|
||||
"FAC": "geo",
|
||||
"EVENT": "event",
|
||||
"PRODUCT": "category",
|
||||
"DATE": "event",
|
||||
"TIME": "event",
|
||||
"MONEY": "category",
|
||||
"QUANTITY": "category",
|
||||
"PERCENT": "category",
|
||||
"LAW": "category",
|
||||
"NORP": "category",
|
||||
"LANGUAGE": "category",
|
||||
"WORK_OF_ART": "category",
|
||||
}
|
||||
|
||||
var skipLabels = map[string]bool{
|
||||
"ORDINAL": true,
|
||||
"CARDINAL": true,
|
||||
}
|
||||
|
||||
// ModelPredictor is a cached predict function for a model path.
|
||||
// Closure captures the C handle to avoid unsafe.Pointer type issues.
|
||||
type ModelPredictor func(tokensJSON string) (string, error)
|
||||
|
||||
var (
|
||||
modelCacheMu sync.Mutex
|
||||
modelCache = map[string]ModelPredictor{}
|
||||
)
|
||||
|
||||
// langModel maps language codes to spaCy model names.
|
||||
var langModel = map[string]string{
|
||||
"en": "en_core_web_sm",
|
||||
"zh": "zh_core_web_sm",
|
||||
"de": "de_core_news_sm",
|
||||
"fr": "fr_core_news_sm",
|
||||
"es": "es_core_news_sm",
|
||||
"pt": "pt_core_news_sm",
|
||||
"ja": "ja_core_news_sm",
|
||||
}
|
||||
|
||||
// langFallback maps languages without dedicated relation patterns to a fallback.
|
||||
var langFallback = map[string]string{
|
||||
"de": "en",
|
||||
"fr": "en",
|
||||
"es": "en",
|
||||
"pt": "en",
|
||||
"ja": "zh",
|
||||
}
|
||||
|
||||
// NewExtractor creates a new extractor.
|
||||
// Supported langs: en, zh, de, fr, es, pt, ja.
|
||||
func NewExtractor(lang string) *Extractor {
|
||||
if lang == "" {
|
||||
lang = "en"
|
||||
}
|
||||
if _, ok := langModel[lang]; !ok {
|
||||
lang = "en"
|
||||
}
|
||||
return &Extractor{
|
||||
Lang: lang,
|
||||
ConfidenceThreshold: 0.0, // include all by default
|
||||
MaxDistance: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// Extract runs NER and optionally relation extraction (dep-based via C++ parser, or regex fallback).
|
||||
func (e *Extractor) Extract(text string, extractRelations bool) (*ExtractionResult, error) {
|
||||
entities, err := e.ExtractEntities(text)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Collect token info if requested (before entity dedup changes offsets)
|
||||
var tokensMeta []map[string]interface{}
|
||||
if e.IncludeTokens {
|
||||
tokensJSON := tokenizeText(text, e.Lang)
|
||||
if tokensJSON != "" {
|
||||
var rawTokens []string
|
||||
if err := json.Unmarshal([]byte(tokensJSON), &rawTokens); err == nil {
|
||||
for i, t := range rawTokens {
|
||||
tokensMeta = append(tokensMeta, map[string]interface{}{
|
||||
"text": t,
|
||||
"index": i,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := &ExtractionResult{
|
||||
Entities: entities,
|
||||
Language: e.Lang,
|
||||
Metadata: map[string]interface{}{
|
||||
"n_entities": len(entities),
|
||||
"model": langModel[e.Lang],
|
||||
},
|
||||
}
|
||||
if len(tokensMeta) > 0 {
|
||||
result.Metadata["n_tokens"] = len(tokensMeta)
|
||||
result.Metadata["tokens"] = tokensMeta
|
||||
}
|
||||
|
||||
if extractRelations && len(entities) >= 2 {
|
||||
relations := e.extractRelations(text, entities)
|
||||
result.Relations = relations
|
||||
nTyped := 0
|
||||
for _, r := range relations {
|
||||
if r.Predicate != "related_to" {
|
||||
nTyped++
|
||||
}
|
||||
}
|
||||
result.Metadata["n_relations"] = nTyped
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// extractRelations attempts dep-based extraction via C++ parser; falls back to regex.
|
||||
func (e *Extractor) extractRelations(text string, entities []Entity) []Relation {
|
||||
relLang := e.Lang
|
||||
if fb, ok := langFallback[e.Lang]; ok {
|
||||
relLang = fb
|
||||
}
|
||||
// Try dep-based extraction via C++ parser — uses e.Lang (not relLang) so
|
||||
// de/fr/es/pt/ja apply their language-specific DepExtractRelations rules.
|
||||
tokensJSON := tokenizeText(text, e.Lang)
|
||||
if tokensJSON == "" {
|
||||
return extractRelationsWithOpts(text, entities, relLang, e.MaxDistance)
|
||||
}
|
||||
var tokens []string
|
||||
if err := json.Unmarshal([]byte(tokensJSON), &tokens); err != nil || len(tokens) == 0 {
|
||||
return extractRelationsWithOpts(text, entities, relLang, e.MaxDistance)
|
||||
}
|
||||
modelDir := e.findModelDir()
|
||||
nerDir := modelDir + "/ner"
|
||||
parserDir := modelDir + "/parser"
|
||||
if deps, err := ParseTokensWithParser(nerDir, parserDir, tokens); err == nil && len(deps) > 0 {
|
||||
depTokens := make([]DepToken, len(deps))
|
||||
for i, d := range deps {
|
||||
depTokens[i] = DepToken{Text: d.Text, Head: d.Head, Dep: d.Dep, Index: d.Index}
|
||||
}
|
||||
if rels := DepExtractRelations(text, depTokens, entities, e.Lang, e.MaxDistance); len(rels) > 0 {
|
||||
return rels
|
||||
}
|
||||
}
|
||||
// Fallback: regex-based extraction
|
||||
return extractRelationsWithOpts(text, entities, relLang, e.MaxDistance)
|
||||
}
|
||||
|
||||
func (e *Extractor) getPredictor(modelDir string) ModelPredictor {
|
||||
modelCacheMu.Lock()
|
||||
defer modelCacheMu.Unlock()
|
||||
if p, ok := modelCache[modelDir]; ok {
|
||||
return p
|
||||
}
|
||||
cModelDir := C.CString(modelDir + "/ner")
|
||||
cModelVocab := C.CString(modelDir + "/vocab")
|
||||
handle := C.ThincNER_Create(cModelDir, cModelVocab)
|
||||
C.free(unsafe.Pointer(cModelDir))
|
||||
C.free(unsafe.Pointer(cModelVocab))
|
||||
// Don't cache a nil handle — return a one-shot error predictor instead.
|
||||
if handle == nil {
|
||||
fn := func(tokensJSON string) (string, error) {
|
||||
return "", fmt.Errorf("ThincNER handle is nil for model dir: %s", modelDir)
|
||||
}
|
||||
return fn
|
||||
}
|
||||
p := func(tokensJSON string) (string, error) {
|
||||
e.mu.Lock()
|
||||
cTokensJSON := C.CString(tokensJSON)
|
||||
cResult := C.ThincNER_Predict(handle, cTokensJSON)
|
||||
e.mu.Unlock()
|
||||
C.free(unsafe.Pointer(cTokensJSON))
|
||||
if cResult == nil {
|
||||
return "", fmt.Errorf("NER prediction failed")
|
||||
}
|
||||
defer C.ThincNER_FreeString(cResult)
|
||||
return C.GoString(cResult), nil
|
||||
}
|
||||
modelCache[modelDir] = p
|
||||
return p
|
||||
}
|
||||
|
||||
// ExtractEntities extracts named entities from text using C++ ThincNER.
|
||||
func (e *Extractor) ExtractEntities(text string) ([]Entity, error) {
|
||||
tokensJSON := tokenizeText(text, e.Lang)
|
||||
if tokensJSON == "" {
|
||||
return nil, fmt.Errorf("tokenization failed")
|
||||
}
|
||||
|
||||
modelDir := e.findModelDir()
|
||||
predict := e.getPredictor(modelDir)
|
||||
|
||||
resultJSON, err := predict(tokensJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rawEntities []struct {
|
||||
Text string `json:"text"`
|
||||
Label string `json:"label"`
|
||||
Start int `json:"start"`
|
||||
End int `json:"end"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(resultJSON), &rawEntities); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse NER result: %w", err)
|
||||
}
|
||||
|
||||
// Dedup by (text.lower(), start_char) — matching Python NERExtractor
|
||||
// For CJK, strip spaces from entity text (BILUO decoder joins tokens with spaces)
|
||||
isCJK := e.Lang == "zh" || e.Lang == "ja"
|
||||
seen := make(map[string]bool)
|
||||
entities := make([]Entity, 0, len(rawEntities))
|
||||
for _, re := range rawEntities {
|
||||
if skipLabels[re.Label] {
|
||||
continue
|
||||
}
|
||||
if re.Confidence < e.ConfidenceThreshold {
|
||||
continue
|
||||
}
|
||||
text := re.Text
|
||||
if isCJK {
|
||||
text = strings.ReplaceAll(text, " ", "")
|
||||
}
|
||||
key := strings.ToLower(text) + "|" + strconv.Itoa(re.Start)
|
||||
if seen[key] {
|
||||
continue
|
||||
}
|
||||
seen[key] = true
|
||||
appType := spacyToAppType[re.Label]
|
||||
if appType == "" {
|
||||
appType = strings.ToLower(re.Label)
|
||||
}
|
||||
entities = append(entities, Entity{
|
||||
Text: text,
|
||||
Label: re.Label,
|
||||
StartChar: re.Start,
|
||||
EndChar: re.End,
|
||||
Confidence: re.Confidence,
|
||||
AppType: appType,
|
||||
Metadata: map[string]interface{}{"source": "thincner"},
|
||||
})
|
||||
}
|
||||
return entities, nil
|
||||
}
|
||||
|
||||
// findModelDir locates the spaCy model directory under /usr/share/infinity/resource/spacy.
|
||||
func (e *Extractor) findModelDir() string {
|
||||
modelName := langModel[e.Lang]
|
||||
if modelName == "" {
|
||||
modelName = "en_core_web_sm"
|
||||
}
|
||||
base := "/usr/share/infinity/resource/spacy/" + modelName
|
||||
if dirExists(base) {
|
||||
return base
|
||||
}
|
||||
if p := getenv("SPACY_MODEL_DIR"); p != "" {
|
||||
return p
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
func dirExists(path string) bool {
|
||||
info, err := os.Stat(path)
|
||||
return err == nil && info.IsDir()
|
||||
}
|
||||
|
||||
// tokenizeText tokenizes text via C++ tokenizer (all languages).
|
||||
// Returns JSON array of token strings.
|
||||
func tokenizeText(text, lang string) string {
|
||||
cText := C.CString(text)
|
||||
cLang := C.CString(lang)
|
||||
defer C.free(unsafe.Pointer(cText))
|
||||
defer C.free(unsafe.Pointer(cLang))
|
||||
|
||||
cTokens := C.ThincNER_Tokenize(cText, cLang)
|
||||
if cTokens == nil {
|
||||
return ""
|
||||
}
|
||||
defer C.ThincNER_FreeString(cTokens)
|
||||
return C.GoString(cTokens)
|
||||
}
|
||||
|
||||
func getenv(key string) string {
|
||||
return os.Getenv(key)
|
||||
}
|
||||
|
||||
// DetectLanguage detects text language based on Unicode ranges.
|
||||
// Pure Go, zero dependencies.
|
||||
func DetectLanguage(text string) string {
|
||||
han, hira, kata, latin := 0, 0, 0, 0
|
||||
for _, r := range text {
|
||||
switch {
|
||||
case isHan(r):
|
||||
han++
|
||||
case isHiragana(r):
|
||||
hira++
|
||||
case isKatakana(r):
|
||||
kata++
|
||||
case isLatin(r):
|
||||
latin++
|
||||
}
|
||||
}
|
||||
total := han + hira + kata + latin
|
||||
if total == 0 {
|
||||
return "en"
|
||||
}
|
||||
// CJK majority
|
||||
if float64(han+hira+kata)/float64(total) > 0.3 {
|
||||
if hira+kata > han {
|
||||
return "ja" // Japanese-heavy
|
||||
}
|
||||
if han > 0 {
|
||||
return "zh" // Han-heavy → Chinese
|
||||
}
|
||||
return "en"
|
||||
}
|
||||
// Latin majority — default to en (user specifies de/fr/es/pt explicitly)
|
||||
return "en"
|
||||
}
|
||||
|
||||
func isHan(r rune) bool { return r >= 0x4E00 && r <= 0x9FFF }
|
||||
func isHiragana(r rune) bool { return r >= 0x3040 && r <= 0x309F }
|
||||
func isKatakana(r rune) bool { return r >= 0x30A0 && r <= 0x30FF }
|
||||
func isLatin(r rune) bool { return (r >= 0x0041 && r <= 0x005A) || (r >= 0x0061 && r <= 0x007A) }
|
||||
30
internal/ingestion/compilation/extractor/ner_extractor.go
Normal file
30
internal/ingestion/compilation/extractor/ner_extractor.go
Normal file
@@ -0,0 +1,30 @@
|
||||
//
|
||||
// Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// Package extractor provides NER and relation extraction for the ingestion
|
||||
// pipeline. It mirrors the Python rag/graphrag/ner package so that both
|
||||
// code paths produce identical output.
|
||||
//
|
||||
// The C++ ThincNER engine (internal/cpp/) loads spaCy model.ckpt+model.bin
|
||||
// directly for NER inference. Relation extraction is pure Go regex.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// ext := extractor.NewExtractor("en")
|
||||
// result, err := ext.Extract("Apple Inc. was founded by Steve Jobs.", true)
|
||||
// for _, e := range result.Entities { ... }
|
||||
// for _, r := range result.Relations { ... }
|
||||
package extractor
|
||||
439
internal/ingestion/compilation/extractor/ner_relation.go
Normal file
439
internal/ingestion/compilation/extractor/ner_relation.go
Normal file
@@ -0,0 +1,439 @@
|
||||
//
|
||||
// Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package extractor
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Multilingual relation patterns — matching Python MULTILANG_RELATION_PATTERNS.
|
||||
// Entity groups [A-Z] are case-sensitive; relation keywords use (?i) inline.
|
||||
// _entWord: uppercase-start word with optional trailing period (e.g. "Inc.", "Corp.")
|
||||
// Periods between initials are also supported (e.g. "U.S.", "J.K.")
|
||||
const _entWord = `[A-Za-z][\w']*(?:\.[A-Za-z][\w']*)*\.?`
|
||||
const _relEntity = `(` + _entWord + `(?:\s+` + _entWord + `)*?)`
|
||||
const _relEntity2 = `(` + _entWord + `(?:\s+` + _entWord + `){0,1})`
|
||||
|
||||
var relationPatterns = map[string][]relPatternEntry{
|
||||
"en": {
|
||||
{regexp.MustCompile(_relEntity + `\s+(?i:was)\s+(?i:founded)\s+(?i:by)\s+` + _relEntity2), "founded_by"},
|
||||
{regexp.MustCompile(_relEntity + `\s+(?i:is)\s+(?i:an?\s+)?(?i:co-)?(?i:founder)\s+(?i:of)\s+` + _relEntity2), "founded_by"},
|
||||
{regexp.MustCompile(_relEntity + `\s+(?i:works)\s+(?i:for)\s+` + _relEntity2), "works_for"},
|
||||
{regexp.MustCompile(_relEntity + `\s+(?i:is)\s+(?i:an?\s+)?(?i:employee)\s+(?i:of)\s+` + _relEntity2), "works_for"},
|
||||
{regexp.MustCompile(_relEntity + `\s+(?i:joined)\s+` + _relEntity2), "works_for"},
|
||||
{regexp.MustCompile(_relEntity + `\s+(?i:is)\s+(?i:the\s+)?(?:CEO|CTO|CFO|VP|(?i:director|manager|engineer))\s+(?i:of|at)\s+` + _relEntity2), "works_for"},
|
||||
{regexp.MustCompile(_relEntity + `\s+(?i:is)\s+(?i:located|based|headquartered|situated)\s+(?i:in)\s+` + _relEntity2), "located_in"},
|
||||
{regexp.MustCompile(_relEntity + `\s+(?i:was)\s+(?i:born)\s+(?i:in|on)\s+` + _relEntity2), "born_in"},
|
||||
{regexp.MustCompile(_relEntity + `\s+(?i:born)\s+(?i:in|on)\s+` + _relEntity2), "born_in"},
|
||||
{regexp.MustCompile(_relEntity + `\s+(?i:was)\s+(?i:acquired)\s+(?i:by)\s+` + _relEntity2), "acquired"},
|
||||
{regexp.MustCompile(_relEntity + `\s+(?i:acquired)\s+` + _relEntity2), "acquired"},
|
||||
{regexp.MustCompile(_relEntity + `\s+(?i:is)\s+(?i:the\s+)?(?i:CEO)\s+(?i:of)\s+` + _relEntity2), "ceo_of"},
|
||||
},
|
||||
"zh": {
|
||||
{regexp.MustCompile(`([\p{Han}\w]{2,6})\s*由\s*([\p{Han}\w]{2,4})\s*(?:创立|创建|成立|创办)`), "founded_by"},
|
||||
{regexp.MustCompile(`([\p{Han}\w]{2,4})\s*(?:创立|创建|成立|创办)(?:\s*了\s*)?([\p{Han}\w]{2,10})`), "founded_by"},
|
||||
{regexp.MustCompile(`([\p{Han}\w]{2,4})\s*(?:是\s*)?([\p{Han}\w]{2,10})\s*(?:创始人|联合创始人)`), "founded_by"},
|
||||
{regexp.MustCompile(`([\p{Han}\w]{2,4})\s*(?:任职于|供职于|工作于|就职于)\s*([\p{Han}\w]{2,10})`), "works_for"},
|
||||
{regexp.MustCompile(`([\p{Han}\w]{2,4})\s*(?:是\s*)?([\p{Han}\w]{2,10})\s*(?:的员工|的雇员)`), "works_for"},
|
||||
{regexp.MustCompile(`([\p{Han}\w]{2,10})\s*(?:位于|坐落于|总部设在|总部位于)\s*([\p{Han}\w]{2,6})`), "located_in"},
|
||||
{regexp.MustCompile(`([\p{Han}\w]{2,10})\s*在\s*([\p{Han}\w]{2,6})`), "located_in"},
|
||||
{regexp.MustCompile(`([\p{Han}\w]{2,4})\s*(?:出生于|生于)\s*([\p{Han}\w]{2,6})`), "born_in"},
|
||||
{regexp.MustCompile(`([\p{Han}\w]{2,10})\s*(?:收购|并购)\s*([\p{Han}\w]{2,10})`), "acquired"},
|
||||
{regexp.MustCompile(`([\p{Han}\w]{2,10})\s*被\s*([\p{Han}\w]{2,10})\s*(?:收购|并购)`), "acquired"},
|
||||
},
|
||||
}
|
||||
|
||||
type relPatternEntry struct {
|
||||
pattern *regexp.Regexp
|
||||
predicate string
|
||||
}
|
||||
|
||||
// ExtractRelations extracts typed relations between entities.
|
||||
// Matches the Python RelationExtractor pattern-based approach,
|
||||
// including cross-sentence filtering via sentence boundary checks.
|
||||
func ExtractRelations(text string, entities []Entity, lang string) []Relation {
|
||||
return extractRelationsWithOpts(text, entities, lang, 100)
|
||||
}
|
||||
|
||||
// extractRelationsWithOpts is the internal version with configurable max distance.
|
||||
func extractRelationsWithOpts(text string, entities []Entity, lang string, maxDistance int) []Relation {
|
||||
patterns, ok := relationPatterns[lang]
|
||||
if !ok {
|
||||
patterns = relationPatterns["en"]
|
||||
}
|
||||
|
||||
// Build multimap: entity text → all occurrences (handles duplicate entity names)
|
||||
entityMultiMap := make(map[string][]Entity, len(entities))
|
||||
for _, e := range entities {
|
||||
key := strings.ToLower(e.Text)
|
||||
entityMultiMap[key] = append(entityMultiMap[key], e)
|
||||
// Also add punctuation-stripped version
|
||||
cleaned := strings.TrimRight(e.Text, ".,;:!?")
|
||||
cleaned = strings.TrimSpace(cleaned)
|
||||
if cleaned != e.Text {
|
||||
ckey := strings.ToLower(cleaned)
|
||||
entityMultiMap[ckey] = append(entityMultiMap[ckey], e)
|
||||
}
|
||||
}
|
||||
|
||||
// Build sentence spans (matching Python's sentence splitting regex)
|
||||
hasOffsets := false
|
||||
for _, e := range entities {
|
||||
if e.StartChar != 0 || e.EndChar != 0 {
|
||||
hasOffsets = true
|
||||
break
|
||||
}
|
||||
}
|
||||
var sentenceSpans [][2]int
|
||||
if hasOffsets {
|
||||
sentenceSpans = splitSentences(text)
|
||||
}
|
||||
|
||||
seen := make(map[string]bool)
|
||||
var relations []Relation
|
||||
|
||||
// Phase 1: Pattern-based typed relations
|
||||
// Process each sentence separately to prevent cross-sentence regex matches.
|
||||
// When entities have no offsets, fall back to full-text matching.
|
||||
if hasOffsets && len(sentenceSpans) > 0 {
|
||||
for _, entry := range patterns {
|
||||
for _, sp := range sentenceSpans {
|
||||
sentText := text[sp[0]:sp[1]]
|
||||
matches := entry.pattern.FindAllStringSubmatchIndex(sentText, -1)
|
||||
for _, m := range matches {
|
||||
if len(m) < 6 {
|
||||
continue
|
||||
}
|
||||
subjStart, subjEnd := m[2], m[3]
|
||||
objStart, objEnd := m[4], m[5]
|
||||
if subjStart < 0 || objStart < 0 {
|
||||
continue
|
||||
}
|
||||
// Adjust to absolute positions
|
||||
absSubjStart := subjStart + sp[0]
|
||||
absSubjEnd := subjEnd + sp[0]
|
||||
absObjStart := objStart + sp[0]
|
||||
absObjEnd := objEnd + sp[0]
|
||||
subjText := strings.TrimSpace(text[absSubjStart:absSubjEnd])
|
||||
objText := strings.TrimSpace(text[absObjStart:absObjEnd])
|
||||
subj := findEntityByText(subjText, absSubjStart, entityMultiMap)
|
||||
obj := findEntityByText(objText, absObjStart, entityMultiMap)
|
||||
if subj.Text == "" || obj.Text == "" {
|
||||
continue
|
||||
}
|
||||
key := subj.Text + "|" + entry.predicate + "|" + obj.Text
|
||||
if seen[key] {
|
||||
continue
|
||||
}
|
||||
seen[key] = true
|
||||
|
||||
var ctx string
|
||||
absMatchStart := m[0] + sp[0]
|
||||
absMatchEnd := m[1] + sp[0]
|
||||
ctx = extractContext(text, text[absMatchStart:absMatchEnd])
|
||||
relations = append(relations, Relation{
|
||||
Subject: subj,
|
||||
Predicate: entry.predicate,
|
||||
Object: obj,
|
||||
Confidence: 0.8,
|
||||
Context: ctx,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No offsets: process full text
|
||||
for _, entry := range patterns {
|
||||
matches := entry.pattern.FindAllStringSubmatchIndex(text, -1)
|
||||
for _, m := range matches {
|
||||
if len(m) < 6 {
|
||||
continue
|
||||
}
|
||||
subjStart, subjEnd := m[2], m[3]
|
||||
objStart, objEnd := m[4], m[5]
|
||||
if subjStart < 0 || objStart < 0 {
|
||||
continue
|
||||
}
|
||||
subjText := strings.TrimSpace(text[subjStart:subjEnd])
|
||||
objText := strings.TrimSpace(text[objStart:objEnd])
|
||||
subj := findEntityByText(subjText, subjStart, entityMultiMap)
|
||||
obj := findEntityByText(objText, objStart, entityMultiMap)
|
||||
if subj.Text == "" || obj.Text == "" {
|
||||
continue
|
||||
}
|
||||
key := subj.Text + "|" + entry.predicate + "|" + obj.Text
|
||||
if seen[key] {
|
||||
continue
|
||||
}
|
||||
seen[key] = true
|
||||
|
||||
var ctx string
|
||||
if len(m) >= 2 && m[0] >= 0 {
|
||||
ctx = extractContext(text, text[m[0]:m[1]])
|
||||
}
|
||||
relations = append(relations, Relation{
|
||||
Subject: subj,
|
||||
Predicate: entry.predicate,
|
||||
Object: obj,
|
||||
Confidence: 0.8,
|
||||
Context: ctx,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: Co-occurrence (standalone, with sentence boundary check)
|
||||
for _, r := range extractCooccurrence(text, entities, maxDistance) {
|
||||
key := r.Subject.Text + "|related_to|" + r.Object.Text
|
||||
if !seen[key] {
|
||||
seen[key] = true
|
||||
relations = append(relations, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Multi-hop inference + dedup (matching Python always applies these)
|
||||
relations = inferMultiHop(relations)
|
||||
relations = dedupRelations(relations)
|
||||
|
||||
return relations
|
||||
}
|
||||
|
||||
// extractCooccurrence generates related_to relations for entity pairs
|
||||
// within maxDistance characters in the same sentence.
|
||||
func extractCooccurrence(text string, entities []Entity, maxDistance int) []Relation {
|
||||
if len(entities) < 2 {
|
||||
return nil
|
||||
}
|
||||
hasOffsets := false
|
||||
for _, e := range entities {
|
||||
if e.StartChar != 0 || e.EndChar != 0 {
|
||||
hasOffsets = true
|
||||
break
|
||||
}
|
||||
}
|
||||
var sentenceSpans [][2]int
|
||||
if hasOffsets {
|
||||
sentenceSpans = splitSentences(text)
|
||||
}
|
||||
sameSentence := func(c1, c2 int) bool {
|
||||
if !hasOffsets || len(sentenceSpans) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, sp := range sentenceSpans {
|
||||
if sp[0] <= c1 && c1 < sp[1] && sp[0] <= c2 && c2 < sp[1] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
var relations []Relation
|
||||
for i := 0; i < len(entities); i++ {
|
||||
for j := i + 1; j < len(entities); j++ {
|
||||
e1, e2 := entities[i], entities[j]
|
||||
if !sameSentence(e1.StartChar, e2.StartChar) {
|
||||
continue
|
||||
}
|
||||
dist := abs(e2.StartChar - e1.EndChar)
|
||||
if dist > maxDistance {
|
||||
continue
|
||||
}
|
||||
relations = append(relations, Relation{
|
||||
Subject: e1,
|
||||
Predicate: "related_to",
|
||||
Object: e2,
|
||||
Confidence: 0.4,
|
||||
Context: extractContextSimple(text, e1, e2),
|
||||
Metadata: map[string]interface{}{"method": "cooccurrence"},
|
||||
})
|
||||
}
|
||||
}
|
||||
return relations
|
||||
}
|
||||
|
||||
// findEntityByText finds the entity occurrence closest to matchPos.
|
||||
// Uses multimap to handle duplicate entity names at different positions.
|
||||
func findEntityByText(raw string, matchPos int, entityMultiMap map[string][]Entity) Entity {
|
||||
text := strings.TrimSpace(raw)
|
||||
// Strip trailing punctuation
|
||||
for len(text) > 0 && strings.ContainsAny(text[len(text)-1:], ".,;:!?") {
|
||||
text = strings.TrimSpace(text[:len(text)-1])
|
||||
}
|
||||
ent := findClosest(text, matchPos, entityMultiMap)
|
||||
if ent.Text != "" {
|
||||
return ent
|
||||
}
|
||||
// Try stripping trailing " and ..." / " or ..." / ", ..."
|
||||
key := strings.ToLower(text)
|
||||
for _, sep := range []string{" and ", " or ", ", "} {
|
||||
if idx := strings.Index(key, sep); idx > 0 {
|
||||
if e := findClosest(key[:idx], matchPos, entityMultiMap); e.Text != "" {
|
||||
return e
|
||||
}
|
||||
}
|
||||
}
|
||||
// Try progressively shorter word sequences (right-to-left word stripping)
|
||||
// Handles cases like "Google in" → try "Google" or "Microsoft. Microsoft" → try "microsoft" (stripped)
|
||||
words := strings.Fields(key)
|
||||
for i := len(words) - 1; i > 0; i-- {
|
||||
candidate := strings.Join(words[:i], " ")
|
||||
// Strip trailing punctuation from candidate before lookup
|
||||
candidate = strings.TrimRight(candidate, ".,;:!?")
|
||||
candidate = strings.TrimSpace(candidate)
|
||||
if e := findClosest(candidate, matchPos, entityMultiMap); e.Text != "" {
|
||||
return e
|
||||
}
|
||||
}
|
||||
return Entity{}
|
||||
}
|
||||
|
||||
// findClosest returns the entity occurrence closest to matchPos from the multimap.
|
||||
// Also tries stripping trailing punctuation from name if exact match fails.
|
||||
func findClosest(name string, matchPos int, multiMap map[string][]Entity) Entity {
|
||||
entries := multiMap[strings.ToLower(name)]
|
||||
if len(entries) == 0 {
|
||||
// Try with trailing punctuation stripped
|
||||
cleaned := strings.TrimRight(name, ".,;:!?")
|
||||
cleaned = strings.TrimSpace(cleaned)
|
||||
if cleaned != name {
|
||||
entries = multiMap[strings.ToLower(cleaned)]
|
||||
}
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return Entity{}
|
||||
}
|
||||
if len(entries) == 1 {
|
||||
return entries[0]
|
||||
}
|
||||
// Multiple occurrences: pick the one whose span center is closest to matchPos
|
||||
best := entries[0]
|
||||
bestDist := abs(best.StartChar + best.EndChar - 2*matchPos)
|
||||
for i := 1; i < len(entries); i++ {
|
||||
d := abs(entries[i].StartChar + entries[i].EndChar - 2*matchPos)
|
||||
if d < bestDist {
|
||||
best = entries[i]
|
||||
bestDist = d
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func extractContext(text string, matchStr string) string {
|
||||
idx := strings.Index(text, matchStr)
|
||||
if idx < 0 {
|
||||
return ""
|
||||
}
|
||||
start := idx - 30
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
end := idx + len(matchStr) + 30
|
||||
if end > len(text) {
|
||||
end = len(text)
|
||||
}
|
||||
return text[start:end]
|
||||
}
|
||||
|
||||
func extractContextSimple(text string, e1, e2 Entity) string {
|
||||
start := min(e1.StartChar, e2.StartChar) - 20
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
end := max(e1.EndChar, e2.EndChar) + 20
|
||||
if end > len(text) {
|
||||
end = len(text)
|
||||
}
|
||||
return text[start:end]
|
||||
}
|
||||
|
||||
func abs(x int) int {
|
||||
if x < 0 {
|
||||
return -x
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// splitSentences splits text into sentence spans [start, end).
|
||||
// Matches Python's: re.finditer(r'[^.!?]+(?:[.!?](?=\s|$))+', text)
|
||||
// Go RE2 lacks lookahead, so this manually identifies sentence boundaries:
|
||||
// - Periods followed by uppercase letter or end-of-string are sentence ends.
|
||||
// - Periods followed by lowercase letter are abbreviations (e.g., "Inc."), not sentence ends.
|
||||
// - ! and ? are always sentence-ending.
|
||||
func splitSentences(text string) [][2]int {
|
||||
var spans [][2]int
|
||||
start := 0
|
||||
for i := 0; i < len(text); {
|
||||
ch := text[i]
|
||||
if ch == '!' || ch == '?' {
|
||||
end := i + 1
|
||||
spans = append(spans, [2]int{start, end})
|
||||
start = end
|
||||
i = end
|
||||
continue
|
||||
}
|
||||
if ch == '.' {
|
||||
// Check if this period is a sentence end or abbreviation
|
||||
// Sentence end: period followed by space(s) + uppercase or end-of-string
|
||||
// Abbreviation: period followed by space(s) + lowercase
|
||||
end := i + 1
|
||||
next := end
|
||||
for next < len(text) && text[next] == ' ' {
|
||||
next++
|
||||
}
|
||||
if next >= len(text) {
|
||||
// Period at end of text = sentence end
|
||||
spans = append(spans, [2]int{start, end})
|
||||
start = end
|
||||
i = end
|
||||
continue
|
||||
}
|
||||
if text[next] >= 'A' && text[next] <= 'Z' {
|
||||
// Period + space + uppercase = sentence end
|
||||
spans = append(spans, [2]int{start, end})
|
||||
start = end
|
||||
i = end
|
||||
continue
|
||||
}
|
||||
// Lowercase after period = abbreviation, not sentence end
|
||||
i = end
|
||||
continue
|
||||
}
|
||||
i++
|
||||
}
|
||||
// Remaining text after last sentence boundary
|
||||
if start < len(text) {
|
||||
spans = append(spans, [2]int{start, len(text)})
|
||||
}
|
||||
if len(spans) == 0 && len(text) > 0 {
|
||||
spans = append(spans, [2]int{0, len(text)})
|
||||
}
|
||||
return spans
|
||||
}
|
||||
113
internal/ingestion/compilation/extractor/ner_test.go
Normal file
113
internal/ingestion/compilation/extractor/ner_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
//
|
||||
// Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package extractor
|
||||
|
||||
import "testing"
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test data — 21 English test cases (ground truth from Python+spaCy)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type EnTestSpec struct {
|
||||
name string
|
||||
text string
|
||||
wantEntities [][2]string // (text, label) pairs that MUST be found
|
||||
wantRels []relSpec // typed relations that MUST be found
|
||||
}
|
||||
|
||||
type relSpec struct {
|
||||
subj string
|
||||
pred string
|
||||
obj string
|
||||
}
|
||||
|
||||
var enTests = []EnTestSpec{
|
||||
{name: "founded_by_simple", text: "Apple Inc. was founded by Steve Jobs.",
|
||||
wantEntities: [][2]string{{"Steve Jobs", "PERSON"}},
|
||||
wantRels: []relSpec{{"Apple Inc.", "founded_by", "Steve Jobs"}}},
|
||||
{name: "founded_by_multi", text: "Google was founded by Larry Page and Sergey Brin.",
|
||||
wantEntities: [][2]string{{"Larry Page", "PERSON"}, {"Sergey Brin", "PERSON"}},
|
||||
wantRels: []relSpec{{"Google", "founded_by", "Larry Page"}}},
|
||||
{name: "cofounder_of", text: "Elon Musk is a co-founder of Tesla.",
|
||||
wantEntities: [][2]string{{"Elon Musk", "PERSON"}, {"Tesla", "ORG"}},
|
||||
wantRels: []relSpec{{"Elon Musk", "founded_by", "Tesla"}}},
|
||||
{name: "works_for_simple", text: "John works for Microsoft.",
|
||||
wantEntities: [][2]string{{"John", "PERSON"}, {"Microsoft", "ORG"}},
|
||||
wantRels: []relSpec{{"John", "works_for", "Microsoft"}}},
|
||||
{name: "employee_of", text: "Mary is an employee of Google.",
|
||||
wantEntities: [][2]string{{"Mary", "PERSON"}, {"Google", "ORG"}},
|
||||
wantRels: []relSpec{{"Mary", "works_for", "Google"}}},
|
||||
{name: "joined_company", text: "Sundar Pichai joined Google in 2004.",
|
||||
wantEntities: [][2]string{{"Sundar Pichai", "PERSON"}, {"Google", "ORG"}},
|
||||
wantRels: []relSpec{{"Sundar Pichai", "works_for", "Google"}}},
|
||||
{name: "headquartered_in", text: "The company is headquartered in San Francisco.",
|
||||
wantEntities: [][2]string{{"San Francisco", "GPE"}},
|
||||
wantRels: nil},
|
||||
{name: "based_in", text: "Microsoft is based in Redmond.",
|
||||
wantEntities: [][2]string{{"Microsoft", "ORG"}, {"Redmond", "GPE"}},
|
||||
wantRels: []relSpec{{"Microsoft", "located_in", "Redmond"}}},
|
||||
{name: "born_in", text: "Albert Einstein was born in Germany.",
|
||||
wantEntities: [][2]string{{"Albert Einstein", "PERSON"}, {"Germany", "GPE"}},
|
||||
wantRels: []relSpec{{"Albert Einstein", "born_in", "Germany"}}},
|
||||
{name: "ceo_of", text: "Sundar Pichai is the CEO of Google.",
|
||||
wantEntities: [][2]string{{"Sundar Pichai", "PERSON"}, {"Google", "ORG"}},
|
||||
wantRels: []relSpec{{"Sundar Pichai", "works_for", "Google"}, {"Sundar Pichai", "ceo_of", "Google"}}},
|
||||
{name: "acquired_by", text: "Instagram was acquired by Facebook.",
|
||||
wantEntities: nil, // en_core_web_sm doesn't tag these
|
||||
wantRels: nil},
|
||||
{name: "acquired_active", text: "Facebook acquired Instagram.",
|
||||
wantEntities: [][2]string{{"Instagram", "PERSON"}}, // en_core_web_sm: Instagram→PERSON
|
||||
wantRels: nil},
|
||||
{name: "multi_founded_ceo", text: "Google was founded by Larry Page. Sundar Pichai is the CEO of Google.",
|
||||
wantEntities: [][2]string{{"Larry Page", "PERSON"}, {"Sundar Pichai", "PERSON"}},
|
||||
wantRels: nil},
|
||||
{name: "multi_works_located", text: "John works for Microsoft. Microsoft is based in Redmond.",
|
||||
wantEntities: [][2]string{{"John", "PERSON"}, {"Microsoft", "ORG"}, {"Redmond", "GPE"}},
|
||||
wantRels: []relSpec{{"Microsoft", "located_in", "Redmond"}}},
|
||||
{name: "no_entities", text: "The cat sat on the mat.",
|
||||
wantEntities: nil,
|
||||
wantRels: nil},
|
||||
{name: "org_with_inc", text: "Microsoft Corporation was founded by Bill Gates.",
|
||||
wantEntities: [][2]string{{"Bill Gates", "PERSON"}},
|
||||
wantRels: []relSpec{{"Microsoft Corporation", "founded_by", "Bill Gates"}}},
|
||||
{name: "located_city", text: "The restaurant is located in Paris.",
|
||||
wantEntities: [][2]string{{"Paris", "GPE"}},
|
||||
wantRels: nil},
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Fast unit tests (pure Go, no Python dependency)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestDetectLanguage(t *testing.T) {
|
||||
tests := []struct {
|
||||
text string
|
||||
want string
|
||||
}{
|
||||
{"Hello world", "en"},
|
||||
{"你好世界", "zh"},
|
||||
{"こんにちは世界", "ja"},
|
||||
{"阿里巴巴由马云创立", "zh"},
|
||||
{"アップルは", "ja"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := DetectLanguage(tt.text)
|
||||
if got != tt.want {
|
||||
t.Errorf("DetectLanguage(%q) = %q, want %q", tt.text, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
87
internal/ingestion/compilation/extractor/parser_go.go
Normal file
87
internal/ingestion/compilation/extractor/parser_go.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package extractor
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include "../../../cpp/thinc_parser.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// DepToken holds dependency parse info for one token (mirrors Go dep_relation.go)
|
||||
type DepTokenC struct {
|
||||
Text string `json:"text"`
|
||||
Head int `json:"head"`
|
||||
Dep string `json:"dep"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
// RunParser runs the C++ dependency parser on tokenized text.
|
||||
// modelBaseDir: path to model directory containing ner/ and parser/ subdirectories.
|
||||
// tokensJSON: JSON array of token strings.
|
||||
// Returns JSON array of DepTokenC.
|
||||
func RunParser(nerDir, parserDir string, tokensJSON string) (string, error) {
|
||||
cNer := C.CString(nerDir)
|
||||
cParser := C.CString(parserDir)
|
||||
cTokens := C.CString(tokensJSON)
|
||||
defer C.free(unsafe.Pointer(cNer))
|
||||
defer C.free(unsafe.Pointer(cParser))
|
||||
defer C.free(unsafe.Pointer(cTokens))
|
||||
|
||||
handle := C.ThincParser_Create(cNer, cParser)
|
||||
if handle == nil {
|
||||
return "", fmt.Errorf("failed to create ThincParser handle")
|
||||
}
|
||||
defer C.ThincParser_Destroy(handle)
|
||||
|
||||
cResult := C.ThincParser_Predict(handle, cTokens)
|
||||
if cResult == nil {
|
||||
return "", fmt.Errorf("parser prediction failed")
|
||||
}
|
||||
defer C.ThincParser_FreeString(cResult)
|
||||
|
||||
return C.GoString(cResult), nil
|
||||
}
|
||||
|
||||
// RunTagger runs the C++ POS tagger.
|
||||
// nerDir: path to NER model directory (for tok2vec weights).
|
||||
// taggerDir: path to tagger model directory.
|
||||
func RunTagger(nerDir, taggerDir string, tokensJSON string) (string, error) {
|
||||
cNer := C.CString(nerDir)
|
||||
cTagger := C.CString(taggerDir)
|
||||
cTokens := C.CString(tokensJSON)
|
||||
defer C.free(unsafe.Pointer(cNer))
|
||||
defer C.free(unsafe.Pointer(cTagger))
|
||||
defer C.free(unsafe.Pointer(cTokens))
|
||||
|
||||
handle := C.ThincTagger_Create(cNer, cTagger)
|
||||
if handle == nil {
|
||||
return "", fmt.Errorf("failed to create ThincTagger handle")
|
||||
}
|
||||
defer C.ThincTagger_Destroy(handle)
|
||||
|
||||
cResult := C.ThincTagger_Predict(handle, cTokens)
|
||||
if cResult == nil {
|
||||
return "", fmt.Errorf("tagger prediction failed")
|
||||
}
|
||||
defer C.ThincTagger_FreeString(cResult)
|
||||
|
||||
return C.GoString(cResult), nil
|
||||
}
|
||||
|
||||
// ParseTokensWithParser runs the C++ parser and returns parsed DepToken slice.
|
||||
func ParseTokensWithParser(nerDir, parserDir string, tokens []string) ([]DepTokenC, error) {
|
||||
tj, _ := json.Marshal(tokens)
|
||||
resultJSON, err := RunParser(nerDir, parserDir, string(tj))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var tokensC []DepTokenC
|
||||
if err := json.Unmarshal([]byte(resultJSON), &tokensC); err != nil {
|
||||
return nil, fmt.Errorf("parse result: %w", err)
|
||||
}
|
||||
return tokensC, nil
|
||||
}
|
||||
@@ -13,6 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from .graph_extractor import GraphExtractor
|
||||
from .ner_extractor import NERExtractor
|
||||
from .dep_relation_extractor import DepRelationExtractor
|
||||
from .types import Entity, ExtractionResult, Relation
|
||||
|
||||
__all__ = ["GraphExtractor"]
|
||||
__all__ = [
|
||||
"NERExtractor",
|
||||
"DepRelationExtractor",
|
||||
"Entity",
|
||||
"Relation",
|
||||
"ExtractionResult",
|
||||
]
|
||||
|
||||
558
rag/graphrag/ner/dep_relation_extractor.py
Normal file
558
rag/graphrag/ner/dep_relation_extractor.py
Normal file
@@ -0,0 +1,558 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Dependency-based relation extractor — full semantica alignment.
|
||||
|
||||
Extracts typed relations using spaCy dependency parse with:
|
||||
- Multi-hop inference (A→B→C transitivity)
|
||||
- Negation filtering
|
||||
- Dynamic confidence scoring
|
||||
- Multi-occurrence entity matching
|
||||
"""
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .types import Entity, Relation
|
||||
|
||||
# Language-specific dependency label mappings
|
||||
# Keys: pass_subj, subj, agent, dobj, prep_obj — each maps to a dep label
|
||||
# or a tuple (dep, child_dep) for compound patterns.
|
||||
# None = no standard mapping (language uses different structure)
|
||||
_LANG_DEP_RULES: Dict[str, Dict[str, object]] = {
|
||||
"en": {"pass_subj": "nsubjpass", "subj": "nsubj",
|
||||
"agent": ("agent", "pobj"),
|
||||
"dobj": "dobj", "prep_obj": ("prep", "pobj")},
|
||||
"de": {"subj": "sb",
|
||||
"agent": ("sbp", "nk"),
|
||||
"prep_obj": ("mo", "nk"),
|
||||
"root_verb_child": "oc"}, # German ROOT is aux, real verb is "oc"
|
||||
"fr": {"pass_subj": "nsubj:pass", "subj": "nsubj",
|
||||
"agent": "obl:agent",
|
||||
"dobj": "obj", "prep_obj": ("case", "obl")},
|
||||
"es": {"subj": "nsubj",
|
||||
"agent": "obj",
|
||||
"prep_obj": ("case", "obl")},
|
||||
"pt": {"pass_subj": "nsubj:pass", "subj": "nsubj",
|
||||
"agent": "obl:agent",
|
||||
"dobj": "obj", "prep_obj": ("case", "obl")},
|
||||
"zh": {"subj": "nsubj",
|
||||
"agent": ("nmod:prep", None, "由"), # case "由" marks agent
|
||||
"prep_obj": ("case", "nmod")},
|
||||
"ja": {"subj": "nsubj",
|
||||
"agent": ("obl", None, "によって"), # "によって" marks agent
|
||||
"prep_obj": ("case", "obl")},
|
||||
}
|
||||
|
||||
# Multi-hop inference rules: if A rel1 B and B rel2 C then A rel3 C
|
||||
_MULTI_HOP: Dict[str, Dict[str, str]] = {
|
||||
"ceo_of": {"is_subsidiary_of": "works_for", "located_in": "works_for"},
|
||||
"works_for": {"is_subsidiary_of": "works_for"},
|
||||
"founded_by": {"is_subsidiary_of": "founded_by"},
|
||||
}
|
||||
|
||||
_VERB_RELATIONS: Dict[str, str] = {
|
||||
# English
|
||||
"found+by": "founded_by", "co-found+by": "founded_by",
|
||||
"establish+by": "founded_by", "create+by": "founded_by",
|
||||
"set+up": "founded_by", "start+by": "founded_by",
|
||||
"work+for": "works_for", "employ+by": "works_for",
|
||||
"hire+by": "works_for", "join": "works_for",
|
||||
"lead+by": "works_for", "manage+by": "works_for",
|
||||
"head+by": "works_for", "run+by": "works_for",
|
||||
"own+by": "owns", "develop+by": "develops",
|
||||
"write+by": "wrote", "publish+by": "published",
|
||||
"invest+in": "invests_in", "partner+with": "partners_with",
|
||||
"collaborate+with": "collaborates_with",
|
||||
"merge+with": "merged_with", "subsidiar+y": "is_subsidiary_of",
|
||||
"base+in": "located_in", "locate+in": "located_in",
|
||||
"situate+in": "located_in", "headquarter+in": "located_in",
|
||||
"bear+in": "born_in", "bear+on": "born_in",
|
||||
"acquire+by": "acquired", "buy+by": "acquired",
|
||||
# German (de): spaCy lemmas
|
||||
"gründen+von": "founded_by", "errichten+von": "founded_by",
|
||||
"arbeiten+für": "works_for", "beschäftigen+bei": "works_for",
|
||||
"anstellen+bei": "works_for", "sich+befinden": "located_in",
|
||||
"liegen+in": "located_in", "sitzen+in": "located_in",
|
||||
"gebären+in": "born_in", "gebären+am": "born_in",
|
||||
"erwerben+durch": "acquired", "kaufen+durch": "acquired",
|
||||
"übernehmen+durch": "acquired",
|
||||
# French (fr): spaCy lemmas
|
||||
"fonder+par": "founded_by", "créer+par": "founded_by",
|
||||
"établir+par": "founded_by",
|
||||
"travailler+pour": "works_for", "employer+par": "works_for",
|
||||
"embaucher+par": "works_for",
|
||||
"situer+à": "located_in", "baser+à": "located_in",
|
||||
"implanter+à": "located_in",
|
||||
"naître+à": "born_in",
|
||||
"acquérir+par": "acquired", "racheter+par": "acquired",
|
||||
# Spanish + Portuguese (shared lemmas, no duplicate keys)
|
||||
"fundar+por": "founded_by", "crear+por": "founded_by",
|
||||
"criar+por": "founded_by",
|
||||
"establecer+por": "founded_by", "estabelecer+por": "founded_by",
|
||||
"trabajar+para": "works_for", "trabalhar+para": "works_for",
|
||||
"emplear+por": "works_for", "empregar+por": "works_for",
|
||||
"contratar+por": "works_for",
|
||||
"ubicar+en": "located_in", "situar+en": "located_in",
|
||||
"localizar+em": "located_in", "situar+em": "located_in",
|
||||
"sediar+em": "located_in", "tener+sede": "located_in",
|
||||
"nacer+en": "born_in", "nascer+em": "born_in",
|
||||
"adquirir+por": "acquired", "comprar+por": "acquired",
|
||||
# Chinese: verb + "由" (agent marker) or "被" (passive)
|
||||
"创立+由": "founded_by", "创建+由": "founded_by",
|
||||
"成立+由": "founded_by", "创办+由": "founded_by",
|
||||
"设立+由": "founded_by",
|
||||
"任职+于": "works_for", "就职+于": "works_for",
|
||||
"工作+在": "works_for", "位于+在": "located_in",
|
||||
"坐落+在": "located_in", "总部设+在": "located_in",
|
||||
"出生+在": "born_in", "出生+于": "born_in",
|
||||
"收购+由": "acquired", "并购+由": "acquired",
|
||||
# Japanese: verb + "によって" (agent marker)
|
||||
"設立+によって": "founded_by", "創立+によって": "founded_by",
|
||||
"勤務+で": "works_for", "在籍+で": "works_for",
|
||||
"位置+に": "located_in", "所在+に": "located_in",
|
||||
"本社+を": "located_in",
|
||||
"出生+に": "born_in",
|
||||
"買収+によって": "acquired",
|
||||
}
|
||||
|
||||
_COPULA_TITLE_MAP: Dict[str, List[str]] = {
|
||||
"ceo": ["ceo_of", "works_for"], "cto": ["works_for"],
|
||||
"cfo": ["works_for"], "coo": ["works_for"],
|
||||
"vp": ["works_for"], "director": ["works_for"],
|
||||
"manager": ["works_for"], "engineer": ["works_for"],
|
||||
"employee": ["works_for"],
|
||||
"founder": ["founded_by"], "co-founder": ["founded_by"],
|
||||
}
|
||||
|
||||
|
||||
class DepRelationExtractor:
|
||||
"""Extract typed relations using dependency parse — semantica-aligned."""
|
||||
|
||||
def __init__(self, language: str = "en",
|
||||
confidence_threshold: float = 0.3,
|
||||
max_distance: int = 100):
|
||||
self.language = language
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self.max_distance = max_distance
|
||||
|
||||
def extract(self, text: str, entities: List[Entity],
|
||||
doc=None, **options) -> List[Relation]:
|
||||
semantica_rels = []
|
||||
if doc is not None:
|
||||
semantica_rels = self._extract_with_dep(text, doc, entities)
|
||||
semantica_rels.extend(self._extract_cooccurrence(text, entities))
|
||||
semantica_rels = self._infer_multi_hop(semantica_rels)
|
||||
semantica_rels = self._deduplicate(semantica_rels)
|
||||
return [r for r in semantica_rels if r.confidence >= self.confidence_threshold]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Multi-hop inference (属性传递)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _infer_multi_hop(relations: List[Relation]) -> List[Relation]:
|
||||
"""Infer transitive relations: A→B→C ⇒ A→C."""
|
||||
by_subj: Dict[str, List[Relation]] = {}
|
||||
for r in relations:
|
||||
if r.predicate == "related_to":
|
||||
continue
|
||||
by_subj.setdefault(r.subject.text.lower(), []).append(r)
|
||||
|
||||
inferred = []
|
||||
for r in relations:
|
||||
if r.predicate == "related_to":
|
||||
continue
|
||||
obj_key = r.obj.text.lower()
|
||||
if obj_key in by_subj:
|
||||
for r2 in by_subj[obj_key]:
|
||||
if r2.predicate in _MULTI_HOP.get(r.predicate, {}):
|
||||
inferred_rel = _MULTI_HOP[r.predicate][r2.predicate]
|
||||
if inferred_rel:
|
||||
inferred.append(Relation(
|
||||
subject=r.subject, predicate=inferred_rel,
|
||||
obj=r2.obj, confidence=min(r.confidence, r2.confidence) * 0.9,
|
||||
metadata={"method": "multi_hop",
|
||||
"via": f"{r.predicate}→{r2.predicate}"},
|
||||
))
|
||||
return relations + inferred
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Dependency extraction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Language-aware role mapping
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _roles(self) -> Dict[str, str]:
|
||||
"""Get role → dep label mapping for current language."""
|
||||
return _LANG_DEP_RULES.get(self.language, _LANG_DEP_RULES["en"])
|
||||
|
||||
def _get_by_role(self, root, role: str, entity_map) -> list:
|
||||
"""Get entities for a semantic role (language-aware). Returns [(Entity, prep?)]"""
|
||||
rule = self._roles().get(role)
|
||||
if rule is None:
|
||||
return []
|
||||
results = []
|
||||
|
||||
for c in root.children:
|
||||
dep = c.dep_
|
||||
if isinstance(rule, str):
|
||||
if dep == rule:
|
||||
ent = self._entity_from_subtree(c, entity_map)
|
||||
if ent:
|
||||
results.append((ent, None))
|
||||
elif isinstance(rule, tuple):
|
||||
parent_dep, child_dep = rule[0], rule[1]
|
||||
# Check optional case marker (e.g., "由" for zh, "によって" for ja)
|
||||
case_marker = rule[2] if len(rule) > 2 else None
|
||||
if dep == parent_dep:
|
||||
if case_marker:
|
||||
# Check if any child has the expected case lemma
|
||||
has_case = any(
|
||||
gc.lemma_ == case_marker or gc.text == case_marker
|
||||
for gc in c.subtree
|
||||
)
|
||||
if not has_case:
|
||||
continue
|
||||
if child_dep is None:
|
||||
ent = self._entity_from_subtree(c, entity_map)
|
||||
if ent:
|
||||
results.append((ent, c.lemma_.lower() if role == "prep_obj" else None))
|
||||
else:
|
||||
for gc in c.children:
|
||||
if gc.dep_ == child_dep:
|
||||
ent = self._entity_from_subtree(gc, entity_map)
|
||||
if ent:
|
||||
prep = c.lemma_.lower() if role == "prep_obj" else None
|
||||
results.append((ent, prep))
|
||||
break
|
||||
return results
|
||||
|
||||
def _extract_with_dep(self, text, doc, entities) -> List[Relation]:
|
||||
relations = []
|
||||
entity_map = self._build_entity_map_multi(entities)
|
||||
is_de = self.language == "de"
|
||||
|
||||
for sent in doc.sents:
|
||||
for token in sent:
|
||||
# German: ROOT is aux verb, real verb is "oc" child
|
||||
if is_de:
|
||||
if token.dep_ != "ROOT":
|
||||
continue
|
||||
for c in token.children:
|
||||
if c.dep_ == "oc":
|
||||
# German: args attach to aux (ROOT), not main verb (oc)
|
||||
# Pass both: root aux for args, oc for verb lemma
|
||||
relations.extend(self._extract_from_root(text, c, entity_map, aux_root=token))
|
||||
continue
|
||||
|
||||
if token.dep_ != "ROOT":
|
||||
continue
|
||||
relations.extend(self._extract_from_root(text, token, entity_map))
|
||||
if token.lemma_ == "be":
|
||||
relations.extend(self._extract_copula(text, token, entity_map))
|
||||
|
||||
return relations
|
||||
|
||||
def _extract_from_root(self, text, root, entity_map, aux_root=None) -> List[Relation]:
|
||||
relations = []
|
||||
# Fall back to text when lemma is empty (zh, ja don't have lemmatizers)
|
||||
verb_lemma = (root.lemma_ or root.text).lower()
|
||||
# For languages like German where args attach to aux verb
|
||||
check = root if aux_root is None else aux_root
|
||||
|
||||
# Negation
|
||||
if any(c.dep_ in ("neg", "advmod:neg") for c in check.children):
|
||||
return relations
|
||||
|
||||
# Extract roles (check both the main verb and optional aux parent)
|
||||
def first(lst):
|
||||
return lst[0][0] if lst else None
|
||||
def get_roles(token):
|
||||
return (
|
||||
first(self._get_by_role(token, "subj", entity_map)),
|
||||
first(self._get_by_role(token, "pass_subj", entity_map)),
|
||||
first(self._get_by_role(token, "dobj", entity_map)),
|
||||
first(self._get_by_role(token, "agent", entity_map)),
|
||||
self._get_by_role(token, "prep_obj", entity_map),
|
||||
any(c.dep_ == "aux" for c in token.children),
|
||||
)
|
||||
|
||||
s1, sp1, d1, a1, p1, h1 = get_roles(root)
|
||||
s2, sp2, d2, a2, p2, h2 = (None, None, None, None, [], False)
|
||||
if aux_root:
|
||||
s2, sp2, d2, a2, p2, h2 = get_roles(aux_root)
|
||||
|
||||
# Merge: prefer found roles from aux if main verb lacks them
|
||||
nsubj = s1 or s2
|
||||
nsubjpass = sp1 or sp2
|
||||
dobj = d1 or d2
|
||||
agent_entity = a1 or a2
|
||||
prep_list = p1 + p2
|
||||
has_aux = h1 or h2 or aux_root is not None
|
||||
has_explicit_agent = agent_entity is not None
|
||||
|
||||
# Detect passive:
|
||||
# - explicit pass_subj (en, fr, pt)
|
||||
# - subj + agent + aux (Spanish-style)
|
||||
# - subj + agent for languages with agent marker (zh, ja)
|
||||
is_passive_candidate = has_explicit_agent and (has_aux or self.language in ("zh", "ja"))
|
||||
|
||||
effective_nsubjpass = nsubjpass or (nsubj if is_passive_candidate else None)
|
||||
effective_nsubj = nsubj if not is_passive_candidate else None
|
||||
|
||||
# Passive: X was founded/acquired by Y
|
||||
if effective_nsubjpass and agent_entity:
|
||||
prep = ""
|
||||
# Try language-appropriate prepositions/case markers
|
||||
candidates = ("by", "von", "par", "por", "durch", "由", "によって")
|
||||
for candidate in candidates:
|
||||
if self._lookup(verb_lemma, candidate):
|
||||
prep = candidate
|
||||
break
|
||||
rel_type = self._lookup(verb_lemma, prep) if prep else None
|
||||
if rel_type:
|
||||
if rel_type in ("founded_by", "acquired"):
|
||||
subj, obj = effective_nsubjpass, agent_entity
|
||||
else:
|
||||
subj, obj = agent_entity, effective_nsubjpass
|
||||
relations.append(self._make_rel(subj, rel_type, obj, 0.90, "passive", verb_lemma))
|
||||
|
||||
# Active: X VERB Y or X VERB prep Y
|
||||
if effective_nsubj:
|
||||
if dobj:
|
||||
rt = self._lookup(verb_lemma, None)
|
||||
if rt:
|
||||
relations.append(self._make_rel(effective_nsubj, rt, dobj, 0.85, "active", verb_lemma))
|
||||
for prep_entity, prep_l in prep_list:
|
||||
rt = self._lookup(verb_lemma, prep_l)
|
||||
if rt:
|
||||
relations.append(self._make_rel(effective_nsubj, rt, prep_entity, 0.85,
|
||||
"active_prep", verb_lemma, prep=prep_l))
|
||||
|
||||
# Passive with prep ("is based in")
|
||||
if effective_nsubjpass and prep_list and not agent_entity:
|
||||
for prep_entity, prep_l in prep_list:
|
||||
rt = self._lookup(verb_lemma, prep_l)
|
||||
if not rt:
|
||||
rt = self._lookup("be+" + verb_lemma, prep_l)
|
||||
if rt:
|
||||
relations.append(self._make_rel(effective_nsubjpass, rt, prep_entity, 0.85,
|
||||
"passive_prep", verb_lemma, prep=prep_l))
|
||||
|
||||
return relations
|
||||
|
||||
@staticmethod
|
||||
def _make_rel(subj, pred, obj, conf, method, verb, prep=""):
|
||||
m = {"method": method, "verb": verb}
|
||||
if prep:
|
||||
m["prep"] = prep
|
||||
return Relation(subject=subj, predicate=pred, obj=obj,
|
||||
confidence=conf, metadata=m)
|
||||
|
||||
@staticmethod
|
||||
def _already_has(rels, subj, pred, obj) -> bool:
|
||||
for r in rels:
|
||||
if r.subject.text == subj.text and r.predicate == pred and r.obj.text == obj.text:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_copula(self, text, root, entity_map) -> List[Relation]:
|
||||
relations = []
|
||||
# Get subject using language-specific rules
|
||||
subjs = self._get_by_role(root, "subj", entity_map)
|
||||
subj = subjs[0][0] if subjs else None
|
||||
if not subj:
|
||||
return relations
|
||||
|
||||
title_lemma = None
|
||||
prep_obj = None
|
||||
deps_to_check = ["attr", "pred"] # attr=en, pred=de
|
||||
for c in root.children:
|
||||
if c.dep_ not in deps_to_check:
|
||||
continue
|
||||
for cc in c.children:
|
||||
prep_deps = {"prep", "mo", "case"} # en=prep, de=mo, fr/case
|
||||
if cc.dep_ not in prep_deps:
|
||||
continue
|
||||
for gc in cc.children:
|
||||
pobj_deps = {"pobj", "nk", "obl"}
|
||||
if gc.dep_ in pobj_deps or True: # accept any child as object
|
||||
prep_obj = self._entity_from_subtree(gc, entity_map)
|
||||
if prep_obj:
|
||||
title_lemma = c.lemma_.lower()
|
||||
break
|
||||
|
||||
if not title_lemma or not prep_obj:
|
||||
return relations
|
||||
for keyword, rel_types in _COPULA_TITLE_MAP.items():
|
||||
if keyword in title_lemma:
|
||||
for rt in rel_types:
|
||||
relations.append(Relation(
|
||||
subject=subj, predicate=rt, obj=prep_obj,
|
||||
confidence=0.88, context=text,
|
||||
metadata={"method": "copula", "title": title_lemma},
|
||||
))
|
||||
break
|
||||
return relations
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Better entity map: multi-occurrence aware
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _build_entity_map_multi(entities: List[Entity]) -> Dict[str, List[Entity]]:
|
||||
"""Build entity map that keeps ALL occurrences per name."""
|
||||
result: Dict[str, List[Entity]] = {}
|
||||
for e in entities:
|
||||
key = e.text.lower()
|
||||
result.setdefault(key, []).append(e)
|
||||
cleaned = e.text.rstrip(".,;:!?").strip().lower()
|
||||
if cleaned != key:
|
||||
result.setdefault(cleaned, []).append(e)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _find_best_entity(key: str, entity_map: Dict[str, List[Entity]],
|
||||
fallback_text: str = "") -> Optional[Entity]:
|
||||
"""Find the best entity match. If multiple, prefer the one whose
|
||||
text is an exact match for fallback_text, or the first one."""
|
||||
entries = entity_map.get(key.lower(), [])
|
||||
if not entries:
|
||||
return None
|
||||
if len(entries) == 1:
|
||||
return entries[0]
|
||||
# Prefer exact text match
|
||||
for e in entries:
|
||||
if e.text.lower() == fallback_text.lower():
|
||||
return e
|
||||
return entries[0]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Argument extraction helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _get_child_entity(token, dep, entity_map):
|
||||
for c in token.children:
|
||||
if c.dep_ == dep:
|
||||
return DepRelationExtractor._entity_from_subtree(c, entity_map)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_agent_pobj(root, entity_map):
|
||||
for c in root.children:
|
||||
if c.dep_ == "agent":
|
||||
for gc in c.children:
|
||||
if gc.dep_ == "pobj":
|
||||
return DepRelationExtractor._entity_from_subtree(gc, entity_map)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_prep_objs(root, entity_map):
|
||||
results = []
|
||||
for c in root.children:
|
||||
if c.dep_ == "prep":
|
||||
prep_lemma = c.lemma_.lower()
|
||||
for gc in c.children:
|
||||
if gc.dep_ == "pobj":
|
||||
ent = DepRelationExtractor._entity_from_subtree(gc, entity_map)
|
||||
if ent:
|
||||
results.append((prep_lemma, ent))
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _entity_from_subtree(token, entity_map) -> Optional[Entity]:
|
||||
"""Match token's subtree against entity map. Uses character positions
|
||||
for conjunction handling."""
|
||||
min_char = token.idx
|
||||
max_char = token.idx + len(token.text)
|
||||
for t in token.subtree:
|
||||
if t.dep_ not in ("prep", "punct", "det", "aux", "auxpass", "cc", "conj"):
|
||||
if t.idx < min_char:
|
||||
min_char = t.idx
|
||||
end = t.idx + len(t.text)
|
||||
if end > max_char:
|
||||
max_char = end
|
||||
text = token.doc.text[min_char:max_char].strip()
|
||||
key = text.lower()
|
||||
# Try multi-map lookup
|
||||
entries = entity_map.get(key, [])
|
||||
if not entries:
|
||||
for sep in (" and ", " or ", ", "):
|
||||
if sep in key:
|
||||
entries = entity_map.get(key.split(sep)[0].strip(), [])
|
||||
if entries:
|
||||
break
|
||||
if not entries:
|
||||
for ek, ev in entity_map.items():
|
||||
if ek in key or key in ek:
|
||||
entries = ev
|
||||
break
|
||||
if entries:
|
||||
return entries[0]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _lookup(verb: str, prep: Optional[str] = None) -> Optional[str]:
|
||||
if prep:
|
||||
key = f"{verb}+{prep}"
|
||||
return _VERB_RELATIONS.get(key)
|
||||
return _VERB_RELATIONS.get(verb)
|
||||
|
||||
@staticmethod
|
||||
def _deduplicate(relations: List[Relation]) -> List[Relation]:
|
||||
seen = set()
|
||||
result = []
|
||||
for r in relations:
|
||||
key = (r.subject.text.lower(), r.predicate, r.obj.text.lower())
|
||||
rev = (r.obj.text.lower(), r.predicate, r.subject.text.lower())
|
||||
if key in seen or rev in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
result.append(r)
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Co-occurrence
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _extract_cooccurrence(self, text, entities) -> List[Relation]:
|
||||
if len(entities) < 2:
|
||||
return []
|
||||
import re as _re
|
||||
spans = [(m.start(), m.end())
|
||||
for m in _re.finditer(r'[^.!?]+(?:[.!?](?=\s|$))+', text)]
|
||||
|
||||
def same_sent(c1, c2):
|
||||
return any(ss <= c1 < se and ss <= c2 < se for ss, se in spans)
|
||||
|
||||
rels = []
|
||||
for i in range(len(entities)):
|
||||
for j in range(i + 1, len(entities)):
|
||||
e1, e2 = entities[i], entities[j]
|
||||
if not same_sent(e1.start_char, e2.start_char):
|
||||
continue
|
||||
if abs(e2.start_char - e1.end_char) > self.max_distance:
|
||||
continue
|
||||
cs = max(0, min(e1.start_char, e2.start_char) - 20)
|
||||
ce = min(len(text), max(e1.end_char, e2.end_char) + 20)
|
||||
rels.append(Relation(
|
||||
subject=e1, predicate="related_to", obj=e2,
|
||||
confidence=0.4, context=text[cs:ce],
|
||||
metadata={"method": "cooccurrence"},
|
||||
))
|
||||
return rels
|
||||
243
rag/graphrag/ner/ner_extractor.py
Normal file
243
rag/graphrag/ner/ner_extractor.py
Normal file
@@ -0,0 +1,243 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
NERExtractor — semantica-style full pipeline extraction.
|
||||
|
||||
Pipeline: tokenize → tag(POS) → parse(dep) → NER → typed relations
|
||||
|
||||
All components share a single spaCy `doc` object (one forward pass).
|
||||
|
||||
Output includes:
|
||||
- Entities (from NER, enriched with POS/dep)
|
||||
- Typed relations (from dependency patterns)
|
||||
- Dependency tree (heads + labels per token)
|
||||
- POS tags per token
|
||||
|
||||
Supports 7 languages: en, zh, de, fr, es, pt, ja
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import spacy
|
||||
from spacy import Language
|
||||
|
||||
from .dep_relation_extractor import DepRelationExtractor
|
||||
from .types import Entity, ExtractionResult
|
||||
|
||||
# Language → spaCy model
|
||||
_MODEL_MAP = {
|
||||
"en": "en_core_web_sm", "zh": "zh_core_web_sm",
|
||||
"de": "de_core_news_sm", "fr": "fr_core_news_sm",
|
||||
"es": "es_core_news_sm", "pt": "pt_core_news_sm",
|
||||
"ja": "ja_core_news_sm",
|
||||
}
|
||||
|
||||
# SpaCy labels to skip from NER output
|
||||
_SKIP_LABELS = {"ORDINAL", "CARDINAL"}
|
||||
|
||||
# Labels by confidence tier (for NER confidence scoring)
|
||||
_HIGH_CONF = {"PERSON", "ORG", "GPE", "LOC", "DATE"}
|
||||
_MED_CONF = {"PRODUCT", "EVENT", "WORK_OF_ART", "LAW", "LANGUAGE", "NORP",
|
||||
"MONEY", "TIME", "PERCENT", "FAC", "QUANTITY"}
|
||||
|
||||
|
||||
class NERExtractor:
|
||||
"""
|
||||
Full semantic extraction pipeline (NER + tagger + parser + relations).
|
||||
|
||||
Usage:
|
||||
ext = NERExtractor(language="en")
|
||||
result = ext.extract("Apple Inc. was founded by Steve Jobs.")
|
||||
|
||||
# result.entities → [Entity]
|
||||
# result.relations → [Relation]
|
||||
# result.tokens → [TokenInfo] (text, head, dep, tag, index)
|
||||
"""
|
||||
|
||||
# Model cache: language → nlp (shared singleton per process)
|
||||
_nlp_cache: Dict[str, Language] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
language: str = "en",
|
||||
spacy_model: Optional[str] = None,
|
||||
confidence_threshold: float = 0.3,
|
||||
):
|
||||
if language not in _MODEL_MAP and spacy_model is None:
|
||||
language = "en"
|
||||
self.language = language
|
||||
self.model_name = spacy_model or _MODEL_MAP.get(language, "en_core_web_sm")
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self._nlp: Optional[Language] = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Model lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _ensure_model(self):
|
||||
"""Lazy-load shared spaCy model. Keeps ALL pipes needed for
|
||||
dependency parsing (tagger, parser, ner, lemmatizer, attribute_ruler)."""
|
||||
if self.model_name in self._nlp_cache:
|
||||
self._nlp = self._nlp_cache[self.model_name]
|
||||
return
|
||||
try:
|
||||
nlp = spacy.load(self.model_name)
|
||||
self._nlp_cache[self.model_name] = nlp
|
||||
self._nlp = nlp
|
||||
except Exception as e:
|
||||
logging.error("Failed to load spaCy model '%s': %s",
|
||||
self.model_name, e)
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main extraction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract(
|
||||
self,
|
||||
text: str,
|
||||
extract_relations: bool = True,
|
||||
include_tokens: bool = True,
|
||||
) -> ExtractionResult:
|
||||
"""Run full pipeline on text."""
|
||||
|
||||
# 1. Single forward pass through spaCy
|
||||
self._ensure_model()
|
||||
doc = self._nlp(text)
|
||||
|
||||
# 2. Extract entities from NER
|
||||
entities = self._extract_entities(doc)
|
||||
|
||||
# 3. Build token list (with POS, dep)
|
||||
tokens = self._build_tokens(doc) if include_tokens else []
|
||||
|
||||
# 4. Extract typed relations using dependency parse
|
||||
relations = []
|
||||
if extract_relations and len(entities) >= 2:
|
||||
dep_ext = DepRelationExtractor(
|
||||
language=self.language,
|
||||
confidence_threshold=self.confidence_threshold,
|
||||
)
|
||||
relations = dep_ext.extract(text, entities, doc=doc)
|
||||
|
||||
# 5. Build result
|
||||
result = ExtractionResult(
|
||||
entities=entities,
|
||||
relations=relations,
|
||||
language=self.language,
|
||||
)
|
||||
result.metadata = {
|
||||
"model": self.model_name,
|
||||
"n_tokens": len(doc),
|
||||
"n_entities": len(entities),
|
||||
"n_relations": len([r for r in relations if r.predicate != "related_to"]),
|
||||
}
|
||||
if include_tokens:
|
||||
result.metadata["tokens"] = tokens
|
||||
|
||||
return result
|
||||
|
||||
def extract_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
extract_relations: bool = True,
|
||||
include_tokens: bool = False,
|
||||
batch_size: int = 32,
|
||||
) -> List[ExtractionResult]:
|
||||
"""Batch extraction using spaCy's nlp.pipe() for efficiency."""
|
||||
self._ensure_model()
|
||||
results = []
|
||||
for doc in self._nlp.pipe(texts, batch_size=batch_size):
|
||||
entities = self._extract_entities(doc)
|
||||
tokens = self._build_tokens(doc) if include_tokens else []
|
||||
relations = []
|
||||
if extract_relations and len(entities) >= 2:
|
||||
dep_ext = DepRelationExtractor(
|
||||
language=self.language,
|
||||
confidence_threshold=self.confidence_threshold,
|
||||
)
|
||||
relations = dep_ext.extract(doc.text, entities, doc=doc)
|
||||
result = ExtractionResult(
|
||||
entities=entities,
|
||||
relations=relations,
|
||||
language=self.language,
|
||||
)
|
||||
if include_tokens:
|
||||
result.metadata = {"tokens": tokens}
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _label_confidence(label: str) -> float:
|
||||
if label in _HIGH_CONF:
|
||||
return 0.85
|
||||
if label in _MED_CONF:
|
||||
return 0.65
|
||||
return 0.50
|
||||
|
||||
def _extract_entities(self, doc) -> List[Entity]:
|
||||
"""Extract NER entities from spaCy doc, enriched with POS."""
|
||||
entities = []
|
||||
seen = set()
|
||||
for ent in doc.ents:
|
||||
if ent.label_ in _SKIP_LABELS:
|
||||
continue
|
||||
confidence = self._label_confidence(ent.label_)
|
||||
if confidence < self.confidence_threshold:
|
||||
continue
|
||||
key = (ent.text.lower(), ent.start_char)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
entities.append(Entity(
|
||||
text=ent.text,
|
||||
label=ent.label_,
|
||||
start_char=ent.start_char,
|
||||
end_char=ent.end_char,
|
||||
confidence=confidence,
|
||||
metadata={"source": "spacy"},
|
||||
))
|
||||
return entities
|
||||
|
||||
@staticmethod
|
||||
def _build_tokens(doc) -> List[Dict[str, Any]]:
|
||||
"""Build token list with POS tags and dependency info."""
|
||||
return [
|
||||
{
|
||||
"text": t.text,
|
||||
"tag": t.tag_,
|
||||
"dep": t.dep_,
|
||||
"head": t.head.i,
|
||||
"index": i,
|
||||
"lemma": t.lemma_,
|
||||
"pos": t.pos_,
|
||||
}
|
||||
for i, t in enumerate(doc)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def clear_cache():
|
||||
"""Clear the NLP model cache (e.g., for testing)."""
|
||||
NERExtractor._nlp_cache.clear()
|
||||
|
||||
|
||||
# Patch ExtractionResult to support metadata
|
||||
|
||||
75
rag/graphrag/ner/types.py
Normal file
75
rag/graphrag/ner/types.py
Normal file
@@ -0,0 +1,75 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Data types for entity and relation extraction.
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
@dataclass
|
||||
class Entity:
|
||||
"""Extracted entity."""
|
||||
text: str
|
||||
label: str # spaCy NER label: PERSON, ORG, GPE, ...
|
||||
start_char: int
|
||||
end_char: int
|
||||
confidence: float = 1.0
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Relation:
|
||||
"""Extracted relation between two entities."""
|
||||
subject: Entity
|
||||
predicate: str # relation type: "founded_by", "works_for", ...
|
||||
obj: Entity
|
||||
confidence: float = 1.0
|
||||
context: str = "" # surrounding text
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResult:
|
||||
"""Result of a full extraction pass."""
|
||||
entities: List[Entity] = field(default_factory=list)
|
||||
relations: List[Relation] = field(default_factory=list)
|
||||
language: str = "en"
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
SPACY_TO_APP_ENTITY_TYPE: Dict[str, str] = {
|
||||
"PERSON": "person",
|
||||
"ORG": "organization",
|
||||
"GPE": "geo",
|
||||
"LOC": "geo",
|
||||
"FAC": "geo",
|
||||
"EVENT": "event",
|
||||
"PRODUCT": "category",
|
||||
"WORK_OF_ART": "category",
|
||||
"LAW": "category",
|
||||
"LANGUAGE": "category",
|
||||
"NORP": "category",
|
||||
"MONEY": "category",
|
||||
"QUANTITY": "category",
|
||||
"TIME": "event",
|
||||
"DATE": "event",
|
||||
"PERCENT": "category",
|
||||
"CARDINAL": "category",
|
||||
"ORDINAL": "category",
|
||||
}
|
||||
|
||||
SKIP_SPACY_LABELS = {"ORDINAL", "CARDINAL"}
|
||||
Reference in New Issue
Block a user