mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-24 21:48:52 +00:00
feat: introduce SparseEmbedding (#7382)
* introduce SparseEmbedding * reno * add to pydoc config
This commit is contained in:
parent
610ad6f6b2
commit
dbfd351da7
@ -2,7 +2,7 @@ loaders:
|
||||
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
||||
search_path: [../../../haystack/dataclasses]
|
||||
modules:
|
||||
["answer", "byte_stream", "chat_message", "document", "streaming_chunk"]
|
||||
["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding"]
|
||||
ignore_when_discovered: ["__init__"]
|
||||
processors:
|
||||
- type: filter
|
||||
|
||||
@ -2,6 +2,7 @@ from haystack.dataclasses.answer import Answer, ExtractedAnswer, GeneratedAnswer
|
||||
from haystack.dataclasses.byte_stream import ByteStream
|
||||
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
|
||||
from haystack.dataclasses.document import Document
|
||||
from haystack.dataclasses.sparse_embedding import SparseEmbedding
|
||||
from haystack.dataclasses.streaming_chunk import StreamingChunk
|
||||
|
||||
__all__ = [
|
||||
@ -13,4 +14,5 @@ __all__ = [
|
||||
"ChatMessage",
|
||||
"ChatRole",
|
||||
"StreamingChunk",
|
||||
"SparseEmbedding",
|
||||
]
|
||||
|
||||
@ -8,6 +8,7 @@ from pandas import DataFrame, read_json
|
||||
|
||||
from haystack import logging
|
||||
from haystack.dataclasses.byte_stream import ByteStream
|
||||
from haystack.dataclasses.sparse_embedding import SparseEmbedding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -57,7 +58,8 @@ class Document(metaclass=_BackwardCompatible):
|
||||
:param blob: Binary data associated with the document, if the document has any binary data associated with it.
|
||||
:param meta: Additional custom metadata for the document. Must be JSON-serializable.
|
||||
:param score: Score of the document. Used for ranking, usually assigned by retrievers.
|
||||
:param embedding: Vector representation of the document.
|
||||
:param embedding: dense vector representation of the document.
|
||||
:param sparse_embedding: sparse vector representation of the document.
|
||||
"""
|
||||
|
||||
id: str = field(default="")
|
||||
@ -67,6 +69,7 @@ class Document(metaclass=_BackwardCompatible):
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
score: Optional[float] = field(default=None)
|
||||
embedding: Optional[List[float]] = field(default=None)
|
||||
sparse_embedding: Optional[SparseEmbedding] = field(default=None)
|
||||
|
||||
def __repr__(self):
|
||||
fields = []
|
||||
@ -84,6 +87,8 @@ class Document(metaclass=_BackwardCompatible):
|
||||
fields.append(f"score: {self.score}")
|
||||
if self.embedding is not None:
|
||||
fields.append(f"embedding: vector of size {len(self.embedding)}")
|
||||
if self.sparse_embedding is not None:
|
||||
fields.append(f"sparse_embedding: vector with {len(self.sparse_embedding.indices)} non-zero elements")
|
||||
fields_str = ", ".join(fields)
|
||||
return f"{self.__class__.__name__}(id={self.id}, {fields_str})"
|
||||
|
||||
@ -114,7 +119,8 @@ class Document(metaclass=_BackwardCompatible):
|
||||
mime_type = self.blob.mime_type if self.blob is not None else None
|
||||
meta = self.meta or {}
|
||||
embedding = self.embedding if self.embedding is not None else None
|
||||
data = f"{text}{dataframe}{blob}{mime_type}{meta}{embedding}"
|
||||
sparse_embedding = self.sparse_embedding.to_dict() if self.sparse_embedding is not None else ""
|
||||
data = f"{text}{dataframe}{blob}{mime_type}{meta}{embedding}{sparse_embedding}"
|
||||
return hashlib.sha256(data.encode("utf-8")).hexdigest()
|
||||
|
||||
def to_dict(self, flatten=True) -> Dict[str, Any]:
|
||||
@ -132,6 +138,9 @@ class Document(metaclass=_BackwardCompatible):
|
||||
if (blob := data.get("blob")) is not None:
|
||||
data["blob"] = {"data": list(blob["data"]), "mime_type": blob["mime_type"]}
|
||||
|
||||
if (sparse_embedding := data.get("sparse_embedding")) is not None:
|
||||
data["sparse_embedding"] = sparse_embedding.to_dict()
|
||||
|
||||
if flatten:
|
||||
meta = data.pop("meta")
|
||||
return {**data, **meta}
|
||||
@ -149,6 +158,9 @@ class Document(metaclass=_BackwardCompatible):
|
||||
data["dataframe"] = read_json(io.StringIO(dataframe))
|
||||
if blob := data.get("blob"):
|
||||
data["blob"] = ByteStream(data=bytes(blob["data"]), mime_type=blob["mime_type"])
|
||||
if sparse_embedding := data.get("sparse_embedding"):
|
||||
data["sparse_embedding"] = SparseEmbedding.from_dict(sparse_embedding)
|
||||
|
||||
# Store metadata for a moment while we try un-flattening allegedly flatten metadata.
|
||||
# We don't expect both a `meta=` keyword and flatten metadata keys so we'll raise a
|
||||
# ValueError later if this is the case.
|
||||
|
||||
26
haystack/dataclasses/sparse_embedding.py
Normal file
26
haystack/dataclasses/sparse_embedding.py
Normal file
@ -0,0 +1,26 @@
|
||||
from typing import List
|
||||
|
||||
|
||||
class SparseEmbedding:
|
||||
"""
|
||||
Class representing a sparse embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, indices: List[int], values: List[float]):
|
||||
"""
|
||||
:param indices: List of indices of non-zero elements in the embedding.
|
||||
:param values: List of values of non-zero elements in the embedding.
|
||||
|
||||
:raises ValueError: If the indices and values lists are not of the same length.
|
||||
"""
|
||||
if len(indices) != len(values):
|
||||
raise ValueError("Length of indices and values must be the same.")
|
||||
self.indices = indices
|
||||
self.values = values
|
||||
|
||||
def to_dict(self):
|
||||
return {"indices": self.indices, "values": self.values}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, sparse_embedding_dict):
|
||||
return cls(indices=sparse_embedding_dict["indices"], values=sparse_embedding_dict["values"])
|
||||
@ -0,0 +1,7 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Introduce a new `SparseEmbedding` class which can be used to store a sparse
|
||||
vector representation of a Document.
|
||||
It will be instrumental to support Sparse Embedding Retrieval with
|
||||
the subsequent introduction of Sparse Embedders and Sparse Embedding Retrievers.
|
||||
@ -3,6 +3,7 @@ import pytest
|
||||
|
||||
from haystack import Document
|
||||
from haystack.dataclasses.byte_stream import ByteStream
|
||||
from haystack.dataclasses.sparse_embedding import SparseEmbedding
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -37,6 +38,7 @@ def test_init():
|
||||
assert doc.meta == {}
|
||||
assert doc.score == None
|
||||
assert doc.embedding == None
|
||||
assert doc.sparse_embedding == None
|
||||
|
||||
|
||||
def test_init_with_wrong_parameters():
|
||||
@ -46,6 +48,7 @@ def test_init_with_wrong_parameters():
|
||||
|
||||
def test_init_with_parameters():
|
||||
blob_data = b"some bytes"
|
||||
sparse_embedding = SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3])
|
||||
doc = Document(
|
||||
content="test text",
|
||||
dataframe=pd.DataFrame([0]),
|
||||
@ -53,8 +56,9 @@ def test_init_with_parameters():
|
||||
meta={"text": "test text"},
|
||||
score=0.812,
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
sparse_embedding=sparse_embedding,
|
||||
)
|
||||
assert doc.id == "ec92455f3f4576d40031163c89b1b4210b34ea1426ee0ff68ebed86cb7ba13f8"
|
||||
assert doc.id == "967b7bd4a21861ad9e863f638cefcbdd6bf6306bebdd30aa3fedf0c26bc636ed"
|
||||
assert doc.content == "test text"
|
||||
assert doc.dataframe is not None
|
||||
assert doc.dataframe.equals(pd.DataFrame([0]))
|
||||
@ -63,6 +67,7 @@ def test_init_with_parameters():
|
||||
assert doc.meta == {"text": "test text"}
|
||||
assert doc.score == 0.812
|
||||
assert doc.embedding == [0.1, 0.2, 0.3]
|
||||
assert doc.sparse_embedding == sparse_embedding
|
||||
|
||||
|
||||
def test_init_with_legacy_fields():
|
||||
@ -76,6 +81,7 @@ def test_init_with_legacy_fields():
|
||||
assert doc.meta == {}
|
||||
assert doc.score == 0.812
|
||||
assert doc.embedding == [0.1, 0.2, 0.3]
|
||||
assert doc.sparse_embedding == None
|
||||
|
||||
|
||||
def test_init_with_legacy_field():
|
||||
@ -93,6 +99,7 @@ def test_init_with_legacy_field():
|
||||
assert doc.meta == {"date": "10-10-2023", "type": "article"}
|
||||
assert doc.score == 0.812
|
||||
assert doc.embedding == [0.1, 0.2, 0.3]
|
||||
assert doc.sparse_embedding == None
|
||||
|
||||
|
||||
def test_basic_equality_type_mismatch():
|
||||
@ -121,6 +128,7 @@ def test_to_dict():
|
||||
"blob": None,
|
||||
"score": None,
|
||||
"embedding": None,
|
||||
"sparse_embedding": None,
|
||||
}
|
||||
|
||||
|
||||
@ -134,6 +142,7 @@ def test_to_dict_without_flattening():
|
||||
"meta": {},
|
||||
"score": None,
|
||||
"embedding": None,
|
||||
"sparse_embedding": None,
|
||||
}
|
||||
|
||||
|
||||
@ -145,6 +154,7 @@ def test_to_dict_with_custom_parameters():
|
||||
meta={"some": "values", "test": 10},
|
||||
score=0.99,
|
||||
embedding=[10.0, 10.0],
|
||||
sparse_embedding=SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3]),
|
||||
)
|
||||
|
||||
assert doc.to_dict() == {
|
||||
@ -156,6 +166,7 @@ def test_to_dict_with_custom_parameters():
|
||||
"test": 10,
|
||||
"score": 0.99,
|
||||
"embedding": [10.0, 10.0],
|
||||
"sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]},
|
||||
}
|
||||
|
||||
|
||||
@ -167,6 +178,7 @@ def test_to_dict_with_custom_parameters_without_flattening():
|
||||
meta={"some": "values", "test": 10},
|
||||
score=0.99,
|
||||
embedding=[10.0, 10.0],
|
||||
sparse_embedding=SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3]),
|
||||
)
|
||||
|
||||
assert doc.to_dict(flatten=False) == {
|
||||
@ -177,6 +189,7 @@ def test_to_dict_with_custom_parameters_without_flattening():
|
||||
"meta": {"some": "values", "test": 10},
|
||||
"score": 0.99,
|
||||
"embedding": [10, 10],
|
||||
"sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]},
|
||||
}
|
||||
|
||||
|
||||
@ -194,6 +207,7 @@ def from_from_dict_with_parameters():
|
||||
"meta": {"text": "test text"},
|
||||
"score": 0.812,
|
||||
"embedding": [0.1, 0.2, 0.3],
|
||||
"sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]},
|
||||
}
|
||||
) == Document(
|
||||
content="test text",
|
||||
@ -202,6 +216,7 @@ def from_from_dict_with_parameters():
|
||||
meta={"text": "test text"},
|
||||
score=0.812,
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
sparse_embedding=SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3]),
|
||||
)
|
||||
|
||||
|
||||
@ -249,6 +264,7 @@ def test_from_dict_with_flat_meta():
|
||||
"blob": {"data": list(blob_data), "mime_type": "text/markdown"},
|
||||
"score": 0.812,
|
||||
"embedding": [0.1, 0.2, 0.3],
|
||||
"sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]},
|
||||
"date": "10-10-2023",
|
||||
"type": "article",
|
||||
}
|
||||
@ -258,6 +274,7 @@ def test_from_dict_with_flat_meta():
|
||||
blob=ByteStream(blob_data, mime_type="text/markdown"),
|
||||
score=0.812,
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
sparse_embedding=SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3]),
|
||||
meta={"date": "10-10-2023", "type": "article"},
|
||||
)
|
||||
|
||||
|
||||
23
test/dataclasses/test_sparse_embedding.py
Normal file
23
test/dataclasses/test_sparse_embedding.py
Normal file
@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
|
||||
from haystack.dataclasses.sparse_embedding import SparseEmbedding
|
||||
|
||||
|
||||
class TestSparseEmbedding:
|
||||
def test_init(self):
|
||||
se = SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3])
|
||||
assert se.indices == [0, 2, 4]
|
||||
assert se.values == [0.1, 0.2, 0.3]
|
||||
|
||||
def test_init_with_wrong_parameters(self):
|
||||
with pytest.raises(ValueError):
|
||||
SparseEmbedding(indices=[0, 2], values=[0.1, 0.2, 0.3, 0.4])
|
||||
|
||||
def test_to_dict(self):
|
||||
se = SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3])
|
||||
assert se.to_dict() == {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]}
|
||||
|
||||
def test_from_dict(self):
|
||||
se = SparseEmbedding.from_dict({"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]})
|
||||
assert se.indices == [0, 2, 4]
|
||||
assert se.values == [0.1, 0.2, 0.3]
|
||||
@ -25,15 +25,15 @@ class TestTypeCoercion:
|
||||
(NonSerializableClass(), "NonSerializableClass"),
|
||||
(
|
||||
Document(id="1", content="text"),
|
||||
'{"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null}',
|
||||
'{"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null, "sparse_embedding": null}',
|
||||
),
|
||||
(
|
||||
[Document(id="1", content="text")],
|
||||
'[{"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null}]',
|
||||
'[{"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null, "sparse_embedding": null}]',
|
||||
),
|
||||
(
|
||||
{"key": Document(id="1", content="text")},
|
||||
'{"key": {"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null}}',
|
||||
'{"key": {"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null, "sparse_embedding": null}}',
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user