112 lines
3.3 KiB
Python
Raw Normal View History

import asyncio
from contextlib import asynccontextmanager
from functools import partial
from fastapi import APIRouter, FastAPI, status
from graphiti_core.nodes import EpisodeType # type: ignore
from graphiti_core.utils.maintenance.graph_data_operations import clear_data # type: ignore
from graph_service.dto import AddEntityNodeRequest, AddMessagesRequest, Message, Result
from graph_service.zep_graphiti import ZepGraphitiDep
class AsyncWorker:
def __init__(self):
self.queue = asyncio.Queue()
self.task = None
async def worker(self):
while True:
try:
print(f'Got a job: (size of remaining queue: {self.queue.qsize()})')
job = await self.queue.get()
await job()
except asyncio.CancelledError:
break
async def start(self):
self.task = asyncio.create_task(self.worker())
async def stop(self):
if self.task:
self.task.cancel()
await self.task
while not self.queue.empty():
self.queue.get_nowait()
async_worker = AsyncWorker()
@asynccontextmanager
async def lifespan(_: FastAPI):
await async_worker.start()
yield
await async_worker.stop()
router = APIRouter(lifespan=lifespan)
@router.post('/messages', status_code=status.HTTP_202_ACCEPTED)
async def add_messages(
request: AddMessagesRequest,
graphiti: ZepGraphitiDep,
):
async def add_messages_task(m: Message):
await graphiti.add_episode(
uuid=m.uuid,
group_id=request.group_id,
name=m.name,
episode_body=f'{m.role or ""}({m.role_type}): {m.content}',
reference_time=m.timestamp,
source=EpisodeType.message,
source_description=m.source_description,
)
for m in request.messages:
await async_worker.queue.put(partial(add_messages_task, m))
return Result(message='Messages added to processing queue', success=True)
@router.post('/entity-node', status_code=status.HTTP_201_CREATED)
async def add_entity_node(
request: AddEntityNodeRequest,
graphiti: ZepGraphitiDep,
):
node = await graphiti.save_entity_node(
uuid=request.uuid,
group_id=request.group_id,
name=request.name,
summary=request.summary,
)
return node
@router.delete('/entity-edge/{uuid}', status_code=status.HTTP_200_OK)
async def delete_entity_edge(uuid: str, graphiti: ZepGraphitiDep):
await graphiti.delete_entity_edge(uuid)
return Result(message='Entity Edge deleted', success=True)
@router.delete('/group/{group_id}', status_code=status.HTTP_200_OK)
async def delete_group(group_id: str, graphiti: ZepGraphitiDep):
await graphiti.delete_group(group_id)
return Result(message='Group deleted', success=True)
@router.delete('/episode/{uuid}', status_code=status.HTTP_200_OK)
async def delete_episode(uuid: str, graphiti: ZepGraphitiDep):
await graphiti.delete_episodic_node(uuid)
return Result(message='Episode deleted', success=True)
@router.post('/clear', status_code=status.HTTP_200_OK)
async def clear(
graphiti: ZepGraphitiDep,
):
await clear_data(graphiti.driver)
await graphiti.build_indices_and_constraints()
return Result(message='Graph cleared', success=True)