Refactor maintenance structure, add prompt library (#4)

* chore: Initial draft of stubs

* chore: Add comments and mock implementation of the add_episode method

* chore: Add success and error callbacks

* chore: Add success and error callbacks

* refactor: Fix conflicts with the latest merge
This commit is contained in:
Pavlo Paliychuk 2024-08-15 12:03:41 -04:00 committed by GitHub
parent b728ff0f68
commit f1c2224c0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 818 additions and 125 deletions

3
core/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .graphiti import Graphiti
__all__ = ["Graphiti"]

View File

@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from pydantic import BaseModel, Field
from datetime import datetime
from neo4j import AsyncDriver
from uuid import uuid1
from uuid import uuid4
import logging
from core.nodes import Node
@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
class Edge(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: uuid1().hex)
uuid: str = Field(default_factory=lambda: str(uuid4()))
source_node: Node
target_node: Node
created_at: datetime
@ -22,6 +22,11 @@ class Edge(BaseModel, ABC):
class EpisodicEdge(Edge):
async def save(self, driver: AsyncDriver):
if self.uuid is None:
uuid = uuid4()
logger.info(f"Created uuid: {uuid} for episodic edge")
self.uuid = str(uuid)
result = await driver.execute_query(
"""
MATCH (episode:Episodic {uuid: $episode_uuid})
@ -45,13 +50,25 @@ class EpisodicEdge(Edge):
class EntityEdge(Edge):
name: str
fact: str
fact_embedding: list[float] = None
episodes: list[str] = None # list of episode ids that reference these entity edges
expired_at: datetime = None # datetime of when the node was invalidated
valid_at: datetime = None # datetime of when the fact became true
invalid_at: datetime = None # datetime of when the fact stopped being true
name: str = Field(description="name of the edge, relation name")
fact: str = Field(
description="fact representing the edge and nodes that it connects"
)
fact_embedding: list[float] | None = Field(
default=None, description="embedding of the fact"
)
episodes: list[str] | None = Field(
default=None, description="list of episode ids that reference these entity edges"
)
expired_at: datetime | None = Field(
default=None, description="datetime of when the node was invalidated"
)
valid_at: datetime | None = Field(
default=None, description="datetime of when the fact became true"
)
invalid_at: datetime | None = Field(
default=None, description="datetime of when the fact stopped being true"
)
def generate_embedding(self, embedder, model="text-embedding-3-large"):
text = self.fact.replace("\n", " ")
@ -62,6 +79,7 @@ class EntityEdge(Edge):
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})

View File

@ -1,47 +1,124 @@
import asyncio
from datetime import datetime
import logging
from typing import Callable, Tuple, LiteralString
from typing import Callable, LiteralString, Tuple
from neo4j import AsyncGraphDatabase
from dotenv import load_dotenv
import os
from core.nodes import EntityNode, EpisodicNode, Node
from core.edges import EntityEdge, Edge
from core.utils import bfs, similarity_search, fulltext_search, build_episodic_edges
from core.utils import (
build_episodic_edges,
retrieve_relevant_schema,
extract_new_edges,
extract_new_nodes,
clear_data,
retrieve_episodes,
)
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
logger = logging.getLogger(__name__)
class LLMConfig:
"""Configuration for the language model"""
def __init__(
self,
api_key: str,
model: str = "gpt-4o",
base_url: str = "https://api.openai.com",
):
self.base_url = base_url
self.api_key = api_key
self.model = model
load_dotenv()
class Graphiti:
def __init__(
self, uri: str, user: str, password: str, llm_config: LLMConfig | None
self, uri: str, user: str, password: str, llm_client: LLMClient | None = None
):
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
self.database = "neo4j"
self.build_indices()
if llm_config:
self.llm_config = llm_config
if llm_client:
self.llm_client = llm_client
else:
self.llm_config = None
self.llm_client = OpenAIClient(
LLMConfig(
api_key=os.getenv("OPENAI_API_KEY"),
model="gpt-4o",
base_url="https://api.openai.com/v1",
)
)
def close(self):
self.driver.close()
async def retrieve_episodes(
self, last_n: int, sources: list[str] | None = "messages"
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
return await retrieve_episodes(self.driver, last_n, sources)
async def retrieve_relevant_schema(self, query: str = None) -> dict[str, any]:
"""Retrieve relevant nodes and edges to a specific query"""
return await retrieve_relevant_schema(self.driver, query)
...
# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[EntityNode],
new_edges: list[EntityEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime = None,
episode_type="string",
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
"""Process an episode and update the graph"""
try:
nodes: list[Node] = []
edges: list[Edge] = []
previous_episodes = await self.retrieve_episodes(last_n=3)
episode = EpisodicNode(
name=name,
labels=[],
source="messages",
content=episode_body,
source_description=source_description,
created_at=datetime.now(),
valid_at=reference_time,
)
await episode.save(self.driver)
relevant_schema = await self.retrieve_relevant_schema(episode.content)
new_nodes = await extract_new_nodes(
self.llm_client, episode, relevant_schema, previous_episodes
)
nodes.extend(new_nodes)
new_edges = await extract_new_edges(
self.llm_client, episode, new_nodes, relevant_schema, previous_episodes
)
edges.extend(new_edges)
episodic_edges = build_episodic_edges(nodes, episode, datetime.now())
edges.extend(episodic_edges)
# invalidated_edges = await self.invalidate_edges(
# episode, new_nodes, new_edges, relevant_schema, previous_episodes
# )
# edges.extend(invalidated_edges)
# Future optimization would be using batch operations to save nodes and edges
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
# for node in nodes:
# if isinstance(node, EntityNode):
# await node.update_summary(self.driver)
if success_callback:
await success_callback(episode)
except Exception as e:
if error_callback:
await error_callback(episode, e)
else:
raise e
async def build_indices(self):
index_queries: list[LiteralString] = [
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
@ -76,18 +153,9 @@ class Graphiti:
"""
)
async def retrieve_episodes(
self, last_n: int, sources: list[str] | None = "messages"
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
...
# Utility function, to be removed from this class
async def clear_data(self): ...
async def search(
self, query: str, config
) -> (list)[Tuple[EntityNode, list[EntityEdge]]]:
) -> (list)[tuple[EntityNode, list[EntityEdge]]]:
(vec_nodes, vec_edges) = similarity_search(query, embedder)
(text_nodes, text_edges) = fulltext_search(query)
@ -102,32 +170,6 @@ class Graphiti:
return [(node, edges)], episodes
async def get_relevant_schema(
self, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
) -> list[Tuple[EntityNode, list[EntityEdge]]]:
pass
# Call llm with the specified messages, and return the response
# Will be used in the conjunction with a prompt library
async def generate_llm_response(self, messages: list[any]) -> str: ...
# Extract new edges from the episode
async def extract_new_edges(
self,
episode: EpisodicNode,
new_nodes: list[EntityNode],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> list[EntityEdge]: ...
# Extract new nodes from the episode
async def extract_new_nodes(
self,
episode: EpisodicNode,
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> list[EntityNode]: ...
# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
@ -137,51 +179,3 @@ class Graphiti:
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime = None,
episode_type="string",
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
"""Process an episode and update the graph"""
try:
nodes: list[Node] = []
edges: list[Edge] = []
previous_episodes = await self.retrieve_episodes(last_n=3)
episode = EpisodicNode()
await episode.save(self.driver)
relevant_schema = await self.retrieve_relevant_schema(episode.content)
new_nodes = await self.extract_new_nodes(
episode, relevant_schema, previous_episodes
)
nodes.extend(new_nodes)
new_edges = await self.extract_new_edges(
episode, new_nodes, relevant_schema, previous_episodes
)
edges.extend(new_edges)
episodic_edges = build_episodic_edges(nodes, episode, datetime.now())
edges.extend(episodic_edges)
invalidated_edges = await self.invalidate_edges(
episode, new_nodes, new_edges, relevant_schema, previous_episodes
)
edges.extend(invalidated_edges)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
for node in nodes:
if isinstance(node, EntityNode):
await node.update_summary(self.driver)
if success_callback:
await success_callback(episode)
except Exception as e:
if error_callback:
await error_callback(episode, e)
else:
raise e

View File

@ -0,0 +1,5 @@
from .client import LLMClient
from .openai_client import OpenAIClient
from .config import LLMConfig
__all__ = ["LLMClient", "OpenAIClient", "LLMConfig"]

12
core/llm_client/client.py Normal file
View File

@ -0,0 +1,12 @@
from abc import ABC, abstractmethod
from .config import LLMConfig
class LLMClient(ABC):
@abstractmethod
def __init__(self, config: LLMConfig):
pass
@abstractmethod
async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]:
pass

33
core/llm_client/config.py Normal file
View File

@ -0,0 +1,33 @@
class LLMConfig:
"""
Configuration class for the Language Learning Model (LLM).
This class encapsulates the necessary parameters to interact with an LLM API,
such as OpenAI's GPT models. It stores the API key, model name, and base URL
for making requests to the LLM service.
"""
def __init__(
self,
api_key: str,
model: str = "gpt-4o",
base_url: str = "https://api.openai.com",
):
"""
Initialize the LLMConfig with the provided parameters.
Args:
api_key (str): The authentication key for accessing the LLM API.
This is required for making authorized requests.
model (str, optional): The specific LLM model to use for generating responses.
Defaults to "gpt-4o", which appears to be a custom model name.
Common values might include "gpt-3.5-turbo" or "gpt-4".
base_url (str, optional): The base URL of the LLM API service.
Defaults to "https://api.openai.com", which is OpenAI's standard API endpoint.
This can be changed if using a different provider or a custom endpoint.
"""
self.base_url = base_url
self.api_key = api_key
self.model = model

View File

@ -0,0 +1,24 @@
import json
from openai import AsyncOpenAI
from .client import LLMClient
from .config import LLMConfig
class OpenAIClient(LLMClient):
def __init__(self, config: LLMConfig):
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
self.model = config.model
async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]:
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.1,
max_tokens=3000,
response_format={"type": "json_object"},
)
return json.loads(response.choices[0].message.content)
except Exception as e:
print(f"Error in generating LLM response: {e}")
raise

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from pydantic import Field
from datetime import datetime
from uuid import uuid1
from uuid import uuid4
from openai import OpenAI
from pydantic import BaseModel, Field
@ -11,9 +12,9 @@ logger = logging.getLogger(__name__)
class Node(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: uuid1().hex)
uuid: str = Field(default_factory=lambda: str(uuid4()))
name: str
labels: list[str]
labels: list[str] = Field(default_factory=list)
created_at: datetime
@abstractmethod
@ -21,11 +22,17 @@ class Node(BaseModel, ABC):
class EpisodicNode(Node):
source: str # source type
source_description: str # description of the data source
content: str # raw episode data
entity_edges: list[str] # list of entity edge ids referenced in this episode
valid_at: datetime = None # datetime of when the original document was created
source: str = Field(description="source type")
source_description: str = Field(description="description of the data source")
content: str = Field(description="raw episode data")
entity_edges: list[str] = Field(
description="list of entity edges referenced in this episode",
default_factory=list,
)
valid_at: datetime | None = Field(
description="datetime of when the original document was created",
default=None,
)
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
@ -51,7 +58,9 @@ class EpisodicNode(Node):
class EntityNode(Node):
summary: str # regional summary of surrounding edges
summary: str = Field(description="regional summary of surrounding edges")
async def update_summary(self, driver: AsyncDriver): ...
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...

