Pavlo Paliychuk ba48f64492
Add Fastapi graph service (#88)
* chore: Folder rearrangement

* chore: Remove unused deps, and add mypy step in CI for graph-service

* fix: Mypy errors

* fix: linter

* fix mypy

* fix mypy

* chore: Update docker setup

* chore: Reduce graph service image size

* chore: Install graph service deps on CI

* remove cache from typecheck

* chore: install graph-service deps on typecheck action

* update graph service mypy direction

* feat: Add release service image step

* chore: Update depot configuration

* chore: Update release image job to run on releases

* chore: Test depot multiplatform build

* update release action tag

* chore: Update action to be in accordance with zep image publish

* test

* test

* revert

* chore: Update python slim image used in service docker

* chore: Remove unused endpoints and dtos
2024-09-06 11:07:45 -04:00

79 lines
2.2 KiB
Python

import asyncio
from contextlib import asynccontextmanager
from functools import partial
from fastapi import APIRouter, FastAPI, status
from graphiti_core.nodes import EpisodeType # type: ignore
from graphiti_core.utils import clear_data # type: ignore
from graph_service.dto import AddMessagesRequest, Message, Result
from graph_service.zep_graphiti import ZepGraphitiDep
class AsyncWorker:
def __init__(self):
self.queue = asyncio.Queue()
self.task = None
async def worker(self):
while True:
try:
print(f'Got a job: (size of remaining queue: {self.queue.qsize()})')
job = await self.queue.get()
await job()
except asyncio.CancelledError:
break
async def start(self):
self.task = asyncio.create_task(self.worker())
async def stop(self):
if self.task:
self.task.cancel()
await self.task
while not self.queue.empty():
self.queue.get_nowait()
async_worker = AsyncWorker()
@asynccontextmanager
async def lifespan(_: FastAPI):
await async_worker.start()
yield
await async_worker.stop()
router = APIRouter(lifespan=lifespan)
@router.post('/messages', status_code=status.HTTP_202_ACCEPTED)
async def add_messages(
request: AddMessagesRequest,
graphiti: ZepGraphitiDep,
):
async def add_messages_task(m: Message):
# Will pass a group_id to the add_episode call once it is implemented
await graphiti.add_episode(
name=m.name,
episode_body=f"{m.role or ''}({m.role_type}): {m.content}",
reference_time=m.timestamp,
source=EpisodeType.message,
source_description=m.source_description,
)
for m in request.messages:
await async_worker.queue.put(partial(add_messages_task, m))
return Result(message='Messages added to processing queue', success=True)
@router.post('/clear', status_code=status.HTTP_200_OK)
async def clear(
graphiti: ZepGraphitiDep,
):
await clear_data(graphiti.driver)
await graphiti.build_indices_and_constraints()
return Result(message='Graph cleared', success=True)