79 lines
2.2 KiB
Python
Raw Normal View History

import asyncio
from contextlib import asynccontextmanager
from functools import partial
from fastapi import APIRouter, FastAPI, status
from graphiti_core.nodes import EpisodeType # type: ignore
from graphiti_core.utils import clear_data # type: ignore
from graph_service.dto import AddMessagesRequest, Message, Result
from graph_service.zep_graphiti import ZepGraphitiDep
class AsyncWorker:
def __init__(self):
self.queue = asyncio.Queue()
self.task = None
async def worker(self):
while True:
try:
print(f'Got a job: (size of remaining queue: {self.queue.qsize()})')
job = await self.queue.get()
await job()
except asyncio.CancelledError:
break
async def start(self):
self.task = asyncio.create_task(self.worker())
async def stop(self):
if self.task:
self.task.cancel()
await self.task
while not self.queue.empty():
self.queue.get_nowait()
async_worker = AsyncWorker()
@asynccontextmanager
async def lifespan(_: FastAPI):
await async_worker.start()
yield
await async_worker.stop()
router = APIRouter(lifespan=lifespan)
@router.post('/messages', status_code=status.HTTP_202_ACCEPTED)
async def add_messages(
request: AddMessagesRequest,
graphiti: ZepGraphitiDep,
):
async def add_messages_task(m: Message):
# Will pass a group_id to the add_episode call once it is implemented
await graphiti.add_episode(
name=m.name,
episode_body=f"{m.role or ''}({m.role_type}): {m.content}",
reference_time=m.timestamp,
source=EpisodeType.message,
source_description=m.source_description,
)
for m in request.messages:
await async_worker.queue.put(partial(add_messages_task, m))
return Result(message='Messages added to processing queue', success=True)
@router.post('/clear', status_code=status.HTTP_200_OK)
async def clear(
graphiti: ZepGraphitiDep,
):
await clear_data(graphiti.driver)
await graphiti.build_indices_and_constraints()
return Result(message='Graph cleared', success=True)