feat: Add TopPSampler Haystack 2.0 component (#5924)

This commit is contained in:
Vladimir Blagojevic 2023-10-09 13:44:01 +02:00 committed by GitHub
parent 0cb9abb1c2
commit 40b83d8a47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 254 additions and 0 deletions

View File

@ -0,0 +1,3 @@
from haystack.preview.components.samplers.top_p import TopPSampler
__all__ = ["TopPSampler"]

View File

@ -0,0 +1,140 @@
import logging
from typing import List, Optional, Dict, Any
from haystack.preview import ComponentError, Document, component, default_from_dict, default_to_dict
from haystack.preview.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
with LazyImport(message="Run 'pip install torch>=1.13'") as torch_import:
import torch
@component
class TopPSampler:
"""
Filters documents using top-p (nucleus) sampling based on their similarity scores' cumulative probability.
Usage example:
```python
from haystack.preview import Document
from haystack.preview.components.samplers import TopPSampler
sampler = TopPSampler(top_p=0.95)
docs = [
Document(text="Berlin", metadata={"similarity_score": -10.6}),
Document(text="Belgrade", metadata={"similarity_score": -8.9}),
Document(text="Sarajevo", metadata={"similarity_score": -4.6}),
]
output = sampler.run(documents=docs)
docs = output["documents"]
assert len(docs) == 1
assert docs[0].text == "Sarajevo"
```
"""
def __init__(self, top_p: float = 1.0, score_field: Optional[str] = None):
"""
Creates an instance of TopPSampler.
:param top_p: Cumulative probability threshold (usually between 0.9 and 0.99).
:param score_field: Field name in a document's metadata containing the scores. Defaults to the Document score
if not provided.
"""
torch_import.check()
self.top_p = top_p
self.score_field = score_field
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, top_p=self.top_p, score_field=self.score_field)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TopPSampler":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
def run(self, documents: List[Document], top_p: Optional[float] = None):
"""
Filter documents based on their similarity scores using top-p sampling.
:param documents: List of Documents to filter.
:param top_p: Cumulative probability threshold. Defaults to the value set during initialization or 1.0
if not set.
:return: List of filtered Documents.
"""
if not documents:
return {"documents": []}
top_p = top_p or self.top_p or 1.0 # default to 1.0 if both are None
if not 0 <= top_p <= 1:
raise ComponentError(f"top_p must be between 0 and 1. Got {top_p}.")
similarity_scores = torch.tensor(self._collect_scores(documents), dtype=torch.float32)
# Apply softmax normalization to the similarity scores
probs = torch.nn.functional.softmax(similarity_scores, dim=-1)
# Sort the probabilities and calculate their cumulative sum
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Check if the cumulative probabilities are close to top_p with a 1e-6 tolerance
close_to_top_p = torch.isclose(cumulative_probs, torch.tensor(top_p, device=cumulative_probs.device), atol=1e-6)
# Combine the close_to_top_p with original condition using logical OR
condition = (cumulative_probs <= top_p) | close_to_top_p
# Find the indices with cumulative probabilities that exceed top_p
top_p_indices = torch.where(torch.BoolTensor(condition))[0]
# Map the selected indices back to their original indices
original_indices = sorted_indices[top_p_indices]
selected_docs = [documents[i.item()] for i in original_indices]
# If low p resulted in no documents being selected, then
# return at least one document
if not selected_docs:
logger.warning(
"Top-p sampling with p=%s resulted in no documents being selected. "
"Returning the document with the highest similarity score.",
top_p,
)
highest_prob_indices = torch.argsort(probs, descending=True)
selected_docs = [documents[int(highest_prob_indices[0].item())]]
return {"documents": selected_docs}
def _collect_scores(self, documents: List[Document]) -> List[float]:
"""
Collect the scores from the documents' metadata.
:param documents: List of Documents.
:return: List of scores.
"""
if self.score_field:
missing_scores_docs = [d for d in documents if self.score_field not in d.metadata]
if missing_scores_docs:
missing_scores_docs_ids = [d.id for d in missing_scores_docs if d.id]
raise ComponentError(
f"Score field '{self.score_field}' not found in metadata of documents "
f"with IDs: {missing_scores_docs_ids}."
f"Make sure that all documents have a score field '{self.score_field}' in their metadata."
)
return [d.metadata[self.score_field] for d in documents]
else:
missing_scores_docs = [d for d in documents if d.score is None]
if missing_scores_docs:
missing_scores_docs_ids = [d.id for d in missing_scores_docs if d.id]
raise ComponentError(
f"Ensure all documents have a valid score value. These docs {missing_scores_docs_ids} don't."
)
return [d.score for d in documents] # type: ignore ## because Document score is Optional

