Pagination for get by group_id (#218)

* add pagination to subgraphs

* update pagination

* update LiteralString import

* cleanup

* cleanup

* update embedding dims
This commit is contained in:
Preston Rasmussen 2024-12-02 11:17:37 -05:00 committed by GitHub
parent 397291de4b
commit 0fbe5c0704
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 123 additions and 35 deletions

View File

@ -1,7 +1,6 @@
import os
import re
from datetime import datetime, timedelta, timezone
from typing import List
from pydantic import BaseModel
@ -36,7 +35,7 @@ def parse_timestamp(timestamp: str) -> timedelta:
return timedelta() # Return 0 duration if parsing fails
def parse_conversation_file(file_path: str, speakers: List[Speaker]) -> list[ParsedMessage]:
def parse_conversation_file(file_path: str, speakers: list[Speaker]) -> list[ParsedMessage]:
with open(file_path) as file:
content = file.read()

View File

@ -15,7 +15,6 @@ limitations under the License.
"""
import asyncio
from typing import List, Tuple
from sentence_transformers import CrossEncoder
@ -26,7 +25,7 @@ class BGERerankerClient(CrossEncoderClient):
def __init__(self):
self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
if not passages:
return []

View File

@ -15,7 +15,6 @@ limitations under the License.
"""
from abc import ABC, abstractmethod
from typing import List, Tuple
class CrossEncoderClient(ABC):
@ -26,16 +25,16 @@ class CrossEncoderClient(ABC):
"""
@abstractmethod
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
"""
Rank the given passages based on their relevance to the query.
Args:
query (str): The query string.
passages (List[str]): A list of passages to rank.
passages (list[str]): A list of passages to rank.
Returns:
List[Tuple[str, float]]: A list of tuples containing the passage and its score,
List[tuple[str, float]]: A list of tuples containing the passage and its score,
sorted in descending order of relevance.
"""
pass

View File

@ -23,10 +23,11 @@ from uuid import uuid4
from neo4j import AsyncDriver
from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT, parse_db_date
from graphiti_core.models.edges.edge_db_queries import (
COMMUNITY_EDGE_SAVE,
ENTITY_EDGE_SAVE,
@ -50,7 +51,7 @@ class Edge(BaseModel, ABC):
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n)-[e {uuid: $uuid}]->(m)
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
DELETE e
""",
uuid=self.uuid,
@ -137,19 +138,34 @@ class EpisodicEdge(Edge):
return edges
@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
ORDER BY e.uuid DESC
LIMIT $limit
""",
group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -274,11 +290,22 @@ class EntityEdge(Edge):
return edges
@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
e.uuid AS uuid,
n.uuid AS source_node_uuid,
@ -292,8 +319,12 @@ class EntityEdge(Edge):
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at
ORDER BY e.uuid DESC
LIMIT $limit
""",
group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -365,19 +396,34 @@ class CommunityEdge(Edge):
return edges
@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
ORDER BY e.uuid DESC
LIMIT $limit
""",
group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)

View File

@ -15,7 +15,7 @@ limitations under the License.
"""
from abc import ABC, abstractmethod
from typing import Iterable, List, Literal
from collections.abc import Iterable
from pydantic import BaseModel, Field
@ -23,12 +23,12 @@ EMBEDDING_DIM = 1024
class EmbedderConfig(BaseModel):
embedding_dim: Literal[1024] = Field(default=EMBEDDING_DIM, frozen=True)
embedding_dim: int = Field(default=EMBEDDING_DIM, frozen=True)
class EmbedderClient(ABC):
@abstractmethod
async def create(
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
pass

View File

@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Iterable, List
from collections.abc import Iterable
from openai import AsyncOpenAI
from openai.types import EmbeddingModel
@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient):
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
async def create(
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
result = await self.client.embeddings.create(
input=input_data, model=self.config.embedding_model

View File

@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Iterable, List
from collections.abc import Iterable
import voyageai # type: ignore
from pydantic import Field
@ -41,11 +41,11 @@ class VoyageAIEmbedder(EmbedderClient):
self.client = voyageai.AsyncClient(api_key=config.api_key)
async def create(
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
if isinstance(input_data, str):
input_list = [input_data]
elif isinstance(input_data, List):
elif isinstance(input_data, list):
input_list = [str(i) for i in input_data if i]
else:
input_list = [str(i) for i in input_data if i is not None]

View File

@ -26,6 +26,7 @@ load_dotenv()
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
MAX_REFLEXION_ITERATIONS = 2
DEFAULT_PAGE_LIMIT = 20
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:

View File

@ -24,10 +24,11 @@ from uuid import uuid4
from neo4j import AsyncDriver
from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import NodeNotFoundError
from graphiti_core.helpers import DEFAULT_DATABASE
from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT
from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_SAVE,
ENTITY_NODE_SAVE,
@ -207,10 +208,21 @@ class EpisodicNode(Node):
return episodes
@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic) WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN DISTINCT
e.content AS content,
e.created_at AS created_at,
@ -220,8 +232,12 @@ class EpisodicNode(Node):
e.group_id AS group_id,
e.source_description AS source_description,
e.source AS source
ORDER BY e.uuid DESC
LIMIT $limit
""",
group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -308,10 +324,21 @@ class EntityNode(Node):
return nodes
@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity) WHERE n.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
n.uuid As uuid,
n.name AS name,
@ -319,8 +346,12 @@ class EntityNode(Node):
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary
ORDER BY n.uuid DESC
LIMIT $limit
""",
group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -407,10 +438,21 @@ class CommunityNode(Node):
return communities
@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community) WHERE n.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
n.uuid As uuid,
n.name AS name,
@ -418,8 +460,12 @@ class CommunityNode(Node):
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary
ORDER BY n.uuid DESC
LIMIT $limit
""",
group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)

View File

@ -40,7 +40,7 @@ from graphiti_core.nodes import (
logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
RELEVANT_SCHEMA_LIMIT = 10
DEFAULT_MIN_SCORE = 0.6
DEFAULT_MMR_LAMBDA = 0.5
MAX_SEARCH_DEPTH = 3

View File

@ -18,7 +18,6 @@ import asyncio
import logging
from datetime import datetime, timezone
from time import time
from typing import List
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
@ -34,11 +33,11 @@ logger = logging.getLogger(__name__)
def build_episodic_edges(
entity_nodes: List[EntityNode],
entity_nodes: list[EntityNode],
episode: EpisodicNode,
created_at: datetime,
) -> List[EpisodicEdge]:
edges: List[EpisodicEdge] = [
) -> list[EpisodicEdge]:
edges: list[EpisodicEdge] = [
EpisodicEdge(
source_node_uuid=episode.uuid,
target_node_uuid=node.uuid,
@ -52,11 +51,11 @@ def build_episodic_edges(
def build_community_edges(
entity_nodes: List[EntityNode],
entity_nodes: list[EntityNode],
community_node: CommunityNode,
created_at: datetime,
) -> List[CommunityEdge]:
edges: List[CommunityEdge] = [
) -> list[CommunityEdge]:
edges: list[CommunityEdge] = [
CommunityEdge(
source_node_uuid=community_node.uuid,
target_node_uuid=node.uuid,

View File

@ -17,7 +17,6 @@ limitations under the License.
import logging
from datetime import datetime
from time import time
from typing import List
from graphiti_core.edges import EntityEdge
from graphiti_core.llm_client import LLMClient
@ -31,7 +30,7 @@ async def extract_edge_dates(
llm_client: LLMClient,
edge: EntityEdge,
current_episode: EpisodicNode,
previous_episodes: List[EpisodicNode],
previous_episodes: list[EpisodicNode],
) -> tuple[datetime | None, datetime | None]:
context = {
'edge_fact': edge.fact,