mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-29 16:08:38 +00:00
Improve the progress bar in update_embeddings() + Fix filters in update_embeddings() (#1063)
* [document_stores]Add the progressbar in update_embeddings() to track the overall documents progress closed #1037 * change 2nd level loop to docs. switch to tqdm.auto. * [document_stores] Elasticsearch new method get_document_without_embedding_count() added. * [test_case] Elasticsearch documentstore get_document_without_embedding_count() test case added. * [document_stores] Add new bool arg in get_document_count() method and fixed #1082 * [document_stores] typo fixed #1082 Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
f46b09c756
commit
4ab1bc3c3e
@ -9,6 +9,7 @@ from elasticsearch.helpers import bulk, scan
|
||||
from elasticsearch.exceptions import RequestError
|
||||
import numpy as np
|
||||
from scipy.special import expit
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from haystack.document_store.base import BaseDocumentStore
|
||||
from haystack import Document, Label
|
||||
@ -475,13 +476,17 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
body = {"doc": meta}
|
||||
self.client.update(index=self.index, id=id, body=body, refresh=self.refresh_type)
|
||||
|
||||
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
|
||||
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None,
|
||||
only_documents_without_embedding: bool = False) -> int:
|
||||
"""
|
||||
Return the number of documents in the document store.
|
||||
"""
|
||||
index = index or self.index
|
||||
|
||||
body: dict = {"query": {"bool": {}}}
|
||||
if only_documents_without_embedding:
|
||||
body['query']['bool']['must_not'] = [{"exists": {"field": self.embedding_field}}]
|
||||
|
||||
if filters:
|
||||
filter_clause = []
|
||||
for key, values in filters.items():
|
||||
@ -620,7 +625,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
body["query"]["bool"]["filter"] = filter_clause
|
||||
|
||||
if only_documents_without_embedding:
|
||||
body["query"]["bool"] = {"must_not": {"exists": {"field": self.embedding_field}}}
|
||||
body['query']['bool']['must_not'] = [{"exists": {"field": self.embedding_field}}]
|
||||
|
||||
result = scan(self.client, query=body, index=index, size=batch_size, scroll="1d")
|
||||
yield from result
|
||||
@ -894,9 +899,12 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
raise RuntimeError("Specify the arg `embedding_field` when initializing ElasticsearchDocumentStore()")
|
||||
|
||||
if update_existing_embeddings:
|
||||
logger.info(f"Updating embeddings for all {self.get_document_count(index=index)} docs ...")
|
||||
document_count = self.get_document_count(index=index)
|
||||
logger.info(f"Updating embeddings for all {document_count} docs ...")
|
||||
else:
|
||||
logger.info(f"Updating embeddings for new docs without embeddings ...")
|
||||
document_count = self.get_document_count(index=index, filters=filters,
|
||||
only_documents_without_embedding=True)
|
||||
logger.info(f"Updating embeddings for {document_count} docs without embeddings ...")
|
||||
|
||||
result = self._get_all_documents_in_index(
|
||||
index=index,
|
||||
@ -905,25 +913,29 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
only_documents_without_embedding=not update_existing_embeddings
|
||||
)
|
||||
|
||||
for result_batch in get_batches_from_generator(result, batch_size):
|
||||
document_batch = [self._convert_es_hit_to_document(hit, return_embedding=False) for hit in result_batch]
|
||||
embeddings = retriever.embed_passages(document_batch) # type: ignore
|
||||
assert len(document_batch) == len(embeddings)
|
||||
logging.getLogger("elasticsearch").setLevel(logging.CRITICAL)
|
||||
|
||||
if embeddings[0].shape[0] != self.embedding_dim:
|
||||
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
|
||||
f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})."
|
||||
"Specify the arg `embedding_dim` when initializing ElasticsearchDocumentStore()")
|
||||
doc_updates = []
|
||||
for doc, emb in zip(document_batch, embeddings):
|
||||
update = {"_op_type": "update",
|
||||
"_index": index,
|
||||
"_id": doc.id,
|
||||
"doc": {self.embedding_field: emb.tolist()},
|
||||
}
|
||||
doc_updates.append(update)
|
||||
with tqdm(total=document_count, position=0, unit=" Docs", desc="Updating embeddings") as progress_bar:
|
||||
for result_batch in get_batches_from_generator(result, batch_size):
|
||||
document_batch = [self._convert_es_hit_to_document(hit, return_embedding=False) for hit in result_batch]
|
||||
embeddings = retriever.embed_passages(document_batch) # type: ignore
|
||||
assert len(document_batch) == len(embeddings)
|
||||
|
||||
bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type)
|
||||
if embeddings[0].shape[0] != self.embedding_dim:
|
||||
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
|
||||
f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})."
|
||||
"Specify the arg `embedding_dim` when initializing ElasticsearchDocumentStore()")
|
||||
doc_updates = []
|
||||
for doc, emb in zip(document_batch, embeddings):
|
||||
update = {"_op_type": "update",
|
||||
"_index": index,
|
||||
"_id": doc.id,
|
||||
"doc": {self.embedding_field: emb.tolist()},
|
||||
}
|
||||
doc_updates.append(update)
|
||||
|
||||
bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type)
|
||||
progress_bar.update(batch_size)
|
||||
|
||||
def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
|
||||
"""
|
||||
@ -967,7 +979,6 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
if self.refresh_type == "wait_for":
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
class OpenDistroElasticsearchDocumentStore(ElasticsearchDocumentStore):
|
||||
"""
|
||||
Document Store using the Open Distro for Elasticsearch. It is compatible with the AWS Elasticsearch Service.
|
||||
|
||||
@ -235,7 +235,8 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
only_documents_without_embedding=not update_existing_embeddings
|
||||
)
|
||||
batched_documents = get_batches_from_generator(result, batch_size)
|
||||
with tqdm(total=document_count, disable=not self.progress_bar) as progress_bar:
|
||||
with tqdm(total=document_count, disable=not self.progress_bar, position=0, unit=" docs",
|
||||
desc="Updating Embedding") as progress_bar:
|
||||
for document_batch in batched_documents:
|
||||
embeddings = retriever.embed_passages(document_batch) # type: ignore
|
||||
assert len(document_batch) == len(embeddings)
|
||||
@ -248,8 +249,8 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
vector_id_map[doc.id] = vector_id
|
||||
vector_id += 1
|
||||
self.update_vector_ids(vector_id_map, index=index)
|
||||
progress_bar.set_description_str("Documents Processed")
|
||||
progress_bar.update(batch_size)
|
||||
progress_bar.close()
|
||||
|
||||
def get_all_documents(
|
||||
self,
|
||||
|
||||
@ -210,7 +210,8 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
document_count = len(result)
|
||||
logger.info(f"Updating embeddings for {document_count} docs ...")
|
||||
batched_documents = get_batches_from_generator(result, batch_size)
|
||||
with tqdm(total=document_count, disable=not self.progress_bar) as progress_bar:
|
||||
with tqdm(total=document_count, disable=not self.progress_bar, position=0, unit=" docs",
|
||||
desc="Updating Embedding") as progress_bar:
|
||||
for document_batch in batched_documents:
|
||||
embeddings = retriever.embed_passages(document_batch) # type: ignore
|
||||
assert len(document_batch) == len(embeddings)
|
||||
@ -222,6 +223,8 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
|
||||
for doc, emb in zip(document_batch, embeddings):
|
||||
self.indexes[index][doc.id].embedding = emb
|
||||
progress_bar.set_description_str("Documents Processed")
|
||||
progress_bar.update(batch_size)
|
||||
|
||||
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
|
||||
"""
|
||||
|
||||
@ -267,7 +267,8 @@ class MilvusDocumentStore(SQLDocumentStore):
|
||||
only_documents_without_embedding=not update_existing_embeddings
|
||||
)
|
||||
batched_documents = get_batches_from_generator(result, batch_size)
|
||||
with tqdm(total=document_count, disable=not self.progress_bar) as progress_bar:
|
||||
with tqdm(total=document_count, disable=not self.progress_bar, position=0, unit=" docs",
|
||||
desc="Updating Embedding") as progress_bar:
|
||||
for document_batch in batched_documents:
|
||||
self._delete_vector_ids_from_milvus(documents=document_batch, index=index)
|
||||
|
||||
@ -284,8 +285,8 @@ class MilvusDocumentStore(SQLDocumentStore):
|
||||
vector_id_map[doc.id] = vector_id
|
||||
|
||||
self.update_vector_ids(vector_id_map, index=index)
|
||||
progress_bar.set_description_str("Documents Processed")
|
||||
progress_bar.update(batch_size)
|
||||
progress_bar.close()
|
||||
|
||||
self.milvus_server.flush([index])
|
||||
self.milvus_server.compact(collection_name=index)
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import List, Union, Optional
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from haystack.document_store.base import BaseDocumentStore
|
||||
from haystack import Document
|
||||
@ -240,16 +240,19 @@ class DensePassageRetriever(BaseRetriever):
|
||||
else:
|
||||
disable_tqdm = not self.progress_bar
|
||||
|
||||
for i, batch in enumerate(tqdm(data_loader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)):
|
||||
batch = {key: batch[key].to(self.device) for key in batch}
|
||||
with tqdm(total=len(data_loader)*self.batch_size, unit=" Docs", desc=f"Create embeddings", position=1,
|
||||
leave=False, disable=disable_tqdm) as progress_bar:
|
||||
for batch in data_loader:
|
||||
batch = {key: batch[key].to(self.device) for key in batch}
|
||||
|
||||
# get logits
|
||||
with torch.no_grad():
|
||||
query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
|
||||
if query_embeddings is not None:
|
||||
all_embeddings["query"].append(query_embeddings.cpu().numpy())
|
||||
if passage_embeddings is not None:
|
||||
all_embeddings["passages"].append(passage_embeddings.cpu().numpy())
|
||||
# get logits
|
||||
with torch.no_grad():
|
||||
query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
|
||||
if query_embeddings is not None:
|
||||
all_embeddings["query"].append(query_embeddings.cpu().numpy())
|
||||
if passage_embeddings is not None:
|
||||
all_embeddings["passages"].append(passage_embeddings.cpu().numpy())
|
||||
progress_bar.update(self.batch_size)
|
||||
|
||||
if all_embeddings["passages"]:
|
||||
all_embeddings["passages"] = np.concatenate(all_embeddings["passages"])
|
||||
|
||||
@ -603,3 +603,29 @@ def test_elasticsearch_custom_fields(elasticsearch_fixture):
|
||||
assert len(documents) == 1
|
||||
assert documents[0].text == "test"
|
||||
np.testing.assert_array_equal(doc_to_write["custom_embedding_field"], documents[0].embedding)
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
def test_get_document_count_only_documents_without_embedding_arg():
|
||||
documents = [
|
||||
{"text": "text1", "id": "1", "embedding": np.random.rand(768).astype(np.float32), "meta_field_for_count": "a"},
|
||||
{"text": "text2", "id": "2", "embedding": np.random.rand(768).astype(np.float64), "meta_field_for_count": "b"},
|
||||
{"text": "text3", "id": "3", "embedding": np.random.rand(768).astype(np.float32).tolist()},
|
||||
{"text": "text4", "id": "4", "meta_field_for_count": "b"},
|
||||
{"text": "text5", "id": "5", "meta_field_for_count": "b"},
|
||||
{"text": "text6", "id": "6", "meta_field_for_count": "c"},
|
||||
{"text": "text7", "id": "7", "embedding": np.random.rand(768).astype(np.float64), "meta_field_for_count": "c"},
|
||||
]
|
||||
|
||||
_index: str = "haystack_test_count"
|
||||
document_store = ElasticsearchDocumentStore(index=_index)
|
||||
document_store.delete_documents(index=_index)
|
||||
|
||||
document_store.write_documents(documents)
|
||||
|
||||
assert document_store.get_document_count() == 7
|
||||
assert document_store.get_document_count(only_documents_without_embedding=True) == 3
|
||||
assert document_store.get_document_count(only_documents_without_embedding=True,
|
||||
filters={"meta_field_for_count": ["c"]}) == 1
|
||||
assert document_store.get_document_count(only_documents_without_embedding=True,
|
||||
filters={"meta_field_for_count": ["b"]}) == 2
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user