2024-07-01 15:25:30 -06:00
|
|
|
# Copyright (c) 2024 Microsoft Corporation.
|
|
|
|
# Licensed under the MIT License
|
|
|
|
|
2025-02-07 11:11:03 -08:00
|
|
|
"""A module containing run_graph_intelligence, run_extract_graph and _create_text_splitter methods to run graph intelligence."""
|
2024-07-01 15:25:30 -06:00
|
|
|
|
2024-12-05 09:57:26 -08:00
|
|
|
import networkx as nx
|
2025-02-13 16:56:37 -05:00
|
|
|
from fnllm.types import ChatLLM
|
2024-07-01 15:25:30 -06:00
|
|
|
|
2024-07-11 10:22:27 -06:00
|
|
|
import graphrag.config.defaults as defs
|
2024-11-27 13:27:43 -05:00
|
|
|
from graphrag.cache.pipeline_cache import PipelineCache
|
2025-01-06 10:58:59 -08:00
|
|
|
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
2025-01-21 15:52:06 -08:00
|
|
|
from graphrag.config.models.language_model_config import LanguageModelConfig
|
|
|
|
from graphrag.index.llm.load_llm import load_llm
|
2025-02-07 11:11:03 -08:00
|
|
|
from graphrag.index.operations.extract_graph.graph_extractor import GraphExtractor
|
|
|
|
from graphrag.index.operations.extract_graph.typing import (
|
2024-07-01 15:25:30 -06:00
|
|
|
Document,
|
|
|
|
EntityExtractionResult,
|
|
|
|
EntityTypes,
|
|
|
|
StrategyConfig,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-10-09 13:46:44 -07:00
|
|
|
async def run_graph_intelligence(
|
2024-07-01 15:25:30 -06:00
|
|
|
docs: list[Document],
|
|
|
|
entity_types: EntityTypes,
|
2025-01-06 10:58:59 -08:00
|
|
|
callbacks: WorkflowCallbacks,
|
2024-10-09 13:46:44 -07:00
|
|
|
cache: PipelineCache,
|
2024-07-01 15:25:30 -06:00
|
|
|
args: StrategyConfig,
|
|
|
|
) -> EntityExtractionResult:
|
|
|
|
"""Run the graph intelligence entity extraction strategy."""
|
2025-01-21 15:52:06 -08:00
|
|
|
llm_config = LanguageModelConfig(**args["llm"])
|
|
|
|
llm = load_llm(
|
2025-02-07 11:11:03 -08:00
|
|
|
"extract_graph",
|
2025-01-21 15:52:06 -08:00
|
|
|
llm_config,
|
|
|
|
callbacks=callbacks,
|
|
|
|
cache=cache,
|
|
|
|
)
|
2025-02-07 11:11:03 -08:00
|
|
|
return await run_extract_graph(llm, docs, entity_types, callbacks, args)
|
2024-07-01 15:25:30 -06:00
|
|
|
|
|
|
|
|
2025-02-07 11:11:03 -08:00
|
|
|
async def run_extract_graph(
|
2024-12-05 16:07:47 -08:00
|
|
|
llm: ChatLLM,
|
2024-07-01 15:25:30 -06:00
|
|
|
docs: list[Document],
|
|
|
|
entity_types: EntityTypes,
|
2025-01-06 10:58:59 -08:00
|
|
|
callbacks: WorkflowCallbacks | None,
|
2024-07-01 15:25:30 -06:00
|
|
|
args: StrategyConfig,
|
|
|
|
) -> EntityExtractionResult:
|
|
|
|
"""Run the entity extraction chain."""
|
|
|
|
tuple_delimiter = args.get("tuple_delimiter", None)
|
|
|
|
record_delimiter = args.get("record_delimiter", None)
|
|
|
|
completion_delimiter = args.get("completion_delimiter", None)
|
|
|
|
extraction_prompt = args.get("extraction_prompt", None)
|
|
|
|
encoding_model = args.get("encoding_name", None)
|
2025-02-07 11:11:03 -08:00
|
|
|
max_gleanings = args.get("max_gleanings", defs.EXTRACT_GRAPH_MAX_GLEANINGS)
|
2024-07-01 15:25:30 -06:00
|
|
|
|
|
|
|
extractor = GraphExtractor(
|
|
|
|
llm_invoker=llm,
|
|
|
|
prompt=extraction_prompt,
|
|
|
|
encoding_model=encoding_model,
|
|
|
|
max_gleanings=max_gleanings,
|
|
|
|
on_error=lambda e, s, d: (
|
2024-10-09 13:46:44 -07:00
|
|
|
callbacks.error("Entity Extraction Error", e, s, d) if callbacks else None
|
2024-07-01 15:25:30 -06:00
|
|
|
),
|
|
|
|
)
|
|
|
|
text_list = [doc.text.strip() for doc in docs]
|
|
|
|
|
|
|
|
results = await extractor(
|
|
|
|
list(text_list),
|
|
|
|
{
|
|
|
|
"entity_types": entity_types,
|
|
|
|
"tuple_delimiter": tuple_delimiter,
|
|
|
|
"record_delimiter": record_delimiter,
|
|
|
|
"completion_delimiter": completion_delimiter,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
graph = results.output
|
|
|
|
# Map the "source_id" back to the "id" field
|
|
|
|
for _, node in graph.nodes(data=True): # type: ignore
|
|
|
|
if node is not None:
|
|
|
|
node["source_id"] = ",".join(
|
|
|
|
docs[int(id)].id for id in node["source_id"].split(",")
|
|
|
|
)
|
|
|
|
|
|
|
|
for _, _, edge in graph.edges(data=True): # type: ignore
|
|
|
|
if edge is not None:
|
|
|
|
edge["source_id"] = ",".join(
|
|
|
|
docs[int(id)].id for id in edge["source_id"].split(",")
|
|
|
|
)
|
|
|
|
|
|
|
|
entities = [
|
2024-12-18 18:07:44 -08:00
|
|
|
({"title": item[0], **(item[1] or {})})
|
2024-07-01 15:25:30 -06:00
|
|
|
for item in graph.nodes(data=True)
|
|
|
|
if item is not None
|
|
|
|
]
|
|
|
|
|
2024-12-05 09:57:26 -08:00
|
|
|
relationships = nx.to_pandas_edgelist(graph)
|
|
|
|
|
|
|
|
return EntityExtractionResult(entities, relationships, graph)
|