4
core/prompts/__init__.py Normal file
View File

@ -0,0 +1,4 @@
from .lib import prompt_library
from .models import Message
__all__ = ["prompt_library", "Message"]

View File

@ -0,0 +1,73 @@
import json
from typing import TypedDict, Protocol
from .models import Message, PromptVersion, PromptFunction
class Prompt(Protocol):
v1: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
def v1(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that extracts graph edges from provided context.",
),
Message(
role="user",
content=f"""
Given the following context, extract new semantic edges (relationships) that need to be added to the knowledge graph:
Current Graph Structure:
{context['relevant_schema']}
New Nodes:
{json.dumps(context['new_nodes'], indent=2)}
New Episode:
Content: {context['episode_content']}
Timestamp: {context['episode_timestamp']}
Previous Episodes:
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
Extract new semantic edges based on the content of the current episode, considering the existing graph structure, new nodes, and context from previous episodes.
Guidelines:
1. Create edges only between semantic nodes (not episodic nodes like messages).
2. Each edge should represent a clear relationship between two semantic nodes.
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
4. Provide a more detailed fact describing the relationship.
5. If a relationship seems to update an existing one, create a new edge with the updated information.
6. Consider temporal aspects of relationships when relevant.
7. Do not create edges involving episodic nodes (like Message 1 or Message 2).
8. Use existing nodes from the current graph structure when appropriate.
Respond with a JSON object in the following format:
{{
"new_edges": [
{{
"relation_type": "RELATION_TYPE_IN_CAPS",
"source_node": "Name of the source semantic node",
"target_node": "Name of the target semantic node",
"fact": "Detailed description of the relationship",
"valid_at": "YYYY-MM-DDTHH:MM:SSZ or null if not explicitly mentioned",
"invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null if ongoing or not explicitly mentioned"
}}
]
}}
If no new edges need to be added, return an empty list for "new_edges".
""",
),
]
versions: Versions = {
"v1": v1,
}

