[document_store] Raise warning when labels are overwritten (#1257)

* [document_store]SQLDocumentStore write_labels() overwrite warning added.

* [document_store]SQLDocumentStore write_labels() overwrite warning added.

* [document_store] bug fixed. #1140

* [document_store] bug fixed. #1140

* [document_store] get_labels_by_id() method removed. #1140

* [document_store] Code refactor. fix #1140

* [document_store] Code refactor. fix #1140

* [document_store] elasticsearch document store Code refactor. fix #1140

* [document_store] elasticsearch document store Code refactor. fix #1140

* [document_store] elasticsearch document store Code refactor. fix #1140

* [document_store] Code refactor for better visibility. fix #1140

* [document_store] Inmemory document store duplicate labels warning added fix #1140
This commit is contained in:
Ikram Ali 2021-07-14 19:21:04 +05:00 committed by GitHub
parent da97d81305
commit 97c1e2cc90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 20 deletions

View File

@ -1,4 +1,5 @@
import logging
import collections
from abc import abstractmethod
from pathlib import Path
from typing import Optional, Dict, List, Union
@ -324,3 +325,24 @@ class BaseDocumentStore(BaseComponent):
documents = list(filter(lambda doc: doc.id not in ids_exist_in_db, documents))
return documents
def _get_duplicate_labels(self, labels: list, index: str = None) -> List[Label]:
"""
Return all duplicate labels
:param labels: List of Label objects
:param index: add an optional index attribute to labels. It can be later used for filtering.
:return: List of labels
"""
index = index or self.label_index
new_ids: List[str] = [label.id for label in labels]
duplicate_ids: List[str] = []
for label_id, count in collections.Counter(new_ids).items():
if count > 1:
duplicate_ids.append(label_id)
for label in self.get_all_labels(index=index):
if label.id in new_ids:
duplicate_ids.append(label.id)
return [label for label in labels if label.id in duplicate_ids]

View File

@ -443,29 +443,31 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
if index and not self.client.indices.exists(index=index):
self._create_label_index(index)
labels = [Label.from_dict(label) if isinstance(label, dict) else label for label in labels]
duplicate_ids: list = [label.id for label in self._get_duplicate_labels(labels, index=index)]
if len(duplicate_ids) > 0:
logger.warning(f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store."
f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of"
f" the answer annotation and not the question."
f" Problematic ids: {','.join(duplicate_ids)}")
labels_to_index = []
for l in labels:
# Make sure we comply to Label class format
if isinstance(l, dict):
label = Label.from_dict(l)
else:
label = l
for label in labels:
# create timestamps if not available yet
if not label.created_at:
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S")
if not label.updated_at:
label.updated_at = label.created_at
if not label.created_at: # type: ignore
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S") # type: ignore
if not label.updated_at: # type: ignore
label.updated_at = label.created_at # type: ignore
_label = {
"_op_type": "index" if self.duplicate_documents == "overwrite" else "create",
"_op_type": "index" if self.duplicate_documents == "overwrite" or label.id in duplicate_ids else # type: ignore
"create",
"_index": index,
**label.to_dict()
**label.to_dict() # type: ignore
} # type: Dict[str, Any]
# rename id for elastic
if label.id is not None:
_label["_id"] = str(_label.pop("id"))
if label.id is not None: # type: ignore
_label["_id"] = str(_label.pop("id")) # type: ignore
labels_to_index.append(_label)
@ -608,7 +610,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
"""
index = index or self.label_index
result = list(self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size))
labels = [Label.from_dict(hit["_source"]) for hit in result]
labels = [Label.from_dict({**hit["_source"], "id": hit["_id"]}) for hit in result]
return labels
def _get_all_documents_in_index(

View File

@ -122,14 +122,20 @@ class InMemoryDocumentStore(BaseDocumentStore):
index = index or self.label_index
label_objects = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
duplicate_ids: list = [label.id for label in self._get_duplicate_labels(label_objects, index=index)]
if len(duplicate_ids) > 0:
logger.warning(f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store."
f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of"
f" the answer annotation and not the question."
f" Problematic ids: {','.join(duplicate_ids)}")
for label in label_objects:
label_id = str(uuid4())
# create timestamps if not available yet
if not label.created_at:
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S")
if not label.updated_at:
label.updated_at = label.created_at
self.indexes[index][label_id] = label
self.indexes[index][label.id] = label
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
"""Fetch a document by specifying its text id string"""

View File

@ -1,5 +1,6 @@
import itertools
import logging
import collections
from typing import Any, Dict, Union, List, Optional, Generator
from uuid import uuid4
@ -326,6 +327,13 @@ class SQLDocumentStore(BaseDocumentStore):
labels = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
index = index or self.label_index
duplicate_ids: list = [label.id for label in self._get_duplicate_labels(labels, index=index)]
if len(duplicate_ids) > 0:
logger.warning(f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store."
f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of"
f" the answer annotation and not the question."
f" Problematic ids: {','.join(duplicate_ids)}")
# TODO: Use batch_size
for label in labels:
label_orm = LabelORM(
@ -341,6 +349,9 @@ class SQLDocumentStore(BaseDocumentStore):
model_id=label.model_id,
index=index,
)
if label.id in duplicate_ids:
self.session.merge(label_orm)
else:
self.session.add(label_orm)
self.session.commit()
@ -432,7 +443,8 @@ class SQLDocumentStore(BaseDocumentStore):
offset_start_in_doc=row.offset_start_in_doc,
model_id=row.model_id,
created_at=row.created_at,
updated_at=row.updated_at
updated_at=row.updated_at,
id=row.id
)
return label