2024-09-25 16:30:22 -07:00
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from pandas . testing import assert_series_equal
2025-01-03 13:59:26 -08:00
from graphrag . config . create_graphrag_config import create_graphrag_config
2025-02-20 08:56:20 -06:00
from graphrag . config . enums import ModelType
2025-02-27 09:31:46 -08:00
from graphrag . data_model . schemas import COVARIATES_FINAL_COLUMNS
2025-02-07 11:11:03 -08:00
from graphrag . index . workflows . extract_covariates import (
2025-01-03 13:59:26 -08:00
run_workflow ,
2024-09-25 16:30:22 -07:00
)
2025-01-03 13:59:26 -08:00
from graphrag . utils . storage import load_table_from_storage
2024-09-25 16:30:22 -07:00
from . util import (
2025-01-21 15:52:06 -08:00
DEFAULT_MODEL_CONFIG ,
2025-01-03 13:59:26 -08:00
create_test_context ,
2024-12-05 09:57:26 -08:00
load_test_table ,
2024-09-25 16:30:22 -07:00
)
2024-10-09 13:46:44 -07:00
MOCK_LLM_RESPONSES = [
"""
( COMPANY A < | > GOVERNMENT AGENCY B < | > ANTI - COMPETITIVE PRACTICES < | > TRUE < | > 2022 - 01 - 10 T00 : 00 : 00 < | > 2022 - 01 - 10 T00 : 00 : 00 < | > Company A was found to engage in anti - competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022 / 01 / 10 < | > According to an article published on 2022 / 01 / 10 , Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B . )
""" .strip()
]
2024-09-25 16:30:22 -07:00
2025-02-07 11:11:03 -08:00
async def test_extract_covariates ( ) :
input = load_test_table ( " text_units " )
2024-09-25 16:30:22 -07:00
2025-01-03 13:59:26 -08:00
context = await create_test_context (
2025-02-07 11:11:03 -08:00
storage = [ " text_units " ] ,
2024-10-24 10:20:03 -07:00
)
2025-01-21 15:52:06 -08:00
config = create_graphrag_config ( { " models " : DEFAULT_MODEL_CONFIG } )
llm_settings = config . get_language_model_config (
2025-02-07 11:11:03 -08:00
config . extract_claims . model_id
2025-01-21 15:52:06 -08:00
) . model_dump ( )
2025-02-20 08:56:20 -06:00
llm_settings [ " type " ] = ModelType . MockChat
2025-01-21 15:52:06 -08:00
llm_settings [ " responses " ] = MOCK_LLM_RESPONSES
2025-02-07 11:11:03 -08:00
config . extract_claims . strategy = {
2025-01-03 13:59:26 -08:00
" type " : " graph_intelligence " ,
2025-01-21 15:52:06 -08:00
" llm " : llm_settings ,
2025-01-03 13:59:26 -08:00
" claim_description " : " description " ,
}
2024-09-25 16:30:22 -07:00
2025-02-28 09:31:48 -08:00
await run_workflow ( config , context )
2024-09-25 16:30:22 -07:00
2025-02-07 11:11:03 -08:00
actual = await load_table_from_storage ( " covariates " , context . storage )
2024-10-30 11:59:44 -06:00
2025-02-27 09:31:46 -08:00
for column in COVARIATES_FINAL_COLUMNS :
assert column in actual . columns
2024-09-25 16:30:22 -07:00
# our mock only returns one covariate per text unit, so that's a 1:1 mapping versus the LLM-extracted content in the test data
assert len ( actual ) == len ( input )
# assert all of the columns that covariates copied from the input
assert_series_equal ( actual [ " text_unit_id " ] , input [ " id " ] , check_names = False )
2024-11-13 15:11:19 -08:00
# make sure the human ids are incrementing
assert actual [ " human_readable_id " ] [ 0 ] == 1
assert actual [ " human_readable_id " ] [ 1 ] == 2
2024-09-25 16:30:22 -07:00
# check that the mock data is parsed and inserted into the correct columns
assert actual [ " covariate_type " ] [ 0 ] == " claim "
assert actual [ " subject_id " ] [ 0 ] == " COMPANY A "
assert actual [ " object_id " ] [ 0 ] == " GOVERNMENT AGENCY B "
assert actual [ " type " ] [ 0 ] == " ANTI-COMPETITIVE PRACTICES "
assert actual [ " status " ] [ 0 ] == " TRUE "
assert actual [ " start_date " ] [ 0 ] == " 2022-01-10T00:00:00 "
assert actual [ " end_date " ] [ 0 ] == " 2022-01-10T00:00:00 "
assert (
actual [ " description " ] [ 0 ]
== " Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10 "
)
assert (
actual [ " source_text " ] [ 0 ]
== " According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. "
)