View File

@ -0,0 +1,65 @@
import json
from typing import TypedDict, Protocol
from .models import Message, PromptVersion, PromptFunction
class Prompt(Protocol):
v1: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
def v1(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that extracts graph nodes from provided context.",
),
Message(
role="user",
content=f"""
Given the following context, extract new semantic nodes that need to be added to the knowledge graph:
Existing Nodes:
{json.dumps(context['existing_nodes'], indent=2)}
Previous Episodes:
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
New Episode:
Content: {context["episode_content"]}
Timestamp: {context['episode_timestamp']}
Extract new semantic nodes based on the content of the current episode, while considering the existing nodes and context from previous episodes.
Guidelines:
1. Only extract new nodes that don't already exist in the graph structure.
2. Focus on entities, concepts, or actors that are central to the current episode.
3. Avoid creating nodes for relationships or actions (these will be handled as edges later).
4. Provide a brief but informative summary for each node.
5. If a node seems to represent an existing concept but with updated information, don't create a new node. This will be handled by edge updates.
6. Do not create nodes for episodic content (like Message 1 or Message 2).
Respond with a JSON object in the following format:
{{
"new_nodes": [
{{
"name": "Unique identifier for the node",
"labels": ["Semantic", "OptionalAdditionalLabel"],
"summary": "Brief summary of the node's role or significance"
}}
]
}}
If no new nodes need to be added, return an empty list for "new_nodes".
""",
),
]
versions: Versions = {
"v1": v1,
}

