search updates (#14)

* search updates

* test updates

* add opinionated search

* update
This commit is contained in:
Preston Rasmussen 2024-08-22 14:26:26 -04:00 committed by GitHub
parent 8141a783b1
commit 63b9790026
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 269 additions and 115 deletions

View File

@ -1,15 +1,15 @@
import asyncio
from datetime import datetime
import logging
from typing import Callable, LiteralString
from typing import Callable
from neo4j import AsyncGraphDatabase
from dotenv import load_dotenv
from time import time
import os
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import EntityNode, EpisodicNode, Node
from core.edges import EntityEdge, Edge, EpisodicEdge
from core.nodes import EntityNode, EpisodicNode
from core.edges import EntityEdge, EpisodicEdge
from core.search.search import SearchConfig, hybrid_search
from core.utils import (
build_episodic_edges,
retrieve_episodes,
@ -19,22 +19,21 @@ from core.utils.bulk_utils import (
BulkEpisode,
extract_nodes_and_edges_bulk,
retrieve_previous_episodes_bulk,
compress_nodes,
dedupe_nodes_bulk,
resolve_edge_pointers,
dedupe_edges_bulk,
)
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
from core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
build_indices_and_constraints,
)
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.maintenance.temporal_operations import (
invalidate_edges,
prepare_edges_for_invalidation,
)
from core.utils.search.search_utils import (
edge_similarity_search,
entity_fulltext_search,
bfs,
from core.search.search_utils import (
get_relevant_nodes,
get_relevant_edges,
)
@ -64,10 +63,13 @@ class Graphiti:
def close(self):
self.driver.close()
async def build_indices_and_constraints(self):
await build_indices_and_constraints(self.driver)
async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int,
last_n: int = EPISODE_WINDOW_LEN,
sources: list[str] | None = "messages",
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
@ -103,9 +105,7 @@ class Graphiti:
embedder = self.llm_client.client.embeddings
now = datetime.now()
previous_episodes = await self.retrieve_episodes(
reference_time, last_n=EPISODE_WINDOW_LEN
)
previous_episodes = await self.retrieve_episodes(reference_time)
episode = EpisodicNode(
name=name,
labels=[],
@ -220,80 +220,6 @@ class Graphiti:
else:
raise e
async def build_indices(self):
index_queries: list[LiteralString] = [
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)",
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)",
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)",
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)",
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)",
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)",
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)",
"CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]",
"CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]",
"""
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""",
"""
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
FOR (n:Entity) ON (n.name_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""",
"""
CREATE CONSTRAINT entity_name IF NOT EXISTS
FOR (n:Entity) REQUIRE n.name IS UNIQUE
""",
"""
CREATE CONSTRAINT edge_facts IF NOT EXISTS
FOR ()-[e:RELATES_TO]-() REQUIRE e.fact IS UNIQUE
""",
]
await asyncio.gather(
*[self.driver.execute_query(query) for query in index_queries]
)
async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
text = query.replace("\n", " ")
search_vector = (
(
await self.llm_client.client.embeddings.create(
input=[text], model="text-embedding-3-small"
)
)
.data[0]
.embedding[:EMBEDDING_DIM]
)
edges = await edge_similarity_search(search_vector, self.driver)
nodes = await entity_fulltext_search(query, self.driver)
node_ids = [node.uuid for node in nodes]
for edge in edges:
node_ids.append(edge.source_node_uuid)
node_ids.append(edge.target_node_uuid)
node_ids = list(dict.fromkeys(node_ids))
context = await bfs(node_ids, self.driver)
return context
async def add_episode_bulk(
self,
bulk_episodes: list[BulkEpisode],
@ -368,3 +294,24 @@ class Graphiti:
except Exception as e:
raise e
async def search(self, query: str, num_results=10):
search_config = SearchConfig(num_episodes=0, num_results=num_results)
edges = (
await hybrid_search(
self.driver,
self.llm_client.client.embeddings,
query,
datetime.now(),
search_config,
)
)["edges"]
facts = [edge.fact for edge in edges]
return facts
async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
return await hybrid_search(
self.driver, self.llm_client.client.embeddings, query, timestamp, config
)

View File

@ -112,7 +112,7 @@ def node_list(context: dict[str, any]) -> list[Message]:
Task:
1. Group nodes together such that all duplicate nodes are in the same list of names
2. All dupolicate names should be grouped together in the same list
2. All duplicate names should be grouped together in the same list
Guidelines:
1. Each name from the list of nodes should appear EXACTLY once in your response

104
core/search/search.py Normal file
View File

@ -0,0 +1,104 @@
import asyncio
import logging
from datetime import datetime
from time import time
from neo4j import AsyncDriver
from pydantic import BaseModel
from core.edges import EntityEdge, Edge
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import Node
from core.search.search_utils import (
edge_similarity_search,
edge_fulltext_search,
get_mentioned_nodes,
rrf,
)
from core.utils import retrieve_episodes
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
logger = logging.getLogger(__name__)
class SearchConfig(BaseModel):
num_results: int = 10
num_episodes: int = EPISODE_WINDOW_LEN
similarity_search: str = "cosine"
text_search: str = "BM25"
reranker: str = "rrf"
async def hybrid_search(
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
) -> dict[str, [Node | Edge]]:
start = time()
episodes = []
nodes = []
edges = []
search_results = []
if config.num_episodes > 0:
episodes.extend(await retrieve_episodes(driver, timestamp))
nodes.extend(await get_mentioned_nodes(driver, episodes))
if config.text_search == "BM25":
text_search = await edge_fulltext_search(query, driver)
search_results.append(text_search)
if config.similarity_search == "cosine":
query_text = query.replace("\n", " ")
search_vector = (
(await embedder.create(input=[query_text], model="text-embedding-3-small"))
.data[0]
.embedding[:EMBEDDING_DIM]
)
similarity_search = await edge_similarity_search(search_vector, driver)
search_results.append(similarity_search)
if len(search_results) == 1:
edges = search_results[0]
elif len(search_results) > 1 and not config.reranker == "rrf":
logger.exception("Multiple searches enabled without a reranker")
raise Exception("Multiple searches enabled without a reranker")
elif config.reranker == "rrf":
edge_uuid_map = {}
search_result_uuids = []
logger.info([[edge.fact for edge in result] for result in search_results])
for result in search_results:
result_uuids = []
for edge in result:
result_uuids.append(edge.uuid)
edge_uuid_map[edge.uuid] = edge
search_result_uuids.append(result_uuids)
search_result_uuids = [
[edge.uuid for edge in result] for result in search_results
]
reranked_uuids = rrf(search_result_uuids)
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
edges.extend(reranked_edges)
context = {
"episodes": episodes,
"nodes": nodes,
"edges": edges,
}
end = time()
logger.info(
f"search returned context for query {query} in {(end - start) * 1000} ms"
)
return context

View File

@ -1,23 +1,54 @@
import asyncio
import logging
from collections import defaultdict
from datetime import datetime
from time import time
from neo4j import AsyncDriver
from core.edges import EntityEdge
from core.nodes import EntityNode
from core.nodes import EntityNode, EpisodicNode
logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query(
"""
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
RETURN DISTINCT
n.uuid As uuid,
n.name AS name,
n.created_at AS created_at,
n.summary AS summary
""",
uuids=episode_uuids,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record["uuid"],
name=record["name"],
labels=["Entity"],
created_at=datetime.now(),
summary=record["summary"],
)
)
return nodes
async def bfs(node_ids: list[str], driver: AsyncDriver):
records, _, _ = await driver.execute_query(
"""
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
RETURN
RETURN DISTINCT
n.uuid AS source_node_uuid,
n.name AS source_name,
n.summary AS source_summary,
@ -138,7 +169,7 @@ async def entity_similarity_search(
EntityNode(
uuid=record["uuid"],
name=record["name"],
labels=[],
labels=["Entity"],
created_at=datetime.now(),
summary=record["summary"],
)
@ -155,7 +186,7 @@ async def entity_fulltext_search(
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
RETURN
RETURN
node.uuid As uuid,
node.name AS name,
node.created_at AS created_at,
@ -173,7 +204,7 @@ async def entity_fulltext_search(
EntityNode(
uuid=record["uuid"],
name=record["name"],
labels=[],
labels=["Entity"],
created_at=datetime.now(),
summary=record["summary"],
)
@ -193,7 +224,7 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS r, score
MATCH (n:Entity)-[r]->(m:Entity)
RETURN
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
@ -291,3 +322,18 @@ async def get_relevant_edges(
)
return relevant_edges
# takes in a list of rankings of uuids
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
scores: dict[str, int] = defaultdict(int)
for result in results:
for i, uuid in enumerate(result):
scores[uuid] += 1 / (i + rank_const)
scored_uuids = [term for term in scores.items()]
scored_uuids.sort(reverse=True, key=lambda term: term[1])
sorted_uuids = [term[0] for term in scored_uuids]
return sorted_uuids

View File

@ -1,5 +1,4 @@
import asyncio
from collections import defaultdict
from datetime import datetime
from neo4j import AsyncDriver
@ -21,7 +20,7 @@ from core.utils.maintenance.node_operations import (
dedupe_node_list,
dedupe_extracted_nodes,
)
from core.utils.search.search_utils import get_relevant_nodes, get_relevant_edges
from core.search.search_utils import get_relevant_nodes, get_relevant_edges
CHUNK_SIZE = 10

View File

@ -1,4 +1,6 @@
import asyncio
from datetime import datetime, timezone
from typing import LiteralString
from core.nodes import EpisodicNode
from neo4j import AsyncDriver
@ -9,6 +11,64 @@ EPISODE_WINDOW_LEN = 3
logger = logging.getLogger(__name__)
async def build_indices_and_constraints(driver: AsyncDriver):
constraints: list[LiteralString] = [
"""
CREATE CONSTRAINT entity_name IF NOT EXISTS
FOR (n:Entity) REQUIRE n.name IS UNIQUE
""",
"""
CREATE CONSTRAINT edge_facts IF NOT EXISTS
FOR ()-[e:RELATES_TO]-() REQUIRE e.fact IS UNIQUE
""",
]
range_indices: list[LiteralString] = [
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)",
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)",
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)",
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)",
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)",
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)",
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)",
]
fulltext_indices: list[LiteralString] = [
"CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]",
"CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]",
]
vector_indices: list[LiteralString] = [
"""
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""",
"""
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
FOR (n:Entity) ON (n.name_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""",
]
index_queries: list[LiteralString] = (
constraints + range_indices + fulltext_indices + vector_indices
)
await asyncio.gather(*[driver.execute_query(query) for query in index_queries])
async def clear_data(driver: AsyncDriver):
async with driver.session() as session:
@ -21,7 +81,7 @@ async def clear_data(driver: AsyncDriver):
async def retrieve_episodes(
driver: AsyncDriver,
reference_time: datetime,
last_n: int,
last_n: int = EPISODE_WINDOW_LEN,
sources: list[str] | None = "messages",
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""

View File

@ -61,7 +61,7 @@ async def main(use_bulk: bool = True):
episode_type="string",
reference_time=message.actual_timestamp,
)
for i, message in enumerate(messages[3:7])
for i, message in enumerate(messages[3:14])
]
await client.add_episode_bulk(episodes)

View File

@ -4,6 +4,8 @@ import os
import pytest
from core.search.search import SearchConfig
pytestmark = pytest.mark.integration
import asyncio
@ -51,16 +53,13 @@ def setup_logging():
return logger
def format_context(context):
def format_context(facts):
formatted_string = ""
for uuid, data in context.items():
formatted_string += f"UUID: {uuid}\n"
formatted_string += f" Name: {data['name']}\n"
formatted_string += f" Summary: {data['summary']}\n"
formatted_string += " Facts:\n"
for fact in data["facts"]:
formatted_string += f" - {fact}\n"
formatted_string += "\n"
formatted_string += "FACTS:\n"
for fact in facts:
formatted_string += f" - {fact}\n"
formatted_string += "\n"
return formatted_string.strip()
@ -68,19 +67,18 @@ def format_context(context):
async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None)
await graphiti.build_indices()
context = await graphiti.search("Freakenomics guest")
facts = await graphiti.search("Freakenomics guest")
logger.info("QUERY: Freakenomics guest" + "RESULT:" + format_context(context))
logger.info("\nQUERY: Freakenomics guest\n" + format_context(facts))
context = await graphiti.search("tania tetlow")
facts = await graphiti.search("tania tetlow\n")
logger.info("QUERY: Tania Tetlow" + "RESULT:" + format_context(context))
logger.info("\nQUERY: Tania Tetlow\n" + format_context(facts))
context = await graphiti.search("issues with higher ed")
facts = await graphiti.search("issues with higher ed")
logger.info("QUERY: issues with higher ed" + "RESULT:" + format_context(context))
logger.info("\nQUERY: issues with higher ed\n" + format_context(facts))
graphiti.close()