diff --git a/haystack/schema.py b/haystack/schema.py index 6ada951c7..a6547e546 100644 --- a/haystack/schema.py +++ b/haystack/schema.py @@ -633,7 +633,7 @@ def is_positive_label(label): class MultiLabel: - def __init__(self, labels: List[Label], drop_negative_labels=False, drop_no_answers=False): + def __init__(self, labels: List[Label], drop_negative_labels: bool = False, drop_no_answers: bool = False): """ There are often multiple `Labels` associated with a single query. For example, there can be multiple annotated answers for one question or multiple documents contain the information you want for a query. @@ -753,12 +753,17 @@ class MultiLabel: def to_dict(self): # convert internal attribute names to property names - return {k[1:] if k[0] == "_" else k: v for k, v in vars(self).items()} + result = {k[1:] if k[0] == "_" else k: v for k, v in vars(self).items()} + # convert Label object to dict + result["labels"] = [label.to_dict() for label in result["labels"]] + return result @classmethod def from_dict(cls, dict: Dict): # exclude extra arguments - return cls(**{k: v for k, v in dict.items() if k in inspect.signature(cls).parameters}) + inputs = {k: v for k, v in dict.items() if k in inspect.signature(cls).parameters} + inputs["labels"] = [Label.from_dict(label) for label in inputs["labels"]] + return cls(**inputs) def to_json(self): return json.dumps(self.to_dict(), default=pydantic_encoder) @@ -769,7 +774,6 @@ class MultiLabel: dict_data = json.loads(data) else: dict_data = data - dict_data["labels"] = [Label.from_dict(l) for l in dict_data["labels"]] return cls.from_dict(dict_data) def __eq__(self, other): diff --git a/test/others/test_schema.py b/test/others/test_schema.py index cde25cbf7..867aea1cd 100644 --- a/test/others/test_schema.py +++ b/test/others/test_schema.py @@ -932,6 +932,52 @@ def test_multilabel_serialization(): assert json_deserialized_multilabel.labels[0] == label +@pytest.mark.unit +def test_table_multilabel_serialization(): + tabel_label_dict = { + "id": "011079cf-c93f-49e6-83bb-42cd850dce12", + "query": "What is the first number?", + "document": { + "content": [["col1", "col2"], [1, 3], [2, 4]], + "content_type": "table", + "id": "table1", + "meta": {}, + "score": None, + "embedding": None, + }, + "is_correct_answer": True, + "is_correct_document": True, + "origin": "user-feedback", + "answer": { + "answer": "1", + "type": "extractive", + "score": None, + "context": [["col1", "col2"], [1, 3], [2, 4]], + "offsets_in_document": [{"row": 0, "col": 0}], + "offsets_in_context": [{"row": 0, "col": 0}], + "document_ids": ["table1"], + "meta": {}, + }, + "no_answer": False, + "pipeline_id": None, + "created_at": "2022-07-22T13:29:33.699781+00:00", + "updated_at": "2022-07-22T13:29:33.784895+00:00", + "meta": {"answer_id": "374394", "document_id": "604995", "question_id": "345530"}, + "filters": None, + } + + label = Label.from_dict(tabel_label_dict) + original_multilabel = MultiLabel([label]) + + deserialized_multilabel = MultiLabel.from_dict(original_multilabel.to_dict()) + assert deserialized_multilabel == original_multilabel + assert deserialized_multilabel.labels[0] == label + + json_deserialized_multilabel = MultiLabel.from_json(original_multilabel.to_json()) + assert json_deserialized_multilabel == original_multilabel + assert json_deserialized_multilabel.labels[0] == label + + @pytest.mark.unit def test_span_in(): assert 10 in Span(5, 15)