2024-09-10 11:00:52 -04:00
|
|
|
from datetime import datetime
|
|
|
|
|
|
2024-09-06 11:07:45 -04:00
|
|
|
from fastapi import APIRouter, status
|
|
|
|
|
|
|
|
|
|
from graph_service.dto import (
|
|
|
|
|
GetMemoryRequest,
|
|
|
|
|
GetMemoryResponse,
|
|
|
|
|
Message,
|
|
|
|
|
SearchQuery,
|
|
|
|
|
SearchResults,
|
|
|
|
|
)
|
|
|
|
|
from graph_service.zep_graphiti import ZepGraphitiDep, get_fact_result_from_edge
|
|
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post('/search', status_code=status.HTTP_200_OK)
|
|
|
|
|
async def search(query: SearchQuery, graphiti: ZepGraphitiDep):
|
|
|
|
|
relevant_edges = await graphiti.search(
|
2024-09-24 16:36:24 -04:00
|
|
|
group_ids=query.group_ids,
|
2024-09-06 11:07:45 -04:00
|
|
|
query=query.query,
|
|
|
|
|
num_results=query.max_facts,
|
|
|
|
|
)
|
|
|
|
|
facts = [get_fact_result_from_edge(edge) for edge in relevant_edges]
|
|
|
|
|
return SearchResults(
|
|
|
|
|
facts=facts,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2024-09-11 12:53:17 -04:00
|
|
|
@router.get('/entity-edge/{uuid}', status_code=status.HTTP_200_OK)
|
|
|
|
|
async def get_entity_edge(uuid: str, graphiti: ZepGraphitiDep):
|
2024-09-18 12:48:44 -04:00
|
|
|
entity_edge = await graphiti.get_entity_edge(uuid)
|
|
|
|
|
return get_fact_result_from_edge(entity_edge)
|
2024-09-11 12:53:17 -04:00
|
|
|
|
|
|
|
|
|
2024-09-10 11:00:52 -04:00
|
|
|
@router.get('/episodes/{group_id}', status_code=status.HTTP_200_OK)
|
|
|
|
|
async def get_episodes(group_id: str, last_n: int, graphiti: ZepGraphitiDep):
|
|
|
|
|
episodes = await graphiti.retrieve_episodes(
|
|
|
|
|
group_ids=[group_id], last_n=last_n, reference_time=datetime.now()
|
|
|
|
|
)
|
|
|
|
|
return episodes
|
|
|
|
|
|
|
|
|
|
|
2024-09-06 11:07:45 -04:00
|
|
|
@router.post('/get-memory', status_code=status.HTTP_200_OK)
|
|
|
|
|
async def get_memory(
|
|
|
|
|
request: GetMemoryRequest,
|
|
|
|
|
graphiti: ZepGraphitiDep,
|
|
|
|
|
):
|
|
|
|
|
combined_query = compose_query_from_messages(request.messages)
|
|
|
|
|
result = await graphiti.search(
|
2024-09-06 15:37:19 -04:00
|
|
|
group_ids=[request.group_id],
|
2024-09-06 11:07:45 -04:00
|
|
|
query=combined_query,
|
|
|
|
|
num_results=request.max_facts,
|
|
|
|
|
)
|
|
|
|
|
facts = [get_fact_result_from_edge(edge) for edge in result]
|
|
|
|
|
return GetMemoryResponse(facts=facts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compose_query_from_messages(messages: list[Message]):
|
|
|
|
|
combined_query = ''
|
|
|
|
|
for message in messages:
|
|
|
|
|
combined_query += f"{message.role_type or ''}({message.role or ''}): {message.content}\n"
|
|
|
|
|
return combined_query
|