diff --git a/haystack/dataclasses/document.py b/haystack/dataclasses/document.py index 397520e07..aed359641 100644 --- a/haystack/dataclasses/document.py +++ b/haystack/dataclasses/document.py @@ -143,9 +143,6 @@ class Document(metaclass=_BackwardCompatible): if (blob := data.get("blob")) is not None: data["blob"] = {"data": list(blob["data"]), "mime_type": blob["mime_type"]} - if (sparse_embedding := data.get("sparse_embedding")) is not None: - data["sparse_embedding"] = sparse_embedding.to_dict() - if flatten: meta = data.pop("meta") return {**data, **meta} diff --git a/haystack/dataclasses/sparse_embedding.py b/haystack/dataclasses/sparse_embedding.py index 262de439d..3f3c18c80 100644 --- a/haystack/dataclasses/sparse_embedding.py +++ b/haystack/dataclasses/sparse_embedding.py @@ -2,42 +2,42 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import List +from dataclasses import asdict, dataclass +from typing import Any, Dict, List +@dataclass class SparseEmbedding: """ Class representing a sparse embedding. + + :param indices: List of indices of non-zero elements in the embedding. + :param values: List of values of non-zero elements in the embedding. """ - def __init__(self, indices: List[int], values: List[float]): - """ - Initialize a SparseEmbedding object. + indices: List[int] + values: List[float] - :param indices: List of indices of non-zero elements in the embedding. - :param values: List of values of non-zero elements in the embedding. - - :raises ValueError: If the indices and values lists are not of the same length. + def __post_init__(self): """ - if len(indices) != len(values): + Checks if the indices and values lists are of the same length. + + Raises a ValueError if they are not. + """ + if len(self.indices) != len(self.values): raise ValueError("Length of indices and values must be the same.") - self.indices = indices - self.values = values - def __eq__(self, other): - return self.indices == other.indices and self.values == other.values - - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: """ Convert the SparseEmbedding object to a dictionary. :returns: Serialized sparse embedding. """ - return {"indices": self.indices, "values": self.values} + return asdict(self) @classmethod - def from_dict(cls, sparse_embedding_dict): + def from_dict(cls, sparse_embedding_dict: Dict[str, Any]) -> "SparseEmbedding": """ Deserializes the sparse embedding from a dictionary. @@ -46,4 +46,4 @@ class SparseEmbedding: :returns: Deserialized sparse embedding. """ - return cls(indices=sparse_embedding_dict["indices"], values=sparse_embedding_dict["values"]) + return cls(**sparse_embedding_dict) diff --git a/releasenotes/notes/sparse-embedding-dataclass-d75ae1ee6d75e646.yaml b/releasenotes/notes/sparse-embedding-dataclass-d75ae1ee6d75e646.yaml new file mode 100644 index 000000000..bd20c4389 --- /dev/null +++ b/releasenotes/notes/sparse-embedding-dataclass-d75ae1ee6d75e646.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Make SparseEmbedding a dataclass, this makes it easier to use the class with Pydantic