Date filters (#240)

* add search filters

* add search filters

* mypy

* mypy

* update filtering

* date-filters

* update

* update filter queries

* update dictionary
This commit is contained in:
Preston Rasmussen 2025-01-28 11:52:53 -05:00 committed by GitHub
parent d3b2cecbe5
commit 6ef2f5e097
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 225 additions and 27 deletions

View File

@ -35,6 +35,7 @@ from graphiti_core.search.search_config_recipes import (
EDGE_HYBRID_SEARCH_NODE_DISTANCE,
EDGE_HYBRID_SEARCH_RRF,
)
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
get_communities_by_nodes,
@ -625,6 +626,7 @@ class Graphiti:
center_node_uuid: str | None = None,
group_ids: list[str] | None = None,
num_results=DEFAULT_SEARCH_LIMIT,
search_filter: SearchFilters | None = None,
) -> list[EntityEdge]:
"""
Perform a hybrid search on the knowledge graph.
@ -670,6 +672,7 @@ class Graphiti:
query,
group_ids,
search_config,
search_filter if search_filter is not None else SearchFilters(),
center_node_uuid,
)
).edges
@ -683,6 +686,7 @@ class Graphiti:
group_ids: list[str] | None = None,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
search_filter: SearchFilters | None = None,
) -> SearchResults:
return await search(
self.driver,
@ -691,6 +695,7 @@ class Graphiti:
query,
group_ids,
config,
search_filter if search_filter is not None else SearchFilters(),
center_node_uuid,
bfs_origin_node_uuids,
)

View File

@ -39,6 +39,7 @@ from graphiti_core.search.search_config import (
SearchConfig,
SearchResults,
)
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import (
community_fulltext_search,
community_similarity_search,
@ -64,6 +65,7 @@ async def search(
query: str,
group_ids: list[str] | None,
config: SearchConfig,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
) -> SearchResults:
@ -86,6 +88,7 @@ async def search(
query_vector,
group_ids,
config.edge_config,
search_filter,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
@ -133,6 +136,7 @@ async def edge_search(
query_vector: list[float],
group_ids: list[str] | None,
config: EdgeSearchConfig | None,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
@ -143,11 +147,20 @@ async def edge_search(
search_results: list[list[EntityEdge]] = list(
await semaphore_gather(
*[
edge_fulltext_search(driver, query, group_ids, 2 * limit),
edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
edge_similarity_search(
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
driver,
query_vector,
None,
None,
search_filter,
group_ids,
2 * limit,
config.sim_min_score,
),
edge_bfs_search(
driver, bfs_origin_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
),
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
]
)
)
@ -155,7 +168,9 @@ async def edge_search(
if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
search_results.append(
await edge_bfs_search(driver, source_node_uuids, config.bfs_max_depth, 2 * limit)
await edge_bfs_search(
driver, source_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
)
)
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}

View File

@ -0,0 +1,152 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
from typing_extensions import LiteralString
class ComparisonOperator(Enum):
equals = '='
not_equals = '<>'
greater_than = '>'
less_than = '<'
greater_than_equal = '>='
less_than_equal = '<='
class DateFilter(BaseModel):
date: datetime = Field(description='A datetime to filter on')
comparison_operator: ComparisonOperator = Field(
description='Comparison operator for date filter'
)
class SearchFilters(BaseModel):
valid_at: list[list[DateFilter]] | None = Field(default=None)
invalid_at: list[list[DateFilter]] | None = Field(default=None)
created_at: list[list[DateFilter]] | None = Field(default=None)
expired_at: list[list[DateFilter]] | None = Field(default=None)
def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralString, dict[str, Any]]:
filter_query: LiteralString = ''
filter_params: dict[str, Any] = {}
if filters.valid_at is not None:
valid_at_filter = 'AND ('
for i, or_list in enumerate(filters.valid_at):
for j, date_filter in enumerate(or_list):
filter_params['valid_at_' + str(j)] = date_filter.date
and_filters = [
'(r.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filter_query) - 1:
and_filter_query += ' AND '
valid_at_filter += and_filter_query
if i == len(or_list) - 1:
valid_at_filter += ')'
else:
valid_at_filter += ' OR '
filter_query += valid_at_filter
if filters.invalid_at is not None:
invalid_at_filter = 'AND ('
for i, or_list in enumerate(filters.invalid_at):
for j, date_filter in enumerate(or_list):
filter_params['invalid_at_' + str(j)] = date_filter.date
and_filters = [
'(r.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filter_query) - 1:
and_filter_query += ' AND '
invalid_at_filter += and_filter_query
if i == len(or_list) - 1:
invalid_at_filter += ')'
else:
invalid_at_filter += ' OR '
filter_query += invalid_at_filter
if filters.created_at is not None:
created_at_filter = 'AND ('
for i, or_list in enumerate(filters.created_at):
for j, date_filter in enumerate(or_list):
filter_params['created_at_' + str(j)] = date_filter.date
and_filters = [
'(r.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filter_query) - 1:
and_filter_query += ' AND '
created_at_filter += and_filter_query
if i == len(or_list) - 1:
created_at_filter += ')'
else:
created_at_filter += ' OR '
filter_query += created_at_filter
if filters.expired_at is not None:
expired_at_filter = 'AND ('
for i, or_list in enumerate(filters.expired_at):
for j, date_filter in enumerate(or_list):
filter_params['expired_at_' + str(j)] = date_filter.date
and_filters = [
'(r.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filter_query) - 1:
and_filter_query += ' AND '
expired_at_filter += and_filter_query
if i == len(or_list) - 1:
expired_at_filter += ')'
else:
expired_at_filter += ' OR '
filter_query += expired_at_filter
return filter_query, filter_params

View File

@ -38,6 +38,7 @@ from graphiti_core.nodes import (
get_community_node_from_record,
get_entity_node_from_record,
)
from graphiti_core.search.search_filters import SearchFilters, search_filter_query_constructor
logger = logging.getLogger(__name__)
@ -136,6 +137,7 @@ async def get_communities_by_nodes(
async def edge_fulltext_search(
driver: AsyncDriver,
query: str,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
@ -144,28 +146,36 @@ async def edge_fulltext_search(
if fuzzy_query == '':
return []
cypher_query = Query("""
filter_query, filter_params = search_filter_query_constructor(search_filter)
cypher_query = Query(
"""
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit})
YIELD relationship AS r, score
WITH r, score, startNode(r) AS n, endNode(r) AS m
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""")
YIELD relationship AS rel, score
MATCH (:ENTITY)-[r:RELATES_TO]->(:ENTITY)
WHERE r.group_id IN $group_ids"""
+ filter_query
+ """\nWITH r, score, startNode(r) AS n, endNode(r) AS m
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
cypher_query,
filter_params,
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
@ -183,6 +193,7 @@ async def edge_similarity_search(
search_vector: list[float],
source_node_uuid: str | None,
target_node_uuid: str | None,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE,
@ -194,6 +205,9 @@ async def edge_similarity_search(
query_params: dict[str, Any] = {}
filter_query, filter_params = search_filter_query_constructor(search_filter)
query_params.update(filter_params)
group_filter_query: LiteralString = ''
if group_ids is not None:
group_filter_query += 'WHERE r.group_id IN $group_ids'
@ -209,9 +223,10 @@ async def edge_similarity_search(
query: LiteralString = (
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
+ group_filter_query
+ filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
@ -254,17 +269,25 @@ async def edge_bfs_search(
driver: AsyncDriver,
bfs_origin_node_uuids: list[str] | None,
bfs_max_depth: int,
search_filter: SearchFilters,
limit: int,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
if bfs_origin_node_uuids is None:
return []
query = Query("""
filter_query, filter_params = search_filter_query_constructor(search_filter)
query = Query(
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel
MATCH ()-[r:RELATES_TO {uuid: rel.uuid}]-()
MATCH ()-[r:RELATES_TO]-()
WHERE r.uuid = rel.uuid
"""
+ filter_query
+ """
RETURN DISTINCT
r.uuid AS uuid,
r.group_id AS group_id,
@ -279,10 +302,12 @@ async def edge_bfs_search(
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
LIMIT $limit
""")
"""
)
records, _, _ = await driver.execute_query(
query,
filter_params,
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
limit=limit,
@ -626,6 +651,7 @@ async def get_relevant_edges(
edge.fact_embedding,
source_node_uuid,
target_node_uuid,
SearchFilters(),
[edge.group_id],
limit,
)