mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-20 07:21:09 +00:00
[document_store] Raise warning when labels are overwritten (#1257)
* [document_store]SQLDocumentStore write_labels() overwrite warning added. * [document_store]SQLDocumentStore write_labels() overwrite warning added. * [document_store] bug fixed. #1140 * [document_store] bug fixed. #1140 * [document_store] get_labels_by_id() method removed. #1140 * [document_store] Code refactor. fix #1140 * [document_store] Code refactor. fix #1140 * [document_store] elasticsearch document store Code refactor. fix #1140 * [document_store] elasticsearch document store Code refactor. fix #1140 * [document_store] elasticsearch document store Code refactor. fix #1140 * [document_store] Code refactor for better visibility. fix #1140 * [document_store] Inmemory document store duplicate labels warning added fix #1140
This commit is contained in:
parent
da97d81305
commit
97c1e2cc90
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import collections
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Dict, List, Union
|
from typing import Optional, Dict, List, Union
|
||||||
@ -324,3 +325,24 @@ class BaseDocumentStore(BaseComponent):
|
|||||||
documents = list(filter(lambda doc: doc.id not in ids_exist_in_db, documents))
|
documents = list(filter(lambda doc: doc.id not in ids_exist_in_db, documents))
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
|
def _get_duplicate_labels(self, labels: list, index: str = None) -> List[Label]:
|
||||||
|
"""
|
||||||
|
Return all duplicate labels
|
||||||
|
:param labels: List of Label objects
|
||||||
|
:param index: add an optional index attribute to labels. It can be later used for filtering.
|
||||||
|
:return: List of labels
|
||||||
|
"""
|
||||||
|
index = index or self.label_index
|
||||||
|
new_ids: List[str] = [label.id for label in labels]
|
||||||
|
duplicate_ids: List[str] = []
|
||||||
|
|
||||||
|
for label_id, count in collections.Counter(new_ids).items():
|
||||||
|
if count > 1:
|
||||||
|
duplicate_ids.append(label_id)
|
||||||
|
|
||||||
|
for label in self.get_all_labels(index=index):
|
||||||
|
if label.id in new_ids:
|
||||||
|
duplicate_ids.append(label.id)
|
||||||
|
|
||||||
|
return [label for label in labels if label.id in duplicate_ids]
|
||||||
|
@ -443,29 +443,31 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
if index and not self.client.indices.exists(index=index):
|
if index and not self.client.indices.exists(index=index):
|
||||||
self._create_label_index(index)
|
self._create_label_index(index)
|
||||||
|
|
||||||
|
labels = [Label.from_dict(label) if isinstance(label, dict) else label for label in labels]
|
||||||
|
duplicate_ids: list = [label.id for label in self._get_duplicate_labels(labels, index=index)]
|
||||||
|
if len(duplicate_ids) > 0:
|
||||||
|
logger.warning(f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store."
|
||||||
|
f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of"
|
||||||
|
f" the answer annotation and not the question."
|
||||||
|
f" Problematic ids: {','.join(duplicate_ids)}")
|
||||||
labels_to_index = []
|
labels_to_index = []
|
||||||
for l in labels:
|
for label in labels:
|
||||||
# Make sure we comply to Label class format
|
|
||||||
if isinstance(l, dict):
|
|
||||||
label = Label.from_dict(l)
|
|
||||||
else:
|
|
||||||
label = l
|
|
||||||
|
|
||||||
# create timestamps if not available yet
|
# create timestamps if not available yet
|
||||||
if not label.created_at:
|
if not label.created_at: # type: ignore
|
||||||
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S")
|
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S") # type: ignore
|
||||||
if not label.updated_at:
|
if not label.updated_at: # type: ignore
|
||||||
label.updated_at = label.created_at
|
label.updated_at = label.created_at # type: ignore
|
||||||
|
|
||||||
_label = {
|
_label = {
|
||||||
"_op_type": "index" if self.duplicate_documents == "overwrite" else "create",
|
"_op_type": "index" if self.duplicate_documents == "overwrite" or label.id in duplicate_ids else # type: ignore
|
||||||
|
"create",
|
||||||
"_index": index,
|
"_index": index,
|
||||||
**label.to_dict()
|
**label.to_dict() # type: ignore
|
||||||
} # type: Dict[str, Any]
|
} # type: Dict[str, Any]
|
||||||
|
|
||||||
# rename id for elastic
|
# rename id for elastic
|
||||||
if label.id is not None:
|
if label.id is not None: # type: ignore
|
||||||
_label["_id"] = str(_label.pop("id"))
|
_label["_id"] = str(_label.pop("id")) # type: ignore
|
||||||
|
|
||||||
labels_to_index.append(_label)
|
labels_to_index.append(_label)
|
||||||
|
|
||||||
@ -608,7 +610,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
"""
|
"""
|
||||||
index = index or self.label_index
|
index = index or self.label_index
|
||||||
result = list(self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size))
|
result = list(self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size))
|
||||||
labels = [Label.from_dict(hit["_source"]) for hit in result]
|
labels = [Label.from_dict({**hit["_source"], "id": hit["_id"]}) for hit in result]
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
def _get_all_documents_in_index(
|
def _get_all_documents_in_index(
|
||||||
|
@ -122,14 +122,20 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
|||||||
index = index or self.label_index
|
index = index or self.label_index
|
||||||
label_objects = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
|
label_objects = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
|
||||||
|
|
||||||
|
duplicate_ids: list = [label.id for label in self._get_duplicate_labels(label_objects, index=index)]
|
||||||
|
if len(duplicate_ids) > 0:
|
||||||
|
logger.warning(f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store."
|
||||||
|
f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of"
|
||||||
|
f" the answer annotation and not the question."
|
||||||
|
f" Problematic ids: {','.join(duplicate_ids)}")
|
||||||
|
|
||||||
for label in label_objects:
|
for label in label_objects:
|
||||||
label_id = str(uuid4())
|
|
||||||
# create timestamps if not available yet
|
# create timestamps if not available yet
|
||||||
if not label.created_at:
|
if not label.created_at:
|
||||||
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S")
|
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
if not label.updated_at:
|
if not label.updated_at:
|
||||||
label.updated_at = label.created_at
|
label.updated_at = label.created_at
|
||||||
self.indexes[index][label_id] = label
|
self.indexes[index][label.id] = label
|
||||||
|
|
||||||
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
|
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
|
||||||
"""Fetch a document by specifying its text id string"""
|
"""Fetch a document by specifying its text id string"""
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
|
import collections
|
||||||
from typing import Any, Dict, Union, List, Optional, Generator
|
from typing import Any, Dict, Union, List, Optional, Generator
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@ -326,6 +327,13 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
|
|
||||||
labels = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
|
labels = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
|
||||||
index = index or self.label_index
|
index = index or self.label_index
|
||||||
|
|
||||||
|
duplicate_ids: list = [label.id for label in self._get_duplicate_labels(labels, index=index)]
|
||||||
|
if len(duplicate_ids) > 0:
|
||||||
|
logger.warning(f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store."
|
||||||
|
f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of"
|
||||||
|
f" the answer annotation and not the question."
|
||||||
|
f" Problematic ids: {','.join(duplicate_ids)}")
|
||||||
# TODO: Use batch_size
|
# TODO: Use batch_size
|
||||||
for label in labels:
|
for label in labels:
|
||||||
label_orm = LabelORM(
|
label_orm = LabelORM(
|
||||||
@ -341,6 +349,9 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
model_id=label.model_id,
|
model_id=label.model_id,
|
||||||
index=index,
|
index=index,
|
||||||
)
|
)
|
||||||
|
if label.id in duplicate_ids:
|
||||||
|
self.session.merge(label_orm)
|
||||||
|
else:
|
||||||
self.session.add(label_orm)
|
self.session.add(label_orm)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
|
|
||||||
@ -432,7 +443,8 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
offset_start_in_doc=row.offset_start_in_doc,
|
offset_start_in_doc=row.offset_start_in_doc,
|
||||||
model_id=row.model_id,
|
model_id=row.model_id,
|
||||||
created_at=row.created_at,
|
created_at=row.created_at,
|
||||||
updated_at=row.updated_at
|
updated_at=row.updated_at,
|
||||||
|
id=row.id
|
||||||
)
|
)
|
||||||
return label
|
return label
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user