mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-03 18:40:08 +00:00
Fix/text unit code cleanup (#1040)
* Optimized _build_text_unit_context function for improved time and space complexity Refactored the _build_text_unit_context function to enhance its performance and efficiency. Key optimizations include: 1. Set for Text Unit IDs: Replaced list-based membership checks with a set (text_unit_ids_set) to achieve constant-time complexity for membership checks, reducing overall time complexity. 2. Direct Attribute Removal: Utilized pop with a default value (None) to directly remove attributes entity_order and num_relationships from text units, minimizing overhead and avoiding potential KeyError. 3. Default Dictionary for Entity Orders: Implemented defaultdict for managing entity orders, simplifying the ranking process and improving readability. These improvements result in a more efficient function with better performance, especially when handling large datasets or numerous selected entities. The refactoring ensures that the core functionality remains unchanged while enhancing both time and space complexity. * Format * Ruff fixes * semver --------- Co-authored-by: arjun-234 <arjun.darji@yudiz.com> Co-authored-by: Arjun D. <103405661+arjun-234@users.noreply.github.com>
This commit is contained in:
parent
5d8e60ceb7
commit
22df2f80d0
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Refactor text unit build at local search"
|
||||
}
|
||||
@ -309,42 +309,36 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
context_name: str = "Sources",
|
||||
) -> tuple[str, dict[str, pd.DataFrame]]:
|
||||
"""Rank matching text units and add them to the context window until it hits the max_tokens limit."""
|
||||
if len(selected_entities) == 0 or len(self.text_units) == 0:
|
||||
if not selected_entities or not self.text_units:
|
||||
return ("", {context_name.lower(): pd.DataFrame()})
|
||||
|
||||
selected_text_units = list[TextUnit]()
|
||||
# for each matching text unit, rank first by the order of the entities that match it, then by the number of matching relationships
|
||||
# that the text unit has with the matching entities
|
||||
for index, entity in enumerate(selected_entities):
|
||||
if entity.text_unit_ids:
|
||||
for text_id in entity.text_unit_ids:
|
||||
if (
|
||||
text_id not in [unit.id for unit in selected_text_units]
|
||||
and text_id in self.text_units
|
||||
):
|
||||
selected_unit = self.text_units[text_id]
|
||||
num_relationships = count_relationships(
|
||||
selected_unit, entity, self.relationships
|
||||
)
|
||||
if selected_unit.attributes is None:
|
||||
selected_unit.attributes = {}
|
||||
selected_unit.attributes["entity_order"] = index
|
||||
selected_unit.attributes["num_relationships"] = (
|
||||
num_relationships
|
||||
)
|
||||
selected_text_units.append(selected_unit)
|
||||
selected_text_units = []
|
||||
text_unit_ids_set = set()
|
||||
|
||||
for index, entity in enumerate(selected_entities):
|
||||
for text_id in entity.text_unit_ids or []:
|
||||
if text_id not in text_unit_ids_set and text_id in self.text_units:
|
||||
text_unit_ids_set.add(text_id)
|
||||
selected_unit = self.text_units[text_id]
|
||||
num_relationships = count_relationships(
|
||||
selected_unit, entity, self.relationships
|
||||
)
|
||||
if selected_unit.attributes is None:
|
||||
selected_unit.attributes = {}
|
||||
selected_unit.attributes["entity_order"] = index
|
||||
selected_unit.attributes["num_relationships"] = num_relationships
|
||||
selected_text_units.append(selected_unit)
|
||||
|
||||
# sort selected text units by ascending order of entity order and descending order of number of relationships
|
||||
selected_text_units.sort(
|
||||
key=lambda x: (
|
||||
x.attributes["entity_order"], # type: ignore
|
||||
-x.attributes["num_relationships"], # type: ignore
|
||||
x.attributes["entity_order"],
|
||||
-x.attributes["num_relationships"],
|
||||
)
|
||||
)
|
||||
|
||||
for unit in selected_text_units:
|
||||
del unit.attributes["entity_order"] # type: ignore
|
||||
del unit.attributes["num_relationships"] # type: ignore
|
||||
unit.attributes.pop("entity_order", None)
|
||||
unit.attributes.pop("num_relationships", None)
|
||||
|
||||
context_text, context_data = build_text_unit_context(
|
||||
text_units=selected_text_units,
|
||||
@ -362,8 +356,8 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
)
|
||||
context_key = context_name.lower()
|
||||
if context_key not in context_data:
|
||||
candidate_context_data["in_context"] = False
|
||||
context_data[context_key] = candidate_context_data
|
||||
context_data[context_key]["in_context"] = False
|
||||
else:
|
||||
if (
|
||||
"id" in candidate_context_data.columns
|
||||
@ -371,12 +365,11 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
):
|
||||
candidate_context_data["in_context"] = candidate_context_data[
|
||||
"id"
|
||||
].isin( # cspell:disable-line
|
||||
context_data[context_key]["id"]
|
||||
)
|
||||
].isin(context_data[context_key]["id"])
|
||||
context_data[context_key] = candidate_context_data
|
||||
else:
|
||||
context_data[context_key]["in_context"] = True
|
||||
|
||||
return (str(context_text), context_data)
|
||||
|
||||
def _build_local_context(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user