mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
79 lines
2.2 KiB
Python
79 lines
2.2 KiB
Python
![]() |
import asyncio
|
||
|
from contextlib import asynccontextmanager
|
||
|
from functools import partial
|
||
|
|
||
|
from fastapi import APIRouter, FastAPI, status
|
||
|
from graphiti_core.nodes import EpisodeType # type: ignore
|
||
|
from graphiti_core.utils import clear_data # type: ignore
|
||
|
|
||
|
from graph_service.dto import AddMessagesRequest, Message, Result
|
||
|
from graph_service.zep_graphiti import ZepGraphitiDep
|
||
|
|
||
|
|
||
|
class AsyncWorker:
|
||
|
def __init__(self):
|
||
|
self.queue = asyncio.Queue()
|
||
|
self.task = None
|
||
|
|
||
|
async def worker(self):
|
||
|
while True:
|
||
|
try:
|
||
|
print(f'Got a job: (size of remaining queue: {self.queue.qsize()})')
|
||
|
job = await self.queue.get()
|
||
|
await job()
|
||
|
except asyncio.CancelledError:
|
||
|
break
|
||
|
|
||
|
async def start(self):
|
||
|
self.task = asyncio.create_task(self.worker())
|
||
|
|
||
|
async def stop(self):
|
||
|
if self.task:
|
||
|
self.task.cancel()
|
||
|
await self.task
|
||
|
while not self.queue.empty():
|
||
|
self.queue.get_nowait()
|
||
|
|
||
|
|
||
|
async_worker = AsyncWorker()
|
||
|
|
||
|
|
||
|
@asynccontextmanager
|
||
|
async def lifespan(_: FastAPI):
|
||
|
await async_worker.start()
|
||
|
yield
|
||
|
await async_worker.stop()
|
||
|
|
||
|
|
||
|
router = APIRouter(lifespan=lifespan)
|
||
|
|
||
|
|
||
|
@router.post('/messages', status_code=status.HTTP_202_ACCEPTED)
|
||
|
async def add_messages(
|
||
|
request: AddMessagesRequest,
|
||
|
graphiti: ZepGraphitiDep,
|
||
|
):
|
||
|
async def add_messages_task(m: Message):
|
||
|
# Will pass a group_id to the add_episode call once it is implemented
|
||
|
await graphiti.add_episode(
|
||
|
name=m.name,
|
||
|
episode_body=f"{m.role or ''}({m.role_type}): {m.content}",
|
||
|
reference_time=m.timestamp,
|
||
|
source=EpisodeType.message,
|
||
|
source_description=m.source_description,
|
||
|
)
|
||
|
|
||
|
for m in request.messages:
|
||
|
await async_worker.queue.put(partial(add_messages_task, m))
|
||
|
|
||
|
return Result(message='Messages added to processing queue', success=True)
|
||
|
|
||
|
|
||
|
@router.post('/clear', status_code=status.HTTP_200_OK)
|
||
|
async def clear(
|
||
|
graphiti: ZepGraphitiDep,
|
||
|
):
|
||
|
await clear_data(graphiti.driver)
|
||
|
await graphiti.build_indices_and_constraints()
|
||
|
return Result(message='Graph cleared', success=True)
|