# -*- coding: utf-8 -*- # Copyright 2023 OpenSPG Authors # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. import copy import logging from typing import Dict, Type, List from kag.interface import LLMClient from tenacity import stop_after_attempt, retry from kag.interface import ExtractorABC, PromptABC, ExternalGraphLoaderABC from kag.common.conf import KAG_PROJECT_CONF from kag.common.utils import processing_phrases, to_camel_case from kag.builder.model.chunk import Chunk from kag.builder.model.sub_graph import SubGraph from kag.builder.prompt.utils import init_prompt_with_fallback from knext.schema.client import CHUNK_TYPE, BASIC_TYPES from knext.common.base.runnable import Input, Output from knext.schema.client import SchemaClient logger = logging.getLogger(__name__) @ExtractorABC.register("schema_constraint") @ExtractorABC.register("schema_constraint_extractor") class SchemaConstraintExtractor(ExtractorABC): """ Perform knowledge extraction for enforcing schema constraints, including entities, events and their edges. The types of entities and events, along with their respective attributes, are automatically inherited from the project's schema. """ def __init__( self, llm: LLMClient, ner_prompt: PromptABC = None, std_prompt: PromptABC = None, relation_prompt: PromptABC = None, event_prompt: PromptABC = None, external_graph: ExternalGraphLoaderABC = None, ): """ Initializes the SchemaBasedExtractor instance. Args: llm (LLMClient): The language model client used for extraction. ner_prompt (PromptABC, optional): The prompt for named entity recognition. Defaults to None. std_prompt (PromptABC, optional): The prompt for named entity standardization. Defaults to None. relation_prompt (PromptABC, optional): The prompt for relation extraction. Defaults to None. event_prompt (PromptABC, optional): The prompt for event extraction. Defaults to None. external_graph (ExternalGraphLoaderABC, optional): The external graph loader for additional data. Defaults to None. """ super().__init__() self.llm = llm self.schema = SchemaClient(project_id=KAG_PROJECT_CONF.project_id).load() self.ner_prompt = ner_prompt self.std_prompt = std_prompt self.relation_prompt = relation_prompt self.event_prompt = event_prompt biz_scene = KAG_PROJECT_CONF.biz_scene if self.ner_prompt is None: self.ner_prompt = init_prompt_with_fallback("ner", biz_scene) if self.std_prompt is None: self.std_prompt = init_prompt_with_fallback("std", biz_scene) self.external_graph = external_graph @property def input_types(self) -> Type[Input]: return Chunk @property def output_types(self) -> Type[Output]: return SubGraph @retry(stop=stop_after_attempt(3)) def named_entity_recognition(self, passage: str): """ Performs named entity recognition on a given text passage. Args: passage (str): The text to perform named entity recognition on. Returns: The result of the named entity recognition operation. """ ner_result = self.llm.invoke( {"input": passage}, self.ner_prompt, with_except=False ) if self.external_graph: extra_ner_result = self.external_graph.ner(passage) else: extra_ner_result = [] output = [] dedup = set() for item in extra_ner_result: name = item.name if name not in dedup: dedup.add(name) output.append( { "name": name, "category": item.label, "properties": item.properties, } ) for item in ner_result: name = item.get("name", None) category = item.get("category", None) if name is None or category is None: continue if not isinstance(name, str): continue if name not in dedup: dedup.add(name) output.append(item) return output @retry(stop=stop_after_attempt(3)) def named_entity_standardization(self, passage: str, entities: List[Dict]): """ Performs named entity standardization on a given text passage and entities. Args: passage (str): The text passage. entities (List[Dict]): The list of entities to standardize. Returns: The result of the named entity standardization operation. """ return self.llm.invoke( {"input": passage, "named_entities": entities}, self.std_prompt, with_except=False, ) @retry(stop=stop_after_attempt(3)) def relations_extraction(self, passage: str, entities: List[Dict]): """ Performs relation extraction on a given text passage and entities. Args: passage (str): The text passage. entities (List[Dict]): The list of entities. Returns: The result of the relation extraction operation. """ if self.relation_prompt is None: logger.debug("Relation extraction prompt not configured, skip.") return [] return self.llm.invoke( {"input": passage, "entity_list": entities}, self.relation_prompt, with_except=False, ) @retry(stop=stop_after_attempt(3)) def event_extraction(self, passage: str): """ Performs event extraction on a given text passage. Args: passage (str): The text passage. Returns: The result of the event extraction operation. """ if self.event_prompt is None: logger.debug("Event extraction prompt not configured, skip.") return [] return self.llm.invoke({"input": passage}, self.event_prompt, with_except=False) def parse_nodes_and_edges(self, entities: List[Dict], category: str = None): """ Parses nodes and edges from a list of entities. Args: entities (List[Dict]): The list of entities. Returns: Tuple[List[Node], List[Edge]]: The parsed nodes and edges. """ graph = SubGraph([], []) entities = copy.deepcopy(entities) root_nodes = [] for record in entities: if record is None: continue if isinstance(record, str): record = {"name": record} s_name = record.get("name", "") s_label = record.get("category", category) properties = record.get("properties", {}) # At times, the name and/or label is placed in the properties. if not s_name: s_name = properties.pop("name", "") if not s_label: s_label = properties.pop("category", "") if not s_name or not s_label: continue s_name = processing_phrases(s_name) root_nodes.append((s_name, s_label)) tmp_properties = copy.deepcopy(properties) spg_type = self.schema.get(s_label) for prop_name, prop_value in properties.items(): if prop_value is None: tmp_properties.pop(prop_name) continue if prop_name in spg_type.properties: prop_schema = spg_type.properties.get(prop_name) o_label = prop_schema.object_type_name_en if o_label not in BASIC_TYPES: # pop and convert property to node and edge if not isinstance(prop_value, list): prop_value = [prop_value] ( new_root_nodes, new_nodes, new_edges, ) = self.parse_nodes_and_edges(prop_value, o_label) graph.nodes.extend(new_nodes) graph.edges.extend(new_edges) # connect current node to property generated nodes for node in new_root_nodes: graph.add_edge( s_id=s_name, s_label=s_label, p=prop_name, o_id=node[0], o_label=node[1], ) tmp_properties.pop(prop_name) record["properties"] = tmp_properties # NOTE: For property converted to nodes/edges, we keep a copy of the original property values. # Perhaps it is not necessary? graph.add_node(id=s_name, name=s_name, label=s_label, properties=properties) if "official_name" in record: official_name = processing_phrases(record["official_name"]) if official_name != s_name: graph.add_node( id=official_name, name=official_name, label=s_label, properties=dict(properties), ) graph.add_edge( s_id=s_name, s_label=s_label, p="OfficialName", o_id=official_name, o_label=s_label, ) return root_nodes, graph.nodes, graph.edges @staticmethod def add_relations_to_graph( sub_graph: SubGraph, entities: List[Dict], relations: List[list] ): """ Add edges to the subgraph based on a list of relations and entities. Args: sub_graph (SubGraph): The subgraph to add edges to. entities (List[Dict]): A list of entities, for looking up category information. relations (List[list]): A list of relations, each representing a relationship to be added to the subgraph. Returns: The constructed subgraph. """ for rel in relations: if len(rel) != 5: continue s_name, s_category, predicate, o_name, o_category = rel s_name = processing_phrases(s_name) sub_graph.add_node(s_name, s_name, s_category) o_name = processing_phrases(o_name) sub_graph.add_node(o_name, o_name, o_category) edge_type = to_camel_case(predicate) if edge_type: sub_graph.add_edge(s_name, s_category, edge_type, o_name, o_category) return sub_graph @staticmethod def add_chunk_to_graph(sub_graph: SubGraph, chunk: Chunk): """ Associates a Chunk object with the subgraph, adding it as a node and connecting it with existing nodes. Args: sub_graph (SubGraph): The subgraph to add the chunk information to. chunk (Chunk): The chunk object containing the text and metadata. Returns: The constructed subgraph. """ for node in sub_graph.nodes: sub_graph.add_edge(node.id, node.label, "source", chunk.id, CHUNK_TYPE) sub_graph.add_node( id=chunk.id, name=chunk.name, label=CHUNK_TYPE, properties={ "id": chunk.id, "name": chunk.name, "content": f"{chunk.name}\n{chunk.content}", **chunk.kwargs, }, ) sub_graph.id = chunk.id return sub_graph def assemble_subgraph( self, chunk: Chunk, entities: List[Dict], relations: List[list], events: List[Dict], ): """ Assembles a subgraph from the given chunk, entities, events, and relations. Args: chunk (Chunk): The chunk object. entities (List[Dict]): The list of entities. events (List[Dict]): The list of events. Returns: The constructed subgraph. """ graph = SubGraph([], []) _, entity_nodes, entity_edges = self.parse_nodes_and_edges(entities) graph.nodes.extend(entity_nodes) graph.edges.extend(entity_edges) _, event_nodes, event_edges = self.parse_nodes_and_edges(events) graph.nodes.extend(event_nodes) graph.edges.extend(event_edges) self.add_relations_to_graph(graph, entities, relations) self.add_chunk_to_graph(graph, chunk) return graph def append_official_name( self, source_entities: List[Dict], entities_with_official_name: List[Dict] ): """ Appends official names to entities. Args: source_entities (List[Dict]): A list of source entities. entities_with_official_name (List[Dict]): A list of entities with official names. """ tmp_dict = {} for tmp_entity in entities_with_official_name: name = tmp_entity["name"] category = tmp_entity["category"] official_name = tmp_entity["official_name"] key = f"{category}{name}" tmp_dict[key] = official_name for tmp_entity in source_entities: name = tmp_entity["name"] category = tmp_entity["category"] key = f"{category}{name}" if key in tmp_dict: official_name = tmp_dict[key] tmp_entity["official_name"] = official_name def postprocess_graph(self, graph): """ Postprocesses the graph by merging nodes with the same name and label. Args: graph (SubGraph): The graph to postprocess. Returns: The postprocessed graph. """ try: all_node_properties = {} for node in graph.nodes: id_ = node.id name = node.name label = node.label key = (id_, name, label) if key not in all_node_properties: all_node_properties[key] = node.properties else: all_node_properties[key].update(node.properties) new_graph = SubGraph([], []) for key, node_properties in all_node_properties.items(): id_, name, label = key new_graph.add_node( id=id_, name=name, label=label, properties=node_properties ) new_graph.edges = graph.edges return new_graph except: return graph def _invoke(self, input: Input, **kwargs) -> List[Output]: """ Invokes the extractor on the given input. Args: input (Input): The input data. **kwargs: Additional keyword arguments. Returns: List[Output]: The list of output results. """ title = input.name passage = title + "\n" + input.content out = [] entities = self.named_entity_recognition(passage) events = self.event_extraction(passage) named_entities = [] for entity in entities: named_entities.append( {"name": entity["name"], "category": entity["category"]} ) relations = self.relations_extraction(passage, named_entities) std_entities = self.named_entity_standardization(passage, named_entities) self.append_official_name(entities, std_entities) subgraph = self.assemble_subgraph(input, entities, relations, events) out.append(self.postprocess_graph(subgraph)) logger.debug(f"input passage:\n{passage}") logger.debug(f"output graphs:\n{out}") return out