feat: Add meta_fields_to_embed to TransformersSimilarityRanker (#6564)

* Add initial implementation following SentenceTransformersDocumentEmbedder

* Add test for embedding metadata

* Add release notes

* Update name

* Fix tests and to dict

* Fix release notes
This commit is contained in:
Sebastian Husch Lee 2023-12-18 11:28:16 +01:00 committed by GitHub
parent 0ac1bdc6a0
commit 3e0e81b1e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 1 deletions

View File

@ -40,6 +40,8 @@ class TransformersSimilarityRanker:
device: str = "cpu",
token: Union[bool, str, None] = None,
top_k: int = 10,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
"""
Creates an instance of TransformersSimilarityRanker.
@ -51,6 +53,8 @@ class TransformersSimilarityRanker:
If this parameter is set to `True`, the token generated when running
`transformers-cli login` (stored in ~/.huggingface) is used.
:param top_k: The maximum number of Documents to return per query.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""
torch_and_transformers_import.check()
@ -62,6 +66,8 @@ class TransformersSimilarityRanker:
self.token = token
self.model = None
self.tokenizer = None
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
@ -89,6 +95,8 @@ class TransformersSimilarityRanker:
model_name_or_path=self.model_name_or_path,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
top_k=self.top_k,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)
@component.output_types(documents=List[Document])
@ -116,7 +124,14 @@ class TransformersSimilarityRanker:
f"The component {self.__class__.__name__} wasn't warmed up. Run 'warm_up()' before calling 'run()'."
)
query_doc_pairs = [[query, doc.content] for doc in documents]
query_doc_pairs = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key]
]
text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""])
query_doc_pairs.append([query, text_to_embed])
features = self.tokenizer(
query_doc_pairs, padding=True, truncation=True, return_tensors="pt"
).to( # type: ignore

View File

@ -0,0 +1,5 @@
---
enhancements:
- |
Add meta_fields_to_embed following the implementation in SentenceTransformersDocumentEmbedder to be able to
embed meta fields along with the content of the document.

View File

@ -1,4 +1,6 @@
from unittest.mock import MagicMock, patch
import pytest
import torch
from haystack import Document, ComponentError
from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker
@ -15,6 +17,8 @@ class TestSimilarityRanker:
"top_k": 10,
"token": None,
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"meta_fields_to_embed": [],
"embedding_separator": "\n",
},
}
@ -30,9 +34,37 @@ class TestSimilarityRanker:
"model_name_or_path": "my_model",
"token": None, # we don't serialize valid tokens,
"top_k": 5,
"meta_fields_to_embed": [],
"embedding_separator": "\n",
},
}
@patch("torch.sort")
def test_embed_meta(self, mocked_sort):
mocked_sort.return_value = (None, torch.tensor([0]))
embedder = TransformersSimilarityRanker(
model_name_or_path="model", meta_fields_to_embed=["meta_field"], embedding_separator="\n"
)
embedder.model = MagicMock()
embedder.tokenizer = MagicMock()
documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
embedder.run(query="test", documents=documents)
embedder.tokenizer.assert_called_once_with(
[
["test", "meta_value 0\ndocument number 0"],
["test", "meta_value 1\ndocument number 1"],
["test", "meta_value 2\ndocument number 2"],
["test", "meta_value 3\ndocument number 3"],
["test", "meta_value 4\ndocument number 4"],
],
padding=True,
truncation=True,
return_tensors="pt",
)
@pytest.mark.integration
@pytest.mark.parametrize(
"query,docs_before_texts,expected_first_text",