mirror of
https://github.com/getzep/graphiti.git
synced 2025-12-28 07:33:30 +00:00
Refactor maintenance structure, add prompt library (#4)
* chore: Initial draft of stubs * chore: Add comments and mock implementation of the add_episode method * chore: Add success and error callbacks * chore: Add success and error callbacks * refactor: Fix conflicts with the latest merge
This commit is contained in:
parent
b728ff0f68
commit
f1c2224c0e
3
core/__init__.py
Normal file
3
core/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .graphiti import Graphiti
|
||||
|
||||
__all__ = ["Graphiti"]
|
||||
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
from neo4j import AsyncDriver
|
||||
from uuid import uuid1
|
||||
from uuid import uuid4
|
||||
import logging
|
||||
|
||||
from core.nodes import Node
|
||||
@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Edge(BaseModel, ABC):
|
||||
uuid: str = Field(default_factory=lambda: uuid1().hex)
|
||||
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
||||
source_node: Node
|
||||
target_node: Node
|
||||
created_at: datetime
|
||||
@ -22,6 +22,11 @@ class Edge(BaseModel, ABC):
|
||||
|
||||
class EpisodicEdge(Edge):
|
||||
async def save(self, driver: AsyncDriver):
|
||||
if self.uuid is None:
|
||||
uuid = uuid4()
|
||||
logger.info(f"Created uuid: {uuid} for episodic edge")
|
||||
self.uuid = str(uuid)
|
||||
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (episode:Episodic {uuid: $episode_uuid})
|
||||
@ -45,13 +50,25 @@ class EpisodicEdge(Edge):
|
||||
|
||||
|
||||
class EntityEdge(Edge):
|
||||
name: str
|
||||
fact: str
|
||||
fact_embedding: list[float] = None
|
||||
episodes: list[str] = None # list of episode ids that reference these entity edges
|
||||
expired_at: datetime = None # datetime of when the node was invalidated
|
||||
valid_at: datetime = None # datetime of when the fact became true
|
||||
invalid_at: datetime = None # datetime of when the fact stopped being true
|
||||
name: str = Field(description="name of the edge, relation name")
|
||||
fact: str = Field(
|
||||
description="fact representing the edge and nodes that it connects"
|
||||
)
|
||||
fact_embedding: list[float] | None = Field(
|
||||
default=None, description="embedding of the fact"
|
||||
)
|
||||
episodes: list[str] | None = Field(
|
||||
default=None, description="list of episode ids that reference these entity edges"
|
||||
)
|
||||
expired_at: datetime | None = Field(
|
||||
default=None, description="datetime of when the node was invalidated"
|
||||
)
|
||||
valid_at: datetime | None = Field(
|
||||
default=None, description="datetime of when the fact became true"
|
||||
)
|
||||
invalid_at: datetime | None = Field(
|
||||
default=None, description="datetime of when the fact stopped being true"
|
||||
)
|
||||
|
||||
def generate_embedding(self, embedder, model="text-embedding-3-large"):
|
||||
text = self.fact.replace("\n", " ")
|
||||
@ -62,6 +79,7 @@ class EntityEdge(Edge):
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
|
||||
"""
|
||||
MATCH (source:Entity {uuid: $source_uuid})
|
||||
MATCH (target:Entity {uuid: $target_uuid})
|
||||
|
||||
208
core/graphiti.py
208
core/graphiti.py
@ -1,47 +1,124 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import Callable, Tuple, LiteralString
|
||||
from typing import Callable, LiteralString, Tuple
|
||||
from neo4j import AsyncGraphDatabase
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
from core.nodes import EntityNode, EpisodicNode, Node
|
||||
from core.edges import EntityEdge, Edge
|
||||
from core.utils import bfs, similarity_search, fulltext_search, build_episodic_edges
|
||||
from core.utils import (
|
||||
build_episodic_edges,
|
||||
retrieve_relevant_schema,
|
||||
extract_new_edges,
|
||||
extract_new_nodes,
|
||||
clear_data,
|
||||
retrieve_episodes,
|
||||
)
|
||||
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMConfig:
|
||||
"""Configuration for the language model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "gpt-4o",
|
||||
base_url: str = "https://api.openai.com",
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Graphiti:
|
||||
def __init__(
|
||||
self, uri: str, user: str, password: str, llm_config: LLMConfig | None
|
||||
self, uri: str, user: str, password: str, llm_client: LLMClient | None = None
|
||||
):
|
||||
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
||||
self.database = "neo4j"
|
||||
|
||||
self.build_indices()
|
||||
|
||||
if llm_config:
|
||||
self.llm_config = llm_config
|
||||
if llm_client:
|
||||
self.llm_client = llm_client
|
||||
else:
|
||||
self.llm_config = None
|
||||
self.llm_client = OpenAIClient(
|
||||
LLMConfig(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model="gpt-4o",
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
)
|
||||
|
||||
def close(self):
|
||||
self.driver.close()
|
||||
|
||||
async def retrieve_episodes(
|
||||
self, last_n: int, sources: list[str] | None = "messages"
|
||||
) -> list[EpisodicNode]:
|
||||
"""Retrieve the last n episodic nodes from the graph"""
|
||||
return await retrieve_episodes(self.driver, last_n, sources)
|
||||
|
||||
async def retrieve_relevant_schema(self, query: str = None) -> dict[str, any]:
|
||||
"""Retrieve relevant nodes and edges to a specific query"""
|
||||
return await retrieve_relevant_schema(self.driver, query)
|
||||
...
|
||||
|
||||
# Invalidate edges that are no longer valid
|
||||
async def invalidate_edges(
|
||||
self,
|
||||
episode: EpisodicNode,
|
||||
new_nodes: list[EntityNode],
|
||||
new_edges: list[EntityEdge],
|
||||
relevant_schema: dict[str, any],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
): ...
|
||||
|
||||
async def add_episode(
|
||||
self,
|
||||
name: str,
|
||||
episode_body: str,
|
||||
source_description: str,
|
||||
reference_time: datetime = None,
|
||||
episode_type="string",
|
||||
success_callback: Callable | None = None,
|
||||
error_callback: Callable | None = None,
|
||||
):
|
||||
"""Process an episode and update the graph"""
|
||||
try:
|
||||
nodes: list[Node] = []
|
||||
edges: list[Edge] = []
|
||||
previous_episodes = await self.retrieve_episodes(last_n=3)
|
||||
episode = EpisodicNode(
|
||||
name=name,
|
||||
labels=[],
|
||||
source="messages",
|
||||
content=episode_body,
|
||||
source_description=source_description,
|
||||
created_at=datetime.now(),
|
||||
valid_at=reference_time,
|
||||
)
|
||||
await episode.save(self.driver)
|
||||
relevant_schema = await self.retrieve_relevant_schema(episode.content)
|
||||
new_nodes = await extract_new_nodes(
|
||||
self.llm_client, episode, relevant_schema, previous_episodes
|
||||
)
|
||||
nodes.extend(new_nodes)
|
||||
new_edges = await extract_new_edges(
|
||||
self.llm_client, episode, new_nodes, relevant_schema, previous_episodes
|
||||
)
|
||||
edges.extend(new_edges)
|
||||
episodic_edges = build_episodic_edges(nodes, episode, datetime.now())
|
||||
edges.extend(episodic_edges)
|
||||
|
||||
# invalidated_edges = await self.invalidate_edges(
|
||||
# episode, new_nodes, new_edges, relevant_schema, previous_episodes
|
||||
# )
|
||||
|
||||
# edges.extend(invalidated_edges)
|
||||
# Future optimization would be using batch operations to save nodes and edges
|
||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
|
||||
# for node in nodes:
|
||||
# if isinstance(node, EntityNode):
|
||||
# await node.update_summary(self.driver)
|
||||
if success_callback:
|
||||
await success_callback(episode)
|
||||
except Exception as e:
|
||||
if error_callback:
|
||||
await error_callback(episode, e)
|
||||
else:
|
||||
raise e
|
||||
|
||||
async def build_indices(self):
|
||||
index_queries: list[LiteralString] = [
|
||||
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
|
||||
@ -76,18 +153,9 @@ class Graphiti:
|
||||
"""
|
||||
)
|
||||
|
||||
async def retrieve_episodes(
|
||||
self, last_n: int, sources: list[str] | None = "messages"
|
||||
) -> list[EpisodicNode]:
|
||||
"""Retrieve the last n episodic nodes from the graph"""
|
||||
...
|
||||
|
||||
# Utility function, to be removed from this class
|
||||
async def clear_data(self): ...
|
||||
|
||||
async def search(
|
||||
self, query: str, config
|
||||
) -> (list)[Tuple[EntityNode, list[EntityEdge]]]:
|
||||
) -> (list)[tuple[EntityNode, list[EntityEdge]]]:
|
||||
(vec_nodes, vec_edges) = similarity_search(query, embedder)
|
||||
(text_nodes, text_edges) = fulltext_search(query)
|
||||
|
||||
@ -102,32 +170,6 @@ class Graphiti:
|
||||
|
||||
return [(node, edges)], episodes
|
||||
|
||||
async def get_relevant_schema(
|
||||
self, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
|
||||
) -> list[Tuple[EntityNode, list[EntityEdge]]]:
|
||||
pass
|
||||
|
||||
# Call llm with the specified messages, and return the response
|
||||
# Will be used in the conjunction with a prompt library
|
||||
async def generate_llm_response(self, messages: list[any]) -> str: ...
|
||||
|
||||
# Extract new edges from the episode
|
||||
async def extract_new_edges(
|
||||
self,
|
||||
episode: EpisodicNode,
|
||||
new_nodes: list[EntityNode],
|
||||
relevant_schema: dict[str, any],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> list[EntityEdge]: ...
|
||||
|
||||
# Extract new nodes from the episode
|
||||
async def extract_new_nodes(
|
||||
self,
|
||||
episode: EpisodicNode,
|
||||
relevant_schema: dict[str, any],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> list[EntityNode]: ...
|
||||
|
||||
# Invalidate edges that are no longer valid
|
||||
async def invalidate_edges(
|
||||
self,
|
||||
@ -137,51 +179,3 @@ class Graphiti:
|
||||
relevant_schema: dict[str, any],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
): ...
|
||||
|
||||
async def add_episode(
|
||||
self,
|
||||
name: str,
|
||||
episode_body: str,
|
||||
source_description: str,
|
||||
reference_time: datetime = None,
|
||||
episode_type="string",
|
||||
success_callback: Callable | None = None,
|
||||
error_callback: Callable | None = None,
|
||||
):
|
||||
"""Process an episode and update the graph"""
|
||||
try:
|
||||
nodes: list[Node] = []
|
||||
edges: list[Edge] = []
|
||||
previous_episodes = await self.retrieve_episodes(last_n=3)
|
||||
episode = EpisodicNode()
|
||||
await episode.save(self.driver)
|
||||
relevant_schema = await self.retrieve_relevant_schema(episode.content)
|
||||
new_nodes = await self.extract_new_nodes(
|
||||
episode, relevant_schema, previous_episodes
|
||||
)
|
||||
nodes.extend(new_nodes)
|
||||
new_edges = await self.extract_new_edges(
|
||||
episode, new_nodes, relevant_schema, previous_episodes
|
||||
)
|
||||
edges.extend(new_edges)
|
||||
episodic_edges = build_episodic_edges(nodes, episode, datetime.now())
|
||||
edges.extend(episodic_edges)
|
||||
|
||||
invalidated_edges = await self.invalidate_edges(
|
||||
episode, new_nodes, new_edges, relevant_schema, previous_episodes
|
||||
)
|
||||
|
||||
edges.extend(invalidated_edges)
|
||||
|
||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
|
||||
for node in nodes:
|
||||
if isinstance(node, EntityNode):
|
||||
await node.update_summary(self.driver)
|
||||
if success_callback:
|
||||
await success_callback(episode)
|
||||
except Exception as e:
|
||||
if error_callback:
|
||||
await error_callback(episode, e)
|
||||
else:
|
||||
raise e
|
||||
|
||||
5
core/llm_client/__init__.py
Normal file
5
core/llm_client/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .client import LLMClient
|
||||
from .openai_client import OpenAIClient
|
||||
from .config import LLMConfig
|
||||
|
||||
__all__ = ["LLMClient", "OpenAIClient", "LLMConfig"]
|
||||
12
core/llm_client/client.py
Normal file
12
core/llm_client/client.py
Normal file
@ -0,0 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from .config import LLMConfig
|
||||
|
||||
|
||||
class LLMClient(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, config: LLMConfig):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]:
|
||||
pass
|
||||
33
core/llm_client/config.py
Normal file
33
core/llm_client/config.py
Normal file
@ -0,0 +1,33 @@
|
||||
class LLMConfig:
|
||||
"""
|
||||
Configuration class for the Language Learning Model (LLM).
|
||||
|
||||
This class encapsulates the necessary parameters to interact with an LLM API,
|
||||
such as OpenAI's GPT models. It stores the API key, model name, and base URL
|
||||
for making requests to the LLM service.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "gpt-4o",
|
||||
base_url: str = "https://api.openai.com",
|
||||
):
|
||||
"""
|
||||
Initialize the LLMConfig with the provided parameters.
|
||||
|
||||
Args:
|
||||
api_key (str): The authentication key for accessing the LLM API.
|
||||
This is required for making authorized requests.
|
||||
|
||||
model (str, optional): The specific LLM model to use for generating responses.
|
||||
Defaults to "gpt-4o", which appears to be a custom model name.
|
||||
Common values might include "gpt-3.5-turbo" or "gpt-4".
|
||||
|
||||
base_url (str, optional): The base URL of the LLM API service.
|
||||
Defaults to "https://api.openai.com", which is OpenAI's standard API endpoint.
|
||||
This can be changed if using a different provider or a custom endpoint.
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
24
core/llm_client/openai_client.py
Normal file
24
core/llm_client/openai_client.py
Normal file
@ -0,0 +1,24 @@
|
||||
import json
|
||||
from openai import AsyncOpenAI
|
||||
from .client import LLMClient
|
||||
from .config import LLMConfig
|
||||
|
||||
|
||||
class OpenAIClient(LLMClient):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
self.model = config.model
|
||||
|
||||
async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]:
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.1,
|
||||
max_tokens=3000,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
except Exception as e:
|
||||
print(f"Error in generating LLM response: {e}")
|
||||
raise
|
||||
@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pydantic import Field
|
||||
from datetime import datetime
|
||||
from uuid import uuid1
|
||||
from uuid import uuid4
|
||||
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
@ -11,9 +12,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Node(BaseModel, ABC):
|
||||
uuid: str = Field(default_factory=lambda: uuid1().hex)
|
||||
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
||||
name: str
|
||||
labels: list[str]
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
created_at: datetime
|
||||
|
||||
@abstractmethod
|
||||
@ -21,11 +22,17 @@ class Node(BaseModel, ABC):
|
||||
|
||||
|
||||
class EpisodicNode(Node):
|
||||
source: str # source type
|
||||
source_description: str # description of the data source
|
||||
content: str # raw episode data
|
||||
entity_edges: list[str] # list of entity edge ids referenced in this episode
|
||||
valid_at: datetime = None # datetime of when the original document was created
|
||||
source: str = Field(description="source type")
|
||||
source_description: str = Field(description="description of the data source")
|
||||
content: str = Field(description="raw episode data")
|
||||
entity_edges: list[str] = Field(
|
||||
description="list of entity edges referenced in this episode",
|
||||
default_factory=list,
|
||||
)
|
||||
valid_at: datetime | None = Field(
|
||||
description="datetime of when the original document was created",
|
||||
default=None,
|
||||
)
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
@ -51,7 +58,9 @@ class EpisodicNode(Node):
|
||||
|
||||
|
||||
class EntityNode(Node):
|
||||
summary: str # regional summary of surrounding edges
|
||||
summary: str = Field(description="regional summary of surrounding edges")
|
||||
|
||||
async def update_summary(self, driver: AsyncDriver): ...
|
||||
|
||||
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
|
||||
|
||||
|
||||
4
core/prompts/__init__.py
Normal file
4
core/prompts/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .lib import prompt_library
|
||||
from .models import Message
|
||||
|
||||
__all__ = ["prompt_library", "Message"]
|
||||
73
core/prompts/extract_edges.py
Normal file
73
core/prompts/extract_edges.py
Normal file
@ -0,0 +1,73 @@
|
||||
import json
|
||||
from typing import TypedDict, Protocol
|
||||
|
||||
from .models import Message, PromptVersion, PromptFunction
|
||||
|
||||
|
||||
class Prompt(Protocol):
|
||||
v1: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
v1: PromptFunction
|
||||
|
||||
|
||||
def v1(context: dict[str, any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role="system",
|
||||
content="You are a helpful assistant that extracts graph edges from provided context.",
|
||||
),
|
||||
Message(
|
||||
role="user",
|
||||
content=f"""
|
||||
Given the following context, extract new semantic edges (relationships) that need to be added to the knowledge graph:
|
||||
|
||||
Current Graph Structure:
|
||||
{context['relevant_schema']}
|
||||
|
||||
New Nodes:
|
||||
{json.dumps(context['new_nodes'], indent=2)}
|
||||
|
||||
New Episode:
|
||||
Content: {context['episode_content']}
|
||||
Timestamp: {context['episode_timestamp']}
|
||||
|
||||
Previous Episodes:
|
||||
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
|
||||
|
||||
Extract new semantic edges based on the content of the current episode, considering the existing graph structure, new nodes, and context from previous episodes.
|
||||
|
||||
Guidelines:
|
||||
1. Create edges only between semantic nodes (not episodic nodes like messages).
|
||||
2. Each edge should represent a clear relationship between two semantic nodes.
|
||||
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
|
||||
4. Provide a more detailed fact describing the relationship.
|
||||
5. If a relationship seems to update an existing one, create a new edge with the updated information.
|
||||
6. Consider temporal aspects of relationships when relevant.
|
||||
7. Do not create edges involving episodic nodes (like Message 1 or Message 2).
|
||||
8. Use existing nodes from the current graph structure when appropriate.
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"new_edges": [
|
||||
{{
|
||||
"relation_type": "RELATION_TYPE_IN_CAPS",
|
||||
"source_node": "Name of the source semantic node",
|
||||
"target_node": "Name of the target semantic node",
|
||||
"fact": "Detailed description of the relationship",
|
||||
"valid_at": "YYYY-MM-DDTHH:MM:SSZ or null if not explicitly mentioned",
|
||||
"invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null if ongoing or not explicitly mentioned"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
If no new edges need to be added, return an empty list for "new_edges".
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {
|
||||
"v1": v1,
|
||||
}
|
||||
65
core/prompts/extract_nodes.py
Normal file
65
core/prompts/extract_nodes.py
Normal file
@ -0,0 +1,65 @@
|
||||
import json
|
||||
from typing import TypedDict, Protocol
|
||||
|
||||
from .models import Message, PromptVersion, PromptFunction
|
||||
|
||||
|
||||
class Prompt(Protocol):
|
||||
v1: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
v1: PromptFunction
|
||||
|
||||
|
||||
def v1(context: dict[str, any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role="system",
|
||||
content="You are a helpful assistant that extracts graph nodes from provided context.",
|
||||
),
|
||||
Message(
|
||||
role="user",
|
||||
content=f"""
|
||||
Given the following context, extract new semantic nodes that need to be added to the knowledge graph:
|
||||
|
||||
Existing Nodes:
|
||||
{json.dumps(context['existing_nodes'], indent=2)}
|
||||
|
||||
Previous Episodes:
|
||||
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
|
||||
|
||||
New Episode:
|
||||
Content: {context["episode_content"]}
|
||||
Timestamp: {context['episode_timestamp']}
|
||||
|
||||
Extract new semantic nodes based on the content of the current episode, while considering the existing nodes and context from previous episodes.
|
||||
|
||||
Guidelines:
|
||||
1. Only extract new nodes that don't already exist in the graph structure.
|
||||
2. Focus on entities, concepts, or actors that are central to the current episode.
|
||||
3. Avoid creating nodes for relationships or actions (these will be handled as edges later).
|
||||
4. Provide a brief but informative summary for each node.
|
||||
5. If a node seems to represent an existing concept but with updated information, don't create a new node. This will be handled by edge updates.
|
||||
6. Do not create nodes for episodic content (like Message 1 or Message 2).
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"new_nodes": [
|
||||
{{
|
||||
"name": "Unique identifier for the node",
|
||||
"labels": ["Semantic", "OptionalAdditionalLabel"],
|
||||
"summary": "Brief summary of the node's role or significance"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
If no new nodes need to be added, return an empty list for "new_nodes".
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {
|
||||
"v1": v1,
|
||||
}
|
||||
53
core/prompts/lib.py
Normal file
53
core/prompts/lib.py
Normal file
@ -0,0 +1,53 @@
|
||||
from typing import TypedDict, Protocol
|
||||
from .models import Message, PromptFunction
|
||||
from typing import TypedDict, Protocol
|
||||
from .models import Message, PromptFunction
|
||||
from .extract_nodes import (
|
||||
Prompt as ExtractNodesPrompt,
|
||||
Versions as ExtractNodesVersions,
|
||||
versions as extract_nodes_versions,
|
||||
)
|
||||
|
||||
from .extract_edges import (
|
||||
Prompt as ExtractEdgesPrompt,
|
||||
Versions as ExtractEdgesVersions,
|
||||
versions as extract_edges_versions,
|
||||
)
|
||||
|
||||
|
||||
class PromptLibrary(Protocol):
|
||||
extract_nodes: ExtractNodesPrompt
|
||||
extract_edges: ExtractEdgesPrompt
|
||||
|
||||
|
||||
class PromptLibraryImpl(TypedDict):
|
||||
extract_nodes: ExtractNodesVersions
|
||||
extract_edges: ExtractEdgesVersions
|
||||
|
||||
|
||||
class VersionWrapper:
|
||||
def __init__(self, func: PromptFunction):
|
||||
self.func = func
|
||||
|
||||
def __call__(self, context: dict[str, any]) -> list[Message]:
|
||||
return self.func(context)
|
||||
|
||||
|
||||
class PromptTypeWrapper:
|
||||
def __init__(self, versions: dict[str, PromptFunction]):
|
||||
for version, func in versions.items():
|
||||
setattr(self, version, VersionWrapper(func))
|
||||
|
||||
|
||||
class PromptLibraryWrapper:
|
||||
def __init__(self, library: PromptLibraryImpl):
|
||||
for prompt_type, versions in library.items():
|
||||
setattr(self, prompt_type, PromptTypeWrapper(versions))
|
||||
|
||||
|
||||
PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
|
||||
"extract_nodes": extract_nodes_versions,
|
||||
"extract_edges": extract_edges_versions,
|
||||
}
|
||||
|
||||
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL)
|
||||
15
core/prompts/models.py
Normal file
15
core/prompts/models.py
Normal file
@ -0,0 +1,15 @@
|
||||
from typing import Callable, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class PromptVersion(Protocol):
|
||||
def __call__(self, context: dict[str, any]) -> list[Message]: ...
|
||||
|
||||
|
||||
PromptFunction = Callable[[dict[str, any]], list[Message]]
|
||||
17
core/utils/__init__.py
Normal file
17
core/utils/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
from .maintenance import (
|
||||
extract_new_edges,
|
||||
build_episodic_edges,
|
||||
extract_new_nodes,
|
||||
clear_data,
|
||||
retrieve_relevant_schema,
|
||||
retrieve_episodes,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"extract_new_edges",
|
||||
"build_episodic_edges",
|
||||
"extract_new_nodes",
|
||||
"clear_data",
|
||||
"retrieve_relevant_schema",
|
||||
"retrieve_episodes",
|
||||
]
|
||||
16
core/utils/maintenance/__init__.py
Normal file
16
core/utils/maintenance/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
from .edge_operations import extract_new_edges, build_episodic_edges
|
||||
from .node_operations import extract_new_nodes
|
||||
from .graph_data_operations import (
|
||||
clear_data,
|
||||
retrieve_relevant_schema,
|
||||
retrieve_episodes,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"extract_new_edges",
|
||||
"build_episodic_edges",
|
||||
"extract_new_nodes",
|
||||
"clear_data",
|
||||
"retrieve_relevant_schema",
|
||||
"retrieve_episodes",
|
||||
]
|
||||
128
core/utils/maintenance/edge_operations.py
Normal file
128
core/utils/maintenance/edge_operations.py
Normal file
@ -0,0 +1,128 @@
|
||||
import json
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
|
||||
from core.nodes import EntityNode, EpisodicNode
|
||||
from core.edges import EpisodicEdge, EntityEdge
|
||||
import logging
|
||||
|
||||
from core.prompts import prompt_library
|
||||
from core.llm_client import LLMClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_episodic_edges(
|
||||
semantic_nodes: List[EntityNode],
|
||||
episode: EpisodicNode,
|
||||
transaction_from: datetime,
|
||||
) -> List[EpisodicEdge]:
|
||||
edges: List[EpisodicEdge] = []
|
||||
|
||||
for node in semantic_nodes:
|
||||
edge = EpisodicEdge(
|
||||
source_node=episode, target_node=node, created_at=transaction_from
|
||||
)
|
||||
edges.append(edge)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def extract_new_edges(
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
new_nodes: list[EntityNode],
|
||||
relevant_schema: dict[str, any],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> list[EntityEdge]:
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
"episode_content": episode.content,
|
||||
"episode_timestamp": (
|
||||
episode.valid_at.isoformat() if episode.valid_at else None
|
||||
),
|
||||
"relevant_schema": json.dumps(relevant_schema, indent=2),
|
||||
"new_nodes": [
|
||||
{"name": node.name, "summary": node.summary} for node in new_nodes
|
||||
],
|
||||
"previous_episodes": [
|
||||
{
|
||||
"content": ep.content,
|
||||
"timestamp": ep.valid_at.isoformat() if ep.valid_at else None,
|
||||
}
|
||||
for ep in previous_episodes
|
||||
],
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.extract_edges.v1(context)
|
||||
)
|
||||
new_edges_data = llm_response.get("new_edges", [])
|
||||
|
||||
# Convert the extracted data into EntityEdge objects
|
||||
new_edges = []
|
||||
for edge_data in new_edges_data:
|
||||
source_node = next(
|
||||
(node for node in new_nodes if node.name == edge_data["source_node"]),
|
||||
None,
|
||||
)
|
||||
target_node = next(
|
||||
(node for node in new_nodes if node.name == edge_data["target_node"]),
|
||||
None,
|
||||
)
|
||||
|
||||
# If source or target is not in new_nodes, check if it's an existing node
|
||||
if source_node is None and edge_data["source_node"] in relevant_schema["nodes"]:
|
||||
existing_node_data = relevant_schema["nodes"][edge_data["source_node"]]
|
||||
source_node = EntityNode(
|
||||
uuid=existing_node_data["uuid"],
|
||||
name=edge_data["source_node"],
|
||||
labels=[existing_node_data["label"]],
|
||||
summary="",
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
if target_node is None and edge_data["target_node"] in relevant_schema["nodes"]:
|
||||
existing_node_data = relevant_schema["nodes"][edge_data["target_node"]]
|
||||
target_node = EntityNode(
|
||||
uuid=existing_node_data["uuid"],
|
||||
name=edge_data["target_node"],
|
||||
labels=[existing_node_data["label"]],
|
||||
summary="",
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
if (
|
||||
source_node
|
||||
and target_node
|
||||
and not (
|
||||
source_node.name.startswith("Message")
|
||||
or target_node.name.startswith("Message")
|
||||
)
|
||||
):
|
||||
valid_at = (
|
||||
datetime.fromisoformat(edge_data["valid_at"])
|
||||
if edge_data["valid_at"]
|
||||
else episode.valid_at or datetime.now()
|
||||
)
|
||||
invalid_at = (
|
||||
datetime.fromisoformat(edge_data["invalid_at"])
|
||||
if edge_data["invalid_at"]
|
||||
else None
|
||||
)
|
||||
|
||||
new_edge = EntityEdge(
|
||||
source_node=source_node,
|
||||
target_node=target_node,
|
||||
name=edge_data["relation_type"],
|
||||
fact=edge_data["fact"],
|
||||
episodes=[episode.uuid],
|
||||
created_at=datetime.now(),
|
||||
valid_at=valid_at,
|
||||
invalid_at=invalid_at,
|
||||
)
|
||||
new_edges.append(new_edge)
|
||||
logger.info(
|
||||
f"Created new edge: {new_edge.name} from {source_node.name} (UUID: {source_node.uuid}) to {target_node.name} (UUID: {target_node.uuid})"
|
||||
)
|
||||
|
||||
return new_edges
|
||||
95
core/utils/maintenance/graph_data_operations.py
Normal file
95
core/utils/maintenance/graph_data_operations.py
Normal file
@ -0,0 +1,95 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from core.nodes import EpisodicNode
|
||||
from neo4j import AsyncDriver
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def clear_data(driver: AsyncDriver):
|
||||
async with driver.session() as session:
|
||||
|
||||
async def delete_all(tx):
|
||||
await tx.run("MATCH (n) DETACH DELETE n")
|
||||
|
||||
await session.execute_write(delete_all)
|
||||
|
||||
|
||||
async def retrieve_relevant_schema(
|
||||
driver: AsyncDriver, query: str = None
|
||||
) -> dict[str, any]:
|
||||
async with driver.session() as session:
|
||||
summary_query = """
|
||||
MATCH (n)
|
||||
OPTIONAL MATCH (n)-[r]->(m)
|
||||
RETURN DISTINCT labels(n) AS node_labels, n.uuid AS node_uuid, n.name AS node_name,
|
||||
type(r) AS relationship_type, r.name AS relationship_name, m.name AS related_node_name
|
||||
"""
|
||||
result = await session.run(summary_query)
|
||||
records = [record async for record in result]
|
||||
|
||||
schema = {"nodes": {}, "relationships": []}
|
||||
|
||||
for record in records:
|
||||
node_label = record["node_labels"][0] # Assuming one label per node
|
||||
node_uuid = record["node_uuid"]
|
||||
node_name = record["node_name"]
|
||||
rel_type = record["relationship_type"]
|
||||
rel_name = record["relationship_name"]
|
||||
related_node = record["related_node_name"]
|
||||
|
||||
if node_name not in schema["nodes"]:
|
||||
schema["nodes"][node_name] = {
|
||||
"uuid": node_uuid,
|
||||
"label": node_label,
|
||||
"relationships": [],
|
||||
}
|
||||
|
||||
if rel_type and related_node:
|
||||
schema["nodes"][node_name]["relationships"].append(
|
||||
{"type": rel_type, "name": rel_name, "target": related_node}
|
||||
)
|
||||
schema["relationships"].append(
|
||||
{
|
||||
"source": node_name,
|
||||
"type": rel_type,
|
||||
"name": rel_name,
|
||||
"target": related_node,
|
||||
}
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
async def retrieve_episodes(
|
||||
driver: AsyncDriver, last_n: int, sources: list[str] | None = "messages"
|
||||
) -> list[EpisodicNode]:
|
||||
"""Retrieve the last n episodic nodes from the graph"""
|
||||
async with driver.session() as session:
|
||||
query = """
|
||||
MATCH (e:EpisodicNode)
|
||||
RETURN e.content as text, e.timestamp as timestamp, e.reference_timestamp as reference_timestamp
|
||||
ORDER BY e.timestamp DESC
|
||||
LIMIT $num_episodes
|
||||
"""
|
||||
result = await session.run(query, num_episodes=last_n)
|
||||
episodes = [
|
||||
EpisodicNode(
|
||||
content=record["text"],
|
||||
transaction_from=datetime.fromtimestamp(
|
||||
record["timestamp"].to_native().timestamp(), timezone.utc
|
||||
),
|
||||
valid_at=(
|
||||
datetime.fromtimestamp(
|
||||
record["reference_timestamp"].to_native().timestamp(),
|
||||
timezone.utc,
|
||||
)
|
||||
if record["reference_timestamp"] is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
async for record in result
|
||||
]
|
||||
return list(reversed(episodes)) # Return in chronological order
|
||||
63
core/utils/maintenance/node_operations.py
Normal file
63
core/utils/maintenance/node_operations.py
Normal file
@ -0,0 +1,63 @@
|
||||
from datetime import datetime
|
||||
|
||||
from core.nodes import EntityNode, EpisodicNode
|
||||
import logging
|
||||
from core.llm_client import LLMClient
|
||||
|
||||
from core.prompts import prompt_library
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def extract_new_nodes(
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
relevant_schema: dict[str, any],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> list[EntityNode]:
|
||||
# Prepare context for LLM
|
||||
existing_nodes = [
|
||||
{"name": node_name, "label": node_info["label"], "uuid": node_info["uuid"]}
|
||||
for node_name, node_info in relevant_schema["nodes"].items()
|
||||
]
|
||||
|
||||
context = {
|
||||
"episode_content": episode.content,
|
||||
"episode_timestamp": (
|
||||
episode.valid_at.isoformat() if episode.valid_at else None
|
||||
),
|
||||
"existing_nodes": existing_nodes,
|
||||
"previous_episodes": [
|
||||
{
|
||||
"content": ep.content,
|
||||
"timestamp": ep.valid_at.isoformat() if ep.valid_at else None,
|
||||
}
|
||||
for ep in previous_episodes
|
||||
],
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.v1(context)
|
||||
)
|
||||
new_nodes_data = llm_response.get("new_nodes", [])
|
||||
logger.info(f"Extracted new nodes: {new_nodes_data}")
|
||||
# Convert the extracted data into EntityNode objects
|
||||
new_nodes = []
|
||||
for node_data in new_nodes_data:
|
||||
# Check if the node already exists
|
||||
if not any(
|
||||
existing_node["name"] == node_data["name"]
|
||||
for existing_node in existing_nodes
|
||||
):
|
||||
new_node = EntityNode(
|
||||
name=node_data["name"],
|
||||
labels=node_data["labels"],
|
||||
summary=node_data["summary"],
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
new_nodes.append(new_node)
|
||||
logger.info(f"Created new node: {new_node.name} (UUID: {new_node.uuid})")
|
||||
else:
|
||||
logger.info(f"Node {node_data['name']} already exists, skipping creation.")
|
||||
|
||||
return new_nodes
|
||||
0
core/utils/maintenance/utils.py
Normal file
0
core/utils/maintenance/utils.py
Normal file
66
runner.py
Normal file
66
runner.py
Normal file
@ -0,0 +1,66 @@
|
||||
from core import Graphiti
|
||||
from core.utils.maintenance.graph_data_operations import clear_data
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
load_dotenv()
|
||||
|
||||
neo4j_uri = os.environ.get("NEO4J_URI") or "bolt://localhost:7687"
|
||||
neo4j_user = os.environ.get("NEO4J_USER") or "neo4j"
|
||||
neo4j_password = os.environ.get("NEO4J_PASSWORD") or "password"
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Create a logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||
|
||||
# Create console handler and set level to INFO
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
# Add formatter to console handler
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Add console handler to logger
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
async def main():
|
||||
setup_logging()
|
||||
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
await clear_data(client.driver)
|
||||
# await client.build_indices()
|
||||
await client.add_episode(
|
||||
name="Message 1",
|
||||
episode_body="Paul: I love apples",
|
||||
source_description="WhatsApp Message",
|
||||
)
|
||||
await client.add_episode(
|
||||
name="Message 2",
|
||||
episode_body="Paul: I love bananas",
|
||||
source_description="WhatsApp Message",
|
||||
)
|
||||
await client.add_episode(
|
||||
name="Message 3",
|
||||
episode_body="Assistant: The best type of apples available are Fuji apples",
|
||||
source_description="WhatsApp Message",
|
||||
)
|
||||
await client.add_episode(
|
||||
name="Message 4",
|
||||
episode_body="Paul: Oh, I actually hate those",
|
||||
source_description="WhatsApp Message",
|
||||
)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
Loading…
x
Reference in New Issue
Block a user