mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 03:57:19 +00:00
feat: Add Lost In The Middle Ranker (#6995)
* add lost in the middle ranker * update * add release notes * update release notes * fix mypy * Update * fix mypy * fix mypy [union-attr] for content.split * remove e2e tests and negative topk param * remove query param, validate params --------- Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
parent
327c2d260d
commit
b335b5d723
@ -1,4 +1,5 @@
|
||||
from haystack.components.rankers.lost_in_the_middle import LostInTheMiddleRanker
|
||||
from haystack.components.rankers.meta_field import MetaFieldRanker
|
||||
from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker
|
||||
|
||||
__all__ = ["MetaFieldRanker", "TransformersSimilarityRanker"]
|
||||
__all__ = ["LostInTheMiddleRanker", "MetaFieldRanker", "TransformersSimilarityRanker"]
|
||||
|
||||
109
haystack/components/rankers/lost_in_the_middle.py
Normal file
109
haystack/components/rankers/lost_in_the_middle.py
Normal file
@ -0,0 +1,109 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from haystack import Document, component, default_to_dict
|
||||
|
||||
|
||||
@component
|
||||
class LostInTheMiddleRanker:
|
||||
"""
|
||||
The LostInTheMiddleRanker implements a ranker that reorders documents based on the "lost in the middle" order.
|
||||
"Lost in the Middle: How Language Models Use Long Contexts" paper by Liu et al. aims to lay out paragraphs into LLM
|
||||
context so that the relevant paragraphs are at the beginning or end of the input context, while the least relevant
|
||||
information is in the middle of the context.
|
||||
|
||||
See https://arxiv.org/abs/2307.03172 for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, word_count_threshold: Optional[int] = None, top_k: Optional[int] = None):
|
||||
"""
|
||||
If 'word_count_threshold' is specified, this ranker includes all documents up until the point where adding
|
||||
another document would exceed the 'word_count_threshold'. The last document that causes the threshold to
|
||||
be breached will be included in the resulting list of documents, but all subsequent documents will be
|
||||
discarded.
|
||||
|
||||
:param word_count_threshold: The maximum total number of words across all documents selected by the ranker.
|
||||
:param top_k: The maximum number of documents to return.
|
||||
"""
|
||||
if isinstance(word_count_threshold, int) and word_count_threshold <= 0:
|
||||
raise ValueError(
|
||||
f"Invalid value for word_count_threshold: {word_count_threshold}. " f"word_count_threshold must be > 0."
|
||||
)
|
||||
if isinstance(top_k, int) and top_k <= 0:
|
||||
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
||||
|
||||
self.word_count_threshold = word_count_threshold
|
||||
self.top_k = top_k
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize object to a dictionary.
|
||||
"""
|
||||
return default_to_dict(self, word_count_threshold=self.word_count_threshold, top_k=self.top_k)
|
||||
|
||||
def run(
|
||||
self, documents: List[Document], top_k: Optional[int] = None, word_count_threshold: Optional[int] = None
|
||||
) -> Dict[str, List[Document]]:
|
||||
"""
|
||||
Reranks documents based on the "lost in the middle" order.
|
||||
Returns a list of Documents reordered based on the input query.
|
||||
:param documents: List of Documents to reorder.
|
||||
:param top_k: The number of documents to return.
|
||||
:param word_count_threshold: The maximum total number of words across all documents selected by the ranker.
|
||||
|
||||
:return: The reordered documents.
|
||||
"""
|
||||
if isinstance(word_count_threshold, int) and word_count_threshold <= 0:
|
||||
raise ValueError(
|
||||
f"Invalid value for word_count_threshold: {word_count_threshold}. " f"word_count_threshold must be > 0."
|
||||
)
|
||||
if isinstance(top_k, int) and top_k <= 0:
|
||||
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
||||
|
||||
if not documents:
|
||||
return {"documents": []}
|
||||
|
||||
top_k = top_k or self.top_k
|
||||
word_count_threshold = word_count_threshold or self.word_count_threshold
|
||||
|
||||
documents_to_reorder = documents[:top_k] if top_k else documents
|
||||
|
||||
# If there's only one document, return it as is
|
||||
if len(documents_to_reorder) == 1:
|
||||
return {"documents": documents_to_reorder}
|
||||
|
||||
# Raise an error if any document is not textual
|
||||
if any(not doc.content_type == "text" for doc in documents_to_reorder):
|
||||
raise ValueError("Some provided documents are not textual; LostInTheMiddleRanker can process only text.")
|
||||
|
||||
# Initialize word count and indices for the "lost in the middle" order
|
||||
word_count = 0
|
||||
document_index = list(range(len(documents_to_reorder)))
|
||||
lost_in_the_middle_indices = [0]
|
||||
|
||||
# If word count threshold is set and the first document has content, calculate word count for the first document
|
||||
if word_count_threshold and documents_to_reorder[0].content:
|
||||
word_count = len(documents_to_reorder[0].content.split())
|
||||
|
||||
# If the first document already meets the word count threshold, return it
|
||||
if word_count >= word_count_threshold:
|
||||
return {"documents": [documents_to_reorder[0]]}
|
||||
|
||||
# Start from the second document and create "lost in the middle" order
|
||||
for doc_idx in document_index[1:]:
|
||||
# Calculate the index at which the current document should be inserted
|
||||
insertion_index = len(lost_in_the_middle_indices) // 2 + len(lost_in_the_middle_indices) % 2
|
||||
|
||||
# Insert the document index at the calculated position
|
||||
lost_in_the_middle_indices.insert(insertion_index, doc_idx)
|
||||
|
||||
# If word count threshold is set and the document has content, calculate the total word count
|
||||
if word_count_threshold and documents_to_reorder[doc_idx].content:
|
||||
word_count += len(documents_to_reorder[doc_idx].content.split()) # type: ignore[union-attr]
|
||||
|
||||
# If the total word count meets the threshold, stop processing further documents
|
||||
if word_count >= word_count_threshold:
|
||||
break
|
||||
|
||||
# Documents in the "lost in the middle" order
|
||||
ranked_docs = [documents_to_reorder[idx] for idx in lost_in_the_middle_indices]
|
||||
return {"documents": ranked_docs}
|
||||
@ -0,0 +1,8 @@
|
||||
---
|
||||
|
||||
features:
|
||||
- |
|
||||
Add LostInTheMiddleRanker.
|
||||
It reorders documents based on the "Lost in the Middle" order, a strategy that
|
||||
places the most relevant paragraphs at the beginning or end of the context,
|
||||
while less relevant paragraphs are positioned in the middle.
|
||||
104
test/components/rankers/test_lost_in_the_middle.py
Normal file
104
test/components/rankers/test_lost_in_the_middle.py
Normal file
@ -0,0 +1,104 @@
|
||||
import pytest
|
||||
from haystack import Document
|
||||
from haystack.components.rankers.lost_in_the_middle import LostInTheMiddleRanker
|
||||
|
||||
|
||||
class TestLostInTheMiddleRanker:
|
||||
def test_lost_in_the_middle_order_odd(self):
|
||||
# tests that lost_in_the_middle order works with an odd number of documents
|
||||
docs = [Document(content=str(i)) for i in range(1, 10)]
|
||||
ranker = LostInTheMiddleRanker()
|
||||
result = ranker.run(documents=docs)
|
||||
assert result["documents"]
|
||||
expected_order = "1 3 5 7 9 8 6 4 2".split()
|
||||
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))
|
||||
|
||||
def test_lost_in_the_middle_order_even(self):
|
||||
# tests that lost_in_the_middle order works with an even number of documents
|
||||
docs = [Document(content=str(i)) for i in range(1, 11)]
|
||||
ranker = LostInTheMiddleRanker()
|
||||
result = ranker.run(documents=docs)
|
||||
expected_order = "1 3 5 7 9 10 8 6 4 2".split()
|
||||
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))
|
||||
|
||||
def test_lost_in_the_middle_order_two_docs(self):
|
||||
# tests that lost_in_the_middle order works with two documents
|
||||
ranker = LostInTheMiddleRanker()
|
||||
# two docs
|
||||
docs = [Document(content="1"), Document(content="2")]
|
||||
result = ranker.run(documents=docs)
|
||||
assert result["documents"][0].content == "1"
|
||||
assert result["documents"][1].content == "2"
|
||||
|
||||
def test_lost_in_the_middle_init(self):
|
||||
# tests that LostInTheMiddleRanker initializes with default values
|
||||
ranker = LostInTheMiddleRanker()
|
||||
assert ranker.word_count_threshold is None
|
||||
|
||||
ranker = LostInTheMiddleRanker(word_count_threshold=10)
|
||||
assert ranker.word_count_threshold == 10
|
||||
|
||||
def test_lost_in_the_middle_init_invalid_word_count_threshold(self):
|
||||
# tests that LostInTheMiddleRanker raises an error when word_count_threshold is <= 0
|
||||
with pytest.raises(ValueError, match="Invalid value for word_count_threshold"):
|
||||
LostInTheMiddleRanker(word_count_threshold=0)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid value for word_count_threshold"):
|
||||
LostInTheMiddleRanker(word_count_threshold=-5)
|
||||
|
||||
def test_lost_in_the_middle_with_word_count_threshold(self):
|
||||
# tests that lost_in_the_middle with word_count_threshold works as expected
|
||||
ranker = LostInTheMiddleRanker(word_count_threshold=6)
|
||||
docs = [Document(content="word" + str(i)) for i in range(1, 10)]
|
||||
# result, _ = ranker.run(query="", documents=docs)
|
||||
result = ranker.run(documents=docs)
|
||||
expected_order = "word1 word3 word5 word6 word4 word2".split()
|
||||
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))
|
||||
|
||||
ranker = LostInTheMiddleRanker(word_count_threshold=9)
|
||||
# result, _ = ranker.run(query="", documents=docs)
|
||||
result = ranker.run(documents=docs)
|
||||
expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split()
|
||||
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))
|
||||
|
||||
def test_word_count_threshold_greater_than_total_number_of_words_returns_all_documents(self):
|
||||
ranker = LostInTheMiddleRanker(word_count_threshold=100)
|
||||
docs = [Document(content="word" + str(i)) for i in range(1, 10)]
|
||||
ordered_docs = ranker.run(documents=docs)
|
||||
# assert len(ordered_docs) == len(docs)
|
||||
expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split()
|
||||
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(ordered_docs["documents"]))
|
||||
|
||||
def test_empty_documents_returns_empty_list(self):
|
||||
ranker = LostInTheMiddleRanker()
|
||||
result = ranker.run(documents=[])
|
||||
assert result == {"documents": []}
|
||||
|
||||
def test_list_of_one_document_returns_same_document(self):
|
||||
ranker = LostInTheMiddleRanker()
|
||||
doc = Document(content="test")
|
||||
assert ranker.run(documents=[doc]) == {"documents": [doc]}
|
||||
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 3, 4, 5, 6, 7, 8, 12, 20])
|
||||
def test_lost_in_the_middle_order_with_top_k(self, top_k: int):
|
||||
# tests that lost_in_the_middle order works with an odd number of documents and a top_k parameter
|
||||
docs = [Document(content=str(i)) for i in range(1, 10)]
|
||||
ranker = LostInTheMiddleRanker()
|
||||
result = ranker.run(documents=docs, top_k=top_k)
|
||||
if top_k < len(docs):
|
||||
# top_k is less than the number of documents, so only the top_k documents should be returned in LITM order
|
||||
assert len(result["documents"]) == top_k
|
||||
expected_order = ranker.run(documents=[Document(content=str(i)) for i in range(1, top_k + 1)])
|
||||
assert result == expected_order
|
||||
else:
|
||||
# top_k is greater than the number of documents, so all documents should be returned in LITM order
|
||||
assert len(result["documents"]) == len(docs)
|
||||
assert result == ranker.run(documents=docs)
|
||||
|
||||
def test_to_dict(self):
|
||||
component = LostInTheMiddleRanker()
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.rankers.lost_in_the_middle.LostInTheMiddleRanker",
|
||||
"init_parameters": {"word_count_threshold": None, "top_k": None},
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user