diff --git a/test/mocks/pinecone.py b/test/mocks/pinecone.py index 17f73186f..f733bc277 100644 --- a/test/mocks/pinecone.py +++ b/test/mocks/pinecone.py @@ -1,10 +1,9 @@ -from typing import Optional, List +from typing import Optional, List, Dict, Union import logging logger = logging.getLogger(__name__) - # Mock Pinecone instance CONFIG: dict = {"api_key": None, "environment": None, "indexes": {}} @@ -58,10 +57,23 @@ class Index: upsert_count += 1 return {"upserted_count": upsert_count} - def describe_index_stats(self): + def update(self, namespace: str, id: str, set_metadata: dict): + # Get existing item metadata + meta = self.index_config.namespaces[namespace][id]["metadata"] + # Add new metadata to existing item metadata + self.index_config.namespaces[namespace][id]["metadata"] = {**meta, **set_metadata} + + def describe_index_stats(self, filter=None): namespaces = {} for namespace in self.index_config.namespaces.items(): - namespaces[namespace[0]] = {"vector_count": len(namespace[1])} + records = self.index_config.namespaces[namespace[0]] + if filter: + filtered_records = [] + for record in records.values(): + if self._filter(metadata=record["metadata"], filters=filter, top_level=True): + filtered_records.append(record) + records = filtered_records + namespaces[namespace[0]] = {"vector_count": len(records)} return {"dimension": self.index_config.dimension, "index_fullness": 0.0, "namespaces": namespaces} def query( @@ -87,7 +99,11 @@ class Index: if include_metadata: match["metadata"] = records[_id]["metadata"].copy() match["score"] = 0.0 - response["matches"].append(match) + if filter is None or ( + filter is not None and self._filter(records[_id]["metadata"], filter, top_level=True) + ): + # filter if needed + response["matches"].append(match) return response def fetch(self, ids: List[str], namespace: str = ""): @@ -107,6 +123,117 @@ class Index: } return response + def _filter( + self, + metadata: dict, + filters: Dict[str, Union[str, int, float, bool, list]], + mode: Optional[str] = "$and", + top_level=False, + ) -> dict: + """ + Mock filtering function + """ + bools = [] + if type(filters) is list: + list_bools = [] + for _filter in filters: + res = self._filter(metadata, _filter, mode=mode) + for key, value in res.items(): + if key == "$and": + list_bools.append(all(value)) + else: + list_bools.append(any(value)) + if mode == "$and": + bools.append(all(list_bools)) + elif mode == "$or": + bools.append(any(list_bools)) + else: + for field, potential_value in filters.items(): + if field in ["$and", "$or"]: + bools.append(self._filter(metadata, potential_value, mode=field)) + mode = field + cond = field + else: + if type(potential_value) is dict: + sub_bool = [] + for cond, value in potential_value.items(): + if len(potential_value.keys()) > 1: + sub_filter = {field: {cond: value}} + bools.append(self._filter(metadata, sub_filter)) + if len(sub_bool) > 1: + if field == "$or": + bools.append(any(sub_bool)) + else: + bools.append(all(sub_bool)) + elif type(potential_value) is list: + cond = "$in" + value = potential_value + else: + cond = "$eq" + value = potential_value + # main chunk of condition checks + if cond == "$eq": + if field in metadata and metadata[field] == value: + bools.append(True) + else: + bools.append(False) + elif cond == "$ne": + if field in metadata and metadata[field] != value: + bools.append(True) + else: + bools.append(False) + elif cond == "$in": + if field in metadata and metadata[field] in value: + bools.append(True) + else: + bools.append(False) + elif cond == "$nin": + if field in metadata and metadata[field] not in value: + bools.append(True) + else: + bools.append(False) + elif cond == "$gt": + if field in metadata and metadata[field] > value: + bools.append(True) + else: + bools.append(False) + elif cond == "$lt": + if field in metadata and metadata[field] < value: + bools.append(True) + else: + bools.append(False) + elif cond == "$gte": + if field in metadata and metadata[field] >= value: + bools.append(True) + else: + bools.append(False) + elif cond == "$lte": + if field in metadata and metadata[field] <= value: + bools.append(True) + else: + bools.append(False) + if top_level: + final = [] + for item in bools: + if type(item) is dict: + for key, value in item.items(): + if key == "$and": + final.append(all(value)) + else: + final.append(any(value)) + else: + final.append(item) + if mode == "$and": + bools = all(final) + else: + bools = any(final) + else: + if mode == "$and": + return {"$and": bools} + else: + return {"$or": bools} + return bools + def delete( self, ids: Optional[List[str]] = None, @@ -114,15 +241,24 @@ class Index: filters: Optional[dict] = None, delete_all: bool = False, ): - if delete_all: + if filters: + # Get a filtered list of IDs + matches = self.query(filters=filters, namespace=namespace, include_values=False, include_metadata=False)[ + "vectors" + ] + filter_ids: List[str] = matches.keys() # .keys() returns an object that supports set operators already + elif delete_all: self.index_config.namespaces[namespace] = {} if namespace not in self.index_config.namespaces: pass elif ids is not None: id_list: List[str] = ids + if filters: + # We find the intersect between the IDs and filtered IDs + id_list = set(id_list).intersection(filter_ids) records = self.index_config.namespaces[namespace] - for _id in list(records.keys()): + for _id in records.keys(): if _id in id_list: del records[_id] else: