mirror of
https://github.com/microsoft/graphrag.git
synced 2025-07-08 09:31:54 +00:00
76 lines
2.9 KiB
Python
76 lines
2.9 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
import logging
|
|
import os
|
|
import unittest
|
|
|
|
from graphrag.index.run import run_pipeline_with_config
|
|
from graphrag.index.typing import PipelineRunResult
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class TestRun(unittest.IsolatedAsyncioTestCase):
|
|
async def test_megapipeline(self):
|
|
pipeline_path = os.path.join(
|
|
os.path.dirname(os.path.abspath(__file__)),
|
|
"./megapipeline.yml",
|
|
)
|
|
pipeline_result = [gen async for gen in run_pipeline_with_config(pipeline_path)]
|
|
|
|
errors = []
|
|
for result in pipeline_result:
|
|
if result.errors is not None and len(result.errors) > 0:
|
|
errors.extend(result.errors)
|
|
|
|
if len(errors) > 0:
|
|
print("Errors: ", errors)
|
|
assert len(errors) == 0, "received errors\n!" + "\n".join(errors)
|
|
|
|
self._assert_text_units_and_entities_reference_each_other(pipeline_result)
|
|
|
|
def _assert_text_units_and_entities_reference_each_other(
|
|
self, pipeline_result: list[PipelineRunResult]
|
|
):
|
|
text_unit_df = next(
|
|
filter(lambda x: x.workflow == "create_final_text_units", pipeline_result)
|
|
).result
|
|
entity_df = next(
|
|
filter(lambda x: x.workflow == "create_final_entities", pipeline_result)
|
|
).result
|
|
|
|
assert text_unit_df is not None, "Text unit dataframe should not be None"
|
|
assert entity_df is not None, "Entity dataframe should not be None"
|
|
|
|
# Get around typing issues
|
|
if text_unit_df is None or entity_df is None:
|
|
return
|
|
|
|
assert len(text_unit_df) > 0, "Text unit dataframe should not be empty"
|
|
assert len(entity_df) > 0, "Entity dataframe should not be empty"
|
|
|
|
text_unit_entity_map = {}
|
|
log.info("text_unit_df %s", text_unit_df.columns)
|
|
|
|
for _, row in text_unit_df.iterrows():
|
|
values = row.get("entity_ids", [])
|
|
text_unit_entity_map[row["id"]] = set([] if values is None else values)
|
|
|
|
entity_text_unit_map = {}
|
|
for _, row in entity_df.iterrows():
|
|
# ALL entities should have text units
|
|
values = row.get("text_unit_ids", [])
|
|
entity_text_unit_map[row["id"]] = set([] if values is None else values)
|
|
|
|
text_unit_ids = set(text_unit_entity_map.keys())
|
|
entity_ids = set(entity_text_unit_map.keys())
|
|
|
|
for text_unit_id, text_unit_entities in text_unit_entity_map.items():
|
|
assert text_unit_entities.issubset(
|
|
entity_ids
|
|
), f"Text unit {text_unit_id} has entities {text_unit_entities} that are not in the entity set"
|
|
for entity_id, entity_text_units in entity_text_unit_map.items():
|
|
assert entity_text_units.issubset(
|
|
text_unit_ids
|
|
), f"Entity {entity_id} has text units {entity_text_units} that are not in the text unit set"
|