diff --git a/haystack/document_store/elasticsearch.py b/haystack/document_store/elasticsearch.py index 8725a6b99..9f41aef0c 100644 --- a/haystack/document_store/elasticsearch.py +++ b/haystack/document_store/elasticsearch.py @@ -9,6 +9,7 @@ from elasticsearch.helpers import bulk, scan from elasticsearch.exceptions import RequestError import numpy as np from scipy.special import expit +from tqdm.auto import tqdm from haystack.document_store.base import BaseDocumentStore from haystack import Document, Label @@ -475,13 +476,17 @@ class ElasticsearchDocumentStore(BaseDocumentStore): body = {"doc": meta} self.client.update(index=self.index, id=id, body=body, refresh=self.refresh_type) - def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int: + def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None, + only_documents_without_embedding: bool = False) -> int: """ Return the number of documents in the document store. """ index = index or self.index body: dict = {"query": {"bool": {}}} + if only_documents_without_embedding: + body['query']['bool']['must_not'] = [{"exists": {"field": self.embedding_field}}] + if filters: filter_clause = [] for key, values in filters.items(): @@ -620,7 +625,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore): body["query"]["bool"]["filter"] = filter_clause if only_documents_without_embedding: - body["query"]["bool"] = {"must_not": {"exists": {"field": self.embedding_field}}} + body['query']['bool']['must_not'] = [{"exists": {"field": self.embedding_field}}] result = scan(self.client, query=body, index=index, size=batch_size, scroll="1d") yield from result @@ -894,9 +899,12 @@ class ElasticsearchDocumentStore(BaseDocumentStore): raise RuntimeError("Specify the arg `embedding_field` when initializing ElasticsearchDocumentStore()") if update_existing_embeddings: - logger.info(f"Updating embeddings for all {self.get_document_count(index=index)} docs ...") + document_count = self.get_document_count(index=index) + logger.info(f"Updating embeddings for all {document_count} docs ...") else: - logger.info(f"Updating embeddings for new docs without embeddings ...") + document_count = self.get_document_count(index=index, filters=filters, + only_documents_without_embedding=True) + logger.info(f"Updating embeddings for {document_count} docs without embeddings ...") result = self._get_all_documents_in_index( index=index, @@ -905,25 +913,29 @@ class ElasticsearchDocumentStore(BaseDocumentStore): only_documents_without_embedding=not update_existing_embeddings ) - for result_batch in get_batches_from_generator(result, batch_size): - document_batch = [self._convert_es_hit_to_document(hit, return_embedding=False) for hit in result_batch] - embeddings = retriever.embed_passages(document_batch) # type: ignore - assert len(document_batch) == len(embeddings) + logging.getLogger("elasticsearch").setLevel(logging.CRITICAL) - if embeddings[0].shape[0] != self.embedding_dim: - raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})" - f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})." - "Specify the arg `embedding_dim` when initializing ElasticsearchDocumentStore()") - doc_updates = [] - for doc, emb in zip(document_batch, embeddings): - update = {"_op_type": "update", - "_index": index, - "_id": doc.id, - "doc": {self.embedding_field: emb.tolist()}, - } - doc_updates.append(update) + with tqdm(total=document_count, position=0, unit=" Docs", desc="Updating embeddings") as progress_bar: + for result_batch in get_batches_from_generator(result, batch_size): + document_batch = [self._convert_es_hit_to_document(hit, return_embedding=False) for hit in result_batch] + embeddings = retriever.embed_passages(document_batch) # type: ignore + assert len(document_batch) == len(embeddings) - bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type) + if embeddings[0].shape[0] != self.embedding_dim: + raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})" + f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})." + "Specify the arg `embedding_dim` when initializing ElasticsearchDocumentStore()") + doc_updates = [] + for doc, emb in zip(document_batch, embeddings): + update = {"_op_type": "update", + "_index": index, + "_id": doc.id, + "doc": {self.embedding_field: emb.tolist()}, + } + doc_updates.append(update) + + bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type) + progress_bar.update(batch_size) def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None): """ @@ -967,7 +979,6 @@ class ElasticsearchDocumentStore(BaseDocumentStore): if self.refresh_type == "wait_for": time.sleep(2) - class OpenDistroElasticsearchDocumentStore(ElasticsearchDocumentStore): """ Document Store using the Open Distro for Elasticsearch. It is compatible with the AWS Elasticsearch Service. diff --git a/haystack/document_store/faiss.py b/haystack/document_store/faiss.py index 98256463b..a3a63c754 100644 --- a/haystack/document_store/faiss.py +++ b/haystack/document_store/faiss.py @@ -235,7 +235,8 @@ class FAISSDocumentStore(SQLDocumentStore): only_documents_without_embedding=not update_existing_embeddings ) batched_documents = get_batches_from_generator(result, batch_size) - with tqdm(total=document_count, disable=not self.progress_bar) as progress_bar: + with tqdm(total=document_count, disable=not self.progress_bar, position=0, unit=" docs", + desc="Updating Embedding") as progress_bar: for document_batch in batched_documents: embeddings = retriever.embed_passages(document_batch) # type: ignore assert len(document_batch) == len(embeddings) @@ -248,8 +249,8 @@ class FAISSDocumentStore(SQLDocumentStore): vector_id_map[doc.id] = vector_id vector_id += 1 self.update_vector_ids(vector_id_map, index=index) + progress_bar.set_description_str("Documents Processed") progress_bar.update(batch_size) - progress_bar.close() def get_all_documents( self, diff --git a/haystack/document_store/memory.py b/haystack/document_store/memory.py index 73693a54e..56962c014 100644 --- a/haystack/document_store/memory.py +++ b/haystack/document_store/memory.py @@ -210,7 +210,8 @@ class InMemoryDocumentStore(BaseDocumentStore): document_count = len(result) logger.info(f"Updating embeddings for {document_count} docs ...") batched_documents = get_batches_from_generator(result, batch_size) - with tqdm(total=document_count, disable=not self.progress_bar) as progress_bar: + with tqdm(total=document_count, disable=not self.progress_bar, position=0, unit=" docs", + desc="Updating Embedding") as progress_bar: for document_batch in batched_documents: embeddings = retriever.embed_passages(document_batch) # type: ignore assert len(document_batch) == len(embeddings) @@ -222,6 +223,8 @@ class InMemoryDocumentStore(BaseDocumentStore): for doc, emb in zip(document_batch, embeddings): self.indexes[index][doc.id].embedding = emb + progress_bar.set_description_str("Documents Processed") + progress_bar.update(batch_size) def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int: """ diff --git a/haystack/document_store/milvus.py b/haystack/document_store/milvus.py index 606e50c99..a40744b02 100644 --- a/haystack/document_store/milvus.py +++ b/haystack/document_store/milvus.py @@ -267,7 +267,8 @@ class MilvusDocumentStore(SQLDocumentStore): only_documents_without_embedding=not update_existing_embeddings ) batched_documents = get_batches_from_generator(result, batch_size) - with tqdm(total=document_count, disable=not self.progress_bar) as progress_bar: + with tqdm(total=document_count, disable=not self.progress_bar, position=0, unit=" docs", + desc="Updating Embedding") as progress_bar: for document_batch in batched_documents: self._delete_vector_ids_from_milvus(documents=document_batch, index=index) @@ -284,8 +285,8 @@ class MilvusDocumentStore(SQLDocumentStore): vector_id_map[doc.id] = vector_id self.update_vector_ids(vector_id_map, index=index) + progress_bar.set_description_str("Documents Processed") progress_bar.update(batch_size) - progress_bar.close() self.milvus_server.flush([index]) self.milvus_server.compact(collection_name=index) diff --git a/haystack/retriever/dense.py b/haystack/retriever/dense.py index ddf893d6a..6dd2e812f 100644 --- a/haystack/retriever/dense.py +++ b/haystack/retriever/dense.py @@ -3,7 +3,7 @@ from typing import List, Union, Optional import torch import numpy as np from pathlib import Path -from tqdm import tqdm +from tqdm.auto import tqdm from haystack.document_store.base import BaseDocumentStore from haystack import Document @@ -240,16 +240,19 @@ class DensePassageRetriever(BaseRetriever): else: disable_tqdm = not self.progress_bar - for i, batch in enumerate(tqdm(data_loader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)): - batch = {key: batch[key].to(self.device) for key in batch} + with tqdm(total=len(data_loader)*self.batch_size, unit=" Docs", desc=f"Create embeddings", position=1, + leave=False, disable=disable_tqdm) as progress_bar: + for batch in data_loader: + batch = {key: batch[key].to(self.device) for key in batch} - # get logits - with torch.no_grad(): - query_embeddings, passage_embeddings = self.model.forward(**batch)[0] - if query_embeddings is not None: - all_embeddings["query"].append(query_embeddings.cpu().numpy()) - if passage_embeddings is not None: - all_embeddings["passages"].append(passage_embeddings.cpu().numpy()) + # get logits + with torch.no_grad(): + query_embeddings, passage_embeddings = self.model.forward(**batch)[0] + if query_embeddings is not None: + all_embeddings["query"].append(query_embeddings.cpu().numpy()) + if passage_embeddings is not None: + all_embeddings["passages"].append(passage_embeddings.cpu().numpy()) + progress_bar.update(self.batch_size) if all_embeddings["passages"]: all_embeddings["passages"] = np.concatenate(all_embeddings["passages"]) diff --git a/test/test_document_store.py b/test/test_document_store.py index 856e671a0..b1f93e08e 100644 --- a/test/test_document_store.py +++ b/test/test_document_store.py @@ -603,3 +603,29 @@ def test_elasticsearch_custom_fields(elasticsearch_fixture): assert len(documents) == 1 assert documents[0].text == "test" np.testing.assert_array_equal(doc_to_write["custom_embedding_field"], documents[0].embedding) + + +@pytest.mark.elasticsearch +def test_get_document_count_only_documents_without_embedding_arg(): + documents = [ + {"text": "text1", "id": "1", "embedding": np.random.rand(768).astype(np.float32), "meta_field_for_count": "a"}, + {"text": "text2", "id": "2", "embedding": np.random.rand(768).astype(np.float64), "meta_field_for_count": "b"}, + {"text": "text3", "id": "3", "embedding": np.random.rand(768).astype(np.float32).tolist()}, + {"text": "text4", "id": "4", "meta_field_for_count": "b"}, + {"text": "text5", "id": "5", "meta_field_for_count": "b"}, + {"text": "text6", "id": "6", "meta_field_for_count": "c"}, + {"text": "text7", "id": "7", "embedding": np.random.rand(768).astype(np.float64), "meta_field_for_count": "c"}, + ] + + _index: str = "haystack_test_count" + document_store = ElasticsearchDocumentStore(index=_index) + document_store.delete_documents(index=_index) + + document_store.write_documents(documents) + + assert document_store.get_document_count() == 7 + assert document_store.get_document_count(only_documents_without_embedding=True) == 3 + assert document_store.get_document_count(only_documents_without_embedding=True, + filters={"meta_field_for_count": ["c"]}) == 1 + assert document_store.get_document_count(only_documents_without_embedding=True, + filters={"meta_field_for_count": ["b"]}) == 2