import asyncio from contextlib import asynccontextmanager from functools import partial from fastapi import APIRouter, FastAPI, status from graphiti_core.nodes import EpisodeType # type: ignore from graphiti_core.utils import clear_data # type: ignore from graph_service.dto import AddMessagesRequest, Message, Result from graph_service.zep_graphiti import ZepGraphitiDep class AsyncWorker: def __init__(self): self.queue = asyncio.Queue() self.task = None async def worker(self): while True: try: print(f'Got a job: (size of remaining queue: {self.queue.qsize()})') job = await self.queue.get() await job() except asyncio.CancelledError: break async def start(self): self.task = asyncio.create_task(self.worker()) async def stop(self): if self.task: self.task.cancel() await self.task while not self.queue.empty(): self.queue.get_nowait() async_worker = AsyncWorker() @asynccontextmanager async def lifespan(_: FastAPI): await async_worker.start() yield await async_worker.stop() router = APIRouter(lifespan=lifespan) @router.post('/messages', status_code=status.HTTP_202_ACCEPTED) async def add_messages( request: AddMessagesRequest, graphiti: ZepGraphitiDep, ): async def add_messages_task(m: Message): # Will pass a group_id to the add_episode call once it is implemented await graphiti.add_episode( name=m.name, episode_body=f"{m.role or ''}({m.role_type}): {m.content}", reference_time=m.timestamp, source=EpisodeType.message, source_description=m.source_description, ) for m in request.messages: await async_worker.queue.put(partial(add_messages_task, m)) return Result(message='Messages added to processing queue', success=True) @router.post('/clear', status_code=status.HTTP_200_OK) async def clear( graphiti: ZepGraphitiDep, ): await clear_data(graphiti.driver) await graphiti.build_indices_and_constraints() return Result(message='Graph cleared', success=True)