# Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License """A module containing run_graph_intelligence, run_extract_graph and _create_text_splitter methods to run graph intelligence.""" import networkx as nx from fnllm.types import ChatLLM import graphrag.config.defaults as defs from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.language_model_config import LanguageModelConfig from graphrag.index.llm.load_llm import load_llm from graphrag.index.operations.extract_graph.graph_extractor import GraphExtractor from graphrag.index.operations.extract_graph.typing import ( Document, EntityExtractionResult, EntityTypes, StrategyConfig, ) async def run_graph_intelligence( docs: list[Document], entity_types: EntityTypes, callbacks: WorkflowCallbacks, cache: PipelineCache, args: StrategyConfig, ) -> EntityExtractionResult: """Run the graph intelligence entity extraction strategy.""" llm_config = LanguageModelConfig(**args["llm"]) llm = load_llm( "extract_graph", llm_config, callbacks=callbacks, cache=cache, ) return await run_extract_graph(llm, docs, entity_types, callbacks, args) async def run_extract_graph( llm: ChatLLM, docs: list[Document], entity_types: EntityTypes, callbacks: WorkflowCallbacks | None, args: StrategyConfig, ) -> EntityExtractionResult: """Run the entity extraction chain.""" tuple_delimiter = args.get("tuple_delimiter", None) record_delimiter = args.get("record_delimiter", None) completion_delimiter = args.get("completion_delimiter", None) extraction_prompt = args.get("extraction_prompt", None) encoding_model = args.get("encoding_name", None) max_gleanings = args.get("max_gleanings", defs.EXTRACT_GRAPH_MAX_GLEANINGS) extractor = GraphExtractor( llm_invoker=llm, prompt=extraction_prompt, encoding_model=encoding_model, max_gleanings=max_gleanings, on_error=lambda e, s, d: ( callbacks.error("Entity Extraction Error", e, s, d) if callbacks else None ), ) text_list = [doc.text.strip() for doc in docs] results = await extractor( list(text_list), { "entity_types": entity_types, "tuple_delimiter": tuple_delimiter, "record_delimiter": record_delimiter, "completion_delimiter": completion_delimiter, }, ) graph = results.output # Map the "source_id" back to the "id" field for _, node in graph.nodes(data=True): # type: ignore if node is not None: node["source_id"] = ",".join( docs[int(id)].id for id in node["source_id"].split(",") ) for _, _, edge in graph.edges(data=True): # type: ignore if edge is not None: edge["source_id"] = ",".join( docs[int(id)].id for id in edge["source_id"].split(",") ) entities = [ ({"title": item[0], **(item[1] or {})}) for item in graph.nodes(data=True) if item is not None ] relationships = nx.to_pandas_edgelist(graph) return EntityExtractionResult(entities, relationships, graph)