mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 03:57:19 +00:00
test: update filtering of Pinecone mock to imitate doc store (#3020)
* updated filtering of doc store to imitate pinecone * Update test/mocks/pinecone.py
This commit is contained in:
parent
74b7c2c12a
commit
82c9cff3d9
@ -1,10 +1,9 @@
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict, Union
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mock Pinecone instance
|
||||
CONFIG: dict = {"api_key": None, "environment": None, "indexes": {}}
|
||||
|
||||
@ -58,10 +57,23 @@ class Index:
|
||||
upsert_count += 1
|
||||
return {"upserted_count": upsert_count}
|
||||
|
||||
def describe_index_stats(self):
|
||||
def update(self, namespace: str, id: str, set_metadata: dict):
|
||||
# Get existing item metadata
|
||||
meta = self.index_config.namespaces[namespace][id]["metadata"]
|
||||
# Add new metadata to existing item metadata
|
||||
self.index_config.namespaces[namespace][id]["metadata"] = {**meta, **set_metadata}
|
||||
|
||||
def describe_index_stats(self, filter=None):
|
||||
namespaces = {}
|
||||
for namespace in self.index_config.namespaces.items():
|
||||
namespaces[namespace[0]] = {"vector_count": len(namespace[1])}
|
||||
records = self.index_config.namespaces[namespace[0]]
|
||||
if filter:
|
||||
filtered_records = []
|
||||
for record in records.values():
|
||||
if self._filter(metadata=record["metadata"], filters=filter, top_level=True):
|
||||
filtered_records.append(record)
|
||||
records = filtered_records
|
||||
namespaces[namespace[0]] = {"vector_count": len(records)}
|
||||
return {"dimension": self.index_config.dimension, "index_fullness": 0.0, "namespaces": namespaces}
|
||||
|
||||
def query(
|
||||
@ -87,7 +99,11 @@ class Index:
|
||||
if include_metadata:
|
||||
match["metadata"] = records[_id]["metadata"].copy()
|
||||
match["score"] = 0.0
|
||||
response["matches"].append(match)
|
||||
if filter is None or (
|
||||
filter is not None and self._filter(records[_id]["metadata"], filter, top_level=True)
|
||||
):
|
||||
# filter if needed
|
||||
response["matches"].append(match)
|
||||
return response
|
||||
|
||||
def fetch(self, ids: List[str], namespace: str = ""):
|
||||
@ -107,6 +123,117 @@ class Index:
|
||||
}
|
||||
return response
|
||||
|
||||
def _filter(
|
||||
self,
|
||||
metadata: dict,
|
||||
filters: Dict[str, Union[str, int, float, bool, list]],
|
||||
mode: Optional[str] = "$and",
|
||||
top_level=False,
|
||||
) -> dict:
|
||||
"""
|
||||
Mock filtering function
|
||||
"""
|
||||
bools = []
|
||||
if type(filters) is list:
|
||||
list_bools = []
|
||||
for _filter in filters:
|
||||
res = self._filter(metadata, _filter, mode=mode)
|
||||
for key, value in res.items():
|
||||
if key == "$and":
|
||||
list_bools.append(all(value))
|
||||
else:
|
||||
list_bools.append(any(value))
|
||||
if mode == "$and":
|
||||
bools.append(all(list_bools))
|
||||
elif mode == "$or":
|
||||
bools.append(any(list_bools))
|
||||
else:
|
||||
for field, potential_value in filters.items():
|
||||
if field in ["$and", "$or"]:
|
||||
bools.append(self._filter(metadata, potential_value, mode=field))
|
||||
mode = field
|
||||
cond = field
|
||||
else:
|
||||
if type(potential_value) is dict:
|
||||
sub_bool = []
|
||||
for cond, value in potential_value.items():
|
||||
if len(potential_value.keys()) > 1:
|
||||
sub_filter = {field: {cond: value}}
|
||||
bools.append(self._filter(metadata, sub_filter))
|
||||
if len(sub_bool) > 1:
|
||||
if field == "$or":
|
||||
bools.append(any(sub_bool))
|
||||
else:
|
||||
bools.append(all(sub_bool))
|
||||
elif type(potential_value) is list:
|
||||
cond = "$in"
|
||||
value = potential_value
|
||||
else:
|
||||
cond = "$eq"
|
||||
value = potential_value
|
||||
# main chunk of condition checks
|
||||
if cond == "$eq":
|
||||
if field in metadata and metadata[field] == value:
|
||||
bools.append(True)
|
||||
else:
|
||||
bools.append(False)
|
||||
elif cond == "$ne":
|
||||
if field in metadata and metadata[field] != value:
|
||||
bools.append(True)
|
||||
else:
|
||||
bools.append(False)
|
||||
elif cond == "$in":
|
||||
if field in metadata and metadata[field] in value:
|
||||
bools.append(True)
|
||||
else:
|
||||
bools.append(False)
|
||||
elif cond == "$nin":
|
||||
if field in metadata and metadata[field] not in value:
|
||||
bools.append(True)
|
||||
else:
|
||||
bools.append(False)
|
||||
elif cond == "$gt":
|
||||
if field in metadata and metadata[field] > value:
|
||||
bools.append(True)
|
||||
else:
|
||||
bools.append(False)
|
||||
elif cond == "$lt":
|
||||
if field in metadata and metadata[field] < value:
|
||||
bools.append(True)
|
||||
else:
|
||||
bools.append(False)
|
||||
elif cond == "$gte":
|
||||
if field in metadata and metadata[field] >= value:
|
||||
bools.append(True)
|
||||
else:
|
||||
bools.append(False)
|
||||
elif cond == "$lte":
|
||||
if field in metadata and metadata[field] <= value:
|
||||
bools.append(True)
|
||||
else:
|
||||
bools.append(False)
|
||||
if top_level:
|
||||
final = []
|
||||
for item in bools:
|
||||
if type(item) is dict:
|
||||
for key, value in item.items():
|
||||
if key == "$and":
|
||||
final.append(all(value))
|
||||
else:
|
||||
final.append(any(value))
|
||||
else:
|
||||
final.append(item)
|
||||
if mode == "$and":
|
||||
bools = all(final)
|
||||
else:
|
||||
bools = any(final)
|
||||
else:
|
||||
if mode == "$and":
|
||||
return {"$and": bools}
|
||||
else:
|
||||
return {"$or": bools}
|
||||
return bools
|
||||
|
||||
def delete(
|
||||
self,
|
||||
ids: Optional[List[str]] = None,
|
||||
@ -114,15 +241,24 @@ class Index:
|
||||
filters: Optional[dict] = None,
|
||||
delete_all: bool = False,
|
||||
):
|
||||
if delete_all:
|
||||
if filters:
|
||||
# Get a filtered list of IDs
|
||||
matches = self.query(filters=filters, namespace=namespace, include_values=False, include_metadata=False)[
|
||||
"vectors"
|
||||
]
|
||||
filter_ids: List[str] = matches.keys() # .keys() returns an object that supports set operators already
|
||||
elif delete_all:
|
||||
self.index_config.namespaces[namespace] = {}
|
||||
|
||||
if namespace not in self.index_config.namespaces:
|
||||
pass
|
||||
elif ids is not None:
|
||||
id_list: List[str] = ids
|
||||
if filters:
|
||||
# We find the intersect between the IDs and filtered IDs
|
||||
id_list = set(id_list).intersection(filter_ids)
|
||||
records = self.index_config.namespaces[namespace]
|
||||
for _id in list(records.keys()):
|
||||
for _id in records.keys():
|
||||
if _id in id_list:
|
||||
del records[_id]
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user