mirror of
https://github.com/getzep/graphiti.git
synced 2026-01-06 04:10:54 +00:00
Date filters (#240)
* add search filters * add search filters * mypy * mypy * update filtering * date-filters * update * update filter queries * update dictionary
This commit is contained in:
parent
d3b2cecbe5
commit
6ef2f5e097
@ -35,6 +35,7 @@ from graphiti_core.search.search_config_recipes import (
|
||||
EDGE_HYBRID_SEARCH_NODE_DISTANCE,
|
||||
EDGE_HYBRID_SEARCH_RRF,
|
||||
)
|
||||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.search.search_utils import (
|
||||
RELEVANT_SCHEMA_LIMIT,
|
||||
get_communities_by_nodes,
|
||||
@ -625,6 +626,7 @@ class Graphiti:
|
||||
center_node_uuid: str | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
num_results=DEFAULT_SEARCH_LIMIT,
|
||||
search_filter: SearchFilters | None = None,
|
||||
) -> list[EntityEdge]:
|
||||
"""
|
||||
Perform a hybrid search on the knowledge graph.
|
||||
@ -670,6 +672,7 @@ class Graphiti:
|
||||
query,
|
||||
group_ids,
|
||||
search_config,
|
||||
search_filter if search_filter is not None else SearchFilters(),
|
||||
center_node_uuid,
|
||||
)
|
||||
).edges
|
||||
@ -683,6 +686,7 @@ class Graphiti:
|
||||
group_ids: list[str] | None = None,
|
||||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
search_filter: SearchFilters | None = None,
|
||||
) -> SearchResults:
|
||||
return await search(
|
||||
self.driver,
|
||||
@ -691,6 +695,7 @@ class Graphiti:
|
||||
query,
|
||||
group_ids,
|
||||
config,
|
||||
search_filter if search_filter is not None else SearchFilters(),
|
||||
center_node_uuid,
|
||||
bfs_origin_node_uuids,
|
||||
)
|
||||
|
||||
@ -39,6 +39,7 @@ from graphiti_core.search.search_config import (
|
||||
SearchConfig,
|
||||
SearchResults,
|
||||
)
|
||||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.search.search_utils import (
|
||||
community_fulltext_search,
|
||||
community_similarity_search,
|
||||
@ -64,6 +65,7 @@ async def search(
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: SearchConfig,
|
||||
search_filter: SearchFilters,
|
||||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
) -> SearchResults:
|
||||
@ -86,6 +88,7 @@ async def search(
|
||||
query_vector,
|
||||
group_ids,
|
||||
config.edge_config,
|
||||
search_filter,
|
||||
center_node_uuid,
|
||||
bfs_origin_node_uuids,
|
||||
config.limit,
|
||||
@ -133,6 +136,7 @@ async def edge_search(
|
||||
query_vector: list[float],
|
||||
group_ids: list[str] | None,
|
||||
config: EdgeSearchConfig | None,
|
||||
search_filter: SearchFilters,
|
||||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
@ -143,11 +147,20 @@ async def edge_search(
|
||||
search_results: list[list[EntityEdge]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
edge_fulltext_search(driver, query, group_ids, 2 * limit),
|
||||
edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
|
||||
edge_similarity_search(
|
||||
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
|
||||
driver,
|
||||
query_vector,
|
||||
None,
|
||||
None,
|
||||
search_filter,
|
||||
group_ids,
|
||||
2 * limit,
|
||||
config.sim_min_score,
|
||||
),
|
||||
edge_bfs_search(
|
||||
driver, bfs_origin_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
|
||||
),
|
||||
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
|
||||
]
|
||||
)
|
||||
)
|
||||
@ -155,7 +168,9 @@ async def edge_search(
|
||||
if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
|
||||
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
|
||||
search_results.append(
|
||||
await edge_bfs_search(driver, source_node_uuids, config.bfs_max_depth, 2 * limit)
|
||||
await edge_bfs_search(
|
||||
driver, source_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
|
||||
)
|
||||
)
|
||||
|
||||
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
||||
|
||||
152
graphiti_core/search/search_filters.py
Normal file
152
graphiti_core/search/search_filters.py
Normal file
@ -0,0 +1,152 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
|
||||
class ComparisonOperator(Enum):
|
||||
equals = '='
|
||||
not_equals = '<>'
|
||||
greater_than = '>'
|
||||
less_than = '<'
|
||||
greater_than_equal = '>='
|
||||
less_than_equal = '<='
|
||||
|
||||
|
||||
class DateFilter(BaseModel):
|
||||
date: datetime = Field(description='A datetime to filter on')
|
||||
comparison_operator: ComparisonOperator = Field(
|
||||
description='Comparison operator for date filter'
|
||||
)
|
||||
|
||||
|
||||
class SearchFilters(BaseModel):
|
||||
valid_at: list[list[DateFilter]] | None = Field(default=None)
|
||||
invalid_at: list[list[DateFilter]] | None = Field(default=None)
|
||||
created_at: list[list[DateFilter]] | None = Field(default=None)
|
||||
expired_at: list[list[DateFilter]] | None = Field(default=None)
|
||||
|
||||
|
||||
def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralString, dict[str, Any]]:
|
||||
filter_query: LiteralString = ''
|
||||
filter_params: dict[str, Any] = {}
|
||||
|
||||
if filters.valid_at is not None:
|
||||
valid_at_filter = 'AND ('
|
||||
for i, or_list in enumerate(filters.valid_at):
|
||||
for j, date_filter in enumerate(or_list):
|
||||
filter_params['valid_at_' + str(j)] = date_filter.date
|
||||
|
||||
and_filters = [
|
||||
'(r.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})'
|
||||
for j, date_filter in enumerate(or_list)
|
||||
]
|
||||
and_filter_query = ''
|
||||
for j, and_filter in enumerate(and_filters):
|
||||
and_filter_query += and_filter
|
||||
if j != len(and_filter_query) - 1:
|
||||
and_filter_query += ' AND '
|
||||
|
||||
valid_at_filter += and_filter_query
|
||||
|
||||
if i == len(or_list) - 1:
|
||||
valid_at_filter += ')'
|
||||
else:
|
||||
valid_at_filter += ' OR '
|
||||
|
||||
filter_query += valid_at_filter
|
||||
|
||||
if filters.invalid_at is not None:
|
||||
invalid_at_filter = 'AND ('
|
||||
for i, or_list in enumerate(filters.invalid_at):
|
||||
for j, date_filter in enumerate(or_list):
|
||||
filter_params['invalid_at_' + str(j)] = date_filter.date
|
||||
|
||||
and_filters = [
|
||||
'(r.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})'
|
||||
for j, date_filter in enumerate(or_list)
|
||||
]
|
||||
and_filter_query = ''
|
||||
for j, and_filter in enumerate(and_filters):
|
||||
and_filter_query += and_filter
|
||||
if j != len(and_filter_query) - 1:
|
||||
and_filter_query += ' AND '
|
||||
|
||||
invalid_at_filter += and_filter_query
|
||||
|
||||
if i == len(or_list) - 1:
|
||||
invalid_at_filter += ')'
|
||||
else:
|
||||
invalid_at_filter += ' OR '
|
||||
|
||||
filter_query += invalid_at_filter
|
||||
|
||||
if filters.created_at is not None:
|
||||
created_at_filter = 'AND ('
|
||||
for i, or_list in enumerate(filters.created_at):
|
||||
for j, date_filter in enumerate(or_list):
|
||||
filter_params['created_at_' + str(j)] = date_filter.date
|
||||
|
||||
and_filters = [
|
||||
'(r.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})'
|
||||
for j, date_filter in enumerate(or_list)
|
||||
]
|
||||
and_filter_query = ''
|
||||
for j, and_filter in enumerate(and_filters):
|
||||
and_filter_query += and_filter
|
||||
if j != len(and_filter_query) - 1:
|
||||
and_filter_query += ' AND '
|
||||
|
||||
created_at_filter += and_filter_query
|
||||
|
||||
if i == len(or_list) - 1:
|
||||
created_at_filter += ')'
|
||||
else:
|
||||
created_at_filter += ' OR '
|
||||
|
||||
filter_query += created_at_filter
|
||||
|
||||
if filters.expired_at is not None:
|
||||
expired_at_filter = 'AND ('
|
||||
for i, or_list in enumerate(filters.expired_at):
|
||||
for j, date_filter in enumerate(or_list):
|
||||
filter_params['expired_at_' + str(j)] = date_filter.date
|
||||
|
||||
and_filters = [
|
||||
'(r.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})'
|
||||
for j, date_filter in enumerate(or_list)
|
||||
]
|
||||
and_filter_query = ''
|
||||
for j, and_filter in enumerate(and_filters):
|
||||
and_filter_query += and_filter
|
||||
if j != len(and_filter_query) - 1:
|
||||
and_filter_query += ' AND '
|
||||
|
||||
expired_at_filter += and_filter_query
|
||||
|
||||
if i == len(or_list) - 1:
|
||||
expired_at_filter += ')'
|
||||
else:
|
||||
expired_at_filter += ' OR '
|
||||
|
||||
filter_query += expired_at_filter
|
||||
|
||||
return filter_query, filter_params
|
||||
@ -38,6 +38,7 @@ from graphiti_core.nodes import (
|
||||
get_community_node_from_record,
|
||||
get_entity_node_from_record,
|
||||
)
|
||||
from graphiti_core.search.search_filters import SearchFilters, search_filter_query_constructor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -136,6 +137,7 @@ async def get_communities_by_nodes(
|
||||
async def edge_fulltext_search(
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
search_filter: SearchFilters,
|
||||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
@ -144,28 +146,36 @@ async def edge_fulltext_search(
|
||||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
cypher_query = Query("""
|
||||
filter_query, filter_params = search_filter_query_constructor(search_filter)
|
||||
|
||||
cypher_query = Query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit})
|
||||
YIELD relationship AS r, score
|
||||
WITH r, score, startNode(r) AS n, endNode(r) AS m
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
""")
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (:ENTITY)-[r:RELATES_TO]->(:ENTITY)
|
||||
WHERE r.group_id IN $group_ids"""
|
||||
+ filter_query
|
||||
+ """\nWITH r, score, startNode(r) AS n, endNode(r) AS m
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
cypher_query,
|
||||
filter_params,
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
@ -183,6 +193,7 @@ async def edge_similarity_search(
|
||||
search_vector: list[float],
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
search_filter: SearchFilters,
|
||||
group_ids: list[str] | None = None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
@ -194,6 +205,9 @@ async def edge_similarity_search(
|
||||
|
||||
query_params: dict[str, Any] = {}
|
||||
|
||||
filter_query, filter_params = search_filter_query_constructor(search_filter)
|
||||
query_params.update(filter_params)
|
||||
|
||||
group_filter_query: LiteralString = ''
|
||||
if group_ids is not None:
|
||||
group_filter_query += 'WHERE r.group_id IN $group_ids'
|
||||
@ -209,9 +223,10 @@ async def edge_similarity_search(
|
||||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ filter_query
|
||||
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
@ -254,17 +269,25 @@ async def edge_bfs_search(
|
||||
driver: AsyncDriver,
|
||||
bfs_origin_node_uuids: list[str] | None,
|
||||
bfs_max_depth: int,
|
||||
search_filter: SearchFilters,
|
||||
limit: int,
|
||||
) -> list[EntityEdge]:
|
||||
# vector similarity search over embedded facts
|
||||
if bfs_origin_node_uuids is None:
|
||||
return []
|
||||
|
||||
query = Query("""
|
||||
filter_query, filter_params = search_filter_query_constructor(search_filter)
|
||||
|
||||
query = Query(
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
UNWIND relationships(path) AS rel
|
||||
MATCH ()-[r:RELATES_TO {uuid: rel.uuid}]-()
|
||||
MATCH ()-[r:RELATES_TO]-()
|
||||
WHERE r.uuid = rel.uuid
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
@ -279,10 +302,12 @@ async def edge_bfs_search(
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
LIMIT $limit
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
filter_params,
|
||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||
depth=bfs_max_depth,
|
||||
limit=limit,
|
||||
@ -626,6 +651,7 @@ async def get_relevant_edges(
|
||||
edge.fact_embedding,
|
||||
source_node_uuid,
|
||||
target_node_uuid,
|
||||
SearchFilters(),
|
||||
[edge.group_id],
|
||||
limit,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user