Improve the progress bar in update_embeddings() + Fix filters in update_embeddings() (#1063)

* [document_stores]Add the progressbar in update_embeddings() to track the overall documents progress closed #1037

* change 2nd level loop to docs. switch to tqdm.auto.

* [document_stores] Elasticsearch new method get_document_without_embedding_count() added.

* [test_case]  Elasticsearch documentstore get_document_without_embedding_count() test case added.

* [document_stores] Add new bool arg in get_document_count() method and fixed #1082

* [document_stores] typo fixed #1082

Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
Ikram Ali 2021-05-21 17:18:07 +05:00 committed by GitHub
parent f46b09c756
commit 4ab1bc3c3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 82 additions and 37 deletions

View File

@ -9,6 +9,7 @@ from elasticsearch.helpers import bulk, scan
from elasticsearch.exceptions import RequestError
import numpy as np
from scipy.special import expit
from tqdm.auto import tqdm
from haystack.document_store.base import BaseDocumentStore
from haystack import Document, Label
@ -475,13 +476,17 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
body = {"doc": meta}
self.client.update(index=self.index, id=id, body=body, refresh=self.refresh_type)
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None,
only_documents_without_embedding: bool = False) -> int:
"""
Return the number of documents in the document store.
"""
index = index or self.index
body: dict = {"query": {"bool": {}}}
if only_documents_without_embedding:
body['query']['bool']['must_not'] = [{"exists": {"field": self.embedding_field}}]
if filters:
filter_clause = []
for key, values in filters.items():
@ -620,7 +625,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
body["query"]["bool"]["filter"] = filter_clause
if only_documents_without_embedding:
body["query"]["bool"] = {"must_not": {"exists": {"field": self.embedding_field}}}
body['query']['bool']['must_not'] = [{"exists": {"field": self.embedding_field}}]
result = scan(self.client, query=body, index=index, size=batch_size, scroll="1d")
yield from result
@ -894,9 +899,12 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
raise RuntimeError("Specify the arg `embedding_field` when initializing ElasticsearchDocumentStore()")
if update_existing_embeddings:
logger.info(f"Updating embeddings for all {self.get_document_count(index=index)} docs ...")
document_count = self.get_document_count(index=index)
logger.info(f"Updating embeddings for all {document_count} docs ...")
else:
logger.info(f"Updating embeddings for new docs without embeddings ...")
document_count = self.get_document_count(index=index, filters=filters,
only_documents_without_embedding=True)
logger.info(f"Updating embeddings for {document_count} docs without embeddings ...")
result = self._get_all_documents_in_index(
index=index,
@ -905,25 +913,29 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
only_documents_without_embedding=not update_existing_embeddings
)
for result_batch in get_batches_from_generator(result, batch_size):
document_batch = [self._convert_es_hit_to_document(hit, return_embedding=False) for hit in result_batch]
embeddings = retriever.embed_passages(document_batch) # type: ignore
assert len(document_batch) == len(embeddings)
logging.getLogger("elasticsearch").setLevel(logging.CRITICAL)
if embeddings[0].shape[0] != self.embedding_dim:
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})."
"Specify the arg `embedding_dim` when initializing ElasticsearchDocumentStore()")
doc_updates = []
for doc, emb in zip(document_batch, embeddings):
update = {"_op_type": "update",
"_index": index,
"_id": doc.id,
"doc": {self.embedding_field: emb.tolist()},
}
doc_updates.append(update)
with tqdm(total=document_count, position=0, unit=" Docs", desc="Updating embeddings") as progress_bar:
for result_batch in get_batches_from_generator(result, batch_size):
document_batch = [self._convert_es_hit_to_document(hit, return_embedding=False) for hit in result_batch]
embeddings = retriever.embed_passages(document_batch) # type: ignore
assert len(document_batch) == len(embeddings)
bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type)
if embeddings[0].shape[0] != self.embedding_dim:
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})."
"Specify the arg `embedding_dim` when initializing ElasticsearchDocumentStore()")
doc_updates = []
for doc, emb in zip(document_batch, embeddings):
update = {"_op_type": "update",
"_index": index,
"_id": doc.id,
"doc": {self.embedding_field: emb.tolist()},
}
doc_updates.append(update)
bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type)
progress_bar.update(batch_size)
def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
"""
@ -967,7 +979,6 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
if self.refresh_type == "wait_for":
time.sleep(2)
class OpenDistroElasticsearchDocumentStore(ElasticsearchDocumentStore):
"""
Document Store using the Open Distro for Elasticsearch. It is compatible with the AWS Elasticsearch Service.

View File

