Merge branch 'main' of github.com:getzep/graphiti into config-embedding-model

# Conflicts:
#	graphiti_core/search/search.py
This commit is contained in:
paulpaliychuk 2024-09-26 16:18:00 -04:00
commit e4bc756c31
12 changed files with 519 additions and 211 deletions

View File

@ -0,0 +1,126 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import csv
import logging
import os
import sys
from time import time
from dotenv import load_dotenv
from examples.multi_session_conversation_memory.parse_msc_messages import conversation_q_and_a
from graphiti_core import Graphiti
from graphiti_core.prompts import prompt_library
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
async def evaluate_qa(graphiti: Graphiti, group_id: str, query: str, answer: str):
search_start = time()
results = await graphiti._search(query, COMBINED_HYBRID_SEARCH_RRF, group_ids=[str(group_id)])
search_end = time()
search_duration = search_end - search_start
facts = [edge.fact for edge in results.edges]
entity_summaries = [node.name + ': ' + node.summary for node in results.nodes]
context = {'facts': facts, 'entity_summaries': entity_summaries, 'query': 'Bob: ' + query}
llm_response = await graphiti.llm_client.generate_response(
prompt_library.eval.qa_prompt(context)
)
response = llm_response.get('ANSWER', '')
eval_context = {
'query': 'Bob: ' + query,
'answer': 'Alice: ' + answer,
'response': 'Alice: ' + response,
}
eval_llm_response = await graphiti.llm_client.generate_response(
prompt_library.eval.eval_prompt(eval_context)
)
eval_response = 1 if eval_llm_response.get('is_correct', False) else 0
return {
'Group id': group_id,
'Question': query,
'Answer': answer,
'Response': response,
'Score': eval_response,
'Search Duration (ms)': search_duration * 1000,
}
async def main():
setup_logging()
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
fields = ['Group id', 'Question', 'Answer', 'Response', 'Score', 'Search Duration (ms)']
with open('../data/msc_eval.csv', 'w', newline='') as file:
writer = csv.DictWriter(file, fieldnames=fields)
writer.writeheader()
qa = conversation_q_and_a()[0:500]
i = 0
while i < 500:
qa_chunk = qa[i : i + 20]
group_ids = range(len(qa))[i : i + 20]
results = list(
await asyncio.gather(
*[
evaluate_qa(graphiti, str(group_id), query, answer)
for group_id, (query, answer) in zip(group_ids, qa_chunk)
]
)
)
with open('../data/msc_eval.csv', 'a', newline='') as file:
writer = csv.DictWriter(file, fieldnames=fields)
writer.writerows(results)
i += 20
await graphiti.close()
asyncio.run(main())

View File

@ -0,0 +1,91 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
import os
import sys
from dotenv import load_dotenv
from examples.multi_session_conversation_memory.parse_msc_messages import (
ParsedMscMessage,
parse_msc_messages,
)
from graphiti_core import Graphiti
load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
async def add_conversation(graphiti: Graphiti, group_id: str, messages: list[ParsedMscMessage]):
for i, message in enumerate(messages):
await graphiti.add_episode(
name=f'Message {group_id + "-" + str(i)}',
episode_body=f'{message.speaker_name}: {message.content}',
reference_time=message.actual_timestamp,
source_description='Multi-Session Conversation',
group_id=group_id,
)
async def main():
setup_logging()
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
msc_messages = parse_msc_messages()
i = 0
while i <= 490:
msc_message_slice = msc_messages[i : i + 10]
group_ids = range(len(msc_messages))[i : i + 10]
await asyncio.gather(
*[
add_conversation(graphiti, str(group_id), messages)
for group_id, messages in zip(group_ids, msc_message_slice)
]
)
i += 10
# build communities
# await client.build_communities()
asyncio.run(main())

View File

