feat: Add embed_meta_fields to Ranker nodes (#5361)

* Adding embed_meta_fields to ranker nodes

* Fix tests by adding case where embed_meta_fields=None

* Adding unit test for _add_meta_fields_to_docs

* Fix pylint

* Add unit test

* Added another unit test. Caught a bug.

* Adding more unit tests

* Add unit test

* Updating some older tests into unit tests using mocking

* Convert another test to unit test

* Test run method

* One last unit test
This commit is contained in:
Sebastian Husch Lee 2023-07-18 09:11:51 +02:00 committed by GitHub
parent e0cf1421c6
commit f7642e83ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 253 additions and 14 deletions

View File

@ -4,6 +4,7 @@ import logging
from abc import abstractmethod
from functools import wraps
from time import perf_counter
from copy import deepcopy
from haystack.schema import Document
from haystack.nodes.base import BaseComponent
@ -32,6 +33,37 @@ class BaseRanker(BaseComponent):
) -> Union[List[Document], List[List[Document]]]:
pass
def _add_meta_fields_to_docs(
self, documents: List[Document], embed_meta_fields: Optional[List[str]] = None
) -> List[Document]:
"""
Concatenates specified metadata fields with the text representations.
:param documents: List of documents to add metadata to.
:param embed_meta_fields: Concatenate the provided meta fields and into the text passage that is then used in
reranking.
:return: List of documents with metadata.
"""
if not embed_meta_fields:
return documents
docs_with_meta = []
for doc in documents:
doc = deepcopy(doc)
# Gather all relevant metadata fields
meta_data_fields = []
for key in embed_meta_fields:
if key in doc.meta and doc.meta[key]:
if isinstance(doc.meta[key], list):
meta_data_fields.extend([item for item in doc.meta[key]])
else:
meta_data_fields.append(doc.meta[key])
# Convert to type string (e.g. for ints or floats)
meta_data_fields = [str(field) for field in meta_data_fields]
doc.content = "\n".join(meta_data_fields + [doc.content])
docs_with_meta.append(doc)
return docs_with_meta
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): # type: ignore
self.query_count += 1
if documents:

View File

