2024-09-25 16:30:22 -07:00
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
2024-10-09 13:46:44 -07:00
import pytest
2024-09-25 16:30:22 -07:00
from pandas . testing import assert_series_equal
2025-01-03 13:59:26 -08:00
from graphrag . callbacks . noop_verb_callbacks import NoopVerbCallbacks
from graphrag . config . create_graphrag_config import create_graphrag_config
2024-10-09 13:46:44 -07:00
from graphrag . config . enums import LLMType
2025-01-03 13:59:26 -08:00
from graphrag . index . run . derive_from_rows import ParallelizationError
from graphrag . index . workflows . create_final_covariates import (
run_workflow ,
2024-09-25 16:30:22 -07:00
workflow_name ,
)
2025-01-03 13:59:26 -08:00
from graphrag . utils . storage import load_table_from_storage
2024-09-25 16:30:22 -07:00
from . util import (
2025-01-03 13:59:26 -08:00
create_test_context ,
2024-12-05 09:57:26 -08:00
load_test_table ,
2024-09-25 16:30:22 -07:00
)
2024-10-09 13:46:44 -07:00
MOCK_LLM_RESPONSES = [
"""
( COMPANY A < | > GOVERNMENT AGENCY B < | > ANTI - COMPETITIVE PRACTICES < | > TRUE < | > 2022 - 01 - 10 T00 : 00 : 00 < | > 2022 - 01 - 10 T00 : 00 : 00 < | > Company A was found to engage in anti - competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022 / 01 / 10 < | > According to an article published on 2022 / 01 / 10 , Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B . )
""" .strip()
]
MOCK_LLM_CONFIG = { " type " : LLMType . StaticResponse , " responses " : MOCK_LLM_RESPONSES }
2024-09-25 16:30:22 -07:00
async def test_create_final_covariates ( ) :
2025-01-03 13:59:26 -08:00
input = load_test_table ( " create_base_text_units " )
2024-12-05 09:57:26 -08:00
expected = load_test_table ( workflow_name )
2024-09-25 16:30:22 -07:00
2025-01-03 13:59:26 -08:00
context = await create_test_context (
storage = [ " create_base_text_units " ] ,
2024-10-24 10:20:03 -07:00
)
2025-01-03 13:59:26 -08:00
config = create_graphrag_config ( )
config . claim_extraction . strategy = {
" type " : " graph_intelligence " ,
" llm " : MOCK_LLM_CONFIG ,
" claim_description " : " description " ,
}
2024-09-25 16:30:22 -07:00
2025-01-03 13:59:26 -08:00
await run_workflow (
config ,
2024-10-24 10:20:03 -07:00
context ,
2025-01-03 13:59:26 -08:00
NoopVerbCallbacks ( ) ,
2024-09-25 16:30:22 -07:00
)
2025-01-03 13:59:26 -08:00
actual = await load_table_from_storage ( workflow_name , context . storage )
2024-10-30 11:59:44 -06:00
assert len ( actual . columns ) == len ( expected . columns )
2024-09-25 16:30:22 -07:00
# our mock only returns one covariate per text unit, so that's a 1:1 mapping versus the LLM-extracted content in the test data
assert len ( actual ) == len ( input )
# assert all of the columns that covariates copied from the input
assert_series_equal ( actual [ " text_unit_id " ] , input [ " id " ] , check_names = False )
2024-11-13 15:11:19 -08:00
# make sure the human ids are incrementing
assert actual [ " human_readable_id " ] [ 0 ] == 1
assert actual [ " human_readable_id " ] [ 1 ] == 2
2024-09-25 16:30:22 -07:00
# check that the mock data is parsed and inserted into the correct columns
assert actual [ " covariate_type " ] [ 0 ] == " claim "
assert actual [ " subject_id " ] [ 0 ] == " COMPANY A "
assert actual [ " object_id " ] [ 0 ] == " GOVERNMENT AGENCY B "
assert actual [ " type " ] [ 0 ] == " ANTI-COMPETITIVE PRACTICES "
assert actual [ " status " ] [ 0 ] == " TRUE "
assert actual [ " start_date " ] [ 0 ] == " 2022-01-10T00:00:00 "
assert actual [ " end_date " ] [ 0 ] == " 2022-01-10T00:00:00 "
assert (
actual [ " description " ] [ 0 ]
== " Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10 "
)
assert (
actual [ " source_text " ] [ 0 ]
== " According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. "
)
2024-10-09 13:46:44 -07:00
async def test_create_final_covariates_missing_llm_throws ( ) :
2025-01-03 13:59:26 -08:00
context = await create_test_context (
storage = [ " create_base_text_units " ] ,
2024-10-24 10:20:03 -07:00
)
2025-01-03 13:59:26 -08:00
config = create_graphrag_config ( )
config . claim_extraction . strategy = {
" type " : " graph_intelligence " ,
" claim_description " : " description " ,
}
2024-10-09 13:46:44 -07:00
2025-01-03 13:59:26 -08:00
with pytest . raises ( ParallelizationError ) :
await run_workflow (
config ,
2024-10-24 10:20:03 -07:00
context ,
2025-01-03 13:59:26 -08:00
NoopVerbCallbacks ( ) ,
2024-10-09 13:46:44 -07:00
)