Fix to_dict and from_dict of Multilabel such that to_dict outputs a json serializable object (using Label.to_dict()) (#5257)

This commit is contained in:
Sebastian Husch Lee 2023-07-04 12:44:11 +02:00 committed by GitHub
parent 195077eca9
commit 87281b2e10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 4 deletions

View File

@ -633,7 +633,7 @@ def is_positive_label(label):
class MultiLabel:
def __init__(self, labels: List[Label], drop_negative_labels=False, drop_no_answers=False):
def __init__(self, labels: List[Label], drop_negative_labels: bool = False, drop_no_answers: bool = False):
"""
There are often multiple `Labels` associated with a single query. For example, there can be multiple annotated
answers for one question or multiple documents contain the information you want for a query.
@ -753,12 +753,17 @@ class MultiLabel:
def to_dict(self):
# convert internal attribute names to property names
return {k[1:] if k[0] == "_" else k: v for k, v in vars(self).items()}
result = {k[1:] if k[0] == "_" else k: v for k, v in vars(self).items()}
# convert Label object to dict
result["labels"] = [label.to_dict() for label in result["labels"]]
return result
@classmethod
def from_dict(cls, dict: Dict):
# exclude extra arguments
return cls(**{k: v for k, v in dict.items() if k in inspect.signature(cls).parameters})
inputs = {k: v for k, v in dict.items() if k in inspect.signature(cls).parameters}
inputs["labels"] = [Label.from_dict(label) for label in inputs["labels"]]
return cls(**inputs)
def to_json(self):
return json.dumps(self.to_dict(), default=pydantic_encoder)
@ -769,7 +774,6 @@ class MultiLabel:
dict_data = json.loads(data)
else:
dict_data = data
dict_data["labels"] = [Label.from_dict(l) for l in dict_data["labels"]]
return cls.from_dict(dict_data)
def __eq__(self, other):

View File

@ -932,6 +932,52 @@ def test_multilabel_serialization():
assert json_deserialized_multilabel.labels[0] == label
@pytest.mark.unit
def test_table_multilabel_serialization():
tabel_label_dict = {
"id": "011079cf-c93f-49e6-83bb-42cd850dce12",
"query": "What is the first number?",
"document": {
"content": [["col1", "col2"], [1, 3], [2, 4]],
"content_type": "table",
"id": "table1",
"meta": {},
"score": None,
"embedding": None,
},
"is_correct_answer": True,
"is_correct_document": True,
"origin": "user-feedback",
"answer": {
"answer": "1",
"type": "extractive",
"score": None,
"context": [["col1", "col2"], [1, 3], [2, 4]],
"offsets_in_document": [{"row": 0, "col": 0}],
"offsets_in_context": [{"row": 0, "col": 0}],
"document_ids": ["table1"],
"meta": {},
},
"no_answer": False,
"pipeline_id": None,
"created_at": "2022-07-22T13:29:33.699781+00:00",
"updated_at": "2022-07-22T13:29:33.784895+00:00",
"meta": {"answer_id": "374394", "document_id": "604995", "question_id": "345530"},
"filters": None,
}
label = Label.from_dict(tabel_label_dict)
original_multilabel = MultiLabel([label])
deserialized_multilabel = MultiLabel.from_dict(original_multilabel.to_dict())
assert deserialized_multilabel == original_multilabel
assert deserialized_multilabel.labels[0] == label
json_deserialized_multilabel = MultiLabel.from_json(original_multilabel.to_json())
assert json_deserialized_multilabel == original_multilabel
assert json_deserialized_multilabel.labels[0] == label
@pytest.mark.unit
def test_span_in():
assert 10 in Span(5, 15)