renaming and add indices (#3)

rename and add indices
This commit is contained in:
Preston Rasmussen 2024-08-15 11:04:57 -04:00 committed by GitHub
parent 83c7640d9c
commit b728ff0f68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 181 additions and 55 deletions

View File

@ -11,10 +11,10 @@ logger = logging.getLogger(__name__)
class Edge(BaseModel, ABC):
uuid: Field(default_factory=lambda: uuid1().hex)
uuid: str = Field(default_factory=lambda: uuid1().hex)
source_node: Node
target_node: Node
transaction_from: datetime
created_at: datetime
@abstractmethod
async def save(self, driver: AsyncDriver): ...
@ -25,14 +25,14 @@ class EpisodicEdge(Edge):
result = await driver.execute_query(
"""
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Semantic {uuid: $semantic_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, transaction_from: $transaction_from}
SET r = {uuid: $uuid, created_at: $created_at}
RETURN r.uuid AS uuid""",
episode_uuid=self.source_node.uuid,
semantic_uuid=self.target_node.uuid,
entity_uuid=self.target_node.uuid,
uuid=self.uuid,
transaction_from=self.transaction_from,
created_at=self.created_at,
)
logger.info(f"Saved edge to neo4j: {self.uuid}")
@ -44,14 +44,14 @@ class EpisodicEdge(Edge):
# Right now we have all edge nodes as type RELATES_TO
class SemanticEdge(Edge):
class EntityEdge(Edge):
name: str
fact: str
fact_embedding: list[int] = None
episodes: list[str] = None # list of episodes that reference these semantic edges
transaction_to: datetime = None # datetime of when the node was invalidated
valid_from: datetime = None # datetime of when the fact became true
valid_to: datetime = None # datetime of when the fact stopped being true
fact_embedding: list[float] = None
episodes: list[str] = None # list of episode ids that reference these entity edges
expired_at: datetime = None # datetime of when the node was invalidated
valid_at: datetime = None # datetime of when the fact became true
invalid_at: datetime = None # datetime of when the fact stopped being true
def generate_embedding(self, embedder, model="text-embedding-3-large"):
text = self.fact.replace("\n", " ")
@ -63,12 +63,12 @@ class SemanticEdge(Edge):
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (source:Semantic {uuid: $source_uuid})
MATCH (target:Semantic {uuid: $target_uuid})
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
episodes: $episodes, transaction_from: $transaction_from, transaction_to: $transaction_to,
valid_from: $valid_from, valid_to: $valid_to}
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
valid_at: $valid_at, invalid_at: $invalid_at}
RETURN r.uuid AS uuid""",
source_uuid=self.source_node.uuid,
target_uuid=self.target_node.uuid,
@ -77,10 +77,10 @@ class SemanticEdge(Edge):
fact=self.fact,
fact_embedding=self.fact_embedding,
episodes=self.episodes,
transaction_from=self.transaction_from,
transaction_to=self.transaction_to,
valid_from=self.valid_from,
valid_to=self.valid_to,
created_at=self.created_at,
expired_at=self.expired_at,
valid_at=self.valid_at,
invalid_at=self.invalid_at,
)
logger.info(f"Saved Node to neo4j: {self.uuid}")

View File

@ -1,11 +1,11 @@
import asyncio
from datetime import datetime
import logging
from typing import Callable, Tuple
from typing import Callable, Tuple, LiteralString
from neo4j import AsyncGraphDatabase
from core.nodes import SemanticNode, EpisodicNode, Node
from core.edges import SemanticEdge, Edge
from core.nodes import EntityNode, EpisodicNode, Node
from core.edges import EntityEdge, Edge
from core.utils import bfs, similarity_search, fulltext_search, build_episodic_edges
logger = logging.getLogger(__name__)
@ -31,6 +31,9 @@ class Graphiti:
):
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
self.database = "neo4j"
self.build_indices()
if llm_config:
self.llm_config = llm_config
else:
@ -39,6 +42,40 @@ class Graphiti:
def close(self):
self.driver.close()
async def build_indices(self):
index_queries: list[LiteralString] = [
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.name)",
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.created_at)",
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.expired_at)",
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.valid_at)",
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.invalid_at)",
]
# Add the range indices
for query in index_queries:
await self.driver.execute_query(query)
# Add the entity indices
await self.driver.execute_query(
"""
CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]
"""
)
await self.driver.execute_query(
"""
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
"""
)
async def retrieve_episodes(
self, last_n: int, sources: list[str] | None = "messages"
) -> list[EpisodicNode]:
@ -48,8 +85,9 @@ class Graphiti:
# Utility function, to be removed from this class
async def clear_data(self): ...
async def search(self, query: str, config) -> (
list)[Tuple[SemanticNode, list[SemanticEdge]]]:
async def search(
self, query: str, config
) -> (list)[Tuple[EntityNode, list[EntityEdge]]]:
(vec_nodes, vec_edges) = similarity_search(query, embedder)
(text_nodes, text_edges) = fulltext_search(query)
@ -64,8 +102,9 @@ class Graphiti:
return [(node, edges)], episodes
async def get_relevant_schema(self, episode: EpisodicNode, previous_episodes: list[EpisodicNode]) -> (
list)[Tuple[SemanticNode, list[SemanticEdge]]]:
async def get_relevant_schema(
self, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
) -> list[Tuple[EntityNode, list[EntityEdge]]]:
pass
# Call llm with the specified messages, and return the response
@ -76,10 +115,10 @@ class Graphiti:
async def extract_new_edges(
self,
episode: EpisodicNode,
new_nodes: list[SemanticNode],
new_nodes: list[EntityNode],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> list[SemanticEdge]: ...
) -> list[EntityEdge]: ...
# Extract new nodes from the episode
async def extract_new_nodes(
@ -87,14 +126,14 @@ class Graphiti:
episode: EpisodicNode,
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> list[SemanticNode]: ...
) -> list[EntityNode]: ...
# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[SemanticNode],
new_edges: list[SemanticEdge],
new_nodes: list[EntityNode],
new_edges: list[EntityEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
@ -137,7 +176,7 @@ class Graphiti:
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
for node in nodes:
if isinstance(node, SemanticNode):
if isinstance(node, EntityNode):
await node.update_summary(self.driver)
if success_callback:
await success_callback(episode)

View File

@ -11,10 +11,10 @@ logger = logging.getLogger(__name__)
class Node(BaseModel, ABC):
uuid: Field(default_factory=lambda: uuid1().hex)
uuid: str = Field(default_factory=lambda: uuid1().hex)
name: str
labels: list[str]
transaction_from: datetime
created_at: datetime
@abstractmethod
async def save(self, driver: AsyncDriver): ...
@ -24,23 +24,23 @@ class EpisodicNode(Node):
source: str # source type
source_description: str # description of the data source
content: str # raw episode data
semantic_edges: list[str] # list of semantic edges referenced in this episode
valid_from: datetime = None # datetime of when the original document was created
entity_edges: list[str] # list of entity edge ids referenced in this episode
valid_at: datetime = None # datetime of when the original document was created
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, source_description: $source_description, content: $content,
semantic_edges: $semantic_edges, transaction_from: $transaction_from, valid_from: $valid_from}
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid""",
uuid=self.uuid,
name=self.name,
source_description=self.source_description,
content=self.content,
semantic_edges=self.semantic_edges,
transaction_from=self.transaction_from,
valid_from=self.valid_from,
entity_edges=self.entity_edges,
created_at=self.created_at,
valid_at=self.valid_at,
_database="neo4j",
)
@ -50,7 +50,7 @@ class EpisodicNode(Node):
return result
class SemanticNode(Node):
class EntityNode(Node):
summary: str # regional summary of surrounding edges
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
@ -58,13 +58,13 @@ class SemanticNode(Node):
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MERGE (n:Semantic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, summary: $summary, transaction_from: $transaction_from}
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, summary: $summary, created_at: $created_at}
RETURN n.uuid AS uuid""",
uuid=self.uuid,
name=self.name,
summary=self.summary,
transaction_from=self.transaction_from,
created_at=self.created_at,
)
logger.info(f"Saved Node to neo4j: {self.uuid}")

View File

@ -1,12 +1,12 @@
from typing import Tuple
from core.edges import EpisodicEdge, SemanticEdge, Edge
from core.nodes import SemanticNode, EpisodicNode, Node
from core.edges import EpisodicEdge, EntityEdge, Edge
from core.nodes import EntityNode, EpisodicNode, Node
async def bfs(
nodes: list[Node], edges: list[Edge], k: int
) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ...
) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
# Breadth first search over nodes and edges with desired depth
@ -14,7 +14,7 @@ async def bfs(
async def similarity_search(
query: str, embedder
) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ...
) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
# vector similarity search over embedded facts
@ -22,23 +22,23 @@ async def similarity_search(
async def fulltext_search(
query: str,
) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ...
) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
# fulltext search over names and summary
def build_episodic_edges(
semantic_nodes: list[SemanticNode], episode: EpisodicNode
entity_nodes: list[EntityNode], episode: EpisodicNode
) -> list[EpisodicEdge]:
edges: list[EpisodicEdge] = []
for node in semantic_nodes:
for node in entity_nodes:
edges.append(
EpisodicEdge(
source_node=episode,
target_node=node,
transaction_from=episode.transaction_from,
created_at=episode.created_at,
)
)

5
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@ -2186,6 +2186,7 @@ description = "Nvidia JIT LTO Library"
optional = false
python-versions = ">=3"
files = [
{file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_aarch64.whl", hash = "sha256:84fb38465a5bc7c70cbc320cfd0963eb302ee25a5e939e9f512bbba55b6072fb"},
{file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl", hash = "sha256:562ab97ea2c23164823b2a89cb328d01d45cb99634b8c65fe7cd60d14562bd79"},
{file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-win_amd64.whl", hash = "sha256:ed3c43a17f37b0c922a919203d2d36cbef24d41cc3e6b625182f8b58203644f6"},
]
@ -4937,4 +4938,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "cefc4469afc33f38b93547ee72ed623000f15faae3889d432a12ddcb33643848"
content-hash = "142d26cbdbf9c07019dfdb8599b70e8efb9c3842a3c95588d1f59b9c187e44ba"

View File

@ -23,6 +23,7 @@ python-dotenv = "^1.0.1"
pandas = "^2.2.2"
pytest-asyncio = "^0.23.8"
pytest-xdist = "^3.6.1"
pytest = "^8.3.2"
[build-system]

View File

@ -0,0 +1,85 @@
import os
import pytest
import asyncio
from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
from openai import OpenAI
from core.edges import EpisodicEdge, EntityEdge
from core.graphiti import Graphiti
from core.nodes import EpisodicNode, EntityNode
from datetime import datetime
pytest_plugins = ("pytest_asyncio",)
load_dotenv()
NEO4J_URI = os.getenv("NEO4J_URI")
NEO4j_USER = os.getenv("NEO4J_USER")
NEO4j_PASSWORD = os.getenv("NEO4J_PASSWORD")
@pytest.mark.asyncio
async def test_graphiti_init():
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None)
await graphiti.build_indices()
graphiti.close()
@pytest.mark.asyncio
async def test_graph_integration():
driver = AsyncGraphDatabase.driver(
NEO4J_URI,
auth=(NEO4j_USER, NEO4j_PASSWORD),
)
embedder = OpenAI().embeddings
now = datetime.now()
episode = EpisodicNode(
name="test_episode",
labels=[],
created_at=now,
source="message",
source_description="conversation message",
content="Alice likes Bob",
entity_edges=[],
)
alice_node = EntityNode(
name="Alice",
labels=[],
created_at=now,
summary="Alice summary",
)
bob_node = EntityNode(name="Bob", labels=[], created_at=now, summary="Bob summary")
episodic_edge_1 = EpisodicEdge(
source_node=episode, target_node=alice_node, created_at=now
)
episodic_edge_2 = EpisodicEdge(
source_node=episode, target_node=bob_node, created_at=now
)
entity_edge = EntityEdge(
source_node=alice_node,
target_node=bob_node,
created_at=now,
name="likes",
fact="Alice likes Bob",
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
)
entity_edge.generate_embedding(embedder)
nodes = [episode, alice_node, bob_node]
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
await asyncio.gather(*[node.save(driver) for node in nodes])
await asyncio.gather(*[edge.save(driver) for edge in edges])