mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
Node episodes list (#381)
* added episode list virtual field * in progress tests * add tests * update search return type * linter * copyright notice * mark integration tests
This commit is contained in:
parent
064d9207d2
commit
009467650f
@ -38,6 +38,20 @@ from graphiti_core.utils.datetime_utils import utc_now
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ENTITY_NODE_RETURN: LiteralString = """
|
||||||
|
OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n)
|
||||||
|
WITH n, collect(e.uuid) AS episodes
|
||||||
|
RETURN
|
||||||
|
n.uuid As uuid,
|
||||||
|
n.name AS name,
|
||||||
|
n.name_embedding AS name_embedding,
|
||||||
|
n.group_id AS group_id,
|
||||||
|
n.created_at AS created_at,
|
||||||
|
n.summary AS summary,
|
||||||
|
labels(n) AS labels,
|
||||||
|
properties(n) AS attributes,
|
||||||
|
episodes"""
|
||||||
|
|
||||||
|
|
||||||
class EpisodeType(Enum):
|
class EpisodeType(Enum):
|
||||||
"""
|
"""
|
||||||
@ -280,6 +294,9 @@ class EpisodicNode(Node):
|
|||||||
class EntityNode(Node):
|
class EntityNode(Node):
|
||||||
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
||||||
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
||||||
|
episodes: list[str] | None = Field(
|
||||||
|
default=None, description='List of episode uuids that mention this node.'
|
||||||
|
)
|
||||||
attributes: dict[str, Any] = Field(
|
attributes: dict[str, Any] = Field(
|
||||||
default={}, description='Additional attributes of the node. Dependent on node labels'
|
default={}, description='Additional attributes of the node. Dependent on node labels'
|
||||||
)
|
)
|
||||||
@ -318,19 +335,14 @@ class EntityNode(Node):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity {uuid: $uuid})
|
MATCH (n:Entity {uuid: $uuid})
|
||||||
RETURN
|
"""
|
||||||
n.uuid As uuid,
|
+ ENTITY_NODE_RETURN
|
||||||
n.name AS name,
|
)
|
||||||
n.name_embedding AS name_embedding,
|
records, _, _ = await driver.execute_query(
|
||||||
n.group_id AS group_id,
|
query,
|
||||||
n.created_at AS created_at,
|
|
||||||
n.summary AS summary,
|
|
||||||
labels(n) AS labels,
|
|
||||||
properties(n) AS attributes
|
|
||||||
""",
|
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
@ -348,16 +360,8 @@ class EntityNode(Node):
|
|||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity) WHERE n.uuid IN $uuids
|
MATCH (n:Entity) WHERE n.uuid IN $uuids
|
||||||
RETURN
|
"""
|
||||||
n.uuid As uuid,
|
+ ENTITY_NODE_RETURN,
|
||||||
n.name AS name,
|
|
||||||
n.name_embedding AS name_embedding,
|
|
||||||
n.group_id AS group_id,
|
|
||||||
n.created_at AS created_at,
|
|
||||||
n.summary AS summary,
|
|
||||||
labels(n) AS labels,
|
|
||||||
properties(n) AS attributes
|
|
||||||
""",
|
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
@ -383,16 +387,8 @@ class EntityNode(Node):
|
|||||||
MATCH (n:Entity) WHERE n.group_id IN $group_ids
|
MATCH (n:Entity) WHERE n.group_id IN $group_ids
|
||||||
"""
|
"""
|
||||||
+ cursor_query
|
+ cursor_query
|
||||||
|
+ ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
RETURN
|
|
||||||
n.uuid As uuid,
|
|
||||||
n.name AS name,
|
|
||||||
n.name_embedding AS name_embedding,
|
|
||||||
n.group_id AS group_id,
|
|
||||||
n.created_at AS created_at,
|
|
||||||
n.summary AS summary,
|
|
||||||
labels(n) AS labels,
|
|
||||||
properties(n) AS attributes
|
|
||||||
ORDER BY n.uuid DESC
|
ORDER BY n.uuid DESC
|
||||||
"""
|
"""
|
||||||
+ limit_query,
|
+ limit_query,
|
||||||
@ -548,6 +544,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
|||||||
created_at=record['created_at'].to_native(),
|
created_at=record['created_at'].to_native(),
|
||||||
summary=record['summary'],
|
summary=record['summary'],
|
||||||
attributes=record['attributes'],
|
attributes=record['attributes'],
|
||||||
|
episodes=record['episodes'],
|
||||||
)
|
)
|
||||||
|
|
||||||
entity_node.attributes.pop('uuid', None)
|
entity_node.attributes.pop('uuid', None)
|
||||||
|
@ -32,6 +32,7 @@ from graphiti_core.helpers import (
|
|||||||
semaphore_gather,
|
semaphore_gather,
|
||||||
)
|
)
|
||||||
from graphiti_core.nodes import (
|
from graphiti_core.nodes import (
|
||||||
|
ENTITY_NODE_RETURN,
|
||||||
CommunityNode,
|
CommunityNode,
|
||||||
EntityNode,
|
EntityNode,
|
||||||
EpisodicNode,
|
EpisodicNode,
|
||||||
@ -53,6 +54,20 @@ DEFAULT_MMR_LAMBDA = 0.5
|
|||||||
MAX_SEARCH_DEPTH = 3
|
MAX_SEARCH_DEPTH = 3
|
||||||
MAX_QUERY_LENGTH = 32
|
MAX_QUERY_LENGTH = 32
|
||||||
|
|
||||||
|
SEARCH_ENTITY_NODE_RETURN: LiteralString = """
|
||||||
|
OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n)
|
||||||
|
WITH n, score, collect(e.uuid) AS episodes
|
||||||
|
RETURN
|
||||||
|
n.uuid As uuid,
|
||||||
|
n.name AS name,
|
||||||
|
n.name_embedding AS name_embedding,
|
||||||
|
n.group_id AS group_id,
|
||||||
|
n.created_at AS created_at,
|
||||||
|
n.summary AS summary,
|
||||||
|
labels(n) AS labels,
|
||||||
|
properties(n) AS attributes,
|
||||||
|
episodes"""
|
||||||
|
|
||||||
|
|
||||||
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||||
group_ids_filter_list = (
|
group_ids_filter_list = (
|
||||||
@ -230,8 +245,8 @@ async def edge_similarity_search(
|
|||||||
|
|
||||||
query: LiteralString = (
|
query: LiteralString = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||||
@ -341,27 +356,21 @@ async def node_fulltext_search(
|
|||||||
|
|
||||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
query = (
|
||||||
"""
|
|
||||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
|
||||||
YIELD node AS node, score
|
|
||||||
MATCH (n:Entity)
|
|
||||||
WHERE n.uuid = node.uuid
|
|
||||||
"""
|
"""
|
||||||
|
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||||
|
YIELD node AS n, score
|
||||||
|
WHERE n:Entity
|
||||||
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
|
+ SEARCH_ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
RETURN
|
|
||||||
n.uuid AS uuid,
|
|
||||||
n.group_id AS group_id,
|
|
||||||
n.name AS name,
|
|
||||||
n.name_embedding AS name_embedding,
|
|
||||||
n.created_at AS created_at,
|
|
||||||
n.summary AS summary,
|
|
||||||
labels(n) AS labels,
|
|
||||||
properties(n) AS attributes
|
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
"""
|
||||||
""",
|
)
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
filter_params,
|
filter_params,
|
||||||
query=fuzzy_query,
|
query=fuzzy_query,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
@ -406,19 +415,12 @@ async def node_similarity_search(
|
|||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
||||||
WHERE score > $min_score
|
WHERE score > $min_score"""
|
||||||
RETURN
|
+ SEARCH_ENTITY_NODE_RETURN
|
||||||
n.uuid As uuid,
|
+ """
|
||||||
n.group_id AS group_id,
|
ORDER BY score DESC
|
||||||
n.name AS name,
|
LIMIT $limit
|
||||||
n.name_embedding AS name_embedding,
|
""",
|
||||||
n.created_at AS created_at,
|
|
||||||
n.summary AS summary,
|
|
||||||
labels(n) AS labels,
|
|
||||||
properties(n) AS attributes
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
""",
|
|
||||||
query_params,
|
query_params,
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
@ -452,16 +454,8 @@ async def node_bfs_search(
|
|||||||
WHERE n.group_id = origin.group_id
|
WHERE n.group_id = origin.group_id
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
|
+ ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT
|
|
||||||
n.uuid As uuid,
|
|
||||||
n.group_id AS group_id,
|
|
||||||
n.name AS name,
|
|
||||||
n.name_embedding AS name_embedding,
|
|
||||||
n.created_at AS created_at,
|
|
||||||
n.summary AS summary,
|
|
||||||
labels(n) AS labels,
|
|
||||||
properties(n) AS attributes
|
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""",
|
""",
|
||||||
filter_params,
|
filter_params,
|
||||||
|
@ -65,9 +65,7 @@ async def test_graphiti_init():
|
|||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||||
|
|
||||||
results = await graphiti.search_(
|
results = await graphiti.search_(query='Who is the User?')
|
||||||
query='Who is the User?',
|
|
||||||
)
|
|
||||||
|
|
||||||
pretty_results = search_results_to_context_string(results)
|
pretty_results = search_results_to_context_string(results)
|
||||||
|
|
||||||
|
122
tests/test_node_int.py
Normal file
122
tests/test_node_int.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from neo4j import AsyncGraphDatabase
|
||||||
|
|
||||||
|
from graphiti_core.nodes import (
|
||||||
|
CommunityNode,
|
||||||
|
EntityNode,
|
||||||
|
EpisodeType,
|
||||||
|
EpisodicNode,
|
||||||
|
)
|
||||||
|
|
||||||
|
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
|
||||||
|
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
|
||||||
|
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_entity_node():
|
||||||
|
return EntityNode(
|
||||||
|
uuid=str(uuid4()),
|
||||||
|
name='Test Entity',
|
||||||
|
group_id='test_group',
|
||||||
|
labels=['Entity'],
|
||||||
|
name_embedding=[0.5] * 1024,
|
||||||
|
summary='Entity Summary',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_episodic_node():
|
||||||
|
return EpisodicNode(
|
||||||
|
uuid=str(uuid4()),
|
||||||
|
name='Episode 1',
|
||||||
|
group_id='test_group',
|
||||||
|
source=EpisodeType.text,
|
||||||
|
source_description='Test source',
|
||||||
|
content='Some content here',
|
||||||
|
valid_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_community_node():
|
||||||
|
return CommunityNode(
|
||||||
|
uuid=str(uuid4()),
|
||||||
|
name='Community A',
|
||||||
|
name_embedding=[0.5] * 1024,
|
||||||
|
group_id='test_group',
|
||||||
|
summary='Community summary',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_entity_node_save_get_and_delete(sample_entity_node):
|
||||||
|
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
||||||
|
await sample_entity_node.save(neo4j_driver)
|
||||||
|
retrieved = await EntityNode.get_by_uuid(neo4j_driver, sample_entity_node.uuid)
|
||||||
|
assert retrieved.uuid == sample_entity_node.uuid
|
||||||
|
assert retrieved.name == 'Test Entity'
|
||||||
|
assert retrieved.group_id == 'test_group'
|
||||||
|
|
||||||
|
await sample_entity_node.delete(neo4j_driver)
|
||||||
|
|
||||||
|
await neo4j_driver.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_community_node_save_get_and_delete(sample_community_node):
|
||||||
|
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
||||||
|
|
||||||
|
await sample_community_node.save(neo4j_driver)
|
||||||
|
|
||||||
|
retrieved = await CommunityNode.get_by_uuid(neo4j_driver, sample_community_node.uuid)
|
||||||
|
assert retrieved.uuid == sample_community_node.uuid
|
||||||
|
assert retrieved.name == 'Community A'
|
||||||
|
assert retrieved.group_id == 'test_group'
|
||||||
|
assert retrieved.summary == 'Community summary'
|
||||||
|
|
||||||
|
await sample_community_node.delete(neo4j_driver)
|
||||||
|
|
||||||
|
await neo4j_driver.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_episodic_node_save_get_and_delete(sample_episodic_node):
|
||||||
|
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
||||||
|
|
||||||
|
await sample_episodic_node.save(neo4j_driver)
|
||||||
|
|
||||||
|
retrieved = await EpisodicNode.get_by_uuid(neo4j_driver, sample_episodic_node.uuid)
|
||||||
|
assert retrieved.uuid == sample_episodic_node.uuid
|
||||||
|
assert retrieved.name == 'Episode 1'
|
||||||
|
assert retrieved.group_id == 'test_group'
|
||||||
|
assert retrieved.source == EpisodeType.text
|
||||||
|
assert retrieved.source_description == 'Test source'
|
||||||
|
assert retrieved.content == 'Some content here'
|
||||||
|
|
||||||
|
await sample_episodic_node.delete(neo4j_driver)
|
||||||
|
|
||||||
|
await neo4j_driver.close()
|
Loading…
x
Reference in New Issue
Block a user