feat: add error handling for missing nodes and edges, introduce new API endpoints, and update ZepGraphiti class (#104)

* feat: Expose crud operations to service + add graphiti errors

* fix: linter
This commit is contained in:
Pavlo Paliychuk 2024-09-11 12:53:17 -04:00 committed by GitHub
parent c0a740ff60
commit 8085b52f2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 69 additions and 5 deletions

View File

@ -24,6 +24,7 @@ from uuid import uuid4
from neo4j import AsyncDriver
from pydantic import BaseModel, Field
from graphiti_core.errors import EdgeNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.llm_client.config import EMBEDDING_DIM
from graphiti_core.nodes import Node
@ -104,7 +105,8 @@ class EpisodicEdge(Edge):
edges = [get_episodic_edge_from_record(record) for record in records]
logger.info(f'Found Edge: {uuid}')
if len(edges) == 0:
raise EdgeNotFoundError(uuid)
return edges[0]
@ -191,7 +193,8 @@ class EntityEdge(Edge):
edges = [get_entity_edge_from_record(record) for record in records]
logger.info(f'Found Edge: {uuid}')
if len(edges) == 0:
raise EdgeNotFoundError(uuid)
return edges[0]

18
graphiti_core/errors.py Normal file
View File

@ -0,0 +1,18 @@
class GraphitiError(Exception):
"""Base exception class for Graphiti Core."""
class EdgeNotFoundError(GraphitiError):
"""Raised when an edge is not found."""
def __init__(self, uuid: str):
self.message = f'edge {uuid} not found'
super().__init__(self.message)
class NodeNotFoundError(GraphitiError):
"""Raised when a node is not found."""
def __init__(self, uuid: str):
self.message = f'node {uuid} not found'
super().__init__(self.message)

View File

@ -25,6 +25,7 @@ from uuid import uuid4
from neo4j import AsyncDriver
from pydantic import BaseModel, Field
from graphiti_core.errors import NodeNotFoundError
from graphiti_core.llm_client.config import EMBEDDING_DIM
logger = logging.getLogger(__name__)
@ -148,7 +149,7 @@ class EpisodicNode(Node):
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id
e.group_id AS group_id,
e.source_description AS source_description,
e.source AS source
""",
@ -159,6 +160,9 @@ class EpisodicNode(Node):
logger.info(f'Found Node: {uuid}')
if len(episodes) == 0:
raise NodeNotFoundError(uuid)
return episodes[0]
@classmethod

View File

@ -83,6 +83,18 @@ async def add_entity_node(
return node
@router.delete('/entity-edge/{uuid}', status_code=status.HTTP_200_OK)
async def delete_entity_edge(uuid: str, graphiti: ZepGraphitiDep):
await graphiti.delete_entity_edge(uuid)
return Result(message='Entity Edge deleted', success=True)
@router.delete('/episode/{uuid}', status_code=status.HTTP_200_OK)
async def delete_episode(uuid: str, graphiti: ZepGraphitiDep):
await graphiti.delete_episodic_node(uuid)
return Result(message='Episode deleted', success=True)
@router.post('/clear', status_code=status.HTTP_200_OK)
async def clear(
graphiti: ZepGraphitiDep,

View File

@ -27,6 +27,11 @@ async def search(query: SearchQuery, graphiti: ZepGraphitiDep):
)
@router.get('/entity-edge/{uuid}', status_code=status.HTTP_200_OK)
async def get_entity_edge(uuid: str, graphiti: ZepGraphitiDep):
return await graphiti.get_entity_edge(uuid)
@router.get('/episodes/{group_id}', status_code=status.HTTP_200_OK)
async def get_episodes(group_id: str, last_n: int, graphiti: ZepGraphitiDep):
episodes = await graphiti.retrieve_episodes(

View File

@ -1,10 +1,11 @@
from typing import Annotated
from fastapi import Depends
from fastapi import Depends, HTTPException
from graphiti_core import Graphiti # type: ignore
from graphiti_core.edges import EntityEdge # type: ignore
from graphiti_core.errors import EdgeNotFoundError, NodeNotFoundError # type: ignore
from graphiti_core.llm_client import LLMClient # type: ignore
from graphiti_core.nodes import EntityNode # type: ignore
from graphiti_core.nodes import EntityNode, EpisodicNode # type: ignore
from graph_service.config import ZepEnvDep
from graph_service.dto import FactResult
@ -25,6 +26,27 @@ class ZepGraphiti(Graphiti):
await new_node.save(self.driver)
return new_node
async def get_entity_edge(self, uuid: str):
try:
edge = await EntityEdge.get_by_uuid(self.driver, uuid)
return edge
except EdgeNotFoundError as e:
raise HTTPException(status_code=404, detail=e.message) from e
async def delete_entity_edge(self, uuid: str):
try:
edge = await EntityEdge.get_by_uuid(self.driver, uuid)
await edge.delete(self.driver)
except EdgeNotFoundError as e:
raise HTTPException(status_code=404, detail=e.message) from e
async def delete_episodic_node(self, uuid: str):
try:
episode = await EpisodicNode.get_by_uuid(self.driver, uuid)
await episode.delete(self.driver)
except NodeNotFoundError as e:
raise HTTPException(status_code=404, detail=e.message) from e
async def get_graphiti(settings: ZepEnvDep):
client = ZepGraphiti(