From 97c1e2cc9013f0a57d25c683093e87dbe29a88ef Mon Sep 17 00:00:00 2001 From: Ikram Ali Date: Wed, 14 Jul 2021 19:21:04 +0500 Subject: [PATCH] [document_store] Raise warning when labels are overwritten (#1257) * [document_store]SQLDocumentStore write_labels() overwrite warning added. * [document_store]SQLDocumentStore write_labels() overwrite warning added. * [document_store] bug fixed. #1140 * [document_store] bug fixed. #1140 * [document_store] get_labels_by_id() method removed. #1140 * [document_store] Code refactor. fix #1140 * [document_store] Code refactor. fix #1140 * [document_store] elasticsearch document store Code refactor. fix #1140 * [document_store] elasticsearch document store Code refactor. fix #1140 * [document_store] elasticsearch document store Code refactor. fix #1140 * [document_store] Code refactor for better visibility. fix #1140 * [document_store] Inmemory document store duplicate labels warning added fix #1140 --- haystack/document_store/base.py | 22 +++++++++++++++ haystack/document_store/elasticsearch.py | 34 +++++++++++++----------- haystack/document_store/memory.py | 10 +++++-- haystack/document_store/sql.py | 16 +++++++++-- 4 files changed, 62 insertions(+), 20 deletions(-) diff --git a/haystack/document_store/base.py b/haystack/document_store/base.py index 4235def6a..86f353bea 100644 --- a/haystack/document_store/base.py +++ b/haystack/document_store/base.py @@ -1,4 +1,5 @@ import logging +import collections from abc import abstractmethod from pathlib import Path from typing import Optional, Dict, List, Union @@ -324,3 +325,24 @@ class BaseDocumentStore(BaseComponent): documents = list(filter(lambda doc: doc.id not in ids_exist_in_db, documents)) return documents + + def _get_duplicate_labels(self, labels: list, index: str = None) -> List[Label]: + """ + Return all duplicate labels + :param labels: List of Label objects + :param index: add an optional index attribute to labels. It can be later used for filtering. + :return: List of labels + """ + index = index or self.label_index + new_ids: List[str] = [label.id for label in labels] + duplicate_ids: List[str] = [] + + for label_id, count in collections.Counter(new_ids).items(): + if count > 1: + duplicate_ids.append(label_id) + + for label in self.get_all_labels(index=index): + if label.id in new_ids: + duplicate_ids.append(label.id) + + return [label for label in labels if label.id in duplicate_ids] diff --git a/haystack/document_store/elasticsearch.py b/haystack/document_store/elasticsearch.py index 4efe4babf..7d03008f0 100644 --- a/haystack/document_store/elasticsearch.py +++ b/haystack/document_store/elasticsearch.py @@ -443,29 +443,31 @@ class ElasticsearchDocumentStore(BaseDocumentStore): if index and not self.client.indices.exists(index=index): self._create_label_index(index) + labels = [Label.from_dict(label) if isinstance(label, dict) else label for label in labels] + duplicate_ids: list = [label.id for label in self._get_duplicate_labels(labels, index=index)] + if len(duplicate_ids) > 0: + logger.warning(f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store." + f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of" + f" the answer annotation and not the question." + f" Problematic ids: {','.join(duplicate_ids)}") labels_to_index = [] - for l in labels: - # Make sure we comply to Label class format - if isinstance(l, dict): - label = Label.from_dict(l) - else: - label = l - + for label in labels: # create timestamps if not available yet - if not label.created_at: - label.created_at = time.strftime("%Y-%m-%d %H:%M:%S") - if not label.updated_at: - label.updated_at = label.created_at + if not label.created_at: # type: ignore + label.created_at = time.strftime("%Y-%m-%d %H:%M:%S") # type: ignore + if not label.updated_at: # type: ignore + label.updated_at = label.created_at # type: ignore _label = { - "_op_type": "index" if self.duplicate_documents == "overwrite" else "create", + "_op_type": "index" if self.duplicate_documents == "overwrite" or label.id in duplicate_ids else # type: ignore + "create", "_index": index, - **label.to_dict() + **label.to_dict() # type: ignore } # type: Dict[str, Any] # rename id for elastic - if label.id is not None: - _label["_id"] = str(_label.pop("id")) + if label.id is not None: # type: ignore + _label["_id"] = str(_label.pop("id")) # type: ignore labels_to_index.append(_label) @@ -608,7 +610,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore): """ index = index or self.label_index result = list(self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size)) - labels = [Label.from_dict(hit["_source"]) for hit in result] + labels = [Label.from_dict({**hit["_source"], "id": hit["_id"]}) for hit in result] return labels def _get_all_documents_in_index( diff --git a/haystack/document_store/memory.py b/haystack/document_store/memory.py index e1ab1871d..c90ccd5d2 100644 --- a/haystack/document_store/memory.py +++ b/haystack/document_store/memory.py @@ -122,14 +122,20 @@ class InMemoryDocumentStore(BaseDocumentStore): index = index or self.label_index label_objects = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels] + duplicate_ids: list = [label.id for label in self._get_duplicate_labels(label_objects, index=index)] + if len(duplicate_ids) > 0: + logger.warning(f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store." + f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of" + f" the answer annotation and not the question." + f" Problematic ids: {','.join(duplicate_ids)}") + for label in label_objects: - label_id = str(uuid4()) # create timestamps if not available yet if not label.created_at: label.created_at = time.strftime("%Y-%m-%d %H:%M:%S") if not label.updated_at: label.updated_at = label.created_at - self.indexes[index][label_id] = label + self.indexes[index][label.id] = label def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]: """Fetch a document by specifying its text id string""" diff --git a/haystack/document_store/sql.py b/haystack/document_store/sql.py index 65a9c87c5..91829ccf7 100644 --- a/haystack/document_store/sql.py +++ b/haystack/document_store/sql.py @@ -1,5 +1,6 @@ import itertools import logging +import collections from typing import Any, Dict, Union, List, Optional, Generator from uuid import uuid4 @@ -326,6 +327,13 @@ class SQLDocumentStore(BaseDocumentStore): labels = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels] index = index or self.label_index + + duplicate_ids: list = [label.id for label in self._get_duplicate_labels(labels, index=index)] + if len(duplicate_ids) > 0: + logger.warning(f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store." + f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of" + f" the answer annotation and not the question." + f" Problematic ids: {','.join(duplicate_ids)}") # TODO: Use batch_size for label in labels: label_orm = LabelORM( @@ -341,7 +349,10 @@ class SQLDocumentStore(BaseDocumentStore): model_id=label.model_id, index=index, ) - self.session.add(label_orm) + if label.id in duplicate_ids: + self.session.merge(label_orm) + else: + self.session.add(label_orm) self.session.commit() def update_vector_ids(self, vector_id_map: Dict[str, str], index: Optional[str] = None, batch_size: int = 10_000): @@ -432,7 +443,8 @@ class SQLDocumentStore(BaseDocumentStore): offset_start_in_doc=row.offset_start_in_doc, model_id=row.model_id, created_at=row.created_at, - updated_at=row.updated_at + updated_at=row.updated_at, + id=row.id ) return label