mirror of
https://github.com/getzep/graphiti.git
synced 2025-11-09 23:07:25 +00:00
52 lines
1.5 KiB
Python
52 lines
1.5 KiB
Python
|
|
from fastapi import APIRouter, status
|
||
|
|
|
||
|
|
from graph_service.dto import (
|
||
|
|
GetMemoryRequest,
|
||
|
|
GetMemoryResponse,
|
||
|
|
Message,
|
||
|
|
SearchQuery,
|
||
|
|
SearchResults,
|
||
|
|
)
|
||
|
|
from graph_service.zep_graphiti import ZepGraphitiDep, get_fact_result_from_edge
|
||
|
|
|
||
|
|
router = APIRouter()
|
||
|
|
|
||
|
|
|
||
|
|
@router.post('/search', status_code=status.HTTP_200_OK)
|
||
|
|
async def search(query: SearchQuery, graphiti: ZepGraphitiDep):
|
||
|
|
center_node_uuid: str | None = None
|
||
|
|
if query.search_type == 'user_centered_facts':
|
||
|
|
user_node = await graphiti.get_user_node(query.group_id)
|
||
|
|
if user_node:
|
||
|
|
center_node_uuid = user_node.uuid
|
||
|
|
relevant_edges = await graphiti.search(
|
||
|
|
query=query.query,
|
||
|
|
num_results=query.max_facts,
|
||
|
|
center_node_uuid=center_node_uuid,
|
||
|
|
)
|
||
|
|
facts = [get_fact_result_from_edge(edge) for edge in relevant_edges]
|
||
|
|
return SearchResults(
|
||
|
|
facts=facts,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.post('/get-memory', status_code=status.HTTP_200_OK)
|
||
|
|
async def get_memory(
|
||
|
|
request: GetMemoryRequest,
|
||
|
|
graphiti: ZepGraphitiDep,
|
||
|
|
):
|
||
|
|
combined_query = compose_query_from_messages(request.messages)
|
||
|
|
result = await graphiti.search(
|
||
|
|
query=combined_query,
|
||
|
|
num_results=request.max_facts,
|
||
|
|
)
|
||
|
|
facts = [get_fact_result_from_edge(edge) for edge in result]
|
||
|
|
return GetMemoryResponse(facts=facts)
|
||
|
|
|
||
|
|
|
||
|
|
def compose_query_from_messages(messages: list[Message]):
|
||
|
|
combined_query = ''
|
||
|
|
for message in messages:
|
||
|
|
combined_query += f"{message.role_type or ''}({message.role or ''}): {message.content}\n"
|
||
|
|
return combined_query
|