2024-08-02 18:51:14 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
2025-03-11 18:36:10 +08:00
import logging
2024-11-14 17:13:48 +08:00
import itertools
2024-08-02 18:51:14 +08:00
import re
2025-01-22 19:43:14 +08:00
import time
2024-08-02 18:51:14 +08:00
from dataclasses import dataclass
2025-01-22 19:43:14 +08:00
from typing import Any , Callable
2024-08-02 18:51:14 +08:00
import networkx as nx
2025-03-03 18:59:49 +08:00
import trio
2024-12-17 09:48:03 +08:00
2025-01-22 19:43:14 +08:00
from graphrag . general . extractor import Extractor
2024-08-02 18:51:14 +08:00
from rag . nlp import is_english
import editdistance
from graphrag . entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
from rag . llm . chat_model import Base as CompletionLLM
2025-03-03 18:59:49 +08:00
from graphrag . utils import perform_variable_replacements , chat_limiter
2024-08-02 18:51:14 +08:00
DEFAULT_RECORD_DELIMITER = " ## "
DEFAULT_ENTITY_INDEX_DELIMITER = " <|> "
DEFAULT_RESOLUTION_RESULT_DELIMITER = " && "
@dataclass
class EntityResolutionResult :
""" Entity resolution result class definition. """
2025-01-22 19:43:14 +08:00
graph : nx . Graph
removed_entities : list
2024-08-02 18:51:14 +08:00
2024-12-17 09:48:03 +08:00
class EntityResolution ( Extractor ) :
2024-08-02 18:51:14 +08:00
""" Entity resolution class definition. """
_resolution_prompt : str
_output_formatter_prompt : str
_record_delimiter_key : str
_entity_index_delimiter_key : str
_resolution_result_delimiter_key : str
def __init__ (
self ,
llm_invoker : CompletionLLM ,
2025-01-22 19:43:14 +08:00
get_entity : Callable | None = None ,
set_entity : Callable | None = None ,
get_relation : Callable | None = None ,
set_relation : Callable | None = None
2024-08-02 18:51:14 +08:00
) :
2025-01-22 19:43:14 +08:00
super ( ) . __init__ ( llm_invoker , get_entity = get_entity , set_entity = set_entity , get_relation = get_relation , set_relation = set_relation )
2024-08-02 18:51:14 +08:00
""" Init method definition. """
self . _llm = llm_invoker
2025-01-22 19:43:14 +08:00
self . _resolution_prompt = ENTITY_RESOLUTION_PROMPT
self . _record_delimiter_key = " record_delimiter "
self . _entity_index_dilimiter_key = " entity_index_delimiter "
self . _resolution_result_delimiter_key = " resolution_result_delimiter "
self . _input_text_key = " input_text "
2024-08-02 18:51:14 +08:00
2025-03-11 18:36:10 +08:00
async def __call__ ( self , graph : nx . Graph , prompt_variables : dict [ str , Any ] | None = None , callback : Callable | None = None ) - > EntityResolutionResult :
2024-08-02 18:51:14 +08:00
""" Call method definition. """
if prompt_variables is None :
prompt_variables = { }
# Wire defaults into the prompt variables
2025-03-03 18:59:49 +08:00
self . prompt_variables = {
2024-08-02 18:51:14 +08:00
* * prompt_variables ,
self . _record_delimiter_key : prompt_variables . get ( self . _record_delimiter_key )
or DEFAULT_RECORD_DELIMITER ,
self . _entity_index_dilimiter_key : prompt_variables . get ( self . _entity_index_dilimiter_key )
or DEFAULT_ENTITY_INDEX_DELIMITER ,
self . _resolution_result_delimiter_key : prompt_variables . get ( self . _resolution_result_delimiter_key )
or DEFAULT_RESOLUTION_RESULT_DELIMITER ,
}
nodes = graph . nodes
2025-01-22 19:43:14 +08:00
entity_types = list ( set ( graph . nodes [ node ] . get ( ' entity_type ' , ' - ' ) for node in nodes ) )
2024-08-02 18:51:14 +08:00
node_clusters = { entity_type : [ ] for entity_type in entity_types }
for node in nodes :
2025-01-22 19:43:14 +08:00
node_clusters [ graph . nodes [ node ] . get ( ' entity_type ' , ' - ' ) ] . append ( node )
2024-08-02 18:51:14 +08:00
candidate_resolution = { entity_type : [ ] for entity_type in entity_types }
2024-09-30 16:54:25 +08:00
for k , v in node_clusters . items ( ) :
2024-09-30 17:18:09 +08:00
candidate_resolution [ k ] = [ ( a , b ) for a , b in itertools . combinations ( v , 2 ) if self . is_similarity ( a , b ) ]
2025-03-11 18:36:10 +08:00
num_candidates = sum ( [ len ( candidates ) for _ , candidates in candidate_resolution . items ( ) ] )
callback ( msg = f " Identified { num_candidates } candidate pairs " )
2024-08-02 18:51:14 +08:00
resolution_result = set ( )
2025-03-03 18:59:49 +08:00
async with trio . open_nursery ( ) as nursery :
for candidate_resolution_i in candidate_resolution . items ( ) :
if not candidate_resolution_i [ 1 ] :
continue
2025-03-04 14:46:54 +08:00
nursery . start_soon ( lambda : self . _resolve_candidate ( candidate_resolution_i , resolution_result ) )
2025-03-11 18:36:10 +08:00
callback ( msg = f " Resolved { num_candidates } candidate pairs, { len ( resolution_result ) } of them are selected to merge. " )
2024-08-02 18:51:14 +08:00
connect_graph = nx . Graph ( )
2025-01-22 19:43:14 +08:00
removed_entities = [ ]
2024-08-02 18:51:14 +08:00
connect_graph . add_edges_from ( resolution_result )
2025-03-05 14:48:03 +08:00
all_entities_data = [ ]
2025-03-05 11:10:04 +08:00
all_relationships_data = [ ]
2025-03-11 18:36:10 +08:00
all_remove_nodes = [ ]
2025-03-05 11:10:04 +08:00
2025-03-11 18:36:10 +08:00
async with trio . open_nursery ( ) as nursery :
for sub_connect_graph in nx . connected_components ( connect_graph ) :
sub_connect_graph = connect_graph . subgraph ( sub_connect_graph )
remove_nodes = list ( sub_connect_graph . nodes )
keep_node = remove_nodes . pop ( )
all_remove_nodes . append ( remove_nodes )
nursery . start_soon ( lambda : self . _merge_nodes ( keep_node , self . _get_entity_ ( remove_nodes ) , all_entities_data ) )
for remove_node in remove_nodes :
removed_entities . append ( remove_node )
remove_node_neighbors = graph [ remove_node ]
remove_node_neighbors = list ( remove_node_neighbors )
for remove_node_neighbor in remove_node_neighbors :
rel = self . _get_relation_ ( remove_node , remove_node_neighbor )
if graph . has_edge ( remove_node , remove_node_neighbor ) :
graph . remove_edge ( remove_node , remove_node_neighbor )
if remove_node_neighbor == keep_node :
if graph . has_edge ( keep_node , remove_node ) :
graph . remove_edge ( keep_node , remove_node )
continue
if not rel :
continue
if graph . has_edge ( keep_node , remove_node_neighbor ) :
nursery . start_soon ( lambda : self . _merge_edges ( keep_node , remove_node_neighbor , [ rel ] , all_relationships_data ) )
else :
pair = sorted ( [ keep_node , remove_node_neighbor ] )
graph . add_edge ( pair [ 0 ] , pair [ 1 ] , weight = rel [ ' weight ' ] )
self . _set_relation_ ( pair [ 0 ] , pair [ 1 ] ,
dict (
src_id = pair [ 0 ] ,
tgt_id = pair [ 1 ] ,
weight = rel [ ' weight ' ] ,
description = rel [ ' description ' ] ,
keywords = [ ] ,
source_id = rel . get ( " source_id " , " " ) ,
metadata = { " created_at " : time . time ( ) }
) )
graph . remove_node ( remove_node )
2024-08-02 18:51:14 +08:00
return EntityResolutionResult (
2025-01-22 19:43:14 +08:00
graph = graph ,
removed_entities = removed_entities
2024-08-02 18:51:14 +08:00
)
2025-03-03 18:59:49 +08:00
async def _resolve_candidate ( self , candidate_resolution_i , resolution_result ) :
gen_conf = { " temperature " : 0.5 }
pair_txt = [
f ' When determining whether two { candidate_resolution_i [ 0 ] } s are the same, you should only focus on critical properties and overlook noisy factors. \n ' ]
for index , candidate in enumerate ( candidate_resolution_i [ 1 ] ) :
pair_txt . append (
f ' Question { index + 1 } : name of { candidate_resolution_i [ 0 ] } A is { candidate [ 0 ] } ,name of { candidate_resolution_i [ 0 ] } B is { candidate [ 1 ] } ' )
sent = ' question above ' if len ( pair_txt ) == 1 else f ' above { len ( pair_txt ) } questions '
pair_txt . append (
f ' \n Use domain knowledge of { candidate_resolution_i [ 0 ] } s to help understand the text and answer the { sent } in the format: For Question i, Yes, { candidate_resolution_i [ 0 ] } A and { candidate_resolution_i [ 0 ] } B are the same { candidate_resolution_i [ 0 ] } ./No, { candidate_resolution_i [ 0 ] } A and { candidate_resolution_i [ 0 ] } B are different { candidate_resolution_i [ 0 ] } s. For Question i+1, (repeat the above procedures) ' )
pair_prompt = ' \n ' . join ( pair_txt )
variables = {
* * self . prompt_variables ,
self . _input_text_key : pair_prompt
}
text = perform_variable_replacements ( self . _resolution_prompt , variables = variables )
2025-03-11 18:36:10 +08:00
logging . info ( f " Created resolution prompt { len ( text ) } bytes for { len ( candidate_resolution_i [ 1 ] ) } entity pairs of type { candidate_resolution_i [ 0 ] } " )
2025-03-03 18:59:49 +08:00
async with chat_limiter :
response = await trio . to_thread . run_sync ( lambda : self . _chat ( text , [ { " role " : " user " , " content " : " Output: " } ] , gen_conf ) )
2025-03-11 18:36:10 +08:00
logging . debug ( f " _resolve_candidate chat prompt: { text } \n chat response: { response } " )
2025-03-03 18:59:49 +08:00
result = self . _process_results ( len ( candidate_resolution_i [ 1 ] ) , response ,
self . prompt_variables . get ( self . _record_delimiter_key ,
DEFAULT_RECORD_DELIMITER ) ,
self . prompt_variables . get ( self . _entity_index_dilimiter_key ,
DEFAULT_ENTITY_INDEX_DELIMITER ) ,
self . prompt_variables . get ( self . _resolution_result_delimiter_key ,
DEFAULT_RESOLUTION_RESULT_DELIMITER ) )
for result_i in result :
resolution_result . add ( candidate_resolution_i [ 1 ] [ result_i [ 0 ] - 1 ] )
2024-08-02 18:51:14 +08:00
def _process_results (
self ,
records_length : int ,
results : str ,
record_delimiter : str ,
entity_index_delimiter : str ,
resolution_result_delimiter : str
) - > list :
ans_list = [ ]
records = [ r . strip ( ) for r in results . split ( record_delimiter ) ]
for record in records :
pattern_int = f " { re . escape ( entity_index_delimiter ) } ( \ d+) { re . escape ( entity_index_delimiter ) } "
match_int = re . search ( pattern_int , record )
res_int = int ( str ( match_int . group ( 1 ) if match_int else ' 0 ' ) )
if res_int > records_length :
continue
pattern_bool = f " { re . escape ( resolution_result_delimiter ) } ([a-zA-Z]+) { re . escape ( resolution_result_delimiter ) } "
match_bool = re . search ( pattern_bool , record )
res_bool = str ( match_bool . group ( 1 ) if match_bool else ' ' )
if res_int and res_bool :
if res_bool . lower ( ) == ' yes ' :
ans_list . append ( ( res_int , " yes " ) )
return ans_list
def is_similarity ( self , a , b ) :
if is_english ( a ) and is_english ( b ) :
if editdistance . eval ( a , b ) < = min ( len ( a ) , len ( b ) ) / / 2 :
return True
2025-03-21 13:16:34 +08:00
if len ( set ( a ) & set ( b ) ) > 1 :
2024-08-02 18:51:14 +08:00
return True
return False