@ -0,0 +1,85 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
from datetime import datetime
from pydantic import BaseModel
class ParsedMscMessage(BaseModel):
speaker_name: str
actual_timestamp: datetime
content: str
group_id: str
def parse_msc_messages() -> list[list[ParsedMscMessage]]:
msc_messages: list[list[ParsedMscMessage]] = []
speakers = ['Alice', 'Bob']
with open('../data/msc.json') as file:
data = json.load(file)['data']
for i, conversation in enumerate(data):
messages: list[ParsedMscMessage] = []
for previous_dialog in conversation['previous_dialogs']:
dialog = previous_dialog['dialog']
speaker_idx = 0
for utterance in dialog:
content = utterance['text']
messages.append(
ParsedMscMessage(
speaker_name=speakers[speaker_idx],
content=content,
actual_timestamp=datetime.now(),
group_id=str(i),
)
)
speaker_idx += 1
speaker_idx %= 2
dialog = conversation['dialog']
speaker_idx = 0
for utterance in dialog:
content = utterance['text']
messages.append(
ParsedMscMessage(
speaker_name=speakers[speaker_idx],
content=content,
actual_timestamp=datetime.now(),
group_id=str(i),
)
)
speaker_idx += 1
speaker_idx %= 2
msc_messages.append(messages)
return msc_messages
def conversation_q_and_a() -> list[tuple[str, str]]:
with open('../data/msc.json') as file:
data = json.load(file)['data']
qa: list[tuple[str, str]] = []
for conversation in data:
query = conversation['self_instruct']['B']
answer = conversation['self_instruct']['A']
qa.append((query, answer))
return qa

View File

@ -161,7 +161,7 @@ class Graphiti:
"""
await self.driver.close()
async def build_indices_and_constraints(self):
async def build_indices_and_constraints(self, delete_existing: bool = False):
"""
Build indices and constraints in the Neo4j database.
@ -171,6 +171,9 @@ class Graphiti:
Parameters
----------
self
delete_existing : bool, optional
Whether to clear existing indices before creating new ones.
Returns
-------
@ -191,7 +194,7 @@ class Graphiti:
Caution: Running this method on a large existing database may take some time
and could impact database performance during execution.
"""
await build_indices_and_constraints(self.driver)
await build_indices_and_constraints(self.driver, delete_existing)
async def retrieve_episodes(
self,

View File

@ -21,3 +21,33 @@ from neo4j import time as neo4j_time
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
return neo_date.to_native() if neo_date else None
def lucene_sanitize(query: str) -> str:
# Escape special characters from a query before passing into Lucene
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \
escape_map = str.maketrans(
{
'+': r'\+',
'-': r'\-',
'&': r'\&',
'|': r'\|',
'!': r'\!',
'(': r'\(',
')': r'\)',
'{': r'\{',
'}': r'\}',
'[': r'\[',
']': r'\]',
'^': r'\^',
'"': r'\"',
'~': r'\~',
'*': r'\*',
'?': r'\?',
':': r'\:',
'\\': r'\\',
}
)
sanitized = query.translate(escape_map)
return sanitized

View File

@ -0,0 +1,90 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
from typing import Any, Protocol, TypedDict
from .models import Message, PromptFunction, PromptVersion
class Prompt(Protocol):
qa_prompt: PromptVersion
eval_prompt: PromptVersion
class Versions(TypedDict):
qa_prompt: PromptFunction
eval_prompt: PromptFunction
def qa_prompt(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are Alice and should respond to all questions from the first person perspective of Alice"""
user_prompt = f"""
Your task is to briefly answer the question in the way that you think Alice would answer the question.
You are given the following entity summaries and facts to help you determine the answer to your question.
<ENTITY_SUMMARIES>
{json.dumps(context['entity_summaries'])}
</ENTITY_SUMMARIES
<FACTS>
{json.dumps(context['facts'])}
</FACTS>
<QUESTION>
{context['query']}
</QUESTION>
respond with a JSON object in the following format:
{{
"ANSWER": "how Alice would answer the question"
}}
"""
return [
Message(role='system', content=sys_prompt),
Message(role='user', content=user_prompt),
]
def eval_prompt(context: dict[str, Any]) -> list[Message]:
sys_prompt = (
"""You are a judge that determines if answers to questions match a gold standard answer"""
)
user_prompt = f"""
Given the QUESTION and the gold standard ANSWER determine if the RESPONSE to the question is correct or incorrect.
Although the RESPONSE may be more verbose, mark it as correct as long as it references the same topic
as the gold standard ANSWER. Also include your reasoning for the grade.
<QUESTION>
{context['query']}
</QUESTION>
<ANSWER>
{context['answer']}
</ANSWER>
<RESPONSE>
{context['response']}
</RESPONSE>
respond with a JSON object in the following format:
{{
"is_correct": "boolean if the answer is correct or incorrect"
"reasoning": "why you determined the response was correct or incorrect"
}}
"""
return [
Message(role='system', content=sys_prompt),
Message(role='user', content=user_prompt),
]
versions: Versions = {'qa_prompt': qa_prompt, 'eval_prompt': eval_prompt}

