mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
fix community context builder (#783)
fix and refactor community context builder Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
d26491a622
commit
7e1529ac19
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Fix missing community reports and refactor community context builder"
|
||||
}
|
@ -115,7 +115,10 @@
|
||||
"\n",
|
||||
"reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)\n",
|
||||
"entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n",
|
||||
"print(f\"Report records: {len(report_df)}\")\n",
|
||||
"print(f\"Total report count: {len(report_df)}\")\n",
|
||||
"print(\n",
|
||||
" f\"Report count after filtering by community level {COMMUNITY_LEVEL}: {len(reports)}\"\n",
|
||||
")\n",
|
||||
"report_df.head()"
|
||||
]
|
||||
},
|
||||
@ -223,17 +226,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"LLM calls: 13. LLM tokens: 184660\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# inspect number of LLM calls and tokens\n",
|
||||
"print(f\"LLM calls: {result.llm_calls}. LLM tokens: {result.prompt_tokens}\")"
|
||||
|
@ -41,7 +41,39 @@ def build_community_context(
|
||||
|
||||
The calculated weight is added as an attribute to the community reports and added to the context data table.
|
||||
"""
|
||||
if (
|
||||
|
||||
def _is_included(report: CommunityReport) -> bool:
|
||||
return report.rank is not None and report.rank >= min_community_rank
|
||||
|
||||
def _get_header(attributes: list[str]) -> list[str]:
|
||||
header = ["id", "title"]
|
||||
attributes = [col for col in attributes if col not in header]
|
||||
if not include_community_weight:
|
||||
attributes = [col for col in attributes if col != community_weight_name]
|
||||
header.extend(attributes)
|
||||
header.append("summary" if use_community_summary else "content")
|
||||
if include_community_rank:
|
||||
header.append(community_rank_name)
|
||||
return header
|
||||
|
||||
def _report_context_text(
|
||||
report: CommunityReport, attributes: list[str]
|
||||
) -> tuple[str, list[str]]:
|
||||
context: list[str] = [
|
||||
report.short_id if report.short_id else "",
|
||||
report.title,
|
||||
*[
|
||||
str(report.attributes.get(field, "")) if report.attributes else ""
|
||||
for field in attributes
|
||||
],
|
||||
]
|
||||
context.append(report.summary if use_community_summary else report.full_content)
|
||||
if include_community_rank:
|
||||
context.append(str(report.rank))
|
||||
result = column_delimiter.join(context) + "\n"
|
||||
return result, context
|
||||
|
||||
compute_community_weights = (
|
||||
entities
|
||||
and len(community_reports) > 0
|
||||
and include_community_weight
|
||||
@ -49,7 +81,8 @@ def build_community_context(
|
||||
community_reports[0].attributes is None
|
||||
or community_weight_name not in community_reports[0].attributes
|
||||
)
|
||||
):
|
||||
)
|
||||
if compute_community_weights:
|
||||
log.info("Computing community weights...")
|
||||
community_reports = _compute_community_weights(
|
||||
community_reports=community_reports,
|
||||
@ -58,11 +91,7 @@ def build_community_context(
|
||||
normalize=normalize_community_weight,
|
||||
)
|
||||
|
||||
selected_reports = [
|
||||
report
|
||||
for report in community_reports
|
||||
if report.rank and report.rank >= min_community_rank
|
||||
]
|
||||
selected_reports = [report for report in community_reports if _is_included(report)]
|
||||
if selected_reports is None or len(selected_reports) == 0:
|
||||
return ([], {})
|
||||
|
||||
@ -70,99 +99,67 @@ def build_community_context(
|
||||
random.seed(random_state)
|
||||
random.shuffle(selected_reports)
|
||||
|
||||
# add context header
|
||||
current_context_text = f"-----{context_name}-----" + "\n"
|
||||
|
||||
# add header
|
||||
header = ["id", "title"]
|
||||
attribute_cols = (
|
||||
list(selected_reports[0].attributes.keys())
|
||||
if selected_reports[0].attributes
|
||||
# "global" variables
|
||||
attributes = (
|
||||
list(community_reports[0].attributes.keys())
|
||||
if community_reports[0].attributes
|
||||
else []
|
||||
)
|
||||
attribute_cols = [col for col in attribute_cols if col not in header]
|
||||
if not include_community_weight:
|
||||
attribute_cols = [col for col in attribute_cols if col != community_weight_name]
|
||||
header.extend(attribute_cols)
|
||||
header.append("summary" if use_community_summary else "content")
|
||||
if include_community_rank:
|
||||
header.append(community_rank_name)
|
||||
header = _get_header(attributes)
|
||||
all_context_text: list[str] = []
|
||||
all_context_records: list[pd.DataFrame] = []
|
||||
|
||||
current_context_text += column_delimiter.join(header) + "\n"
|
||||
current_tokens = num_tokens(current_context_text, token_encoder)
|
||||
current_context_records = [header]
|
||||
all_context_text = []
|
||||
all_context_records = []
|
||||
# batch variables
|
||||
batch_text: str = ""
|
||||
batch_tokens: int = 0
|
||||
batch_records: list[list[str]] = []
|
||||
|
||||
def _init_batch() -> None:
|
||||
nonlocal batch_text, batch_tokens, batch_records
|
||||
batch_text = (
|
||||
f"-----{context_name}-----" + "\n" + column_delimiter.join(header) + "\n"
|
||||
)
|
||||
batch_tokens = num_tokens(batch_text, token_encoder)
|
||||
batch_records = []
|
||||
|
||||
def _cut_batch() -> None:
|
||||
# convert the current context records to pandas dataframe and sort by weight and rank if exist
|
||||
record_df = _convert_report_context_to_df(
|
||||
context_records=batch_records,
|
||||
header=header,
|
||||
weight_column=community_weight_name
|
||||
if entities and include_community_weight
|
||||
else None,
|
||||
rank_column=community_rank_name if include_community_rank else None,
|
||||
)
|
||||
if len(record_df) == 0:
|
||||
return
|
||||
current_context_text = record_df.to_csv(index=False, sep=column_delimiter)
|
||||
all_context_text.append(current_context_text)
|
||||
all_context_records.append(record_df)
|
||||
|
||||
# initialize the first batch
|
||||
_init_batch()
|
||||
|
||||
for report in selected_reports:
|
||||
new_context = [
|
||||
report.short_id,
|
||||
report.title,
|
||||
*[
|
||||
str(report.attributes.get(field, "")) if report.attributes else ""
|
||||
for field in attribute_cols
|
||||
],
|
||||
]
|
||||
new_context.append(
|
||||
report.summary if use_community_summary else report.full_content
|
||||
)
|
||||
if include_community_rank:
|
||||
new_context.append(str(report.rank))
|
||||
new_context_text = column_delimiter.join(new_context) + "\n"
|
||||
|
||||
new_context_text, new_context = _report_context_text(report, attributes)
|
||||
new_tokens = num_tokens(new_context_text, token_encoder)
|
||||
if current_tokens + new_tokens > max_tokens:
|
||||
# convert the current context records to pandas dataframe and sort by weight and rank if exist
|
||||
if len(current_context_records) > 1:
|
||||
record_df = _convert_report_context_to_df(
|
||||
context_records=current_context_records[1:],
|
||||
header=current_context_records[0],
|
||||
weight_column=community_weight_name
|
||||
if entities and include_community_weight
|
||||
else None,
|
||||
rank_column=community_rank_name if include_community_rank else None,
|
||||
)
|
||||
|
||||
else:
|
||||
record_df = pd.DataFrame()
|
||||
current_context_text = record_df.to_csv(index=False, sep=column_delimiter)
|
||||
|
||||
if batch_tokens + new_tokens > max_tokens:
|
||||
# add the current batch to the context data and start a new batch if we are in multi-batch mode
|
||||
_cut_batch()
|
||||
if single_batch:
|
||||
return current_context_text, {context_name.lower(): record_df}
|
||||
break
|
||||
_init_batch()
|
||||
|
||||
all_context_text.append(current_context_text)
|
||||
all_context_records.append(record_df)
|
||||
|
||||
# start a new batch
|
||||
current_context_text = (
|
||||
f"-----{context_name}-----"
|
||||
+ "\n"
|
||||
+ column_delimiter.join(header)
|
||||
+ "\n"
|
||||
)
|
||||
current_tokens = num_tokens(current_context_text, token_encoder)
|
||||
current_context_records = [header]
|
||||
else:
|
||||
current_context_text += new_context_text
|
||||
current_tokens += new_tokens
|
||||
current_context_records.append(new_context)
|
||||
# add current report to the current batch
|
||||
batch_text += new_context_text
|
||||
batch_tokens += new_tokens
|
||||
batch_records.append(new_context)
|
||||
|
||||
# add the last batch if it has not been added
|
||||
if current_context_text not in all_context_text:
|
||||
if len(current_context_records) > 1:
|
||||
record_df = _convert_report_context_to_df(
|
||||
context_records=current_context_records[1:],
|
||||
header=current_context_records[0],
|
||||
weight_column=community_weight_name
|
||||
if entities and include_community_weight
|
||||
else None,
|
||||
rank_column=community_rank_name if include_community_rank else None,
|
||||
)
|
||||
else:
|
||||
record_df = pd.DataFrame()
|
||||
all_context_records.append(record_df)
|
||||
current_context_text = record_df.to_csv(index=False, sep=column_delimiter)
|
||||
all_context_text.append(current_context_text)
|
||||
if batch_text not in all_context_text:
|
||||
_cut_batch()
|
||||
|
||||
return all_context_text, {
|
||||
context_name.lower(): pd.concat(all_context_records, ignore_index=True)
|
||||
@ -171,11 +168,14 @@ def build_community_context(
|
||||
|
||||
def _compute_community_weights(
|
||||
community_reports: list[CommunityReport],
|
||||
entities: list[Entity],
|
||||
entities: list[Entity] | None,
|
||||
weight_attribute: str = "occurrence",
|
||||
normalize: bool = True,
|
||||
) -> list[CommunityReport]:
|
||||
"""Calculate a community's weight as count of text units associated with entities within the community."""
|
||||
if not entities:
|
||||
return community_reports
|
||||
|
||||
community_text_units = {}
|
||||
for entity in entities:
|
||||
if entity.community_ids:
|
||||
@ -211,7 +211,7 @@ def _rank_report_context(
|
||||
rank_column: str | None = "rank",
|
||||
) -> pd.DataFrame:
|
||||
"""Sort report context by community weight and rank if exist."""
|
||||
rank_attributes = []
|
||||
rank_attributes: list[str] = []
|
||||
if weight_column:
|
||||
rank_attributes.append(weight_column)
|
||||
report_df[weight_column] = report_df[weight_column].astype(float)
|
||||
@ -230,6 +230,9 @@ def _convert_report_context_to_df(
|
||||
rank_column: str | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Convert report context records to pandas dataframe and sort by weight and rank if exist."""
|
||||
if len(context_records) == 0:
|
||||
return pd.DataFrame()
|
||||
|
||||
record_df = pd.DataFrame(
|
||||
context_records,
|
||||
columns=cast(Any, header),
|
||||
|
Loading…
x
Reference in New Issue
Block a user