diff --git a/examples/podcast/transcript_parser.py b/examples/podcast/transcript_parser.py index 2466dba0..8dce6ea1 100644 --- a/examples/podcast/transcript_parser.py +++ b/examples/podcast/transcript_parser.py @@ -1,7 +1,6 @@ import os import re from datetime import datetime, timedelta, timezone -from typing import List from pydantic import BaseModel @@ -36,7 +35,7 @@ def parse_timestamp(timestamp: str) -> timedelta: return timedelta() # Return 0 duration if parsing fails -def parse_conversation_file(file_path: str, speakers: List[Speaker]) -> list[ParsedMessage]: +def parse_conversation_file(file_path: str, speakers: list[Speaker]) -> list[ParsedMessage]: with open(file_path) as file: content = file.read() diff --git a/graphiti_core/cross_encoder/bge_reranker_client.py b/graphiti_core/cross_encoder/bge_reranker_client.py index 100aabac..9cd2ac3a 100644 --- a/graphiti_core/cross_encoder/bge_reranker_client.py +++ b/graphiti_core/cross_encoder/bge_reranker_client.py @@ -15,7 +15,6 @@ limitations under the License. """ import asyncio -from typing import List, Tuple from sentence_transformers import CrossEncoder @@ -26,7 +25,7 @@ class BGERerankerClient(CrossEncoderClient): def __init__(self): self.model = CrossEncoder('BAAI/bge-reranker-v2-m3') - async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]: + async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]: if not passages: return [] diff --git a/graphiti_core/cross_encoder/client.py b/graphiti_core/cross_encoder/client.py index 989c0a53..1664d79d 100644 --- a/graphiti_core/cross_encoder/client.py +++ b/graphiti_core/cross_encoder/client.py @@ -15,7 +15,6 @@ limitations under the License. """ from abc import ABC, abstractmethod -from typing import List, Tuple class CrossEncoderClient(ABC): @@ -26,16 +25,16 @@ class CrossEncoderClient(ABC): """ @abstractmethod - async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]: + async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]: """ Rank the given passages based on their relevance to the query. Args: query (str): The query string. - passages (List[str]): A list of passages to rank. + passages (list[str]): A list of passages to rank. Returns: - List[Tuple[str, float]]: A list of tuples containing the passage and its score, + List[tuple[str, float]]: A list of tuples containing the passage and its score, sorted in descending order of relevance. """ pass diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index e9b065ac..8e122278 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -23,10 +23,11 @@ from uuid import uuid4 from neo4j import AsyncDriver from pydantic import BaseModel, Field +from typing_extensions import LiteralString from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError -from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date +from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT, parse_db_date from graphiti_core.models.edges.edge_db_queries import ( COMMUNITY_EDGE_SAVE, ENTITY_EDGE_SAVE, @@ -50,7 +51,7 @@ class Edge(BaseModel, ABC): async def delete(self, driver: AsyncDriver): result = await driver.execute_query( """ - MATCH (n)-[e {uuid: $uuid}]->(m) + MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m) DELETE e """, uuid=self.uuid, @@ -137,19 +138,34 @@ class EpisodicEdge(Edge): return edges @classmethod - async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): + async def get_by_group_ids( + cls, + driver: AsyncDriver, + group_ids: list[str], + limit: int = DEFAULT_PAGE_LIMIT, + created_at: datetime | None = None, + ): + cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else '' + records, _, _ = await driver.execute_query( """ MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity) WHERE e.group_id IN $group_ids + """ + + cursor_query + + """ RETURN e.uuid As uuid, e.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, e.created_at AS created_at + ORDER BY e.uuid DESC + LIMIT $limit """, group_ids=group_ids, + created_at=created_at, + limit=limit, database_=DEFAULT_DATABASE, routing_='r', ) @@ -274,11 +290,22 @@ class EntityEdge(Edge): return edges @classmethod - async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): + async def get_by_group_ids( + cls, + driver: AsyncDriver, + group_ids: list[str], + limit: int = DEFAULT_PAGE_LIMIT, + created_at: datetime | None = None, + ): + cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else '' + records, _, _ = await driver.execute_query( """ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) WHERE e.group_id IN $group_ids + """ + + cursor_query + + """ RETURN e.uuid AS uuid, n.uuid AS source_node_uuid, @@ -292,8 +319,12 @@ class EntityEdge(Edge): e.expired_at AS expired_at, e.valid_at AS valid_at, e.invalid_at AS invalid_at + ORDER BY e.uuid DESC + LIMIT $limit """, group_ids=group_ids, + created_at=created_at, + limit=limit, database_=DEFAULT_DATABASE, routing_='r', ) @@ -365,19 +396,34 @@ class CommunityEdge(Edge): return edges @classmethod - async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): + async def get_by_group_ids( + cls, + driver: AsyncDriver, + group_ids: list[str], + limit: int = DEFAULT_PAGE_LIMIT, + created_at: datetime | None = None, + ): + cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else '' + records, _, _ = await driver.execute_query( """ MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community) WHERE e.group_id IN $group_ids + """ + + cursor_query + + """ RETURN e.uuid As uuid, e.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, e.created_at AS created_at + ORDER BY e.uuid DESC + LIMIT $limit """, group_ids=group_ids, + created_at=created_at, + limit=limit, database_=DEFAULT_DATABASE, routing_='r', ) diff --git a/graphiti_core/embedder/client.py b/graphiti_core/embedder/client.py index e120e203..8b8a15f3 100644 --- a/graphiti_core/embedder/client.py +++ b/graphiti_core/embedder/client.py @@ -15,7 +15,7 @@ limitations under the License. """ from abc import ABC, abstractmethod -from typing import Iterable, List, Literal +from collections.abc import Iterable from pydantic import BaseModel, Field @@ -23,12 +23,12 @@ EMBEDDING_DIM = 1024 class EmbedderConfig(BaseModel): - embedding_dim: Literal[1024] = Field(default=EMBEDDING_DIM, frozen=True) + embedding_dim: int = Field(default=EMBEDDING_DIM, frozen=True) class EmbedderClient(ABC): @abstractmethod async def create( - self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]] + self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]] ) -> list[float]: pass diff --git a/graphiti_core/embedder/openai.py b/graphiti_core/embedder/openai.py index 8436df68..6f5f86d2 100644 --- a/graphiti_core/embedder/openai.py +++ b/graphiti_core/embedder/openai.py @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -from typing import Iterable, List +from collections.abc import Iterable from openai import AsyncOpenAI from openai.types import EmbeddingModel @@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient): self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) async def create( - self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]] + self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]] ) -> list[float]: result = await self.client.embeddings.create( input=input_data, model=self.config.embedding_model diff --git a/graphiti_core/embedder/voyage.py b/graphiti_core/embedder/voyage.py index 3aa33631..4ef894e1 100644 --- a/graphiti_core/embedder/voyage.py +++ b/graphiti_core/embedder/voyage.py @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -from typing import Iterable, List +from collections.abc import Iterable import voyageai # type: ignore from pydantic import Field @@ -41,11 +41,11 @@ class VoyageAIEmbedder(EmbedderClient): self.client = voyageai.AsyncClient(api_key=config.api_key) async def create( - self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]] + self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]] ) -> list[float]: if isinstance(input_data, str): input_list = [input_data] - elif isinstance(input_data, List): + elif isinstance(input_data, list): input_list = [str(i) for i in input_data if i] else: input_list = [str(i) for i in input_data if i is not None] diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 253509c9..5455f9ff 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -26,6 +26,7 @@ load_dotenv() DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None) USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False)) MAX_REFLEXION_ITERATIONS = 2 +DEFAULT_PAGE_LIMIT = 20 def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None: diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 14629003..eb650c8e 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -24,10 +24,11 @@ from uuid import uuid4 from neo4j import AsyncDriver from pydantic import BaseModel, Field +from typing_extensions import LiteralString from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import NodeNotFoundError -from graphiti_core.helpers import DEFAULT_DATABASE +from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT from graphiti_core.models.nodes.node_db_queries import ( COMMUNITY_NODE_SAVE, ENTITY_NODE_SAVE, @@ -207,10 +208,21 @@ class EpisodicNode(Node): return episodes @classmethod - async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): + async def get_by_group_ids( + cls, + driver: AsyncDriver, + group_ids: list[str], + limit: int = DEFAULT_PAGE_LIMIT, + created_at: datetime | None = None, + ): + cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else '' + records, _, _ = await driver.execute_query( """ MATCH (e:Episodic) WHERE e.group_id IN $group_ids + """ + + cursor_query + + """ RETURN DISTINCT e.content AS content, e.created_at AS created_at, @@ -220,8 +232,12 @@ class EpisodicNode(Node): e.group_id AS group_id, e.source_description AS source_description, e.source AS source + ORDER BY e.uuid DESC + LIMIT $limit """, group_ids=group_ids, + created_at=created_at, + limit=limit, database_=DEFAULT_DATABASE, routing_='r', ) @@ -308,10 +324,21 @@ class EntityNode(Node): return nodes @classmethod - async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): + async def get_by_group_ids( + cls, + driver: AsyncDriver, + group_ids: list[str], + limit: int = DEFAULT_PAGE_LIMIT, + created_at: datetime | None = None, + ): + cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else '' + records, _, _ = await driver.execute_query( """ MATCH (n:Entity) WHERE n.group_id IN $group_ids + """ + + cursor_query + + """ RETURN n.uuid As uuid, n.name AS name, @@ -319,8 +346,12 @@ class EntityNode(Node): n.group_id AS group_id, n.created_at AS created_at, n.summary AS summary + ORDER BY n.uuid DESC + LIMIT $limit """, group_ids=group_ids, + created_at=created_at, + limit=limit, database_=DEFAULT_DATABASE, routing_='r', ) @@ -407,10 +438,21 @@ class CommunityNode(Node): return communities @classmethod - async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): + async def get_by_group_ids( + cls, + driver: AsyncDriver, + group_ids: list[str], + limit: int = DEFAULT_PAGE_LIMIT, + created_at: datetime | None = None, + ): + cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else '' + records, _, _ = await driver.execute_query( """ MATCH (n:Community) WHERE n.group_id IN $group_ids + """ + + cursor_query + + """ RETURN n.uuid As uuid, n.name AS name, @@ -418,8 +460,12 @@ class CommunityNode(Node): n.group_id AS group_id, n.created_at AS created_at, n.summary AS summary + ORDER BY n.uuid DESC + LIMIT $limit """, group_ids=group_ids, + created_at=created_at, + limit=limit, database_=DEFAULT_DATABASE, routing_='r', ) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 32eef8bf..f541fafa 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -40,7 +40,7 @@ from graphiti_core.nodes import ( logger = logging.getLogger(__name__) -RELEVANT_SCHEMA_LIMIT = 3 +RELEVANT_SCHEMA_LIMIT = 10 DEFAULT_MIN_SCORE = 0.6 DEFAULT_MMR_LAMBDA = 0.5 MAX_SEARCH_DEPTH = 3 diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index a7375ccd..90f38a65 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -18,7 +18,6 @@ import asyncio import logging from datetime import datetime, timezone from time import time -from typing import List from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS @@ -34,11 +33,11 @@ logger = logging.getLogger(__name__) def build_episodic_edges( - entity_nodes: List[EntityNode], + entity_nodes: list[EntityNode], episode: EpisodicNode, created_at: datetime, -) -> List[EpisodicEdge]: - edges: List[EpisodicEdge] = [ +) -> list[EpisodicEdge]: + edges: list[EpisodicEdge] = [ EpisodicEdge( source_node_uuid=episode.uuid, target_node_uuid=node.uuid, @@ -52,11 +51,11 @@ def build_episodic_edges( def build_community_edges( - entity_nodes: List[EntityNode], + entity_nodes: list[EntityNode], community_node: CommunityNode, created_at: datetime, -) -> List[CommunityEdge]: - edges: List[CommunityEdge] = [ +) -> list[CommunityEdge]: + edges: list[CommunityEdge] = [ CommunityEdge( source_node_uuid=community_node.uuid, target_node_uuid=node.uuid, diff --git a/graphiti_core/utils/maintenance/temporal_operations.py b/graphiti_core/utils/maintenance/temporal_operations.py index 6f740c2a..c95e4bb0 100644 --- a/graphiti_core/utils/maintenance/temporal_operations.py +++ b/graphiti_core/utils/maintenance/temporal_operations.py @@ -17,7 +17,6 @@ limitations under the License. import logging from datetime import datetime from time import time -from typing import List from graphiti_core.edges import EntityEdge from graphiti_core.llm_client import LLMClient @@ -31,7 +30,7 @@ async def extract_edge_dates( llm_client: LLMClient, edge: EntityEdge, current_episode: EpisodicNode, - previous_episodes: List[EpisodicNode], + previous_episodes: list[EpisodicNode], ) -> tuple[datetime | None, datetime | None]: context = { 'edge_fact': edge.fact,