mirror of
https://github.com/getzep/graphiti.git
synced 2026-01-06 12:20:47 +00:00
Add max_coroutines parameter to Graphiti and update semaphore_gather function (#619)
- Introduced max_coroutines parameter in the Graphiti class to control the maximum number of concurrent operations. - Updated the semaphore_gather function to accept max_coroutines as an optional argument, defaulting to SEMAPHORE_LIMIT if not provided. - Adjusted multiple calls to semaphore_gather throughout the Graphiti class to utilize the new max_coroutines parameter for better concurrency management.
This commit is contained in:
parent
ae7f2234a8
commit
fe870b953f
@ -103,6 +103,7 @@ class Graphiti:
|
||||
cross_encoder: CrossEncoderClient | None = None,
|
||||
store_raw_episode_content: bool = True,
|
||||
graph_driver: GraphDriver | None = None,
|
||||
max_coroutines: int | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize a Graphiti instance.
|
||||
@ -121,6 +122,20 @@ class Graphiti:
|
||||
llm_client : LLMClient | None, optional
|
||||
An instance of LLMClient for natural language processing tasks.
|
||||
If not provided, a default OpenAIClient will be initialized.
|
||||
embedder : EmbedderClient | None, optional
|
||||
An instance of EmbedderClient for embedding tasks.
|
||||
If not provided, a default OpenAIEmbedder will be initialized.
|
||||
cross_encoder : CrossEncoderClient | None, optional
|
||||
An instance of CrossEncoderClient for reranking tasks.
|
||||
If not provided, a default OpenAIRerankerClient will be initialized.
|
||||
store_raw_episode_content : bool, optional
|
||||
Whether to store the raw content of episodes. Defaults to True.
|
||||
graph_driver : GraphDriver | None, optional
|
||||
An instance of GraphDriver for database operations.
|
||||
If not provided, a default Neo4jDriver will be initialized.
|
||||
max_coroutines : int | None, optional
|
||||
The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
|
||||
If not set, the Graphiti default is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -145,6 +160,7 @@ class Graphiti:
|
||||
|
||||
self.database = DEFAULT_DATABASE
|
||||
self.store_raw_episode_content = store_raw_episode_content
|
||||
self.max_coroutines = max_coroutines
|
||||
if llm_client:
|
||||
self.llm_client = llm_client
|
||||
else:
|
||||
@ -393,6 +409,7 @@ class Graphiti:
|
||||
group_id,
|
||||
edge_types,
|
||||
),
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
|
||||
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
||||
@ -409,6 +426,7 @@ class Graphiti:
|
||||
extract_attributes_from_nodes(
|
||||
self.clients, nodes, episode, previous_episodes, entity_types
|
||||
),
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
|
||||
duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates)
|
||||
@ -432,7 +450,8 @@ class Graphiti:
|
||||
*[
|
||||
update_community(self.driver, self.llm_client, self.embedder, node)
|
||||
for node in nodes
|
||||
]
|
||||
],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
end = time()
|
||||
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
||||
@ -499,7 +518,10 @@ class Graphiti:
|
||||
]
|
||||
|
||||
# Save all the episodes
|
||||
await semaphore_gather(*[episode.save(self.driver) for episode in episodes])
|
||||
await semaphore_gather(
|
||||
*[episode.save(self.driver) for episode in episodes],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
|
||||
# Get previous episode context for each episode
|
||||
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
||||
@ -515,16 +537,21 @@ class Graphiti:
|
||||
await semaphore_gather(
|
||||
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
|
||||
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
|
||||
# Dedupe extracted nodes, compress extracted edges
|
||||
(nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
|
||||
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
|
||||
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
|
||||
# save nodes to KG
|
||||
await semaphore_gather(*[node.save(self.driver) for node in nodes])
|
||||
await semaphore_gather(
|
||||
*[node.save(self.driver) for node in nodes],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
|
||||
# re-map edge pointers so that they don't point to discard dupe nodes
|
||||
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
|
||||
@ -536,7 +563,8 @@ class Graphiti:
|
||||
|
||||
# save episodic edges to KG
|
||||
await semaphore_gather(
|
||||
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
|
||||
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
|
||||
# Dedupe extracted edges
|
||||
@ -548,7 +576,10 @@ class Graphiti:
|
||||
# invalidate edges
|
||||
|
||||
# save edges to KG
|
||||
await semaphore_gather(*[edge.save(self.driver) for edge in edges])
|
||||
await semaphore_gather(
|
||||
*[edge.save(self.driver) for edge in edges],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
|
||||
end = time()
|
||||
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
|
||||
@ -572,11 +603,18 @@ class Graphiti:
|
||||
)
|
||||
|
||||
await semaphore_gather(
|
||||
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
|
||||
*[node.generate_name_embedding(self.embedder) for node in community_nodes],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
|
||||
await semaphore_gather(*[node.save(self.driver) for node in community_nodes])
|
||||
await semaphore_gather(*[edge.save(self.driver) for edge in community_edges])
|
||||
await semaphore_gather(
|
||||
*[node.save(self.driver) for node in community_nodes],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
await semaphore_gather(
|
||||
*[edge.save(self.driver) for edge in community_edges],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
|
||||
return community_nodes
|
||||
|
||||
@ -683,7 +721,8 @@ class Graphiti:
|
||||
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
||||
|
||||
edges_list = await semaphore_gather(
|
||||
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
||||
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
|
||||
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
|
||||
@ -759,6 +798,12 @@ class Graphiti:
|
||||
if record['episode_count'] == 1:
|
||||
nodes_to_delete.append(node)
|
||||
|
||||
await semaphore_gather(*[node.delete(self.driver) for node in nodes_to_delete])
|
||||
await semaphore_gather(*[edge.delete(self.driver) for edge in edges_to_delete])
|
||||
await semaphore_gather(
|
||||
*[node.delete(self.driver) for node in nodes_to_delete],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
await semaphore_gather(
|
||||
*[edge.delete(self.driver) for edge in edges_to_delete],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
await episode.delete(self.driver)
|
||||
|
||||
@ -94,9 +94,9 @@ def normalize_l2(embedding: list[float]) -> NDArray:
|
||||
# Use this instead of asyncio.gather() to bound coroutines
|
||||
async def semaphore_gather(
|
||||
*coroutines: Coroutine,
|
||||
max_coroutines: int = SEMAPHORE_LIMIT,
|
||||
max_coroutines: int | None = None,
|
||||
):
|
||||
semaphore = asyncio.Semaphore(max_coroutines)
|
||||
semaphore = asyncio.Semaphore(max_coroutines or SEMAPHORE_LIMIT)
|
||||
|
||||
async def _wrap_coroutine(coroutine):
|
||||
async with semaphore:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user