mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00

* chore: Folder rearrangement * chore: Remove unused deps, and add mypy step in CI for graph-service * fix: Mypy errors * fix: linter * fix mypy * fix mypy * chore: Update docker setup * chore: Reduce graph service image size * chore: Install graph service deps on CI * remove cache from typecheck * chore: install graph-service deps on typecheck action * update graph service mypy direction * feat: Add release service image step * chore: Update depot configuration * chore: Update release image job to run on releases * chore: Test depot multiplatform build * update release action tag * chore: Update action to be in accordance with zep image publish * test * test * revert * chore: Update python slim image used in service docker * chore: Remove unused endpoints and dtos
79 lines
2.2 KiB
Python
79 lines
2.2 KiB
Python
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
from functools import partial
|
|
|
|
from fastapi import APIRouter, FastAPI, status
|
|
from graphiti_core.nodes import EpisodeType # type: ignore
|
|
from graphiti_core.utils import clear_data # type: ignore
|
|
|
|
from graph_service.dto import AddMessagesRequest, Message, Result
|
|
from graph_service.zep_graphiti import ZepGraphitiDep
|
|
|
|
|
|
class AsyncWorker:
|
|
def __init__(self):
|
|
self.queue = asyncio.Queue()
|
|
self.task = None
|
|
|
|
async def worker(self):
|
|
while True:
|
|
try:
|
|
print(f'Got a job: (size of remaining queue: {self.queue.qsize()})')
|
|
job = await self.queue.get()
|
|
await job()
|
|
except asyncio.CancelledError:
|
|
break
|
|
|
|
async def start(self):
|
|
self.task = asyncio.create_task(self.worker())
|
|
|
|
async def stop(self):
|
|
if self.task:
|
|
self.task.cancel()
|
|
await self.task
|
|
while not self.queue.empty():
|
|
self.queue.get_nowait()
|
|
|
|
|
|
async_worker = AsyncWorker()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(_: FastAPI):
|
|
await async_worker.start()
|
|
yield
|
|
await async_worker.stop()
|
|
|
|
|
|
router = APIRouter(lifespan=lifespan)
|
|
|
|
|
|
@router.post('/messages', status_code=status.HTTP_202_ACCEPTED)
|
|
async def add_messages(
|
|
request: AddMessagesRequest,
|
|
graphiti: ZepGraphitiDep,
|
|
):
|
|
async def add_messages_task(m: Message):
|
|
# Will pass a group_id to the add_episode call once it is implemented
|
|
await graphiti.add_episode(
|
|
name=m.name,
|
|
episode_body=f"{m.role or ''}({m.role_type}): {m.content}",
|
|
reference_time=m.timestamp,
|
|
source=EpisodeType.message,
|
|
source_description=m.source_description,
|
|
)
|
|
|
|
for m in request.messages:
|
|
await async_worker.queue.put(partial(add_messages_task, m))
|
|
|
|
return Result(message='Messages added to processing queue', success=True)
|
|
|
|
|
|
@router.post('/clear', status_code=status.HTTP_200_OK)
|
|
async def clear(
|
|
graphiti: ZepGraphitiDep,
|
|
):
|
|
await clear_data(graphiti.driver)
|
|
await graphiti.build_indices_and_constraints()
|
|
return Result(message='Graph cleared', success=True)
|