Bounded semaphore - limiting concurrency (#244)

* WIP

* add semaphore

* remove unused imports

* remove unused imports

* lower concurrency limit
This commit is contained in:
Preston Rasmussen 2024-12-17 13:08:18 -05:00 committed by GitHub
parent 0186ac920c
commit 00fe87679e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 87 additions and 64 deletions

View File

@ -25,6 +25,7 @@ from dotenv import load_dotenv
from examples.multi_session_conversation_memory.parse_msc_messages import conversation_q_and_a
from graphiti_core import Graphiti
from graphiti_core.helpers import semaphore_gather
from graphiti_core.prompts import prompt_library
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
@ -122,7 +123,7 @@ async def main():
qa_chunk = qa[i : i + 20]
group_ids = range(len(qa))[i : i + 20]
results = list(
await asyncio.gather(
await semaphore_gather(
*[
evaluate_qa(graphiti, str(group_id), query, answer)
for group_id, (query, answer) in zip(group_ids, qa_chunk)

View File

@ -26,6 +26,7 @@ from examples.multi_session_conversation_memory.parse_msc_messages import (
parse_msc_messages,
)
from graphiti_core import Graphiti
from graphiti_core.helpers import semaphore_gather
load_dotenv()
@ -75,7 +76,7 @@ async def main():
msc_message_slice = msc_messages[i : i + 10]
group_ids = range(len(msc_messages))[i : i + 10]
await asyncio.gather(
await semaphore_gather(
*[
add_conversation(graphiti, str(group_id), messages)
for group_id, messages in zip(group_ids, msc_message_slice)

View File

@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
from typing import Any
@ -22,6 +21,7 @@ import openai
from openai import AsyncOpenAI
from pydantic import BaseModel
from ..helpers import semaphore_gather
from ..llm_client import LLMConfig, RateLimitError
from ..prompts import Message
from .client import CrossEncoderClient
@ -75,7 +75,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
for passage in passages
]
try:
responses = await asyncio.gather(
responses = await semaphore_gather(
*[
self.client.chat.completions.create(
model=DEFAULT_MODEL,

View File

@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
from datetime import datetime
from time import time
@ -27,7 +26,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.helpers import DEFAULT_DATABASE
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, search
@ -340,13 +339,13 @@ class Graphiti:
# Calculate Embeddings
await asyncio.gather(
await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
)
# Find relevant nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = list(
await asyncio.gather(
await semaphore_gather(
*[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
)
)
@ -354,7 +353,7 @@ class Graphiti:
# Resolve extracted nodes with nodes already in the graph and extract facts
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
(mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
resolve_extracted_nodes(
self.llm_client,
extracted_nodes,
@ -374,7 +373,7 @@ class Graphiti:
)
# calculate embeddings
await asyncio.gather(
await semaphore_gather(
*[
edge.generate_embedding(self.embedder)
for edge in extracted_edges_with_resolved_pointers
@ -383,7 +382,7 @@ class Graphiti:
# Resolve extracted edges with related edges already in the graph
related_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
await semaphore_gather(
*[
get_relevant_edges(
self.driver,
@ -404,7 +403,7 @@ class Graphiti:
)
existing_source_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
await semaphore_gather(
*[
get_relevant_edges(
self.driver,
@ -419,7 +418,7 @@ class Graphiti:
)
existing_target_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
await semaphore_gather(
*[
get_relevant_edges(
self.driver,
@ -468,7 +467,7 @@ class Graphiti:
# Update any communities
if update_communities:
await asyncio.gather(
await semaphore_gather(
*[
update_community(self.driver, self.llm_client, self.embedder, node)
for node in nodes
@ -538,7 +537,7 @@ class Graphiti:
]
# Save all the episodes
await asyncio.gather(*[episode.save(self.driver) for episode in episodes])
await semaphore_gather(*[episode.save(self.driver) for episode in episodes])
# Get previous episode context for each episode
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
@ -551,19 +550,19 @@ class Graphiti:
) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
# Generate embeddings
await asyncio.gather(
await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
)
# Dedupe extracted nodes, compress extracted edges
(nodes, uuid_map), extracted_edges_timestamped = await asyncio.gather(
(nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
)
# save nodes to KG
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await semaphore_gather(*[node.save(self.driver) for node in nodes])
# re-map edge pointers so that they don't point to discard dupe nodes
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
@ -574,7 +573,7 @@ class Graphiti:
)
# save episodic edges to KG
await asyncio.gather(
await semaphore_gather(
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
)
@ -587,7 +586,7 @@ class Graphiti:
# invalidate edges
# save edges to KG
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
await semaphore_gather(*[edge.save(self.driver) for edge in edges])
end = time()
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
@ -610,12 +609,12 @@ class Graphiti:
self.driver, self.llm_client, group_ids
)
await asyncio.gather(
await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
)
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
await semaphore_gather(*[node.save(self.driver) for node in community_nodes])
await semaphore_gather(*[edge.save(self.driver) for edge in community_edges])
return community_nodes
@ -698,7 +697,7 @@ class Graphiti:
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
edges_list = await asyncio.gather(
edges_list = await semaphore_gather(
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
)

View File

@ -14,7 +14,9 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import os
from collections.abc import Coroutine
from datetime import datetime
import numpy as np
@ -25,6 +27,7 @@ load_dotenv()
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
MAX_REFLEXION_ITERATIONS = 2
DEFAULT_PAGE_LIMIT = 20
@ -80,3 +83,19 @@ def normalize_l2(embedding: list[float]) -> list[float]:
else:
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
# Use this instead of asyncio.gather() to bound coroutines
async def semaphore_gather(
*coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT, return_exceptions=True
):
semaphore = asyncio.Semaphore(max_coroutines)
async def _wrap_coroutine(coroutine):
async with semaphore:
return await coroutine
return await asyncio.gather(
*(_wrap_coroutine(coroutine) for coroutine in coroutines),
return_exceptions=return_exceptions,
)

View File

@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
from collections import defaultdict
from time import time
@ -25,6 +24,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.edges import EntityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import SearchRerankerError
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.search.search_config import (
DEFAULT_SEARCH_LIMIT,
@ -78,7 +78,7 @@ async def search(
# if group_ids is empty, set it to None
group_ids = group_ids if group_ids else None
edges, nodes, communities = await asyncio.gather(
edges, nodes, communities = await semaphore_gather(
edge_search(
driver,
cross_encoder,
@ -141,7 +141,7 @@ async def edge_search(
return []
search_results: list[list[EntityEdge]] = list(
await asyncio.gather(
await semaphore_gather(
*[
edge_fulltext_search(driver, query, group_ids, 2 * limit),
edge_similarity_search(
@ -226,7 +226,7 @@ async def node_search(
return []
search_results: list[list[EntityNode]] = list(
await asyncio.gather(
await semaphore_gather(
*[
node_fulltext_search(driver, query, group_ids, 2 * limit),
node_similarity_search(
@ -295,7 +295,7 @@ async def community_search(
return []
search_results: list[list[CommunityNode]] = list(
await asyncio.gather(
await semaphore_gather(
*[
community_fulltext_search(driver, query, group_ids, 2 * limit),
community_similarity_search(

View File

@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
from collections import defaultdict
from time import time
@ -30,6 +29,7 @@ from graphiti_core.helpers import (
USE_PARALLEL_RUNTIME,
lucene_sanitize,
normalize_l2,
semaphore_gather,
)
from graphiti_core.nodes import (
CommunityNode,
@ -549,7 +549,7 @@ async def hybrid_node_search(
start = time()
results: list[list[EntityNode]] = list(
await asyncio.gather(
await semaphore_gather(
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
)
@ -619,7 +619,7 @@ async def get_relevant_edges(
relevant_edges: list[EntityEdge] = []
relevant_edge_uuids = set()
results = await asyncio.gather(
results = await semaphore_gather(
*[
edge_similarity_search(
driver,

View File

@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
import typing
from collections import defaultdict
@ -26,6 +25,7 @@ from numpy import dot, sqrt
from pydantic import BaseModel
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
from graphiti_core.helpers import semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.models.edges.edge_db_queries import (
ENTITY_EDGE_SAVE_BULK,
@ -71,7 +71,7 @@ class RawEpisode(BaseModel):
async def retrieve_previous_episodes_bulk(
driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
previous_episodes_list = await asyncio.gather(
previous_episodes_list = await semaphore_gather(
*[
retrieve_episodes(
driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
@ -118,7 +118,7 @@ async def add_nodes_and_edges_bulk_tx(
async def extract_nodes_and_edges_bulk(
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
extracted_nodes_bulk = await asyncio.gather(
extracted_nodes_bulk = await semaphore_gather(
*[
extract_nodes(llm_client, episode, previous_episodes)
for episode, previous_episodes in episode_tuples
@ -130,7 +130,7 @@ async def extract_nodes_and_edges_bulk(
[episode[1] for episode in episode_tuples],
)
extracted_edges_bulk = await asyncio.gather(
extracted_edges_bulk = await semaphore_gather(
*[
extract_edges(
llm_client,
@ -171,13 +171,13 @@ async def dedupe_nodes_bulk(
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
existing_nodes_chunks: list[list[EntityNode]] = list(
await asyncio.gather(
await semaphore_gather(
*[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks]
)
)
results: list[tuple[list[EntityNode], dict[str, str]]] = list(
await asyncio.gather(
await semaphore_gather(
*[
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
for i, node_chunk in enumerate(node_chunks)
@ -205,13 +205,13 @@ async def dedupe_edges_bulk(
]
relevant_edges_chunks: list[list[EntityEdge]] = list(
await asyncio.gather(
await semaphore_gather(
*[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
)
)
resolved_edge_chunks: list[list[EntityEdge]] = list(
await asyncio.gather(
await semaphore_gather(
*[
dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
for i, edge_chunk in enumerate(edge_chunks)
@ -292,7 +292,9 @@ async def compress_nodes(
# add both nodes to the shortest chunk
node_chunks[-1].extend([n, m])
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
results = await semaphore_gather(
*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
)
extended_map = dict(uuid_map)
compressed_nodes: list[EntityNode] = []
@ -315,7 +317,9 @@ async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list
# We build a map of the edges based on their source and target nodes.
edge_chunks = chunk_edges_by_nodes(edges)
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
results = await semaphore_gather(
*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
)
compressed_edges: list[EntityEdge] = []
for edge_chunk in results:
@ -368,7 +372,7 @@ async def extract_edge_dates_bulk(
episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
}
results = await asyncio.gather(
results = await semaphore_gather(
*[
extract_edge_dates(
llm_client,

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel
from graphiti_core.edges import CommunityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.helpers import DEFAULT_DATABASE
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import (
CommunityNode,
@ -71,7 +71,7 @@ async def get_community_clusters(
community_clusters.extend(
list(
await asyncio.gather(
await semaphore_gather(
*[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids]
)
)
@ -164,7 +164,7 @@ async def build_community(
odd_one_out = summaries.pop()
length -= 1
new_summaries: list[str] = list(
await asyncio.gather(
await semaphore_gather(
*[
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
for left_summary, right_summary in zip(
@ -207,7 +207,9 @@ async def build_communities(
return await build_community(llm_client, cluster)
communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
await asyncio.gather(*[limited_build_community(cluster) for cluster in community_clusters])
await semaphore_gather(
*[limited_build_community(cluster) for cluster in community_clusters]
)
)
community_nodes: list[CommunityNode] = []

View File

@ -14,13 +14,12 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
from datetime import datetime
from time import time
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library
@ -199,7 +198,7 @@ async def resolve_extracted_edges(
) -> tuple[list[EntityEdge], list[EntityEdge]]:
# resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
await asyncio.gather(
await semaphore_gather(
*[
resolve_extracted_edge(
llm_client,
@ -266,7 +265,7 @@ async def resolve_extracted_edge(
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> tuple[EntityEdge, list[EntityEdge]]:
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await semaphore_gather(
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
get_edge_contradictions(llm_client, extracted_edge, existing_edges),

View File

@ -14,14 +14,13 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
from datetime import datetime, timezone
from neo4j import AsyncDriver
from typing_extensions import LiteralString
from graphiti_core.helpers import DEFAULT_DATABASE
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.nodes import EpisodeType, EpisodicNode
EPISODE_WINDOW_LEN = 3
@ -38,7 +37,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
database_=DEFAULT_DATABASE,
)
index_names = [record['name'] for record in records]
await asyncio.gather(
await semaphore_gather(
*[
driver.execute_query(
"""DROP INDEX $name""",
@ -82,7 +81,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
index_queries: list[LiteralString] = range_indices + fulltext_indices
await asyncio.gather(
await semaphore_gather(
*[
driver.execute_query(
query,

View File

@ -14,11 +14,10 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
from time import time
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.prompts import prompt_library
@ -223,7 +222,7 @@ async def resolve_extracted_nodes(
uuid_map: dict[str, str] = {}
resolved_nodes: list[EntityNode] = []
results: list[tuple[EntityNode, dict[str, str]]] = list(
await asyncio.gather(
await semaphore_gather(
*[
resolve_extracted_node(
llm_client, extracted_node, existing_nodes, episode, previous_episodes
@ -275,7 +274,7 @@ async def resolve_extracted_node(
else [],
}
llm_response, node_summary_response = await asyncio.gather(
llm_response, node_summary_response = await semaphore_gather(
llm_client.generate_response(
prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate
),

View File

@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
import os
import sys
@ -25,6 +24,7 @@ from dotenv import load_dotenv
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_config_recipes import (
COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
@ -137,8 +137,8 @@ async def test_graph_integration():
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
# test save
await asyncio.gather(*[node.save(driver) for node in nodes])
await asyncio.gather(*[edge.save(driver) for edge in edges])
await semaphore_gather(*[node.save(driver) for node in nodes])
await semaphore_gather(*[edge.save(driver) for edge in edges])
# test get
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
@ -147,5 +147,5 @@ async def test_graph_integration():
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
# test delete
await asyncio.gather(*[node.delete(driver) for node in nodes])
await asyncio.gather(*[edge.delete(driver) for edge in edges])
await semaphore_gather(*[node.delete(driver) for node in nodes])
await semaphore_gather(*[edge.delete(driver) for edge in edges])