mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-13 00:23:46 +00:00
Fix to_dict and from_dict of Multilabel such that to_dict outputs a json serializable object (using Label.to_dict()) (#5257)
This commit is contained in:
parent
195077eca9
commit
87281b2e10
@ -633,7 +633,7 @@ def is_positive_label(label):
|
|||||||
|
|
||||||
|
|
||||||
class MultiLabel:
|
class MultiLabel:
|
||||||
def __init__(self, labels: List[Label], drop_negative_labels=False, drop_no_answers=False):
|
def __init__(self, labels: List[Label], drop_negative_labels: bool = False, drop_no_answers: bool = False):
|
||||||
"""
|
"""
|
||||||
There are often multiple `Labels` associated with a single query. For example, there can be multiple annotated
|
There are often multiple `Labels` associated with a single query. For example, there can be multiple annotated
|
||||||
answers for one question or multiple documents contain the information you want for a query.
|
answers for one question or multiple documents contain the information you want for a query.
|
||||||
@ -753,12 +753,17 @@ class MultiLabel:
|
|||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
# convert internal attribute names to property names
|
# convert internal attribute names to property names
|
||||||
return {k[1:] if k[0] == "_" else k: v for k, v in vars(self).items()}
|
result = {k[1:] if k[0] == "_" else k: v for k, v in vars(self).items()}
|
||||||
|
# convert Label object to dict
|
||||||
|
result["labels"] = [label.to_dict() for label in result["labels"]]
|
||||||
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, dict: Dict):
|
def from_dict(cls, dict: Dict):
|
||||||
# exclude extra arguments
|
# exclude extra arguments
|
||||||
return cls(**{k: v for k, v in dict.items() if k in inspect.signature(cls).parameters})
|
inputs = {k: v for k, v in dict.items() if k in inspect.signature(cls).parameters}
|
||||||
|
inputs["labels"] = [Label.from_dict(label) for label in inputs["labels"]]
|
||||||
|
return cls(**inputs)
|
||||||
|
|
||||||
def to_json(self):
|
def to_json(self):
|
||||||
return json.dumps(self.to_dict(), default=pydantic_encoder)
|
return json.dumps(self.to_dict(), default=pydantic_encoder)
|
||||||
@ -769,7 +774,6 @@ class MultiLabel:
|
|||||||
dict_data = json.loads(data)
|
dict_data = json.loads(data)
|
||||||
else:
|
else:
|
||||||
dict_data = data
|
dict_data = data
|
||||||
dict_data["labels"] = [Label.from_dict(l) for l in dict_data["labels"]]
|
|
||||||
return cls.from_dict(dict_data)
|
return cls.from_dict(dict_data)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
|||||||
@ -932,6 +932,52 @@ def test_multilabel_serialization():
|
|||||||
assert json_deserialized_multilabel.labels[0] == label
|
assert json_deserialized_multilabel.labels[0] == label
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_table_multilabel_serialization():
|
||||||
|
tabel_label_dict = {
|
||||||
|
"id": "011079cf-c93f-49e6-83bb-42cd850dce12",
|
||||||
|
"query": "What is the first number?",
|
||||||
|
"document": {
|
||||||
|
"content": [["col1", "col2"], [1, 3], [2, 4]],
|
||||||
|
"content_type": "table",
|
||||||
|
"id": "table1",
|
||||||
|
"meta": {},
|
||||||
|
"score": None,
|
||||||
|
"embedding": None,
|
||||||
|
},
|
||||||
|
"is_correct_answer": True,
|
||||||
|
"is_correct_document": True,
|
||||||
|
"origin": "user-feedback",
|
||||||
|
"answer": {
|
||||||
|
"answer": "1",
|
||||||
|
"type": "extractive",
|
||||||
|
"score": None,
|
||||||
|
"context": [["col1", "col2"], [1, 3], [2, 4]],
|
||||||
|
"offsets_in_document": [{"row": 0, "col": 0}],
|
||||||
|
"offsets_in_context": [{"row": 0, "col": 0}],
|
||||||
|
"document_ids": ["table1"],
|
||||||
|
"meta": {},
|
||||||
|
},
|
||||||
|
"no_answer": False,
|
||||||
|
"pipeline_id": None,
|
||||||
|
"created_at": "2022-07-22T13:29:33.699781+00:00",
|
||||||
|
"updated_at": "2022-07-22T13:29:33.784895+00:00",
|
||||||
|
"meta": {"answer_id": "374394", "document_id": "604995", "question_id": "345530"},
|
||||||
|
"filters": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
label = Label.from_dict(tabel_label_dict)
|
||||||
|
original_multilabel = MultiLabel([label])
|
||||||
|
|
||||||
|
deserialized_multilabel = MultiLabel.from_dict(original_multilabel.to_dict())
|
||||||
|
assert deserialized_multilabel == original_multilabel
|
||||||
|
assert deserialized_multilabel.labels[0] == label
|
||||||
|
|
||||||
|
json_deserialized_multilabel = MultiLabel.from_json(original_multilabel.to_json())
|
||||||
|
assert json_deserialized_multilabel == original_multilabel
|
||||||
|
assert json_deserialized_multilabel.labels[0] == label
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_span_in():
|
def test_span_in():
|
||||||
assert 10 in Span(5, 15)
|
assert 10 in Span(5, 15)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user