mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 10:19:23 +00:00
feat: Add embed_meta_fields to Ranker nodes (#5361)
* Adding embed_meta_fields to ranker nodes * Fix tests by adding case where embed_meta_fields=None * Adding unit test for _add_meta_fields_to_docs * Fix pylint * Add unit test * Added another unit test. Caught a bug. * Adding more unit tests * Add unit test * Updating some older tests into unit tests using mocking * Convert another test to unit test * Test run method * One last unit test
This commit is contained in:
parent
e0cf1421c6
commit
f7642e83ea
@ -4,6 +4,7 @@ import logging
|
||||
from abc import abstractmethod
|
||||
from functools import wraps
|
||||
from time import perf_counter
|
||||
from copy import deepcopy
|
||||
|
||||
from haystack.schema import Document
|
||||
from haystack.nodes.base import BaseComponent
|
||||
@ -32,6 +33,37 @@ class BaseRanker(BaseComponent):
|
||||
) -> Union[List[Document], List[List[Document]]]:
|
||||
pass
|
||||
|
||||
def _add_meta_fields_to_docs(
|
||||
self, documents: List[Document], embed_meta_fields: Optional[List[str]] = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Concatenates specified metadata fields with the text representations.
|
||||
|
||||
:param documents: List of documents to add metadata to.
|
||||
:param embed_meta_fields: Concatenate the provided meta fields and into the text passage that is then used in
|
||||
reranking.
|
||||
:return: List of documents with metadata.
|
||||
"""
|
||||
if not embed_meta_fields:
|
||||
return documents
|
||||
|
||||
docs_with_meta = []
|
||||
for doc in documents:
|
||||
doc = deepcopy(doc)
|
||||
# Gather all relevant metadata fields
|
||||
meta_data_fields = []
|
||||
for key in embed_meta_fields:
|
||||
if key in doc.meta and doc.meta[key]:
|
||||
if isinstance(doc.meta[key], list):
|
||||
meta_data_fields.extend([item for item in doc.meta[key]])
|
||||
else:
|
||||
meta_data_fields.append(doc.meta[key])
|
||||
# Convert to type string (e.g. for ints or floats)
|
||||
meta_data_fields = [str(field) for field in meta_data_fields]
|
||||
doc.content = "\n".join(meta_data_fields + [doc.content])
|
||||
docs_with_meta.append(doc)
|
||||
return docs_with_meta
|
||||
|
||||
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): # type: ignore
|
||||
self.query_count += 1
|
||||
if documents:
|
||||
|
||||
@ -42,7 +42,12 @@ class CohereRanker(BaseRanker):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, model_name_or_path: str, top_k: int = 10, max_chunks_per_doc: Optional[int] = None
|
||||
self,
|
||||
api_key: str,
|
||||
model_name_or_path: str,
|
||||
top_k: int = 10,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
embed_meta_fields: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Creates an instance of CohereInvocationLayer for the specified Cohere model.
|
||||
@ -54,6 +59,8 @@ class CohereRanker(BaseRanker):
|
||||
chunks a document can be split into. If None, the default of 10 is used.
|
||||
For example, if your document is 6000 tokens, with the default of 10, the document will be split into 10
|
||||
chunks each of 512 tokens and the last 880 tokens will be disregarded.
|
||||
:param embed_meta_fields: Concatenate the provided meta fields and into the text passage that is then used in
|
||||
reranking. The original documents are returned so the concatenated metadata is not included in the returned documents.
|
||||
"""
|
||||
super().__init__()
|
||||
valid_api_key = isinstance(api_key, str) and api_key
|
||||
@ -71,6 +78,7 @@ class CohereRanker(BaseRanker):
|
||||
self.api_key = api_key
|
||||
self.top_k = top_k
|
||||
self.max_chunks_per_doc = max_chunks_per_doc
|
||||
self.embed_meta_fields = embed_meta_fields
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
@ -139,7 +147,10 @@ class CohereRanker(BaseRanker):
|
||||
top_k = self.top_k
|
||||
|
||||
# See https://docs.cohere.com/reference/rerank-1
|
||||
cohere_docs = [{"text": d.content} for d in documents]
|
||||
docs_with_meta_fields = self._add_meta_fields_to_docs(
|
||||
documents=documents, embed_meta_fields=self.embed_meta_fields
|
||||
)
|
||||
cohere_docs = [{"text": d.content} for d in docs_with_meta_fields]
|
||||
if len(cohere_docs) > 1000:
|
||||
logger.warning(
|
||||
"The Cohere reranking endpoint only supports 1000 documents. "
|
||||
@ -216,5 +227,6 @@ class CohereRanker(BaseRanker):
|
||||
|
||||
results = []
|
||||
for query, cur_docs in zip(queries, documents):
|
||||
assert isinstance(cur_docs, list)
|
||||
results.append(self.predict(query=query, documents=cur_docs, top_k=top_k)) # type: ignore
|
||||
return results
|
||||
|
||||
@ -56,6 +56,7 @@ class SentenceTransformersRanker(BaseRanker):
|
||||
scale_score: bool = True,
|
||||
progress_bar: bool = True,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
embed_meta_fields: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
:param model_name_or_path: Directory of a saved model or the name of a public model e.g.
|
||||
@ -78,6 +79,8 @@ class SentenceTransformersRanker(BaseRanker):
|
||||
A list containing torch device objects and/or strings is supported (For example
|
||||
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
|
||||
parameter is not used and a single cpu device is used for inference.
|
||||
:param embed_meta_fields: Concatenate the provided meta fields and into the text passage that is then used in
|
||||
reranking. The original documents are returned so the concatenated metadata is not included in the returned documents.
|
||||
"""
|
||||
torch_and_transformers_import.check()
|
||||
super().__init__()
|
||||
@ -109,6 +112,7 @@ class SentenceTransformersRanker(BaseRanker):
|
||||
self.model = DataParallel(self.transformer_model, device_ids=self.devices)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.embed_meta_fields = embed_meta_fields
|
||||
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Document]:
|
||||
"""
|
||||
@ -124,12 +128,12 @@ class SentenceTransformersRanker(BaseRanker):
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
|
||||
docs_with_meta_fields = self._add_meta_fields_to_docs(
|
||||
documents=documents, embed_meta_fields=self.embed_meta_fields
|
||||
)
|
||||
docs = [doc.content for doc in docs_with_meta_fields]
|
||||
features = self.transformer_tokenizer(
|
||||
[query for _ in documents],
|
||||
[doc.content for doc in documents],
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
[query for _ in documents], docs, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(self.devices[0])
|
||||
|
||||
# SentenceTransformerRanker uses:
|
||||
@ -214,9 +218,12 @@ class SentenceTransformersRanker(BaseRanker):
|
||||
number_of_docs, all_queries, all_docs, single_list_of_docs = self._preprocess_batch_queries_and_docs(
|
||||
queries=queries, documents=documents
|
||||
)
|
||||
all_docs_with_meta_fields = self._add_meta_fields_to_docs(
|
||||
documents=all_docs, embed_meta_fields=self.embed_meta_fields
|
||||
)
|
||||
|
||||
batches = self._get_batches(all_queries=all_queries, all_docs=all_docs, batch_size=batch_size)
|
||||
pb = tqdm(total=len(all_docs), disable=not self.progress_bar, desc="Ranking")
|
||||
batches = self._get_batches(all_queries=all_queries, all_docs=all_docs_with_meta_fields, batch_size=batch_size)
|
||||
pb = tqdm(total=len(all_docs_with_meta_fields), disable=not self.progress_bar, desc="Ranking")
|
||||
preds = []
|
||||
for cur_queries, cur_docs in batches:
|
||||
features = self.transformer_tokenizer(
|
||||
|
||||
@ -2,6 +2,7 @@ import pytest
|
||||
import math
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from haystack.schema import Document
|
||||
from haystack.nodes.ranker.base import BaseRanker
|
||||
from haystack.nodes.ranker import SentenceTransformersRanker, CohereRanker
|
||||
@ -52,6 +53,52 @@ def mock_cohere_post():
|
||||
yield cohere_post
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_transformer_tokenizer():
|
||||
class Features(dict):
|
||||
def to(self, arg):
|
||||
return self
|
||||
|
||||
class Tokenizer:
|
||||
def __call__(self, *args, **kwargs):
|
||||
return Features(
|
||||
{
|
||||
"input_ids": torch.zeros([5, 162]),
|
||||
"token_type_ids": torch.zeros([5, 162], dtype=torch.long),
|
||||
"attention_mask": torch.zeros([5, 162], dtype=torch.long),
|
||||
}
|
||||
)
|
||||
|
||||
with patch("transformers.AutoTokenizer.from_pretrained") as mock_tokenizer:
|
||||
mock_tokenizer.return_value = Tokenizer()
|
||||
yield mock_tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_transformer_model():
|
||||
class Logits:
|
||||
def __init__(self, logits):
|
||||
self.logits = logits
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
self.logits = torch.tensor([[-9.7414], [-11.1572], [-11.1708], [-11.1515], [5.2571]])
|
||||
self.num_labels = 1
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return Logits(logits=self.logits)
|
||||
|
||||
def eval(self):
|
||||
return self
|
||||
|
||||
def to(self, arg):
|
||||
return self
|
||||
|
||||
with patch("transformers.AutoModelForSequenceClassification.from_pretrained") as mock_model:
|
||||
mock_model.return_value = Model()
|
||||
yield mock_model
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ranker_preprocess_batch_queries_and_docs_raises():
|
||||
query_1 = "query 1"
|
||||
@ -109,19 +156,102 @@ def test_ranker_get_batches():
|
||||
assert next(batches) == (all_queries[0:1], all_docs[0:1])
|
||||
|
||||
|
||||
def test_ranker(ranker, docs):
|
||||
@pytest.mark.unit
|
||||
def test_add_meta_fields_to_docs():
|
||||
docs = [
|
||||
Document(
|
||||
content="dummy doc 1",
|
||||
meta={
|
||||
"str_field": "test1",
|
||||
"empty_str_field": "",
|
||||
"numeric_field": 2.0,
|
||||
"list_field": ["item0.1", "item0.2"],
|
||||
"empty_list_field": [],
|
||||
},
|
||||
),
|
||||
Document(
|
||||
content="dummy doc 2",
|
||||
meta={
|
||||
"str_field": "test2",
|
||||
"empty_str_field": "",
|
||||
"numeric_field": 5.0,
|
||||
"list_field": ["item1.1", "item1.2"],
|
||||
"empty_list_field": [],
|
||||
},
|
||||
),
|
||||
]
|
||||
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
|
||||
mock_ranker_init.return_value = None
|
||||
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||
docs_with_meta = ranker._add_meta_fields_to_docs(
|
||||
documents=docs,
|
||||
embed_meta_fields=["str_field", "empty_str_field", "numeric_field", "list_field", "empty_list_field"],
|
||||
)
|
||||
assert docs_with_meta[0].content.startswith("test1\n2.0\nitem0.1\nitem0.2\ndummy doc 1")
|
||||
assert docs_with_meta[1].content.startswith("test2\n5.0\nitem1.1\nitem1.2\ndummy doc 2")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_meta_fields_to_docs_none():
|
||||
docs = [Document(content="dummy doc 1", meta={"none_field": None})]
|
||||
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
|
||||
mock_ranker_init.return_value = None
|
||||
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||
docs_with_meta = ranker._add_meta_fields_to_docs(documents=docs, embed_meta_fields=["none_field"])
|
||||
assert docs_with_meta == docs
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_meta_fields_to_docs_non_existent():
|
||||
docs = [Document(content="dummy doc 1", meta={"test_field": "A string"})]
|
||||
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
|
||||
mock_ranker_init.return_value = None
|
||||
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||
docs_with_meta = ranker._add_meta_fields_to_docs(documents=docs, embed_meta_fields=["wrong_field"])
|
||||
assert docs_with_meta == docs
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_meta_fields_to_docs_empty_list():
|
||||
docs = [Document(content="dummy doc 1", meta={"test_field": "A string"})]
|
||||
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
|
||||
mock_ranker_init.return_value = None
|
||||
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||
docs_with_meta = ranker._add_meta_fields_to_docs(documents=docs, embed_meta_fields=[])
|
||||
assert docs_with_meta == docs
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ranker(docs, mock_transformer_model, mock_transformer_tokenizer):
|
||||
with patch("torch.nn.DataParallel"):
|
||||
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||
query = "What is the most important building in King's Landing that has a religious background?"
|
||||
results = ranker.predict(query=query, documents=docs)
|
||||
assert results[0] == docs[4]
|
||||
|
||||
|
||||
def test_ranker_batch_single_query_single_doc_list(ranker, docs):
|
||||
@pytest.mark.unit
|
||||
def test_ranker_run(docs, mock_transformer_model, mock_transformer_tokenizer):
|
||||
with patch("torch.nn.DataParallel"):
|
||||
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||
query = "What is the most important building in King's Landing that has a religious background?"
|
||||
results = ranker.run(query=query, documents=docs)
|
||||
assert results[0]["documents"][0] == docs[4]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ranker_batch_single_query_single_doc_list(docs, mock_transformer_model, mock_transformer_tokenizer):
|
||||
with patch("torch.nn.DataParallel"):
|
||||
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||
query = "What is the most important building in King's Landing that has a religious background?"
|
||||
results = ranker.predict_batch(queries=[query], documents=docs)
|
||||
assert results[0] == docs[4]
|
||||
|
||||
|
||||
def test_ranker_batch_single_query_multiple_doc_lists(ranker, docs):
|
||||
@pytest.mark.unit
|
||||
def test_ranker_batch_single_query_multiple_doc_lists(docs, mock_transformer_model, mock_transformer_tokenizer):
|
||||
with patch("torch.nn.DataParallel"):
|
||||
ranker = SentenceTransformersRanker(model_name_or_path="fake_model", batch_size=5)
|
||||
query = "What is the most important building in King's Landing that has a religious background?"
|
||||
results = ranker.predict_batch(queries=[query], documents=[docs, docs])
|
||||
assert isinstance(results, list)
|
||||
@ -140,6 +270,15 @@ def test_ranker_batch_multiple_queries_multiple_doc_lists(ranker, docs):
|
||||
assert results[1][0] == docs[1]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ranker_with_embed_meta_fields(docs, mock_transformer_model, mock_transformer_tokenizer):
|
||||
with patch("torch.nn.DataParallel"):
|
||||
ranker = SentenceTransformersRanker(model_name_or_path="fake_model", embed_meta_fields=["name"])
|
||||
query = "What is the most important building in King's Landing that has a religious background?"
|
||||
results = ranker.predict(query=query, documents=docs)
|
||||
assert results[0] == docs[4]
|
||||
|
||||
|
||||
def test_ranker_two_logits(ranker_two_logits, docs):
|
||||
assert isinstance(ranker_two_logits, BaseRanker)
|
||||
assert isinstance(ranker_two_logits, SentenceTransformersRanker)
|
||||
@ -227,9 +366,7 @@ def test_ranker_returns_raw_score_for_two_logits(ranker_two_logits):
|
||||
|
||||
def test_predict_batch_returns_correct_number_of_docs(ranker):
|
||||
docs = [Document(content=f"test {number}") for number in range(5)]
|
||||
|
||||
assert len(ranker.predict("where is test 3?", docs, top_k=4)) == 4
|
||||
|
||||
assert len(ranker.predict_batch(["where is test 3?"], docs, batch_size=2, top_k=4)) == 4
|
||||
|
||||
|
||||
@ -251,6 +388,32 @@ def test_cohere_ranker(docs, mock_cohere_post):
|
||||
assert results[0] == docs[4]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cohere_ranker_with_embed_meta_fields(docs, mock_cohere_post):
|
||||
query = "What is the most important building in King's Landing that has a religious background?"
|
||||
ranker = CohereRanker(api_key="fake_key", model_name_or_path="rerank-english-v2.0", embed_meta_fields=["name"])
|
||||
results = ranker.predict(query=query, documents=docs)
|
||||
# Prep expected input
|
||||
documents = []
|
||||
for d in docs:
|
||||
meta = d.meta.get("name")
|
||||
if meta:
|
||||
documents.append({"text": d.meta["name"] + "\n" + d.content})
|
||||
else:
|
||||
documents.append({"text": d.content})
|
||||
mock_cohere_post.assert_called_once_with(
|
||||
{
|
||||
"model": "rerank-english-v2.0",
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": None, # By passing None we return all documents and use top_k to truncate later
|
||||
"return_documents": False,
|
||||
"max_chunks_per_doc": None,
|
||||
}
|
||||
)
|
||||
assert results[0] == docs[4]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cohere_ranker_batch_single_query_single_doc_list(docs, mock_cohere_post):
|
||||
query = "What is the most important building in King's Landing that has a religious background?"
|
||||
@ -269,6 +432,31 @@ def test_cohere_ranker_batch_single_query_single_doc_list(docs, mock_cohere_post
|
||||
assert results[0] == docs[4]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cohere_ranker_batch_single_query_single_doc_list_with_embed_meta_fields(docs, mock_cohere_post):
|
||||
query = "What is the most important building in King's Landing that has a religious background?"
|
||||
ranker = CohereRanker(api_key="fake_key", model_name_or_path="rerank-english-v2.0", embed_meta_fields=["name"])
|
||||
results = ranker.predict_batch(queries=[query], documents=docs)
|
||||
documents = []
|
||||
for d in docs:
|
||||
meta = d.meta.get("name")
|
||||
if meta:
|
||||
documents.append({"text": d.meta["name"] + "\n" + d.content})
|
||||
else:
|
||||
documents.append({"text": d.content})
|
||||
mock_cohere_post.assert_called_once_with(
|
||||
{
|
||||
"model": "rerank-english-v2.0",
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": None, # By passing None we return all documents and use top_k to truncate later
|
||||
"return_documents": False,
|
||||
"max_chunks_per_doc": None,
|
||||
}
|
||||
)
|
||||
assert results[0] == docs[4]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cohere_ranker_batch_single_query_multiple_doc_lists(docs, mock_cohere_post):
|
||||
query = "What is the most important building in King's Landing that has a religious background?"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user