graphrag/graphrag/index/operations/extract_graph/graph_intelligence_strategy.py

100 lines
3.3 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing run_graph_intelligence, run_extract_graph and _create_text_splitter methods to run graph intelligence."""
import networkx as nx
from fnllm.types import ChatLLM
import graphrag.config.defaults as defs
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.index.llm.load_llm import load_llm
from graphrag.index.operations.extract_graph.graph_extractor import GraphExtractor
from graphrag.index.operations.extract_graph.typing import (
Document,
EntityExtractionResult,
EntityTypes,
StrategyConfig,
)
async def run_graph_intelligence(
docs: list[Document],
entity_types: EntityTypes,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
args: StrategyConfig,
) -> EntityExtractionResult:
"""Run the graph intelligence entity extraction strategy."""
llm_config = LanguageModelConfig(**args["llm"])
llm = load_llm(
"extract_graph",
llm_config,
callbacks=callbacks,
cache=cache,
)
return await run_extract_graph(llm, docs, entity_types, callbacks, args)
async def run_extract_graph(
llm: ChatLLM,
docs: list[Document],
entity_types: EntityTypes,
callbacks: WorkflowCallbacks | None,
args: StrategyConfig,
) -> EntityExtractionResult:
"""Run the entity extraction chain."""
tuple_delimiter = args.get("tuple_delimiter", None)
record_delimiter = args.get("record_delimiter", None)
completion_delimiter = args.get("completion_delimiter", None)
extraction_prompt = args.get("extraction_prompt", None)
encoding_model = args.get("encoding_name", None)
max_gleanings = args.get("max_gleanings", defs.EXTRACT_GRAPH_MAX_GLEANINGS)
extractor = GraphExtractor(
llm_invoker=llm,
prompt=extraction_prompt,
encoding_model=encoding_model,
max_gleanings=max_gleanings,
on_error=lambda e, s, d: (
callbacks.error("Entity Extraction Error", e, s, d) if callbacks else None
),
)
text_list = [doc.text.strip() for doc in docs]
results = await extractor(
list(text_list),
{
"entity_types": entity_types,
"tuple_delimiter": tuple_delimiter,
"record_delimiter": record_delimiter,
"completion_delimiter": completion_delimiter,
},
)
graph = results.output
# Map the "source_id" back to the "id" field
for _, node in graph.nodes(data=True): # type: ignore
if node is not None:
node["source_id"] = ",".join(
docs[int(id)].id for id in node["source_id"].split(",")
)
for _, _, edge in graph.edges(data=True): # type: ignore
if edge is not None:
edge["source_id"] = ",".join(
docs[int(id)].id for id in edge["source_id"].split(",")
)
entities = [
({"title": item[0], **(item[1] or {})})
for item in graph.nodes(data=True)
if item is not None
]
relationships = nx.to_pandas_edgelist(graph)
return EntityExtractionResult(entities, relationships, graph)