mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
search updates (#14)
* search updates * test updates * add opinionated search * update
This commit is contained in:
parent
8141a783b1
commit
63b9790026
123
core/graphiti.py
123
core/graphiti.py
@ -1,15 +1,15 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import Callable, LiteralString
|
||||
from typing import Callable
|
||||
from neo4j import AsyncGraphDatabase
|
||||
from dotenv import load_dotenv
|
||||
from time import time
|
||||
import os
|
||||
|
||||
from core.llm_client.config import EMBEDDING_DIM
|
||||
from core.nodes import EntityNode, EpisodicNode, Node
|
||||
from core.edges import EntityEdge, Edge, EpisodicEdge
|
||||
from core.nodes import EntityNode, EpisodicNode
|
||||
from core.edges import EntityEdge, EpisodicEdge
|
||||
from core.search.search import SearchConfig, hybrid_search
|
||||
from core.utils import (
|
||||
build_episodic_edges,
|
||||
retrieve_episodes,
|
||||
@ -19,22 +19,21 @@ from core.utils.bulk_utils import (
|
||||
BulkEpisode,
|
||||
extract_nodes_and_edges_bulk,
|
||||
retrieve_previous_episodes_bulk,
|
||||
compress_nodes,
|
||||
dedupe_nodes_bulk,
|
||||
resolve_edge_pointers,
|
||||
dedupe_edges_bulk,
|
||||
)
|
||||
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
|
||||
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
|
||||
from core.utils.maintenance.graph_data_operations import (
|
||||
EPISODE_WINDOW_LEN,
|
||||
build_indices_and_constraints,
|
||||
)
|
||||
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
||||
from core.utils.maintenance.temporal_operations import (
|
||||
invalidate_edges,
|
||||
prepare_edges_for_invalidation,
|
||||
)
|
||||
from core.utils.search.search_utils import (
|
||||
edge_similarity_search,
|
||||
entity_fulltext_search,
|
||||
bfs,
|
||||
from core.search.search_utils import (
|
||||
get_relevant_nodes,
|
||||
get_relevant_edges,
|
||||
)
|
||||
@ -64,10 +63,13 @@ class Graphiti:
|
||||
def close(self):
|
||||
self.driver.close()
|
||||
|
||||
async def build_indices_and_constraints(self):
|
||||
await build_indices_and_constraints(self.driver)
|
||||
|
||||
async def retrieve_episodes(
|
||||
self,
|
||||
reference_time: datetime,
|
||||
last_n: int,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
sources: list[str] | None = "messages",
|
||||
) -> list[EpisodicNode]:
|
||||
"""Retrieve the last n episodic nodes from the graph"""
|
||||
@ -103,9 +105,7 @@ class Graphiti:
|
||||
embedder = self.llm_client.client.embeddings
|
||||
now = datetime.now()
|
||||
|
||||
previous_episodes = await self.retrieve_episodes(
|
||||
reference_time, last_n=EPISODE_WINDOW_LEN
|
||||
)
|
||||
previous_episodes = await self.retrieve_episodes(reference_time)
|
||||
episode = EpisodicNode(
|
||||
name=name,
|
||||
labels=[],
|
||||
@ -220,80 +220,6 @@ class Graphiti:
|
||||
else:
|
||||
raise e
|
||||
|
||||
async def build_indices(self):
|
||||
index_queries: list[LiteralString] = [
|
||||
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
|
||||
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
|
||||
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)",
|
||||
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)",
|
||||
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
|
||||
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
|
||||
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
|
||||
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
|
||||
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)",
|
||||
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)",
|
||||
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)",
|
||||
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)",
|
||||
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)",
|
||||
"CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]",
|
||||
"CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]",
|
||||
"""
|
||||
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
|
||||
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""",
|
||||
"""
|
||||
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
|
||||
FOR (n:Entity) ON (n.name_embedding)
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""",
|
||||
"""
|
||||
CREATE CONSTRAINT entity_name IF NOT EXISTS
|
||||
FOR (n:Entity) REQUIRE n.name IS UNIQUE
|
||||
""",
|
||||
"""
|
||||
CREATE CONSTRAINT edge_facts IF NOT EXISTS
|
||||
FOR ()-[e:RELATES_TO]-() REQUIRE e.fact IS UNIQUE
|
||||
""",
|
||||
]
|
||||
|
||||
await asyncio.gather(
|
||||
*[self.driver.execute_query(query) for query in index_queries]
|
||||
)
|
||||
|
||||
async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
|
||||
text = query.replace("\n", " ")
|
||||
search_vector = (
|
||||
(
|
||||
await self.llm_client.client.embeddings.create(
|
||||
input=[text], model="text-embedding-3-small"
|
||||
)
|
||||
)
|
||||
.data[0]
|
||||
.embedding[:EMBEDDING_DIM]
|
||||
)
|
||||
|
||||
edges = await edge_similarity_search(search_vector, self.driver)
|
||||
nodes = await entity_fulltext_search(query, self.driver)
|
||||
|
||||
node_ids = [node.uuid for node in nodes]
|
||||
|
||||
for edge in edges:
|
||||
node_ids.append(edge.source_node_uuid)
|
||||
node_ids.append(edge.target_node_uuid)
|
||||
|
||||
node_ids = list(dict.fromkeys(node_ids))
|
||||
|
||||
context = await bfs(node_ids, self.driver)
|
||||
|
||||
return context
|
||||
|
||||
async def add_episode_bulk(
|
||||
self,
|
||||
bulk_episodes: list[BulkEpisode],
|
||||
@ -368,3 +294,24 @@ class Graphiti:
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def search(self, query: str, num_results=10):
|
||||
search_config = SearchConfig(num_episodes=0, num_results=num_results)
|
||||
edges = (
|
||||
await hybrid_search(
|
||||
self.driver,
|
||||
self.llm_client.client.embeddings,
|
||||
query,
|
||||
datetime.now(),
|
||||
search_config,
|
||||
)
|
||||
)["edges"]
|
||||
|
||||
facts = [edge.fact for edge in edges]
|
||||
|
||||
return facts
|
||||
|
||||
async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
|
||||
return await hybrid_search(
|
||||
self.driver, self.llm_client.client.embeddings, query, timestamp, config
|
||||
)
|
||||
|
@ -112,7 +112,7 @@ def node_list(context: dict[str, any]) -> list[Message]:
|
||||
|
||||
Task:
|
||||
1. Group nodes together such that all duplicate nodes are in the same list of names
|
||||
2. All dupolicate names should be grouped together in the same list
|
||||
2. All duplicate names should be grouped together in the same list
|
||||
|
||||
Guidelines:
|
||||
1. Each name from the list of nodes should appear EXACTLY once in your response
|
||||
|
104
core/search/search.py
Normal file
104
core/search/search.py
Normal file
@ -0,0 +1,104 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.edges import EntityEdge, Edge
|
||||
from core.llm_client.config import EMBEDDING_DIM
|
||||
from core.nodes import Node
|
||||
from core.search.search_utils import (
|
||||
edge_similarity_search,
|
||||
edge_fulltext_search,
|
||||
get_mentioned_nodes,
|
||||
rrf,
|
||||
)
|
||||
from core.utils import retrieve_episodes
|
||||
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
num_results: int = 10
|
||||
num_episodes: int = EPISODE_WINDOW_LEN
|
||||
similarity_search: str = "cosine"
|
||||
text_search: str = "BM25"
|
||||
reranker: str = "rrf"
|
||||
|
||||
|
||||
async def hybrid_search(
|
||||
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
|
||||
) -> dict[str, [Node | Edge]]:
|
||||
start = time()
|
||||
|
||||
episodes = []
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
search_results = []
|
||||
|
||||
if config.num_episodes > 0:
|
||||
episodes.extend(await retrieve_episodes(driver, timestamp))
|
||||
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
||||
|
||||
if config.text_search == "BM25":
|
||||
text_search = await edge_fulltext_search(query, driver)
|
||||
search_results.append(text_search)
|
||||
|
||||
if config.similarity_search == "cosine":
|
||||
query_text = query.replace("\n", " ")
|
||||
search_vector = (
|
||||
(await embedder.create(input=[query_text], model="text-embedding-3-small"))
|
||||
.data[0]
|
||||
.embedding[:EMBEDDING_DIM]
|
||||
)
|
||||
|
||||
similarity_search = await edge_similarity_search(search_vector, driver)
|
||||
search_results.append(similarity_search)
|
||||
|
||||
if len(search_results) == 1:
|
||||
edges = search_results[0]
|
||||
|
||||
elif len(search_results) > 1 and not config.reranker == "rrf":
|
||||
logger.exception("Multiple searches enabled without a reranker")
|
||||
raise Exception("Multiple searches enabled without a reranker")
|
||||
|
||||
elif config.reranker == "rrf":
|
||||
edge_uuid_map = {}
|
||||
search_result_uuids = []
|
||||
|
||||
logger.info([[edge.fact for edge in result] for result in search_results])
|
||||
|
||||
for result in search_results:
|
||||
result_uuids = []
|
||||
for edge in result:
|
||||
result_uuids.append(edge.uuid)
|
||||
edge_uuid_map[edge.uuid] = edge
|
||||
|
||||
search_result_uuids.append(result_uuids)
|
||||
|
||||
search_result_uuids = [
|
||||
[edge.uuid for edge in result] for result in search_results
|
||||
]
|
||||
|
||||
reranked_uuids = rrf(search_result_uuids)
|
||||
|
||||
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
edges.extend(reranked_edges)
|
||||
|
||||
context = {
|
||||
"episodes": episodes,
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
}
|
||||
|
||||
end = time()
|
||||
|
||||
logger.info(
|
||||
f"search returned context for query {query} in {(end - start) * 1000} ms"
|
||||
)
|
||||
|
||||
return context
|
@ -1,23 +1,54 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
|
||||
from core.edges import EntityEdge
|
||||
from core.nodes import EntityNode
|
||||
from core.nodes import EntityNode, EpisodicNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RELEVANT_SCHEMA_LIMIT = 3
|
||||
|
||||
|
||||
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
|
||||
episode_uuids = [episode.uuid for episode in episodes]
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
""",
|
||||
uuids=episode_uuids,
|
||||
)
|
||||
|
||||
nodes: list[EntityNode] = []
|
||||
|
||||
for record in records:
|
||||
nodes.append(
|
||||
EntityNode(
|
||||
uuid=record["uuid"],
|
||||
name=record["name"],
|
||||
labels=["Entity"],
|
||||
created_at=datetime.now(),
|
||||
summary=record["summary"],
|
||||
)
|
||||
)
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
|
||||
RETURN
|
||||
RETURN DISTINCT
|
||||
n.uuid AS source_node_uuid,
|
||||
n.name AS source_name,
|
||||
n.summary AS source_summary,
|
||||
@ -138,7 +169,7 @@ async def entity_similarity_search(
|
||||
EntityNode(
|
||||
uuid=record["uuid"],
|
||||
name=record["name"],
|
||||
labels=[],
|
||||
labels=["Entity"],
|
||||
created_at=datetime.now(),
|
||||
summary=record["summary"],
|
||||
)
|
||||
@ -155,7 +186,7 @@ async def entity_fulltext_search(
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
|
||||
RETURN
|
||||
RETURN
|
||||
node.uuid As uuid,
|
||||
node.name AS name,
|
||||
node.created_at AS created_at,
|
||||
@ -173,7 +204,7 @@ async def entity_fulltext_search(
|
||||
EntityNode(
|
||||
uuid=record["uuid"],
|
||||
name=record["name"],
|
||||
labels=[],
|
||||
labels=["Entity"],
|
||||
created_at=datetime.now(),
|
||||
summary=record["summary"],
|
||||
)
|
||||
@ -193,7 +224,7 @@ async def edge_fulltext_search(
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS r, score
|
||||
MATCH (n:Entity)-[r]->(m:Entity)
|
||||
RETURN
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
@ -291,3 +322,18 @@ async def get_relevant_edges(
|
||||
)
|
||||
|
||||
return relevant_edges
|
||||
|
||||
|
||||
# takes in a list of rankings of uuids
|
||||
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
||||
scores: dict[str, int] = defaultdict(int)
|
||||
for result in results:
|
||||
for i, uuid in enumerate(result):
|
||||
scores[uuid] += 1 / (i + rank_const)
|
||||
|
||||
scored_uuids = [term for term in scores.items()]
|
||||
scored_uuids.sort(reverse=True, key=lambda term: term[1])
|
||||
|
||||
sorted_uuids = [term[0] for term in scored_uuids]
|
||||
|
||||
return sorted_uuids
|
@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
@ -21,7 +20,7 @@ from core.utils.maintenance.node_operations import (
|
||||
dedupe_node_list,
|
||||
dedupe_extracted_nodes,
|
||||
)
|
||||
from core.utils.search.search_utils import get_relevant_nodes, get_relevant_edges
|
||||
from core.search.search_utils import get_relevant_nodes, get_relevant_edges
|
||||
|
||||
CHUNK_SIZE = 10
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import LiteralString
|
||||
|
||||
from core.nodes import EpisodicNode
|
||||
from neo4j import AsyncDriver
|
||||
@ -9,6 +11,64 @@ EPISODE_WINDOW_LEN = 3
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def build_indices_and_constraints(driver: AsyncDriver):
|
||||
constraints: list[LiteralString] = [
|
||||
"""
|
||||
CREATE CONSTRAINT entity_name IF NOT EXISTS
|
||||
FOR (n:Entity) REQUIRE n.name IS UNIQUE
|
||||
""",
|
||||
"""
|
||||
CREATE CONSTRAINT edge_facts IF NOT EXISTS
|
||||
FOR ()-[e:RELATES_TO]-() REQUIRE e.fact IS UNIQUE
|
||||
""",
|
||||
]
|
||||
|
||||
range_indices: list[LiteralString] = [
|
||||
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
|
||||
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
|
||||
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)",
|
||||
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)",
|
||||
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
|
||||
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
|
||||
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
|
||||
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
|
||||
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)",
|
||||
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)",
|
||||
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)",
|
||||
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)",
|
||||
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)",
|
||||
]
|
||||
|
||||
fulltext_indices: list[LiteralString] = [
|
||||
"CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]",
|
||||
"CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]",
|
||||
]
|
||||
|
||||
vector_indices: list[LiteralString] = [
|
||||
"""
|
||||
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
|
||||
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""",
|
||||
"""
|
||||
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
|
||||
FOR (n:Entity) ON (n.name_embedding)
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""",
|
||||
]
|
||||
index_queries: list[LiteralString] = (
|
||||
constraints + range_indices + fulltext_indices + vector_indices
|
||||
)
|
||||
|
||||
await asyncio.gather(*[driver.execute_query(query) for query in index_queries])
|
||||
|
||||
|
||||
async def clear_data(driver: AsyncDriver):
|
||||
async with driver.session() as session:
|
||||
|
||||
@ -21,7 +81,7 @@ async def clear_data(driver: AsyncDriver):
|
||||
async def retrieve_episodes(
|
||||
driver: AsyncDriver,
|
||||
reference_time: datetime,
|
||||
last_n: int,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
sources: list[str] | None = "messages",
|
||||
) -> list[EpisodicNode]:
|
||||
"""Retrieve the last n episodic nodes from the graph"""
|
||||
|
@ -61,7 +61,7 @@ async def main(use_bulk: bool = True):
|
||||
episode_type="string",
|
||||
reference_time=message.actual_timestamp,
|
||||
)
|
||||
for i, message in enumerate(messages[3:7])
|
||||
for i, message in enumerate(messages[3:14])
|
||||
]
|
||||
|
||||
await client.add_episode_bulk(episodes)
|
||||
|
@ -4,6 +4,8 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.search.search import SearchConfig
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
import asyncio
|
||||
@ -51,16 +53,13 @@ def setup_logging():
|
||||
return logger
|
||||
|
||||
|
||||
def format_context(context):
|
||||
def format_context(facts):
|
||||
formatted_string = ""
|
||||
for uuid, data in context.items():
|
||||
formatted_string += f"UUID: {uuid}\n"
|
||||
formatted_string += f" Name: {data['name']}\n"
|
||||
formatted_string += f" Summary: {data['summary']}\n"
|
||||
formatted_string += " Facts:\n"
|
||||
for fact in data["facts"]:
|
||||
formatted_string += f" - {fact}\n"
|
||||
formatted_string += "\n"
|
||||
formatted_string += "FACTS:\n"
|
||||
for fact in facts:
|
||||
formatted_string += f" - {fact}\n"
|
||||
formatted_string += "\n"
|
||||
|
||||
return formatted_string.strip()
|
||||
|
||||
|
||||
@ -68,19 +67,18 @@ def format_context(context):
|
||||
async def test_graphiti_init():
|
||||
logger = setup_logging()
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None)
|
||||
await graphiti.build_indices()
|
||||
|
||||
context = await graphiti.search("Freakenomics guest")
|
||||
facts = await graphiti.search("Freakenomics guest")
|
||||
|
||||
logger.info("QUERY: Freakenomics guest" + "RESULT:" + format_context(context))
|
||||
logger.info("\nQUERY: Freakenomics guest\n" + format_context(facts))
|
||||
|
||||
context = await graphiti.search("tania tetlow")
|
||||
facts = await graphiti.search("tania tetlow\n")
|
||||
|
||||
logger.info("QUERY: Tania Tetlow" + "RESULT:" + format_context(context))
|
||||
logger.info("\nQUERY: Tania Tetlow\n" + format_context(facts))
|
||||
|
||||
context = await graphiti.search("issues with higher ed")
|
||||
facts = await graphiti.search("issues with higher ed")
|
||||
|
||||
logger.info("QUERY: issues with higher ed" + "RESULT:" + format_context(context))
|
||||
logger.info("\nQUERY: issues with higher ed\n" + format_context(facts))
|
||||
graphiti.close()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user