diff --git a/graphiti_core/search/search_filters.py b/graphiti_core/search/search_filters.py index be574c07..cf11bd70 100644 --- a/graphiti_core/search/search_filters.py +++ b/graphiti_core/search/search_filters.py @@ -21,6 +21,8 @@ from typing import Any from pydantic import BaseModel, Field from typing_extensions import LiteralString +from graphiti_core.helpers import lucene_sanitize + class ComparisonOperator(Enum): equals = '=' @@ -42,6 +44,9 @@ class SearchFilters(BaseModel): node_labels: list[str] | None = Field( default=None, description='List of node labels to filter on' ) + edge_types: list[str] | None = Field( + default=None, description='List of edge types to filter on' + ) valid_at: list[list[DateFilter]] | None = Field(default=None) invalid_at: list[list[DateFilter]] | None = Field(default=None) created_at: list[list[DateFilter]] | None = Field(default=None) @@ -55,7 +60,7 @@ def node_search_filter_query_constructor( filter_params: dict[str, Any] = {} if filters.node_labels is not None: - node_labels = '|'.join(filters.node_labels) + node_labels = '|'.join(list(map(lucene_sanitize, filters.node_labels))) node_label_filter = ' AND n:' + node_labels filter_query += node_label_filter @@ -68,8 +73,19 @@ def edge_search_filter_query_constructor( filter_query: LiteralString = '' filter_params: dict[str, Any] = {} + if filters.edge_types is not None: + edge_types = filters.edge_types + edge_types_filter = '\nAND r.name in $edge_types' + filter_query += edge_types_filter + filter_params['edge_types'] = edge_types + + if filters.node_labels is not None: + node_labels = '|'.join(list(map(lucene_sanitize, filters.node_labels))) + node_label_filter = '\nAND n:' + node_labels + ' AND m:' + node_labels + filter_query += node_label_filter + if filters.valid_at is not None: - valid_at_filter = ' AND (' + valid_at_filter = '\nAND (' for i, or_list in enumerate(filters.valid_at): for j, date_filter in enumerate(or_list): filter_params['valid_at_' + str(j)] = date_filter.date diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 86f26ef3..707704aa 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -159,7 +159,7 @@ async def edge_fulltext_search( """ CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit}) YIELD relationship AS rel, score - MATCH (:Entity)-[r:RELATES_TO]->(:Entity) + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) WHERE r.group_id IN $group_ids""" + filter_query + """\nWITH r, score, startNode(r) AS n, endNode(r) AS m @@ -211,9 +211,9 @@ async def edge_similarity_search( filter_query, filter_params = edge_search_filter_query_constructor(search_filter) query_params.update(filter_params) - group_filter_query: LiteralString = '' + group_filter_query: LiteralString = 'WHERE r.group_id IS NOT NULL' if group_ids is not None: - group_filter_query += 'WHERE r.group_id IN $group_ids' + group_filter_query += '\nAND r.group_id IN $group_ids' query_params['group_ids'] = group_ids query_params['source_node_uuid'] = source_node_uuid query_params['target_node_uuid'] = target_node_uuid @@ -227,8 +227,8 @@ async def edge_similarity_search( query: LiteralString = ( RUNTIME_QUERY + """ - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + """ + group_filter_query + filter_query + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score @@ -287,7 +287,7 @@ async def edge_bfs_search( UNWIND $bfs_origin_node_uuids AS origin_uuid MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) UNWIND relationships(path) AS rel - MATCH ()-[r:RELATES_TO]-() + MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) WHERE r.uuid = rel.uuid """ + filter_query @@ -340,10 +340,10 @@ async def node_fulltext_search( query = ( """ - CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) - YIELD node AS n, score - WHERE n:Entity - """ + CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) + YIELD node AS n, score + WHERE n:Entity + """ + filter_query + ENTITY_NODE_RETURN + """ @@ -376,9 +376,9 @@ async def node_similarity_search( # vector similarity search over entity names query_params: dict[str, Any] = {} - group_filter_query: LiteralString = '' + group_filter_query: LiteralString = 'WHERE n.group_id IS NOT NULL' if group_ids is not None: - group_filter_query += 'WHERE n.group_id IN $group_ids' + group_filter_query += ' AND n.group_id IN $group_ids' query_params['group_ids'] = group_ids filter_query, filter_params = node_search_filter_query_constructor(search_filter) diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 3b21700d..9bb2fed8 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -276,8 +276,8 @@ async def resolve_extracted_edges( # Determine which edge types are relevant for each edge edge_types_lst: list[dict[str, BaseModel]] = [] for extracted_edge in extracted_edges: - source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels - target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels + source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels + ['Entity'] + target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels + ['Entity'] label_tuples = [ (source_label, target_label) for source_label in source_node_labels diff --git a/poetry.lock b/poetry.lock index 517386c6..d9b35e0d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -333,7 +333,7 @@ description = "Timeout context manager for asyncio programs" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, @@ -759,7 +759,7 @@ description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" groups = ["main", "dev"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -2665,7 +2665,6 @@ description = "Fast, correct Python JSON library supporting dataclasses, datetim optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation != \"PyPy\"" files = [ {file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"}, {file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"}, @@ -4497,7 +4496,7 @@ description = "A lil' TOML parser" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, diff --git a/pyproject.toml b/pyproject.toml index 612bf343..8c427c74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.12.0pre1" +version = "0.12.0" authors = [ { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" }, { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },