diff --git a/haystack/preview/dataclasses/document.py b/haystack/preview/dataclasses/document.py index 644d12cfe..98db4d944 100644 --- a/haystack/preview/dataclasses/document.py +++ b/haystack/preview/dataclasses/document.py @@ -2,7 +2,7 @@ import io import hashlib import logging from dataclasses import asdict, dataclass, field, fields -from typing import Any, Dict, List, Optional, Type, cast +from typing import Any, Dict, List, Optional import numpy import pandas @@ -40,18 +40,6 @@ class _BackwardCompatible(type): if "id_hash_keys" in kwargs: del kwargs["id_hash_keys"] - if kwargs.get("meta") is None: - # This must be a flattened Document, so we treat all keys that are not - # Document fields as metadata. - meta = {} - field_names = [f.name for f in fields(cast(Type[Document], cls))] - keys = list(kwargs.keys()) # get a list of the keys as we'll modify the dict in the loop - for key in keys: - if key in field_names: - continue - meta[key] = kwargs.pop(key) - kwargs["meta"] = meta - return super().__call__(*args, **kwargs) @@ -149,7 +137,15 @@ class Document(metaclass=_BackwardCompatible): data["dataframe"] = pandas.read_json(io.StringIO(dataframe)) if blob := data.get("blob"): data["blob"] = ByteStream(data=bytes(blob["data"]), mime_type=blob["mime_type"]) - return cls(**data) + # Unflatten metadata if it was flattened + meta = {} + legacy_fields = ["content_type", "id_hash_keys"] + field_names = legacy_fields + [f.name for f in fields(cls)] + for key in list(data.keys()): + if key not in field_names: + meta[key] = data.pop(key) + + return cls(**data, meta=meta) @property def content_type(self): diff --git a/releasenotes/notes/fix-document-init-09c1cbb14202be7d.yaml b/releasenotes/notes/fix-document-init-09c1cbb14202be7d.yaml new file mode 100644 index 000000000..f1fc6a6cb --- /dev/null +++ b/releasenotes/notes/fix-document-init-09c1cbb14202be7d.yaml @@ -0,0 +1,4 @@ +--- +preview: + - | + Make Document's constructor fail when is passed fields that are not present in the dataclass. An exception is made for "content_type" and "id_hash_keys": they are accepted in order to keep backward compatibility. diff --git a/test/preview/dataclasses/test_document.py b/test/preview/dataclasses/test_document.py index 4305cdfbb..49c61541b 100644 --- a/test/preview/dataclasses/test_document.py +++ b/test/preview/dataclasses/test_document.py @@ -43,6 +43,12 @@ def test_init(): assert doc.embedding == None +@pytest.mark.unit +def test_init_with_wrong_parameters(): + with pytest.raises(TypeError): + Document(text="") + + @pytest.mark.unit def test_init_with_parameters(): blob_data = b"some bytes" @@ -80,15 +86,14 @@ def test_init_with_legacy_fields(): @pytest.mark.unit -def test_init_with_legacy_field_and_flat_meta(): +def test_init_with_legacy_field(): doc = Document( content="test text", content_type="text", # type: ignore id_hash_keys=["content"], # type: ignore score=0.812, embedding=[0.1, 0.2, 0.3], - date="10-10-2023", # type: ignore - type="article", # type: ignore + meta={"date": "10-10-2023", "type": "article"}, ) assert doc.id == "a2c0321b34430cc675294611e55529fceb56140ca3202f1c59a43a8cecac1f43" assert doc.content == "test text" @@ -98,44 +103,6 @@ def test_init_with_legacy_field_and_flat_meta(): assert doc.embedding == [0.1, 0.2, 0.3] -@pytest.mark.unit -def test_init_with_flat_meta(): - blob_data = b"some bytes" - doc = Document( - content="test text", - dataframe=pd.DataFrame([0]), - blob=ByteStream(data=blob_data, mime_type="text/markdown"), - score=0.812, - embedding=[0.1, 0.2, 0.3], - date="10-10-2023", # type: ignore - type="article", # type: ignore - ) - assert doc.id == "c6212ad7bb513c572367e11dd12fd671911a1a5499e3d31e4fe3bda7e87c0641" - assert doc.content == "test text" - assert doc.dataframe is not None - assert doc.dataframe.equals(pd.DataFrame([0])) - assert doc.blob.data == blob_data - assert doc.blob.mime_type == "text/markdown" - assert doc.meta == {"date": "10-10-2023", "type": "article"} - assert doc.score == 0.812 - assert doc.embedding == [0.1, 0.2, 0.3] - - -@pytest.mark.unit -def test_init_with_flat_and_non_flat_meta(): - with pytest.raises(TypeError): - Document( - content="test text", - dataframe=pd.DataFrame([0]), - blob=ByteStream(data=b"some bytes", mime_type="text/markdown"), - score=0.812, - meta={"test": 10}, - embedding=[0.1, 0.2, 0.3], - date="10-10-2023", # type: ignore - type="article", # type: ignore - ) - - @pytest.mark.unit def test_basic_equality_type_mismatch(): doc = Document(content="test text") @@ -286,8 +253,7 @@ def test_from_dict_with_legacy_field_and_flat_meta(): id_hash_keys=["content"], # type: ignore score=0.812, embedding=[0.1, 0.2, 0.3], - date="10-10-2023", # type: ignore - type="article", # type: ignore + meta={"date": "10-10-2023", "type": "article"}, )