diff --git a/haystack/preview/components/samplers/__init__.py b/haystack/preview/components/samplers/__init__.py new file mode 100644 index 000000000..cab0e878e --- /dev/null +++ b/haystack/preview/components/samplers/__init__.py @@ -0,0 +1,3 @@ +from haystack.preview.components.samplers.top_p import TopPSampler + +__all__ = ["TopPSampler"] diff --git a/haystack/preview/components/samplers/top_p.py b/haystack/preview/components/samplers/top_p.py new file mode 100644 index 000000000..f0c9d6d99 --- /dev/null +++ b/haystack/preview/components/samplers/top_p.py @@ -0,0 +1,140 @@ +import logging +from typing import List, Optional, Dict, Any + +from haystack.preview import ComponentError, Document, component, default_from_dict, default_to_dict +from haystack.preview.lazy_imports import LazyImport + +logger = logging.getLogger(__name__) + + +with LazyImport(message="Run 'pip install torch>=1.13'") as torch_import: + import torch + + +@component +class TopPSampler: + """ + Filters documents using top-p (nucleus) sampling based on their similarity scores' cumulative probability. + + Usage example: + + ```python + from haystack.preview import Document + from haystack.preview.components.samplers import TopPSampler + + sampler = TopPSampler(top_p=0.95) + docs = [ + Document(text="Berlin", metadata={"similarity_score": -10.6}), + Document(text="Belgrade", metadata={"similarity_score": -8.9}), + Document(text="Sarajevo", metadata={"similarity_score": -4.6}), + ] + output = sampler.run(documents=docs) + docs = output["documents"] + assert len(docs) == 1 + assert docs[0].text == "Sarajevo" + ``` + """ + + def __init__(self, top_p: float = 1.0, score_field: Optional[str] = None): + """ + Creates an instance of TopPSampler. + + :param top_p: Cumulative probability threshold (usually between 0.9 and 0.99). + :param score_field: Field name in a document's metadata containing the scores. Defaults to the Document score + if not provided. + """ + torch_import.check() + + self.top_p = top_p + self.score_field = score_field + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + return default_to_dict(self, top_p=self.top_p, score_field=self.score_field) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TopPSampler": + """ + Deserialize this component from a dictionary. + """ + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document], top_p: Optional[float] = None): + """ + Filter documents based on their similarity scores using top-p sampling. + + :param documents: List of Documents to filter. + :param top_p: Cumulative probability threshold. Defaults to the value set during initialization or 1.0 + if not set. + :return: List of filtered Documents. + """ + if not documents: + return {"documents": []} + + top_p = top_p or self.top_p or 1.0 # default to 1.0 if both are None + + if not 0 <= top_p <= 1: + raise ComponentError(f"top_p must be between 0 and 1. Got {top_p}.") + + similarity_scores = torch.tensor(self._collect_scores(documents), dtype=torch.float32) + + # Apply softmax normalization to the similarity scores + probs = torch.nn.functional.softmax(similarity_scores, dim=-1) + + # Sort the probabilities and calculate their cumulative sum + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + + # Check if the cumulative probabilities are close to top_p with a 1e-6 tolerance + close_to_top_p = torch.isclose(cumulative_probs, torch.tensor(top_p, device=cumulative_probs.device), atol=1e-6) + + # Combine the close_to_top_p with original condition using logical OR + condition = (cumulative_probs <= top_p) | close_to_top_p + + # Find the indices with cumulative probabilities that exceed top_p + top_p_indices = torch.where(torch.BoolTensor(condition))[0] + + # Map the selected indices back to their original indices + original_indices = sorted_indices[top_p_indices] + selected_docs = [documents[i.item()] for i in original_indices] + + # If low p resulted in no documents being selected, then + # return at least one document + if not selected_docs: + logger.warning( + "Top-p sampling with p=%s resulted in no documents being selected. " + "Returning the document with the highest similarity score.", + top_p, + ) + highest_prob_indices = torch.argsort(probs, descending=True) + selected_docs = [documents[int(highest_prob_indices[0].item())]] + + return {"documents": selected_docs} + + def _collect_scores(self, documents: List[Document]) -> List[float]: + """ + Collect the scores from the documents' metadata. + :param documents: List of Documents. + :return: List of scores. + """ + if self.score_field: + missing_scores_docs = [d for d in documents if self.score_field not in d.metadata] + if missing_scores_docs: + missing_scores_docs_ids = [d.id for d in missing_scores_docs if d.id] + raise ComponentError( + f"Score field '{self.score_field}' not found in metadata of documents " + f"with IDs: {missing_scores_docs_ids}." + f"Make sure that all documents have a score field '{self.score_field}' in their metadata." + ) + return [d.metadata[self.score_field] for d in documents] + else: + missing_scores_docs = [d for d in documents if d.score is None] + if missing_scores_docs: + missing_scores_docs_ids = [d.id for d in missing_scores_docs if d.id] + raise ComponentError( + f"Ensure all documents have a valid score value. These docs {missing_scores_docs_ids} don't." + ) + return [d.score for d in documents] # type: ignore ## because Document score is Optional diff --git a/releasenotes/notes/add-top-p-sampler-ad6e0f5d623a6bb5.yaml b/releasenotes/notes/add-top-p-sampler-ad6e0f5d623a6bb5.yaml new file mode 100644 index 000000000..729c8b624 --- /dev/null +++ b/releasenotes/notes/add-top-p-sampler-ad6e0f5d623a6bb5.yaml @@ -0,0 +1,4 @@ +--- +preview: + - | + Adds TopPSampler, a component selects documents based on the cumulative probability of the Document scores using top p (nucleus) sampling. diff --git a/test/preview/components/samplers/__init__.py b/test/preview/components/samplers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/preview/components/samplers/test_top_p.py b/test/preview/components/samplers/test_top_p.py new file mode 100644 index 000000000..111dffa25 --- /dev/null +++ b/test/preview/components/samplers/test_top_p.py @@ -0,0 +1,107 @@ +import random + +import pytest + +from haystack.preview import Document, ComponentError +from haystack.preview.components.samplers.top_p import TopPSampler + + +class TestTopPSampler: + @pytest.mark.unit + def test_to_dict(self): + component = TopPSampler() + data = component.to_dict() + assert data == {"type": "TopPSampler", "init_parameters": {"top_p": 1.0, "score_field": None}} + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + component = TopPSampler(top_p=0.92) + data = component.to_dict() + assert data == {"type": "TopPSampler", "init_parameters": {"top_p": 0.92, "score_field": None}} + + @pytest.mark.unit + def test_from_dict(self): + data = {"type": "TopPSampler", "init_parameters": {"top_p": 0.9, "score_field": None}} + component = TopPSampler.from_dict(data) + assert component.top_p == 0.9 + + @pytest.mark.unit + def test_run_scores_from_metadata(self): + """ + Test if the component runs correctly with scores already in the metadata. + """ + sampler = TopPSampler(top_p=0.95, score_field="similarity_score") + docs = [ + Document(text="Berlin", metadata={"similarity_score": -10.6}), + Document(text="Belgrade", metadata={"similarity_score": -8.9}), + Document(text="Sarajevo", metadata={"similarity_score": -4.6}), + ] + output = sampler.run(documents=docs) + docs = output["documents"] + assert len(docs) == 1 + assert docs[0].text == "Sarajevo" + + @pytest.mark.unit + def test_run_scores(self): + """ + Test if the component runs correctly with scores in the Document score field. + """ + sampler = TopPSampler(top_p=0.99) + docs = [ + Document(text="Berlin", score=-10.6), + Document(text="Belgrade", score=-8.9), + Document(text="Sarajevo", score=-4.6), + ] + + random.shuffle(docs) + sorted_scores = sorted([doc.score for doc in docs], reverse=True) + + # top_p = 0.99 will get the top 1 document + output = sampler.run(documents=docs) + docs_filtered = output["documents"] + assert len(docs_filtered) == 1 + assert docs_filtered[0].text == "Sarajevo" + + assert [doc.score for doc in docs_filtered] == sorted_scores[:1] + + @pytest.mark.unit + def test_run_scores_top_p_1(self): + """ + Test if the component runs correctly top_p=1. + """ + sampler = TopPSampler(top_p=1.0) + docs = [ + Document(text="Berlin", score=-10.6), + Document(text="Belgrade", score=-8.9), + Document(text="Sarajevo", score=-4.6), + ] + + random.shuffle(docs) + output = sampler.run(documents=docs) + docs_filtered = output["documents"] + assert len(docs_filtered) == len(docs) + assert docs_filtered[0].text == "Sarajevo" + + assert [doc.score for doc in docs_filtered] == sorted([doc.score for doc in docs], reverse=True) + + # Returns an empty list if no documents are provided + @pytest.mark.unit + def test_returns_empty_list_if_no_documents_are_provided(self): + sampler = TopPSampler() + output = sampler.run(documents=[]) + assert output["documents"] == [] + + @pytest.mark.unit + def test_run_scores_no_metadata_present(self): + """ + Test if the component runs correctly with scores missing from the metadata yet being specified in the + score_field. + """ + sampler = TopPSampler(top_p=0.95, score_field="similarity_score") + docs = [ + Document(text="Berlin", score=-10.6), + Document(text="Belgrade", score=-8.9), + Document(text="Sarajevo", score=-4.6), + ] + with pytest.raises(ComponentError, match="Score field 'similarity_score' not found"): + sampler.run(documents=docs)