53
core/prompts/lib.py Normal file
View File

@ -0,0 +1,53 @@
from typing import TypedDict, Protocol
from .models import Message, PromptFunction
from typing import TypedDict, Protocol
from .models import Message, PromptFunction
from .extract_nodes import (
Prompt as ExtractNodesPrompt,
Versions as ExtractNodesVersions,
versions as extract_nodes_versions,
)
from .extract_edges import (
Prompt as ExtractEdgesPrompt,
Versions as ExtractEdgesVersions,
versions as extract_edges_versions,
)
class PromptLibrary(Protocol):
extract_nodes: ExtractNodesPrompt
extract_edges: ExtractEdgesPrompt
class PromptLibraryImpl(TypedDict):
extract_nodes: ExtractNodesVersions
extract_edges: ExtractEdgesVersions
class VersionWrapper:
def __init__(self, func: PromptFunction):
self.func = func
def __call__(self, context: dict[str, any]) -> list[Message]:
return self.func(context)
class PromptTypeWrapper:
def __init__(self, versions: dict[str, PromptFunction]):
for version, func in versions.items():
setattr(self, version, VersionWrapper(func))
class PromptLibraryWrapper:
def __init__(self, library: PromptLibraryImpl):
for prompt_type, versions in library.items():
setattr(self, prompt_type, PromptTypeWrapper(versions))
PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
"extract_nodes": extract_nodes_versions,
"extract_edges": extract_edges_versions,
}
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL)

15
core/prompts/models.py Normal file
View File

@ -0,0 +1,15 @@
from typing import Callable, Protocol
from pydantic import BaseModel
class Message(BaseModel):
role: str
content: str
class PromptVersion(Protocol):
def __call__(self, context: dict[str, any]) -> list[Message]: ...
PromptFunction = Callable[[dict[str, any]], list[Message]]

17
core/utils/__init__.py Normal file
View File

@ -0,0 +1,17 @@
from .maintenance import (
extract_new_edges,
build_episodic_edges,
extract_new_nodes,
clear_data,
retrieve_relevant_schema,
retrieve_episodes,
)
__all__ = [
"extract_new_edges",
"build_episodic_edges",
"extract_new_nodes",
"clear_data",
"retrieve_relevant_schema",
"retrieve_episodes",
]

View File

@ -0,0 +1,16 @@
from .edge_operations import extract_new_edges, build_episodic_edges
from .node_operations import extract_new_nodes
from .graph_data_operations import (
clear_data,
retrieve_relevant_schema,
retrieve_episodes,
)
__all__ = [
"extract_new_edges",
"build_episodic_edges",
"extract_new_nodes",
"clear_data",
"retrieve_relevant_schema",
"retrieve_episodes",
]