View File

@ -0,0 +1,4 @@
---
preview:
- |
Adds TopPSampler, a component selects documents based on the cumulative probability of the Document scores using top p (nucleus) sampling.

View File

@ -0,0 +1,107 @@
import random
import pytest
from haystack.preview import Document, ComponentError
from haystack.preview.components.samplers.top_p import TopPSampler
class TestTopPSampler:
@pytest.mark.unit
def test_to_dict(self):
component = TopPSampler()
data = component.to_dict()
assert data == {"type": "TopPSampler", "init_parameters": {"top_p": 1.0, "score_field": None}}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
component = TopPSampler(top_p=0.92)
data = component.to_dict()
assert data == {"type": "TopPSampler", "init_parameters": {"top_p": 0.92, "score_field": None}}
@pytest.mark.unit
def test_from_dict(self):
data = {"type": "TopPSampler", "init_parameters": {"top_p": 0.9, "score_field": None}}
component = TopPSampler.from_dict(data)
assert component.top_p == 0.9
@pytest.mark.unit
def test_run_scores_from_metadata(self):
"""
Test if the component runs correctly with scores already in the metadata.
"""
sampler = TopPSampler(top_p=0.95, score_field="similarity_score")
docs = [
Document(text="Berlin", metadata={"similarity_score": -10.6}),
Document(text="Belgrade", metadata={"similarity_score": -8.9}),
Document(text="Sarajevo", metadata={"similarity_score": -4.6}),
]
output = sampler.run(documents=docs)
docs = output["documents"]
assert len(docs) == 1
assert docs[0].text == "Sarajevo"
@pytest.mark.unit
def test_run_scores(self):
"""
Test if the component runs correctly with scores in the Document score field.
"""
sampler = TopPSampler(top_p=0.99)
docs = [
Document(text="Berlin", score=-10.6),
Document(text="Belgrade", score=-8.9),
Document(text="Sarajevo", score=-4.6),
]
random.shuffle(docs)
sorted_scores = sorted([doc.score for doc in docs], reverse=True)
# top_p = 0.99 will get the top 1 document
output = sampler.run(documents=docs)
docs_filtered = output["documents"]
assert len(docs_filtered) == 1
assert docs_filtered[0].text == "Sarajevo"
assert [doc.score for doc in docs_filtered] == sorted_scores[:1]
@pytest.mark.unit
def test_run_scores_top_p_1(self):
"""
Test if the component runs correctly top_p=1.
"""
sampler = TopPSampler(top_p=1.0)
docs = [
Document(text="Berlin", score=-10.6),
Document(text="Belgrade", score=-8.9),
Document(text="Sarajevo", score=-4.6),
]
random.shuffle(docs)
output = sampler.run(documents=docs)
docs_filtered = output["documents"]
assert len(docs_filtered) == len(docs)
assert docs_filtered[0].text == "Sarajevo"
assert [doc.score for doc in docs_filtered] == sorted([doc.score for doc in docs], reverse=True)
# Returns an empty list if no documents are provided
@pytest.mark.unit
def test_returns_empty_list_if_no_documents_are_provided(self):
sampler = TopPSampler()
output = sampler.run(documents=[])
assert output["documents"] == []
@pytest.mark.unit
def test_run_scores_no_metadata_present(self):
"""
Test if the component runs correctly with scores missing from the metadata yet being specified in the
score_field.
"""
sampler = TopPSampler(top_p=0.95, score_field="similarity_score")
docs = [
Document(text="Berlin", score=-10.6),
Document(text="Belgrade", score=-8.9),
Document(text="Sarajevo", score=-4.6),
]
with pytest.raises(ComponentError, match="Score field 'similarity_score' not found"):
sampler.run(documents=docs)