Make SparseEmbedding a dataclass (#7678)

This commit is contained in:
Silvano Cerza 2024-05-09 17:11:43 +02:00 committed by GitHub
parent 10c675d534
commit 0e1a5a65e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 21 deletions

View File

@ -143,9 +143,6 @@ class Document(metaclass=_BackwardCompatible):
if (blob := data.get("blob")) is not None:
data["blob"] = {"data": list(blob["data"]), "mime_type": blob["mime_type"]}
if (sparse_embedding := data.get("sparse_embedding")) is not None:
data["sparse_embedding"] = sparse_embedding.to_dict()
if flatten:
meta = data.pop("meta")
return {**data, **meta}

View File

@ -2,42 +2,42 @@
#
# SPDX-License-Identifier: Apache-2.0
from typing import List
from dataclasses import asdict, dataclass
from typing import Any, Dict, List
@dataclass
class SparseEmbedding:
"""
Class representing a sparse embedding.
:param indices: List of indices of non-zero elements in the embedding.
:param values: List of values of non-zero elements in the embedding.
"""
def __init__(self, indices: List[int], values: List[float]):
"""
Initialize a SparseEmbedding object.
indices: List[int]
values: List[float]
:param indices: List of indices of non-zero elements in the embedding.
:param values: List of values of non-zero elements in the embedding.
:raises ValueError: If the indices and values lists are not of the same length.
def __post_init__(self):
"""
if len(indices) != len(values):
Checks if the indices and values lists are of the same length.
Raises a ValueError if they are not.
"""
if len(self.indices) != len(self.values):
raise ValueError("Length of indices and values must be the same.")
self.indices = indices
self.values = values
def __eq__(self, other):
return self.indices == other.indices and self.values == other.values
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
"""
Convert the SparseEmbedding object to a dictionary.
:returns:
Serialized sparse embedding.
"""
return {"indices": self.indices, "values": self.values}
return asdict(self)
@classmethod
def from_dict(cls, sparse_embedding_dict):
def from_dict(cls, sparse_embedding_dict: Dict[str, Any]) -> "SparseEmbedding":
"""
Deserializes the sparse embedding from a dictionary.
@ -46,4 +46,4 @@ class SparseEmbedding:
:returns:
Deserialized sparse embedding.
"""
return cls(indices=sparse_embedding_dict["indices"], values=sparse_embedding_dict["values"])
return cls(**sparse_embedding_dict)

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Make SparseEmbedding a dataclass, this makes it easier to use the class with Pydantic