mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 18:29:32 +00:00
Fix to_dict and from_dict of Multilabel such that to_dict outputs a json serializable object (using Label.to_dict()) (#5257)
This commit is contained in:
parent
195077eca9
commit
87281b2e10
@ -633,7 +633,7 @@ def is_positive_label(label):
|
||||
|
||||
|
||||
class MultiLabel:
|
||||
def __init__(self, labels: List[Label], drop_negative_labels=False, drop_no_answers=False):
|
||||
def __init__(self, labels: List[Label], drop_negative_labels: bool = False, drop_no_answers: bool = False):
|
||||
"""
|
||||
There are often multiple `Labels` associated with a single query. For example, there can be multiple annotated
|
||||
answers for one question or multiple documents contain the information you want for a query.
|
||||
@ -753,12 +753,17 @@ class MultiLabel:
|
||||
|
||||
def to_dict(self):
|
||||
# convert internal attribute names to property names
|
||||
return {k[1:] if k[0] == "_" else k: v for k, v in vars(self).items()}
|
||||
result = {k[1:] if k[0] == "_" else k: v for k, v in vars(self).items()}
|
||||
# convert Label object to dict
|
||||
result["labels"] = [label.to_dict() for label in result["labels"]]
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict: Dict):
|
||||
# exclude extra arguments
|
||||
return cls(**{k: v for k, v in dict.items() if k in inspect.signature(cls).parameters})
|
||||
inputs = {k: v for k, v in dict.items() if k in inspect.signature(cls).parameters}
|
||||
inputs["labels"] = [Label.from_dict(label) for label in inputs["labels"]]
|
||||
return cls(**inputs)
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(self.to_dict(), default=pydantic_encoder)
|
||||
@ -769,7 +774,6 @@ class MultiLabel:
|
||||
dict_data = json.loads(data)
|
||||
else:
|
||||
dict_data = data
|
||||
dict_data["labels"] = [Label.from_dict(l) for l in dict_data["labels"]]
|
||||
return cls.from_dict(dict_data)
|
||||
|
||||
def __eq__(self, other):
|
||||
|
||||
@ -932,6 +932,52 @@ def test_multilabel_serialization():
|
||||
assert json_deserialized_multilabel.labels[0] == label
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_table_multilabel_serialization():
|
||||
tabel_label_dict = {
|
||||
"id": "011079cf-c93f-49e6-83bb-42cd850dce12",
|
||||
"query": "What is the first number?",
|
||||
"document": {
|
||||
"content": [["col1", "col2"], [1, 3], [2, 4]],
|
||||
"content_type": "table",
|
||||
"id": "table1",
|
||||
"meta": {},
|
||||
"score": None,
|
||||
"embedding": None,
|
||||
},
|
||||
"is_correct_answer": True,
|
||||
"is_correct_document": True,
|
||||
"origin": "user-feedback",
|
||||
"answer": {
|
||||
"answer": "1",
|
||||
"type": "extractive",
|
||||
"score": None,
|
||||
"context": [["col1", "col2"], [1, 3], [2, 4]],
|
||||
"offsets_in_document": [{"row": 0, "col": 0}],
|
||||
"offsets_in_context": [{"row": 0, "col": 0}],
|
||||
"document_ids": ["table1"],
|
||||
"meta": {},
|
||||
},
|
||||
"no_answer": False,
|
||||
"pipeline_id": None,
|
||||
"created_at": "2022-07-22T13:29:33.699781+00:00",
|
||||
"updated_at": "2022-07-22T13:29:33.784895+00:00",
|
||||
"meta": {"answer_id": "374394", "document_id": "604995", "question_id": "345530"},
|
||||
"filters": None,
|
||||
}
|
||||
|
||||
label = Label.from_dict(tabel_label_dict)
|
||||
original_multilabel = MultiLabel([label])
|
||||
|
||||
deserialized_multilabel = MultiLabel.from_dict(original_multilabel.to_dict())
|
||||
assert deserialized_multilabel == original_multilabel
|
||||
assert deserialized_multilabel.labels[0] == label
|
||||
|
||||
json_deserialized_multilabel = MultiLabel.from_json(original_multilabel.to_json())
|
||||
assert json_deserialized_multilabel == original_multilabel
|
||||
assert json_deserialized_multilabel.labels[0] == label
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_span_in():
|
||||
assert 10 in Span(5, 15)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user