mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
fix community context builder (#783)
fix and refactor community context builder Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
d26491a622
commit
7e1529ac19
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Fix missing community reports and refactor community context builder"
|
||||||
|
}
|
@ -115,7 +115,10 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)\n",
|
"reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)\n",
|
||||||
"entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n",
|
"entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n",
|
||||||
"print(f\"Report records: {len(report_df)}\")\n",
|
"print(f\"Total report count: {len(report_df)}\")\n",
|
||||||
|
"print(\n",
|
||||||
|
" f\"Report count after filtering by community level {COMMUNITY_LEVEL}: {len(reports)}\"\n",
|
||||||
|
")\n",
|
||||||
"report_df.head()"
|
"report_df.head()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -223,17 +226,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 31,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"LLM calls: 13. LLM tokens: 184660\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# inspect number of LLM calls and tokens\n",
|
"# inspect number of LLM calls and tokens\n",
|
||||||
"print(f\"LLM calls: {result.llm_calls}. LLM tokens: {result.prompt_tokens}\")"
|
"print(f\"LLM calls: {result.llm_calls}. LLM tokens: {result.prompt_tokens}\")"
|
||||||
|
@ -41,7 +41,39 @@ def build_community_context(
|
|||||||
|
|
||||||
The calculated weight is added as an attribute to the community reports and added to the context data table.
|
The calculated weight is added as an attribute to the community reports and added to the context data table.
|
||||||
"""
|
"""
|
||||||
if (
|
|
||||||
|
def _is_included(report: CommunityReport) -> bool:
|
||||||
|
return report.rank is not None and report.rank >= min_community_rank
|
||||||
|
|
||||||
|
def _get_header(attributes: list[str]) -> list[str]:
|
||||||
|
header = ["id", "title"]
|
||||||
|
attributes = [col for col in attributes if col not in header]
|
||||||
|
if not include_community_weight:
|
||||||
|
attributes = [col for col in attributes if col != community_weight_name]
|
||||||
|
header.extend(attributes)
|
||||||
|
header.append("summary" if use_community_summary else "content")
|
||||||
|
if include_community_rank:
|
||||||
|
header.append(community_rank_name)
|
||||||
|
return header
|
||||||
|
|
||||||
|
def _report_context_text(
|
||||||
|
report: CommunityReport, attributes: list[str]
|
||||||
|
) -> tuple[str, list[str]]:
|
||||||
|
context: list[str] = [
|
||||||
|
report.short_id if report.short_id else "",
|
||||||
|
report.title,
|
||||||
|
*[
|
||||||
|
str(report.attributes.get(field, "")) if report.attributes else ""
|
||||||
|
for field in attributes
|
||||||
|
],
|
||||||
|
]
|
||||||
|
context.append(report.summary if use_community_summary else report.full_content)
|
||||||
|
if include_community_rank:
|
||||||
|
context.append(str(report.rank))
|
||||||
|
result = column_delimiter.join(context) + "\n"
|
||||||
|
return result, context
|
||||||
|
|
||||||
|
compute_community_weights = (
|
||||||
entities
|
entities
|
||||||
and len(community_reports) > 0
|
and len(community_reports) > 0
|
||||||
and include_community_weight
|
and include_community_weight
|
||||||
@ -49,7 +81,8 @@ def build_community_context(
|
|||||||
community_reports[0].attributes is None
|
community_reports[0].attributes is None
|
||||||
or community_weight_name not in community_reports[0].attributes
|
or community_weight_name not in community_reports[0].attributes
|
||||||
)
|
)
|
||||||
):
|
)
|
||||||
|
if compute_community_weights:
|
||||||
log.info("Computing community weights...")
|
log.info("Computing community weights...")
|
||||||
community_reports = _compute_community_weights(
|
community_reports = _compute_community_weights(
|
||||||
community_reports=community_reports,
|
community_reports=community_reports,
|
||||||
@ -58,11 +91,7 @@ def build_community_context(
|
|||||||
normalize=normalize_community_weight,
|
normalize=normalize_community_weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
selected_reports = [
|
selected_reports = [report for report in community_reports if _is_included(report)]
|
||||||
report
|
|
||||||
for report in community_reports
|
|
||||||
if report.rank and report.rank >= min_community_rank
|
|
||||||
]
|
|
||||||
if selected_reports is None or len(selected_reports) == 0:
|
if selected_reports is None or len(selected_reports) == 0:
|
||||||
return ([], {})
|
return ([], {})
|
||||||
|
|
||||||
@ -70,99 +99,67 @@ def build_community_context(
|
|||||||
random.seed(random_state)
|
random.seed(random_state)
|
||||||
random.shuffle(selected_reports)
|
random.shuffle(selected_reports)
|
||||||
|
|
||||||
# add context header
|
# "global" variables
|
||||||
current_context_text = f"-----{context_name}-----" + "\n"
|
attributes = (
|
||||||
|
list(community_reports[0].attributes.keys())
|
||||||
# add header
|
if community_reports[0].attributes
|
||||||
header = ["id", "title"]
|
|
||||||
attribute_cols = (
|
|
||||||
list(selected_reports[0].attributes.keys())
|
|
||||||
if selected_reports[0].attributes
|
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
attribute_cols = [col for col in attribute_cols if col not in header]
|
header = _get_header(attributes)
|
||||||
if not include_community_weight:
|
all_context_text: list[str] = []
|
||||||
attribute_cols = [col for col in attribute_cols if col != community_weight_name]
|
all_context_records: list[pd.DataFrame] = []
|
||||||
header.extend(attribute_cols)
|
|
||||||
header.append("summary" if use_community_summary else "content")
|
|
||||||
if include_community_rank:
|
|
||||||
header.append(community_rank_name)
|
|
||||||
|
|
||||||
current_context_text += column_delimiter.join(header) + "\n"
|
# batch variables
|
||||||
current_tokens = num_tokens(current_context_text, token_encoder)
|
batch_text: str = ""
|
||||||
current_context_records = [header]
|
batch_tokens: int = 0
|
||||||
all_context_text = []
|
batch_records: list[list[str]] = []
|
||||||
all_context_records = []
|
|
||||||
|
def _init_batch() -> None:
|
||||||
|
nonlocal batch_text, batch_tokens, batch_records
|
||||||
|
batch_text = (
|
||||||
|
f"-----{context_name}-----" + "\n" + column_delimiter.join(header) + "\n"
|
||||||
|
)
|
||||||
|
batch_tokens = num_tokens(batch_text, token_encoder)
|
||||||
|
batch_records = []
|
||||||
|
|
||||||
|
def _cut_batch() -> None:
|
||||||
|
# convert the current context records to pandas dataframe and sort by weight and rank if exist
|
||||||
|
record_df = _convert_report_context_to_df(
|
||||||
|
context_records=batch_records,
|
||||||
|
header=header,
|
||||||
|
weight_column=community_weight_name
|
||||||
|
if entities and include_community_weight
|
||||||
|
else None,
|
||||||
|
rank_column=community_rank_name if include_community_rank else None,
|
||||||
|
)
|
||||||
|
if len(record_df) == 0:
|
||||||
|
return
|
||||||
|
current_context_text = record_df.to_csv(index=False, sep=column_delimiter)
|
||||||
|
all_context_text.append(current_context_text)
|
||||||
|
all_context_records.append(record_df)
|
||||||
|
|
||||||
|
# initialize the first batch
|
||||||
|
_init_batch()
|
||||||
|
|
||||||
for report in selected_reports:
|
for report in selected_reports:
|
||||||
new_context = [
|
new_context_text, new_context = _report_context_text(report, attributes)
|
||||||
report.short_id,
|
|
||||||
report.title,
|
|
||||||
*[
|
|
||||||
str(report.attributes.get(field, "")) if report.attributes else ""
|
|
||||||
for field in attribute_cols
|
|
||||||
],
|
|
||||||
]
|
|
||||||
new_context.append(
|
|
||||||
report.summary if use_community_summary else report.full_content
|
|
||||||
)
|
|
||||||
if include_community_rank:
|
|
||||||
new_context.append(str(report.rank))
|
|
||||||
new_context_text = column_delimiter.join(new_context) + "\n"
|
|
||||||
|
|
||||||
new_tokens = num_tokens(new_context_text, token_encoder)
|
new_tokens = num_tokens(new_context_text, token_encoder)
|
||||||
if current_tokens + new_tokens > max_tokens:
|
|
||||||
# convert the current context records to pandas dataframe and sort by weight and rank if exist
|
|
||||||
if len(current_context_records) > 1:
|
|
||||||
record_df = _convert_report_context_to_df(
|
|
||||||
context_records=current_context_records[1:],
|
|
||||||
header=current_context_records[0],
|
|
||||||
weight_column=community_weight_name
|
|
||||||
if entities and include_community_weight
|
|
||||||
else None,
|
|
||||||
rank_column=community_rank_name if include_community_rank else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
record_df = pd.DataFrame()
|
|
||||||
current_context_text = record_df.to_csv(index=False, sep=column_delimiter)
|
|
||||||
|
|
||||||
|
if batch_tokens + new_tokens > max_tokens:
|
||||||
|
# add the current batch to the context data and start a new batch if we are in multi-batch mode
|
||||||
|
_cut_batch()
|
||||||
if single_batch:
|
if single_batch:
|
||||||
return current_context_text, {context_name.lower(): record_df}
|
break
|
||||||
|
_init_batch()
|
||||||
|
|
||||||
all_context_text.append(current_context_text)
|
# add current report to the current batch
|
||||||
all_context_records.append(record_df)
|
batch_text += new_context_text
|
||||||
|
batch_tokens += new_tokens
|
||||||
# start a new batch
|
batch_records.append(new_context)
|
||||||
current_context_text = (
|
|
||||||
f"-----{context_name}-----"
|
|
||||||
+ "\n"
|
|
||||||
+ column_delimiter.join(header)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
current_tokens = num_tokens(current_context_text, token_encoder)
|
|
||||||
current_context_records = [header]
|
|
||||||
else:
|
|
||||||
current_context_text += new_context_text
|
|
||||||
current_tokens += new_tokens
|
|
||||||
current_context_records.append(new_context)
|
|
||||||
|
|
||||||
# add the last batch if it has not been added
|
# add the last batch if it has not been added
|
||||||
if current_context_text not in all_context_text:
|
if batch_text not in all_context_text:
|
||||||
if len(current_context_records) > 1:
|
_cut_batch()
|
||||||
record_df = _convert_report_context_to_df(
|
|
||||||
context_records=current_context_records[1:],
|
|
||||||
header=current_context_records[0],
|
|
||||||
weight_column=community_weight_name
|
|
||||||
if entities and include_community_weight
|
|
||||||
else None,
|
|
||||||
rank_column=community_rank_name if include_community_rank else None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
record_df = pd.DataFrame()
|
|
||||||
all_context_records.append(record_df)
|
|
||||||
current_context_text = record_df.to_csv(index=False, sep=column_delimiter)
|
|
||||||
all_context_text.append(current_context_text)
|
|
||||||
|
|
||||||
return all_context_text, {
|
return all_context_text, {
|
||||||
context_name.lower(): pd.concat(all_context_records, ignore_index=True)
|
context_name.lower(): pd.concat(all_context_records, ignore_index=True)
|
||||||
@ -171,11 +168,14 @@ def build_community_context(
|
|||||||
|
|
||||||
def _compute_community_weights(
|
def _compute_community_weights(
|
||||||
community_reports: list[CommunityReport],
|
community_reports: list[CommunityReport],
|
||||||
entities: list[Entity],
|
entities: list[Entity] | None,
|
||||||
weight_attribute: str = "occurrence",
|
weight_attribute: str = "occurrence",
|
||||||
normalize: bool = True,
|
normalize: bool = True,
|
||||||
) -> list[CommunityReport]:
|
) -> list[CommunityReport]:
|
||||||
"""Calculate a community's weight as count of text units associated with entities within the community."""
|
"""Calculate a community's weight as count of text units associated with entities within the community."""
|
||||||
|
if not entities:
|
||||||
|
return community_reports
|
||||||
|
|
||||||
community_text_units = {}
|
community_text_units = {}
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
if entity.community_ids:
|
if entity.community_ids:
|
||||||
@ -211,7 +211,7 @@ def _rank_report_context(
|
|||||||
rank_column: str | None = "rank",
|
rank_column: str | None = "rank",
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""Sort report context by community weight and rank if exist."""
|
"""Sort report context by community weight and rank if exist."""
|
||||||
rank_attributes = []
|
rank_attributes: list[str] = []
|
||||||
if weight_column:
|
if weight_column:
|
||||||
rank_attributes.append(weight_column)
|
rank_attributes.append(weight_column)
|
||||||
report_df[weight_column] = report_df[weight_column].astype(float)
|
report_df[weight_column] = report_df[weight_column].astype(float)
|
||||||
@ -230,6 +230,9 @@ def _convert_report_context_to_df(
|
|||||||
rank_column: str | None = None,
|
rank_column: str | None = None,
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""Convert report context records to pandas dataframe and sort by weight and rank if exist."""
|
"""Convert report context records to pandas dataframe and sort by weight and rank if exist."""
|
||||||
|
if len(context_records) == 0:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
record_df = pd.DataFrame(
|
record_df = pd.DataFrame(
|
||||||
context_records,
|
context_records,
|
||||||
columns=cast(Any, header),
|
columns=cast(Any, header),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user