mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-05 11:38:20 +00:00
feat: Add TopPSampler Haystack 2.0 component (#5924)
This commit is contained in:
parent
0cb9abb1c2
commit
40b83d8a47
3
haystack/preview/components/samplers/__init__.py
Normal file
3
haystack/preview/components/samplers/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from haystack.preview.components.samplers.top_p import TopPSampler
|
||||
|
||||
__all__ = ["TopPSampler"]
|
||||
140
haystack/preview/components/samplers/top_p.py
Normal file
140
haystack/preview/components/samplers/top_p.py
Normal file
@ -0,0 +1,140 @@
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from haystack.preview import ComponentError, Document, component, default_from_dict, default_to_dict
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
with LazyImport(message="Run 'pip install torch>=1.13'") as torch_import:
|
||||
import torch
|
||||
|
||||
|
||||
@component
|
||||
class TopPSampler:
|
||||
"""
|
||||
Filters documents using top-p (nucleus) sampling based on their similarity scores' cumulative probability.
|
||||
|
||||
Usage example:
|
||||
|
||||
```python
|
||||
from haystack.preview import Document
|
||||
from haystack.preview.components.samplers import TopPSampler
|
||||
|
||||
sampler = TopPSampler(top_p=0.95)
|
||||
docs = [
|
||||
Document(text="Berlin", metadata={"similarity_score": -10.6}),
|
||||
Document(text="Belgrade", metadata={"similarity_score": -8.9}),
|
||||
Document(text="Sarajevo", metadata={"similarity_score": -4.6}),
|
||||
]
|
||||
output = sampler.run(documents=docs)
|
||||
docs = output["documents"]
|
||||
assert len(docs) == 1
|
||||
assert docs[0].text == "Sarajevo"
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, top_p: float = 1.0, score_field: Optional[str] = None):
|
||||
"""
|
||||
Creates an instance of TopPSampler.
|
||||
|
||||
:param top_p: Cumulative probability threshold (usually between 0.9 and 0.99).
|
||||
:param score_field: Field name in a document's metadata containing the scores. Defaults to the Document score
|
||||
if not provided.
|
||||
"""
|
||||
torch_import.check()
|
||||
|
||||
self.top_p = top_p
|
||||
self.score_field = score_field
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(self, top_p=self.top_p, score_field=self.score_field)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "TopPSampler":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, documents: List[Document], top_p: Optional[float] = None):
|
||||
"""
|
||||
Filter documents based on their similarity scores using top-p sampling.
|
||||
|
||||
:param documents: List of Documents to filter.
|
||||
:param top_p: Cumulative probability threshold. Defaults to the value set during initialization or 1.0
|
||||
if not set.
|
||||
:return: List of filtered Documents.
|
||||
"""
|
||||
if not documents:
|
||||
return {"documents": []}
|
||||
|
||||
top_p = top_p or self.top_p or 1.0 # default to 1.0 if both are None
|
||||
|
||||
if not 0 <= top_p <= 1:
|
||||
raise ComponentError(f"top_p must be between 0 and 1. Got {top_p}.")
|
||||
|
||||
similarity_scores = torch.tensor(self._collect_scores(documents), dtype=torch.float32)
|
||||
|
||||
# Apply softmax normalization to the similarity scores
|
||||
probs = torch.nn.functional.softmax(similarity_scores, dim=-1)
|
||||
|
||||
# Sort the probabilities and calculate their cumulative sum
|
||||
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||||
|
||||
# Check if the cumulative probabilities are close to top_p with a 1e-6 tolerance
|
||||
close_to_top_p = torch.isclose(cumulative_probs, torch.tensor(top_p, device=cumulative_probs.device), atol=1e-6)
|
||||
|
||||
# Combine the close_to_top_p with original condition using logical OR
|
||||
condition = (cumulative_probs <= top_p) | close_to_top_p
|
||||
|
||||
# Find the indices with cumulative probabilities that exceed top_p
|
||||
top_p_indices = torch.where(torch.BoolTensor(condition))[0]
|
||||
|
||||
# Map the selected indices back to their original indices
|
||||
original_indices = sorted_indices[top_p_indices]
|
||||
selected_docs = [documents[i.item()] for i in original_indices]
|
||||
|
||||
# If low p resulted in no documents being selected, then
|
||||
# return at least one document
|
||||
if not selected_docs:
|
||||
logger.warning(
|
||||
"Top-p sampling with p=%s resulted in no documents being selected. "
|
||||
"Returning the document with the highest similarity score.",
|
||||
top_p,
|
||||
)
|
||||
highest_prob_indices = torch.argsort(probs, descending=True)
|
||||
selected_docs = [documents[int(highest_prob_indices[0].item())]]
|
||||
|
||||
return {"documents": selected_docs}
|
||||
|
||||
def _collect_scores(self, documents: List[Document]) -> List[float]:
|
||||
"""
|
||||
Collect the scores from the documents' metadata.
|
||||
:param documents: List of Documents.
|
||||
:return: List of scores.
|
||||
"""
|
||||
if self.score_field:
|
||||
missing_scores_docs = [d for d in documents if self.score_field not in d.metadata]
|
||||
if missing_scores_docs:
|
||||
missing_scores_docs_ids = [d.id for d in missing_scores_docs if d.id]
|
||||
raise ComponentError(
|
||||
f"Score field '{self.score_field}' not found in metadata of documents "
|
||||
f"with IDs: {missing_scores_docs_ids}."
|
||||
f"Make sure that all documents have a score field '{self.score_field}' in their metadata."
|
||||
)
|
||||
return [d.metadata[self.score_field] for d in documents]
|
||||
else:
|
||||
missing_scores_docs = [d for d in documents if d.score is None]
|
||||
if missing_scores_docs:
|
||||
missing_scores_docs_ids = [d.id for d in missing_scores_docs if d.id]
|
||||
raise ComponentError(
|
||||
f"Ensure all documents have a valid score value. These docs {missing_scores_docs_ids} don't."
|
||||
)
|
||||
return [d.score for d in documents] # type: ignore ## because Document score is Optional
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Adds TopPSampler, a component selects documents based on the cumulative probability of the Document scores using top p (nucleus) sampling.
|
||||
0
test/preview/components/samplers/__init__.py
Normal file
0
test/preview/components/samplers/__init__.py
Normal file
107
test/preview/components/samplers/test_top_p.py
Normal file
107
test/preview/components/samplers/test_top_p.py
Normal file
@ -0,0 +1,107 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Document, ComponentError
|
||||
from haystack.preview.components.samplers.top_p import TopPSampler
|
||||
|
||||
|
||||
class TestTopPSampler:
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = TopPSampler()
|
||||
data = component.to_dict()
|
||||
assert data == {"type": "TopPSampler", "init_parameters": {"top_p": 1.0, "score_field": None}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
component = TopPSampler(top_p=0.92)
|
||||
data = component.to_dict()
|
||||
assert data == {"type": "TopPSampler", "init_parameters": {"top_p": 0.92, "score_field": None}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {"type": "TopPSampler", "init_parameters": {"top_p": 0.9, "score_field": None}}
|
||||
component = TopPSampler.from_dict(data)
|
||||
assert component.top_p == 0.9
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_scores_from_metadata(self):
|
||||
"""
|
||||
Test if the component runs correctly with scores already in the metadata.
|
||||
"""
|
||||
sampler = TopPSampler(top_p=0.95, score_field="similarity_score")
|
||||
docs = [
|
||||
Document(text="Berlin", metadata={"similarity_score": -10.6}),
|
||||
Document(text="Belgrade", metadata={"similarity_score": -8.9}),
|
||||
Document(text="Sarajevo", metadata={"similarity_score": -4.6}),
|
||||
]
|
||||
output = sampler.run(documents=docs)
|
||||
docs = output["documents"]
|
||||
assert len(docs) == 1
|
||||
assert docs[0].text == "Sarajevo"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_scores(self):
|
||||
"""
|
||||
Test if the component runs correctly with scores in the Document score field.
|
||||
"""
|
||||
sampler = TopPSampler(top_p=0.99)
|
||||
docs = [
|
||||
Document(text="Berlin", score=-10.6),
|
||||
Document(text="Belgrade", score=-8.9),
|
||||
Document(text="Sarajevo", score=-4.6),
|
||||
]
|
||||
|
||||
random.shuffle(docs)
|
||||
sorted_scores = sorted([doc.score for doc in docs], reverse=True)
|
||||
|
||||
# top_p = 0.99 will get the top 1 document
|
||||
output = sampler.run(documents=docs)
|
||||
docs_filtered = output["documents"]
|
||||
assert len(docs_filtered) == 1
|
||||
assert docs_filtered[0].text == "Sarajevo"
|
||||
|
||||
assert [doc.score for doc in docs_filtered] == sorted_scores[:1]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_scores_top_p_1(self):
|
||||
"""
|
||||
Test if the component runs correctly top_p=1.
|
||||
"""
|
||||
sampler = TopPSampler(top_p=1.0)
|
||||
docs = [
|
||||
Document(text="Berlin", score=-10.6),
|
||||
Document(text="Belgrade", score=-8.9),
|
||||
Document(text="Sarajevo", score=-4.6),
|
||||
]
|
||||
|
||||
random.shuffle(docs)
|
||||
output = sampler.run(documents=docs)
|
||||
docs_filtered = output["documents"]
|
||||
assert len(docs_filtered) == len(docs)
|
||||
assert docs_filtered[0].text == "Sarajevo"
|
||||
|
||||
assert [doc.score for doc in docs_filtered] == sorted([doc.score for doc in docs], reverse=True)
|
||||
|
||||
# Returns an empty list if no documents are provided
|
||||
@pytest.mark.unit
|
||||
def test_returns_empty_list_if_no_documents_are_provided(self):
|
||||
sampler = TopPSampler()
|
||||
output = sampler.run(documents=[])
|
||||
assert output["documents"] == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_scores_no_metadata_present(self):
|
||||
"""
|
||||
Test if the component runs correctly with scores missing from the metadata yet being specified in the
|
||||
score_field.
|
||||
"""
|
||||
sampler = TopPSampler(top_p=0.95, score_field="similarity_score")
|
||||
docs = [
|
||||
Document(text="Berlin", score=-10.6),
|
||||
Document(text="Belgrade", score=-8.9),
|
||||
Document(text="Sarajevo", score=-4.6),
|
||||
]
|
||||
with pytest.raises(ComponentError, match="Score field 'similarity_score' not found"):
|
||||
sampler.run(documents=docs)
|
||||
Loading…
x
Reference in New Issue
Block a user