fix community context builder (#783)

fix and refactor community context builder

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Ha Trinh 2024-07-30 19:14:40 -07:00 committed by GitHub
parent d26491a622
commit 7e1529ac19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 107 additions and 105 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix missing community reports and refactor community context builder"
}

View File

@ -115,7 +115,10 @@
"\n", "\n",
"reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)\n", "reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)\n",
"entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n", "entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n",
"print(f\"Report records: {len(report_df)}\")\n", "print(f\"Total report count: {len(report_df)}\")\n",
"print(\n",
" f\"Report count after filtering by community level {COMMUNITY_LEVEL}: {len(reports)}\"\n",
")\n",
"report_df.head()" "report_df.head()"
] ]
}, },
@ -223,17 +226,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"LLM calls: 13. LLM tokens: 184660\n"
]
}
],
"source": [ "source": [
"# inspect number of LLM calls and tokens\n", "# inspect number of LLM calls and tokens\n",
"print(f\"LLM calls: {result.llm_calls}. LLM tokens: {result.prompt_tokens}\")" "print(f\"LLM calls: {result.llm_calls}. LLM tokens: {result.prompt_tokens}\")"

View File

@ -41,7 +41,39 @@ def build_community_context(
The calculated weight is added as an attribute to the community reports and added to the context data table. The calculated weight is added as an attribute to the community reports and added to the context data table.
""" """
if (
def _is_included(report: CommunityReport) -> bool:
return report.rank is not None and report.rank >= min_community_rank
def _get_header(attributes: list[str]) -> list[str]:
header = ["id", "title"]
attributes = [col for col in attributes if col not in header]
if not include_community_weight:
attributes = [col for col in attributes if col != community_weight_name]
header.extend(attributes)
header.append("summary" if use_community_summary else "content")
if include_community_rank:
header.append(community_rank_name)
return header
def _report_context_text(
report: CommunityReport, attributes: list[str]
) -> tuple[str, list[str]]:
context: list[str] = [
report.short_id if report.short_id else "",
report.title,
*[
str(report.attributes.get(field, "")) if report.attributes else ""
for field in attributes
],
]
context.append(report.summary if use_community_summary else report.full_content)
if include_community_rank:
context.append(str(report.rank))
result = column_delimiter.join(context) + "\n"
return result, context
compute_community_weights = (
entities entities
and len(community_reports) > 0 and len(community_reports) > 0
and include_community_weight and include_community_weight
@ -49,7 +81,8 @@ def build_community_context(
community_reports[0].attributes is None community_reports[0].attributes is None
or community_weight_name not in community_reports[0].attributes or community_weight_name not in community_reports[0].attributes
) )
): )
if compute_community_weights:
log.info("Computing community weights...") log.info("Computing community weights...")
community_reports = _compute_community_weights( community_reports = _compute_community_weights(
community_reports=community_reports, community_reports=community_reports,
@ -58,11 +91,7 @@ def build_community_context(
normalize=normalize_community_weight, normalize=normalize_community_weight,
) )
selected_reports = [ selected_reports = [report for report in community_reports if _is_included(report)]
report
for report in community_reports
if report.rank and report.rank >= min_community_rank
]
if selected_reports is None or len(selected_reports) == 0: if selected_reports is None or len(selected_reports) == 0:
return ([], {}) return ([], {})
@ -70,99 +99,67 @@ def build_community_context(
random.seed(random_state) random.seed(random_state)
random.shuffle(selected_reports) random.shuffle(selected_reports)
# add context header # "global" variables
current_context_text = f"-----{context_name}-----" + "\n" attributes = (
list(community_reports[0].attributes.keys())
# add header if community_reports[0].attributes
header = ["id", "title"]
attribute_cols = (
list(selected_reports[0].attributes.keys())
if selected_reports[0].attributes
else [] else []
) )
attribute_cols = [col for col in attribute_cols if col not in header] header = _get_header(attributes)
if not include_community_weight: all_context_text: list[str] = []
attribute_cols = [col for col in attribute_cols if col != community_weight_name] all_context_records: list[pd.DataFrame] = []
header.extend(attribute_cols)
header.append("summary" if use_community_summary else "content")
if include_community_rank:
header.append(community_rank_name)
current_context_text += column_delimiter.join(header) + "\n" # batch variables
current_tokens = num_tokens(current_context_text, token_encoder) batch_text: str = ""
current_context_records = [header] batch_tokens: int = 0
all_context_text = [] batch_records: list[list[str]] = []
all_context_records = []
def _init_batch() -> None:
nonlocal batch_text, batch_tokens, batch_records
batch_text = (
f"-----{context_name}-----" + "\n" + column_delimiter.join(header) + "\n"
)
batch_tokens = num_tokens(batch_text, token_encoder)
batch_records = []
def _cut_batch() -> None:
# convert the current context records to pandas dataframe and sort by weight and rank if exist
record_df = _convert_report_context_to_df(
context_records=batch_records,
header=header,
weight_column=community_weight_name
if entities and include_community_weight
else None,
rank_column=community_rank_name if include_community_rank else None,
)
if len(record_df) == 0:
return
current_context_text = record_df.to_csv(index=False, sep=column_delimiter)
all_context_text.append(current_context_text)
all_context_records.append(record_df)
# initialize the first batch
_init_batch()
for report in selected_reports: for report in selected_reports:
new_context = [ new_context_text, new_context = _report_context_text(report, attributes)
report.short_id,
report.title,
*[
str(report.attributes.get(field, "")) if report.attributes else ""
for field in attribute_cols
],
]
new_context.append(
report.summary if use_community_summary else report.full_content
)
if include_community_rank:
new_context.append(str(report.rank))
new_context_text = column_delimiter.join(new_context) + "\n"
new_tokens = num_tokens(new_context_text, token_encoder) new_tokens = num_tokens(new_context_text, token_encoder)
if current_tokens + new_tokens > max_tokens:
# convert the current context records to pandas dataframe and sort by weight and rank if exist
if len(current_context_records) > 1:
record_df = _convert_report_context_to_df(
context_records=current_context_records[1:],
header=current_context_records[0],
weight_column=community_weight_name
if entities and include_community_weight
else None,
rank_column=community_rank_name if include_community_rank else None,
)
else:
record_df = pd.DataFrame()
current_context_text = record_df.to_csv(index=False, sep=column_delimiter)
if batch_tokens + new_tokens > max_tokens:
# add the current batch to the context data and start a new batch if we are in multi-batch mode
_cut_batch()
if single_batch: if single_batch:
return current_context_text, {context_name.lower(): record_df} break
_init_batch()
all_context_text.append(current_context_text) # add current report to the current batch
all_context_records.append(record_df) batch_text += new_context_text
batch_tokens += new_tokens
# start a new batch batch_records.append(new_context)
current_context_text = (
f"-----{context_name}-----"
+ "\n"
+ column_delimiter.join(header)
+ "\n"
)
current_tokens = num_tokens(current_context_text, token_encoder)
current_context_records = [header]
else:
current_context_text += new_context_text
current_tokens += new_tokens
current_context_records.append(new_context)
# add the last batch if it has not been added # add the last batch if it has not been added
if current_context_text not in all_context_text: if batch_text not in all_context_text:
if len(current_context_records) > 1: _cut_batch()
record_df = _convert_report_context_to_df(
context_records=current_context_records[1:],
header=current_context_records[0],
weight_column=community_weight_name
if entities and include_community_weight
else None,
rank_column=community_rank_name if include_community_rank else None,
)
else:
record_df = pd.DataFrame()
all_context_records.append(record_df)
current_context_text = record_df.to_csv(index=False, sep=column_delimiter)
all_context_text.append(current_context_text)
return all_context_text, { return all_context_text, {
context_name.lower(): pd.concat(all_context_records, ignore_index=True) context_name.lower(): pd.concat(all_context_records, ignore_index=True)
@ -171,11 +168,14 @@ def build_community_context(
def _compute_community_weights( def _compute_community_weights(
community_reports: list[CommunityReport], community_reports: list[CommunityReport],
entities: list[Entity], entities: list[Entity] | None,
weight_attribute: str = "occurrence", weight_attribute: str = "occurrence",
normalize: bool = True, normalize: bool = True,
) -> list[CommunityReport]: ) -> list[CommunityReport]:
"""Calculate a community's weight as count of text units associated with entities within the community.""" """Calculate a community's weight as count of text units associated with entities within the community."""
if not entities:
return community_reports
community_text_units = {} community_text_units = {}
for entity in entities: for entity in entities:
if entity.community_ids: if entity.community_ids:
@ -211,7 +211,7 @@ def _rank_report_context(
rank_column: str | None = "rank", rank_column: str | None = "rank",
) -> pd.DataFrame: ) -> pd.DataFrame:
"""Sort report context by community weight and rank if exist.""" """Sort report context by community weight and rank if exist."""
rank_attributes = [] rank_attributes: list[str] = []
if weight_column: if weight_column:
rank_attributes.append(weight_column) rank_attributes.append(weight_column)
report_df[weight_column] = report_df[weight_column].astype(float) report_df[weight_column] = report_df[weight_column].astype(float)
@ -230,6 +230,9 @@ def _convert_report_context_to_df(
rank_column: str | None = None, rank_column: str | None = None,
) -> pd.DataFrame: ) -> pd.DataFrame:
"""Convert report context records to pandas dataframe and sort by weight and rank if exist.""" """Convert report context records to pandas dataframe and sort by weight and rank if exist."""
if len(context_records) == 0:
return pd.DataFrame()
record_df = pd.DataFrame( record_df = pd.DataFrame(
context_records, context_records,
columns=cast(Any, header), columns=cast(Any, header),