mirror of
https://github.com/getzep/graphiti.git
synced 2025-12-29 08:05:02 +00:00
Pagination for get by group_id (#218)
* add pagination to subgraphs * update pagination * update LiteralString import * cleanup * cleanup * update embedding dims
This commit is contained in:
parent
397291de4b
commit
0fbe5c0704
@ -1,7 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -36,7 +35,7 @@ def parse_timestamp(timestamp: str) -> timedelta:
|
||||
return timedelta() # Return 0 duration if parsing fails
|
||||
|
||||
|
||||
def parse_conversation_file(file_path: str, speakers: List[Speaker]) -> list[ParsedMessage]:
|
||||
def parse_conversation_file(file_path: str, speakers: list[Speaker]) -> list[ParsedMessage]:
|
||||
with open(file_path) as file:
|
||||
content = file.read()
|
||||
|
||||
|
||||
@ -15,7 +15,6 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Tuple
|
||||
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
@ -26,7 +25,7 @@ class BGERerankerClient(CrossEncoderClient):
|
||||
def __init__(self):
|
||||
self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')
|
||||
|
||||
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
|
||||
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
||||
if not passages:
|
||||
return []
|
||||
|
||||
|
||||
@ -15,7 +15,6 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class CrossEncoderClient(ABC):
|
||||
@ -26,16 +25,16 @@ class CrossEncoderClient(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
|
||||
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
||||
"""
|
||||
Rank the given passages based on their relevance to the query.
|
||||
|
||||
Args:
|
||||
query (str): The query string.
|
||||
passages (List[str]): A list of passages to rank.
|
||||
passages (list[str]): A list of passages to rank.
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, float]]: A list of tuples containing the passage and its score,
|
||||
List[tuple[str, float]]: A list of tuples containing the passage and its score,
|
||||
sorted in descending order of relevance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -23,10 +23,11 @@ from uuid import uuid4
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT, parse_db_date
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
COMMUNITY_EDGE_SAVE,
|
||||
ENTITY_EDGE_SAVE,
|
||||
@ -50,7 +51,7 @@ class Edge(BaseModel, ABC):
|
||||
async def delete(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n)-[e {uuid: $uuid}]->(m)
|
||||
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
||||
DELETE e
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
@ -137,19 +138,34 @@ class EpisodicEdge(Edge):
|
||||
return edges
|
||||
|
||||
@classmethod
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||
async def get_by_group_ids(
|
||||
cls,
|
||||
driver: AsyncDriver,
|
||||
group_ids: list[str],
|
||||
limit: int = DEFAULT_PAGE_LIMIT,
|
||||
created_at: datetime | None = None,
|
||||
):
|
||||
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ """
|
||||
RETURN
|
||||
e.uuid As uuid,
|
||||
e.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at
|
||||
ORDER BY e.uuid DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
@ -274,11 +290,22 @@ class EntityEdge(Edge):
|
||||
return edges
|
||||
|
||||
@classmethod
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||
async def get_by_group_ids(
|
||||
cls,
|
||||
driver: AsyncDriver,
|
||||
group_ids: list[str],
|
||||
limit: int = DEFAULT_PAGE_LIMIT,
|
||||
created_at: datetime | None = None,
|
||||
):
|
||||
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ """
|
||||
RETURN
|
||||
e.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
@ -292,8 +319,12 @@ class EntityEdge(Edge):
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at
|
||||
ORDER BY e.uuid DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
@ -365,19 +396,34 @@ class CommunityEdge(Edge):
|
||||
return edges
|
||||
|
||||
@classmethod
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||
async def get_by_group_ids(
|
||||
cls,
|
||||
driver: AsyncDriver,
|
||||
group_ids: list[str],
|
||||
limit: int = DEFAULT_PAGE_LIMIT,
|
||||
created_at: datetime | None = None,
|
||||
):
|
||||
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
||||
WHERE e.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ """
|
||||
RETURN
|
||||
e.uuid As uuid,
|
||||
e.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at
|
||||
ORDER BY e.uuid DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -15,7 +15,7 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable, List, Literal
|
||||
from collections.abc import Iterable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -23,12 +23,12 @@ EMBEDDING_DIM = 1024
|
||||
|
||||
|
||||
class EmbedderConfig(BaseModel):
|
||||
embedding_dim: Literal[1024] = Field(default=EMBEDDING_DIM, frozen=True)
|
||||
embedding_dim: int = Field(default=EMBEDDING_DIM, frozen=True)
|
||||
|
||||
|
||||
class EmbedderClient(ABC):
|
||||
@abstractmethod
|
||||
async def create(
|
||||
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
) -> list[float]:
|
||||
pass
|
||||
|
||||
@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Iterable, List
|
||||
from collections.abc import Iterable
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types import EmbeddingModel
|
||||
@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient):
|
||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
async def create(
|
||||
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
) -> list[float]:
|
||||
result = await self.client.embeddings.create(
|
||||
input=input_data, model=self.config.embedding_model
|
||||
|
||||
@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Iterable, List
|
||||
from collections.abc import Iterable
|
||||
|
||||
import voyageai # type: ignore
|
||||
from pydantic import Field
|
||||
@ -41,11 +41,11 @@ class VoyageAIEmbedder(EmbedderClient):
|
||||
self.client = voyageai.AsyncClient(api_key=config.api_key)
|
||||
|
||||
async def create(
|
||||
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
) -> list[float]:
|
||||
if isinstance(input_data, str):
|
||||
input_list = [input_data]
|
||||
elif isinstance(input_data, List):
|
||||
elif isinstance(input_data, list):
|
||||
input_list = [str(i) for i in input_data if i]
|
||||
else:
|
||||
input_list = [str(i) for i in input_data if i is not None]
|
||||
|
||||
@ -26,6 +26,7 @@ load_dotenv()
|
||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
||||
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
||||
MAX_REFLEXION_ITERATIONS = 2
|
||||
DEFAULT_PAGE_LIMIT = 20
|
||||
|
||||
|
||||
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
||||
|
||||
@ -24,10 +24,11 @@ from uuid import uuid4
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import NodeNotFoundError
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
COMMUNITY_NODE_SAVE,
|
||||
ENTITY_NODE_SAVE,
|
||||
@ -207,10 +208,21 @@ class EpisodicNode(Node):
|
||||
return episodes
|
||||
|
||||
@classmethod
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||
async def get_by_group_ids(
|
||||
cls,
|
||||
driver: AsyncDriver,
|
||||
group_ids: list[str],
|
||||
limit: int = DEFAULT_PAGE_LIMIT,
|
||||
created_at: datetime | None = None,
|
||||
):
|
||||
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ """
|
||||
RETURN DISTINCT
|
||||
e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
@ -220,8 +232,12 @@ class EpisodicNode(Node):
|
||||
e.group_id AS group_id,
|
||||
e.source_description AS source_description,
|
||||
e.source AS source
|
||||
ORDER BY e.uuid DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
@ -308,10 +324,21 @@ class EntityNode(Node):
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||
async def get_by_group_ids(
|
||||
cls,
|
||||
driver: AsyncDriver,
|
||||
group_ids: list[str],
|
||||
limit: int = DEFAULT_PAGE_LIMIT,
|
||||
created_at: datetime | None = None,
|
||||
):
|
||||
cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity) WHERE n.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ """
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
@ -319,8 +346,12 @@ class EntityNode(Node):
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY n.uuid DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
@ -407,10 +438,21 @@ class CommunityNode(Node):
|
||||
return communities
|
||||
|
||||
@classmethod
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||
async def get_by_group_ids(
|
||||
cls,
|
||||
driver: AsyncDriver,
|
||||
group_ids: list[str],
|
||||
limit: int = DEFAULT_PAGE_LIMIT,
|
||||
created_at: datetime | None = None,
|
||||
):
|
||||
cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community) WHERE n.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ """
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
@ -418,8 +460,12 @@ class CommunityNode(Node):
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY n.uuid DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -40,7 +40,7 @@ from graphiti_core.nodes import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RELEVANT_SCHEMA_LIMIT = 3
|
||||
RELEVANT_SCHEMA_LIMIT = 10
|
||||
DEFAULT_MIN_SCORE = 0.6
|
||||
DEFAULT_MMR_LAMBDA = 0.5
|
||||
MAX_SEARCH_DEPTH = 3
|
||||
|
||||
@ -18,7 +18,6 @@ import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from time import time
|
||||
from typing import List
|
||||
|
||||
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
|
||||
@ -34,11 +33,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_episodic_edges(
|
||||
entity_nodes: List[EntityNode],
|
||||
entity_nodes: list[EntityNode],
|
||||
episode: EpisodicNode,
|
||||
created_at: datetime,
|
||||
) -> List[EpisodicEdge]:
|
||||
edges: List[EpisodicEdge] = [
|
||||
) -> list[EpisodicEdge]:
|
||||
edges: list[EpisodicEdge] = [
|
||||
EpisodicEdge(
|
||||
source_node_uuid=episode.uuid,
|
||||
target_node_uuid=node.uuid,
|
||||
@ -52,11 +51,11 @@ def build_episodic_edges(
|
||||
|
||||
|
||||
def build_community_edges(
|
||||
entity_nodes: List[EntityNode],
|
||||
entity_nodes: list[EntityNode],
|
||||
community_node: CommunityNode,
|
||||
created_at: datetime,
|
||||
) -> List[CommunityEdge]:
|
||||
edges: List[CommunityEdge] = [
|
||||
) -> list[CommunityEdge]:
|
||||
edges: list[CommunityEdge] = [
|
||||
CommunityEdge(
|
||||
source_node_uuid=community_node.uuid,
|
||||
target_node_uuid=node.uuid,
|
||||
|
||||
@ -17,7 +17,6 @@ limitations under the License.
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
from typing import List
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
@ -31,7 +30,7 @@ async def extract_edge_dates(
|
||||
llm_client: LLMClient,
|
||||
edge: EntityEdge,
|
||||
current_episode: EpisodicNode,
|
||||
previous_episodes: List[EpisodicNode],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
context = {
|
||||
'edge_fact': edge.fact,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user