2024-08-23 13:01:33 -07:00
|
|
|
"""
|
|
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
"""
|
|
|
|
|
2024-08-13 14:35:43 -04:00
|
|
|
import asyncio
|
|
|
|
import logging
|
2024-08-22 12:26:13 -07:00
|
|
|
from datetime import datetime
|
|
|
|
from time import time
|
|
|
|
|
2024-08-15 12:03:41 -04:00
|
|
|
from dotenv import load_dotenv
|
2024-08-22 12:26:13 -07:00
|
|
|
from neo4j import AsyncGraphDatabase
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-25 10:07:50 -07:00
|
|
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
2024-08-26 10:13:05 -07:00
|
|
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
2024-08-26 10:30:22 -04:00
|
|
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
2024-09-16 14:03:05 -04:00
|
|
|
from graphiti_core.search.search import SearchConfig, search
|
|
|
|
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
|
|
|
from graphiti_core.search.search_config_recipes import (
|
|
|
|
EDGE_HYBRID_SEARCH_NODE_DISTANCE,
|
|
|
|
EDGE_HYBRID_SEARCH_RRF,
|
|
|
|
NODE_HYBRID_SEARCH_NODE_DISTANCE,
|
|
|
|
NODE_HYBRID_SEARCH_RRF,
|
|
|
|
)
|
2024-08-25 10:07:50 -07:00
|
|
|
from graphiti_core.search.search_utils import (
|
2024-08-30 10:48:28 -04:00
|
|
|
RELEVANT_SCHEMA_LIMIT,
|
2024-08-23 14:18:45 -04:00
|
|
|
get_relevant_edges,
|
|
|
|
get_relevant_nodes,
|
2024-08-22 12:26:13 -07:00
|
|
|
)
|
2024-08-25 10:07:50 -07:00
|
|
|
from graphiti_core.utils import (
|
2024-08-23 14:18:45 -04:00
|
|
|
build_episodic_edges,
|
|
|
|
retrieve_episodes,
|
2024-08-15 12:03:41 -04:00
|
|
|
)
|
2024-08-25 10:07:50 -07:00
|
|
|
from graphiti_core.utils.bulk_utils import (
|
2024-08-26 10:30:22 -04:00
|
|
|
RawEpisode,
|
2024-08-23 14:18:45 -04:00
|
|
|
dedupe_edges_bulk,
|
|
|
|
dedupe_nodes_bulk,
|
2024-08-30 10:48:28 -04:00
|
|
|
extract_edge_dates_bulk,
|
2024-08-23 14:18:45 -04:00
|
|
|
extract_nodes_and_edges_bulk,
|
|
|
|
resolve_edge_pointers,
|
|
|
|
retrieve_previous_episodes_bulk,
|
2024-08-20 16:29:19 -04:00
|
|
|
)
|
2024-09-11 12:06:35 -04:00
|
|
|
from graphiti_core.utils.maintenance.community_operations import (
|
|
|
|
build_communities,
|
|
|
|
remove_communities,
|
|
|
|
)
|
2024-08-25 10:07:50 -07:00
|
|
|
from graphiti_core.utils.maintenance.edge_operations import (
|
2024-08-23 14:18:45 -04:00
|
|
|
extract_edges,
|
2024-09-03 13:25:52 -04:00
|
|
|
resolve_extracted_edges,
|
2024-08-22 18:09:44 -04:00
|
|
|
)
|
2024-08-25 10:07:50 -07:00
|
|
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
2024-08-23 14:18:45 -04:00
|
|
|
EPISODE_WINDOW_LEN,
|
|
|
|
build_indices_and_constraints,
|
2024-08-22 14:26:26 -04:00
|
|
|
)
|
2024-09-03 13:25:52 -04:00
|
|
|
from graphiti_core.utils.maintenance.node_operations import (
|
|
|
|
extract_nodes,
|
|
|
|
resolve_extracted_nodes,
|
|
|
|
)
|
2024-08-13 14:35:43 -04:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2024-08-15 12:03:41 -04:00
|
|
|
load_dotenv()
|
2024-08-14 10:17:12 -04:00
|
|
|
|
|
|
|
|
2024-08-13 14:35:43 -04:00
|
|
|
class Graphiti:
|
2024-08-23 14:18:45 -04:00
|
|
|
def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None):
|
2024-08-26 13:11:50 -04:00
|
|
|
"""
|
|
|
|
Initialize a Graphiti instance.
|
|
|
|
|
|
|
|
This constructor sets up a connection to the Neo4j database and initializes
|
|
|
|
the LLM client for natural language processing tasks.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
uri : str
|
|
|
|
The URI of the Neo4j database.
|
|
|
|
user : str
|
|
|
|
The username for authenticating with the Neo4j database.
|
|
|
|
password : str
|
|
|
|
The password for authenticating with the Neo4j database.
|
|
|
|
llm_client : LLMClient | None, optional
|
|
|
|
An instance of LLMClient for natural language processing tasks.
|
|
|
|
If not provided, a default OpenAIClient will be initialized.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
None
|
|
|
|
|
|
|
|
Notes
|
|
|
|
-----
|
|
|
|
This method establishes a connection to the Neo4j database using the provided
|
|
|
|
credentials. It also sets up the LLM client, either using the provided client
|
|
|
|
or by creating a default OpenAIClient.
|
|
|
|
|
|
|
|
The default database name is set to 'neo4j'. If a different database name
|
|
|
|
is required, it should be specified in the URI or set separately after
|
|
|
|
initialization.
|
|
|
|
|
|
|
|
The OpenAI API key is expected to be set in the environment variables.
|
|
|
|
Make sure to set the OPENAI_API_KEY environment variable before initializing
|
|
|
|
Graphiti if you're using the default OpenAIClient.
|
|
|
|
"""
|
2024-08-23 14:18:45 -04:00
|
|
|
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
|
|
|
self.database = 'neo4j'
|
|
|
|
if llm_client:
|
|
|
|
self.llm_client = llm_client
|
|
|
|
else:
|
2024-08-26 10:13:05 -07:00
|
|
|
self.llm_client = OpenAIClient()
|
2024-08-23 14:18:45 -04:00
|
|
|
|
|
|
|
def close(self):
|
2024-08-26 13:11:50 -04:00
|
|
|
"""
|
|
|
|
Close the connection to the Neo4j database.
|
|
|
|
|
|
|
|
This method safely closes the driver connection to the Neo4j database.
|
|
|
|
It should be called when the Graphiti instance is no longer needed or
|
|
|
|
when the application is shutting down.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
2024-09-06 12:33:42 -04:00
|
|
|
self
|
2024-08-26 13:11:50 -04:00
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
None
|
|
|
|
|
|
|
|
Notes
|
|
|
|
-----
|
|
|
|
It's important to close the driver connection to release system resources
|
|
|
|
and ensure that all pending transactions are completed or rolled back.
|
|
|
|
This method should be called as part of a cleanup process, potentially
|
|
|
|
in a context manager or a shutdown hook.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
graphiti = Graphiti(uri, user, password)
|
|
|
|
try:
|
|
|
|
# Use graphiti...
|
|
|
|
finally:
|
|
|
|
graphiti.close()
|
2024-08-23 14:18:45 -04:00
|
|
|
self.driver.close()
|
2024-08-26 13:11:50 -04:00
|
|
|
"""
|
2024-08-23 14:18:45 -04:00
|
|
|
|
|
|
|
async def build_indices_and_constraints(self):
|
2024-08-26 13:11:50 -04:00
|
|
|
"""
|
|
|
|
Build indices and constraints in the Neo4j database.
|
|
|
|
|
|
|
|
This method sets up the necessary indices and constraints in the Neo4j database
|
|
|
|
to optimize query performance and ensure data integrity for the knowledge graph.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
2024-09-06 12:33:42 -04:00
|
|
|
self
|
2024-08-26 13:11:50 -04:00
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
None
|
|
|
|
|
|
|
|
Notes
|
|
|
|
-----
|
|
|
|
This method should typically be called once during the initial setup of the
|
|
|
|
knowledge graph or when updating the database schema. It uses the
|
|
|
|
`build_indices_and_constraints` function from the
|
|
|
|
`graphiti_core.utils.maintenance.graph_data_operations` module to perform
|
|
|
|
the actual database operations.
|
|
|
|
|
|
|
|
The specific indices and constraints created depend on the implementation
|
|
|
|
of the `build_indices_and_constraints` function. Refer to that function's
|
|
|
|
documentation for details on the exact database schema modifications.
|
|
|
|
|
|
|
|
Caution: Running this method on a large existing database may take some time
|
|
|
|
and could impact database performance during execution.
|
|
|
|
"""
|
2024-08-23 14:18:45 -04:00
|
|
|
await build_indices_and_constraints(self.driver)
|
|
|
|
|
|
|
|
async def retrieve_episodes(
|
2024-09-04 10:05:45 -04:00
|
|
|
self,
|
|
|
|
reference_time: datetime,
|
|
|
|
last_n: int = EPISODE_WINDOW_LEN,
|
2024-09-06 12:33:42 -04:00
|
|
|
group_ids: list[str | None] | None = None,
|
2024-08-23 14:18:45 -04:00
|
|
|
) -> list[EpisodicNode]:
|
2024-08-26 13:11:50 -04:00
|
|
|
"""
|
|
|
|
Retrieve the last n episodic nodes from the graph.
|
|
|
|
|
|
|
|
This method fetches a specified number of the most recent episodic nodes
|
|
|
|
from the graph, relative to the given reference time.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
reference_time : datetime
|
|
|
|
The reference time to retrieve episodes before.
|
|
|
|
last_n : int, optional
|
|
|
|
The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
|
2024-09-06 12:33:42 -04:00
|
|
|
group_ids : list[str | None], optional
|
|
|
|
The group ids to return data from.
|
2024-08-26 13:11:50 -04:00
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
list[EpisodicNode]
|
|
|
|
A list of the most recent EpisodicNode objects.
|
|
|
|
|
|
|
|
Notes
|
|
|
|
-----
|
|
|
|
The actual retrieval is performed by the `retrieve_episodes` function
|
|
|
|
from the `graphiti_core.utils` module.
|
|
|
|
"""
|
2024-09-06 12:33:42 -04:00
|
|
|
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
|
2024-08-23 14:18:45 -04:00
|
|
|
|
|
|
|
async def add_episode(
|
2024-09-04 10:05:45 -04:00
|
|
|
self,
|
|
|
|
name: str,
|
|
|
|
episode_body: str,
|
|
|
|
source_description: str,
|
|
|
|
reference_time: datetime,
|
|
|
|
source: EpisodeType = EpisodeType.message,
|
2024-09-06 12:33:42 -04:00
|
|
|
group_id: str | None = None,
|
|
|
|
uuid: str | None = None,
|
2024-08-23 14:18:45 -04:00
|
|
|
):
|
2024-08-26 13:11:50 -04:00
|
|
|
"""
|
|
|
|
Process an episode and update the graph.
|
|
|
|
|
|
|
|
This method extracts information from the episode, creates nodes and edges,
|
|
|
|
and updates the graph database accordingly.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
name : str
|
|
|
|
The name of the episode.
|
|
|
|
episode_body : str
|
|
|
|
The content of the episode.
|
|
|
|
source_description : str
|
|
|
|
A description of the episode's source.
|
|
|
|
reference_time : datetime
|
|
|
|
The reference time for the episode.
|
|
|
|
source : EpisodeType, optional
|
|
|
|
The type of the episode. Defaults to EpisodeType.message.
|
2024-09-06 12:33:42 -04:00
|
|
|
group_id : str | None
|
|
|
|
An id for the graph partition the episode is a part of.
|
|
|
|
uuid : str | None
|
|
|
|
Optional uuid of the episode.
|
2024-08-26 13:11:50 -04:00
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
None
|
|
|
|
|
|
|
|
Notes
|
|
|
|
-----
|
|
|
|
This method performs several steps including node extraction, edge extraction,
|
|
|
|
deduplication, and database updates. It also handles embedding generation
|
|
|
|
and edge invalidation.
|
|
|
|
|
|
|
|
It is recommended to run this method as a background process, such as in a queue.
|
|
|
|
It's important that each episode is added sequentially and awaited before adding
|
|
|
|
the next one. For web applications, consider using FastAPI's background tasks
|
|
|
|
or a dedicated task queue like Celery for this purpose.
|
|
|
|
|
|
|
|
Example using FastAPI background tasks:
|
|
|
|
@app.post("/add_episode")
|
|
|
|
async def add_episode_endpoint(episode_data: EpisodeData):
|
|
|
|
background_tasks.add_task(graphiti.add_episode, **episode_data.dict())
|
|
|
|
return {"message": "Episode processing started"}
|
|
|
|
"""
|
2024-08-23 14:18:45 -04:00
|
|
|
try:
|
|
|
|
start = time()
|
|
|
|
|
|
|
|
nodes: list[EntityNode] = []
|
|
|
|
entity_edges: list[EntityEdge] = []
|
|
|
|
embedder = self.llm_client.get_embedder()
|
|
|
|
now = datetime.now()
|
|
|
|
|
2024-09-06 12:33:42 -04:00
|
|
|
previous_episodes = await self.retrieve_episodes(
|
|
|
|
reference_time, last_n=3, group_ids=[group_id]
|
|
|
|
)
|
2024-08-23 14:18:45 -04:00
|
|
|
episode = EpisodicNode(
|
|
|
|
name=name,
|
2024-09-06 12:33:42 -04:00
|
|
|
group_id=group_id,
|
2024-08-23 14:18:45 -04:00
|
|
|
labels=[],
|
2024-08-26 10:30:22 -04:00
|
|
|
source=source,
|
2024-08-23 14:18:45 -04:00
|
|
|
content=episode_body,
|
|
|
|
source_description=source_description,
|
|
|
|
created_at=now,
|
|
|
|
valid_at=reference_time,
|
|
|
|
)
|
2024-09-06 12:33:42 -04:00
|
|
|
episode.uuid = uuid if uuid is not None else episode.uuid
|
2024-08-23 14:18:45 -04:00
|
|
|
|
2024-09-03 13:25:52 -04:00
|
|
|
# Extract entities as nodes
|
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
|
|
|
|
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
|
|
|
|
|
|
# Calculate Embeddings
|
|
|
|
|
|
|
|
await asyncio.gather(
|
|
|
|
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
|
|
|
|
)
|
2024-09-03 13:25:52 -04:00
|
|
|
|
2024-09-05 12:05:44 -04:00
|
|
|
# Resolve extracted nodes with nodes already in the graph and extract facts
|
2024-09-03 13:25:52 -04:00
|
|
|
existing_nodes_lists: list[list[EntityNode]] = list(
|
|
|
|
await asyncio.gather(
|
|
|
|
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
2024-09-03 13:25:52 -04:00
|
|
|
|
2024-09-05 12:05:44 -04:00
|
|
|
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
|
|
|
|
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
|
2024-09-06 12:33:42 -04:00
|
|
|
extract_edges(
|
|
|
|
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
|
|
|
|
),
|
2024-08-23 14:18:45 -04:00
|
|
|
)
|
2024-09-03 13:25:52 -04:00
|
|
|
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
|
|
|
nodes.extend(mentioned_nodes)
|
2024-08-23 14:18:45 -04:00
|
|
|
|
2024-09-05 12:05:44 -04:00
|
|
|
extracted_edges_with_resolved_pointers = resolve_edge_pointers(
|
|
|
|
extracted_edges, uuid_map
|
2024-08-23 14:18:45 -04:00
|
|
|
)
|
|
|
|
|
2024-09-03 13:25:52 -04:00
|
|
|
# calculate embeddings
|
2024-09-05 12:05:44 -04:00
|
|
|
await asyncio.gather(
|
|
|
|
*[
|
|
|
|
edge.generate_embedding(embedder)
|
|
|
|
for edge in extracted_edges_with_resolved_pointers
|
|
|
|
]
|
|
|
|
)
|
2024-08-23 14:18:45 -04:00
|
|
|
|
2024-09-05 12:05:44 -04:00
|
|
|
# Resolve extracted edges with related edges already in the graph
|
|
|
|
related_edges_list: list[list[EntityEdge]] = list(
|
2024-09-03 13:25:52 -04:00
|
|
|
await asyncio.gather(
|
|
|
|
*[
|
|
|
|
get_relevant_edges(
|
|
|
|
self.driver,
|
2024-09-04 10:05:45 -04:00
|
|
|
[edge],
|
2024-09-03 13:25:52 -04:00
|
|
|
edge.source_node_uuid,
|
|
|
|
edge.target_node_uuid,
|
2024-09-04 10:05:45 -04:00
|
|
|
RELEVANT_SCHEMA_LIMIT,
|
2024-09-03 13:25:52 -04:00
|
|
|
)
|
2024-09-05 12:05:44 -04:00
|
|
|
for edge in extracted_edges_with_resolved_pointers
|
2024-09-03 13:25:52 -04:00
|
|
|
]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
logger.info(
|
2024-09-05 12:05:44 -04:00
|
|
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
|
2024-09-03 13:25:52 -04:00
|
|
|
)
|
2024-09-05 12:05:44 -04:00
|
|
|
logger.info(
|
|
|
|
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
|
2024-08-23 14:18:45 -04:00
|
|
|
)
|
|
|
|
|
2024-09-05 12:05:44 -04:00
|
|
|
existing_source_edges_list: list[list[EntityEdge]] = list(
|
|
|
|
await asyncio.gather(
|
|
|
|
*[
|
|
|
|
get_relevant_edges(
|
|
|
|
self.driver,
|
|
|
|
[edge],
|
|
|
|
edge.source_node_uuid,
|
|
|
|
None,
|
|
|
|
RELEVANT_SCHEMA_LIMIT,
|
|
|
|
)
|
|
|
|
for edge in extracted_edges_with_resolved_pointers
|
|
|
|
]
|
|
|
|
)
|
2024-09-03 13:25:52 -04:00
|
|
|
)
|
|
|
|
|
2024-09-05 12:05:44 -04:00
|
|
|
existing_target_edges_list: list[list[EntityEdge]] = list(
|
|
|
|
await asyncio.gather(
|
|
|
|
*[
|
|
|
|
get_relevant_edges(
|
|
|
|
self.driver,
|
|
|
|
[edge],
|
|
|
|
None,
|
|
|
|
edge.target_node_uuid,
|
|
|
|
RELEVANT_SCHEMA_LIMIT,
|
|
|
|
)
|
|
|
|
for edge in extracted_edges_with_resolved_pointers
|
|
|
|
]
|
|
|
|
)
|
|
|
|
)
|
2024-09-03 13:25:52 -04:00
|
|
|
|
2024-09-05 12:05:44 -04:00
|
|
|
existing_edges_list: list[list[EntityEdge]] = [
|
|
|
|
source_lst + target_lst
|
|
|
|
for source_lst, target_lst in zip(
|
|
|
|
existing_source_edges_list, existing_target_edges_list
|
|
|
|
)
|
2024-09-03 13:25:52 -04:00
|
|
|
]
|
|
|
|
|
2024-09-05 12:05:44 -04:00
|
|
|
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
2024-08-23 14:18:45 -04:00
|
|
|
self.llm_client,
|
2024-09-05 12:05:44 -04:00
|
|
|
extracted_edges_with_resolved_pointers,
|
|
|
|
related_edges_list,
|
|
|
|
existing_edges_list,
|
2024-08-23 14:18:45 -04:00
|
|
|
episode,
|
|
|
|
previous_episodes,
|
|
|
|
)
|
|
|
|
|
2024-09-05 12:05:44 -04:00
|
|
|
entity_edges.extend(resolved_edges + invalidated_edges)
|
2024-08-23 14:18:45 -04:00
|
|
|
|
2024-09-05 12:05:44 -04:00
|
|
|
logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
2024-08-23 14:18:45 -04:00
|
|
|
|
2024-09-06 12:33:42 -04:00
|
|
|
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
|
2024-09-03 13:25:52 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
logger.info(f'Built episodic edges: {episodic_edges}')
|
|
|
|
|
|
|
|
# Future optimization would be using batch operations to save nodes and edges
|
|
|
|
await episode.save(self.driver)
|
|
|
|
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
|
|
|
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
|
|
|
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
|
|
|
|
|
|
|
|
end = time()
|
2024-08-26 10:30:22 -04:00
|
|
|
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
2024-09-03 13:25:52 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
except Exception as e:
|
2024-09-06 12:33:42 -04:00
|
|
|
raise e
|
2024-08-23 14:18:45 -04:00
|
|
|
|
2024-09-09 19:12:59 -07:00
|
|
|
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None = None):
|
2024-08-26 13:11:50 -04:00
|
|
|
"""
|
|
|
|
Process multiple episodes in bulk and update the graph.
|
|
|
|
|
|
|
|
This method extracts information from multiple episodes, creates nodes and edges,
|
|
|
|
and updates the graph database accordingly, all in a single batch operation.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
bulk_episodes : list[RawEpisode]
|
|
|
|
A list of RawEpisode objects to be processed and added to the graph.
|
2024-09-06 12:33:42 -04:00
|
|
|
group_id : str | None
|
|
|
|
An id for the graph partition the episode is a part of.
|
2024-08-26 13:11:50 -04:00
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
None
|
|
|
|
|
|
|
|
Notes
|
|
|
|
-----
|
|
|
|
This method performs several steps including:
|
|
|
|
- Saving all episodes to the database
|
|
|
|
- Retrieving previous episode context for each new episode
|
|
|
|
- Extracting nodes and edges from all episodes
|
|
|
|
- Generating embeddings for nodes and edges
|
|
|
|
- Deduplicating nodes and edges
|
|
|
|
- Saving nodes, episodic edges, and entity edges to the knowledge graph
|
|
|
|
|
|
|
|
This bulk operation is designed for efficiency when processing multiple episodes
|
|
|
|
at once. However, it's important to ensure that the bulk operation doesn't
|
|
|
|
overwhelm system resources. Consider implementing rate limiting or chunking for
|
|
|
|
very large batches of episodes.
|
|
|
|
|
|
|
|
Important: This method does not perform edge invalidation or date extraction steps.
|
|
|
|
If these operations are required, use the `add_episode` method instead for each
|
|
|
|
individual episode.
|
|
|
|
"""
|
2024-08-23 14:18:45 -04:00
|
|
|
try:
|
|
|
|
start = time()
|
|
|
|
embedder = self.llm_client.get_embedder()
|
|
|
|
now = datetime.now()
|
|
|
|
|
|
|
|
episodes = [
|
|
|
|
EpisodicNode(
|
|
|
|
name=episode.name,
|
|
|
|
labels=[],
|
2024-08-26 10:30:22 -04:00
|
|
|
source=episode.source,
|
2024-08-23 14:18:45 -04:00
|
|
|
content=episode.content,
|
|
|
|
source_description=episode.source_description,
|
2024-09-06 12:33:42 -04:00
|
|
|
group_id=group_id,
|
2024-08-23 14:18:45 -04:00
|
|
|
created_at=now,
|
|
|
|
valid_at=episode.reference_time,
|
|
|
|
)
|
|
|
|
for episode in bulk_episodes
|
|
|
|
]
|
|
|
|
|
|
|
|
# Save all the episodes
|
|
|
|
await asyncio.gather(*[episode.save(self.driver) for episode in episodes])
|
|
|
|
|
|
|
|
# Get previous episode context for each episode
|
|
|
|
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
|
|
|
|
|
|
|
# Extract all nodes and edges
|
|
|
|
(
|
|
|
|
extracted_nodes,
|
|
|
|
extracted_edges,
|
|
|
|
episodic_edges,
|
|
|
|
) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
|
|
|
|
|
|
|
|
# Generate embeddings
|
|
|
|
await asyncio.gather(
|
|
|
|
*[node.generate_name_embedding(embedder) for node in extracted_nodes],
|
|
|
|
*[edge.generate_embedding(embedder) for edge in extracted_edges],
|
|
|
|
)
|
|
|
|
|
2024-08-30 10:48:28 -04:00
|
|
|
# Dedupe extracted nodes, compress extracted edges
|
|
|
|
(nodes, uuid_map), extracted_edges_timestamped = await asyncio.gather(
|
|
|
|
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
|
|
|
|
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
|
|
|
|
)
|
2024-08-23 14:18:45 -04:00
|
|
|
|
|
|
|
# save nodes to KG
|
|
|
|
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
|
|
|
|
|
|
|
# re-map edge pointers so that they don't point to discard dupe nodes
|
|
|
|
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
|
2024-08-30 10:48:28 -04:00
|
|
|
extracted_edges_timestamped, uuid_map
|
2024-08-23 14:18:45 -04:00
|
|
|
)
|
|
|
|
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
|
|
|
|
episodic_edges, uuid_map
|
|
|
|
)
|
|
|
|
|
|
|
|
# save episodic edges to KG
|
|
|
|
await asyncio.gather(
|
|
|
|
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
|
|
|
|
)
|
|
|
|
|
|
|
|
# Dedupe extracted edges
|
|
|
|
edges = await dedupe_edges_bulk(
|
|
|
|
self.driver, self.llm_client, extracted_edges_with_resolved_pointers
|
|
|
|
)
|
|
|
|
logger.info(f'extracted edge length: {len(edges)}')
|
|
|
|
|
|
|
|
# invalidate edges
|
|
|
|
|
|
|
|
# save edges to KG
|
|
|
|
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
|
|
|
|
|
|
|
|
end = time()
|
2024-08-26 10:30:22 -04:00
|
|
|
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
|
2024-08-23 14:18:45 -04:00
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
raise e
|
|
|
|
|
2024-09-11 12:06:35 -04:00
|
|
|
async def build_communities(self):
|
|
|
|
embedder = self.llm_client.get_embedder()
|
|
|
|
|
|
|
|
# Clear existing communities
|
|
|
|
await remove_communities(self.driver)
|
|
|
|
|
|
|
|
community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
|
|
|
|
|
|
|
|
await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes])
|
|
|
|
|
|
|
|
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
|
|
|
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
|
|
|
|
2024-09-06 12:33:42 -04:00
|
|
|
async def search(
|
|
|
|
self,
|
|
|
|
query: str,
|
|
|
|
center_node_uuid: str | None = None,
|
|
|
|
group_ids: list[str | None] | None = None,
|
2024-09-16 14:03:05 -04:00
|
|
|
num_results=DEFAULT_SEARCH_LIMIT,
|
2024-09-06 12:33:42 -04:00
|
|
|
):
|
2024-08-26 13:11:50 -04:00
|
|
|
"""
|
|
|
|
Perform a hybrid search on the knowledge graph.
|
|
|
|
|
|
|
|
This method executes a search query on the graph, combining vector and
|
|
|
|
text-based search techniques to retrieve relevant facts.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
query : str
|
|
|
|
The search query string.
|
2024-08-26 18:34:57 -04:00
|
|
|
center_node_uuid: str, optional
|
|
|
|
Facts will be reranked based on proximity to this node
|
2024-09-06 12:33:42 -04:00
|
|
|
group_ids : list[str | None] | None, optional
|
|
|
|
The graph partitions to return data from.
|
2024-09-16 14:03:05 -04:00
|
|
|
limit : int, optional
|
2024-08-26 13:11:50 -04:00
|
|
|
The maximum number of results to return. Defaults to 10.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
list
|
2024-08-26 17:24:35 -07:00
|
|
|
A list of EntityEdge objects that are relevant to the search query.
|
2024-08-26 13:11:50 -04:00
|
|
|
|
|
|
|
Notes
|
|
|
|
-----
|
|
|
|
This method uses a SearchConfig with num_episodes set to 0 and
|
|
|
|
num_results set to the provided num_results parameter. It then calls
|
|
|
|
the hybrid_search function to perform the actual search operation.
|
|
|
|
|
|
|
|
The search is performed using the current date and time as the reference
|
|
|
|
point for temporal relevance.
|
|
|
|
"""
|
2024-09-16 14:03:05 -04:00
|
|
|
search_config = (
|
|
|
|
EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
|
2024-08-26 18:34:57 -04:00
|
|
|
)
|
2024-09-16 14:03:05 -04:00
|
|
|
search_config.limit = num_results
|
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
edges = (
|
2024-09-16 14:03:05 -04:00
|
|
|
await search(
|
2024-08-23 14:18:45 -04:00
|
|
|
self.driver,
|
|
|
|
self.llm_client.get_embedder(),
|
|
|
|
query,
|
2024-09-16 14:03:05 -04:00
|
|
|
group_ids,
|
2024-08-23 14:18:45 -04:00
|
|
|
search_config,
|
2024-08-26 18:34:57 -04:00
|
|
|
center_node_uuid,
|
2024-08-23 14:18:45 -04:00
|
|
|
)
|
|
|
|
).edges
|
|
|
|
|
2024-08-26 17:24:35 -07:00
|
|
|
return edges
|
2024-08-23 14:18:45 -04:00
|
|
|
|
2024-08-26 18:34:57 -04:00
|
|
|
async def _search(
|
2024-09-04 10:05:45 -04:00
|
|
|
self,
|
|
|
|
query: str,
|
|
|
|
config: SearchConfig,
|
2024-09-16 14:03:05 -04:00
|
|
|
group_ids: list[str | None] | None = None,
|
2024-09-04 10:05:45 -04:00
|
|
|
center_node_uuid: str | None = None,
|
2024-09-16 14:03:05 -04:00
|
|
|
) -> SearchResults:
|
|
|
|
return await search(
|
|
|
|
self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid
|
2024-08-23 14:18:45 -04:00
|
|
|
)
|
2024-08-26 20:00:28 -07:00
|
|
|
|
2024-08-30 10:48:28 -04:00
|
|
|
async def get_nodes_by_query(
|
2024-09-06 12:33:42 -04:00
|
|
|
self,
|
|
|
|
query: str,
|
2024-09-16 14:03:05 -04:00
|
|
|
center_node_uuid: str | None = None,
|
2024-09-06 12:33:42 -04:00
|
|
|
group_ids: list[str | None] | None = None,
|
2024-09-16 14:03:05 -04:00
|
|
|
limit: int = DEFAULT_SEARCH_LIMIT,
|
2024-08-30 10:48:28 -04:00
|
|
|
) -> list[EntityNode]:
|
2024-08-26 20:00:28 -07:00
|
|
|
"""
|
|
|
|
Retrieve nodes from the graph database based on a text query.
|
|
|
|
|
|
|
|
This method performs a hybrid search using both text-based and
|
|
|
|
embedding-based approaches to find relevant nodes.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
query : str
|
2024-09-16 14:03:05 -04:00
|
|
|
The text query to search for in the graph
|
|
|
|
center_node_uuid: str, optional
|
|
|
|
Facts will be reranked based on proximity to this node.
|
2024-09-06 12:33:42 -04:00
|
|
|
group_ids : list[str | None] | None, optional
|
|
|
|
The graph partitions to return data from.
|
2024-08-26 20:00:28 -07:00
|
|
|
limit : int | None, optional
|
|
|
|
The maximum number of results to return per search method.
|
|
|
|
If None, a default limit will be applied.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
list[EntityNode]
|
|
|
|
A list of EntityNode objects that match the search criteria.
|
|
|
|
|
|
|
|
Notes
|
|
|
|
-----
|
|
|
|
This method uses the following steps:
|
|
|
|
1. Generates an embedding for the input query using the LLM client's embedder.
|
|
|
|
2. Calls the hybrid_node_search function with both the text query and its embedding.
|
|
|
|
3. The hybrid search combines fulltext search and vector similarity search
|
|
|
|
to find the most relevant nodes.
|
|
|
|
|
|
|
|
The method leverages the LLM client's embedding capabilities to enhance
|
|
|
|
the search with semantic similarity matching. The 'limit' parameter is applied
|
|
|
|
to each individual search method before results are combined and deduplicated.
|
|
|
|
If not specified, a default limit (defined in the search functions) will be used.
|
|
|
|
"""
|
|
|
|
embedder = self.llm_client.get_embedder()
|
2024-09-16 14:03:05 -04:00
|
|
|
search_config = (
|
|
|
|
NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE
|
2024-09-06 12:33:42 -04:00
|
|
|
)
|
2024-09-16 14:03:05 -04:00
|
|
|
search_config.limit = limit
|
|
|
|
|
|
|
|
nodes = (
|
|
|
|
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
|
|
|
|
).nodes
|
|
|
|
return nodes
|