mirror of
https://github.com/getzep/graphiti.git
synced 2025-12-26 14:45:20 +00:00
Merge branch 'main' of github.com:getzep/graphiti into config-embedding-model
# Conflicts: # graphiti_core/search/search.py
This commit is contained in:
commit
e4bc756c31
126
examples/multi_session_conversation_memory/msc_eval.py
Normal file
126
examples/multi_session_conversation_memory/msc_eval.py
Normal file
@ -0,0 +1,126 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import csv
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from time import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from examples.multi_session_conversation_memory.parse_msc_messages import conversation_q_and_a
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
|
||||
|
||||
load_dotenv()
|
||||
|
||||
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
||||
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
||||
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Create a logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||
|
||||
# Create console handler and set level to INFO
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
# Add formatter to console handler
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Add console handler to logger
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
async def evaluate_qa(graphiti: Graphiti, group_id: str, query: str, answer: str):
|
||||
search_start = time()
|
||||
results = await graphiti._search(query, COMBINED_HYBRID_SEARCH_RRF, group_ids=[str(group_id)])
|
||||
search_end = time()
|
||||
search_duration = search_end - search_start
|
||||
|
||||
facts = [edge.fact for edge in results.edges]
|
||||
entity_summaries = [node.name + ': ' + node.summary for node in results.nodes]
|
||||
context = {'facts': facts, 'entity_summaries': entity_summaries, 'query': 'Bob: ' + query}
|
||||
|
||||
llm_response = await graphiti.llm_client.generate_response(
|
||||
prompt_library.eval.qa_prompt(context)
|
||||
)
|
||||
response = llm_response.get('ANSWER', '')
|
||||
|
||||
eval_context = {
|
||||
'query': 'Bob: ' + query,
|
||||
'answer': 'Alice: ' + answer,
|
||||
'response': 'Alice: ' + response,
|
||||
}
|
||||
|
||||
eval_llm_response = await graphiti.llm_client.generate_response(
|
||||
prompt_library.eval.eval_prompt(eval_context)
|
||||
)
|
||||
eval_response = 1 if eval_llm_response.get('is_correct', False) else 0
|
||||
|
||||
return {
|
||||
'Group id': group_id,
|
||||
'Question': query,
|
||||
'Answer': answer,
|
||||
'Response': response,
|
||||
'Score': eval_response,
|
||||
'Search Duration (ms)': search_duration * 1000,
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
setup_logging()
|
||||
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
|
||||
fields = ['Group id', 'Question', 'Answer', 'Response', 'Score', 'Search Duration (ms)']
|
||||
with open('../data/msc_eval.csv', 'w', newline='') as file:
|
||||
writer = csv.DictWriter(file, fieldnames=fields)
|
||||
writer.writeheader()
|
||||
|
||||
qa = conversation_q_and_a()[0:500]
|
||||
i = 0
|
||||
while i < 500:
|
||||
qa_chunk = qa[i : i + 20]
|
||||
group_ids = range(len(qa))[i : i + 20]
|
||||
results = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
evaluate_qa(graphiti, str(group_id), query, answer)
|
||||
for group_id, (query, answer) in zip(group_ids, qa_chunk)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with open('../data/msc_eval.csv', 'a', newline='') as file:
|
||||
writer = csv.DictWriter(file, fieldnames=fields)
|
||||
writer.writerows(results)
|
||||
i += 20
|
||||
|
||||
await graphiti.close()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
91
examples/multi_session_conversation_memory/msc_runner.py
Normal file
91
examples/multi_session_conversation_memory/msc_runner.py
Normal file
@ -0,0 +1,91 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from examples.multi_session_conversation_memory.parse_msc_messages import (
|
||||
ParsedMscMessage,
|
||||
parse_msc_messages,
|
||||
)
|
||||
from graphiti_core import Graphiti
|
||||
|
||||
load_dotenv()
|
||||
|
||||
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
||||
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
||||
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Create a logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||
|
||||
# Create console handler and set level to INFO
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
# Add formatter to console handler
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Add console handler to logger
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
async def add_conversation(graphiti: Graphiti, group_id: str, messages: list[ParsedMscMessage]):
|
||||
for i, message in enumerate(messages):
|
||||
await graphiti.add_episode(
|
||||
name=f'Message {group_id + "-" + str(i)}',
|
||||
episode_body=f'{message.speaker_name}: {message.content}',
|
||||
reference_time=message.actual_timestamp,
|
||||
source_description='Multi-Session Conversation',
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
setup_logging()
|
||||
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
msc_messages = parse_msc_messages()
|
||||
i = 0
|
||||
while i <= 490:
|
||||
msc_message_slice = msc_messages[i : i + 10]
|
||||
group_ids = range(len(msc_messages))[i : i + 10]
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
add_conversation(graphiti, str(group_id), messages)
|
||||
for group_id, messages in zip(group_ids, msc_message_slice)
|
||||
]
|
||||
)
|
||||
|
||||
i += 10
|
||||
|
||||
# build communities
|
||||
# await client.build_communities()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
@ -0,0 +1,85 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ParsedMscMessage(BaseModel):
|
||||
speaker_name: str
|
||||
actual_timestamp: datetime
|
||||
content: str
|
||||
group_id: str
|
||||
|
||||
|
||||
def parse_msc_messages() -> list[list[ParsedMscMessage]]:
|
||||
msc_messages: list[list[ParsedMscMessage]] = []
|
||||
speakers = ['Alice', 'Bob']
|
||||
|
||||
with open('../data/msc.json') as file:
|
||||
data = json.load(file)['data']
|
||||
for i, conversation in enumerate(data):
|
||||
messages: list[ParsedMscMessage] = []
|
||||
for previous_dialog in conversation['previous_dialogs']:
|
||||
dialog = previous_dialog['dialog']
|
||||
speaker_idx = 0
|
||||
|
||||
for utterance in dialog:
|
||||
content = utterance['text']
|
||||
messages.append(
|
||||
ParsedMscMessage(
|
||||
speaker_name=speakers[speaker_idx],
|
||||
content=content,
|
||||
actual_timestamp=datetime.now(),
|
||||
group_id=str(i),
|
||||
)
|
||||
)
|
||||
speaker_idx += 1
|
||||
speaker_idx %= 2
|
||||
|
||||
dialog = conversation['dialog']
|
||||
speaker_idx = 0
|
||||
for utterance in dialog:
|
||||
content = utterance['text']
|
||||
messages.append(
|
||||
ParsedMscMessage(
|
||||
speaker_name=speakers[speaker_idx],
|
||||
content=content,
|
||||
actual_timestamp=datetime.now(),
|
||||
group_id=str(i),
|
||||
)
|
||||
)
|
||||
speaker_idx += 1
|
||||
speaker_idx %= 2
|
||||
|
||||
msc_messages.append(messages)
|
||||
|
||||
return msc_messages
|
||||
|
||||
|
||||
def conversation_q_and_a() -> list[tuple[str, str]]:
|
||||
with open('../data/msc.json') as file:
|
||||
data = json.load(file)['data']
|
||||
|
||||
qa: list[tuple[str, str]] = []
|
||||
for conversation in data:
|
||||
query = conversation['self_instruct']['B']
|
||||
answer = conversation['self_instruct']['A']
|
||||
|
||||
qa.append((query, answer))
|
||||
return qa
|
||||
@ -161,7 +161,7 @@ class Graphiti:
|
||||
"""
|
||||
await self.driver.close()
|
||||
|
||||
async def build_indices_and_constraints(self):
|
||||
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
||||
"""
|
||||
Build indices and constraints in the Neo4j database.
|
||||
|
||||
@ -171,6 +171,9 @@ class Graphiti:
|
||||
Parameters
|
||||
----------
|
||||
self
|
||||
delete_existing : bool, optional
|
||||
Whether to clear existing indices before creating new ones.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -191,7 +194,7 @@ class Graphiti:
|
||||
Caution: Running this method on a large existing database may take some time
|
||||
and could impact database performance during execution.
|
||||
"""
|
||||
await build_indices_and_constraints(self.driver)
|
||||
await build_indices_and_constraints(self.driver, delete_existing)
|
||||
|
||||
async def retrieve_episodes(
|
||||
self,
|
||||
|
||||
@ -21,3 +21,33 @@ from neo4j import time as neo4j_time
|
||||
|
||||
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
||||
return neo_date.to_native() if neo_date else None
|
||||
|
||||
|
||||
def lucene_sanitize(query: str) -> str:
|
||||
# Escape special characters from a query before passing into Lucene
|
||||
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \
|
||||
escape_map = str.maketrans(
|
||||
{
|
||||
'+': r'\+',
|
||||
'-': r'\-',
|
||||
'&': r'\&',
|
||||
'|': r'\|',
|
||||
'!': r'\!',
|
||||
'(': r'\(',
|
||||
')': r'\)',
|
||||
'{': r'\{',
|
||||
'}': r'\}',
|
||||
'[': r'\[',
|
||||
']': r'\]',
|
||||
'^': r'\^',
|
||||
'"': r'\"',
|
||||
'~': r'\~',
|
||||
'*': r'\*',
|
||||
'?': r'\?',
|
||||
':': r'\:',
|
||||
'\\': r'\\',
|
||||
}
|
||||
)
|
||||
|
||||
sanitized = query.translate(escape_map)
|
||||
return sanitized
|
||||
|
||||
90
graphiti_core/prompts/eval.py
Normal file
90
graphiti_core/prompts/eval.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Protocol, TypedDict
|
||||
|
||||
from .models import Message, PromptFunction, PromptVersion
|
||||
|
||||
|
||||
class Prompt(Protocol):
|
||||
qa_prompt: PromptVersion
|
||||
eval_prompt: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
qa_prompt: PromptFunction
|
||||
eval_prompt: PromptFunction
|
||||
|
||||
|
||||
def qa_prompt(context: dict[str, Any]) -> list[Message]:
|
||||
sys_prompt = """You are Alice and should respond to all questions from the first person perspective of Alice"""
|
||||
|
||||
user_prompt = f"""
|
||||
Your task is to briefly answer the question in the way that you think Alice would answer the question.
|
||||
You are given the following entity summaries and facts to help you determine the answer to your question.
|
||||
<ENTITY_SUMMARIES>
|
||||
{json.dumps(context['entity_summaries'])}
|
||||
</ENTITY_SUMMARIES
|
||||
<FACTS>
|
||||
{json.dumps(context['facts'])}
|
||||
</FACTS>
|
||||
<QUESTION>
|
||||
{context['query']}
|
||||
</QUESTION>
|
||||
respond with a JSON object in the following format:
|
||||
{{
|
||||
"ANSWER": "how Alice would answer the question"
|
||||
}}
|
||||
"""
|
||||
return [
|
||||
Message(role='system', content=sys_prompt),
|
||||
Message(role='user', content=user_prompt),
|
||||
]
|
||||
|
||||
|
||||
def eval_prompt(context: dict[str, Any]) -> list[Message]:
|
||||
sys_prompt = (
|
||||
"""You are a judge that determines if answers to questions match a gold standard answer"""
|
||||
)
|
||||
|
||||
user_prompt = f"""
|
||||
Given the QUESTION and the gold standard ANSWER determine if the RESPONSE to the question is correct or incorrect.
|
||||
Although the RESPONSE may be more verbose, mark it as correct as long as it references the same topic
|
||||
as the gold standard ANSWER. Also include your reasoning for the grade.
|
||||
<QUESTION>
|
||||
{context['query']}
|
||||
</QUESTION>
|
||||
<ANSWER>
|
||||
{context['answer']}
|
||||
</ANSWER>
|
||||
<RESPONSE>
|
||||
{context['response']}
|
||||
</RESPONSE>
|
||||
|
||||
respond with a JSON object in the following format:
|
||||
{{
|
||||
"is_correct": "boolean if the answer is correct or incorrect"
|
||||
"reasoning": "why you determined the response was correct or incorrect"
|
||||
}}
|
||||
"""
|
||||
return [
|
||||
Message(role='system', content=sys_prompt),
|
||||
Message(role='user', content=user_prompt),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {'qa_prompt': qa_prompt, 'eval_prompt': eval_prompt}
|
||||
@ -34,6 +34,9 @@ from .dedupe_nodes import (
|
||||
from .dedupe_nodes import (
|
||||
versions as dedupe_nodes_versions,
|
||||
)
|
||||
from .eval import Prompt as EvalPrompt
|
||||
from .eval import Versions as EvalVersions
|
||||
from .eval import versions as eval_versions
|
||||
from .extract_edge_dates import (
|
||||
Prompt as ExtractEdgeDatesPrompt,
|
||||
)
|
||||
@ -84,6 +87,7 @@ class PromptLibrary(Protocol):
|
||||
invalidate_edges: InvalidateEdgesPrompt
|
||||
extract_edge_dates: ExtractEdgeDatesPrompt
|
||||
summarize_nodes: SummarizeNodesPrompt
|
||||
eval: EvalPrompt
|
||||
|
||||
|
||||
class PromptLibraryImpl(TypedDict):
|
||||
@ -94,6 +98,7 @@ class PromptLibraryImpl(TypedDict):
|
||||
invalidate_edges: InvalidateEdgesVersions
|
||||
extract_edge_dates: ExtractEdgeDatesVersions
|
||||
summarize_nodes: SummarizeNodesVersions
|
||||
eval: EvalVersions
|
||||
|
||||
|
||||
class VersionWrapper:
|
||||
@ -124,5 +129,6 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
|
||||
'invalidate_edges': invalidate_edges_versions,
|
||||
'extract_edge_dates': extract_edge_dates_versions,
|
||||
'summarize_nodes': summarize_nodes_versions,
|
||||
'eval': eval_versions,
|
||||
}
|
||||
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]
|
||||
|
||||
@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from time import time
|
||||
@ -65,32 +66,20 @@ async def search(
|
||||
query = query.replace('\n', ' ')
|
||||
# if group_ids is empty, set it to None
|
||||
group_ids = group_ids if group_ids else None
|
||||
edges = (
|
||||
await edge_search(
|
||||
driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit, config.embedding_model
|
||||
)
|
||||
if config.edge_config is not None
|
||||
else []
|
||||
)
|
||||
nodes = (
|
||||
await node_search(
|
||||
driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit, config.embedding_model
|
||||
)
|
||||
if config.node_config is not None
|
||||
else []
|
||||
)
|
||||
communities = (
|
||||
await community_search(
|
||||
driver, embedder, query, group_ids, config.community_config, config.limit, config.embedding_model
|
||||
)
|
||||
if config.community_config is not None
|
||||
else []
|
||||
edges, nodes, communities = await asyncio.gather(
|
||||
edge_search(
|
||||
driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit
|
||||
),
|
||||
node_search(
|
||||
driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit
|
||||
),
|
||||
community_search(driver, embedder, query, group_ids, config.community_config, config.limit),
|
||||
)
|
||||
|
||||
results = SearchResults(
|
||||
edges=edges[: config.limit],
|
||||
nodes=nodes[: config.limit],
|
||||
communities=communities[: config.limit],
|
||||
edges=edges,
|
||||
nodes=nodes,
|
||||
communities=communities,
|
||||
)
|
||||
|
||||
end = time()
|
||||
@ -105,11 +94,14 @@ async def edge_search(
|
||||
embedder,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: EdgeSearchConfig,
|
||||
config: EdgeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
embedding_model: str | None = None,
|
||||
) -> list[EntityEdge]:
|
||||
if config is None:
|
||||
return []
|
||||
|
||||
search_results: list[list[EntityEdge]] = []
|
||||
|
||||
if EdgeSearchMethod.bm25 in config.search_methods:
|
||||
@ -163,7 +155,7 @@ async def edge_search(
|
||||
if config.reranker == EdgeReranker.episode_mentions:
|
||||
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
|
||||
|
||||
return reranked_edges
|
||||
return reranked_edges[:limit]
|
||||
|
||||
|
||||
async def node_search(
|
||||
@ -171,11 +163,14 @@ async def node_search(
|
||||
embedder,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: NodeSearchConfig,
|
||||
config: NodeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
embedding_model: str | None = None,
|
||||
) -> list[EntityNode]:
|
||||
if config is None:
|
||||
return []
|
||||
|
||||
search_results: list[list[EntityNode]] = []
|
||||
|
||||
if NodeSearchMethod.bm25 in config.search_methods:
|
||||
@ -214,7 +209,7 @@ async def node_search(
|
||||
|
||||
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
|
||||
return reranked_nodes
|
||||
return reranked_nodes[:limit]
|
||||
|
||||
|
||||
async def community_search(
|
||||
@ -222,10 +217,13 @@ async def community_search(
|
||||
embedder,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: CommunitySearchConfig,
|
||||
config: CommunitySearchConfig | None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
embedding_model: str | None = None,
|
||||
) -> list[CommunityNode]:
|
||||
if config is None:
|
||||
return []
|
||||
|
||||
search_results: list[list[CommunityNode]] = []
|
||||
|
||||
if CommunitySearchMethod.bm25 in config.search_methods:
|
||||
@ -258,4 +256,4 @@ async def community_search(
|
||||
|
||||
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
|
||||
return reranked_communities
|
||||
return reranked_communities[:limit]
|
||||
|
||||
@ -16,13 +16,13 @@ limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from time import time
|
||||
|
||||
from neo4j import AsyncDriver, Query
|
||||
|
||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||
from graphiti_core.helpers import lucene_sanitize
|
||||
from graphiti_core.nodes import (
|
||||
CommunityNode,
|
||||
EntityNode,
|
||||
@ -36,6 +36,22 @@ logger = logging.getLogger(__name__)
|
||||
RELEVANT_SCHEMA_LIMIT = 3
|
||||
|
||||
|
||||
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||
group_ids_filter_list = (
|
||||
[f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else []
|
||||
)
|
||||
group_ids_filter = ''
|
||||
for f in group_ids_filter_list:
|
||||
group_ids_filter += f if not group_ids_filter else f'OR {f}'
|
||||
|
||||
group_ids_filter += ' AND ' if group_ids_filter else ''
|
||||
|
||||
fuzzy_query = lucene_sanitize(query) + '~'
|
||||
full_query = group_ids_filter + fuzzy_query
|
||||
|
||||
return full_query
|
||||
|
||||
|
||||
async def get_mentioned_nodes(
|
||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||
) -> list[EntityNode]:
|
||||
@ -91,11 +107,15 @@ async def edge_fulltext_search(
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
# fulltext search over facts
|
||||
fuzzy_query = fulltext_query(query, group_ids)
|
||||
|
||||
cypher_query = Query("""
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE ($source_uuid IS NULL OR n.uuid = $source_uuid)
|
||||
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
|
||||
AND ($group_ids IS NULL OR n.group_id IN $group_ids)
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
@ -112,72 +132,6 @@ async def edge_fulltext_search(
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
""")
|
||||
|
||||
if source_node_uuid is None and target_node_uuid is None:
|
||||
cypher_query = Query("""
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
""")
|
||||
elif source_node_uuid is None:
|
||||
cypher_query = Query("""
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
""")
|
||||
elif target_node_uuid is None:
|
||||
cypher_query = Query("""
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
""")
|
||||
|
||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
cypher_query,
|
||||
query=fuzzy_query,
|
||||
@ -202,11 +156,12 @@ async def edge_similarity_search(
|
||||
) -> list[EntityEdge]:
|
||||
# vector similarity search over embedded facts
|
||||
query = Query("""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
||||
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
||||
AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
|
||||
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
|
||||
RETURN
|
||||
vector.similarity.cosine(r.fact_embedding, $search_vector) AS score,
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
@ -220,72 +175,9 @@ async def edge_similarity_search(
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""")
|
||||
|
||||
if source_node_uuid is None and target_node_uuid is None:
|
||||
query = Query("""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC
|
||||
""")
|
||||
elif source_node_uuid is None:
|
||||
query = Query("""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC
|
||||
""")
|
||||
elif target_node_uuid is None:
|
||||
query = Query("""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC
|
||||
""")
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
search_vector=search_vector,
|
||||
@ -307,10 +199,11 @@ async def node_fulltext_search(
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
# BM25 search to get top nodes
|
||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||
fuzzy_query = fulltext_query(query, group_ids)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
||||
YIELD node AS n, score
|
||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||
RETURN
|
||||
@ -341,11 +234,10 @@ async def node_similarity_search(
|
||||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
||||
YIELD node AS n, score
|
||||
MATCH (n:Entity)
|
||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||
RETURN
|
||||
vector.similarity.cosine(n.name_embedding, $search_vector) AS score,
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
@ -353,6 +245,7 @@ async def node_similarity_search(
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
@ -370,7 +263,8 @@ async def community_fulltext_search(
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[CommunityNode]:
|
||||
# BM25 search to get top communities
|
||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||
fuzzy_query = fulltext_query(query, group_ids)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("community_name", $query)
|
||||
@ -405,11 +299,10 @@ async def community_similarity_search(
|
||||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
|
||||
YIELD node AS comm, score
|
||||
MATCH (comm:Community)
|
||||
WHERE $group_ids IS NULL OR comm.group_id IN $group_ids
|
||||
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
||||
RETURN
|
||||
vector.similarity.cosine(comm.name_embedding, $search_vector) AS score,
|
||||
comm.uuid As uuid,
|
||||
comm.group_id AS group_id,
|
||||
comm.name AS name,
|
||||
@ -417,6 +310,7 @@ async def community_similarity_search(
|
||||
comm.created_at AS created_at,
|
||||
comm.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
|
||||
@ -28,7 +28,16 @@ EPISODE_WINDOW_LEN = 3
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def build_indices_and_constraints(driver: AsyncDriver):
|
||||
async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bool = False):
|
||||
if delete_existing:
|
||||
records, _, _ = await driver.execute_query("""
|
||||
SHOW INDEXES YIELD name
|
||||
""")
|
||||
index_names = [record['name'] for record in records]
|
||||
await asyncio.gather(
|
||||
*[driver.execute_query("""DROP INDEX $name""", name=name) for name in index_names]
|
||||
)
|
||||
|
||||
range_indices: list[LiteralString] = [
|
||||
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
||||
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
||||
@ -52,38 +61,15 @@ async def build_indices_and_constraints(driver: AsyncDriver):
|
||||
]
|
||||
|
||||
fulltext_indices: list[LiteralString] = [
|
||||
'CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]',
|
||||
'CREATE FULLTEXT INDEX community_name IF NOT EXISTS FOR (n:Community) ON EACH [n.name]',
|
||||
'CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]',
|
||||
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
||||
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
||||
FOR (n:Community) ON EACH [n.name, n.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
|
||||
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
|
||||
]
|
||||
|
||||
vector_indices: list[LiteralString] = [
|
||||
"""
|
||||
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
|
||||
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""",
|
||||
"""
|
||||
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
|
||||
FOR (n:Entity) ON (n.name_embedding)
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""",
|
||||
"""
|
||||
CREATE VECTOR INDEX community_name_embedding IF NOT EXISTS
|
||||
FOR (n:Community) ON (n.name_embedding)
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""",
|
||||
]
|
||||
index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices
|
||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||
|
||||
await asyncio.gather(*[driver.execute_query(query) for query in index_queries])
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "graphiti-core"
|
||||
version = "0.3.6"
|
||||
version = "0.3.7"
|
||||
description = "A temporal graph building library"
|
||||
authors = [
|
||||
"Paul Paliychuk <paul@getzep.com>",
|
||||
|
||||
@ -74,7 +74,6 @@ def format_context(facts):
|
||||
async def test_graphiti_init():
|
||||
logger = setup_logging()
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
await graphiti.build_communities()
|
||||
|
||||
edges = await graphiti.search(
|
||||
'tania tetlow', center_node_uuid='4bf7ebb3-3a98-46c7-90a6-8e516c487961', group_ids=None
|
||||
@ -96,7 +95,7 @@ async def test_graphiti_init():
|
||||
}
|
||||
|
||||
logger.info(pretty_results)
|
||||
graphiti.close()
|
||||
await graphiti.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user