Fix Document init when passing non existing fields (#6286)

* Fix Document init when passing non existing fields

* Update releasenotes/notes/fix-document-init-09c1cbb14202be7d.yaml

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>

* Fix linting

---------

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
Silvano Cerza 2023-11-13 11:42:42 +01:00 committed by GitHub
parent bf637e9c7e
commit 8e7ce208fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 57 deletions

View File

@ -2,7 +2,7 @@ import io
import hashlib
import logging
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, List, Optional, Type, cast
from typing import Any, Dict, List, Optional
import numpy
import pandas
@ -40,18 +40,6 @@ class _BackwardCompatible(type):
if "id_hash_keys" in kwargs:
del kwargs["id_hash_keys"]
if kwargs.get("meta") is None:
# This must be a flattened Document, so we treat all keys that are not
# Document fields as metadata.
meta = {}
field_names = [f.name for f in fields(cast(Type[Document], cls))]
keys = list(kwargs.keys()) # get a list of the keys as we'll modify the dict in the loop
for key in keys:
if key in field_names:
continue
meta[key] = kwargs.pop(key)
kwargs["meta"] = meta
return super().__call__(*args, **kwargs)
@ -149,7 +137,15 @@ class Document(metaclass=_BackwardCompatible):
data["dataframe"] = pandas.read_json(io.StringIO(dataframe))
if blob := data.get("blob"):
data["blob"] = ByteStream(data=bytes(blob["data"]), mime_type=blob["mime_type"])
return cls(**data)
# Unflatten metadata if it was flattened
meta = {}
legacy_fields = ["content_type", "id_hash_keys"]
field_names = legacy_fields + [f.name for f in fields(cls)]
for key in list(data.keys()):
if key not in field_names:
meta[key] = data.pop(key)
return cls(**data, meta=meta)
@property
def content_type(self):

View File

@ -0,0 +1,4 @@
---
preview:
- |
Make Document's constructor fail when is passed fields that are not present in the dataclass. An exception is made for "content_type" and "id_hash_keys": they are accepted in order to keep backward compatibility.

View File

@ -43,6 +43,12 @@ def test_init():
assert doc.embedding == None
@pytest.mark.unit
def test_init_with_wrong_parameters():
with pytest.raises(TypeError):
Document(text="")
@pytest.mark.unit
def test_init_with_parameters():
blob_data = b"some bytes"
@ -80,15 +86,14 @@ def test_init_with_legacy_fields():
@pytest.mark.unit
def test_init_with_legacy_field_and_flat_meta():
def test_init_with_legacy_field():
doc = Document(
content="test text",
content_type="text", # type: ignore
id_hash_keys=["content"], # type: ignore
score=0.812,
embedding=[0.1, 0.2, 0.3],
date="10-10-2023", # type: ignore
type="article", # type: ignore
meta={"date": "10-10-2023", "type": "article"},
)
assert doc.id == "a2c0321b34430cc675294611e55529fceb56140ca3202f1c59a43a8cecac1f43"
assert doc.content == "test text"
@ -98,44 +103,6 @@ def test_init_with_legacy_field_and_flat_meta():
assert doc.embedding == [0.1, 0.2, 0.3]
@pytest.mark.unit
def test_init_with_flat_meta():
blob_data = b"some bytes"
doc = Document(
content="test text",
dataframe=pd.DataFrame([0]),
blob=ByteStream(data=blob_data, mime_type="text/markdown"),
score=0.812,
embedding=[0.1, 0.2, 0.3],
date="10-10-2023", # type: ignore
type="article", # type: ignore
)
assert doc.id == "c6212ad7bb513c572367e11dd12fd671911a1a5499e3d31e4fe3bda7e87c0641"
assert doc.content == "test text"
assert doc.dataframe is not None
assert doc.dataframe.equals(pd.DataFrame([0]))
assert doc.blob.data == blob_data
assert doc.blob.mime_type == "text/markdown"
assert doc.meta == {"date": "10-10-2023", "type": "article"}
assert doc.score == 0.812
assert doc.embedding == [0.1, 0.2, 0.3]
@pytest.mark.unit
def test_init_with_flat_and_non_flat_meta():
with pytest.raises(TypeError):
Document(
content="test text",
dataframe=pd.DataFrame([0]),
blob=ByteStream(data=b"some bytes", mime_type="text/markdown"),
score=0.812,
meta={"test": 10},
embedding=[0.1, 0.2, 0.3],
date="10-10-2023", # type: ignore
type="article", # type: ignore
)
@pytest.mark.unit
def test_basic_equality_type_mismatch():
doc = Document(content="test text")
@ -286,8 +253,7 @@ def test_from_dict_with_legacy_field_and_flat_meta():
id_hash_keys=["content"], # type: ignore
score=0.812,
embedding=[0.1, 0.2, 0.3],
date="10-10-2023", # type: ignore
type="article", # type: ignore
meta={"date": "10-10-2023", "type": "article"},
)