From 4bad2021978418532d16e44a5994f76f1eef3623 Mon Sep 17 00:00:00 2001 From: Ivana Zeljkovic Date: Thu, 14 Sep 2023 11:46:47 +0200 Subject: [PATCH] feat: Pinecone document store refactoring (#5725) * Refactor codebase so that doc_type metadata is used instead of namespaces for making distinction between documents without embeddings, documents with embeddings and labels * Fix parameter name in integration test * Remove code under comment in add_type_metadata_filter method * Fix mypy and pylint checks * Add release note * Apply minimal changes: rename method, update method docs and remove redundant method * Mypy fixes * Fix docstrings * Revert helper methods for fetching documents when the number of documents exceeds Pinecone limit * Remove unnecessary attributes in PineconeDocumentStore * Fix unit test --------- Co-authored-by: Ivana Zeljkovic Co-authored-by: DosticJelena --- haystack/document_stores/pinecone.py | 752 ++++++++++-------- .../refactor-pinecone-document-store.yaml | 6 + test/document_stores/test_pinecone.py | 19 +- 3 files changed, 441 insertions(+), 336 deletions(-) create mode 100644 releasenotes/notes/refactor-pinecone-document-store.yaml diff --git a/haystack/document_stores/pinecone.py b/haystack/document_stores/pinecone.py index 7051618c1..11b2d6d29 100644 --- a/haystack/document_stores/pinecone.py +++ b/haystack/document_stores/pinecone.py @@ -1,31 +1,43 @@ +from __future__ import annotations + import copy import json -from typing import Set, Union, List, Optional, Dict, Generator, Any - import logging -from itertools import islice -from functools import reduce import operator +from functools import reduce +from itertools import islice +from typing import Any, Dict, Generator, List, Literal, Optional, Set, Union import numpy as np from tqdm import tqdm -from haystack.schema import Document, FilterType, Label, Answer, Span from haystack.document_stores import BaseDocumentStore - from haystack.document_stores.filter_utils import LogicalFilterClause -from haystack.errors import PineconeDocumentStoreError, DuplicateDocumentError -from haystack.nodes.retriever import DenseRetriever +from haystack.errors import DuplicateDocumentError, PineconeDocumentStoreError from haystack.lazy_imports import LazyImport +from haystack.nodes.retriever import DenseRetriever +from haystack.schema import Answer, Document, FilterType, Label, Span with LazyImport("Run 'pip install farm-haystack[pinecone]'") as pinecone_import: import pinecone - logger = logging.getLogger(__name__) +TYPE_METADATA_FIELD = "doc_type" +DOCUMENT_WITH_EMBEDDING = "vector" +DOCUMENT_WITHOUT_EMBEDDING = "no-vector" +LABEL = "label" -def _sanitize_index_name(index: Optional[str]) -> Optional[str]: +AND_OPERATOR = "$and" +IN_OPERATOR = "$in" +EQ_OPERATOR = "$eq" + +DEFAULT_BATCH_SIZE = 32 + +DocTypeMetadata = Literal["vector", "no-vector", "label"] + + +def _sanitize_index(index: Optional[str]) -> Optional[str]: if index: return index.replace("_", "-").lower() return None @@ -69,6 +81,7 @@ class PineconeDocumentStore(BaseDocumentStore): similarity: str = "cosine", replicas: int = 1, shards: int = 1, + namespace: Optional[str] = None, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = "overwrite", @@ -78,8 +91,8 @@ class PineconeDocumentStore(BaseDocumentStore): ): """ :param api_key: Pinecone vector database API key ([https://app.pinecone.io](https://app.pinecone.io)). - :param environment: Pinecone cloud environment uses `"us-west1-gcp"` by default. Other GCP and AWS regions are - supported, contact Pinecone [here](https://www.pinecone.io/contact/) if required. + :param environment: Pinecone cloud environment uses `"us-west1-gcp"` by default. Other GCP and AWS + regions are supported, contact Pinecone [here](https://www.pinecone.io/contact/) if required. :param pinecone_index: pinecone-client Index object, an index will be initialized or loaded if not specified. :param embedding_dim: The embedding vector size. :param return_embedding: Whether to return document embeddings. @@ -93,11 +106,11 @@ class PineconeDocumentStore(BaseDocumentStore): :param replicas: The number of replicas. Replicas duplicate the index. They provide higher availability and throughput. :param shards: The number of shards to be used in the index. We recommend to use 1 shard per 1GB of data. + :param namespace: Optional namespace. If not specified, None is default. :param embedding_field: Name of field containing an embedding vector. :param progress_bar: Whether to show a tqdm progress bar or not. Can be helpful to disable in production deployments to keep the logs clean. :param duplicate_documents: Handle duplicate documents based on parameter options.\ - Parameter options: - `"skip"`: Ignore the duplicate documents. - `"overwrite"`: Update any existing documents with the same ID when adding documents. @@ -124,34 +137,23 @@ class PineconeDocumentStore(BaseDocumentStore): self._api_key = api_key # Formal similarity string - if similarity == "cosine": - self.metric_type = similarity - elif similarity == "dot_product": - self.metric_type = "dotproduct" - elif similarity in ("l2", "euclidean"): - self.metric_type = "euclidean" - else: - raise ValueError( - "The Pinecone document store can currently only support dot_product, cosine and euclidean metrics. " - "Please set similarity to one of the above." - ) + self._set_similarity_metric(similarity) self.similarity = similarity - self.index: str = self._index_name(index) + self.index: str = self._index(index) self.embedding_dim = embedding_dim self.return_embedding = return_embedding self.embedding_field = embedding_field self.progress_bar = progress_bar self.duplicate_documents = duplicate_documents - self.document_namespace = "no-vectors" - self.embedding_namespace = "vectors" # Pinecone index params self.replicas = replicas self.shards = shards + self.namespace = namespace # Add necessary metadata fields to metadata_config - fields = ["label-id", "query"] + fields = ["label-id", "query", TYPE_METADATA_FIELD] metadata_config["indexed"] += fields self.metadata_config = metadata_config @@ -162,11 +164,10 @@ class PineconeDocumentStore(BaseDocumentStore): # Initialize dictionary to store temporary set of document IDs self.all_ids: dict = {} + # Dummy query to be used during searches self.dummy_query = [0.0] * self.embedding_dim - self.progress_bar = progress_bar - if pinecone_index: if not isinstance(pinecone_index, pinecone.Index): raise PineconeDocumentStoreError( @@ -188,17 +189,8 @@ class PineconeDocumentStore(BaseDocumentStore): super().__init__() - def _add_local_ids(self, index: str, ids: list): - """ - Add all document IDs to the set of all IDs. - """ - if index not in self.all_ids: - self.all_ids[index] = set() - self.all_ids[index] = self.all_ids[index].union(set(ids)) - - def _index_name(self, index) -> str: - index = _sanitize_index_name(index) or self.index - # self.index = index # TODO maybe not needed + def _index(self, index) -> str: + index = _sanitize_index(index) or self.index return index def _create_index( @@ -210,20 +202,19 @@ class PineconeDocumentStore(BaseDocumentStore): shards: Optional[int] = 1, recreate_index: bool = False, metadata_config: Optional[Dict] = None, - ): + ) -> pinecone.Index: """ - Create a new index for storing documents in case an - index with the name doesn't exist already. + Create a new index for storing documents in case an index with the name + doesn't exist already. """ if metadata_config is None: metadata_config = {"indexed": []} - index = self._index_name(index) if recreate_index: self.delete_index(index) # Skip if already exists - if index in self.pinecone_indexes.keys(): + if index in self.pinecone_indexes: index_connection = self.pinecone_indexes[index] else: # Search pinecone hosted indexes and create an index if it does not exist @@ -243,19 +234,123 @@ class PineconeDocumentStore(BaseDocumentStore): dims = stats["dimension"] count = stats["namespaces"][""]["vector_count"] if stats["namespaces"].get("") else 0 logger.info("Index statistics: name: %s embedding dimensions: %s, record count: %s", index, dims, count) + # return index connection return index_connection + def _index_connection_exists(self, index: str, create: bool = False) -> Optional["pinecone.Index"]: + """ + Check if the index connection exists. If specified, create an index if it does not exist yet. + + :param index: Index name. + :param create: Indicates if an index needs to be created or not. If set to `True`, create an index + and return connection to it, otherwise raise `PineconeDocumentStoreError` error. + :raises PineconeDocumentStoreError: Exception trigger when index connection not found. + """ + if index not in self.pinecone_indexes: + if create: + return self._create_index( + embedding_dim=self.embedding_dim, + index=index, + metric_type=self.metric_type, + replicas=self.replicas, + shards=self.shards, + recreate_index=False, + metadata_config=self.metadata_config, + ) + raise PineconeDocumentStoreError( + f"Index named '{index}' does not exist. Try reinitializing PineconeDocumentStore() and running " + f"'update_embeddings()' to create and populate an index." + ) + return None + + def _set_similarity_metric(self, similarity: str): + """ + Set vector similarity metric. + """ + if similarity == "cosine": + self.metric_type = similarity + elif similarity == "dot_product": + self.metric_type = "dotproduct" + elif similarity in ["l2", "euclidean"]: + self.metric_type = "euclidean" + else: + raise ValueError( + "The Pinecone document store can currently only support dot_product, cosine and euclidean metrics. " + "Please set similarity to one of the above." + ) + + def _add_local_ids(self, index: str, ids: List[str]): + """ + Add all document IDs to the set of all IDs. + """ + if index not in self.all_ids: + self.all_ids[index] = set() + self.all_ids[index] = self.all_ids[index].union(set(ids)) + + def _add_type_metadata_filter(self, filters: FilterType, type_value: Optional[DocTypeMetadata]) -> FilterType: + """ + Add new filter for `doc_type` metadata field. + """ + if type_value: + new_type_filter = {TYPE_METADATA_FIELD: {EQ_OPERATOR: type_value}} + if AND_OPERATOR not in filters and TYPE_METADATA_FIELD not in filters: + # extend filters with new `doc_type` filter and add $and operator + filters.update(new_type_filter) + all_filters = filters + return {AND_OPERATOR: all_filters} + + filters_content = filters[AND_OPERATOR] if AND_OPERATOR in filters else filters + if TYPE_METADATA_FIELD in filters_content: # type: ignore + current_type_filter = filters_content[TYPE_METADATA_FIELD] # type: ignore + type_values = {type_value} + if isinstance(current_type_filter, str): + type_values.add(current_type_filter) # type: ignore + elif isinstance(current_type_filter, dict): + if EQ_OPERATOR in current_type_filter: + # current `doc_type` filter has single value + type_values.add(current_type_filter[EQ_OPERATOR]) + else: + # current `doc_type` filter has multiple values + type_values.update(set(current_type_filter[IN_OPERATOR])) + new_type_filter = {TYPE_METADATA_FIELD: {IN_OPERATOR: list(type_values)}} # type: ignore + filters_content.update(new_type_filter) # type: ignore + + return filters + + def _get_default_type_metadata(self, index: Optional[str], namespace: Optional[str] = None) -> str: + """ + Get default value for `doc_type` metadata filed. If there is at least one embedding, default value + will be `vector`, otherwise it will be `no-vector`. + """ + if self.get_embedding_count(index=index, namespace=namespace) > 0: + return DOCUMENT_WITH_EMBEDDING + return DOCUMENT_WITHOUT_EMBEDDING + + def _get_vector_count(self, index: str, filters: Optional[FilterType], namespace: Optional[str]) -> int: + res = self.pinecone_indexes[index].query( + self.dummy_query, + top_k=self.top_k_limit, + include_values=False, + include_metadata=False, + filter=filters, + namespace=namespace, + ) + return len(res["matches"]) + def get_document_count( self, filters: Optional[FilterType] = None, index: Optional[str] = None, only_documents_without_embedding: bool = False, headers: Optional[Dict[str, str]] = None, + namespace: Optional[str] = None, + type_metadata: Optional[DocTypeMetadata] = None, ) -> int: """ - Return the count of embeddings in the document store. - :param filters: Optional filters to narrow down the documents for which embeddings are to be updated. + Return the count of documents in the document store. + + :param filters: Optional filters to narrow down the documents which will be counted. Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, `"$gte"`, `"$lt"`, `"$lte"`), or a metadata field name. @@ -280,33 +375,63 @@ class PineconeDocumentStore(BaseDocumentStore): } } ``` - :param index: Optional index to use for the query. If not provided, the default index is used. + :param index: Optional index name to use for the query. If not provided, the default index name is used. :param only_documents_without_embedding: If set to `True`, only documents without embeddings are counted. :param headers: PineconeDocumentStore does not support headers. + :param namespace: Optional namespace to count documents from. If not specified, None is default. + :param type_metadata: Optional value for `doc_type` metadata to reference documents that need to be counted. + Parameter options: + - `"vector"`: Documents with embedding. + - `"no-vector"`: Documents without embedding (dummy embedding only). + - `"label"`: Labels. """ if headers: raise NotImplementedError("PineconeDocumentStore does not support headers.") - index = self._index_name(index) - if index not in self.pinecone_indexes: - raise PineconeDocumentStoreError( - f"Index named '{index}' does not exist. Try reinitializing PineconeDocumentStore() and running " - f"'update_embeddings()' to create and populate an index." - ) + index = self._index(index) + self._index_connection_exists(index) + + filters = filters or {} + if not type_metadata: + # add filter for `doc_type` metadata related to documents without embeddings + filters = self._add_type_metadata_filter(filters, type_value=DOCUMENT_WITHOUT_EMBEDDING) # type: ignore + if not only_documents_without_embedding: + # add filter for `doc_type` metadata related to documents with embeddings + filters = self._add_type_metadata_filter(filters, type_value=DOCUMENT_WITH_EMBEDDING) # type: ignore + else: + # if value for `doc_type` metadata is specified, add filter with given value + filters = self._add_type_metadata_filter(filters, type_value=type_metadata) pinecone_syntax_filter = LogicalFilterClause.parse(filters).convert_to_pinecone() if filters else None + return self._get_vector_count(index, filters=pinecone_syntax_filter, namespace=namespace) - stats = self.pinecone_indexes[index].describe_index_stats(filter=pinecone_syntax_filter) - if only_documents_without_embedding: - return sum(value["vector_count"] for key, value in stats["namespaces"].items() if "no-vectors" in key) - return sum(value["vector_count"] for value in stats["namespaces"].values()) + def get_embedding_count( + self, filters: Optional[FilterType] = None, index: Optional[str] = None, namespace: Optional[str] = None + ) -> int: + """ + Return the count of embeddings in the document store. + + :param index: Optional index name to retrieve all documents from. + :param filters: Filters are not supported for `get_embedding_count` in Pinecone. + :param namespace: Optional namespace to count embeddings from. If not specified, None is default. + """ + if filters: + raise NotImplementedError("Filters are not supported for get_embedding_count in PineconeDocumentStore") + + index = self._index(index) + self._index_connection_exists(index) + + pinecone_filters = self._meta_for_pinecone({TYPE_METADATA_FIELD: DOCUMENT_WITH_EMBEDDING}) + return self._get_vector_count(index, filters=pinecone_filters, namespace=namespace) def _validate_index_sync(self, index: Optional[str] = None): """ - This check ensures the correct number of documents and embeddings are found in the + This check ensures the correct number of documents with embeddings and embeddings are found in the Pinecone database. """ - if self.get_document_count(index=index) != self.get_embedding_count(index=index): + if self.get_document_count( + index=index, type_metadata=DOCUMENT_WITH_EMBEDDING # type: ignore + ) != self.get_embedding_count(index=index): raise PineconeDocumentStoreError( f"The number of documents present in Pinecone ({self.get_document_count(index=index)}) " "does not match the number of embeddings in Pinecone " @@ -319,10 +444,11 @@ class PineconeDocumentStore(BaseDocumentStore): self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, - batch_size: int = 32, + batch_size: int = DEFAULT_BATCH_SIZE, duplicate_documents: Optional[str] = None, headers: Optional[Dict[str, str]] = None, labels: Optional[bool] = False, + namespace: Optional[str] = None, ): """ Add new documents to the DocumentStore. @@ -333,64 +459,61 @@ class PineconeDocumentStore(BaseDocumentStore): :param batch_size: Number of documents to process at a time. When working with large number of documents, batching can help to reduce the memory footprint. :param duplicate_documents: handle duplicate documents based on parameter options. - Parameter options: - `"skip"`: Ignore the duplicate documents. - `"overwrite"`: Update any existing documents with the same ID when adding documents. - `"fail"`: An error is raised if the document ID of the document being added already exists. :param headers: PineconeDocumentStore does not support headers. :param labels: Tells us whether these records are labels or not. Defaults to False. + :param namespace: Optional namespace to write documents to. If not specified, None is default. :raises DuplicateDocumentError: Exception trigger on duplicate document. """ if headers: raise NotImplementedError("PineconeDocumentStore does not support headers.") - index = self._index_name(index) + index = self._index(index) duplicate_documents = duplicate_documents or self.duplicate_documents assert ( duplicate_documents in self.duplicate_documents_options ), f"duplicate_documents parameter must be {', '.join(self.duplicate_documents_options)}" - if index not in self.pinecone_indexes: - self.pinecone_indexes[index] = self._create_index( - embedding_dim=self.embedding_dim, - index=index, - metric_type=self.metric_type, - replicas=self.replicas, - shards=self.shards, - recreate_index=False, - metadata_config=self.metadata_config, - ) + index_connection = self._index_connection_exists(index, create=True) + if index_connection: + self.pinecone_indexes[index] = index_connection field_map = self._create_document_field_map() - document_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents] + document_objects = [ + Document.from_dict(doc, field_map=field_map) if isinstance(doc, dict) else doc for doc in documents + ] document_objects = self._handle_duplicate_documents( documents=document_objects, index=index, duplicate_documents=duplicate_documents ) - if len(document_objects) > 0: + if document_objects: add_vectors = False if document_objects[0].embedding is None else True - # If these are not labels, we need to find the correct namespace + # If these are not labels, we need to find the correct value for `doc_type` metadata field if not labels: - # If not adding vectors we use document namespace - namespace = self.embedding_namespace if add_vectors else self.document_namespace + type_metadata = DOCUMENT_WITH_EMBEDDING if add_vectors else DOCUMENT_WITHOUT_EMBEDDING else: - namespace = "labels" + type_metadata = LABEL if not add_vectors: # To store documents in Pinecone, we use dummy embeddings (to be replaced with real embeddings later) embeddings_to_index = np.zeros((batch_size, self.embedding_dim), dtype="float32") # Convert embeddings to list objects embeddings = [embed.tolist() if embed is not None else None for embed in embeddings_to_index] + with tqdm( total=len(document_objects), disable=not self.progress_bar, position=0, desc="Writing Documents" ) as progress_bar: for i in range(0, len(document_objects), batch_size): document_batch = document_objects[i : i + batch_size] ids = [doc.id for doc in document_batch] - # If duplicate_documents set to skip or fail, we need to check for existing documents + # If duplicate_documents set to `skip` or `fail`, we need to check for existing documents if duplicate_documents in ["skip", "fail"]: - existing_documents = self.get_documents_by_id(ids=ids, index=index, namespace=namespace) + existing_documents = self.get_documents_by_id( + ids=ids, index=index, namespace=namespace, include_type_metadata=True + ) # First check for documents in current batch that exist in the index - if len(existing_documents) > 0: + if existing_documents: if duplicate_documents == "skip": # If we should skip existing documents, we drop the ids that already exist skip_ids = [doc.id for doc in existing_documents] @@ -406,7 +529,7 @@ class PineconeDocumentStore(BaseDocumentStore): ) # Now check for duplicate documents within the batch itself if len(ids) != len(set(ids)): - if duplicate_documents in "skip": + if duplicate_documents == "skip": # We just keep the first instance of each duplicate document ids = [] temp_document_batch = [] @@ -419,7 +542,14 @@ class PineconeDocumentStore(BaseDocumentStore): # Otherwise, we raise an error raise DuplicateDocumentError(f"Duplicate document IDs found in batch: {ids}") metadata = [ - self._meta_for_pinecone({"content": doc.content, "content_type": doc.content_type, **doc.meta}) + self._meta_for_pinecone( + { + TYPE_METADATA_FIELD: type_metadata, # add `doc_type` in metadata + "content": doc.content, + "content_type": doc.content_type, + **doc.meta, + } + ) for doc in document_objects[i : i + batch_size] ] if add_vectors: @@ -447,7 +577,8 @@ class PineconeDocumentStore(BaseDocumentStore): index: Optional[str] = None, update_existing_embeddings: bool = True, filters: Optional[FilterType] = None, - batch_size: int = 32, + batch_size: int = DEFAULT_BATCH_SIZE, + namespace: Optional[str] = None, ): """ Updates the embeddings in the document store using the encoding model specified in the retriever. @@ -487,15 +618,19 @@ class PineconeDocumentStore(BaseDocumentStore): ``` :param batch_size: Number of documents to process at a time. When working with large number of documents, batching can help reduce memory footprint. + :param namespace: Optional namespace to retrieve document from. If not specified, None is default. """ - index = self._index_name(index) + index = self._index(index) if index not in self.pinecone_indexes: raise ValueError( f"Couldn't find a the index '{index}' in Pinecone. Try to init the " f"PineconeDocumentStore() again ..." ) document_count = self.get_document_count( - index=index, filters=filters, only_documents_without_embedding=not update_existing_embeddings + index=index, + filters=filters, + only_documents_without_embedding=not update_existing_embeddings, + namespace=namespace, ) if document_count == 0: logger.warning("Calling DocumentStore.update_embeddings() on an empty index") @@ -503,15 +638,20 @@ class PineconeDocumentStore(BaseDocumentStore): logger.info("Updating embeddings for %s docs...", document_count) - # If the embedding namespace is empty or the user does not want to update existing embeddings, we use document namespace + # If embeddings don't exist or the user doesn't want to update existing embeddings, update dummy embeddings if self.get_embedding_count(index=index) == 0 or not update_existing_embeddings: - namespace = self.document_namespace + type_value = DOCUMENT_WITHOUT_EMBEDDING else: - # Else, we use the embedding namespace as this is the primary namespace for embeddings - namespace = self.embedding_namespace + type_value = DOCUMENT_WITH_EMBEDDING documents = self.get_all_documents_generator( - index=index, namespace=namespace, filters=filters, return_embedding=False, batch_size=batch_size + index=index, + type_metadata=type_value, # type: ignore + filters=filters, + return_embedding=False, + batch_size=batch_size, + namespace=namespace, + include_type_metadata=True, ) with tqdm( @@ -537,15 +677,21 @@ class PineconeDocumentStore(BaseDocumentStore): ids = [] for doc in document_batch: metadata.append( - self._meta_for_pinecone({"content": doc.content, "content_type": doc.content_type, **doc.meta}) + self._meta_for_pinecone( + { + "content": doc.content, + "content_type": doc.content_type, + **doc.meta, + # set `doc_type` metadata field to `vector` since the dummy embedding is updated + TYPE_METADATA_FIELD: DOCUMENT_WITH_EMBEDDING, + } + ) ) ids.append(doc.id) # Update existing vectors in pinecone index self.pinecone_indexes[index].upsert( - vectors=zip(ids, embeddings.tolist(), metadata), namespace=self.embedding_namespace + vectors=zip(ids, embeddings.tolist(), metadata), namespace=namespace ) - # Delete existing vectors from document namespace if they exist there - self.delete_documents(index=index, ids=ids, namespace=self.document_namespace) # Add these vector IDs to local store self._add_local_ids(index, ids) progress_bar.set_description_str("Documents Processed") @@ -556,8 +702,9 @@ class PineconeDocumentStore(BaseDocumentStore): index: Optional[str] = None, filters: Optional[FilterType] = None, return_embedding: Optional[bool] = None, - batch_size: int = 32, + batch_size: int = DEFAULT_BATCH_SIZE, headers: Optional[Dict[str, str]] = None, + type_metadata: Optional[DocTypeMetadata] = None, namespace: Optional[str] = None, ) -> List[Document]: """ @@ -591,21 +738,25 @@ class PineconeDocumentStore(BaseDocumentStore): ``` :param return_embedding: Optional flag to return the embedding of the document. :param batch_size: Number of documents to process at a time. When working with large number of documents, - batching can help reduce memory footprint. + batching can help reduce memory footprint. :param headers: Pinecone does not support headers. - :param namespace: Optional namespace to retrieve documents from. + :param type_metadata: Value of `doc_type` metadata that indicates which documents need to be retrieved. + :param namespace: Optional namespace to retrieve documents from. If not specified, None is default. """ if headers: raise NotImplementedError("PineconeDocumentStore does not support headers.") - if namespace is None: - if self.get_embedding_count(index=index) > 0: - namespace = self.embedding_namespace - else: - namespace = self.document_namespace + if not type_metadata: + # set default value for `doc_type` metadata field + type_metadata = self._get_default_type_metadata(index, namespace) # type: ignore result = self.get_all_documents_generator( - index=index, namespace=namespace, filters=filters, return_embedding=return_embedding, batch_size=batch_size + index=index, + type_metadata=type_metadata, + filters=filters, + return_embedding=return_embedding, + batch_size=batch_size, + namespace=namespace, ) documents: List[Document] = list(result) return documents @@ -615,9 +766,11 @@ class PineconeDocumentStore(BaseDocumentStore): index: Optional[str] = None, filters: Optional[FilterType] = None, return_embedding: Optional[bool] = None, - batch_size: int = 32, + batch_size: int = DEFAULT_BATCH_SIZE, headers: Optional[Dict[str, str]] = None, namespace: Optional[str] = None, + type_metadata: Optional[DocTypeMetadata] = None, + include_type_metadata: Optional[bool] = False, ) -> Generator[Document, None, None]: """ Get all documents from the document store. Under-the-hood, documents are fetched in batches from the @@ -625,7 +778,7 @@ class PineconeDocumentStore(BaseDocumentStore): a large number of documents without having to load all documents in memory. :param index: Name of the index to get the documents from. If None, the - DocumentStore's default index (self.index) will be used. + DocumentStore's default index (self.index) will be used. :param filters: Optional filters to narrow down the documents for which embeddings are to be updated. Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, @@ -654,7 +807,10 @@ class PineconeDocumentStore(BaseDocumentStore): :param return_embedding: Whether to return the document embeddings. :param batch_size: When working with large number of documents, batching can help reduce memory footprint. :param headers: PineconeDocumentStore does not support headers. - :param namespace: Optional namespace to retrieve documents from. + :param namespace: Optional namespace to retrieve document from. If not specified, None is default. + :param type_metadata: Value of `doc_type` metadata that indicates which documents need to be retrieved. + :param include_type_metadata: Indicates if `doc_type` value will be included in document metadata or not. + If not specified, `doc_type` field will be dropped from document metadata. """ if headers: raise NotImplementedError("PineconeDocumentStore does not support headers.") @@ -662,22 +818,18 @@ class PineconeDocumentStore(BaseDocumentStore): if return_embedding is None: return_embedding = self.return_embedding - index = self._index_name(index) - if index not in self.pinecone_indexes: - raise PineconeDocumentStoreError( - f"Index named '{index}' does not exist. Try reinitializing PineconeDocumentStore() and running " - f"'update_embeddings()' to create and populate an index." - ) + index = self._index(index) + self._index_connection_exists(index) - if namespace is None: - if self.get_embedding_count(index=index) > 0: - namespace = self.embedding_namespace - else: - namespace = self.document_namespace + if not type_metadata: + # set default value for `doc_type` metadata field + type_metadata = self._get_default_type_metadata(index, namespace) # type: ignore - ids = self._get_all_document_ids(index=index, namespace=namespace, filters=filters, batch_size=batch_size) + ids = self._get_all_document_ids( + index=index, type_metadata=type_metadata, filters=filters, namespace=namespace, batch_size=batch_size + ) - if filters is not None and len(ids) == 0: + if filters is not None and not ids: logger.warning( "This query might have been done without metadata indexed and thus no DOCUMENTS were retrieved. " "Make sure the desired metadata you want to filter with is indexed." @@ -688,9 +840,10 @@ class PineconeDocumentStore(BaseDocumentStore): documents = self.get_documents_by_id( ids=ids[i:i_end], index=index, - namespace=namespace, batch_size=batch_size, return_embedding=return_embedding, + namespace=namespace, + include_type_metadata=include_type_metadata, ) for doc in documents: yield doc @@ -698,24 +851,15 @@ class PineconeDocumentStore(BaseDocumentStore): def _get_all_document_ids( self, index: Optional[str] = None, - namespace: Optional[str] = None, + type_metadata: Optional[DocTypeMetadata] = None, filters: Optional[FilterType] = None, - batch_size: int = 32, + namespace: Optional[str] = None, + batch_size: int = DEFAULT_BATCH_SIZE, ) -> List[str]: - index = self._index_name(index) - if index not in self.pinecone_indexes: - raise PineconeDocumentStoreError( - f"Index named '{index}' does not exist. Try reinitializing PineconeDocumentStore() and running " - f"'update_embeddings()' to create and populate an index." - ) + index = self._index(index) + self._index_connection_exists(index) - if namespace is None: - if self.get_embedding_count(index=index) > 0: - namespace = self.embedding_namespace - else: - namespace = self.document_namespace - - document_count = self.get_document_count(index=index) + document_count = self.get_document_count(index=index, namespace=namespace, type_metadata=type_metadata) if index not in self.all_ids: self.all_ids[index] = set() @@ -726,16 +870,20 @@ class PineconeDocumentStore(BaseDocumentStore): # Otherwise we must query and extract IDs from the original namespace, then move the retrieved embeddings # to a temporary namespace and query again for new items. We repeat this process until all embeddings # have been retrieved. - target_namespace = f"{namespace}-copy" + target_namespace = f"{namespace}-copy" if namespace is not None else "copy" all_ids: Set[str] = set() vector_id_matrix = ["dummy-id"] with tqdm( total=document_count, disable=not self.progress_bar, position=0, unit=" ids", desc="Retrieving IDs" ) as progress_bar: - while len(vector_id_matrix) != 0: + while vector_id_matrix: # Retrieve IDs from Pinecone vector_id_matrix = self._get_ids( - index=index, namespace=namespace, batch_size=batch_size, filters=filters + index=index, + namespace=namespace, + filters=filters, + type_metadata=type_metadata, + batch_size=batch_size, ) # Save IDs all_ids = all_ids.union(set(vector_id_matrix)) @@ -749,8 +897,10 @@ class PineconeDocumentStore(BaseDocumentStore): ) progress_bar.set_description_str("Retrieved IDs") progress_bar.update(len(set(vector_id_matrix))) + # Now move all documents back to source namespace - self._namespace_cleanup(index) + self._namespace_cleanup(index=index, namespace=target_namespace, batch_size=batch_size) + self._add_local_ids(index, list(all_ids)) return list(all_ids) @@ -758,11 +908,11 @@ class PineconeDocumentStore(BaseDocumentStore): self, ids: List[str], index: Optional[str] = None, - source_namespace: Optional[str] = "vectors", + source_namespace: Optional[str] = None, target_namespace: Optional[str] = "copy", - batch_size: int = 32, + batch_size: int = DEFAULT_BATCH_SIZE, ): - index = self._index_name(index) + index = self._index(index) if index not in self.pinecone_indexes: raise PineconeDocumentStoreError( f"Index named '{index}' does not exist. Try reinitializing PineconeDocumentStore() and running " @@ -794,14 +944,40 @@ class PineconeDocumentStore(BaseDocumentStore): progress_bar.set_description_str("Documents Moved") progress_bar.update(len(id_batch)) + def _namespace_cleanup(self, index: str, namespace: str, batch_size: int = DEFAULT_BATCH_SIZE): + """ + Shifts vectors back from "-copy" namespace to the original namespace. + """ + with tqdm( + total=1, disable=not self.progress_bar, position=0, unit=" namespaces", desc="Cleaning Namespace" + ) as progress_bar: + target_namespace = namespace[:-5] if namespace != "copy" else None + while True: + # Retrieve IDs from Pinecone + vector_id_matrix = self._get_ids(index=index, namespace=namespace, batch_size=batch_size) + # Once we reach final item, we break + if len(vector_id_matrix) == 0: + break + # Move these IDs to new namespace + self._move_documents_by_id_namespace( + ids=vector_id_matrix, + index=index, + source_namespace=namespace, + target_namespace=target_namespace, + batch_size=batch_size, + ) + progress_bar.set_description_str("Cleaned Namespace") + progress_bar.update(1) + def get_documents_by_id( self, ids: List[str], index: Optional[str] = None, - batch_size: int = 32, + batch_size: int = DEFAULT_BATCH_SIZE, headers: Optional[Dict[str, str]] = None, return_embedding: Optional[bool] = None, namespace: Optional[str] = None, + include_type_metadata: Optional[bool] = False, ) -> List[Document]: """ Retrieves all documents in the index using their IDs. @@ -812,27 +988,18 @@ class PineconeDocumentStore(BaseDocumentStore): batching can help reduce memory footprint. :param headers: Pinecone does not support headers. :param return_embedding: Optional flag to return the embedding of the document. - :param namespace: Optional namespace to retrieve documents from. + :param namespace: Optional namespace to retrieve document from. If not specified, None is default. + :param include_type_metadata: Indicates if `doc_type` value will be included in document metadata or not. + If not specified, `doc_type` field will be dropped from document metadata. """ - if headers: raise NotImplementedError("PineconeDocumentStore does not support headers.") if return_embedding is None: return_embedding = self.return_embedding - if namespace is None: - if self.get_embedding_count(index=index) > 0: - namespace = self.embedding_namespace - else: - namespace = self.document_namespace - - index = self._index_name(index) - if index not in self.pinecone_indexes: - raise PineconeDocumentStoreError( - f"Index named '{index}' does not exist. Try reinitializing PineconeDocumentStore() and running " - f"'update_embeddings()' to create and populate an index." - ) + index = self._index(index) + self._index_connection_exists(index) documents = [] for i in range(0, len(ids), batch_size): @@ -843,9 +1010,12 @@ class PineconeDocumentStore(BaseDocumentStore): vector_id_matrix = [] meta_matrix = [] embedding_matrix = [] - for _id in result["vectors"].keys(): + for _id in result["vectors"]: vector_id_matrix.append(_id) - meta_matrix.append(self._pinecone_meta_format(result["vectors"][_id]["metadata"])) + metadata = result["vectors"][_id]["metadata"] + if not include_type_metadata and TYPE_METADATA_FIELD in metadata: + metadata.pop(TYPE_METADATA_FIELD) + meta_matrix.append(self._pinecone_meta_format(metadata)) if return_embedding: embedding_matrix.append(result["vectors"][_id]["values"]) if return_embedding: @@ -874,66 +1044,31 @@ class PineconeDocumentStore(BaseDocumentStore): :param index: Optional index name to retrieve all documents from. :param headers: Pinecone does not support headers. :param return_embedding: Optional flag to return the embedding of the document. - :param namespace: Optional namespace to retrieve documents from. + :param namespace: Optional namespace to retrieve document from. If not specified, None is default. """ documents = self.get_documents_by_id( - ids=[id], namespace=namespace, index=index, headers=headers, return_embedding=return_embedding + ids=[id], index=index, headers=headers, return_embedding=return_embedding, namespace=namespace ) return documents[0] - def get_embedding_count(self, index: Optional[str] = None, filters: Optional[FilterType] = None) -> int: - """ - Return the count of embeddings in the document store. - - :param index: Optional index name to retrieve all documents from. - :param filters: Filters are not supported for `get_embedding_count` in Pinecone. - """ - if filters: - raise NotImplementedError("Filters are not supported for get_embedding_count in PineconeDocumentStore") - - index = self._index_name(index) - if index not in self.pinecone_indexes: - raise PineconeDocumentStoreError( - f"Index named '{index}' does not exist. Try reinitializing PineconeDocumentStore() and running " - f"'update_embeddings()' to create and populate an index." - ) - - stats = self.pinecone_indexes[index].describe_index_stats() - # if no embeddings namespace return zero - if self.embedding_namespace in stats["namespaces"]: - count = stats["namespaces"][self.embedding_namespace]["vector_count"] - else: - count = 0 - return count - - def update_document_meta(self, id: str, meta: Dict[str, str], namespace: Optional[str] = None, index: Optional[str] = None): # type: ignore + def update_document_meta(self, id: str, meta: Dict[str, str], index: Optional[str] = None): """ Update the metadata dictionary of a document by specifying its string ID. :param id: ID of the Document to update. :param meta: Dictionary of new metadata. - :param namespace: Optional namespace to update documents from. If not specified, defaults to the embedding - namespace (vectors) if it exists, otherwise the document namespace (no-vectors). + :param namespace: Optional namespace to update documents from. :param index: Optional index name to update documents from. """ - index = self._index_name(index) - if index not in self.pinecone_indexes: - raise PineconeDocumentStoreError( - f"Index named '{index}' does not exist. Try reinitializing PineconeDocumentStore() and running " - f"'update_embeddings()' to create and populate an index." - ) + index = self._index(index) + self._index_connection_exists(index) - if namespace is None: - if self.get_embedding_count(index=index) > 0: - namespace = self.embedding_namespace - else: - namespace = self.document_namespace + doc = self.get_document_by_id(id=id, index=index, return_embedding=True) - doc = self.get_documents_by_id(ids=[id], index=index, return_embedding=True)[0] if doc.embedding is not None: meta = {"content": doc.content, "content_type": doc.content_type, **meta} - self.pinecone_indexes[index].upsert(vectors=[(id, doc.embedding.tolist(), meta)], namespace=namespace) + self.pinecone_indexes[index].upsert(vectors=[(id, doc.embedding.tolist(), meta)], namespace=self.namespace) def delete_documents( self, @@ -943,15 +1078,14 @@ class PineconeDocumentStore(BaseDocumentStore): headers: Optional[Dict[str, str]] = None, drop_ids: Optional[bool] = True, namespace: Optional[str] = None, + type_metadata: Optional[DocTypeMetadata] = None, ): """ Delete documents from the document store. :param index: Index name to delete the documents from. If `None`, the DocumentStore's default index - (`self.index`) will be used. + name (`self.index`) will be used. :param ids: Optional list of IDs to narrow down the documents to be deleted. - :param namespace: Optional namespace string. By default, it deletes vectors from the embeddings namespace - unless the namespace is empty, in which case it deletes from the documents namespace. :param filters: Optional filters to narrow down the documents for which embeddings are to be updated. Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, @@ -978,27 +1112,21 @@ class PineconeDocumentStore(BaseDocumentStore): } ``` :param headers: PineconeDocumentStore does not support headers. - :param drop_ids: Specifies if the locally stored IDs should be deleted. The default - is True. - :param namespace: Optional namespace to delete documents from. If not specified, defaults to the embedding - namespace (vectors) if it exists, otherwise the document namespace (no-vectors). + :param drop_ids: Specifies if the locally stored IDs should be deleted. The default is True. + :param namespace: Optional namespace. If not specified, None is default. + :param type_metadata: Optional value for `doc_type` metadata field as reference for documents to delete. :return None: """ if headers: raise NotImplementedError("PineconeDocumentStore does not support headers.") - if namespace is None: - if self.get_embedding_count(index=index) > 0: - namespace = self.embedding_namespace - else: - namespace = self.document_namespace + index = self._index(index) + self._index_connection_exists(index) - index = self._index_name(index) - if index not in self.pinecone_indexes: - raise PineconeDocumentStoreError( - f"Index named '{index}' does not exist. Try reinitializing PineconeDocumentStore() and running " - f"'update_embeddings()' to create and populate an index." - ) + if type_metadata: + # add filter for `doc_type` metadata field + filters = filters or {} + filters = self._add_type_metadata_filter(filters, type_metadata) pinecone_syntax_filter = LogicalFilterClause.parse(filters).convert_to_pinecone() if filters else None @@ -1020,20 +1148,21 @@ class PineconeDocumentStore(BaseDocumentStore): filter_ids = [doc.id for doc in docs] # Find the intersect id_values = list(set(id_values).intersection(set(filter_ids))) - if len(id_values) > 0: + if id_values: # Now we delete self.pinecone_indexes[index].delete(ids=id_values, namespace=namespace) if drop_ids: self.all_ids[index] = self.all_ids[index].difference(set(id_values)) - def delete_index(self, index: str): + def delete_index(self, index: Optional[str]): """ Delete an existing index. The index including all data will be removed. :param index: The name of the index to delete. :return: None """ - index = self._index_name(index) + index = self._index(index) + if index in pinecone.list_indexes(): pinecone.delete_index(index) logger.info("Index '%s' deleted.", index) @@ -1052,6 +1181,7 @@ class PineconeDocumentStore(BaseDocumentStore): headers: Optional[Dict[str, str]] = None, scale_score: bool = True, namespace: Optional[str] = None, + type_metadata: Optional[DocTypeMetadata] = None, ) -> List[Document]: """ Find the document that is most similar to the provided `query_emb` by using a vector similarity metric. @@ -1126,6 +1256,8 @@ class PineconeDocumentStore(BaseDocumentStore): :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + :param namespace: Optional namespace to query document from. If not specified, None is default. + :param type_metadata: Value of `doc_type` metadata that indicates which documents need to be queried. """ if headers: raise NotImplementedError("PineconeDocumentStore does not support headers.") @@ -1134,21 +1266,22 @@ class PineconeDocumentStore(BaseDocumentStore): return_embedding = self.return_embedding self._limit_check(top_k, include_values=return_embedding) - pinecone_syntax_filter = LogicalFilterClause.parse(filters).convert_to_pinecone() if filters else None + index = self._index(index) + self._index_connection_exists(index) - index = self._index_name(index) - if index not in self.pinecone_indexes: - raise PineconeDocumentStoreError( - f"Index named '{index}' does not exist. Try reinitializing PineconeDocumentStore() and running " - f"'update_embeddings()' to create and populate an index." - ) query_emb = query_emb.astype(np.float32) if self.similarity == "cosine": self.normalize_embedding(query_emb) - if namespace is None: - namespace = self.embedding_namespace + # if `doc_type` metadata not set, set to documents with embeddings + if type_metadata is None: + type_metadata = DOCUMENT_WITH_EMBEDDING # type: ignore + + filters = filters or {} + filters = self._add_type_metadata_filter(filters, type_metadata) + + pinecone_syntax_filter = LogicalFilterClause.parse(filters).convert_to_pinecone() if filters else None res = self.pinecone_indexes[index].query( query_emb.tolist(), @@ -1177,7 +1310,7 @@ class PineconeDocumentStore(BaseDocumentStore): vector_id_matrix, meta_matrix, values=values, index=index, return_embedding=return_embedding ) - if filters is not None and len(documents) == 0: + if filters is not None and not documents: logger.warning( "This query might have been done without metadata indexed and thus no results were retrieved. " "Make sure the desired metadata you want to filter with is indexed." @@ -1201,7 +1334,6 @@ class PineconeDocumentStore(BaseDocumentStore): ids: List[str], metadata: List[dict], values: Optional[List[List[float]]] = None, - namespace: Optional[str] = None, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None, return_embedding: Optional[bool] = None, @@ -1212,13 +1344,7 @@ class PineconeDocumentStore(BaseDocumentStore): if return_embedding is None: return_embedding = self.return_embedding - if namespace is None: - if self.get_embedding_count(index=index) > 0: - namespace = self.embedding_namespace - else: - namespace = self.document_namespace - - index = self._index_name(index) + index = self._index(index) # extract ID, content, and metadata to create Documents documents = [] @@ -1233,7 +1359,7 @@ class PineconeDocumentStore(BaseDocumentStore): if values is None: # If no embedding values are provided, we must request the embeddings from Pinecone for doc in documents: - self._attach_embedding_to_document(document=doc, index=index, namespace=namespace) + self._attach_embedding_to_document(document=doc, index=index) else: # If embedding values are given, we just add for doc, embedding in zip(documents, values): @@ -1241,12 +1367,12 @@ class PineconeDocumentStore(BaseDocumentStore): return documents - def _attach_embedding_to_document(self, document: Document, index: str, namespace: str): + def _attach_embedding_to_document(self, document: Document, index: str): """ Fetches the Document's embedding from the specified Pinecone index and attaches it to the Document's embedding field. """ - result = self.pinecone_indexes[index].fetch(ids=[document.id], namespace=namespace) + result = self.pinecone_indexes[index].fetch(ids=[document.id]) if result["vectors"].get(document.id, False): embedding = result["vectors"][document.id].get("values", None) document.embedding = np.asarray(embedding, dtype=np.float32) @@ -1283,57 +1409,32 @@ class PineconeDocumentStore(BaseDocumentStore): res = self.pinecone_indexes[index].fetch(ids=[id], namespace=namespace) return bool(res["vectors"].get(id, False)) - def _namespace_cleanup(self, index: str, batch_size: int = 32): - """ - Searches for any "-copy" namespaces and shifts vectors back to the original namespace. - """ - namespaces = self._list_namespaces(index) - namespaces = [name for name in namespaces if name[-5:] == "-copy"] - - with tqdm( - total=len(namespaces), - disable=not self.progress_bar, - position=0, - unit=" namespaces", - desc="Cleaning Namespace", - ) as progress_bar: - for namespace in namespaces: - target_namespace = namespace[:-5] - while True: - # Retrieve IDs from Pinecone - vector_id_matrix = self._get_ids(index=index, namespace=namespace, batch_size=batch_size) - # Once we reach final item, we break - if len(vector_id_matrix) == 0: - break - # Move these IDs to new namespace - self._move_documents_by_id_namespace( - ids=vector_id_matrix, - index=index, - source_namespace=namespace, - target_namespace=target_namespace, - batch_size=batch_size, - ) - progress_bar.set_description_str("Cleaned Namespace") - progress_bar.update(1) - def _get_ids( - self, index: str, namespace: str, batch_size: int = 32, filters: Optional[FilterType] = None + self, + index: str, + namespace: Optional[str] = None, + type_metadata: Optional[DocTypeMetadata] = None, + filters: Optional[FilterType] = None, + batch_size: int = DEFAULT_BATCH_SIZE, ) -> List[str]: """ Retrieves a list of IDs that satisfy a particular filter condition (or any) using a dummy query embedding. """ + filters = filters or {} + if type_metadata: + filters = self._add_type_metadata_filter(filters, type_value=type_metadata) pinecone_syntax_filter = LogicalFilterClause.parse(filters).convert_to_pinecone() if filters else None # Retrieve embeddings from Pinecone try: res = self.pinecone_indexes[index].query( self.dummy_query, - namespace=namespace, top_k=batch_size, include_values=False, include_metadata=False, filter=pinecone_syntax_filter, + namespace=namespace, ) except pinecone.ApiException as e: raise PineconeDocumentStoreError( @@ -1502,7 +1603,7 @@ class PineconeDocumentStore(BaseDocumentStore): doc.meta[k[20:]] = v # Extract offsets offsets: Dict[str, Optional[List[Span]]] = {"document": None, "context": None} - for mode in offsets.keys(): + for mode in offsets: if label_meta.get(f"label-answer-offsets-in-{mode}-start") is not None: offsets[mode] = [ Span( @@ -1560,75 +1661,76 @@ class PineconeDocumentStore(BaseDocumentStore): ids: Optional[List[str]] = None, filters: Optional[FilterType] = None, headers: Optional[Dict[str, str]] = None, - batch_size: int = 32, + batch_size: int = DEFAULT_BATCH_SIZE, + namespace: Optional[str] = None, ): """ Default class method used for deleting labels. Not supported by PineconeDocumentStore. """ - index = self._index_name(index) - if index not in self.pinecone_indexes: - raise PineconeDocumentStoreError( - f"Index named '{index}' does not exist. Try reinitializing PineconeDocumentStore() and running " - f"'update_embeddings()' to create and populate an index." - ) - - pinecone_syntax_filter = LogicalFilterClause.parse(filters).convert_to_pinecone() if filters else None + index = self._index(index) + self._index_connection_exists(index) i = 0 dummy_query = np.asarray(self.dummy_query) - # Set label namespace - namespace = "labels" + + type_metadata = LABEL while True: if ids is None: # Iteratively upsert new records without the labels metadata docs = self.query_by_embedding( dummy_query, - filters=pinecone_syntax_filter, + filters=filters, top_k=batch_size, index=index, return_embedding=True, namespace=namespace, + type_metadata=type_metadata, # type: ignore ) update_ids = [doc.id for doc in docs] else: i_end = min(i + batch_size, len(ids)) update_ids = ids[i:i_end] - if pinecone_syntax_filter: - pinecone_syntax_filter["label-id"] = {"$in": update_ids} + if filters: + filters["label-id"] = {IN_OPERATOR: update_ids} else: - pinecone_syntax_filter = {"label-id": {"$in": update_ids}} + filters = {"label-id": {IN_OPERATOR: update_ids}} # Retrieve embeddings and metadata for the batch of documents docs = self.query_by_embedding( dummy_query, - filters=pinecone_syntax_filter, + filters=filters, top_k=batch_size, index=index, return_embedding=True, namespace=namespace, + type_metadata=type_metadata, # type: ignore ) # Apply filter to update IDs, finding intersection update_ids = list(set(update_ids).intersection({doc.id for doc in docs})) i = i_end - if len(update_ids) == 0: + if not update_ids: break # Delete the documents self.delete_documents(ids=update_ids, index=index, namespace=namespace) def get_all_labels( - self, index=None, filters: Optional[FilterType] = None, headers: Optional[Dict[str, str]] = None + self, + index=None, + filters: Optional[FilterType] = None, + headers: Optional[Dict[str, str]] = None, + namespace: Optional[str] = None, ): """ Default class method used for getting all labels. """ - index = self._index_name(index) - if index not in self.pinecone_indexes: - raise PineconeDocumentStoreError( - f"Index named '{index}' does not exist. Try reinitializing PineconeDocumentStore() and running " - f"'update_embeddings()' to create and populate an index." - ) + index = self._index(index) + self._index_connection_exists(index) - documents = self.get_all_documents(index=index, filters=filters, headers=headers, namespace="labels") + # add filter for `doc_type` metadata field + filters = filters or {} + filters = self._add_type_metadata_filter(filters, LABEL) # type: ignore + + documents = self.get_all_documents(index=index, filters=filters, headers=headers, namespace=namespace) for doc in documents: doc.meta = self._pinecone_meta_format(doc.meta, labels=True) labels = self._meta_to_labels(documents) @@ -1640,30 +1742,26 @@ class PineconeDocumentStore(BaseDocumentStore): """ raise NotImplementedError("Labels are not supported by PineconeDocumentStore.") - def write_labels(self, labels, index=None, headers: Optional[Dict[str, str]] = None): + def write_labels( + self, labels, index=None, headers: Optional[Dict[str, str]] = None, namespace: Optional[str] = None + ): """ Default class method used for writing labels. """ - index = self._index_name(index) - if index not in self.pinecone_indexes: - self.pinecone_indexes[index] = self._create_index( - embedding_dim=self.embedding_dim, - index=index, - metric_type=self.metric_type, - replicas=self.replicas, - shards=self.shards, - recreate_index=False, - metadata_config=self.metadata_config, - ) + index = self._index(index) + index_connection = self._index_connection_exists(index, create=True) + if index_connection: + self.pinecone_indexes[index] = index_connection # Convert Label objects to dictionary of metadata metadata = self._label_to_meta(labels) ids = list(metadata.keys()) - # Set label namespace - namespace = "labels" - # Check if vectors exist in the namespace - existing_documents = self.get_documents_by_id(ids=ids, index=index, namespace=namespace, return_embedding=True) - if len(existing_documents) != 0: + + # Check if labels exist + existing_documents = self.get_documents_by_id( + ids=ids, index=index, namespace=namespace, return_embedding=True, include_type_metadata=True + ) + if existing_documents: # If they exist, we loop through and partial update their metadata with the new labels existing_ids = [doc.id for doc in existing_documents] for _id in existing_ids: @@ -1672,9 +1770,9 @@ class PineconeDocumentStore(BaseDocumentStore): # After update, we delete the ID from the metadata list del metadata[_id] # If there are any remaining IDs, we create new documents with the remaining metadata - if len(metadata) != 0: + if metadata: documents = [] for _id, meta in metadata.items(): metadata[_id] = self._meta_for_pinecone(meta) documents.append(Document(id=_id, content=meta["label-document-content"], meta=meta)) - self.write_documents(documents, index=index, labels=True) + self.write_documents(documents, index=index, labels=True, namespace=namespace) diff --git a/releasenotes/notes/refactor-pinecone-document-store.yaml b/releasenotes/notes/refactor-pinecone-document-store.yaml new file mode 100644 index 000000000..d67d134a3 --- /dev/null +++ b/releasenotes/notes/refactor-pinecone-document-store.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Refactor PineconeDocumentStore to use metadata instead of namespaces + for distinction between documents with embeddings, documents without + embeddings and labels \ No newline at end of file diff --git a/test/document_stores/test_pinecone.py b/test/document_stores/test_pinecone.py index 4a3a9f28e..d91961367 100644 --- a/test/document_stores/test_pinecone.py +++ b/test/document_stores/test_pinecone.py @@ -1,20 +1,18 @@ -from typing import List, Union, Dict, Any - import os -import numpy as np from inspect import getmembers, isclass, isfunction +from typing import Any, Dict, List, Union from unittest.mock import MagicMock +import numpy as np import pytest -from haystack.document_stores.pinecone import pinecone -from haystack.document_stores.pinecone import PineconeDocumentStore -from haystack.schema import Document +from haystack.document_stores.pinecone import DOCUMENT_WITH_EMBEDDING, PineconeDocumentStore, pinecone from haystack.errors import FilterError, PineconeDocumentStoreError +from haystack.schema import Document from haystack.testing import DocumentStoreBaseTestAbstract -from ..mocks import pinecone as pinecone_mock from ..conftest import MockBaseRetriever +from ..mocks import pinecone as pinecone_mock # Set metadata fields used during testing for PineconeDocumentStore meta_config META_FIELDS = ["meta_field", "name", "date", "numeric_field", "odd_document"] @@ -150,6 +148,7 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract): # # Tests # + @pytest.mark.integration def test_doc_store_wrong_init(self): """ @@ -545,7 +544,8 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract): assert doc_store_with_docs.get_document_count() == initial_document_count + 2 # remove one of the documents with embedding - all_embedding_docs = doc_store_with_docs.get_all_documents(namespace="vectors") + all_embedding_docs = doc_store_with_docs.get_all_documents(type_metadata=DOCUMENT_WITH_EMBEDDING) + doc_store_with_docs.delete_documents(ids=[all_embedding_docs[0].id]) # since we deleted one doc, we expect initial_document_count + 1 documents in total @@ -577,7 +577,7 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract): assert doc_store_with_docs.get_document_count() == initial_document_count + 2 # remove one of the documents without embedding - all_non_embedding_docs = doc_store_with_docs.get_all_documents(namespace="no-vectors") + all_non_embedding_docs = doc_store_with_docs.get_all_documents(type_metadata="no-vector") doc_store_with_docs.delete_documents(ids=[all_non_embedding_docs[0].id]) # since we deleted one doc, we expect initial_document_count + 1 documents in total @@ -656,6 +656,7 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract): mocked_ds.write_documents([doc]) call_args = mocked_ds.pinecone_indexes["document"].upsert.call_args.kwargs assert list(call_args["vectors"])[0][2] == { + "doc_type": "no-vector", "content": "test", "content_type": "text", "_split_overlap": '[{"doc_id": "test_id", "range": [0, 10]}]',