2024-08-25 18:58:20 +08:00
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
2024-08-02 18:51:14 +08:00
"""
Reference :
- [ graphrag ] ( https : / / github . com / microsoft / graphrag )
2025-01-22 19:43:14 +08:00
- [ LightRag ] ( https : / / github . com / HKUDS / LightRAG )
2024-08-02 18:51:14 +08:00
"""
import html
2024-12-17 09:48:03 +08:00
import json
2025-01-22 19:43:14 +08:00
import logging
2024-08-02 18:51:14 +08:00
import re
2025-01-22 19:43:14 +08:00
import time
from collections import defaultdict
from hashlib import md5
2024-11-18 17:38:17 +08:00
from typing import Any , Callable
2025-03-03 18:59:49 +08:00
import os
import trio
2025-03-26 15:34:42 +08:00
from typing import Set , Tuple
2024-08-02 18:51:14 +08:00
2025-01-22 19:43:14 +08:00
import networkx as nx
2024-12-17 09:48:03 +08:00
import numpy as np
import xxhash
2025-01-22 19:43:14 +08:00
from networkx . readwrite import json_graph
2025-03-26 15:34:42 +08:00
import dataclasses
2024-12-17 09:48:03 +08:00
2025-01-22 19:43:14 +08:00
from api import settings
2025-03-26 15:34:42 +08:00
from api . utils import get_uuid
2025-01-22 19:43:14 +08:00
from rag . nlp import search , rag_tokenizer
2025-01-23 17:26:20 +08:00
from rag . utils . doc_store_conn import OrderByExpr
2024-12-17 09:48:03 +08:00
from rag . utils . redis_conn import REDIS_CONN
2025-03-26 15:34:42 +08:00
GRAPH_FIELD_SEP = " <SEP> "
2024-08-02 18:51:14 +08:00
ErrorHandlerFn = Callable [ [ BaseException | None , str | None , dict | None ] , None ]
2025-03-10 11:14:31 +08:00
chat_limiter = trio . CapacityLimiter ( int ( os . environ . get ( ' MAX_CONCURRENT_CHATS ' , 10 ) ) )
2024-08-02 18:51:14 +08:00
2025-03-26 15:34:42 +08:00
@dataclasses.dataclass
class GraphChange :
removed_nodes : Set [ str ] = dataclasses . field ( default_factory = set )
added_updated_nodes : Set [ str ] = dataclasses . field ( default_factory = set )
removed_edges : Set [ Tuple [ str , str ] ] = dataclasses . field ( default_factory = set )
added_updated_edges : Set [ Tuple [ str , str ] ] = dataclasses . field ( default_factory = set )
2024-08-02 18:51:14 +08:00
def perform_variable_replacements (
2024-09-29 10:16:00 +08:00
input : str , history : list [ dict ] | None = None , variables : dict | None = None
2024-08-02 18:51:14 +08:00
) - > str :
""" Perform variable replacements on the input string and in a chat log. """
2024-09-29 10:16:00 +08:00
if history is None :
history = [ ]
if variables is None :
variables = { }
2024-08-02 18:51:14 +08:00
result = input
def replace_all ( input : str ) - > str :
result = input
2024-09-29 10:16:00 +08:00
for k , v in variables . items ( ) :
result = result . replace ( f " {{ { k } }} " , v )
2024-08-02 18:51:14 +08:00
return result
result = replace_all ( result )
2024-09-29 10:16:00 +08:00
for i , entry in enumerate ( history ) :
2024-08-02 18:51:14 +08:00
if entry . get ( " role " ) == " system " :
2024-09-29 10:16:00 +08:00
entry [ " content " ] = replace_all ( entry . get ( " content " ) or " " )
2024-08-02 18:51:14 +08:00
return result
def clean_str ( input : Any ) - > str :
""" Clean an input string by removing HTML escapes, control characters, and other unwanted characters. """
# If we get non-string input, just give it back
if not isinstance ( input , str ) :
return input
result = html . unescape ( input . strip ( ) )
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re . sub ( r " [ \" \ x00- \ x1f \ x7f- \ x9f] " , " " , result )
def dict_has_keys_with_types (
data : dict , expected_fields : list [ tuple [ str , type ] ]
) - > bool :
""" Return True if the given dictionary has the given keys with the given types. """
for field , field_type in expected_fields :
if field not in data :
return False
value = data [ field ]
if not isinstance ( value , field_type ) :
return False
return True
2024-12-17 09:48:03 +08:00
def get_llm_cache ( llmnm , txt , history , genconf ) :
hasher = xxhash . xxh64 ( )
hasher . update ( str ( llmnm ) . encode ( " utf-8 " ) )
hasher . update ( str ( txt ) . encode ( " utf-8 " ) )
hasher . update ( str ( history ) . encode ( " utf-8 " ) )
hasher . update ( str ( genconf ) . encode ( " utf-8 " ) )
k = hasher . hexdigest ( )
bin = REDIS_CONN . get ( k )
if not bin :
return
2024-12-17 15:28:35 +08:00
return bin
2024-12-17 09:48:03 +08:00
2024-12-30 18:38:51 +08:00
def set_llm_cache ( llmnm , txt , v , history , genconf ) :
2024-12-17 09:48:03 +08:00
hasher = xxhash . xxh64 ( )
hasher . update ( str ( llmnm ) . encode ( " utf-8 " ) )
hasher . update ( str ( txt ) . encode ( " utf-8 " ) )
hasher . update ( str ( history ) . encode ( " utf-8 " ) )
hasher . update ( str ( genconf ) . encode ( " utf-8 " ) )
k = hasher . hexdigest ( )
REDIS_CONN . set ( k , v . encode ( " utf-8 " ) , 24 * 3600 )
def get_embed_cache ( llmnm , txt ) :
hasher = xxhash . xxh64 ( )
hasher . update ( str ( llmnm ) . encode ( " utf-8 " ) )
hasher . update ( str ( txt ) . encode ( " utf-8 " ) )
k = hasher . hexdigest ( )
bin = REDIS_CONN . get ( k )
if not bin :
return
2024-12-20 10:39:50 +08:00
return np . array ( json . loads ( bin ) )
2024-12-17 09:48:03 +08:00
def set_embed_cache ( llmnm , txt , arr ) :
hasher = xxhash . xxh64 ( )
hasher . update ( str ( llmnm ) . encode ( " utf-8 " ) )
hasher . update ( str ( txt ) . encode ( " utf-8 " ) )
k = hasher . hexdigest ( )
arr = json . dumps ( arr . tolist ( ) if isinstance ( arr , np . ndarray ) else arr )
2025-01-09 17:07:21 +08:00
REDIS_CONN . set ( k , arr . encode ( " utf-8 " ) , 24 * 3600 )
def get_tags_from_cache ( kb_ids ) :
hasher = xxhash . xxh64 ( )
hasher . update ( str ( kb_ids ) . encode ( " utf-8 " ) )
k = hasher . hexdigest ( )
bin = REDIS_CONN . get ( k )
if not bin :
return
return bin
def set_tags_to_cache ( kb_ids , tags ) :
hasher = xxhash . xxh64 ( )
hasher . update ( str ( kb_ids ) . encode ( " utf-8 " ) )
k = hasher . hexdigest ( )
REDIS_CONN . set ( k , json . dumps ( tags ) . encode ( " utf-8 " ) , 600 )
2025-01-22 19:43:14 +08:00
2025-03-26 15:34:42 +08:00
def tidy_graph ( graph : nx . Graph , callback ) :
"""
Ensure all nodes and edges in the graph have some essential attribute .
"""
def is_valid_node ( node_attrs : dict ) - > bool :
valid_node = True
for attr in [ " description " , " source_id " ] :
if attr not in node_attrs :
valid_node = False
break
return valid_node
purged_nodes = [ ]
for node , node_attrs in graph . nodes ( data = True ) :
if not is_valid_node ( node_attrs ) :
purged_nodes . append ( node )
for node in purged_nodes :
graph . remove_node ( node )
if purged_nodes and callback :
callback ( msg = f " Purged { len ( purged_nodes ) } nodes from graph due to missing essential attributes. " )
purged_edges = [ ]
for source , target , attr in graph . edges ( data = True ) :
if not is_valid_node ( attr ) :
purged_edges . append ( ( source , target ) )
if " keywords " not in attr :
attr [ " keywords " ] = [ ]
for source , target in purged_edges :
graph . remove_edge ( source , target )
if purged_edges and callback :
callback ( msg = f " Purged { len ( purged_edges ) } edges from graph due to missing essential attributes. " )
def get_from_to ( node1 , node2 ) :
if node1 < node2 :
return ( node1 , node2 )
else :
return ( node2 , node1 )
def graph_merge ( g1 : nx . Graph , g2 : nx . Graph , change : GraphChange ) :
""" Merge graph g2 into g1 in place. """
for node_name , attr in g2 . nodes ( data = True ) :
change . added_updated_nodes . add ( node_name )
if not g1 . has_node ( node_name ) :
g1 . add_node ( node_name , * * attr )
2025-01-22 19:43:14 +08:00
continue
2025-03-26 15:34:42 +08:00
node = g1 . nodes [ node_name ]
node [ " description " ] + = GRAPH_FIELD_SEP + attr [ " description " ]
# A node's source_id indicates which chunks it came from.
node [ " source_id " ] + = attr [ " source_id " ]
for source , target , attr in g2 . edges ( data = True ) :
change . added_updated_edges . add ( get_from_to ( source , target ) )
edge = g1 . get_edge_data ( source , target )
if edge is None :
g1 . add_edge ( source , target , * * attr )
2025-01-22 19:43:14 +08:00
continue
2025-03-26 15:34:42 +08:00
edge [ " weight " ] + = attr . get ( " weight " , 0 )
edge [ " description " ] + = GRAPH_FIELD_SEP + attr [ " description " ]
edge [ " keywords " ] + = attr [ " keywords " ]
# A edge's source_id indicates which chunks it came from.
edge [ " source_id " ] + = attr [ " source_id " ]
for node_degree in g1 . degree :
g1 . nodes [ str ( node_degree [ 0 ] ) ] [ " rank " ] = int ( node_degree [ 1 ] )
# A graph's source_id indicates which documents it came from.
if " source_id " not in g1 . graph :
g1 . graph [ " source_id " ] = [ ]
g1 . graph [ " source_id " ] + = g2 . graph . get ( " source_id " , [ ] )
return g1
2025-01-22 19:43:14 +08:00
def compute_args_hash ( * args ) :
return md5 ( str ( args ) . encode ( ) ) . hexdigest ( )
def handle_single_entity_extraction (
record_attributes : list [ str ] ,
chunk_key : str ,
) :
if len ( record_attributes ) < 4 or record_attributes [ 0 ] != ' " entity " ' :
return None
# add this record as a node in the G
entity_name = clean_str ( record_attributes [ 1 ] . upper ( ) )
if not entity_name . strip ( ) :
return None
entity_type = clean_str ( record_attributes [ 2 ] . upper ( ) )
entity_description = clean_str ( record_attributes [ 3 ] )
entity_source_id = chunk_key
return dict (
entity_name = entity_name . upper ( ) ,
entity_type = entity_type . upper ( ) ,
description = entity_description ,
source_id = entity_source_id ,
)
def handle_single_relationship_extraction ( record_attributes : list [ str ] , chunk_key : str ) :
if len ( record_attributes ) < 5 or record_attributes [ 0 ] != ' " relationship " ' :
return None
# add this record as edge
source = clean_str ( record_attributes [ 1 ] . upper ( ) )
target = clean_str ( record_attributes [ 2 ] . upper ( ) )
edge_description = clean_str ( record_attributes [ 3 ] )
edge_keywords = clean_str ( record_attributes [ 4 ] )
edge_source_id = chunk_key
weight = (
float ( record_attributes [ - 1 ] ) if is_float_regex ( record_attributes [ - 1 ] ) else 1.0
)
pair = sorted ( [ source . upper ( ) , target . upper ( ) ] )
return dict (
src_id = pair [ 0 ] ,
tgt_id = pair [ 1 ] ,
weight = weight ,
description = edge_description ,
keywords = edge_keywords ,
source_id = edge_source_id ,
metadata = { " created_at " : time . time ( ) } ,
)
def pack_user_ass_to_openai_messages ( * args : str ) :
roles = [ " user " , " assistant " ]
return [
{ " role " : roles [ i % 2 ] , " content " : content } for i , content in enumerate ( args )
]
def split_string_by_multi_markers ( content : str , markers : list [ str ] ) - > list [ str ] :
""" Split a string by multiple markers """
if not markers :
return [ content ]
results = re . split ( " | " . join ( re . escape ( marker ) for marker in markers ) , content )
return [ r . strip ( ) for r in results if r . strip ( ) ]
def is_float_regex ( value ) :
return bool ( re . match ( r " ^[-+]?[0-9]* \ .?[0-9]+$ " , value ) )
def chunk_id ( chunk ) :
return xxhash . xxh64 ( ( chunk [ " content_with_weight " ] + chunk [ " kb_id " ] ) . encode ( " utf-8 " ) ) . hexdigest ( )
2025-03-26 15:34:42 +08:00
async def graph_node_to_chunk ( kb_id , embd_mdl , ent_name , meta , chunks ) :
2025-01-22 19:43:14 +08:00
chunk = {
2025-03-26 15:34:42 +08:00
" id " : get_uuid ( ) ,
2025-01-22 19:43:14 +08:00
" important_kwd " : [ ent_name ] ,
" title_tks " : rag_tokenizer . tokenize ( ent_name ) ,
" entity_kwd " : ent_name ,
" knowledge_graph_kwd " : " entity " ,
" entity_type_kwd " : meta [ " entity_type " ] ,
" content_with_weight " : json . dumps ( meta , ensure_ascii = False ) ,
" content_ltks " : rag_tokenizer . tokenize ( meta [ " description " ] ) ,
2025-03-26 15:34:42 +08:00
" source_id " : meta [ " source_id " ] ,
2025-01-22 19:43:14 +08:00
" kb_id " : kb_id ,
" available_int " : 0
}
chunk [ " content_sm_ltks " ] = rag_tokenizer . fine_grained_tokenize ( chunk [ " content_ltks " ] )
2025-03-26 15:34:42 +08:00
ebd = get_embed_cache ( embd_mdl . llm_name , ent_name )
if ebd is None :
ebd , _ = await trio . to_thread . run_sync ( lambda : embd_mdl . encode ( [ ent_name ] ) )
ebd = ebd [ 0 ]
set_embed_cache ( embd_mdl . llm_name , ent_name , ebd )
assert ebd is not None
chunk [ " q_ %d _vec " % len ( ebd ) ] = ebd
chunks . append ( chunk )
2025-01-22 19:43:14 +08:00
def get_relation ( tenant_id , kb_id , from_ent_name , to_ent_name , size = 1 ) :
ents = from_ent_name
if isinstance ( ents , str ) :
ents = [ from_ent_name ]
if isinstance ( to_ent_name , str ) :
to_ent_name = [ to_ent_name ]
ents . extend ( to_ent_name )
ents = list ( set ( ents ) )
conds = {
" fields " : [ " content_with_weight " ] ,
" size " : size ,
" from_entity_kwd " : ents ,
" to_entity_kwd " : ents ,
" knowledge_graph_kwd " : [ " relation " ]
}
res = [ ]
es_res = settings . retrievaler . search ( conds , search . index_name ( tenant_id ) , [ kb_id ] if isinstance ( kb_id , str ) else kb_id )
for id in es_res . ids :
try :
if size == 1 :
return json . loads ( es_res . field [ id ] [ " content_with_weight " ] )
res . append ( json . loads ( es_res . field [ id ] [ " content_with_weight " ] ) )
except Exception :
continue
return res
2025-03-26 15:34:42 +08:00
async def graph_edge_to_chunk ( kb_id , embd_mdl , from_ent_name , to_ent_name , meta , chunks ) :
2025-01-22 19:43:14 +08:00
chunk = {
2025-03-26 15:34:42 +08:00
" id " : get_uuid ( ) ,
2025-01-22 19:43:14 +08:00
" from_entity_kwd " : from_ent_name ,
" to_entity_kwd " : to_ent_name ,
" knowledge_graph_kwd " : " relation " ,
" content_with_weight " : json . dumps ( meta , ensure_ascii = False ) ,
" content_ltks " : rag_tokenizer . tokenize ( meta [ " description " ] ) ,
" important_kwd " : meta [ " keywords " ] ,
2025-03-26 15:34:42 +08:00
" source_id " : meta [ " source_id " ] ,
2025-01-22 19:43:14 +08:00
" weight_int " : int ( meta [ " weight " ] ) ,
" kb_id " : kb_id ,
" available_int " : 0
}
chunk [ " content_sm_ltks " ] = rag_tokenizer . fine_grained_tokenize ( chunk [ " content_ltks " ] )
2025-03-26 15:34:42 +08:00
txt = f " { from_ent_name } -> { to_ent_name } "
ebd = get_embed_cache ( embd_mdl . llm_name , txt )
if ebd is None :
ebd , _ = await trio . to_thread . run_sync ( lambda : embd_mdl . encode ( [ txt + f " : { meta [ ' description ' ] } " ] ) )
ebd = ebd [ 0 ]
set_embed_cache ( embd_mdl . llm_name , txt , ebd )
assert ebd is not None
chunk [ " q_ %d _vec " % len ( ebd ) ] = ebd
chunks . append ( chunk )
2025-01-22 19:43:14 +08:00
2025-03-10 15:15:06 +08:00
async def does_graph_contains ( tenant_id , kb_id , doc_id ) :
# Get doc_ids of graph
fields = [ " source_id " ]
condition = {
" knowledge_graph_kwd " : [ " graph " ] ,
" removed_kwd " : " N " ,
}
res = await trio . to_thread . run_sync ( lambda : settings . docStoreConn . search ( fields , [ ] , condition , [ ] , OrderByExpr ( ) , 0 , 1 , search . index_name ( tenant_id ) , [ kb_id ] ) )
fields2 = settings . docStoreConn . getFields ( res , fields )
graph_doc_ids = set ( )
for chunk_id in fields2 . keys ( ) :
graph_doc_ids = set ( fields2 [ chunk_id ] [ " source_id " ] )
return doc_id in graph_doc_ids
async def get_graph_doc_ids ( tenant_id , kb_id ) - > list [ str ] :
conds = {
" fields " : [ " source_id " ] ,
" removed_kwd " : " N " ,
" size " : 1 ,
" knowledge_graph_kwd " : [ " graph " ]
}
res = await trio . to_thread . run_sync ( lambda : settings . retrievaler . search ( conds , search . index_name ( tenant_id ) , [ kb_id ] ) )
doc_ids = [ ]
if res . total == 0 :
return doc_ids
for id in res . ids :
doc_ids = res . field [ id ] [ " source_id " ]
return doc_ids
2025-01-22 19:43:14 +08:00
2025-03-10 15:15:06 +08:00
async def get_graph ( tenant_id , kb_id ) :
2025-01-22 19:43:14 +08:00
conds = {
" fields " : [ " content_with_weight " , " source_id " ] ,
" removed_kwd " : " N " ,
" size " : 1 ,
" knowledge_graph_kwd " : [ " graph " ]
}
2025-03-10 15:15:06 +08:00
res = await trio . to_thread . run_sync ( lambda : settings . retrievaler . search ( conds , search . index_name ( tenant_id ) , [ kb_id ] ) )
if res . total == 0 :
2025-03-26 15:34:42 +08:00
return None
2025-01-22 19:43:14 +08:00
for id in res . ids :
try :
2025-03-26 15:34:42 +08:00
g = json_graph . node_link_graph ( json . loads ( res . field [ id ] [ " content_with_weight " ] ) , edges = " edges " )
if " source_id " not in g . graph :
g . graph [ " source_id " ] = res . field [ id ] [ " source_id " ]
return g
2025-01-22 19:43:14 +08:00
except Exception :
continue
2025-03-10 15:15:06 +08:00
result = await rebuild_graph ( tenant_id , kb_id )
return result
2025-01-22 19:43:14 +08:00
2025-03-26 15:34:42 +08:00
async def set_graph ( tenant_id : str , kb_id : str , embd_mdl , graph : nx . Graph , change : GraphChange , callback ) :
start = trio . current_time ( )
await trio . to_thread . run_sync ( lambda : settings . docStoreConn . delete ( { " knowledge_graph_kwd " : [ " graph " ] } , search . index_name ( tenant_id ) , kb_id ) )
if change . removed_nodes :
await trio . to_thread . run_sync ( lambda : settings . docStoreConn . delete ( { " knowledge_graph_kwd " : [ " entity " ] , " entity_kwd " : sorted ( change . removed_nodes ) } , search . index_name ( tenant_id ) , kb_id ) )
if change . removed_edges :
async with trio . open_nursery ( ) as nursery :
for from_node , to_node in change . removed_edges :
2025-03-31 22:31:35 +08:00
nursery . start_soon ( lambda : trio . to_thread . run_sync ( lambda : settings . docStoreConn . delete ( { " knowledge_graph_kwd " : [ " relation " ] , " from_entity_kwd " : from_node , " to_entity_kwd " : to_node } , search . index_name ( tenant_id ) , kb_id ) ) )
2025-03-26 15:34:42 +08:00
now = trio . current_time ( )
if callback :
callback ( msg = f " set_graph removed { len ( change . removed_nodes ) } nodes and { len ( change . removed_edges ) } edges from index in { now - start : .2f } s. " )
start = now
chunks = [ {
" id " : get_uuid ( ) ,
" content_with_weight " : json . dumps ( nx . node_link_data ( graph , edges = " edges " ) , ensure_ascii = False ) ,
2025-01-22 19:43:14 +08:00
" knowledge_graph_kwd " : " graph " ,
" kb_id " : kb_id ,
2025-03-26 15:34:42 +08:00
" source_id " : graph . graph . get ( " source_id " , [ ] ) ,
2025-01-22 19:43:14 +08:00
" available_int " : 0 ,
" removed_kwd " : " N "
2025-03-26 15:34:42 +08:00
} ]
async with trio . open_nursery ( ) as nursery :
for node in change . added_updated_nodes :
node_attrs = graph . nodes [ node ]
nursery . start_soon ( lambda : graph_node_to_chunk ( kb_id , embd_mdl , node , node_attrs , chunks ) )
for from_node , to_node in change . added_updated_edges :
2025-04-03 11:09:04 +08:00
edge_attrs = graph . get_edge_data ( from_node , to_node )
if not edge_attrs :
# added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging.
continue
2025-03-26 15:34:42 +08:00
nursery . start_soon ( lambda : graph_edge_to_chunk ( kb_id , embd_mdl , from_node , to_node , edge_attrs , chunks ) )
now = trio . current_time ( )
if callback :
callback ( msg = f " set_graph converted graph change to { len ( chunks ) } chunks in { now - start : .2f } s. " )
start = now
es_bulk_size = 4
for b in range ( 0 , len ( chunks ) , es_bulk_size ) :
doc_store_result = await trio . to_thread . run_sync ( lambda : settings . docStoreConn . insert ( chunks [ b : b + es_bulk_size ] , search . index_name ( tenant_id ) , kb_id ) )
if doc_store_result :
error_message = f " Insert chunk error: { doc_store_result } , please check log file and Elasticsearch/Infinity status! "
raise Exception ( error_message )
now = trio . current_time ( )
if callback :
callback ( msg = f " set_graph added/updated { len ( change . added_updated_nodes ) } nodes and { len ( change . added_updated_edges ) } edges from index in { now - start : .2f } s. " )
2025-01-22 19:43:14 +08:00
def is_continuous_subsequence ( subseq , seq ) :
def find_all_indexes ( tup , value ) :
indexes = [ ]
start = 0
while True :
try :
index = tup . index ( value , start )
indexes . append ( index )
start = index + 1
except ValueError :
break
return indexes
index_list = find_all_indexes ( seq , subseq [ 0 ] )
for idx in index_list :
if idx != len ( seq ) - 1 :
if seq [ idx + 1 ] == subseq [ - 1 ] :
return True
return False
def merge_tuples ( list1 , list2 ) :
result = [ ]
for tup in list1 :
last_element = tup [ - 1 ]
if last_element in tup [ : - 1 ] :
result . append ( tup )
else :
matching_tuples = [ t for t in list2 if t [ 0 ] == last_element ]
already_match_flag = 0
for match in matching_tuples :
matchh = ( match [ 1 ] , match [ 0 ] )
if is_continuous_subsequence ( match , tup ) or is_continuous_subsequence ( matchh , tup ) :
continue
already_match_flag = 1
merged_tuple = tup + match [ 1 : ]
result . append ( merged_tuple )
if not already_match_flag :
result . append ( tup )
return result
2025-03-10 15:15:06 +08:00
async def get_entity_type2sampels ( idxnms , kb_ids : list ) :
es_res = await trio . to_thread . run_sync ( lambda : settings . retrievaler . search ( { " knowledge_graph_kwd " : " ty2ents " , " kb_id " : kb_ids ,
2025-01-22 19:43:14 +08:00
" size " : 10000 ,
" fields " : [ " content_with_weight " ] } ,
2025-03-10 15:15:06 +08:00
idxnms , kb_ids ) )
2025-01-22 19:43:14 +08:00
res = defaultdict ( list )
for id in es_res . ids :
smp = es_res . field [ id ] . get ( " content_with_weight " )
if not smp :
continue
try :
smp = json . loads ( smp )
except Exception as e :
logging . exception ( e )
for ty , ents in smp . items ( ) :
res [ ty ] . extend ( ents )
return res
def flat_uniq_list ( arr , key ) :
res = [ ]
for a in arr :
a = a [ key ]
if isinstance ( a , list ) :
res . extend ( a )
else :
res . append ( a )
return list ( set ( res ) )
2025-01-23 17:26:20 +08:00
2025-03-10 15:15:06 +08:00
async def rebuild_graph ( tenant_id , kb_id ) :
2025-01-23 17:26:20 +08:00
graph = nx . Graph ( )
2025-03-26 15:34:42 +08:00
src_ids = set ( )
flds = [ " entity_kwd " , " from_entity_kwd " , " to_entity_kwd " , " knowledge_graph_kwd " , " content_with_weight " , " source_id " ]
2025-01-23 17:26:20 +08:00
bs = 256
2025-03-26 15:34:42 +08:00
for i in range ( 0 , 1024 * bs , bs ) :
es_res = await trio . to_thread . run_sync ( lambda : settings . docStoreConn . search ( flds , [ ] ,
{ " kb_id " : kb_id , " knowledge_graph_kwd " : [ " entity " ] } ,
[ ] ,
OrderByExpr ( ) ,
i , bs , search . index_name ( tenant_id ) , [ kb_id ]
) )
tot = settings . docStoreConn . getTotal ( es_res )
if tot == 0 :
break
es_res = settings . docStoreConn . getFields ( es_res , flds )
for id , d in es_res . items ( ) :
assert d [ " knowledge_graph_kwd " ] == " relation "
src_ids . update ( d . get ( " source_id " , [ ] ) )
attrs = json . load ( d [ " content_with_weight " ] )
graph . add_node ( d [ " entity_kwd " ] , * * attrs )
for i in range ( 0 , 1024 * bs , bs ) :
2025-03-10 15:15:06 +08:00
es_res = await trio . to_thread . run_sync ( lambda : settings . docStoreConn . search ( flds , [ ] ,
2025-03-26 15:34:42 +08:00
{ " kb_id " : kb_id , " knowledge_graph_kwd " : [ " relation " ] } ,
2025-01-23 17:26:20 +08:00
[ ] ,
OrderByExpr ( ) ,
i , bs , search . index_name ( tenant_id ) , [ kb_id ]
2025-03-10 15:15:06 +08:00
) )
2025-01-23 17:26:20 +08:00
tot = settings . docStoreConn . getTotal ( es_res )
if tot == 0 :
2025-03-26 15:34:42 +08:00
return None
2025-01-23 17:26:20 +08:00
es_res = settings . docStoreConn . getFields ( es_res , flds )
for id , d in es_res . items ( ) :
2025-03-26 15:34:42 +08:00
assert d [ " knowledge_graph_kwd " ] == " relation "
src_ids . update ( d . get ( " source_id " , [ ] ) )
if graph . has_node ( d [ " from_entity_kwd " ] ) and graph . has_node ( d [ " to_entity_kwd " ] ) :
attrs = json . load ( d [ " content_with_weight " ] )
graph . add_edge ( d [ " from_entity_kwd " ] , d [ " to_entity_kwd " ] , * * attrs )
src_ids = sorted ( src_ids )
graph . graph [ " source_id " ] = src_ids
return graph