feat: Add batch_size parameter and cast timeout_config value to tuple for WeaviateDocumentStore (#5079)

* Add batch_size parameter and cast timeout_config to tuple

* Add unit test

* Remove debug tqdm

* Remove debug tqdm introduced in #5063
This commit is contained in:
bogdankostic 2023-06-06 17:06:10 +02:00 committed by GitHub
parent 1777b22fcb
commit da1f245a84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 9 deletions

View File

@ -1310,7 +1310,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
body = []
all_documents = []
for query_emb, cur_filters in tqdm(zip(query_embs, filters)):
for query_emb, cur_filters in zip(query_embs, filters):
cur_query_body = self._construct_dense_query_body(
query_emb=query_emb, filters=cur_filters, top_k=top_k, return_embedding=return_embedding
)

View File

@ -97,6 +97,7 @@ class WeaviateDocumentStore(KeywordDocumentStore):
duplicate_documents: str = "overwrite",
recreate_index: bool = False,
replication_factor: int = 1,
batch_size: int = 10_000,
):
"""
:param host: Weaviate server connection URL for storing and processing documents and vectors.
@ -138,6 +139,7 @@ class WeaviateDocumentStore(KeywordDocumentStore):
lost if you choose to recreate the index.
:param replication_factor: Sets the Weaviate Class's replication factor in Weaviate at the time of Class creation.
See also [Weaviate documentation](https://weaviate.io/developers/weaviate/current/configuration/replication.html).
:param batch_size: The number of documents to index at once.
"""
super().__init__()
@ -146,6 +148,9 @@ class WeaviateDocumentStore(KeywordDocumentStore):
secret = self._get_auth_secret(
username, password, client_secret, access_token, expires_in, refresh_token, scope
)
# Timeout config can only be defined as a list in YAML, but Weaviate expects a tuple
if isinstance(timeout_config, list):
timeout_config = tuple(timeout_config)
self.weaviate_client = client.Client(
url=weaviate_url,
auth_client_secret=secret,
@ -186,6 +191,7 @@ class WeaviateDocumentStore(KeywordDocumentStore):
self.progress_bar = progress_bar
self.duplicate_documents = duplicate_documents
self.replication_factor = replication_factor
self.batch_size = batch_size
self._create_schema_and_index(self.index, recreate_index=recreate_index)
self.uuid_format_warning_raised = False
@ -400,7 +406,7 @@ class WeaviateDocumentStore(KeywordDocumentStore):
self,
ids: List[str],
index: Optional[str] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
) -> List[Document]:
"""
@ -410,6 +416,7 @@ class WeaviateDocumentStore(KeywordDocumentStore):
raise NotImplementedError("WeaviateDocumentStore does not support headers.")
index = self._sanitize_index_name(index) or self.index
batch_size = batch_size or self.batch_size
# We retrieve the JSON properties from the schema and convert them back to the Python dicts
json_properties = self._get_json_properties(index=index)
documents = []
@ -557,7 +564,7 @@ class WeaviateDocumentStore(KeywordDocumentStore):
self,
documents: Union[List[dict], List[Document]],
index: Optional[str] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
duplicate_documents: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
):
@ -567,6 +574,7 @@ class WeaviateDocumentStore(KeywordDocumentStore):
:param documents: List of `Dicts` or List of `Documents`. A dummy embedding vector for each document is automatically generated if it is not provided. The document id needs to be in uuid format. Otherwise a correctly formatted uuid will be automatically generated based on the provided id.
:param index: index name for storing the docs and metadata
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
If no batch_size is provided, self.batch_size is used.
:param duplicate_documents: Handle duplicates document based on parameter options.
Parameter options : ( 'skip','overwrite','fail')
skip: Ignore the duplicates documents
@ -580,6 +588,7 @@ class WeaviateDocumentStore(KeywordDocumentStore):
raise NotImplementedError("WeaviateDocumentStore does not support headers.")
index = self._sanitize_index_name(index) or self.index
batch_size = batch_size or self.batch_size
self._create_schema_and_index(index, recreate_index=False)
field_map = self._create_document_field_map()
@ -764,7 +773,7 @@ class WeaviateDocumentStore(KeywordDocumentStore):
index: Optional[str] = None,
filters: Optional[FilterType] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
) -> List[Document]:
"""
@ -817,11 +826,13 @@ class WeaviateDocumentStore(KeywordDocumentStore):
```
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
If no batch_size is provided, self.batch_size is used.
"""
if headers:
raise NotImplementedError("WeaviateDocumentStore does not support headers.")
index = self._sanitize_index_name(index) or self.index
batch_size = batch_size or self.batch_size
result = self.get_all_documents_generator(
index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
)
@ -832,13 +843,14 @@ class WeaviateDocumentStore(KeywordDocumentStore):
self,
index: Optional[str],
filters: Optional[FilterType] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
only_documents_without_embedding: bool = False,
) -> Generator[dict, None, None]:
"""
Return all documents in a specific index in the document store
"""
index = self._sanitize_index_name(index) or self.index
batch_size = batch_size or self.batch_size
# Build the properties to retrieve from Weaviate
properties = self._get_current_properties(index)
@ -907,7 +919,7 @@ class WeaviateDocumentStore(KeywordDocumentStore):
index: Optional[str] = None,
filters: Optional[FilterType] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
) -> Generator[Document, None, None]:
"""
@ -962,11 +974,13 @@ class WeaviateDocumentStore(KeywordDocumentStore):
```
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
If no batch_size is provided, self.batch_size is used.
"""
if headers:
raise NotImplementedError("WeaviateDocumentStore does not support headers.")
index = self._sanitize_index_name(index) or self.index
batch_size = batch_size or self.batch_size
if return_embedding is None:
return_embedding = self.return_embedding
@ -1418,11 +1432,11 @@ class WeaviateDocumentStore(KeywordDocumentStore):
index: Optional[str] = None,
filters: Optional[FilterType] = None,
update_existing_embeddings: bool = True,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
This can be useful if want to change the embeddings for your documents (e.g. after changing the retriever config).
Updates the embeddings in the document store using the encoding model specified in the retriever.
This can be useful if you want to change the embeddings for your documents (e.g. after changing the retriever config).
:param retriever: Retriever to use to update the embeddings.
:param index: Index name to update
@ -1456,9 +1470,11 @@ class WeaviateDocumentStore(KeywordDocumentStore):
}
```
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
If no batch_size is specified, self.batch_size is used.
:return: None
"""
index = self._sanitize_index_name(index) or self.index
batch_size = batch_size or self.batch_size
if not self.embedding_field:
raise RuntimeError("Specify the arg `embedding_field` when initializing WeaviateDocumentStore()")

View File

@ -449,3 +449,9 @@ class TestWeaviateDocumentStore(DocumentStoreBaseTestAbstract):
)
retrieved_docs = mocked_ds.get_all_documents()
assert retrieved_docs[0].meta["list_dict_field"] == [{"key": "value"}, {"key": "value"}]
@pytest.mark.unit
def test_write_documents_req_for_each_batch(self, mocked_ds, documents):
mocked_ds.batch_size = 2
mocked_ds.write_documents(documents)
assert mocked_ds.weaviate_client.batch.create_objects.call_count == 5