Add Fastapi graph service (#88)

* chore: Folder rearrangement

* chore: Remove unused deps, and add mypy step in CI for graph-service

* fix: Mypy errors

* fix: linter

* fix mypy

* fix mypy

* chore: Update docker setup

* chore: Reduce graph service image size

* chore: Install graph service deps on CI

* remove cache from typecheck

* chore: install graph-service deps on typecheck action

* update graph service mypy direction

* feat: Add release service image step

* chore: Update depot configuration

* chore: Update release image job to run on releases

* chore: Test depot multiplatform build

* update release action tag

* chore: Update action to be in accordance with zep image publish

* test

* test

* revert

* chore: Update python slim image used in service docker

* chore: Remove unused endpoints and dtos
This commit is contained in:
Pavlo Paliychuk 2024-09-06 11:07:45 -04:00 committed by GitHub
parent a29c3557d3
commit ba48f64492
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 2234 additions and 1124 deletions

View File

@ -1,4 +1,4 @@
OPENAI_API_KEY=
NEO4J_URI=
NEO4J_USER=
NEO4J_PASSWORD=
NEO4J_PASSWORD=

View File

@ -0,0 +1,56 @@
name: Build image
on:
push:
# Publish semver tags as releases.
tags: [ 'v*.*.*' ]
workflow_dispatch:
inputs:
tag:
description: 'Tag to build and publish'
required: true
env:
REGISTRY: docker.io
IMAGE_NAME: zepai/graph-service
jobs:
docker-image:
environment:
name: release
runs-on: ubuntu-latest
steps:
- name: Checkout repo
uses: actions/checkout@v4
with:
ref: ${{ github.event.inputs.tag || github.ref }}
- name: Set up Depot CLI
uses: depot/setup-action@v1
- name: Login to DockerHub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Extract Docker metadata
id: meta
uses: docker/metadata-action@v4.4.0
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=match,pattern=v(.*-beta),group=1
type=match,pattern=v.*-(beta),group=1
- name: Build and push
uses: depot/build-push-action@v1
with:
token: ${{ secrets.DEPOT_PROJECT_TOKEN }}
context: .
push: true
platforms: linux/amd64,linux/arm64
tags: ${{ steps.meta.outputs.tags || env.TAGS }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max

View File

@ -24,16 +24,9 @@ jobs:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --with dev
- name: Run MyPy
- name: Run MyPy for graphiti-core
shell: bash
run: |
set -o pipefail
@ -41,3 +34,17 @@ jobs:
s/^(.*):([0-9]+):([0-9]+): (error|warning): (.+) \[(.+)\]/::error file=\1,line=\2,endLine=\2,col=\3,title=\6::\5/;
s/^(.*):([0-9]+):([0-9]+): note: (.+)/::notice file=\1,line=\2,endLine=\2,col=\3,title=Note::\4/;
'
- name: Install graph-service dependencies
shell: bash
run: |
cd server
poetry install --no-interaction --with dev
- name: Run MyPy for graph-service
shell: bash
run: |
cd server
set -o pipefail
poetry run mypy . --show-column-numbers --show-error-codes | sed -E '
s/^(.*):([0-9]+):([0-9]+): (error|warning): (.+) \[(.+)\]/::error file=\1,line=\2,endLine=\2,col=\3,title=\6::\5/;
s/^(.*):([0-9]+):([0-9]+): note: (.+)/::notice file=\1,line=\2,endLine=\2,col=\3,title=Note::\4/;
'

42
Dockerfile Normal file
View File

@ -0,0 +1,42 @@
# Build stage
FROM python:3.12-slim as builder
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
&& rm -rf /var/lib/apt/lists/*
# Install Poetry
RUN pip install --no-cache-dir poetry
# Copy only the files needed for installation
COPY ./pyproject.toml ./poetry.lock* ./README.md /app/
COPY ./graphiti_core /app/graphiti_core
COPY ./server/pyproject.toml ./server/poetry.lock* /app/server/
RUN poetry config virtualenvs.create false
# Install the local package
RUN poetry build && pip install dist/*.whl
# Install server dependencies
WORKDIR /app/server
RUN poetry install --no-interaction --no-ansi --no-dev
FROM python:3.12-slim
# Copy only the necessary files from the builder stage
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
COPY --from=builder /usr/local/bin /usr/local/bin
# Create the app directory and copy server files
WORKDIR /app
COPY ./server /app
# Set environment variables
ENV PYTHONUNBUFFERED=1
# Command to run the application
CMD ["uvicorn", "graph_service.main:app", "--host", "0.0.0.0", "--port", "8000"]

1
depot.json Normal file
View File

@ -0,0 +1 @@
{"id":"v9jv1mlpwc"}

26
docker-compose.yml Normal file
View File

@ -0,0 +1,26 @@
version: '3.8'
services:
graph:
build: .
ports:
- "8000:8000"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
- NEO4J_URI=bolt://neo4j:${NEO4J_PORT}
- NEO4J_USER=${NEO4J_USER}
- NEO4J_PASSWORD=${NEO4J_PASSWORD}
neo4j:
image: neo4j:5.22.0
ports:
- "7474:7474" # HTTP
- "${NEO4J_PORT}:${NEO4J_PORT}" # Bolt
volumes:
- neo4j_data:/data
environment:
- NEO4J_AUTH=${NEO4J_USER}/${NEO4J_PASSWORD}
volumes:
neo4j_data:

1486
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -10,19 +10,16 @@ authors = [
readme = "README.md"
license = "Apache-2.0"
packages = [
{ include = "graphiti_core", from = "." }
]
packages = [{ include = "graphiti_core", from = "." }]
[tool.poetry.dependencies]
python = "^3.10"
pydantic = "^2.8.2"
fastapi = "^0.112.0"
neo4j = "^5.23.0"
sentence-transformers = "^3.0.1"
diskcache = "^5.6.3"
openai = "^1.38.0"
tenacity = "<9.0.0"
numpy = "^2.1.1"
[tool.poetry.dev-dependencies]
pytest = "^8.3.2"

6
server/.env.example Normal file
View File

@ -0,0 +1,6 @@
OPENAI_API_KEY=
NEO4J_PORT=7687
# Only used if not running a neo4j container in docker
NEO4J_URI=bolt://localhost:7687
NEO4J_USER=neo4j
NEO4J_PASSWORD=password

32
server/Makefile Normal file
View File

@ -0,0 +1,32 @@
.PHONY: install format lint test all check
# Define variables
PYTHON = python3
POETRY = poetry
PYTEST = $(POETRY) run pytest
RUFF = $(POETRY) run ruff
MYPY = $(POETRY) run mypy
# Default target
all: format lint test
# Install dependencies
install:
$(POETRY) install --with dev
# Format code
format:
$(RUFF) check --select I --fix
$(RUFF) format
# Lint code
lint:
$(RUFF) check
$(MYPY) . --show-column-numbers --show-error-codes --pretty
# Run tests
test:
$(PYTEST)
# Run format, lint, and test
check: format lint test

32
server/README.md Normal file
View File

@ -0,0 +1,32 @@
# graph-service
Graph service is a fast api server implementing the Graphiti package.
## Running Instructions
1. Ensure you have Docker and Docker Compose installed on your system.
2. Clone the repository and navigate to the `graph-service` directory.
3. Create a `.env` file in the `graph-service` directory with the following content:
```
OPENAI_API_KEY=your_openai_api_key
NEO4J_USER=neo4j
NEO4J_PASSWORD=your_neo4j_password
NEO4J_PORT=7687
```
Replace `your_openai_api_key` and `your_neo4j_password` with your actual OpenAI API key and desired Neo4j password.
4. Run the following command to start the services:
```
docker-compose up --build
```
5. The graph service will be available at `http://localhost:8000`.
6. You may access the swagger docs at `http://localhost:8000/docs`.
7. You may also access the neo4j browser at `http://localhost:7474`.

View File

View File

@ -0,0 +1,22 @@
from functools import lru_cache
from typing import Annotated
from fastapi import Depends
from pydantic_settings import BaseSettings, SettingsConfigDict # type: ignore
class Settings(BaseSettings):
openai_api_key: str
neo4j_uri: str
neo4j_user: str
neo4j_password: str
model_config = SettingsConfigDict(env_file='.env', extra='ignore')
@lru_cache
def get_settings():
return Settings()
ZepEnvDep = Annotated[Settings, Depends(get_settings)]

View File

@ -0,0 +1,20 @@
from .common import Message, Result
from .ingest import AddMessagesRequest
from .retrieve import (
FactResult,
GetMemoryRequest,
GetMemoryResponse,
SearchQuery,
SearchResults,
)
__all__ = [
'SearchQuery',
'Message',
'AddMessagesRequest',
'SearchResults',
'FactResult',
'Result',
'GetMemoryRequest',
'GetMemoryResponse',
]

View File

@ -0,0 +1,28 @@
from datetime import datetime
from typing import Literal
from pydantic import BaseModel, Field
class Result(BaseModel):
message: str
success: bool
class Message(BaseModel):
content: str = Field(..., description='The content of the message')
name: str = Field(
default='', description='The name of the episodic node for the message (message uuid)'
)
role_type: Literal['user', 'assistant', 'system'] = Field(
..., description='The role type of the message (user, assistant or system)'
)
role: str | None = Field(
description='The custom role of the message to be used alongside role_type (user name, bot name, etc.)',
)
timestamp: datetime = Field(
default_factory=datetime.now, description='The timestamp of the message'
)
source_description: str = Field(
default='', description='The description of the source of the message'
)

View File

@ -0,0 +1,8 @@
from pydantic import BaseModel, Field
from graph_service.dto.common import Message
class AddMessagesRequest(BaseModel):
group_id: str = Field(..., description='The group id of the messages to add')
messages: list[Message] = Field(..., description='The messages to add')

View File

@ -0,0 +1,41 @@
from datetime import datetime
from typing import Literal
from pydantic import BaseModel, Field
from graph_service.dto.common import Message
class SearchQuery(BaseModel):
group_id: str = Field(..., description='The group id of the memory to get')
query: str
max_facts: int = Field(default=10, description='The maximum number of facts to retrieve')
search_type: Literal['facts', 'user_centered_facts'] = Field(
default='facts', description='The type of search to perform'
)
class FactResult(BaseModel):
uuid: str
name: str
fact: str
valid_at: datetime | None
invalid_at: datetime | None
created_at: datetime
expired_at: datetime | None
class SearchResults(BaseModel):
facts: list[FactResult]
class GetMemoryRequest(BaseModel):
group_id: str = Field(..., description='The group id of the memory to get')
max_facts: int = Field(default=10, description='The maximum number of facts to retrieve')
messages: list[Message] = Field(
..., description='The messages to build the retrieval query from '
)
class GetMemoryResponse(BaseModel):
facts: list[FactResult] = Field(..., description='The facts that were retrieved from the graph')

View File

@ -0,0 +1,14 @@
from fastapi import FastAPI
from graph_service.routers import ingest, retrieve
app = FastAPI()
app.include_router(retrieve.router)
app.include_router(ingest.router)
@app.get('/')
def read_root():
return {'Hello': 'World'}

View File

View File

@ -0,0 +1,78 @@
import asyncio
from contextlib import asynccontextmanager
from functools import partial
from fastapi import APIRouter, FastAPI, status
from graphiti_core.nodes import EpisodeType # type: ignore
from graphiti_core.utils import clear_data # type: ignore
from graph_service.dto import AddMessagesRequest, Message, Result
from graph_service.zep_graphiti import ZepGraphitiDep
class AsyncWorker:
def __init__(self):
self.queue = asyncio.Queue()
self.task = None
async def worker(self):
while True:
try:
print(f'Got a job: (size of remaining queue: {self.queue.qsize()})')
job = await self.queue.get()
await job()
except asyncio.CancelledError:
break
async def start(self):
self.task = asyncio.create_task(self.worker())
async def stop(self):
if self.task:
self.task.cancel()
await self.task
while not self.queue.empty():
self.queue.get_nowait()
async_worker = AsyncWorker()
@asynccontextmanager
async def lifespan(_: FastAPI):
await async_worker.start()
yield
await async_worker.stop()
router = APIRouter(lifespan=lifespan)
@router.post('/messages', status_code=status.HTTP_202_ACCEPTED)
async def add_messages(
request: AddMessagesRequest,
graphiti: ZepGraphitiDep,
):
async def add_messages_task(m: Message):
# Will pass a group_id to the add_episode call once it is implemented
await graphiti.add_episode(
name=m.name,
episode_body=f"{m.role or ''}({m.role_type}): {m.content}",
reference_time=m.timestamp,
source=EpisodeType.message,
source_description=m.source_description,
)
for m in request.messages:
await async_worker.queue.put(partial(add_messages_task, m))
return Result(message='Messages added to processing queue', success=True)
@router.post('/clear', status_code=status.HTTP_200_OK)
async def clear(
graphiti: ZepGraphitiDep,
):
await clear_data(graphiti.driver)
await graphiti.build_indices_and_constraints()
return Result(message='Graph cleared', success=True)

View File

@ -0,0 +1,51 @@
from fastapi import APIRouter, status
from graph_service.dto import (
GetMemoryRequest,
GetMemoryResponse,
Message,
SearchQuery,
SearchResults,
)
from graph_service.zep_graphiti import ZepGraphitiDep, get_fact_result_from_edge
router = APIRouter()
@router.post('/search', status_code=status.HTTP_200_OK)
async def search(query: SearchQuery, graphiti: ZepGraphitiDep):
center_node_uuid: str | None = None
if query.search_type == 'user_centered_facts':
user_node = await graphiti.get_user_node(query.group_id)
if user_node:
center_node_uuid = user_node.uuid
relevant_edges = await graphiti.search(
query=query.query,
num_results=query.max_facts,
center_node_uuid=center_node_uuid,
)
facts = [get_fact_result_from_edge(edge) for edge in relevant_edges]
return SearchResults(
facts=facts,
)
@router.post('/get-memory', status_code=status.HTTP_200_OK)
async def get_memory(
request: GetMemoryRequest,
graphiti: ZepGraphitiDep,
):
combined_query = compose_query_from_messages(request.messages)
result = await graphiti.search(
query=combined_query,
num_results=request.max_facts,
)
facts = [get_fact_result_from_edge(edge) for edge in result]
return GetMemoryResponse(facts=facts)
def compose_query_from_messages(messages: list[Message]):
combined_query = ''
for message in messages:
combined_query += f"{message.role_type or ''}({message.role or ''}): {message.content}\n"
return combined_query

View File

@ -0,0 +1,48 @@
from typing import Annotated
from fastapi import Depends
from graphiti_core import Graphiti # type: ignore
from graphiti_core.edges import EntityEdge # type: ignore
from graphiti_core.llm_client import LLMClient # type: ignore
from graphiti_core.nodes import EntityNode # type: ignore
from graph_service.config import ZepEnvDep
from graph_service.dto import FactResult
class ZepGraphiti(Graphiti):
def __init__(
self, uri: str, user: str, password: str, user_id: str, llm_client: LLMClient | None = None
):
super().__init__(uri, user, password, llm_client)
self.user_id = user_id
async def get_user_node(self, user_id: str) -> EntityNode | None: ...
async def get_graphiti(settings: ZepEnvDep):
client = ZepGraphiti(
uri=settings.neo4j_uri,
user=settings.neo4j_user,
password=settings.neo4j_password,
user_id='test1234',
)
try:
yield client
finally:
client.close()
def get_fact_result_from_edge(edge: EntityEdge):
return FactResult(
uuid=edge.uuid,
name=edge.name,
fact=edge.fact,
valid_at=edge.valid_at,
invalid_at=edge.invalid_at,
created_at=edge.created_at,
expired_at=edge.expired_at,
)
ZepGraphitiDep = Annotated[ZepGraphiti, Depends(get_graphiti)]

1276
server/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

59
server/pyproject.toml Normal file
View File

@ -0,0 +1,59 @@
[tool.poetry]
name = "graph-service"
version = "0.1.0"
description = "Zep Graph service implementing Graphiti package"
authors = ["Paul Paliychuk <paul@getzep.com>"]
readme = "README.md"
packages = [{ include = "graph_service" }]
[tool.poetry.dependencies]
python = "^3.10"
fastapi = "^0.112.2"
graphiti-core = { path = "../" }
pydantic-settings = "^2.4.0"
uvicorn = "^0.30.6"
[tool.poetry.dev-dependencies]
pytest = "^8.3.2"
python-dotenv = "^1.0.1"
pytest-asyncio = "^0.24.0"
pytest-xdist = "^3.6.1"
ruff = "^0.6.2"
fastapi-cli = "^0.0.5"
[tool.poetry.group.dev.dependencies]
pydantic = "^2.8.2"
mypy = "^1.11.1"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
pythonpath = ["."]
[tool.ruff]
line-length = 100
[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
]
ignore = ["E501"]
[tool.ruff.format]
quote-style = "single"
indent-style = "space"
docstring-code-format = true