@ -235,7 +235,8 @@ class FAISSDocumentStore(SQLDocumentStore):
only_documents_without_embedding=not update_existing_embeddings
)
batched_documents = get_batches_from_generator(result, batch_size)
with tqdm(total=document_count, disable=not self.progress_bar) as progress_bar:
with tqdm(total=document_count, disable=not self.progress_bar, position=0, unit=" docs",
desc="Updating Embedding") as progress_bar:
for document_batch in batched_documents:
embeddings = retriever.embed_passages(document_batch) # type: ignore
assert len(document_batch) == len(embeddings)
@ -248,8 +249,8 @@ class FAISSDocumentStore(SQLDocumentStore):
vector_id_map[doc.id] = vector_id
vector_id += 1
self.update_vector_ids(vector_id_map, index=index)
progress_bar.set_description_str("Documents Processed")
progress_bar.update(batch_size)
progress_bar.close()
def get_all_documents(
self,

View File

@ -210,7 +210,8 @@ class InMemoryDocumentStore(BaseDocumentStore):
document_count = len(result)
logger.info(f"Updating embeddings for {document_count} docs ...")
batched_documents = get_batches_from_generator(result, batch_size)
with tqdm(total=document_count, disable=not self.progress_bar) as progress_bar:
with tqdm(total=document_count, disable=not self.progress_bar, position=0, unit=" docs",
desc="Updating Embedding") as progress_bar:
for document_batch in batched_documents:
embeddings = retriever.embed_passages(document_batch) # type: ignore
assert len(document_batch) == len(embeddings)
@ -222,6 +223,8 @@ class InMemoryDocumentStore(BaseDocumentStore):
for doc, emb in zip(document_batch, embeddings):
self.indexes[index][doc.id].embedding = emb
progress_bar.set_description_str("Documents Processed")
progress_bar.update(batch_size)
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
"""

View File

@ -267,7 +267,8 @@ class MilvusDocumentStore(SQLDocumentStore):
only_documents_without_embedding=not update_existing_embeddings
)
batched_documents = get_batches_from_generator(result, batch_size)
with tqdm(total=document_count, disable=not self.progress_bar) as progress_bar:
with tqdm(total=document_count, disable=not self.progress_bar, position=0, unit=" docs",
desc="Updating Embedding") as progress_bar:
for document_batch in batched_documents:
self._delete_vector_ids_from_milvus(documents=document_batch, index=index)
@ -284,8 +285,8 @@ class MilvusDocumentStore(SQLDocumentStore):
vector_id_map[doc.id] = vector_id
self.update_vector_ids(vector_id_map, index=index)
progress_bar.set_description_str("Documents Processed")
progress_bar.update(batch_size)
progress_bar.close()
self.milvus_server.flush([index])
self.milvus_server.compact(collection_name=index)

View File

@ -3,7 +3,7 @@ from typing import List, Union, Optional
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
from tqdm.auto import tqdm
from haystack.document_store.base import BaseDocumentStore
from haystack import Document
@ -240,16 +240,19 @@ class DensePassageRetriever(BaseRetriever):
else:
disable_tqdm = not self.progress_bar
for i, batch in enumerate(tqdm(data_loader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)):
batch = {key: batch[key].to(self.device) for key in batch}
with tqdm(total=len(data_loader)*self.batch_size, unit=" Docs", desc=f"Create embeddings", position=1,
leave=False, disable=disable_tqdm) as progress_bar:
for batch in data_loader:
batch = {key: batch[key].to(self.device) for key in batch}
# get logits
with torch.no_grad():
query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
if query_embeddings is not None:
all_embeddings["query"].append(query_embeddings.cpu().numpy())
if passage_embeddings is not None:
all_embeddings["passages"].append(passage_embeddings.cpu().numpy())
# get logits
with torch.no_grad():
query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
if query_embeddings is not None:
all_embeddings["query"].append(query_embeddings.cpu().numpy())
if passage_embeddings is not None:
all_embeddings["passages"].append(passage_embeddings.cpu().numpy())
progress_bar.update(self.batch_size)
if all_embeddings["passages"]:
all_embeddings["passages"] = np.concatenate(all_embeddings["passages"])

View File

@ -603,3 +603,29 @@ def test_elasticsearch_custom_fields(elasticsearch_fixture):
assert len(documents) == 1
assert documents[0].text == "test"
np.testing.assert_array_equal(doc_to_write["custom_embedding_field"], documents[0].embedding)
@pytest.mark.elasticsearch
def test_get_document_count_only_documents_without_embedding_arg():
documents = [
{"text": "text1", "id": "1", "embedding": np.random.rand(768).astype(np.float32), "meta_field_for_count": "a"},
{"text": "text2", "id": "2", "embedding": np.random.rand(768).astype(np.float64), "meta_field_for_count": "b"},
{"text": "text3", "id": "3", "embedding": np.random.rand(768).astype(np.float32).tolist()},
{"text": "text4", "id": "4", "meta_field_for_count": "b"},
{"text": "text5", "id": "5", "meta_field_for_count": "b"},
{"text": "text6", "id": "6", "meta_field_for_count": "c"},
{"text": "text7", "id": "7", "embedding": np.random.rand(768).astype(np.float64), "meta_field_for_count": "c"},
]
_index: str = "haystack_test_count"
document_store = ElasticsearchDocumentStore(index=_index)
document_store.delete_documents(index=_index)
document_store.write_documents(documents)
assert document_store.get_document_count() == 7
assert document_store.get_document_count(only_documents_without_embedding=True) == 3
assert document_store.get_document_count(only_documents_without_embedding=True,
filters={"meta_field_for_count": ["c"]}) == 1
assert document_store.get_document_count(only_documents_without_embedding=True,
filters={"meta_field_for_count": ["b"]}) == 2