diff --git a/.semversioner/next-release/patch-20240730235138482202.json b/.semversioner/next-release/patch-20240730235138482202.json new file mode 100644 index 00000000..d217abfa --- /dev/null +++ b/.semversioner/next-release/patch-20240730235138482202.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Fix missing community reports and refactor community context builder" +} diff --git a/examples_notebooks/global_search.ipynb b/examples_notebooks/global_search.ipynb index fdde2c70..3c0c462c 100644 --- a/examples_notebooks/global_search.ipynb +++ b/examples_notebooks/global_search.ipynb @@ -115,7 +115,10 @@ "\n", "reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)\n", "entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n", - "print(f\"Report records: {len(report_df)}\")\n", + "print(f\"Total report count: {len(report_df)}\")\n", + "print(\n", + " f\"Report count after filtering by community level {COMMUNITY_LEVEL}: {len(reports)}\"\n", + ")\n", "report_df.head()" ] }, @@ -223,17 +226,9 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "LLM calls: 13. LLM tokens: 184660\n" - ] - } - ], + "outputs": [], "source": [ "# inspect number of LLM calls and tokens\n", "print(f\"LLM calls: {result.llm_calls}. LLM tokens: {result.prompt_tokens}\")" diff --git a/graphrag/query/context_builder/community_context.py b/graphrag/query/context_builder/community_context.py index dba7f729..5039376f 100644 --- a/graphrag/query/context_builder/community_context.py +++ b/graphrag/query/context_builder/community_context.py @@ -41,7 +41,39 @@ def build_community_context( The calculated weight is added as an attribute to the community reports and added to the context data table. """ - if ( + + def _is_included(report: CommunityReport) -> bool: + return report.rank is not None and report.rank >= min_community_rank + + def _get_header(attributes: list[str]) -> list[str]: + header = ["id", "title"] + attributes = [col for col in attributes if col not in header] + if not include_community_weight: + attributes = [col for col in attributes if col != community_weight_name] + header.extend(attributes) + header.append("summary" if use_community_summary else "content") + if include_community_rank: + header.append(community_rank_name) + return header + + def _report_context_text( + report: CommunityReport, attributes: list[str] + ) -> tuple[str, list[str]]: + context: list[str] = [ + report.short_id if report.short_id else "", + report.title, + *[ + str(report.attributes.get(field, "")) if report.attributes else "" + for field in attributes + ], + ] + context.append(report.summary if use_community_summary else report.full_content) + if include_community_rank: + context.append(str(report.rank)) + result = column_delimiter.join(context) + "\n" + return result, context + + compute_community_weights = ( entities and len(community_reports) > 0 and include_community_weight @@ -49,7 +81,8 @@ def build_community_context( community_reports[0].attributes is None or community_weight_name not in community_reports[0].attributes ) - ): + ) + if compute_community_weights: log.info("Computing community weights...") community_reports = _compute_community_weights( community_reports=community_reports, @@ -58,11 +91,7 @@ def build_community_context( normalize=normalize_community_weight, ) - selected_reports = [ - report - for report in community_reports - if report.rank and report.rank >= min_community_rank - ] + selected_reports = [report for report in community_reports if _is_included(report)] if selected_reports is None or len(selected_reports) == 0: return ([], {}) @@ -70,99 +99,67 @@ def build_community_context( random.seed(random_state) random.shuffle(selected_reports) - # add context header - current_context_text = f"-----{context_name}-----" + "\n" - - # add header - header = ["id", "title"] - attribute_cols = ( - list(selected_reports[0].attributes.keys()) - if selected_reports[0].attributes + # "global" variables + attributes = ( + list(community_reports[0].attributes.keys()) + if community_reports[0].attributes else [] ) - attribute_cols = [col for col in attribute_cols if col not in header] - if not include_community_weight: - attribute_cols = [col for col in attribute_cols if col != community_weight_name] - header.extend(attribute_cols) - header.append("summary" if use_community_summary else "content") - if include_community_rank: - header.append(community_rank_name) + header = _get_header(attributes) + all_context_text: list[str] = [] + all_context_records: list[pd.DataFrame] = [] - current_context_text += column_delimiter.join(header) + "\n" - current_tokens = num_tokens(current_context_text, token_encoder) - current_context_records = [header] - all_context_text = [] - all_context_records = [] + # batch variables + batch_text: str = "" + batch_tokens: int = 0 + batch_records: list[list[str]] = [] - for report in selected_reports: - new_context = [ - report.short_id, - report.title, - *[ - str(report.attributes.get(field, "")) if report.attributes else "" - for field in attribute_cols - ], - ] - new_context.append( - report.summary if use_community_summary else report.full_content + def _init_batch() -> None: + nonlocal batch_text, batch_tokens, batch_records + batch_text = ( + f"-----{context_name}-----" + "\n" + column_delimiter.join(header) + "\n" ) - if include_community_rank: - new_context.append(str(report.rank)) - new_context_text = column_delimiter.join(new_context) + "\n" + batch_tokens = num_tokens(batch_text, token_encoder) + batch_records = [] - new_tokens = num_tokens(new_context_text, token_encoder) - if current_tokens + new_tokens > max_tokens: - # convert the current context records to pandas dataframe and sort by weight and rank if exist - if len(current_context_records) > 1: - record_df = _convert_report_context_to_df( - context_records=current_context_records[1:], - header=current_context_records[0], - weight_column=community_weight_name - if entities and include_community_weight - else None, - rank_column=community_rank_name if include_community_rank else None, - ) - - else: - record_df = pd.DataFrame() - current_context_text = record_df.to_csv(index=False, sep=column_delimiter) - - if single_batch: - return current_context_text, {context_name.lower(): record_df} - - all_context_text.append(current_context_text) - all_context_records.append(record_df) - - # start a new batch - current_context_text = ( - f"-----{context_name}-----" - + "\n" - + column_delimiter.join(header) - + "\n" - ) - current_tokens = num_tokens(current_context_text, token_encoder) - current_context_records = [header] - else: - current_context_text += new_context_text - current_tokens += new_tokens - current_context_records.append(new_context) - - # add the last batch if it has not been added - if current_context_text not in all_context_text: - if len(current_context_records) > 1: - record_df = _convert_report_context_to_df( - context_records=current_context_records[1:], - header=current_context_records[0], - weight_column=community_weight_name - if entities and include_community_weight - else None, - rank_column=community_rank_name if include_community_rank else None, - ) - else: - record_df = pd.DataFrame() - all_context_records.append(record_df) + def _cut_batch() -> None: + # convert the current context records to pandas dataframe and sort by weight and rank if exist + record_df = _convert_report_context_to_df( + context_records=batch_records, + header=header, + weight_column=community_weight_name + if entities and include_community_weight + else None, + rank_column=community_rank_name if include_community_rank else None, + ) + if len(record_df) == 0: + return current_context_text = record_df.to_csv(index=False, sep=column_delimiter) all_context_text.append(current_context_text) + all_context_records.append(record_df) + + # initialize the first batch + _init_batch() + + for report in selected_reports: + new_context_text, new_context = _report_context_text(report, attributes) + new_tokens = num_tokens(new_context_text, token_encoder) + + if batch_tokens + new_tokens > max_tokens: + # add the current batch to the context data and start a new batch if we are in multi-batch mode + _cut_batch() + if single_batch: + break + _init_batch() + + # add current report to the current batch + batch_text += new_context_text + batch_tokens += new_tokens + batch_records.append(new_context) + + # add the last batch if it has not been added + if batch_text not in all_context_text: + _cut_batch() return all_context_text, { context_name.lower(): pd.concat(all_context_records, ignore_index=True) @@ -171,11 +168,14 @@ def build_community_context( def _compute_community_weights( community_reports: list[CommunityReport], - entities: list[Entity], + entities: list[Entity] | None, weight_attribute: str = "occurrence", normalize: bool = True, ) -> list[CommunityReport]: """Calculate a community's weight as count of text units associated with entities within the community.""" + if not entities: + return community_reports + community_text_units = {} for entity in entities: if entity.community_ids: @@ -211,7 +211,7 @@ def _rank_report_context( rank_column: str | None = "rank", ) -> pd.DataFrame: """Sort report context by community weight and rank if exist.""" - rank_attributes = [] + rank_attributes: list[str] = [] if weight_column: rank_attributes.append(weight_column) report_df[weight_column] = report_df[weight_column].astype(float) @@ -230,6 +230,9 @@ def _convert_report_context_to_df( rank_column: str | None = None, ) -> pd.DataFrame: """Convert report context records to pandas dataframe and sort by weight and rank if exist.""" + if len(context_records) == 0: + return pd.DataFrame() + record_df = pd.DataFrame( context_records, columns=cast(Any, header),