Add max_coroutines parameter to Graphiti and update semaphore_gather function (#619)

- Introduced max_coroutines parameter in the Graphiti class to control the maximum number of concurrent operations.
- Updated the semaphore_gather function to accept max_coroutines as an optional argument, defaulting to SEMAPHORE_LIMIT if not provided.
- Adjusted multiple calls to semaphore_gather throughout the Graphiti class to utilize the new max_coroutines parameter for better concurrency management.
This commit is contained in:
Daniel Chalef 2025-06-24 09:32:16 -07:00 committed by GitHub
parent ae7f2234a8
commit fe870b953f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 13 deletions

View File

@ -103,6 +103,7 @@ class Graphiti:
cross_encoder: CrossEncoderClient | None = None,
store_raw_episode_content: bool = True,
graph_driver: GraphDriver | None = None,
max_coroutines: int | None = None,
):
"""
Initialize a Graphiti instance.
@ -121,6 +122,20 @@ class Graphiti:
llm_client : LLMClient | None, optional
An instance of LLMClient for natural language processing tasks.
If not provided, a default OpenAIClient will be initialized.
embedder : EmbedderClient | None, optional
An instance of EmbedderClient for embedding tasks.
If not provided, a default OpenAIEmbedder will be initialized.
cross_encoder : CrossEncoderClient | None, optional
An instance of CrossEncoderClient for reranking tasks.
If not provided, a default OpenAIRerankerClient will be initialized.
store_raw_episode_content : bool, optional
Whether to store the raw content of episodes. Defaults to True.
graph_driver : GraphDriver | None, optional
An instance of GraphDriver for database operations.
If not provided, a default Neo4jDriver will be initialized.
max_coroutines : int | None, optional
The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
If not set, the Graphiti default is used.
Returns
-------
@ -145,6 +160,7 @@ class Graphiti:
self.database = DEFAULT_DATABASE
self.store_raw_episode_content = store_raw_episode_content
self.max_coroutines = max_coroutines
if llm_client:
self.llm_client = llm_client
else:
@ -393,6 +409,7 @@ class Graphiti:
group_id,
edge_types,
),
max_coroutines=self.max_coroutines,
)
edges = resolve_edge_pointers(extracted_edges, uuid_map)
@ -409,6 +426,7 @@ class Graphiti:
extract_attributes_from_nodes(
self.clients, nodes, episode, previous_episodes, entity_types
),
max_coroutines=self.max_coroutines,
)
duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates)
@ -432,7 +450,8 @@ class Graphiti:
*[
update_community(self.driver, self.llm_client, self.embedder, node)
for node in nodes
]
],
max_coroutines=self.max_coroutines,
)
end = time()
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
@ -499,7 +518,10 @@ class Graphiti:
]
# Save all the episodes
await semaphore_gather(*[episode.save(self.driver) for episode in episodes])
await semaphore_gather(
*[episode.save(self.driver) for episode in episodes],
max_coroutines=self.max_coroutines,
)
# Get previous episode context for each episode
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
@ -515,16 +537,21 @@ class Graphiti:
await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
max_coroutines=self.max_coroutines,
)
# Dedupe extracted nodes, compress extracted edges
(nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
max_coroutines=self.max_coroutines,
)
# save nodes to KG
await semaphore_gather(*[node.save(self.driver) for node in nodes])
await semaphore_gather(
*[node.save(self.driver) for node in nodes],
max_coroutines=self.max_coroutines,
)
# re-map edge pointers so that they don't point to discard dupe nodes
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
@ -536,7 +563,8 @@ class Graphiti:
# save episodic edges to KG
await semaphore_gather(
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers],
max_coroutines=self.max_coroutines,
)
# Dedupe extracted edges
@ -548,7 +576,10 @@ class Graphiti:
# invalidate edges
# save edges to KG
await semaphore_gather(*[edge.save(self.driver) for edge in edges])
await semaphore_gather(
*[edge.save(self.driver) for edge in edges],
max_coroutines=self.max_coroutines,
)
end = time()
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
@ -572,11 +603,18 @@ class Graphiti:
)
await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
*[node.generate_name_embedding(self.embedder) for node in community_nodes],
max_coroutines=self.max_coroutines,
)
await semaphore_gather(*[node.save(self.driver) for node in community_nodes])
await semaphore_gather(*[edge.save(self.driver) for edge in community_edges])
await semaphore_gather(
*[node.save(self.driver) for node in community_nodes],
max_coroutines=self.max_coroutines,
)
await semaphore_gather(
*[edge.save(self.driver) for edge in community_edges],
max_coroutines=self.max_coroutines,
)
return community_nodes
@ -683,7 +721,8 @@ class Graphiti:
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
edges_list = await semaphore_gather(
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes],
max_coroutines=self.max_coroutines,
)
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
@ -759,6 +798,12 @@ class Graphiti:
if record['episode_count'] == 1:
nodes_to_delete.append(node)
await semaphore_gather(*[node.delete(self.driver) for node in nodes_to_delete])
await semaphore_gather(*[edge.delete(self.driver) for edge in edges_to_delete])
await semaphore_gather(
*[node.delete(self.driver) for node in nodes_to_delete],
max_coroutines=self.max_coroutines,
)
await semaphore_gather(
*[edge.delete(self.driver) for edge in edges_to_delete],
max_coroutines=self.max_coroutines,
)
await episode.delete(self.driver)

View File

@ -94,9 +94,9 @@ def normalize_l2(embedding: list[float]) -> NDArray:
# Use this instead of asyncio.gather() to bound coroutines
async def semaphore_gather(
*coroutines: Coroutine,
max_coroutines: int = SEMAPHORE_LIMIT,
max_coroutines: int | None = None,
):
semaphore = asyncio.Semaphore(max_coroutines)
semaphore = asyncio.Semaphore(max_coroutines or SEMAPHORE_LIMIT)
async def _wrap_coroutine(coroutine):
async with semaphore: