Fix/text unit code cleanup (#1040)

* Optimized _build_text_unit_context function for improved time and space complexity

Refactored the _build_text_unit_context function to enhance its performance and efficiency. Key optimizations include:

1. Set for Text Unit IDs: Replaced list-based membership checks with a set (text_unit_ids_set) to achieve constant-time complexity for membership checks, reducing overall time complexity.
2. Direct Attribute Removal: Utilized pop with a default value (None) to directly remove attributes entity_order and num_relationships from text units, minimizing overhead and avoiding potential KeyError.
3. Default Dictionary for Entity Orders: Implemented defaultdict for managing entity orders, simplifying the ranking process and improving readability.

These improvements result in a more efficient function with better performance, especially when handling large datasets or numerous selected entities. The refactoring ensures that the core functionality remains unchanged while enhancing both time and space complexity.

* Format

* Ruff fixes

* semver

---------

Co-authored-by: arjun-234 <arjun.darji@yudiz.com>
Co-authored-by: Arjun D. <103405661+arjun-234@users.noreply.github.com>
This commit is contained in:
Alonso Guevara 2024-08-27 16:15:16 -06:00 committed by GitHub
parent 5d8e60ceb7
commit 22df2f80d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 31 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Refactor text unit build at local search"
}

View File

@ -309,42 +309,36 @@ class LocalSearchMixedContext(LocalContextBuilder):
context_name: str = "Sources",
) -> tuple[str, dict[str, pd.DataFrame]]:
"""Rank matching text units and add them to the context window until it hits the max_tokens limit."""
if len(selected_entities) == 0 or len(self.text_units) == 0:
if not selected_entities or not self.text_units:
return ("", {context_name.lower(): pd.DataFrame()})
selected_text_units = list[TextUnit]()
# for each matching text unit, rank first by the order of the entities that match it, then by the number of matching relationships
# that the text unit has with the matching entities
for index, entity in enumerate(selected_entities):
if entity.text_unit_ids:
for text_id in entity.text_unit_ids:
if (
text_id not in [unit.id for unit in selected_text_units]
and text_id in self.text_units
):
selected_unit = self.text_units[text_id]
num_relationships = count_relationships(
selected_unit, entity, self.relationships
)
if selected_unit.attributes is None:
selected_unit.attributes = {}
selected_unit.attributes["entity_order"] = index
selected_unit.attributes["num_relationships"] = (
num_relationships
)
selected_text_units.append(selected_unit)
selected_text_units = []
text_unit_ids_set = set()
for index, entity in enumerate(selected_entities):
for text_id in entity.text_unit_ids or []:
if text_id not in text_unit_ids_set and text_id in self.text_units:
text_unit_ids_set.add(text_id)
selected_unit = self.text_units[text_id]
num_relationships = count_relationships(
selected_unit, entity, self.relationships
)
if selected_unit.attributes is None:
selected_unit.attributes = {}
selected_unit.attributes["entity_order"] = index
selected_unit.attributes["num_relationships"] = num_relationships
selected_text_units.append(selected_unit)
# sort selected text units by ascending order of entity order and descending order of number of relationships
selected_text_units.sort(
key=lambda x: (
x.attributes["entity_order"], # type: ignore
-x.attributes["num_relationships"], # type: ignore
x.attributes["entity_order"],
-x.attributes["num_relationships"],
)
)
for unit in selected_text_units:
del unit.attributes["entity_order"] # type: ignore
del unit.attributes["num_relationships"] # type: ignore
unit.attributes.pop("entity_order", None)
unit.attributes.pop("num_relationships", None)
context_text, context_data = build_text_unit_context(
text_units=selected_text_units,
@ -362,8 +356,8 @@ class LocalSearchMixedContext(LocalContextBuilder):
)
context_key = context_name.lower()
if context_key not in context_data:
candidate_context_data["in_context"] = False
context_data[context_key] = candidate_context_data
context_data[context_key]["in_context"] = False
else:
if (
"id" in candidate_context_data.columns
@ -371,12 +365,11 @@ class LocalSearchMixedContext(LocalContextBuilder):
):
candidate_context_data["in_context"] = candidate_context_data[
"id"
].isin( # cspell:disable-line
context_data[context_key]["id"]
)
].isin(context_data[context_key]["id"])
context_data[context_key] = candidate_context_data
else:
context_data[context_key]["in_context"] = True
return (str(context_text), context_data)
def _build_local_context(