mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
Node episodes list (#381)
* added episode list virtual field * in progress tests * add tests * update search return type * linter * copyright notice * mark integration tests
This commit is contained in:
parent
064d9207d2
commit
009467650f
@ -38,6 +38,20 @@ from graphiti_core.utils.datetime_utils import utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ENTITY_NODE_RETURN: LiteralString = """
|
||||
OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n)
|
||||
WITH n, collect(e.uuid) AS episodes
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes,
|
||||
episodes"""
|
||||
|
||||
|
||||
class EpisodeType(Enum):
|
||||
"""
|
||||
@ -280,6 +294,9 @@ class EpisodicNode(Node):
|
||||
class EntityNode(Node):
|
||||
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
||||
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
||||
episodes: list[str] | None = Field(
|
||||
default=None, description='List of episode uuids that mention this node.'
|
||||
)
|
||||
attributes: dict[str, Any] = Field(
|
||||
default={}, description='Additional attributes of the node. Dependent on node labels'
|
||||
)
|
||||
@ -318,19 +335,14 @@ class EntityNode(Node):
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
""",
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
)
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
uuid=uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
@ -348,16 +360,8 @@ class EntityNode(Node):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity) WHERE n.uuid IN $uuids
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
""",
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN,
|
||||
uuids=uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
@ -383,16 +387,8 @@ class EntityNode(Node):
|
||||
MATCH (n:Entity) WHERE n.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
ORDER BY n.uuid DESC
|
||||
"""
|
||||
+ limit_query,
|
||||
@ -548,6 +544,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
||||
created_at=record['created_at'].to_native(),
|
||||
summary=record['summary'],
|
||||
attributes=record['attributes'],
|
||||
episodes=record['episodes'],
|
||||
)
|
||||
|
||||
entity_node.attributes.pop('uuid', None)
|
||||
|
@ -32,6 +32,7 @@ from graphiti_core.helpers import (
|
||||
semaphore_gather,
|
||||
)
|
||||
from graphiti_core.nodes import (
|
||||
ENTITY_NODE_RETURN,
|
||||
CommunityNode,
|
||||
EntityNode,
|
||||
EpisodicNode,
|
||||
@ -53,6 +54,20 @@ DEFAULT_MMR_LAMBDA = 0.5
|
||||
MAX_SEARCH_DEPTH = 3
|
||||
MAX_QUERY_LENGTH = 32
|
||||
|
||||
SEARCH_ENTITY_NODE_RETURN: LiteralString = """
|
||||
OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n)
|
||||
WITH n, score, collect(e.uuid) AS episodes
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes,
|
||||
episodes"""
|
||||
|
||||
|
||||
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||
group_ids_filter_list = (
|
||||
@ -230,8 +245,8 @@ async def edge_similarity_search(
|
||||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ filter_query
|
||||
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||
@ -341,27 +356,21 @@ async def node_fulltext_search(
|
||||
|
||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||
YIELD node AS node, score
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid = node.uuid
|
||||
query = (
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||
YIELD node AS n, score
|
||||
WHERE n:Entity
|
||||
"""
|
||||
+ filter_query
|
||||
+ SEARCH_ENTITY_NODE_RETURN
|
||||
+ """
|
||||
RETURN
|
||||
n.uuid AS uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
filter_params,
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
@ -406,19 +415,12 @@ async def node_similarity_search(
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
WHERE score > $min_score"""
|
||||
+ SEARCH_ENTITY_NODE_RETURN
|
||||
+ """
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
query_params,
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
@ -452,16 +454,8 @@ async def node_bfs_search(
|
||||
WHERE n.group_id = origin.group_id
|
||||
"""
|
||||
+ filter_query
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
RETURN DISTINCT
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
LIMIT $limit
|
||||
""",
|
||||
filter_params,
|
||||
|
@ -65,9 +65,7 @@ async def test_graphiti_init():
|
||||
logger = setup_logging()
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
|
||||
results = await graphiti.search_(
|
||||
query='Who is the User?',
|
||||
)
|
||||
results = await graphiti.search_(query='Who is the User?')
|
||||
|
||||
pretty_results = search_results_to_context_string(results)
|
||||
|
||||
|
122
tests/test_node_int.py
Normal file
122
tests/test_node_int.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from neo4j import AsyncGraphDatabase
|
||||
|
||||
from graphiti_core.nodes import (
|
||||
CommunityNode,
|
||||
EntityNode,
|
||||
EpisodeType,
|
||||
EpisodicNode,
|
||||
)
|
||||
|
||||
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
|
||||
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
|
||||
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_entity_node():
|
||||
return EntityNode(
|
||||
uuid=str(uuid4()),
|
||||
name='Test Entity',
|
||||
group_id='test_group',
|
||||
labels=['Entity'],
|
||||
name_embedding=[0.5] * 1024,
|
||||
summary='Entity Summary',
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_episodic_node():
|
||||
return EpisodicNode(
|
||||
uuid=str(uuid4()),
|
||||
name='Episode 1',
|
||||
group_id='test_group',
|
||||
source=EpisodeType.text,
|
||||
source_description='Test source',
|
||||
content='Some content here',
|
||||
valid_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_community_node():
|
||||
return CommunityNode(
|
||||
uuid=str(uuid4()),
|
||||
name='Community A',
|
||||
name_embedding=[0.5] * 1024,
|
||||
group_id='test_group',
|
||||
summary='Community summary',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_entity_node_save_get_and_delete(sample_entity_node):
|
||||
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
||||
await sample_entity_node.save(neo4j_driver)
|
||||
retrieved = await EntityNode.get_by_uuid(neo4j_driver, sample_entity_node.uuid)
|
||||
assert retrieved.uuid == sample_entity_node.uuid
|
||||
assert retrieved.name == 'Test Entity'
|
||||
assert retrieved.group_id == 'test_group'
|
||||
|
||||
await sample_entity_node.delete(neo4j_driver)
|
||||
|
||||
await neo4j_driver.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_community_node_save_get_and_delete(sample_community_node):
|
||||
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
||||
|
||||
await sample_community_node.save(neo4j_driver)
|
||||
|
||||
retrieved = await CommunityNode.get_by_uuid(neo4j_driver, sample_community_node.uuid)
|
||||
assert retrieved.uuid == sample_community_node.uuid
|
||||
assert retrieved.name == 'Community A'
|
||||
assert retrieved.group_id == 'test_group'
|
||||
assert retrieved.summary == 'Community summary'
|
||||
|
||||
await sample_community_node.delete(neo4j_driver)
|
||||
|
||||
await neo4j_driver.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_episodic_node_save_get_and_delete(sample_episodic_node):
|
||||
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
||||
|
||||
await sample_episodic_node.save(neo4j_driver)
|
||||
|
||||
retrieved = await EpisodicNode.get_by_uuid(neo4j_driver, sample_episodic_node.uuid)
|
||||
assert retrieved.uuid == sample_episodic_node.uuid
|
||||
assert retrieved.name == 'Episode 1'
|
||||
assert retrieved.group_id == 'test_group'
|
||||
assert retrieved.source == EpisodeType.text
|
||||
assert retrieved.source_description == 'Test source'
|
||||
assert retrieved.content == 'Some content here'
|
||||
|
||||
await sample_episodic_node.delete(neo4j_driver)
|
||||
|
||||
await neo4j_driver.close()
|
Loading…
x
Reference in New Issue
Block a user