fix community context builder (#783)

fix and refactor community context builder

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Ha Trinh 2024-07-30 19:14:40 -07:00 committed by GitHub
parent d26491a622
commit 7e1529ac19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 107 additions and 105 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix missing community reports and refactor community context builder"
}

View File

@ -115,7 +115,10 @@
"\n",
"reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)\n",
"entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n",
"print(f\"Report records: {len(report_df)}\")\n",
"print(f\"Total report count: {len(report_df)}\")\n",
"print(\n",
" f\"Report count after filtering by community level {COMMUNITY_LEVEL}: {len(reports)}\"\n",
")\n",
"report_df.head()"
]
},
@ -223,17 +226,9 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LLM calls: 13. LLM tokens: 184660\n"
]
}
],
"outputs": [],
"source": [
"# inspect number of LLM calls and tokens\n",
"print(f\"LLM calls: {result.llm_calls}. LLM tokens: {result.prompt_tokens}\")"

View File

@ -41,7 +41,39 @@ def build_community_context(
The calculated weight is added as an attribute to the community reports and added to the context data table.
"""
if (
def _is_included(report: CommunityReport) -> bool:
return report.rank is not None and report.rank >= min_community_rank
def _get_header(attributes: list[str]) -> list[str]:
header = ["id", "title"]
attributes = [col for col in attributes if col not in header]
if not include_community_weight:
attributes = [col for col in attributes if col != community_weight_name]
header.extend(attributes)
header.append("summary" if use_community_summary else "content")
if include_community_rank:
header.append(community_rank_name)
return header
def _report_context_text(
report: CommunityReport, attributes: list[str]
) -> tuple[str, list[str]]:
context: list[str] = [
report.short_id if report.short_id else "",
report.title,
*[
str(report.attributes.get(field, "")) if report.attributes else ""
for field in attributes
],
]
context.append(report.summary if use_community_summary else report.full_content)
if include_community_rank:
context.append(str(report.rank))
result = column_delimiter.join(context) + "\n"
return result, context
compute_community_weights = (
entities
and len(community_reports) > 0
and include_community_weight
@ -49,7 +81,8 @@ def build_community_context(
community_reports[0].attributes is None
or community_weight_name not in community_reports[0].attributes
)
):
)
if compute_community_weights:
log.info("Computing community weights...")
community_reports = _compute_community_weights(
community_reports=community_reports,
@ -58,11 +91,7 @@ def build_community_context(
normalize=normalize_community_weight,
)
selected_reports = [
report
for report in community_reports
if report.rank and report.rank >= min_community_rank
]
selected_reports = [report for report in community_reports if _is_included(report)]
if selected_reports is None or len(selected_reports) == 0:
return ([], {})
@ -70,99 +99,67 @@ def build_community_context(
random.seed(random_state)
random.shuffle(selected_reports)
# add context header
current_context_text = f"-----{context_name}-----" + "\n"
# add header
header = ["id", "title"]
attribute_cols = (
list(selected_reports[0].attributes.keys())
if selected_reports[0].attributes
# "global" variables
attributes = (
list(community_reports[0].attributes.keys())
if community_reports[0].attributes
else []
)
attribute_cols = [col for col in attribute_cols if col not in header]
if not include_community_weight:
attribute_cols = [col for col in attribute_cols if col != community_weight_name]
header.extend(attribute_cols)
header.append("summary" if use_community_summary else "content")
if include_community_rank:
header.append(community_rank_name)
header = _get_header(attributes)
all_context_text: list[str] = []
all_context_records: list[pd.DataFrame] = []
current_context_text += column_delimiter.join(header) + "\n"
current_tokens = num_tokens(current_context_text, token_encoder)
current_context_records = [header]
all_context_text = []
all_context_records = []
# batch variables
batch_text: str = ""
batch_tokens: int = 0
batch_records: list[list[str]] = []
for report in selected_reports:
new_context = [
report.short_id,
report.title,
*[
str(report.attributes.get(field, "")) if report.attributes else ""
for field in attribute_cols
],
]
new_context.append(
report.summary if use_community_summary else report.full_content
def _init_batch() -> None:
nonlocal batch_text, batch_tokens, batch_records
batch_text = (
f"-----{context_name}-----" + "\n" + column_delimiter.join(header) + "\n"
)
if include_community_rank:
new_context.append(str(report.rank))
new_context_text = column_delimiter.join(new_context) + "\n"
batch_tokens = num_tokens(batch_text, token_encoder)
batch_records = []
new_tokens = num_tokens(new_context_text, token_encoder)
if current_tokens + new_tokens > max_tokens:
# convert the current context records to pandas dataframe and sort by weight and rank if exist
if len(current_context_records) > 1:
record_df = _convert_report_context_to_df(
context_records=current_context_records[1:],
header=current_context_records[0],
weight_column=community_weight_name
if entities and include_community_weight
else None,
rank_column=community_rank_name if include_community_rank else None,
)
else:
record_df = pd.DataFrame()
current_context_text = record_df.to_csv(index=False, sep=column_delimiter)
if single_batch:
return current_context_text, {context_name.lower(): record_df}
all_context_text.append(current_context_text)
all_context_records.append(record_df)
# start a new batch
current_context_text = (
f"-----{context_name}-----"
+ "\n"
+ column_delimiter.join(header)
+ "\n"
)
current_tokens = num_tokens(current_context_text, token_encoder)
current_context_records = [header]
else:
current_context_text += new_context_text
current_tokens += new_tokens
current_context_records.append(new_context)
# add the last batch if it has not been added
if current_context_text not in all_context_text:
if len(current_context_records) > 1:
record_df = _convert_report_context_to_df(
context_records=current_context_records[1:],
header=current_context_records[0],
weight_column=community_weight_name
if entities and include_community_weight
else None,
rank_column=community_rank_name if include_community_rank else None,
)
else:
record_df = pd.DataFrame()
all_context_records.append(record_df)
def _cut_batch() -> None:
# convert the current context records to pandas dataframe and sort by weight and rank if exist
record_df = _convert_report_context_to_df(
context_records=batch_records,
header=header,
weight_column=community_weight_name
if entities and include_community_weight
else None,
rank_column=community_rank_name if include_community_rank else None,
)
if len(record_df) == 0:
return
current_context_text = record_df.to_csv(index=False, sep=column_delimiter)
all_context_text.append(current_context_text)
all_context_records.append(record_df)
# initialize the first batch
_init_batch()
for report in selected_reports:
new_context_text, new_context = _report_context_text(report, attributes)
new_tokens = num_tokens(new_context_text, token_encoder)
if batch_tokens + new_tokens > max_tokens:
# add the current batch to the context data and start a new batch if we are in multi-batch mode
_cut_batch()
if single_batch:
break
_init_batch()
# add current report to the current batch
batch_text += new_context_text
batch_tokens += new_tokens
batch_records.append(new_context)
# add the last batch if it has not been added
if batch_text not in all_context_text:
_cut_batch()
return all_context_text, {
context_name.lower(): pd.concat(all_context_records, ignore_index=True)
@ -171,11 +168,14 @@ def build_community_context(
def _compute_community_weights(
community_reports: list[CommunityReport],
entities: list[Entity],
entities: list[Entity] | None,
weight_attribute: str = "occurrence",
normalize: bool = True,
) -> list[CommunityReport]:
"""Calculate a community's weight as count of text units associated with entities within the community."""
if not entities:
return community_reports
community_text_units = {}
for entity in entities:
if entity.community_ids:
@ -211,7 +211,7 @@ def _rank_report_context(
rank_column: str | None = "rank",
) -> pd.DataFrame:
"""Sort report context by community weight and rank if exist."""
rank_attributes = []
rank_attributes: list[str] = []
if weight_column:
rank_attributes.append(weight_column)
report_df[weight_column] = report_df[weight_column].astype(float)
@ -230,6 +230,9 @@ def _convert_report_context_to_df(
rank_column: str | None = None,
) -> pd.DataFrame:
"""Convert report context records to pandas dataframe and sort by weight and rank if exist."""
if len(context_records) == 0:
return pd.DataFrame()
record_df = pd.DataFrame(
context_records,
columns=cast(Any, header),