mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-19 06:52:56 +00:00
[document_store] Raise warning when labels are overwritten (#1257)
* [document_store]SQLDocumentStore write_labels() overwrite warning added. * [document_store]SQLDocumentStore write_labels() overwrite warning added. * [document_store] bug fixed. #1140 * [document_store] bug fixed. #1140 * [document_store] get_labels_by_id() method removed. #1140 * [document_store] Code refactor. fix #1140 * [document_store] Code refactor. fix #1140 * [document_store] elasticsearch document store Code refactor. fix #1140 * [document_store] elasticsearch document store Code refactor. fix #1140 * [document_store] elasticsearch document store Code refactor. fix #1140 * [document_store] Code refactor for better visibility. fix #1140 * [document_store] Inmemory document store duplicate labels warning added fix #1140
This commit is contained in:
parent
da97d81305
commit
97c1e2cc90
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import collections
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Union
|
||||
@ -324,3 +325,24 @@ class BaseDocumentStore(BaseComponent):
|
||||
documents = list(filter(lambda doc: doc.id not in ids_exist_in_db, documents))
|
||||
|
||||
return documents
|
||||
|
||||
def _get_duplicate_labels(self, labels: list, index: str = None) -> List[Label]:
|
||||
"""
|
||||
Return all duplicate labels
|
||||
:param labels: List of Label objects
|
||||
:param index: add an optional index attribute to labels. It can be later used for filtering.
|
||||
:return: List of labels
|
||||
"""
|
||||
index = index or self.label_index
|
||||
new_ids: List[str] = [label.id for label in labels]
|
||||
duplicate_ids: List[str] = []
|
||||
|
||||
for label_id, count in collections.Counter(new_ids).items():
|
||||
if count > 1:
|
||||
duplicate_ids.append(label_id)
|
||||
|
||||
for label in self.get_all_labels(index=index):
|
||||
if label.id in new_ids:
|
||||
duplicate_ids.append(label.id)
|
||||
|
||||
return [label for label in labels if label.id in duplicate_ids]
|
||||
|
@ -443,29 +443,31 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
if index and not self.client.indices.exists(index=index):
|
||||
self._create_label_index(index)
|
||||
|
||||
labels = [Label.from_dict(label) if isinstance(label, dict) else label for label in labels]
|
||||
duplicate_ids: list = [label.id for label in self._get_duplicate_labels(labels, index=index)]
|
||||
if len(duplicate_ids) > 0:
|
||||
logger.warning(f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store."
|
||||
f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of"
|
||||
f" the answer annotation and not the question."
|
||||
f" Problematic ids: {','.join(duplicate_ids)}")
|
||||
labels_to_index = []
|
||||
for l in labels:
|
||||
# Make sure we comply to Label class format
|
||||
if isinstance(l, dict):
|
||||
label = Label.from_dict(l)
|
||||
else:
|
||||
label = l
|
||||
|
||||
for label in labels:
|
||||
# create timestamps if not available yet
|
||||
if not label.created_at:
|
||||
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if not label.updated_at:
|
||||
label.updated_at = label.created_at
|
||||
if not label.created_at: # type: ignore
|
||||
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S") # type: ignore
|
||||
if not label.updated_at: # type: ignore
|
||||
label.updated_at = label.created_at # type: ignore
|
||||
|
||||
_label = {
|
||||
"_op_type": "index" if self.duplicate_documents == "overwrite" else "create",
|
||||
"_op_type": "index" if self.duplicate_documents == "overwrite" or label.id in duplicate_ids else # type: ignore
|
||||
"create",
|
||||
"_index": index,
|
||||
**label.to_dict()
|
||||
**label.to_dict() # type: ignore
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
# rename id for elastic
|
||||
if label.id is not None:
|
||||
_label["_id"] = str(_label.pop("id"))
|
||||
if label.id is not None: # type: ignore
|
||||
_label["_id"] = str(_label.pop("id")) # type: ignore
|
||||
|
||||
labels_to_index.append(_label)
|
||||
|
||||
@ -608,7 +610,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
"""
|
||||
index = index or self.label_index
|
||||
result = list(self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size))
|
||||
labels = [Label.from_dict(hit["_source"]) for hit in result]
|
||||
labels = [Label.from_dict({**hit["_source"], "id": hit["_id"]}) for hit in result]
|
||||
return labels
|
||||
|
||||
def _get_all_documents_in_index(
|
||||
|
@ -122,14 +122,20 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
index = index or self.label_index
|
||||
label_objects = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
|
||||
|
||||
duplicate_ids: list = [label.id for label in self._get_duplicate_labels(label_objects, index=index)]
|
||||
if len(duplicate_ids) > 0:
|
||||
logger.warning(f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store."
|
||||
f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of"
|
||||
f" the answer annotation and not the question."
|
||||
f" Problematic ids: {','.join(duplicate_ids)}")
|
||||
|
||||
for label in label_objects:
|
||||
label_id = str(uuid4())
|
||||
# create timestamps if not available yet
|
||||
if not label.created_at:
|
||||
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if not label.updated_at:
|
||||
label.updated_at = label.created_at
|
||||
self.indexes[index][label_id] = label
|
||||
self.indexes[index][label.id] = label
|
||||
|
||||
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
|
||||
"""Fetch a document by specifying its text id string"""
|
||||
|
@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
import logging
|
||||
import collections
|
||||
from typing import Any, Dict, Union, List, Optional, Generator
|
||||
from uuid import uuid4
|
||||
|
||||
@ -326,6 +327,13 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
|
||||
labels = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
|
||||
index = index or self.label_index
|
||||
|
||||
duplicate_ids: list = [label.id for label in self._get_duplicate_labels(labels, index=index)]
|
||||
if len(duplicate_ids) > 0:
|
||||
logger.warning(f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store."
|
||||
f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of"
|
||||
f" the answer annotation and not the question."
|
||||
f" Problematic ids: {','.join(duplicate_ids)}")
|
||||
# TODO: Use batch_size
|
||||
for label in labels:
|
||||
label_orm = LabelORM(
|
||||
@ -341,6 +349,9 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
model_id=label.model_id,
|
||||
index=index,
|
||||
)
|
||||
if label.id in duplicate_ids:
|
||||
self.session.merge(label_orm)
|
||||
else:
|
||||
self.session.add(label_orm)
|
||||
self.session.commit()
|
||||
|
||||
@ -432,7 +443,8 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
offset_start_in_doc=row.offset_start_in_doc,
|
||||
model_id=row.model_id,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at
|
||||
updated_at=row.updated_at,
|
||||
id=row.id
|
||||
)
|
||||
return label
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user