feat: Add Lost In The Middle Ranker (#6995)

* add lost in the middle ranker

* update

* add release notes

* update release notes

* fix mypy

* Update

* fix mypy

* fix mypy [union-attr] for content.split

* remove e2e tests and negative topk param

* remove query param, validate params

---------

Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
Varun Mathur 2024-02-21 00:25:41 +05:30 committed by GitHub
parent 327c2d260d
commit b335b5d723
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 223 additions and 1 deletions

View File

@ -1,4 +1,5 @@
from haystack.components.rankers.lost_in_the_middle import LostInTheMiddleRanker
from haystack.components.rankers.meta_field import MetaFieldRanker
from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker
__all__ = ["MetaFieldRanker", "TransformersSimilarityRanker"]
__all__ = ["LostInTheMiddleRanker", "MetaFieldRanker", "TransformersSimilarityRanker"]

View File

@ -0,0 +1,109 @@
from typing import Any, Dict, List, Optional
from haystack import Document, component, default_to_dict
@component
class LostInTheMiddleRanker:
"""
The LostInTheMiddleRanker implements a ranker that reorders documents based on the "lost in the middle" order.
"Lost in the Middle: How Language Models Use Long Contexts" paper by Liu et al. aims to lay out paragraphs into LLM
context so that the relevant paragraphs are at the beginning or end of the input context, while the least relevant
information is in the middle of the context.
See https://arxiv.org/abs/2307.03172 for more details.
"""
def __init__(self, word_count_threshold: Optional[int] = None, top_k: Optional[int] = None):
"""
If 'word_count_threshold' is specified, this ranker includes all documents up until the point where adding
another document would exceed the 'word_count_threshold'. The last document that causes the threshold to
be breached will be included in the resulting list of documents, but all subsequent documents will be
discarded.
:param word_count_threshold: The maximum total number of words across all documents selected by the ranker.
:param top_k: The maximum number of documents to return.
"""
if isinstance(word_count_threshold, int) and word_count_threshold <= 0:
raise ValueError(
f"Invalid value for word_count_threshold: {word_count_threshold}. " f"word_count_threshold must be > 0."
)
if isinstance(top_k, int) and top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
self.word_count_threshold = word_count_threshold
self.top_k = top_k
def to_dict(self) -> Dict[str, Any]:
"""
Serialize object to a dictionary.
"""
return default_to_dict(self, word_count_threshold=self.word_count_threshold, top_k=self.top_k)
def run(
self, documents: List[Document], top_k: Optional[int] = None, word_count_threshold: Optional[int] = None
) -> Dict[str, List[Document]]:
"""
Reranks documents based on the "lost in the middle" order.
Returns a list of Documents reordered based on the input query.
:param documents: List of Documents to reorder.
:param top_k: The number of documents to return.
:param word_count_threshold: The maximum total number of words across all documents selected by the ranker.
:return: The reordered documents.
"""
if isinstance(word_count_threshold, int) and word_count_threshold <= 0:
raise ValueError(
f"Invalid value for word_count_threshold: {word_count_threshold}. " f"word_count_threshold must be > 0."
)
if isinstance(top_k, int) and top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
if not documents:
return {"documents": []}
top_k = top_k or self.top_k
word_count_threshold = word_count_threshold or self.word_count_threshold
documents_to_reorder = documents[:top_k] if top_k else documents
# If there's only one document, return it as is
if len(documents_to_reorder) == 1:
return {"documents": documents_to_reorder}
# Raise an error if any document is not textual
if any(not doc.content_type == "text" for doc in documents_to_reorder):
raise ValueError("Some provided documents are not textual; LostInTheMiddleRanker can process only text.")
# Initialize word count and indices for the "lost in the middle" order
word_count = 0
document_index = list(range(len(documents_to_reorder)))
lost_in_the_middle_indices = [0]
# If word count threshold is set and the first document has content, calculate word count for the first document
if word_count_threshold and documents_to_reorder[0].content:
word_count = len(documents_to_reorder[0].content.split())
# If the first document already meets the word count threshold, return it
if word_count >= word_count_threshold:
return {"documents": [documents_to_reorder[0]]}
# Start from the second document and create "lost in the middle" order
for doc_idx in document_index[1:]:
# Calculate the index at which the current document should be inserted
insertion_index = len(lost_in_the_middle_indices) // 2 + len(lost_in_the_middle_indices) % 2
# Insert the document index at the calculated position
lost_in_the_middle_indices.insert(insertion_index, doc_idx)
# If word count threshold is set and the document has content, calculate the total word count
if word_count_threshold and documents_to_reorder[doc_idx].content:
word_count += len(documents_to_reorder[doc_idx].content.split()) # type: ignore[union-attr]
# If the total word count meets the threshold, stop processing further documents
if word_count >= word_count_threshold:
break
# Documents in the "lost in the middle" order
ranked_docs = [documents_to_reorder[idx] for idx in lost_in_the_middle_indices]
return {"documents": ranked_docs}

View File

@ -0,0 +1,8 @@
---
features:
- |
Add LostInTheMiddleRanker.
It reorders documents based on the "Lost in the Middle" order, a strategy that
places the most relevant paragraphs at the beginning or end of the context,
while less relevant paragraphs are positioned in the middle.

View File

@ -0,0 +1,104 @@
import pytest
from haystack import Document
from haystack.components.rankers.lost_in_the_middle import LostInTheMiddleRanker
class TestLostInTheMiddleRanker:
def test_lost_in_the_middle_order_odd(self):
# tests that lost_in_the_middle order works with an odd number of documents
docs = [Document(content=str(i)) for i in range(1, 10)]
ranker = LostInTheMiddleRanker()
result = ranker.run(documents=docs)
assert result["documents"]
expected_order = "1 3 5 7 9 8 6 4 2".split()
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))
def test_lost_in_the_middle_order_even(self):
# tests that lost_in_the_middle order works with an even number of documents
docs = [Document(content=str(i)) for i in range(1, 11)]
ranker = LostInTheMiddleRanker()
result = ranker.run(documents=docs)
expected_order = "1 3 5 7 9 10 8 6 4 2".split()
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))
def test_lost_in_the_middle_order_two_docs(self):
# tests that lost_in_the_middle order works with two documents
ranker = LostInTheMiddleRanker()
# two docs
docs = [Document(content="1"), Document(content="2")]
result = ranker.run(documents=docs)
assert result["documents"][0].content == "1"
assert result["documents"][1].content == "2"
def test_lost_in_the_middle_init(self):
# tests that LostInTheMiddleRanker initializes with default values
ranker = LostInTheMiddleRanker()
assert ranker.word_count_threshold is None
ranker = LostInTheMiddleRanker(word_count_threshold=10)
assert ranker.word_count_threshold == 10
def test_lost_in_the_middle_init_invalid_word_count_threshold(self):
# tests that LostInTheMiddleRanker raises an error when word_count_threshold is <= 0
with pytest.raises(ValueError, match="Invalid value for word_count_threshold"):
LostInTheMiddleRanker(word_count_threshold=0)
with pytest.raises(ValueError, match="Invalid value for word_count_threshold"):
LostInTheMiddleRanker(word_count_threshold=-5)
def test_lost_in_the_middle_with_word_count_threshold(self):
# tests that lost_in_the_middle with word_count_threshold works as expected
ranker = LostInTheMiddleRanker(word_count_threshold=6)
docs = [Document(content="word" + str(i)) for i in range(1, 10)]
# result, _ = ranker.run(query="", documents=docs)
result = ranker.run(documents=docs)
expected_order = "word1 word3 word5 word6 word4 word2".split()
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))
ranker = LostInTheMiddleRanker(word_count_threshold=9)
# result, _ = ranker.run(query="", documents=docs)
result = ranker.run(documents=docs)
expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split()
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))
def test_word_count_threshold_greater_than_total_number_of_words_returns_all_documents(self):
ranker = LostInTheMiddleRanker(word_count_threshold=100)
docs = [Document(content="word" + str(i)) for i in range(1, 10)]
ordered_docs = ranker.run(documents=docs)
# assert len(ordered_docs) == len(docs)
expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split()
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(ordered_docs["documents"]))
def test_empty_documents_returns_empty_list(self):
ranker = LostInTheMiddleRanker()
result = ranker.run(documents=[])
assert result == {"documents": []}
def test_list_of_one_document_returns_same_document(self):
ranker = LostInTheMiddleRanker()
doc = Document(content="test")
assert ranker.run(documents=[doc]) == {"documents": [doc]}
@pytest.mark.parametrize("top_k", [1, 2, 3, 4, 5, 6, 7, 8, 12, 20])
def test_lost_in_the_middle_order_with_top_k(self, top_k: int):
# tests that lost_in_the_middle order works with an odd number of documents and a top_k parameter
docs = [Document(content=str(i)) for i in range(1, 10)]
ranker = LostInTheMiddleRanker()
result = ranker.run(documents=docs, top_k=top_k)
if top_k < len(docs):
# top_k is less than the number of documents, so only the top_k documents should be returned in LITM order
assert len(result["documents"]) == top_k
expected_order = ranker.run(documents=[Document(content=str(i)) for i in range(1, top_k + 1)])
assert result == expected_order
else:
# top_k is greater than the number of documents, so all documents should be returned in LITM order
assert len(result["documents"]) == len(docs)
assert result == ranker.run(documents=docs)
def test_to_dict(self):
component = LostInTheMiddleRanker()
data = component.to_dict()
assert data == {
"type": "haystack.components.rankers.lost_in_the_middle.LostInTheMiddleRanker",
"init_parameters": {"word_count_threshold": None, "top_k": None},
}