mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-03 03:09:28 +00:00
feat: Add embed_meta_fields to Ranker nodes (#5361)
* Adding embed_meta_fields to ranker nodes * Fix tests by adding case where embed_meta_fields=None * Adding unit test for _add_meta_fields_to_docs * Fix pylint * Add unit test * Added another unit test. Caught a bug. * Adding more unit tests * Add unit test * Updating some older tests into unit tests using mocking * Convert another test to unit test * Test run method * One last unit test
This commit is contained in:
parent
e0cf1421c6
commit
f7642e83ea
@ -4,6 +4,7 @@ import logging
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
from haystack.schema import Document
|
from haystack.schema import Document
|
||||||
from haystack.nodes.base import BaseComponent
|
from haystack.nodes.base import BaseComponent
|
||||||
@ -32,6 +33,37 @@ class BaseRanker(BaseComponent):
|
|||||||
) -> Union[List[Document], List[List[Document]]]:
|
) -> Union[List[Document], List[List[Document]]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _add_meta_fields_to_docs(
|
||||||
|
self, documents: List[Document], embed_meta_fields: Optional[List[str]] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Concatenates specified metadata fields with the text representations.
|
||||||
|
|
||||||
|
:param documents: List of documents to add metadata to.
|
||||||
|
:param embed_meta_fields: Concatenate the provided meta fields and into the text passage that is then used in
|
||||||
|
reranking.
|
||||||
|
:return: List of documents with metadata.
|
||||||
|
"""
|
||||||
|
if not embed_meta_fields:
|
||||||
|
return documents
|
||||||
|
|
||||||
|
docs_with_meta = []
|
||||||
|
for doc in documents:
|
||||||
|
doc = deepcopy(doc)
|
||||||
|
# Gather all relevant metadata fields
|
||||||
|
meta_data_fields = []
|
||||||
|
for key in embed_meta_fields:
|
||||||
|
if key in doc.meta and doc.meta[key]:
|
||||||
|
if isinstance(doc.meta[key], list):
|
||||||
|
meta_data_fields.extend([item for item in doc.meta[key]])
|
||||||
|
else:
|
||||||
|
meta_data_fields.append(doc.meta[key])
|
||||||
|
# Convert to type string (e.g. for ints or floats)
|
||||||
|
meta_data_fields = [str(field) for field in meta_data_fields]
|
||||||
|
doc.content = "\n".join(meta_data_fields + [doc.content])
|
||||||
|
docs_with_meta.append(doc)
|
||||||
|
return docs_with_meta
|
||||||
|
|
||||||
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): # type: ignore
|
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): # type: ignore
|
||||||
self.query_count += 1
|
self.query_count += 1
|
||||||
if documents:
|
if documents:
|
||||||
|
|||||||
@ -42,7 +42,12 @@ class CohereRanker(BaseRanker):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, api_key: str, model_name_or_path: str, top_k: int = 10, max_chunks_per_doc: Optional[int] = None
|
self,
|
||||||
|
api_key: str,
|
||||||
|
model_name_or_path: str,
|
||||||
|
top_k: int = 10,
|
||||||
|
max_chunks_per_doc: Optional[int] = None,
|
||||||
|
embed_meta_fields: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates an instance of CohereInvocationLayer for the specified Cohere model.
|
Creates an instance of CohereInvocationLayer for the specified Cohere model.
|
||||||
@ -54,6 +59,8 @@ class CohereRanker(BaseRanker):
|
|||||||
chunks a document can be split into. If None, the default of 10 is used.
|
chunks a document can be split into. If None, the default of 10 is used.
|
||||||
For example, if your document is 6000 tokens, with the default of 10, the document will be split into 10
|
For example, if your document is 6000 tokens, with the default of 10, the document will be split into 10
|
||||||
chunks each of 512 tokens and the last 880 tokens will be disregarded.
|
chunks each of 512 tokens and the last 880 tokens will be disregarded.
|
||||||
|
:param embed_meta_fields: Concatenate the provided meta fields and into the text passage that is then used in
|
||||||
|
reranking. The original documents are returned so the concatenated metadata is not included in the returned documents.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
valid_api_key = isinstance(api_key, str) and api_key
|
valid_api_key = isinstance(api_key, str) and api_key
|
||||||
@ -71,6 +78,7 @@ class CohereRanker(BaseRanker):
|
|||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.max_chunks_per_doc = max_chunks_per_doc
|
self.max_chunks_per_doc = max_chunks_per_doc
|
||||||
|
self.embed_meta_fields = embed_meta_fields
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def url(self) -> str:
|
def url(self) -> str:
|
||||||
@ -139,7 +147,10 @@ class CohereRanker(BaseRanker):
|
|||||||
top_k = self.top_k
|
top_k = self.top_k
|
||||||
|
|
||||||
# See https://docs.cohere.com/reference/rerank-1
|
# See https://docs.cohere.com/reference/rerank-1
|
||||||
cohere_docs = [{"text": d.content} for d in documents]
|
docs_with_meta_fields = self._add_meta_fields_to_docs(
|
||||||
|
documents=documents, embed_meta_fields=self.embed_meta_fields
|
||||||
|
)
|
||||||
|
cohere_docs = [{"text": d.content} for d in docs_with_meta_fields]
|
||||||
if len(cohere_docs) > 1000:
|
if len(cohere_docs) > 1000:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The Cohere reranking endpoint only supports 1000 documents. "
|
"The Cohere reranking endpoint only supports 1000 documents. "
|
||||||
@ -216,5 +227,6 @@ class CohereRanker(BaseRanker):
|
|||||||
|
|
||||||
results = []
|
results = []
|
||||||
for query, cur_docs in zip(queries, documents):
|
for query, cur_docs in zip(queries, documents):
|
||||||
|
assert isinstance(cur_docs, list)
|
||||||
results.append(self.predict(query=query, documents=cur_docs, top_k=top_k)) # type: ignore
|
results.append(self.predict(query=query, documents=cur_docs, top_k=top_k)) # type: ignore
|
||||||
return results
|
return results
|
||||||
|
|||||||
@ -56,6 +56,7 @@ class SentenceTransformersRanker(BaseRanker):
|
|||||||
scale_score: bool = True,
|
scale_score: bool = True,
|
||||||
progress_bar: bool = True,
|
progress_bar: bool = True,
|
||||||
use_auth_token: Optional[Union[str, bool]] = None,
|
use_auth_token: Optional[Union[str, bool]] = None,
|
||||||
|
embed_meta_fields: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param model_name_or_path: Directory of a saved model or the name of a public model e.g.
|
:param model_name_or_path: Directory of a saved model or the name of a public model e.g.
|
||||||
@ -78,6 +79,8 @@ class SentenceTransformersRanker(BaseRanker):
|
|||||||
A list containing torch device objects and/or strings is supported (For example
|
A list containing torch device objects and/or strings is supported (For example
|
||||||
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
|
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
|
||||||
parameter is not used and a single cpu device is used for inference.
|
parameter is not used and a single cpu device is used for inference.
|
||||||
|
:param embed_meta_fields: Concatenate the provided meta fields and into the text passage that is then used in
|
||||||
|
reranking. The original documents are returned so the concatenated metadata is not included in the returned documents.
|
||||||
"""
|
"""
|
||||||
torch_and_transformers_import.check()
|
torch_and_transformers_import.check()
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -109,6 +112,7 @@ class SentenceTransformersRanker(BaseRanker):
|
|||||||
self.model = DataParallel(self.transformer_model, device_ids=self.devices)
|
self.model = DataParallel(self.transformer_model, device_ids=self.devices)
|
||||||
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
self.embed_meta_fields = embed_meta_fields
|
||||||
|
|
||||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Document]:
|
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
@ -124,12 +128,12 @@ class SentenceTransformersRanker(BaseRanker):
|
|||||||
if top_k is None:
|
if top_k is None:
|
||||||
top_k = self.top_k
|
top_k = self.top_k
|
||||||
|
|
||||||
|
docs_with_meta_fields = self._add_meta_fields_to_docs(
|
||||||
|
documents=documents, embed_meta_fields=self.embed_meta_fields
|
||||||
|
)
|
||||||
|
docs = [doc.content for doc in docs_with_meta_fields]
|
||||||
features = self.transformer_tokenizer(
|
features = self.transformer_tokenizer(
|
||||||
[query for _ in documents],
|
[query for _ in documents], docs, padding=True, truncation=True, return_tensors="pt"
|
||||||
[doc.content for doc in documents],
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
).to(self.devices[0])
|
).to(self.devices[0])
|
||||||
|
|
||||||
# SentenceTransformerRanker uses:
|
# SentenceTransformerRanker uses:
|
||||||
@ -214,9 +218,12 @@ class SentenceTransformersRanker(BaseRanker):
|
|||||||
number_of_docs, all_queries, all_docs, single_list_of_docs = self._preprocess_batch_queries_and_docs(
|
number_of_docs, all_queries, all_docs, single_list_of_docs = self._preprocess_batch_queries_and_docs(
|
||||||
queries=queries, documents=documents
|
queries=queries, documents=documents
|
||||||
)
|
)
|
||||||
|
all_docs_with_meta_fields = self._add_meta_fields_to_docs(
|
||||||
|
documents=all_docs, embed_meta_fields=self.embed_meta_fields
|
||||||
|
)
|
||||||
|
|
||||||
batches = self._get_batches(all_queries=all_queries, all_docs=all_docs, batch_size=batch_size)
|
batches = self._get_batches(all_queries=all_queries, all_docs=all_docs_with_meta_fields, batch_size=batch_size)
|
||||||
pb = tqdm(total=len(all_docs), disable=not self.progress_bar, desc="Ranking")
|
pb = tqdm(total=len(all_docs_with_meta_fields), disable=not self.progress_bar, desc="Ranking")
|
||||||
preds = []
|
preds = []
|
||||||
for cur_queries, cur_docs in batches:
|
for cur_queries, cur_docs in batches:
|
||||||
features = self.transformer_tokenizer(
|
features = self.transformer_tokenizer(
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import pytest
|
|||||||
import math
|
import math
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
from haystack.schema import Document
|
from haystack.schema import Document
|
||||||
from haystack.nodes.ranker.base import BaseRanker
|
from haystack.nodes.ranker.base import BaseRanker
|
||||||
from haystack.nodes.ranker import SentenceTransformersRanker, CohereRanker
|
from haystack.nodes.ranker import SentenceTransformersRanker, CohereRanker
|
||||||
@ -52,6 +53,52 @@ def mock_cohere_post():
|
|||||||
yield cohere_post
|
yield cohere_post
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_transformer_tokenizer():
|
||||||
|
class Features(dict):
|
||||||
|
def to(self, arg):
|
||||||
|
return self
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return Features(
|
||||||
|
{
|
||||||
|
"input_ids": torch.zeros([5, 162]),
|
||||||
|
"token_type_ids": torch.zeros([5, 162], dtype=torch.long),
|
||||||
|
"attention_mask": torch.zeros([5, 162], dtype=torch.long),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("transformers.AutoTokenizer.from_pretrained") as mock_tokenizer:
|
||||||
|
mock_tokenizer.return_value = Tokenizer()
|
||||||
|
yield mock_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_transformer_model():
|
||||||
|
class Logits:
|
||||||
|
def __init__(self, logits):
|
||||||
|
self.logits = logits
|
||||||
|
|
||||||
|
class Model:
|
||||||
|
def __init__(self):
|
||||||
|
self.logits = torch.tensor([[-9.7414], [-11.1572], [-11.1708], [-11.1515], [5.2571]])
|
||||||
|
self.num_labels = 1
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return Logits(logits=self.logits)
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to(self, arg):
|
||||||
|
return self
|
||||||
|
|
||||||
|
with patch("transformers.AutoModelForSequenceClassification.from_pretrained") as mock_model:
|
||||||
|
mock_model.return_value = Model()
|
||||||
|
yield mock_model
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_ranker_preprocess_batch_queries_and_docs_raises():
|
def test_ranker_preprocess_batch_queries_and_docs_raises():
|
||||||
query_1 = "query 1"
|
query_1 = "query 1"
|
||||||
@ -109,19 +156,102 @@ def test_ranker_get_batches():
|
|||||||
assert next(batches) == (all_queries[0:1], all_docs[0:1])
|
assert next(batches) == (all_queries[0:1], all_docs[0:1])
|
||||||
|
|
||||||
|
|
||||||
def test_ranker(ranker, docs):
|
@pytest.mark.unit
|
||||||
|
def test_add_meta_fields_to_docs():
|
||||||
|
docs = [
|
||||||
|
Document(
|
||||||
|
content="dummy doc 1",
|
||||||
|
meta={
|
||||||
|
"str_field": "test1",
|
||||||
|
"empty_str_field": "",
|
||||||
|
"numeric_field": 2.0,
|
||||||
|
"list_field": ["item0.1", "item0.2"],
|
||||||
|
"empty_list_field": [],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
content="dummy doc 2",
|
||||||
|
meta={
|
||||||
|
"str_field": "test2",
|
||||||
|
"empty_str_field": "",
|
||||||
|
"numeric_field": 5.0,
|
||||||
|
"list_field": ["item1.1", "item1.2"],
|
||||||
|
"empty_list_field": [],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
|
||||||
|
mock_ranker_init.return_value = None
|
||||||
|
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||||
|
docs_with_meta = ranker._add_meta_fields_to_docs(
|
||||||
|
documents=docs,
|
||||||
|
embed_meta_fields=["str_field", "empty_str_field", "numeric_field", "list_field", "empty_list_field"],
|
||||||
|
)
|
||||||
|
assert docs_with_meta[0].content.startswith("test1\n2.0\nitem0.1\nitem0.2\ndummy doc 1")
|
||||||
|
assert docs_with_meta[1].content.startswith("test2\n5.0\nitem1.1\nitem1.2\ndummy doc 2")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_add_meta_fields_to_docs_none():
|
||||||
|
docs = [Document(content="dummy doc 1", meta={"none_field": None})]
|
||||||
|
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
|
||||||
|
mock_ranker_init.return_value = None
|
||||||
|
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||||
|
docs_with_meta = ranker._add_meta_fields_to_docs(documents=docs, embed_meta_fields=["none_field"])
|
||||||
|
assert docs_with_meta == docs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_add_meta_fields_to_docs_non_existent():
|
||||||
|
docs = [Document(content="dummy doc 1", meta={"test_field": "A string"})]
|
||||||
|
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
|
||||||
|
mock_ranker_init.return_value = None
|
||||||
|
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||||
|
docs_with_meta = ranker._add_meta_fields_to_docs(documents=docs, embed_meta_fields=["wrong_field"])
|
||||||
|
assert docs_with_meta == docs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_add_meta_fields_to_docs_empty_list():
|
||||||
|
docs = [Document(content="dummy doc 1", meta={"test_field": "A string"})]
|
||||||
|
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
|
||||||
|
mock_ranker_init.return_value = None
|
||||||
|
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||||
|
docs_with_meta = ranker._add_meta_fields_to_docs(documents=docs, embed_meta_fields=[])
|
||||||
|
assert docs_with_meta == docs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_ranker(docs, mock_transformer_model, mock_transformer_tokenizer):
|
||||||
|
with patch("torch.nn.DataParallel"):
|
||||||
|
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||||
query = "What is the most important building in King's Landing that has a religious background?"
|
query = "What is the most important building in King's Landing that has a religious background?"
|
||||||
results = ranker.predict(query=query, documents=docs)
|
results = ranker.predict(query=query, documents=docs)
|
||||||
assert results[0] == docs[4]
|
assert results[0] == docs[4]
|
||||||
|
|
||||||
|
|
||||||
def test_ranker_batch_single_query_single_doc_list(ranker, docs):
|
@pytest.mark.unit
|
||||||
|
def test_ranker_run(docs, mock_transformer_model, mock_transformer_tokenizer):
|
||||||
|
with patch("torch.nn.DataParallel"):
|
||||||
|
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||||
|
query = "What is the most important building in King's Landing that has a religious background?"
|
||||||
|
results = ranker.run(query=query, documents=docs)
|
||||||
|
assert results[0]["documents"][0] == docs[4]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_ranker_batch_single_query_single_doc_list(docs, mock_transformer_model, mock_transformer_tokenizer):
|
||||||
|
with patch("torch.nn.DataParallel"):
|
||||||
|
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||||
query = "What is the most important building in King's Landing that has a religious background?"
|
query = "What is the most important building in King's Landing that has a religious background?"
|
||||||
results = ranker.predict_batch(queries=[query], documents=docs)
|
results = ranker.predict_batch(queries=[query], documents=docs)
|
||||||
assert results[0] == docs[4]
|
assert results[0] == docs[4]
|
||||||
|
|
||||||
|
|
||||||
def test_ranker_batch_single_query_multiple_doc_lists(ranker, docs):
|
@pytest.mark.unit
|
||||||
|
def test_ranker_batch_single_query_multiple_doc_lists(docs, mock_transformer_model, mock_transformer_tokenizer):
|
||||||
|
with patch("torch.nn.DataParallel"):
|
||||||
|
ranker = SentenceTransformersRanker(model_name_or_path="fake_model", batch_size=5)
|
||||||
query = "What is the most important building in King's Landing that has a religious background?"
|
query = "What is the most important building in King's Landing that has a religious background?"
|
||||||
results = ranker.predict_batch(queries=[query], documents=[docs, docs])
|
results = ranker.predict_batch(queries=[query], documents=[docs, docs])
|
||||||
assert isinstance(results, list)
|
assert isinstance(results, list)
|
||||||
@ -140,6 +270,15 @@ def test_ranker_batch_multiple_queries_multiple_doc_lists(ranker, docs):
|
|||||||
assert results[1][0] == docs[1]
|
assert results[1][0] == docs[1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_ranker_with_embed_meta_fields(docs, mock_transformer_model, mock_transformer_tokenizer):
|
||||||
|
with patch("torch.nn.DataParallel"):
|
||||||
|
ranker = SentenceTransformersRanker(model_name_or_path="fake_model", embed_meta_fields=["name"])
|
||||||
|
query = "What is the most important building in King's Landing that has a religious background?"
|
||||||
|
results = ranker.predict(query=query, documents=docs)
|
||||||
|
assert results[0] == docs[4]
|
||||||
|
|
||||||
|
|
||||||
def test_ranker_two_logits(ranker_two_logits, docs):
|
def test_ranker_two_logits(ranker_two_logits, docs):
|
||||||
assert isinstance(ranker_two_logits, BaseRanker)
|
assert isinstance(ranker_two_logits, BaseRanker)
|
||||||
assert isinstance(ranker_two_logits, SentenceTransformersRanker)
|
assert isinstance(ranker_two_logits, SentenceTransformersRanker)
|
||||||
@ -227,9 +366,7 @@ def test_ranker_returns_raw_score_for_two_logits(ranker_two_logits):
|
|||||||
|
|
||||||
def test_predict_batch_returns_correct_number_of_docs(ranker):
|
def test_predict_batch_returns_correct_number_of_docs(ranker):
|
||||||
docs = [Document(content=f"test {number}") for number in range(5)]
|
docs = [Document(content=f"test {number}") for number in range(5)]
|
||||||
|
|
||||||
assert len(ranker.predict("where is test 3?", docs, top_k=4)) == 4
|
assert len(ranker.predict("where is test 3?", docs, top_k=4)) == 4
|
||||||
|
|
||||||
assert len(ranker.predict_batch(["where is test 3?"], docs, batch_size=2, top_k=4)) == 4
|
assert len(ranker.predict_batch(["where is test 3?"], docs, batch_size=2, top_k=4)) == 4
|
||||||
|
|
||||||
|
|
||||||
@ -251,6 +388,32 @@ def test_cohere_ranker(docs, mock_cohere_post):
|
|||||||
assert results[0] == docs[4]
|
assert results[0] == docs[4]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_cohere_ranker_with_embed_meta_fields(docs, mock_cohere_post):
|
||||||
|
query = "What is the most important building in King's Landing that has a religious background?"
|
||||||
|
ranker = CohereRanker(api_key="fake_key", model_name_or_path="rerank-english-v2.0", embed_meta_fields=["name"])
|
||||||
|
results = ranker.predict(query=query, documents=docs)
|
||||||
|
# Prep expected input
|
||||||
|
documents = []
|
||||||
|
for d in docs:
|
||||||
|
meta = d.meta.get("name")
|
||||||
|
if meta:
|
||||||
|
documents.append({"text": d.meta["name"] + "\n" + d.content})
|
||||||
|
else:
|
||||||
|
documents.append({"text": d.content})
|
||||||
|
mock_cohere_post.assert_called_once_with(
|
||||||
|
{
|
||||||
|
"model": "rerank-english-v2.0",
|
||||||
|
"query": query,
|
||||||
|
"documents": documents,
|
||||||
|
"top_n": None, # By passing None we return all documents and use top_k to truncate later
|
||||||
|
"return_documents": False,
|
||||||
|
"max_chunks_per_doc": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert results[0] == docs[4]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_cohere_ranker_batch_single_query_single_doc_list(docs, mock_cohere_post):
|
def test_cohere_ranker_batch_single_query_single_doc_list(docs, mock_cohere_post):
|
||||||
query = "What is the most important building in King's Landing that has a religious background?"
|
query = "What is the most important building in King's Landing that has a religious background?"
|
||||||
@ -269,6 +432,31 @@ def test_cohere_ranker_batch_single_query_single_doc_list(docs, mock_cohere_post
|
|||||||
assert results[0] == docs[4]
|
assert results[0] == docs[4]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_cohere_ranker_batch_single_query_single_doc_list_with_embed_meta_fields(docs, mock_cohere_post):
|
||||||
|
query = "What is the most important building in King's Landing that has a religious background?"
|
||||||
|
ranker = CohereRanker(api_key="fake_key", model_name_or_path="rerank-english-v2.0", embed_meta_fields=["name"])
|
||||||
|
results = ranker.predict_batch(queries=[query], documents=docs)
|
||||||
|
documents = []
|
||||||
|
for d in docs:
|
||||||
|
meta = d.meta.get("name")
|
||||||
|
if meta:
|
||||||
|
documents.append({"text": d.meta["name"] + "\n" + d.content})
|
||||||
|
else:
|
||||||
|
documents.append({"text": d.content})
|
||||||
|
mock_cohere_post.assert_called_once_with(
|
||||||
|
{
|
||||||
|
"model": "rerank-english-v2.0",
|
||||||
|
"query": query,
|
||||||
|
"documents": documents,
|
||||||
|
"top_n": None, # By passing None we return all documents and use top_k to truncate later
|
||||||
|
"return_documents": False,
|
||||||
|
"max_chunks_per_doc": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert results[0] == docs[4]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_cohere_ranker_batch_single_query_multiple_doc_lists(docs, mock_cohere_post):
|
def test_cohere_ranker_batch_single_query_multiple_doc_lists(docs, mock_cohere_post):
|
||||||
query = "What is the most important building in King's Landing that has a religious background?"
|
query = "What is the most important building in King's Landing that has a religious background?"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user