View File

@ -34,6 +34,9 @@ from .dedupe_nodes import (
from .dedupe_nodes import (
versions as dedupe_nodes_versions,
)
from .eval import Prompt as EvalPrompt
from .eval import Versions as EvalVersions
from .eval import versions as eval_versions
from .extract_edge_dates import (
Prompt as ExtractEdgeDatesPrompt,
)
@ -84,6 +87,7 @@ class PromptLibrary(Protocol):
invalidate_edges: InvalidateEdgesPrompt
extract_edge_dates: ExtractEdgeDatesPrompt
summarize_nodes: SummarizeNodesPrompt
eval: EvalPrompt
class PromptLibraryImpl(TypedDict):
@ -94,6 +98,7 @@ class PromptLibraryImpl(TypedDict):
invalidate_edges: InvalidateEdgesVersions
extract_edge_dates: ExtractEdgeDatesVersions
summarize_nodes: SummarizeNodesVersions
eval: EvalVersions
class VersionWrapper:
@ -124,5 +129,6 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
'invalidate_edges': invalidate_edges_versions,
'extract_edge_dates': extract_edge_dates_versions,
'summarize_nodes': summarize_nodes_versions,
'eval': eval_versions,
}
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]

View File

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
from collections import defaultdict
from time import time
@ -65,32 +66,20 @@ async def search(
query = query.replace('\n', ' ')
# if group_ids is empty, set it to None
group_ids = group_ids if group_ids else None
edges = (
await edge_search(
driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit, config.embedding_model
)
if config.edge_config is not None
else []
)
nodes = (
await node_search(
driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit, config.embedding_model
)
if config.node_config is not None
else []
)
communities = (
await community_search(
driver, embedder, query, group_ids, config.community_config, config.limit, config.embedding_model
)
if config.community_config is not None
else []
edges, nodes, communities = await asyncio.gather(
edge_search(
driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit
),
node_search(
driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit
),
community_search(driver, embedder, query, group_ids, config.community_config, config.limit),
)
results = SearchResults(
edges=edges[: config.limit],
nodes=nodes[: config.limit],
communities=communities[: config.limit],
edges=edges,
nodes=nodes,
communities=communities,
)
end = time()
@ -105,11 +94,14 @@ async def edge_search(
embedder,
query: str,
group_ids: list[str] | None,
config: EdgeSearchConfig,
config: EdgeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
embedding_model: str | None = None,
) -> list[EntityEdge]:
if config is None:
return []
search_results: list[list[EntityEdge]] = []
if EdgeSearchMethod.bm25 in config.search_methods:
@ -163,7 +155,7 @@ async def edge_search(
if config.reranker == EdgeReranker.episode_mentions:
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
return reranked_edges
return reranked_edges[:limit]
async def node_search(
@ -171,11 +163,14 @@ async def node_search(
embedder,
query: str,
group_ids: list[str] | None,
config: NodeSearchConfig,
config: NodeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
embedding_model: str | None = None,
) -> list[EntityNode]:
if config is None:
return []
search_results: list[list[EntityNode]] = []
if NodeSearchMethod.bm25 in config.search_methods:
@ -214,7 +209,7 @@ async def node_search(
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_nodes
return reranked_nodes[:limit]
async def community_search(
@ -222,10 +217,13 @@ async def community_search(
embedder,
query: str,
group_ids: list[str] | None,
config: CommunitySearchConfig,
config: CommunitySearchConfig | None,
limit=DEFAULT_SEARCH_LIMIT,
embedding_model: str | None = None,
) -> list[CommunityNode]:
if config is None:
return []
search_results: list[list[CommunityNode]] = []
if CommunitySearchMethod.bm25 in config.search_methods:
@ -258,4 +256,4 @@ async def community_search(
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_communities
return reranked_communities[:limit]

View File

@ -16,13 +16,13 @@ limitations under the License.
import asyncio
import logging
import re
from collections import defaultdict
from time import time
from neo4j import AsyncDriver, Query
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.helpers import lucene_sanitize
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
@ -36,6 +36,22 @@ logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
def fulltext_query(query: str, group_ids: list[str] | None = None):
group_ids_filter_list = (
[f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else []
)
group_ids_filter = ''
for f in group_ids_filter_list:
group_ids_filter += f if not group_ids_filter else f'OR {f}'
group_ids_filter += ' AND ' if group_ids_filter else ''
fuzzy_query = lucene_sanitize(query) + '~'
full_query = group_ids_filter + fuzzy_query
return full_query
async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]:
@ -91,11 +107,15 @@ async def edge_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# fulltext search over facts
fuzzy_query = fulltext_query(query, group_ids)
cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
WHERE ($source_uuid IS NULL OR n.uuid = $source_uuid)
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
AND ($group_ids IS NULL OR n.group_id IN $group_ids)
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
@ -112,72 +132,6 @@ async def edge_fulltext_search(
ORDER BY score DESC LIMIT $limit
""")
if source_node_uuid is None and target_node_uuid is None:
cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""")
elif source_node_uuid is None:
cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""")
elif target_node_uuid is None:
cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""")
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
records, _, _ = await driver.execute_query(
cypher_query,
query=fuzzy_query,
@ -202,11 +156,12 @@ async def edge_similarity_search(
) -> list[EntityEdge]:
# vector similarity search over embedded facts
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
RETURN
vector.similarity.cosine(r.fact_embedding, $search_vector) AS score,
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
@ -220,72 +175,9 @@ async def edge_similarity_search(
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
LIMIT $limit
""")
if source_node_uuid is None and target_node_uuid is None:
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
elif source_node_uuid is None:
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
elif target_node_uuid is None:
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
@ -307,10 +199,11 @@ async def node_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
fuzzy_query = fulltext_query(query, group_ids)
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
YIELD node AS n, score
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
RETURN
@ -341,11 +234,10 @@ async def node_similarity_search(
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
YIELD node AS n, score
MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
RETURN
vector.similarity.cosine(n.name_embedding, $search_vector) AS score,
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
@ -353,6 +245,7 @@ async def node_similarity_search(
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
search_vector=search_vector,
group_ids=group_ids,
@ -370,7 +263,8 @@ async def community_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[CommunityNode]:
# BM25 search to get top communities
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
fuzzy_query = fulltext_query(query, group_ids)
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("community_name", $query)
@ -405,11 +299,10 @@ async def community_similarity_search(
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
YIELD node AS comm, score
MATCH (comm:Community)
WHERE $group_ids IS NULL OR comm.group_id IN $group_ids
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
RETURN
vector.similarity.cosine(comm.name_embedding, $search_vector) AS score,
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
@ -417,6 +310,7 @@ async def community_similarity_search(
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
search_vector=search_vector,
group_ids=group_ids,

View File

@ -28,7 +28,16 @@ EPISODE_WINDOW_LEN = 3
logger = logging.getLogger(__name__)
async def build_indices_and_constraints(driver: AsyncDriver):
async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bool = False):
if delete_existing:
records, _, _ = await driver.execute_query("""
SHOW INDEXES YIELD name
""")
index_names = [record['name'] for record in records]
await asyncio.gather(
*[driver.execute_query("""DROP INDEX $name""", name=name) for name in index_names]
)
range_indices: list[LiteralString] = [
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
@ -52,38 +61,15 @@ async def build_indices_and_constraints(driver: AsyncDriver):
]
fulltext_indices: list[LiteralString] = [
'CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]',
'CREATE FULLTEXT INDEX community_name IF NOT EXISTS FOR (n:Community) ON EACH [n.name]',
'CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]',
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
FOR (n:Community) ON EACH [n.name, n.group_id]""",
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
]
vector_indices: list[LiteralString] = [
"""
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""",
"""
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
FOR (n:Entity) ON (n.name_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""",
"""
CREATE VECTOR INDEX community_name_embedding IF NOT EXISTS
FOR (n:Community) ON (n.name_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""",
]
index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices
index_queries: list[LiteralString] = range_indices + fulltext_indices
await asyncio.gather(*[driver.execute_query(query) for query in index_queries])

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.3.6"
version = "0.3.7"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",

View File

@ -74,7 +74,6 @@ def format_context(facts):
async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
await graphiti.build_communities()
edges = await graphiti.search(
'tania tetlow', center_node_uuid='4bf7ebb3-3a98-46c7-90a6-8e516c487961', group_ids=None
@ -96,7 +95,7 @@ async def test_graphiti_init():
}
logger.info(pretty_results)
graphiti.close()
await graphiti.close()
@pytest.mark.asyncio