feat: Pinecone document store optimizations (#5902)

* Optimize methods for deleting documents and getting vector count. Enable warning messages when Pinecone limits are exceeded on Starter index type.

* Fix typo

* Add release note

* Fix mypy errors

* Remove unused import. Fix warning logging message.

* Update release note with description about limits for Starter index type in Pinecone

* Improve code base by:
- Adding new test cases for get_embedding_count method
- Fixing get_embedding_count method
- Improving delete documents
- Fix label retrieval
- Increase default batch size
- Improve get_document_count method

* Remove unused variable

* Fix mypy issues
This commit is contained in:
Ivana Zeljkovic 2023-10-16 19:26:24 +02:00 committed by GitHub
parent b43fc35deb
commit 2326f2f9fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 317 additions and 132 deletions

View File

@ -2,6 +2,8 @@ import copy
import json import json
import logging import logging
import operator import operator
from copy import deepcopy
from datetime import datetime
from functools import reduce from functools import reduce
from itertools import islice from itertools import islice
from typing import Any, Dict, Generator, List, Literal, Optional, Set, Union from typing import Any, Dict, Generator, List, Literal, Optional, Set, Union
@ -14,7 +16,8 @@ from haystack.document_stores.filter_utils import LogicalFilterClause
from haystack.errors import DuplicateDocumentError, PineconeDocumentStoreError from haystack.errors import DuplicateDocumentError, PineconeDocumentStoreError
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.nodes.retriever import DenseRetriever from haystack.nodes.retriever import DenseRetriever
from haystack.schema import Answer, Document, FilterType, Label, Span from haystack.schema import LABEL_DATETIME_FORMAT, Answer, Document, FilterType, Label, Span
from haystack.utils.batching import get_batches_from_generator
with LazyImport("Run 'pip install farm-haystack[pinecone]'") as pinecone_import: with LazyImport("Run 'pip install farm-haystack[pinecone]'") as pinecone_import:
import pinecone import pinecone
@ -30,7 +33,9 @@ AND_OPERATOR = "$and"
IN_OPERATOR = "$in" IN_OPERATOR = "$in"
EQ_OPERATOR = "$eq" EQ_OPERATOR = "$eq"
DEFAULT_BATCH_SIZE = 32 DEFAULT_BATCH_SIZE = 128
PINECONE_STARTER_POD = "starter"
DocTypeMetadata = Literal["vector", "no-vector", "label"] DocTypeMetadata = Literal["vector", "no-vector", "label"]
@ -290,15 +295,15 @@ class PineconeDocumentStore(BaseDocumentStore):
""" """
Add new filter for `doc_type` metadata field. Add new filter for `doc_type` metadata field.
""" """
all_filters = deepcopy(filters)
if type_value: if type_value:
new_type_filter = {TYPE_METADATA_FIELD: {EQ_OPERATOR: type_value}} new_type_filter = {TYPE_METADATA_FIELD: {EQ_OPERATOR: type_value}}
if AND_OPERATOR not in filters and TYPE_METADATA_FIELD not in filters: if AND_OPERATOR not in all_filters and TYPE_METADATA_FIELD not in all_filters:
# extend filters with new `doc_type` filter and add $and operator # extend filters with new `doc_type` filter and add $and operator
filters.update(new_type_filter) all_filters.update(new_type_filter)
all_filters = filters
return {AND_OPERATOR: all_filters} return {AND_OPERATOR: all_filters}
filters_content = filters[AND_OPERATOR] if AND_OPERATOR in filters else filters filters_content = all_filters[AND_OPERATOR] if AND_OPERATOR in all_filters else all_filters
if TYPE_METADATA_FIELD in filters_content: # type: ignore if TYPE_METADATA_FIELD in filters_content: # type: ignore
current_type_filter = filters_content[TYPE_METADATA_FIELD] # type: ignore current_type_filter = filters_content[TYPE_METADATA_FIELD] # type: ignore
type_values = {type_value} type_values = {type_value}
@ -314,7 +319,19 @@ class PineconeDocumentStore(BaseDocumentStore):
new_type_filter = {TYPE_METADATA_FIELD: {IN_OPERATOR: list(type_values)}} # type: ignore new_type_filter = {TYPE_METADATA_FIELD: {IN_OPERATOR: list(type_values)}} # type: ignore
filters_content.update(new_type_filter) # type: ignore filters_content.update(new_type_filter) # type: ignore
return filters return all_filters
def _remove_type_metadata_filter(self, filters: FilterType) -> FilterType:
"""
Remove filter for `doc_type` metadata field if it exists.
"""
all_filters = deepcopy(filters)
for key, value in all_filters.copy().items():
if key == TYPE_METADATA_FIELD:
del all_filters[key]
elif isinstance(value, dict):
all_filters[key] = self._remove_type_metadata_filter(filters=value)
return all_filters
def _get_default_type_metadata(self, index: Optional[str], namespace: Optional[str] = None) -> str: def _get_default_type_metadata(self, index: Optional[str], namespace: Optional[str] = None) -> str:
""" """
@ -325,16 +342,49 @@ class PineconeDocumentStore(BaseDocumentStore):
return DOCUMENT_WITH_EMBEDDING return DOCUMENT_WITH_EMBEDDING
return DOCUMENT_WITHOUT_EMBEDDING return DOCUMENT_WITHOUT_EMBEDDING
def _get_vector_count(self, index: str, filters: Optional[FilterType], namespace: Optional[str]) -> int: def _get_vector_count(
self, index: str, filters: Optional[FilterType], namespace: Optional[str], types_metadata: Set[DocTypeMetadata]
) -> int:
index = self._index(index)
self._index_connection_exists(index)
pinecone_index = self.pinecone_indexes[index]
filters = filters or {}
for type_value in types_metadata:
# add filter for each `doc_type` metadata value
filters = self._add_type_metadata_filter(filters, type_value)
pinecone_syntax_filter = LogicalFilterClause.parse(filters).convert_to_pinecone() if filters else None
if pinecone.describe_index(index).pod_type != PINECONE_STARTER_POD:
stats = pinecone_index.describe_index_stats(filter=pinecone_syntax_filter)
namespaces = stats["namespaces"]
if namespace is None and namespace not in namespaces:
namespace = ""
return namespaces[namespace]["vector_count"] if namespace in namespaces else 0
# Due to missing support for metadata filtering in `describe_index_stats()` method for `gcp-starter`,
# use dummy query for getting vector count
res = self.pinecone_indexes[index].query( res = self.pinecone_indexes[index].query(
self.dummy_query, self.dummy_query,
top_k=self.top_k_limit, top_k=self.top_k_limit,
include_values=False, include_values=False,
include_metadata=False, include_metadata=False,
filter=filters, filter=pinecone_syntax_filter,
namespace=namespace,
) )
return len(res["matches"]) vector_count = len(res["matches"])
if vector_count >= self.top_k_limit:
logger.warning(
"Current index type 'Starter' doesn't support features 'Namespace' and metadata filtering as part of describe_index_stats operation. "
"Limit for fetching documents in 'Starter' index type is %s.",
self.top_k_limit,
)
return vector_count
def _delete_vectors(self, index: str, ids: List[str], namespace: Optional[str]) -> None:
batch_size = self.top_k_limit_vectors
for id_batch in get_batches_from_generator(ids, batch_size):
self.pinecone_indexes[index].delete(ids=list(id_batch), namespace=namespace)
def get_document_count( def get_document_count(
self, self,
@ -386,22 +436,20 @@ class PineconeDocumentStore(BaseDocumentStore):
if headers: if headers:
raise NotImplementedError("PineconeDocumentStore does not support headers.") raise NotImplementedError("PineconeDocumentStore does not support headers.")
index = self._index(index) # add `doc_type` value if specified
self._index_connection_exists(index) if type_metadata:
types_metadata = {type_metadata}
filters = filters or {} # otherwise add default `doc_type` value which is related to documents without embeddings,
if not type_metadata: # but only if `doc_type` doesn't already exist in filters
# add filter for `doc_type` metadata related to documents without embeddings elif TYPE_METADATA_FIELD not in str(filters):
filters = self._add_type_metadata_filter(filters, type_value=DOCUMENT_WITHOUT_EMBEDDING) # type: ignore types_metadata = {DOCUMENT_WITHOUT_EMBEDDING} # type: ignore
if not only_documents_without_embedding: if not only_documents_without_embedding:
# add filter for `doc_type` metadata related to documents with embeddings # add `doc_type` related to documents with embeddings
filters = self._add_type_metadata_filter(filters, type_value=DOCUMENT_WITH_EMBEDDING) # type: ignore types_metadata.add(DOCUMENT_WITH_EMBEDDING) # type: ignore
else: else:
# if value for `doc_type` metadata is specified, add filter with given value types_metadata = set()
filters = self._add_type_metadata_filter(filters, type_value=type_metadata)
pinecone_syntax_filter = LogicalFilterClause.parse(filters).convert_to_pinecone() if filters else None return self._get_vector_count(index, filters=filters, namespace=namespace, types_metadata=types_metadata) # type: ignore
return self._get_vector_count(index, filters=pinecone_syntax_filter, namespace=namespace)
def get_embedding_count( def get_embedding_count(
self, filters: Optional[FilterType] = None, index: Optional[str] = None, namespace: Optional[str] = None self, filters: Optional[FilterType] = None, index: Optional[str] = None, namespace: Optional[str] = None
@ -410,17 +458,39 @@ class PineconeDocumentStore(BaseDocumentStore):
Return the count of embeddings in the document store. Return the count of embeddings in the document store.
:param index: Optional index name to retrieve all documents from. :param index: Optional index name to retrieve all documents from.
:param filters: Filters are not supported for `get_embedding_count` in Pinecone. :param filters: Optional filters to narrow down the documents with embedding which
will be counted. Filters are defined as nested dictionaries. The keys of the dictionaries
can be a logical operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`,
`"$in"`, `"$gt"`, `"$gte"`, `"$lt"`, `"$lte"`), or a metadata field name.
Logical operator keys take a dictionary of metadata field names or logical operators as
value. Metadata field names take a dictionary of comparison operators as value. Comparison
operator keys take a single value or (in case of `"$in"`) a list of values as value.
If no logical operator is provided, `"$and"` is used as default operation. If no comparison
operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default
operation.
__Example__:
```python
filters = {
"$and": {
"type": {"$eq": "article"},
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
"rating": {"$gte": 3},
"$or": {
"genre": {"$in": ["economy", "politics"]},
"publisher": {"$eq": "nytimes"}
}
}
}
```
:param namespace: Optional namespace to count embeddings from. If not specified, None is default. :param namespace: Optional namespace to count embeddings from. If not specified, None is default.
""" """
if filters: # drop filter for `doc_type` if exists
raise NotImplementedError("Filters are not supported for get_embedding_count in PineconeDocumentStore") if TYPE_METADATA_FIELD in str(filters):
filters = self._remove_type_metadata_filter(filters) # type: ignore
index = self._index(index) return self._get_vector_count(
self._index_connection_exists(index) index, filters=filters, namespace=namespace, types_metadata={DOCUMENT_WITH_EMBEDDING} # type: ignore
)
pinecone_filters = self._meta_for_pinecone({TYPE_METADATA_FIELD: DOCUMENT_WITH_EMBEDDING})
return self._get_vector_count(index, filters=pinecone_filters, namespace=namespace)
def _validate_index_sync(self, index: Optional[str] = None): def _validate_index_sync(self, index: Optional[str] = None):
""" """
@ -502,8 +572,9 @@ class PineconeDocumentStore(BaseDocumentStore):
with tqdm( with tqdm(
total=len(document_objects), disable=not self.progress_bar, position=0, desc="Writing Documents" total=len(document_objects), disable=not self.progress_bar, position=0, desc="Writing Documents"
) as progress_bar: ) as progress_bar:
for i in range(0, len(document_objects), batch_size): for document_batch in get_batches_from_generator(document_objects, batch_size):
document_batch = document_objects[i : i + batch_size] document_batch = list(document_batch)
document_batch_copy = deepcopy(document_batch)
ids = [doc.id for doc in document_batch] ids = [doc.id for doc in document_batch]
# If duplicate_documents set to `skip` or `fail`, we need to check for existing documents # If duplicate_documents set to `skip` or `fail`, we need to check for existing documents
if duplicate_documents in ["skip", "fail"]: if duplicate_documents in ["skip", "fail"]:
@ -548,10 +619,10 @@ class PineconeDocumentStore(BaseDocumentStore):
**doc.meta, **doc.meta,
} }
) )
for doc in document_objects[i : i + batch_size] for doc in document_batch_copy
] ]
if add_vectors: if add_vectors:
embeddings = [doc.embedding for doc in document_objects[i : i + batch_size]] embeddings = [doc.embedding for doc in document_batch_copy]
embeddings_to_index = np.array(embeddings, dtype="float32") embeddings_to_index = np.array(embeddings, dtype="float32")
if self.similarity == "cosine": if self.similarity == "cosine":
# Normalize embeddings inplace # Normalize embeddings inplace
@ -744,7 +815,7 @@ class PineconeDocumentStore(BaseDocumentStore):
if headers: if headers:
raise NotImplementedError("PineconeDocumentStore does not support headers.") raise NotImplementedError("PineconeDocumentStore does not support headers.")
if not type_metadata: if not type_metadata and TYPE_METADATA_FIELD not in str(filters):
# set default value for `doc_type` metadata field # set default value for `doc_type` metadata field
type_metadata = self._get_default_type_metadata(index, namespace) # type: ignore type_metadata = self._get_default_type_metadata(index, namespace) # type: ignore
@ -819,7 +890,7 @@ class PineconeDocumentStore(BaseDocumentStore):
index = self._index(index) index = self._index(index)
self._index_connection_exists(index) self._index_connection_exists(index)
if not type_metadata: if not type_metadata and TYPE_METADATA_FIELD not in str(filters):
# set default value for `doc_type` metadata field # set default value for `doc_type` metadata field
type_metadata = self._get_default_type_metadata(index, namespace) # type: ignore type_metadata = self._get_default_type_metadata(index, namespace) # type: ignore
@ -833,10 +904,9 @@ class PineconeDocumentStore(BaseDocumentStore):
"Make sure the desired metadata you want to filter with is indexed." "Make sure the desired metadata you want to filter with is indexed."
) )
for i in range(0, len(ids), batch_size): for id_batch in get_batches_from_generator(ids, batch_size):
i_end = min(len(ids), i + batch_size)
documents = self.get_documents_by_id( documents = self.get_documents_by_id(
ids=ids[i:i_end], ids=list(id_batch),
index=index, index=index,
batch_size=batch_size, batch_size=batch_size,
return_embedding=return_embedding, return_embedding=return_embedding,
@ -857,7 +927,9 @@ class PineconeDocumentStore(BaseDocumentStore):
index = self._index(index) index = self._index(index)
self._index_connection_exists(index) self._index_connection_exists(index)
document_count = self.get_document_count(index=index, namespace=namespace, type_metadata=type_metadata) document_count = self.get_document_count(
index=index, namespace=namespace, type_metadata=type_metadata, filters=filters
)
if index not in self.all_ids: if index not in self.all_ids:
self.all_ids[index] = set() self.all_ids[index] = set()
@ -865,16 +937,26 @@ class PineconeDocumentStore(BaseDocumentStore):
# We have all of the IDs and don't need to extract from Pinecone # We have all of the IDs and don't need to extract from Pinecone
return list(self.all_ids[index]) return list(self.all_ids[index])
else: else:
# Otherwise we must query and extract IDs from the original namespace, then move the retrieved embeddings if pinecone.describe_index(index).pod_type == PINECONE_STARTER_POD:
# to a temporary namespace and query again for new items. We repeat this process until all embeddings # Due to missing support for Namespace in Starter Pinecone index type, retrieve up to 10000 vectors
# have been retrieved. logger.warning(
"Current index type 'Starter' doesn't support 'Namespace' feature. "
"Limit for fetching documents in 'Starter' index type is %s.",
self.top_k_limit,
)
all_ids = self._get_ids(
index=index, filters=filters, type_metadata=type_metadata, batch_size=self.top_k_limit
)
else:
# If we don't have all IDs, we must query and extract IDs from the original namespace, then move the
# retrieved documents to a temporary namespace and query again for new items. We repeat this process
# until all documents have been retrieved.
target_namespace = f"{namespace}-copy" if namespace is not None else "copy" target_namespace = f"{namespace}-copy" if namespace is not None else "copy"
all_ids: Set[str] = set() all_ids: Set[str] = set() # type: ignore
vector_id_matrix = ["dummy-id"]
with tqdm( with tqdm(
total=document_count, disable=not self.progress_bar, position=0, unit=" ids", desc="Retrieving IDs" total=document_count, disable=not self.progress_bar, position=0, unit=" ids", desc="Retrieving IDs"
) as progress_bar: ) as progress_bar:
while vector_id_matrix: while True:
# Retrieve IDs from Pinecone # Retrieve IDs from Pinecone
vector_id_matrix = self._get_ids( vector_id_matrix = self._get_ids(
index=index, index=index,
@ -883,8 +965,11 @@ class PineconeDocumentStore(BaseDocumentStore):
type_metadata=type_metadata, type_metadata=type_metadata,
batch_size=batch_size, batch_size=batch_size,
) )
if not vector_id_matrix:
break
# Save IDs # Save IDs
all_ids = all_ids.union(set(vector_id_matrix)) unique_ids = set(vector_id_matrix)
all_ids = all_ids.union(unique_ids) # type: ignore
# Move these IDs to new namespace # Move these IDs to new namespace
self._move_documents_by_id_namespace( self._move_documents_by_id_namespace(
ids=vector_id_matrix, ids=vector_id_matrix,
@ -894,10 +979,12 @@ class PineconeDocumentStore(BaseDocumentStore):
batch_size=batch_size, batch_size=batch_size,
) )
progress_bar.set_description_str("Retrieved IDs") progress_bar.set_description_str("Retrieved IDs")
progress_bar.update(len(set(vector_id_matrix))) progress_bar.update(len(unique_ids))
# Now move all documents back to source namespace # Now move all documents back to source namespace
self._namespace_cleanup(index=index, namespace=target_namespace, batch_size=batch_size) self._namespace_cleanup(
index=index, ids=list(all_ids), namespace=target_namespace, batch_size=batch_size
)
self._add_local_ids(index, list(all_ids)) self._add_local_ids(index, list(all_ids))
return list(all_ids) return list(all_ids)
@ -924,11 +1011,8 @@ class PineconeDocumentStore(BaseDocumentStore):
with tqdm( with tqdm(
total=len(ids), disable=not self.progress_bar, position=0, unit=" docs", desc="Moving Documents" total=len(ids), disable=not self.progress_bar, position=0, unit=" docs", desc="Moving Documents"
) as progress_bar: ) as progress_bar:
for i in range(0, len(ids), batch_size): for id_batch in get_batches_from_generator(ids, batch_size):
i_end = min(len(ids), i + batch_size) id_batch = list(id_batch)
# TODO if i == i_end:
# break
id_batch = ids[i:i_end]
# Retrieve documents from source_namespace # Retrieve documents from source_namespace
result = self.pinecone_indexes[index].fetch(ids=id_batch, namespace=source_namespace) result = self.pinecone_indexes[index].fetch(ids=id_batch, namespace=source_namespace)
vector_id_matrix = result["vectors"].keys() vector_id_matrix = result["vectors"].keys()
@ -938,27 +1022,24 @@ class PineconeDocumentStore(BaseDocumentStore):
# Store metadata nd embeddings in new target_namespace # Store metadata nd embeddings in new target_namespace
self.pinecone_indexes[index].upsert(vectors=data_to_write_to_pinecone, namespace=target_namespace) self.pinecone_indexes[index].upsert(vectors=data_to_write_to_pinecone, namespace=target_namespace)
# Delete vectors from source_namespace # Delete vectors from source_namespace
self.delete_documents(index=index, ids=ids[i:i_end], namespace=source_namespace, drop_ids=False) self.delete_documents(index=index, ids=id_batch, namespace=source_namespace, drop_ids=False)
progress_bar.set_description_str("Documents Moved") progress_bar.set_description_str("Documents Moved")
progress_bar.update(len(id_batch)) progress_bar.update(len(id_batch))
def _namespace_cleanup(self, index: str, namespace: str, batch_size: int = DEFAULT_BATCH_SIZE): def _namespace_cleanup(self, index: str, ids: List[str], namespace: str, batch_size: int = DEFAULT_BATCH_SIZE):
""" """
Shifts vectors back from "-copy" namespace to the original namespace. Shifts vectors back from "*-copy" namespace to the original namespace.
""" """
with tqdm( with tqdm(
total=1, disable=not self.progress_bar, position=0, unit=" namespaces", desc="Cleaning Namespace" total=1, disable=not self.progress_bar, position=0, unit=" namespaces", desc="Cleaning Namespace"
) as progress_bar: ) as progress_bar:
target_namespace = namespace[:-5] if namespace != "copy" else None target_namespace = namespace[:-5] if namespace != "copy" else None
while True: for id_batch in get_batches_from_generator(ids, batch_size):
# Retrieve IDs from Pinecone id_batch = list(id_batch)
vector_id_matrix = self._get_ids(index=index, namespace=namespace, batch_size=batch_size) if not id_batch:
# Once we reach final item, we break
if len(vector_id_matrix) == 0:
break break
# Move these IDs to new namespace
self._move_documents_by_id_namespace( self._move_documents_by_id_namespace(
ids=vector_id_matrix, ids=id_batch,
index=index, index=index,
source_namespace=namespace, source_namespace=namespace,
target_namespace=target_namespace, target_namespace=target_namespace,
@ -1000,10 +1081,8 @@ class PineconeDocumentStore(BaseDocumentStore):
self._index_connection_exists(index) self._index_connection_exists(index)
documents = [] documents = []
for i in range(0, len(ids), batch_size): for id_batch in get_batches_from_generator(ids, batch_size):
i_end = min(len(ids), i + batch_size) result = self.pinecone_indexes[index].fetch(ids=list(id_batch), namespace=namespace)
id_batch = ids[i:i_end]
result = self.pinecone_indexes[index].fetch(ids=id_batch, namespace=namespace)
vector_id_matrix = [] vector_id_matrix = []
meta_matrix = [] meta_matrix = []
@ -1135,20 +1214,15 @@ class PineconeDocumentStore(BaseDocumentStore):
self.pinecone_indexes[index].delete(delete_all=True, namespace=namespace) self.pinecone_indexes[index].delete(delete_all=True, namespace=namespace)
id_values = list(self.all_ids[index]) id_values = list(self.all_ids[index])
else: else:
if ids is None: id_values = ids or []
# In this case we identify all IDs that satisfy the filter condition
id_values = self._get_all_document_ids(index=index, namespace=namespace, filters=pinecone_syntax_filter)
else:
id_values = ids
if pinecone_syntax_filter: if pinecone_syntax_filter:
# We must first identify the IDs that satisfy the filter condition # Extract IDs for all documents that satisfy given filters
docs = self.get_all_documents(index=index, namespace=namespace, filters=pinecone_syntax_filter) doc_ids = self._get_all_document_ids(index=index, namespace=namespace, filters=filters)
filter_ids = [doc.id for doc in docs] # Extend the list of document IDs that should be deleted
# Find the intersect id_values = list(set(id_values).union(set(doc_ids)))
id_values = list(set(id_values).intersection(set(filter_ids)))
if id_values: if id_values:
# Now we delete self._delete_vectors(index, id_values, namespace)
self.pinecone_indexes[index].delete(ids=id_values, namespace=namespace)
if drop_ids: if drop_ids:
self.all_ids[index] = self.all_ids[index].difference(set(id_values)) self.all_ids[index] = self.all_ids[index].difference(set(id_values))
@ -1636,14 +1710,20 @@ class PineconeDocumentStore(BaseDocumentStore):
if k.startswith("label-meta-"): if k.startswith("label-meta-"):
label_meta_metadata[k[11:]] = v label_meta_metadata[k[11:]] = v
# Rebuild Label object # Rebuild Label object
created_at = label_meta.get("label-created-at")
updated_at = label_meta.get("label-updated-at")
if created_at and isinstance(created_at, datetime):
created_at = created_at.strftime(LABEL_DATETIME_FORMAT)
if updated_at and isinstance(updated_at, datetime):
updated_at = updated_at.strftime(LABEL_DATETIME_FORMAT)
label = Label( label = Label(
id=label_meta["label-id"], id=label_meta["label-id"],
query=label_meta["query"], query=label_meta["query"],
document=doc, document=doc,
answer=answer, answer=answer,
pipeline_id=label_meta["label-pipeline-id"], pipeline_id=label_meta["label-pipeline-id"],
created_at=label_meta["label-created-at"], created_at=created_at,
updated_at=label_meta["label-updated-at"], updated_at=updated_at,
is_correct_answer=label_meta["label-is-correct-answer"], is_correct_answer=label_meta["label-is-correct-answer"],
is_correct_document=label_meta["label-is-correct-document"], is_correct_document=label_meta["label-is-correct-document"],
origin=label_meta["label-origin"], origin=label_meta["label-origin"],
@ -1724,11 +1804,9 @@ class PineconeDocumentStore(BaseDocumentStore):
index = self._index(index) index = self._index(index)
self._index_connection_exists(index) self._index_connection_exists(index)
# add filter for `doc_type` metadata field documents = self.get_all_documents(
filters = filters or {} index=index, filters=filters, headers=headers, namespace=namespace, type_metadata=LABEL # type: ignore
filters = self._add_type_metadata_filter(filters, LABEL) # type: ignore )
documents = self.get_all_documents(index=index, filters=filters, headers=headers, namespace=namespace)
for doc in documents: for doc in documents:
doc.meta = self._pinecone_meta_format(doc.meta, labels=True) doc.meta = self._pinecone_meta_format(doc.meta, labels=True)
labels = self._meta_to_labels(documents) labels = self._meta_to_labels(documents)

View File

@ -1,33 +1,30 @@
from __future__ import annotations from __future__ import annotations
import ast
import csv import csv
import hashlib import hashlib
import inspect import inspect
import json
from typing import Any, Optional, Dict, List, Union, Literal
from pathlib import Path
from uuid import uuid4
import logging import logging
import time import time
import json
import ast
from dataclasses import asdict from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from uuid import uuid4
import numpy as np import numpy as np
from numpy import ndarray
import pandas as pd import pandas as pd
from numpy import ndarray
from pandas import DataFrame from pandas import DataFrame
from pydantic import BaseConfig, Field from pydantic import BaseConfig, Field
from pydantic.json import pydantic_encoder
# We are using Pydantic dataclasses instead of vanilla Python's # We are using Pydantic dataclasses instead of vanilla Python's
# See #1598 for the reasons behind this choice & performance considerations # See #1598 for the reasons behind this choice & performance considerations
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from pydantic.json import pydantic_encoder
from haystack.mmh3 import hash128 from haystack.mmh3 import hash128
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,6 +35,8 @@ BaseConfig.arbitrary_types_allowed = True
ContentTypes = Literal["text", "table", "image", "audio"] ContentTypes = Literal["text", "table", "image", "audio"]
FilterType = Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]] FilterType = Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]]
LABEL_DATETIME_FORMAT: str = "%Y-%m-%d %H:%M:%S"
@dataclass @dataclass
class Document: class Document:
@ -526,7 +525,7 @@ class Label:
:param pipeline_id: pipeline identifier (any str) that was involved for generating this label (in-case of user feedback). :param pipeline_id: pipeline identifier (any str) that was involved for generating this label (in-case of user feedback).
:param created_at: Timestamp of creation with format yyyy-MM-dd HH:mm:ss. :param created_at: Timestamp of creation with format yyyy-MM-dd HH:mm:ss.
Generate in Python via time.strftime("%Y-%m-%d %H:%M:%S"). Generate in Python via time.strftime("%Y-%m-%d %H:%M:%S").
:param created_at: Timestamp of update with format yyyy-MM-dd HH:mm:ss. :param updated_at: Timestamp of update with format yyyy-MM-dd HH:mm:ss.
Generate in Python via time.strftime("%Y-%m-%d %H:%M:%S") Generate in Python via time.strftime("%Y-%m-%d %H:%M:%S")
:param meta: Meta fields like "annotator_name" in the form of a custom dict (any keys and values allowed). :param meta: Meta fields like "annotator_name" in the form of a custom dict (any keys and values allowed).
:param filters: filters that should be applied to the query to rule out non-relevant documents. For example, if there are different correct answers :param filters: filters that should be applied to the query to rule out non-relevant documents. For example, if there are different correct answers
@ -540,7 +539,7 @@ class Label:
self.id = str(uuid4()) self.id = str(uuid4())
if created_at is None: if created_at is None:
created_at = time.strftime("%Y-%m-%d %H:%M:%S") created_at = time.strftime(LABEL_DATETIME_FORMAT)
self.created_at = created_at self.created_at = created_at
self.updated_at = updated_at self.updated_at = updated_at

View File

@ -0,0 +1,11 @@
---
enhancements:
- |
Optimize particular methods from PineconeDocumentStore (delete_documents and _get_vector_count)
upgrade:
- |
This update enables all Pinecone index types to be used, including Starter.
Previously, Pinecone Starter index type couldn't be used as document store. Due to limitations of this index type
(https://docs.pinecone.io/docs/starter-environment), in current implementation fetching documents is limited to
Pinecone query vector limit (10000 vectors). Accordingly, if the number of documents in the index is above this limit,
some of PineconeDocumentStore functions will be limited.

View File

@ -6,7 +6,13 @@ from unittest.mock import MagicMock
import numpy as np import numpy as np
import pytest import pytest
from haystack.document_stores.pinecone import DOCUMENT_WITH_EMBEDDING, PineconeDocumentStore, pinecone from haystack.document_stores.pinecone import (
DOCUMENT_WITH_EMBEDDING,
DOCUMENT_WITHOUT_EMBEDDING,
TYPE_METADATA_FIELD,
PineconeDocumentStore,
pinecone,
)
from haystack.errors import FilterError, PineconeDocumentStoreError from haystack.errors import FilterError, PineconeDocumentStoreError
from haystack.schema import Document from haystack.schema import Document
from haystack.testing import DocumentStoreBaseTestAbstract from haystack.testing import DocumentStoreBaseTestAbstract
@ -15,7 +21,7 @@ from ..conftest import MockBaseRetriever
from ..mocks import pinecone as pinecone_mock from ..mocks import pinecone as pinecone_mock
# Set metadata fields used during testing for PineconeDocumentStore meta_config # Set metadata fields used during testing for PineconeDocumentStore meta_config
META_FIELDS = ["meta_field", "name", "date", "numeric_field", "odd_document"] META_FIELDS = ["meta_field", "name", "date", "numeric_field", "odd_document", "doc_type"]
class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract): class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
@ -57,6 +63,7 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
pass pass
pinecone.init = MagicMock() pinecone.init = MagicMock()
pinecone.describe_index = MagicMock()
DSMock._create_index = MagicMock() DSMock._create_index = MagicMock()
mocked_ds = DSMock(api_key="MOCK") mocked_ds = DSMock(api_key="MOCK")
@ -466,10 +473,60 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
We expect 1 doc with an embeddings because all documents in already written in doc_store_with_docs contain no We expect 1 doc with an embeddings because all documents in already written in doc_store_with_docs contain no
embeddings. embeddings.
""" """
doc = Document(content="Doc with embedding", embedding=np.random.rand(768).astype(np.float32)) doc = Document(
content="Doc with embedding",
embedding=np.random.rand(768).astype(np.float32),
meta={"meta_field": "test-1"},
)
doc_store_with_docs.write_documents([doc]) doc_store_with_docs.write_documents([doc])
assert doc_store_with_docs.get_embedding_count() == 1 assert doc_store_with_docs.get_embedding_count() == 1
@pytest.mark.integration
def test_get_embedding_count_with_filters(self, doc_store_with_docs: PineconeDocumentStore):
"""
We expect 1 doc with an embedding and given filters, because there are only two documents with embedding
written in doc_store_with_docs, while only one of them satisfies given filters.
"""
doc_1 = Document(
content="Doc with embedding 1",
embedding=np.random.rand(768).astype(np.float32),
meta={"meta_field": "test-1"},
)
doc_2 = Document(
content="Doc with embedding 2",
embedding=np.random.rand(768).astype(np.float32),
meta={"meta_field": "test-2"},
)
doc_store_with_docs.write_documents([doc_1, doc_2])
assert doc_store_with_docs.get_embedding_count(filters={"meta_field": "test-1"}) == 1
@pytest.mark.integration
def test_get_embedding_count_with_doc_type_filters(self, doc_store_with_docs: PineconeDocumentStore):
"""
We expect 2 docs with an embedding and given filters, because there are only two documents with embedding
written in doc_store_with_docs and both of them satisfy given filters (`meta_field` filter).
Even though the filters include `doc_type` with value related to documents without embedding (`no-vector`),
we expect this particular filter to be ignored (irrelevant, since documents with embedding have `doc_type`
set to `vector`).
"""
doc_1 = Document(
content="Doc with embedding 1",
embedding=np.random.rand(768).astype(np.float32),
meta={"meta_field": "test-2"},
)
doc_2 = Document(
content="Doc with embedding 2",
embedding=np.random.rand(768).astype(np.float32),
meta={"meta_field": "test-2"},
)
doc_store_with_docs.write_documents([doc_1, doc_2])
assert (
doc_store_with_docs.get_embedding_count(
filters={TYPE_METADATA_FIELD: DOCUMENT_WITHOUT_EMBEDDING, "meta_field": "test-2"}
)
== 2
)
@pytest.mark.integration @pytest.mark.integration
def test_get_document_count_after_write_doc_with_embedding(self, doc_store_with_docs: PineconeDocumentStore): def test_get_document_count_after_write_doc_with_embedding(self, doc_store_with_docs: PineconeDocumentStore):
""" """

View File

@ -1,10 +1,8 @@
from typing import Optional, List, Union
import logging import logging
from typing import Any, Dict, List, Optional, Union
from haystack.schema import FilterType from haystack.schema import FilterType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -12,6 +10,33 @@ logger = logging.getLogger(__name__)
CONFIG: dict = {"api_key": None, "environment": None, "indexes": {}} CONFIG: dict = {"api_key": None, "environment": None, "indexes": {}}
# Mock Pinecone Index Description instance
class IndexDescription:
def __init__(
self,
name: str,
metric: Optional[str] = None,
replicas: Optional[int] = None,
dimension: Optional[int] = None,
shards: Optional[int] = None,
pods: Optional[int] = None,
pod_type: Optional[str] = None,
status: Dict[str, Any] = None,
metadata_config: Optional[dict] = None,
source_collection: Optional[str] = None,
) -> None:
self.name = name
self.metric = metric
self.replicas = replicas
self.dimension = dimension
self.shards = shards
self.pods = pods
self.pod_type = pod_type
self.status = status
self.metadata_config = metadata_config
self.source_collection = source_collection
# Mock Pinecone Index instance # Mock Pinecone Index instance
class IndexObject: class IndexObject:
def __init__( def __init__(
@ -331,3 +356,18 @@ def create_index(
def delete_index(index: str): def delete_index(index: str):
del CONFIG["indexes"][index] del CONFIG["indexes"][index]
def describe_index(index: str):
return IndexDescription(
name=index,
metric="dotproduct",
replicas=1,
dimension=768.0,
shards=1,
pods=1,
pod_type="p1.x1",
status={"ready": True, "state": "Ready"},
metadata_config=None,
source_collection="",
)