View File

@ -0,0 +1,128 @@
import json
from typing import List
from datetime import datetime
from core.nodes import EntityNode, EpisodicNode
from core.edges import EpisodicEdge, EntityEdge
import logging
from core.prompts import prompt_library
from core.llm_client import LLMClient
logger = logging.getLogger(__name__)
def build_episodic_edges(
semantic_nodes: List[EntityNode],
episode: EpisodicNode,
transaction_from: datetime,
) -> List[EpisodicEdge]:
edges: List[EpisodicEdge] = []
for node in semantic_nodes:
edge = EpisodicEdge(
source_node=episode, target_node=node, created_at=transaction_from
)
edges.append(edge)
return edges
async def extract_new_edges(
llm_client: LLMClient,
episode: EpisodicNode,
new_nodes: list[EntityNode],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> list[EntityEdge]:
# Prepare context for LLM
context = {
"episode_content": episode.content,
"episode_timestamp": (
episode.valid_at.isoformat() if episode.valid_at else None
),
"relevant_schema": json.dumps(relevant_schema, indent=2),
"new_nodes": [
{"name": node.name, "summary": node.summary} for node in new_nodes
],
"previous_episodes": [
{
"content": ep.content,
"timestamp": ep.valid_at.isoformat() if ep.valid_at else None,
}
for ep in previous_episodes
],
}
llm_response = await llm_client.generate_response(
prompt_library.extract_edges.v1(context)
)
new_edges_data = llm_response.get("new_edges", [])
# Convert the extracted data into EntityEdge objects
new_edges = []
for edge_data in new_edges_data:
source_node = next(
(node for node in new_nodes if node.name == edge_data["source_node"]),
None,
)
target_node = next(
(node for node in new_nodes if node.name == edge_data["target_node"]),
None,
)
# If source or target is not in new_nodes, check if it's an existing node
if source_node is None and edge_data["source_node"] in relevant_schema["nodes"]:
existing_node_data = relevant_schema["nodes"][edge_data["source_node"]]
source_node = EntityNode(
uuid=existing_node_data["uuid"],
name=edge_data["source_node"],
labels=[existing_node_data["label"]],
summary="",
created_at=datetime.now(),
)
if target_node is None and edge_data["target_node"] in relevant_schema["nodes"]:
existing_node_data = relevant_schema["nodes"][edge_data["target_node"]]
target_node = EntityNode(
uuid=existing_node_data["uuid"],
name=edge_data["target_node"],
labels=[existing_node_data["label"]],
summary="",
created_at=datetime.now(),
)
if (
source_node
and target_node
and not (
source_node.name.startswith("Message")
or target_node.name.startswith("Message")
)
):
valid_at = (
datetime.fromisoformat(edge_data["valid_at"])
if edge_data["valid_at"]
else episode.valid_at or datetime.now()
)
invalid_at = (
datetime.fromisoformat(edge_data["invalid_at"])
if edge_data["invalid_at"]
else None
)
new_edge = EntityEdge(
source_node=source_node,
target_node=target_node,
name=edge_data["relation_type"],
fact=edge_data["fact"],
episodes=[episode.uuid],
created_at=datetime.now(),
valid_at=valid_at,
invalid_at=invalid_at,
)
new_edges.append(new_edge)
logger.info(
f"Created new edge: {new_edge.name} from {source_node.name} (UUID: {source_node.uuid}) to {target_node.name} (UUID: {target_node.uuid})"
)
return new_edges

View File

@ -0,0 +1,95 @@
from datetime import datetime, timezone
from core.nodes import EpisodicNode
from neo4j import AsyncDriver
import logging
logger = logging.getLogger(__name__)
async def clear_data(driver: AsyncDriver):
async with driver.session() as session:
async def delete_all(tx):
await tx.run("MATCH (n) DETACH DELETE n")
await session.execute_write(delete_all)
async def retrieve_relevant_schema(
driver: AsyncDriver, query: str = None
) -> dict[str, any]:
async with driver.session() as session:
summary_query = """
MATCH (n)
OPTIONAL MATCH (n)-[r]->(m)
RETURN DISTINCT labels(n) AS node_labels, n.uuid AS node_uuid, n.name AS node_name,
type(r) AS relationship_type, r.name AS relationship_name, m.name AS related_node_name
"""
result = await session.run(summary_query)
records = [record async for record in result]
schema = {"nodes": {}, "relationships": []}
for record in records:
node_label = record["node_labels"][0] # Assuming one label per node
node_uuid = record["node_uuid"]
node_name = record["node_name"]
rel_type = record["relationship_type"]
rel_name = record["relationship_name"]
related_node = record["related_node_name"]
if node_name not in schema["nodes"]:
schema["nodes"][node_name] = {
"uuid": node_uuid,
"label": node_label,
"relationships": [],
}
if rel_type and related_node:
schema["nodes"][node_name]["relationships"].append(
{"type": rel_type, "name": rel_name, "target": related_node}
)
schema["relationships"].append(
{
"source": node_name,
"type": rel_type,
"name": rel_name,
"target": related_node,
}
)
return schema
async def retrieve_episodes(
driver: AsyncDriver, last_n: int, sources: list[str] | None = "messages"
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
async with driver.session() as session:
query = """
MATCH (e:EpisodicNode)
RETURN e.content as text, e.timestamp as timestamp, e.reference_timestamp as reference_timestamp
ORDER BY e.timestamp DESC
LIMIT $num_episodes
"""
result = await session.run(query, num_episodes=last_n)
episodes = [
EpisodicNode(
content=record["text"],
transaction_from=datetime.fromtimestamp(
record["timestamp"].to_native().timestamp(), timezone.utc
),
valid_at=(
datetime.fromtimestamp(
record["reference_timestamp"].to_native().timestamp(),
timezone.utc,
)
if record["reference_timestamp"] is not None
else None
),
)
async for record in result
]
return list(reversed(episodes)) # Return in chronological order

View File

@ -0,0 +1,63 @@
from datetime import datetime
from core.nodes import EntityNode, EpisodicNode
import logging
from core.llm_client import LLMClient
from core.prompts import prompt_library
logger = logging.getLogger(__name__)
async def extract_new_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> list[EntityNode]:
# Prepare context for LLM
existing_nodes = [
{"name": node_name, "label": node_info["label"], "uuid": node_info["uuid"]}
for node_name, node_info in relevant_schema["nodes"].items()
]
context = {
"episode_content": episode.content,
"episode_timestamp": (
episode.valid_at.isoformat() if episode.valid_at else None
),
"existing_nodes": existing_nodes,
"previous_episodes": [
{
"content": ep.content,
"timestamp": ep.valid_at.isoformat() if ep.valid_at else None,
}
for ep in previous_episodes
],
}
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.v1(context)
)
new_nodes_data = llm_response.get("new_nodes", [])
logger.info(f"Extracted new nodes: {new_nodes_data}")
# Convert the extracted data into EntityNode objects
new_nodes = []
for node_data in new_nodes_data:
# Check if the node already exists
if not any(
existing_node["name"] == node_data["name"]
for existing_node in existing_nodes
):
new_node = EntityNode(
name=node_data["name"],
labels=node_data["labels"],
summary=node_data["summary"],
created_at=datetime.now(),
)
new_nodes.append(new_node)
logger.info(f"Created new node: {new_node.name} (UUID: {new_node.uuid})")
else:
logger.info(f"Node {node_data['name']} already exists, skipping creation.")
return new_nodes

View File

66
runner.py Normal file
View File

@ -0,0 +1,66 @@
from core import Graphiti
from core.utils.maintenance.graph_data_operations import clear_data
from dotenv import load_dotenv
import os
import asyncio
import logging
import sys
load_dotenv()
neo4j_uri = os.environ.get("NEO4J_URI") or "bolt://localhost:7687"
neo4j_user = os.environ.get("NEO4J_USER") or "neo4j"
neo4j_password = os.environ.get("NEO4J_PASSWORD") or "password"
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
async def main():
setup_logging()
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
await clear_data(client.driver)
# await client.build_indices()
await client.add_episode(
name="Message 1",
episode_body="Paul: I love apples",
source_description="WhatsApp Message",
)
await client.add_episode(
name="Message 2",
episode_body="Paul: I love bananas",
source_description="WhatsApp Message",
)
await client.add_episode(
name="Message 3",
episode_body="Assistant: The best type of apples available are Fuji apples",
source_description="WhatsApp Message",
)
await client.add_episode(
name="Message 4",
episode_body="Paul: Oh, I actually hate those",
source_description="WhatsApp Message",
)
asyncio.run(main())