@ -42,7 +42,12 @@ class CohereRanker(BaseRanker):
"""
def __init__(
self, api_key: str, model_name_or_path: str, top_k: int = 10, max_chunks_per_doc: Optional[int] = None
self,
api_key: str,
model_name_or_path: str,
top_k: int = 10,
max_chunks_per_doc: Optional[int] = None,
embed_meta_fields: Optional[List[str]] = None,
):
"""
Creates an instance of CohereInvocationLayer for the specified Cohere model.
@ -54,6 +59,8 @@ class CohereRanker(BaseRanker):
chunks a document can be split into. If None, the default of 10 is used.
For example, if your document is 6000 tokens, with the default of 10, the document will be split into 10
chunks each of 512 tokens and the last 880 tokens will be disregarded.
:param embed_meta_fields: Concatenate the provided meta fields and into the text passage that is then used in
reranking. The original documents are returned so the concatenated metadata is not included in the returned documents.
"""
super().__init__()
valid_api_key = isinstance(api_key, str) and api_key
@ -71,6 +78,7 @@ class CohereRanker(BaseRanker):
self.api_key = api_key
self.top_k = top_k
self.max_chunks_per_doc = max_chunks_per_doc
self.embed_meta_fields = embed_meta_fields
@property
def url(self) -> str:
@ -139,7 +147,10 @@ class CohereRanker(BaseRanker):
top_k = self.top_k
# See https://docs.cohere.com/reference/rerank-1
cohere_docs = [{"text": d.content} for d in documents]
docs_with_meta_fields = self._add_meta_fields_to_docs(
documents=documents, embed_meta_fields=self.embed_meta_fields
)
cohere_docs = [{"text": d.content} for d in docs_with_meta_fields]
if len(cohere_docs) > 1000:
logger.warning(
"The Cohere reranking endpoint only supports 1000 documents. "
@ -216,5 +227,6 @@ class CohereRanker(BaseRanker):
results = []
for query, cur_docs in zip(queries, documents):
assert isinstance(cur_docs, list)
results.append(self.predict(query=query, documents=cur_docs, top_k=top_k)) # type: ignore
return results

View File

@ -56,6 +56,7 @@ class SentenceTransformersRanker(BaseRanker):
scale_score: bool = True,
progress_bar: bool = True,
use_auth_token: Optional[Union[str, bool]] = None,
embed_meta_fields: Optional[List[str]] = None,
):
"""
:param model_name_or_path: Directory of a saved model or the name of a public model e.g.
@ -78,6 +79,8 @@ class SentenceTransformersRanker(BaseRanker):
A list containing torch device objects and/or strings is supported (For example
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
parameter is not used and a single cpu device is used for inference.
:param embed_meta_fields: Concatenate the provided meta fields and into the text passage that is then used in
reranking. The original documents are returned so the concatenated metadata is not included in the returned documents.
"""
torch_and_transformers_import.check()
super().__init__()
@ -109,6 +112,7 @@ class SentenceTransformersRanker(BaseRanker):
self.model = DataParallel(self.transformer_model, device_ids=self.devices)
self.batch_size = batch_size
self.embed_meta_fields = embed_meta_fields
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Document]:
"""
@ -124,12 +128,12 @@ class SentenceTransformersRanker(BaseRanker):
if top_k is None:
top_k = self.top_k
docs_with_meta_fields = self._add_meta_fields_to_docs(
documents=documents, embed_meta_fields=self.embed_meta_fields
)
docs = [doc.content for doc in docs_with_meta_fields]
features = self.transformer_tokenizer(
[query for _ in documents],
[doc.content for doc in documents],
padding=True,
truncation=True,
return_tensors="pt",
[query for _ in documents], docs, padding=True, truncation=True, return_tensors="pt"
).to(self.devices[0])
# SentenceTransformerRanker uses:
@ -214,9 +218,12 @@ class SentenceTransformersRanker(BaseRanker):
number_of_docs, all_queries, all_docs, single_list_of_docs = self._preprocess_batch_queries_and_docs(
queries=queries, documents=documents
)
all_docs_with_meta_fields = self._add_meta_fields_to_docs(
documents=all_docs, embed_meta_fields=self.embed_meta_fields
)
batches = self._get_batches(all_queries=all_queries, all_docs=all_docs, batch_size=batch_size)
pb = tqdm(total=len(all_docs), disable=not self.progress_bar, desc="Ranking")
batches = self._get_batches(all_queries=all_queries, all_docs=all_docs_with_meta_fields, batch_size=batch_size)
pb = tqdm(total=len(all_docs_with_meta_fields), disable=not self.progress_bar, desc="Ranking")
preds = []
for cur_queries, cur_docs in batches:
features = self.transformer_tokenizer(

View File

@ -2,6 +2,7 @@ import pytest
import math
from unittest.mock import patch
import torch
from haystack.schema import Document
from haystack.nodes.ranker.base import BaseRanker
from haystack.nodes.ranker import SentenceTransformersRanker, CohereRanker
@ -52,6 +53,52 @@ def mock_cohere_post():
yield cohere_post
@pytest.fixture
def mock_transformer_tokenizer():
class Features(dict):
def to(self, arg):
return self
class Tokenizer:
def __call__(self, *args, **kwargs):
return Features(
{
"input_ids": torch.zeros([5, 162]),
"token_type_ids": torch.zeros([5, 162], dtype=torch.long),
"attention_mask": torch.zeros([5, 162], dtype=torch.long),
}
)
with patch("transformers.AutoTokenizer.from_pretrained") as mock_tokenizer:
mock_tokenizer.return_value = Tokenizer()
yield mock_tokenizer
@pytest.fixture
def mock_transformer_model():
class Logits:
def __init__(self, logits):
self.logits = logits
class Model:
def __init__(self):
self.logits = torch.tensor([[-9.7414], [-11.1572], [-11.1708], [-11.1515], [5.2571]])
self.num_labels = 1
def __call__(self, *args, **kwargs):
return Logits(logits=self.logits)
def eval(self):
return self
def to(self, arg):
return self
with patch("transformers.AutoModelForSequenceClassification.from_pretrained") as mock_model:
mock_model.return_value = Model()
yield mock_model
@pytest.mark.unit
def test_ranker_preprocess_batch_queries_and_docs_raises():
query_1 = "query 1"
@ -109,19 +156,102 @@ def test_ranker_get_batches():
assert next(batches) == (all_queries[0:1], all_docs[0:1])
def test_ranker(ranker, docs):
@pytest.mark.unit
def test_add_meta_fields_to_docs():
docs = [
Document(
content="dummy doc 1",
meta={
"str_field": "test1",
"empty_str_field": "",
"numeric_field": 2.0,
"list_field": ["item0.1", "item0.2"],
"empty_list_field": [],
},
),
Document(
content="dummy doc 2",
meta={
"str_field": "test2",
"empty_str_field": "",
"numeric_field": 5.0,
"list_field": ["item1.1", "item1.2"],
"empty_list_field": [],
},
),
]
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
mock_ranker_init.return_value = None
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
docs_with_meta = ranker._add_meta_fields_to_docs(
documents=docs,
embed_meta_fields=["str_field", "empty_str_field", "numeric_field", "list_field", "empty_list_field"],
)
assert docs_with_meta[0].content.startswith("test1\n2.0\nitem0.1\nitem0.2\ndummy doc 1")
assert docs_with_meta[1].content.startswith("test2\n5.0\nitem1.1\nitem1.2\ndummy doc 2")
@pytest.mark.unit
def test_add_meta_fields_to_docs_none():
docs = [Document(content="dummy doc 1", meta={"none_field": None})]
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
mock_ranker_init.return_value = None
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
docs_with_meta = ranker._add_meta_fields_to_docs(documents=docs, embed_meta_fields=["none_field"])
assert docs_with_meta == docs
@pytest.mark.unit
def test_add_meta_fields_to_docs_non_existent():
docs = [Document(content="dummy doc 1", meta={"test_field": "A string"})]
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
mock_ranker_init.return_value = None
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
docs_with_meta = ranker._add_meta_fields_to_docs(documents=docs, embed_meta_fields=["wrong_field"])
assert docs_with_meta == docs
@pytest.mark.unit
def test_add_meta_fields_to_docs_empty_list():
docs = [Document(content="dummy doc 1", meta={"test_field": "A string"})]
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
mock_ranker_init.return_value = None
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
docs_with_meta = ranker._add_meta_fields_to_docs(documents=docs, embed_meta_fields=[])
assert docs_with_meta == docs
@pytest.mark.unit
def test_ranker(docs, mock_transformer_model, mock_transformer_tokenizer):
with patch("torch.nn.DataParallel"):
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
query = "What is the most important building in King's Landing that has a religious background?"
results = ranker.predict(query=query, documents=docs)
assert results[0] == docs[4]
def test_ranker_batch_single_query_single_doc_list(ranker, docs):
@pytest.mark.unit
def test_ranker_run(docs, mock_transformer_model, mock_transformer_tokenizer):
with patch("torch.nn.DataParallel"):
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
query = "What is the most important building in King's Landing that has a religious background?"
results = ranker.run(query=query, documents=docs)
assert results[0]["documents"][0] == docs[4]
@pytest.mark.unit
def test_ranker_batch_single_query_single_doc_list(docs, mock_transformer_model, mock_transformer_tokenizer):
with patch("torch.nn.DataParallel"):
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
query = "What is the most important building in King's Landing that has a religious background?"
results = ranker.predict_batch(queries=[query], documents=docs)
assert results[0] == docs[4]
def test_ranker_batch_single_query_multiple_doc_lists(ranker, docs):
@pytest.mark.unit
def test_ranker_batch_single_query_multiple_doc_lists(docs, mock_transformer_model, mock_transformer_tokenizer):
with patch("torch.nn.DataParallel"):
ranker = SentenceTransformersRanker(model_name_or_path="fake_model", batch_size=5)
query = "What is the most important building in King's Landing that has a religious background?"
results = ranker.predict_batch(queries=[query], documents=[docs, docs])
assert isinstance(results, list)
@ -140,6 +270,15 @@ def test_ranker_batch_multiple_queries_multiple_doc_lists(ranker, docs):
assert results[1][0] == docs[1]
@pytest.mark.unit
def test_ranker_with_embed_meta_fields(docs, mock_transformer_model, mock_transformer_tokenizer):
with patch("torch.nn.DataParallel"):
ranker = SentenceTransformersRanker(model_name_or_path="fake_model", embed_meta_fields=["name"])
query = "What is the most important building in King's Landing that has a religious background?"
results = ranker.predict(query=query, documents=docs)
assert results[0] == docs[4]
def test_ranker_two_logits(ranker_two_logits, docs):
assert isinstance(ranker_two_logits, BaseRanker)
assert isinstance(ranker_two_logits, SentenceTransformersRanker)
@ -227,9 +366,7 @@ def test_ranker_returns_raw_score_for_two_logits(ranker_two_logits):
def test_predict_batch_returns_correct_number_of_docs(ranker):
docs = [Document(content=f"test {number}") for number in range(5)]
assert len(ranker.predict("where is test 3?", docs, top_k=4)) == 4
assert len(ranker.predict_batch(["where is test 3?"], docs, batch_size=2, top_k=4)) == 4
@ -251,6 +388,32 @@ def test_cohere_ranker(docs, mock_cohere_post):
assert results[0] == docs[4]
@pytest.mark.unit
def test_cohere_ranker_with_embed_meta_fields(docs, mock_cohere_post):
query = "What is the most important building in King's Landing that has a religious background?"
ranker = CohereRanker(api_key="fake_key", model_name_or_path="rerank-english-v2.0", embed_meta_fields=["name"])
results = ranker.predict(query=query, documents=docs)
# Prep expected input
documents = []
for d in docs:
meta = d.meta.get("name")
if meta:
documents.append({"text": d.meta["name"] + "\n" + d.content})
else:
documents.append({"text": d.content})
mock_cohere_post.assert_called_once_with(
{
"model": "rerank-english-v2.0",
"query": query,
"documents": documents,
"top_n": None, # By passing None we return all documents and use top_k to truncate later
"return_documents": False,
"max_chunks_per_doc": None,
}
)
assert results[0] == docs[4]
@pytest.mark.unit
def test_cohere_ranker_batch_single_query_single_doc_list(docs, mock_cohere_post):
query = "What is the most important building in King's Landing that has a religious background?"
@ -269,6 +432,31 @@ def test_cohere_ranker_batch_single_query_single_doc_list(docs, mock_cohere_post
assert results[0] == docs[4]
@pytest.mark.unit
def test_cohere_ranker_batch_single_query_single_doc_list_with_embed_meta_fields(docs, mock_cohere_post):
query = "What is the most important building in King's Landing that has a religious background?"
ranker = CohereRanker(api_key="fake_key", model_name_or_path="rerank-english-v2.0", embed_meta_fields=["name"])
results = ranker.predict_batch(queries=[query], documents=docs)
documents = []
for d in docs:
meta = d.meta.get("name")
if meta:
documents.append({"text": d.meta["name"] + "\n" + d.content})
else:
documents.append({"text": d.content})
mock_cohere_post.assert_called_once_with(
{
"model": "rerank-english-v2.0",
"query": query,
"documents": documents,
"top_n": None, # By passing None we return all documents and use top_k to truncate later
"return_documents": False,
"max_chunks_per_doc": None,
}
)
assert results[0] == docs[4]
@pytest.mark.unit
def test_cohere_ranker_batch_single_query_multiple_doc_lists(docs, mock_cohere_post):
query = "What is the most important building in King's Landing that has a religious background?"