mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
Bounded semaphore - limiting concurrency (#244)
* WIP * add semaphore * remove unused imports * remove unused imports * lower concurrency limit
This commit is contained in:
parent
0186ac920c
commit
00fe87679e
@ -25,6 +25,7 @@ from dotenv import load_dotenv
|
||||
|
||||
from examples.multi_session_conversation_memory.parse_msc_messages import conversation_q_and_a
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
|
||||
|
||||
@ -122,7 +123,7 @@ async def main():
|
||||
qa_chunk = qa[i : i + 20]
|
||||
group_ids = range(len(qa))[i : i + 20]
|
||||
results = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
evaluate_qa(graphiti, str(group_id), query, answer)
|
||||
for group_id, (query, answer) in zip(group_ids, qa_chunk)
|
||||
|
@ -26,6 +26,7 @@ from examples.multi_session_conversation_memory.parse_msc_messages import (
|
||||
parse_msc_messages,
|
||||
)
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@ -75,7 +76,7 @@ async def main():
|
||||
msc_message_slice = msc_messages[i : i + 10]
|
||||
group_ids = range(len(msc_messages))[i : i + 10]
|
||||
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
add_conversation(graphiti, str(group_id), messages)
|
||||
for group_id, messages in zip(group_ids, msc_message_slice)
|
||||
|
@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@ -22,6 +21,7 @@ import openai
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..helpers import semaphore_gather
|
||||
from ..llm_client import LLMConfig, RateLimitError
|
||||
from ..prompts import Message
|
||||
from .client import CrossEncoderClient
|
||||
@ -75,7 +75,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
||||
for passage in passages
|
||||
]
|
||||
try:
|
||||
responses = await asyncio.gather(
|
||||
responses = await semaphore_gather(
|
||||
*[
|
||||
self.client.chat.completions.create(
|
||||
model=DEFAULT_MODEL,
|
||||
|
@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
@ -27,7 +26,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search import SearchConfig, search
|
||||
@ -340,13 +339,13 @@ class Graphiti:
|
||||
|
||||
# Calculate Embeddings
|
||||
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
|
||||
)
|
||||
|
||||
# Find relevant nodes already in the graph
|
||||
existing_nodes_lists: list[list[EntityNode]] = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
|
||||
)
|
||||
)
|
||||
@ -354,7 +353,7 @@ class Graphiti:
|
||||
# Resolve extracted nodes with nodes already in the graph and extract facts
|
||||
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||
|
||||
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
|
||||
(mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
|
||||
resolve_extracted_nodes(
|
||||
self.llm_client,
|
||||
extracted_nodes,
|
||||
@ -374,7 +373,7 @@ class Graphiti:
|
||||
)
|
||||
|
||||
# calculate embeddings
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
edge.generate_embedding(self.embedder)
|
||||
for edge in extracted_edges_with_resolved_pointers
|
||||
@ -383,7 +382,7 @@ class Graphiti:
|
||||
|
||||
# Resolve extracted edges with related edges already in the graph
|
||||
related_edges_list: list[list[EntityEdge]] = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
get_relevant_edges(
|
||||
self.driver,
|
||||
@ -404,7 +403,7 @@ class Graphiti:
|
||||
)
|
||||
|
||||
existing_source_edges_list: list[list[EntityEdge]] = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
get_relevant_edges(
|
||||
self.driver,
|
||||
@ -419,7 +418,7 @@ class Graphiti:
|
||||
)
|
||||
|
||||
existing_target_edges_list: list[list[EntityEdge]] = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
get_relevant_edges(
|
||||
self.driver,
|
||||
@ -468,7 +467,7 @@ class Graphiti:
|
||||
|
||||
# Update any communities
|
||||
if update_communities:
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
update_community(self.driver, self.llm_client, self.embedder, node)
|
||||
for node in nodes
|
||||
@ -538,7 +537,7 @@ class Graphiti:
|
||||
]
|
||||
|
||||
# Save all the episodes
|
||||
await asyncio.gather(*[episode.save(self.driver) for episode in episodes])
|
||||
await semaphore_gather(*[episode.save(self.driver) for episode in episodes])
|
||||
|
||||
# Get previous episode context for each episode
|
||||
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
||||
@ -551,19 +550,19 @@ class Graphiti:
|
||||
) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
|
||||
|
||||
# Generate embeddings
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
|
||||
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
|
||||
)
|
||||
|
||||
# Dedupe extracted nodes, compress extracted edges
|
||||
(nodes, uuid_map), extracted_edges_timestamped = await asyncio.gather(
|
||||
(nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
|
||||
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
|
||||
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
|
||||
)
|
||||
|
||||
# save nodes to KG
|
||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||
await semaphore_gather(*[node.save(self.driver) for node in nodes])
|
||||
|
||||
# re-map edge pointers so that they don't point to discard dupe nodes
|
||||
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
|
||||
@ -574,7 +573,7 @@ class Graphiti:
|
||||
)
|
||||
|
||||
# save episodic edges to KG
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
|
||||
)
|
||||
|
||||
@ -587,7 +586,7 @@ class Graphiti:
|
||||
# invalidate edges
|
||||
|
||||
# save edges to KG
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
|
||||
await semaphore_gather(*[edge.save(self.driver) for edge in edges])
|
||||
|
||||
end = time()
|
||||
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
|
||||
@ -610,12 +609,12 @@ class Graphiti:
|
||||
self.driver, self.llm_client, group_ids
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
|
||||
)
|
||||
|
||||
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
||||
await semaphore_gather(*[node.save(self.driver) for node in community_nodes])
|
||||
await semaphore_gather(*[edge.save(self.driver) for edge in community_edges])
|
||||
|
||||
return community_nodes
|
||||
|
||||
@ -698,7 +697,7 @@ class Graphiti:
|
||||
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
|
||||
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
||||
|
||||
edges_list = await asyncio.gather(
|
||||
edges_list = await semaphore_gather(
|
||||
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
||||
)
|
||||
|
||||
|
@ -14,7 +14,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import Coroutine
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
@ -25,6 +27,7 @@ load_dotenv()
|
||||
|
||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
||||
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
||||
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
||||
MAX_REFLEXION_ITERATIONS = 2
|
||||
DEFAULT_PAGE_LIMIT = 20
|
||||
|
||||
@ -80,3 +83,19 @@ def normalize_l2(embedding: list[float]) -> list[float]:
|
||||
else:
|
||||
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
|
||||
return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
|
||||
|
||||
|
||||
# Use this instead of asyncio.gather() to bound coroutines
|
||||
async def semaphore_gather(
|
||||
*coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT, return_exceptions=True
|
||||
):
|
||||
semaphore = asyncio.Semaphore(max_coroutines)
|
||||
|
||||
async def _wrap_coroutine(coroutine):
|
||||
async with semaphore:
|
||||
return await coroutine
|
||||
|
||||
return await asyncio.gather(
|
||||
*(_wrap_coroutine(coroutine) for coroutine in coroutines),
|
||||
return_exceptions=return_exceptions,
|
||||
)
|
||||
|
@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from time import time
|
||||
@ -25,6 +24,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import SearchRerankerError
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||
from graphiti_core.search.search_config import (
|
||||
DEFAULT_SEARCH_LIMIT,
|
||||
@ -78,7 +78,7 @@ async def search(
|
||||
|
||||
# if group_ids is empty, set it to None
|
||||
group_ids = group_ids if group_ids else None
|
||||
edges, nodes, communities = await asyncio.gather(
|
||||
edges, nodes, communities = await semaphore_gather(
|
||||
edge_search(
|
||||
driver,
|
||||
cross_encoder,
|
||||
@ -141,7 +141,7 @@ async def edge_search(
|
||||
return []
|
||||
|
||||
search_results: list[list[EntityEdge]] = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
edge_fulltext_search(driver, query, group_ids, 2 * limit),
|
||||
edge_similarity_search(
|
||||
@ -226,7 +226,7 @@ async def node_search(
|
||||
return []
|
||||
|
||||
search_results: list[list[EntityNode]] = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
node_fulltext_search(driver, query, group_ids, 2 * limit),
|
||||
node_similarity_search(
|
||||
@ -295,7 +295,7 @@ async def community_search(
|
||||
return []
|
||||
|
||||
search_results: list[list[CommunityNode]] = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
community_fulltext_search(driver, query, group_ids, 2 * limit),
|
||||
community_similarity_search(
|
||||
|
@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from time import time
|
||||
@ -30,6 +29,7 @@ from graphiti_core.helpers import (
|
||||
USE_PARALLEL_RUNTIME,
|
||||
lucene_sanitize,
|
||||
normalize_l2,
|
||||
semaphore_gather,
|
||||
)
|
||||
from graphiti_core.nodes import (
|
||||
CommunityNode,
|
||||
@ -549,7 +549,7 @@ async def hybrid_node_search(
|
||||
|
||||
start = time()
|
||||
results: list[list[EntityNode]] = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
|
||||
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
|
||||
)
|
||||
@ -619,7 +619,7 @@ async def get_relevant_edges(
|
||||
relevant_edges: list[EntityEdge] = []
|
||||
relevant_edge_uuids = set()
|
||||
|
||||
results = await asyncio.gather(
|
||||
results = await semaphore_gather(
|
||||
*[
|
||||
edge_similarity_search(
|
||||
driver,
|
||||
|
@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from collections import defaultdict
|
||||
@ -26,6 +25,7 @@ from numpy import dot, sqrt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
ENTITY_EDGE_SAVE_BULK,
|
||||
@ -71,7 +71,7 @@ class RawEpisode(BaseModel):
|
||||
async def retrieve_previous_episodes_bulk(
|
||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
||||
previous_episodes_list = await asyncio.gather(
|
||||
previous_episodes_list = await semaphore_gather(
|
||||
*[
|
||||
retrieve_episodes(
|
||||
driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
|
||||
@ -118,7 +118,7 @@ async def add_nodes_and_edges_bulk_tx(
|
||||
async def extract_nodes_and_edges_bulk(
|
||||
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
||||
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
||||
extracted_nodes_bulk = await asyncio.gather(
|
||||
extracted_nodes_bulk = await semaphore_gather(
|
||||
*[
|
||||
extract_nodes(llm_client, episode, previous_episodes)
|
||||
for episode, previous_episodes in episode_tuples
|
||||
@ -130,7 +130,7 @@ async def extract_nodes_and_edges_bulk(
|
||||
[episode[1] for episode in episode_tuples],
|
||||
)
|
||||
|
||||
extracted_edges_bulk = await asyncio.gather(
|
||||
extracted_edges_bulk = await semaphore_gather(
|
||||
*[
|
||||
extract_edges(
|
||||
llm_client,
|
||||
@ -171,13 +171,13 @@ async def dedupe_nodes_bulk(
|
||||
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
||||
|
||||
existing_nodes_chunks: list[list[EntityNode]] = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks]
|
||||
)
|
||||
)
|
||||
|
||||
results: list[tuple[list[EntityNode], dict[str, str]]] = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
|
||||
for i, node_chunk in enumerate(node_chunks)
|
||||
@ -205,13 +205,13 @@ async def dedupe_edges_bulk(
|
||||
]
|
||||
|
||||
relevant_edges_chunks: list[list[EntityEdge]] = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
|
||||
)
|
||||
)
|
||||
|
||||
resolved_edge_chunks: list[list[EntityEdge]] = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
|
||||
for i, edge_chunk in enumerate(edge_chunks)
|
||||
@ -292,7 +292,9 @@ async def compress_nodes(
|
||||
# add both nodes to the shortest chunk
|
||||
node_chunks[-1].extend([n, m])
|
||||
|
||||
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
|
||||
results = await semaphore_gather(
|
||||
*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
|
||||
)
|
||||
|
||||
extended_map = dict(uuid_map)
|
||||
compressed_nodes: list[EntityNode] = []
|
||||
@ -315,7 +317,9 @@ async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list
|
||||
# We build a map of the edges based on their source and target nodes.
|
||||
edge_chunks = chunk_edges_by_nodes(edges)
|
||||
|
||||
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
|
||||
results = await semaphore_gather(
|
||||
*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
|
||||
)
|
||||
|
||||
compressed_edges: list[EntityEdge] = []
|
||||
for edge_chunk in results:
|
||||
@ -368,7 +372,7 @@ async def extract_edge_dates_bulk(
|
||||
episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
|
||||
}
|
||||
|
||||
results = await asyncio.gather(
|
||||
results = await semaphore_gather(
|
||||
*[
|
||||
extract_edge_dates(
|
||||
llm_client,
|
||||
|
@ -7,7 +7,7 @@ from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.edges import CommunityEdge
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import (
|
||||
CommunityNode,
|
||||
@ -71,7 +71,7 @@ async def get_community_clusters(
|
||||
|
||||
community_clusters.extend(
|
||||
list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids]
|
||||
)
|
||||
)
|
||||
@ -164,7 +164,7 @@ async def build_community(
|
||||
odd_one_out = summaries.pop()
|
||||
length -= 1
|
||||
new_summaries: list[str] = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
|
||||
for left_summary, right_summary in zip(
|
||||
@ -207,7 +207,9 @@ async def build_communities(
|
||||
return await build_community(llm_client, cluster)
|
||||
|
||||
communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
|
||||
await asyncio.gather(*[limited_build_community(cluster) for cluster in community_clusters])
|
||||
await semaphore_gather(
|
||||
*[limited_build_community(cluster) for cluster in community_clusters]
|
||||
)
|
||||
)
|
||||
|
||||
community_nodes: list[CommunityNode] = []
|
||||
|
@ -14,13 +14,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
|
||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||
from graphiti_core.prompts import prompt_library
|
||||
@ -199,7 +198,7 @@ async def resolve_extracted_edges(
|
||||
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
||||
# resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
|
||||
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
resolve_extracted_edge(
|
||||
llm_client,
|
||||
@ -266,7 +265,7 @@ async def resolve_extracted_edge(
|
||||
current_episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> tuple[EntityEdge, list[EntityEdge]]:
|
||||
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
|
||||
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await semaphore_gather(
|
||||
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
|
||||
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
|
||||
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
|
||||
|
@ -14,14 +14,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||
|
||||
EPISODE_WINDOW_LEN = 3
|
||||
@ -38,7 +37,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
index_names = [record['name'] for record in records]
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
driver.execute_query(
|
||||
"""DROP INDEX $name""",
|
||||
@ -82,7 +81,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
||||
|
||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
driver.execute_query(
|
||||
query,
|
||||
|
@ -14,11 +14,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from time import time
|
||||
|
||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
|
||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.prompts import prompt_library
|
||||
@ -223,7 +222,7 @@ async def resolve_extracted_nodes(
|
||||
uuid_map: dict[str, str] = {}
|
||||
resolved_nodes: list[EntityNode] = []
|
||||
results: list[tuple[EntityNode, dict[str, str]]] = list(
|
||||
await asyncio.gather(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
resolve_extracted_node(
|
||||
llm_client, extracted_node, existing_nodes, episode, previous_episodes
|
||||
@ -275,7 +274,7 @@ async def resolve_extracted_node(
|
||||
else [],
|
||||
}
|
||||
|
||||
llm_response, node_summary_response = await asyncio.gather(
|
||||
llm_response, node_summary_response = await semaphore_gather(
|
||||
llm_client.generate_response(
|
||||
prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate
|
||||
),
|
||||
|
@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@ -25,6 +24,7 @@ from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.graphiti import Graphiti
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.search.search_config_recipes import (
|
||||
COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
|
||||
@ -137,8 +137,8 @@ async def test_graph_integration():
|
||||
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
|
||||
|
||||
# test save
|
||||
await asyncio.gather(*[node.save(driver) for node in nodes])
|
||||
await asyncio.gather(*[edge.save(driver) for edge in edges])
|
||||
await semaphore_gather(*[node.save(driver) for node in nodes])
|
||||
await semaphore_gather(*[edge.save(driver) for edge in edges])
|
||||
|
||||
# test get
|
||||
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
|
||||
@ -147,5 +147,5 @@ async def test_graph_integration():
|
||||
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
|
||||
|
||||
# test delete
|
||||
await asyncio.gather(*[node.delete(driver) for node in nodes])
|
||||
await asyncio.gather(*[edge.delete(driver) for edge in edges])
|
||||
await semaphore_gather(*[node.delete(driver) for node in nodes])
|
||||
await semaphore_gather(*[edge.delete(driver) for edge in edges])
|
||||
|
Loading…
x
Reference in New Issue
Block a user