[document_store] Raise warning when labels are overwritten (#1257)

* [document_store]SQLDocumentStore write_labels() overwrite warning added.

* [document_store]SQLDocumentStore write_labels() overwrite warning added.

* [document_store] bug fixed. #1140

* [document_store] bug fixed. #1140

* [document_store] get_labels_by_id() method removed. #1140

* [document_store] Code refactor. fix #1140

* [document_store] Code refactor. fix #1140

* [document_store] elasticsearch document store Code refactor. fix #1140

* [document_store] elasticsearch document store Code refactor. fix #1140

* [document_store] elasticsearch document store Code refactor. fix #1140

* [document_store] Code refactor for better visibility. fix #1140

* [document_store] Inmemory document store duplicate labels warning added fix #1140
This commit is contained in:
Ikram Ali 2021-07-14 19:21:04 +05:00 committed by GitHub
parent da97d81305
commit 97c1e2cc90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 20 deletions

View File

@ -1,4 +1,5 @@
import logging import logging
import collections
from abc import abstractmethod from abc import abstractmethod
from pathlib import Path from pathlib import Path
from typing import Optional, Dict, List, Union from typing import Optional, Dict, List, Union
@ -324,3 +325,24 @@ class BaseDocumentStore(BaseComponent):
documents = list(filter(lambda doc: doc.id not in ids_exist_in_db, documents)) documents = list(filter(lambda doc: doc.id not in ids_exist_in_db, documents))
return documents return documents
def _get_duplicate_labels(self, labels: list, index: str = None) -> List[Label]:
"""
Return all duplicate labels
:param labels: List of Label objects
:param index: add an optional index attribute to labels. It can be later used for filtering.
:return: List of labels
"""
index = index or self.label_index
new_ids: List[str] = [label.id for label in labels]
duplicate_ids: List[str] = []
for label_id, count in collections.Counter(new_ids).items():
if count > 1:
duplicate_ids.append(label_id)
for label in self.get_all_labels(index=index):
if label.id in new_ids:
duplicate_ids.append(label.id)
return [label for label in labels if label.id in duplicate_ids]

View File

@ -443,29 +443,31 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
if index and not self.client.indices.exists(index=index): if index and not self.client.indices.exists(index=index):
self._create_label_index(index) self._create_label_index(index)
labels = [Label.from_dict(label) if isinstance(label, dict) else label for label in labels]
duplicate_ids: list = [label.id for label in self._get_duplicate_labels(labels, index=index)]
if len(duplicate_ids) > 0:
logger.warning(f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store."
f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of"
f" the answer annotation and not the question."
f" Problematic ids: {','.join(duplicate_ids)}")
labels_to_index = [] labels_to_index = []
for l in labels: for label in labels:
# Make sure we comply to Label class format
if isinstance(l, dict):
label = Label.from_dict(l)
else:
label = l
# create timestamps if not available yet # create timestamps if not available yet
if not label.created_at: if not label.created_at: # type: ignore
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S") label.created_at = time.strftime("%Y-%m-%d %H:%M:%S") # type: ignore
if not label.updated_at: if not label.updated_at: # type: ignore
label.updated_at = label.created_at label.updated_at = label.created_at # type: ignore
_label = { _label = {
"_op_type": "index" if self.duplicate_documents == "overwrite" else "create", "_op_type": "index" if self.duplicate_documents == "overwrite" or label.id in duplicate_ids else # type: ignore
"create",
"_index": index, "_index": index,
**label.to_dict() **label.to_dict() # type: ignore
} # type: Dict[str, Any] } # type: Dict[str, Any]
# rename id for elastic # rename id for elastic
if label.id is not None: if label.id is not None: # type: ignore
_label["_id"] = str(_label.pop("id")) _label["_id"] = str(_label.pop("id")) # type: ignore
labels_to_index.append(_label) labels_to_index.append(_label)
@ -608,7 +610,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
""" """
index = index or self.label_index index = index or self.label_index
result = list(self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size)) result = list(self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size))
labels = [Label.from_dict(hit["_source"]) for hit in result] labels = [Label.from_dict({**hit["_source"], "id": hit["_id"]}) for hit in result]
return labels return labels
def _get_all_documents_in_index( def _get_all_documents_in_index(

View File

@ -122,14 +122,20 @@ class InMemoryDocumentStore(BaseDocumentStore):
index = index or self.label_index index = index or self.label_index
label_objects = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels] label_objects = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
duplicate_ids: list = [label.id for label in self._get_duplicate_labels(label_objects, index=index)]
if len(duplicate_ids) > 0:
logger.warning(f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store."
f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of"
f" the answer annotation and not the question."
f" Problematic ids: {','.join(duplicate_ids)}")
for label in label_objects: for label in label_objects:
label_id = str(uuid4())
# create timestamps if not available yet # create timestamps if not available yet
if not label.created_at: if not label.created_at:
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S") label.created_at = time.strftime("%Y-%m-%d %H:%M:%S")
if not label.updated_at: if not label.updated_at:
label.updated_at = label.created_at label.updated_at = label.created_at
self.indexes[index][label_id] = label self.indexes[index][label.id] = label
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]: def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
"""Fetch a document by specifying its text id string""" """Fetch a document by specifying its text id string"""

View File

@ -1,5 +1,6 @@
import itertools import itertools
import logging import logging
import collections
from typing import Any, Dict, Union, List, Optional, Generator from typing import Any, Dict, Union, List, Optional, Generator
from uuid import uuid4 from uuid import uuid4
@ -326,6 +327,13 @@ class SQLDocumentStore(BaseDocumentStore):
labels = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels] labels = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
index = index or self.label_index index = index or self.label_index
duplicate_ids: list = [label.id for label in self._get_duplicate_labels(labels, index=index)]
if len(duplicate_ids) > 0:
logger.warning(f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store."
f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of"
f" the answer annotation and not the question."
f" Problematic ids: {','.join(duplicate_ids)}")
# TODO: Use batch_size # TODO: Use batch_size
for label in labels: for label in labels:
label_orm = LabelORM( label_orm = LabelORM(
@ -341,6 +349,9 @@ class SQLDocumentStore(BaseDocumentStore):
model_id=label.model_id, model_id=label.model_id,
index=index, index=index,
) )
if label.id in duplicate_ids:
self.session.merge(label_orm)
else:
self.session.add(label_orm) self.session.add(label_orm)
self.session.commit() self.session.commit()
@ -432,7 +443,8 @@ class SQLDocumentStore(BaseDocumentStore):
offset_start_in_doc=row.offset_start_in_doc, offset_start_in_doc=row.offset_start_in_doc,
model_id=row.model_id, model_id=row.model_id,
created_at=row.created_at, created_at=row.created_at,
updated_at=row.updated_at updated_at=row.updated_at,
id=row.id
) )
return label return label