mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 12:07:04 +00:00
feat: Add meta_fields_to_embed to TransformersSimilarityRanker (#6564)
* Add initial implementation following SentenceTransformersDocumentEmbedder * Add test for embedding metadata * Add release notes * Update name * Fix tests and to dict * Fix release notes
This commit is contained in:
parent
0ac1bdc6a0
commit
3e0e81b1e0
@ -40,6 +40,8 @@ class TransformersSimilarityRanker:
|
||||
device: str = "cpu",
|
||||
token: Union[bool, str, None] = None,
|
||||
top_k: int = 10,
|
||||
meta_fields_to_embed: Optional[List[str]] = None,
|
||||
embedding_separator: str = "\n",
|
||||
):
|
||||
"""
|
||||
Creates an instance of TransformersSimilarityRanker.
|
||||
@ -51,6 +53,8 @@ class TransformersSimilarityRanker:
|
||||
If this parameter is set to `True`, the token generated when running
|
||||
`transformers-cli login` (stored in ~/.huggingface) is used.
|
||||
:param top_k: The maximum number of Documents to return per query.
|
||||
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
|
||||
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
|
||||
"""
|
||||
torch_and_transformers_import.check()
|
||||
|
||||
@ -62,6 +66,8 @@ class TransformersSimilarityRanker:
|
||||
self.token = token
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.meta_fields_to_embed = meta_fields_to_embed or []
|
||||
self.embedding_separator = embedding_separator
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -89,6 +95,8 @@ class TransformersSimilarityRanker:
|
||||
model_name_or_path=self.model_name_or_path,
|
||||
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
|
||||
top_k=self.top_k,
|
||||
meta_fields_to_embed=self.meta_fields_to_embed,
|
||||
embedding_separator=self.embedding_separator,
|
||||
)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
@ -116,7 +124,14 @@ class TransformersSimilarityRanker:
|
||||
f"The component {self.__class__.__name__} wasn't warmed up. Run 'warm_up()' before calling 'run()'."
|
||||
)
|
||||
|
||||
query_doc_pairs = [[query, doc.content] for doc in documents]
|
||||
query_doc_pairs = []
|
||||
for doc in documents:
|
||||
meta_values_to_embed = [
|
||||
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key]
|
||||
]
|
||||
text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""])
|
||||
query_doc_pairs.append([query, text_to_embed])
|
||||
|
||||
features = self.tokenizer(
|
||||
query_doc_pairs, padding=True, truncation=True, return_tensors="pt"
|
||||
).to( # type: ignore
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Add meta_fields_to_embed following the implementation in SentenceTransformersDocumentEmbedder to be able to
|
||||
embed meta fields along with the content of the document.
|
||||
@ -1,4 +1,6 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from haystack import Document, ComponentError
|
||||
from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker
|
||||
@ -15,6 +17,8 @@ class TestSimilarityRanker:
|
||||
"top_k": 10,
|
||||
"token": None,
|
||||
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
"meta_fields_to_embed": [],
|
||||
"embedding_separator": "\n",
|
||||
},
|
||||
}
|
||||
|
||||
@ -30,9 +34,37 @@ class TestSimilarityRanker:
|
||||
"model_name_or_path": "my_model",
|
||||
"token": None, # we don't serialize valid tokens,
|
||||
"top_k": 5,
|
||||
"meta_fields_to_embed": [],
|
||||
"embedding_separator": "\n",
|
||||
},
|
||||
}
|
||||
|
||||
@patch("torch.sort")
|
||||
def test_embed_meta(self, mocked_sort):
|
||||
mocked_sort.return_value = (None, torch.tensor([0]))
|
||||
embedder = TransformersSimilarityRanker(
|
||||
model_name_or_path="model", meta_fields_to_embed=["meta_field"], embedding_separator="\n"
|
||||
)
|
||||
embedder.model = MagicMock()
|
||||
embedder.tokenizer = MagicMock()
|
||||
|
||||
documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
|
||||
|
||||
embedder.run(query="test", documents=documents)
|
||||
|
||||
embedder.tokenizer.assert_called_once_with(
|
||||
[
|
||||
["test", "meta_value 0\ndocument number 0"],
|
||||
["test", "meta_value 1\ndocument number 1"],
|
||||
["test", "meta_value 2\ndocument number 2"],
|
||||
["test", "meta_value 3\ndocument number 3"],
|
||||
["test", "meta_value 4\ndocument number 4"],
|
||||
],
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
"query,docs_before_texts,expected_first_text",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user