mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-18 22:42:24 +00:00

* Add feature to Top P Sampler * Add release notes * Fix zip call * Fix mypy * Restore doc string and make mypy happy hopefully * Make mypy happy * PR comment * Revert change to make mypy happy * Add back type ignore * try to fix typing * Update haystack/components/samplers/top_p.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/samplers/top_p.py --------- Co-authored-by: anakin87 <stefanofiorucci@gmail.com> Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
132 lines
5.2 KiB
Python
132 lines
5.2 KiB
Python
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
import random
|
|
from typing import List
|
|
|
|
import pytest
|
|
from haystack import Document
|
|
from haystack.components.samplers.top_p import TopPSampler
|
|
|
|
|
|
@pytest.fixture
|
|
def documents_with_score_field() -> List[Document]:
|
|
return [
|
|
Document(content="Sarajevo", meta={"similarity_score": 0.7}),
|
|
Document(content="Belgrade", meta={"similarity_score": 0.01}),
|
|
Document(content="Berlin", meta={"similarity_score": 0.001}),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def documents_with_score() -> List[Document]:
|
|
return [
|
|
Document(content="Sarajevo", score=0.7),
|
|
Document(content="Belgrade", score=0.01),
|
|
Document(content="Berlin", score=0.001),
|
|
]
|
|
|
|
|
|
class TestTopPSampler:
|
|
def test_init_raises_value_error(self) -> None:
|
|
with pytest.raises(ValueError):
|
|
TopPSampler(top_p=2.0)
|
|
|
|
def test_run_raises_value_error(self, documents_with_score: List[Document]) -> None:
|
|
sampler = TopPSampler(top_p=0.95)
|
|
with pytest.raises(ValueError):
|
|
sampler.run(documents=documents_with_score, top_p=2.0)
|
|
|
|
def test_run_score_field(self, documents_with_score_field: List[Document]) -> None:
|
|
sampler = TopPSampler(top_p=0.95, score_field="similarity_score")
|
|
docs = documents_with_score_field
|
|
output = sampler.run(documents=docs)
|
|
docs = output["documents"]
|
|
assert len(docs) == 2
|
|
assert docs[0].content == "Sarajevo"
|
|
assert docs[1].content == "Belgrade"
|
|
|
|
def test_run_score_field_missing_scores(self, caplog: pytest.LogCaptureFixture) -> None:
|
|
sampler = TopPSampler(top_p=1.0, score_field="similarity_score")
|
|
docs = [
|
|
Document(content="Sarajevo", meta={"similarity_score": 0.7}),
|
|
Document(content="Belgrade", meta={"similarity_score": 0.01}),
|
|
Document(content="Berlin", meta={"similarity_score": None}),
|
|
]
|
|
output = sampler.run(documents=docs)
|
|
docs = output["documents"]
|
|
assert len(docs) == 2
|
|
assert docs[0].content == "Sarajevo"
|
|
assert docs[1].content == "Belgrade"
|
|
assert "Score field" in caplog.text
|
|
|
|
def test_run(self, documents_with_score: List[Document]) -> None:
|
|
sampler = TopPSampler(top_p=0.99)
|
|
docs = documents_with_score
|
|
random.shuffle(docs)
|
|
sorted_scores = sorted([doc.score for doc in docs], reverse=True)
|
|
|
|
# top_p = 0.99 will get the top 1 document
|
|
output = sampler.run(documents=docs)
|
|
docs_filtered = output["documents"]
|
|
assert len(docs_filtered) == 2
|
|
assert docs_filtered[0].content == "Sarajevo"
|
|
assert docs_filtered[1].content == "Belgrade"
|
|
|
|
assert [doc.score for doc in docs_filtered] == sorted_scores[:2]
|
|
|
|
def test_run_top_p_1(self, documents_with_score: List[Document]) -> None:
|
|
sampler = TopPSampler(top_p=1.0)
|
|
docs = documents_with_score
|
|
random.shuffle(docs)
|
|
output = sampler.run(documents=docs)
|
|
docs_filtered = output["documents"]
|
|
assert len(docs_filtered) == len(docs)
|
|
assert docs_filtered[0].content == "Sarajevo"
|
|
assert [doc.score for doc in docs_filtered] == sorted([doc.score for doc in docs], reverse=True)
|
|
|
|
def test_run_top_p_0(self, caplog: pytest.LogCaptureFixture, documents_with_score: List[Document]) -> None:
|
|
sampler = TopPSampler(top_p=0.0)
|
|
docs = documents_with_score
|
|
output = sampler.run(documents=docs)
|
|
docs = output["documents"]
|
|
assert len(docs) == 1
|
|
assert docs[0].content == "Sarajevo"
|
|
assert "Top-p sampling with p=" in caplog.text
|
|
|
|
def test_run_returns_empty_list_no_documents(self) -> None:
|
|
sampler = TopPSampler()
|
|
output = sampler.run(documents=[])
|
|
assert output["documents"] == []
|
|
|
|
def test_run_no_score_field(self, caplog: pytest.LogCaptureFixture, documents_with_score: List[Document]) -> None:
|
|
sampler = TopPSampler(top_p=0.95, score_field="similarity_score")
|
|
docs = documents_with_score
|
|
output = sampler.run(documents=docs)
|
|
docs = output["documents"]
|
|
assert len(docs) == 3
|
|
assert docs[0].content == "Sarajevo"
|
|
assert "Score field 'similarity_score' not found" in caplog.text
|
|
|
|
def test_run_missing_scores(self, caplog: pytest.LogCaptureFixture) -> None:
|
|
sampler = TopPSampler(top_p=0.95)
|
|
docs = [
|
|
Document(content="Sarajevo", score=0.7),
|
|
Document(content="Belgrade", score=0.01),
|
|
Document(content="Berlin", score=None),
|
|
]
|
|
output = sampler.run(documents=docs)
|
|
docs = output["documents"]
|
|
assert len(docs) == 1
|
|
assert docs[0].content == "Sarajevo"
|
|
assert "Ensure all documents have a valid score value" in caplog.text
|
|
|
|
def test_run_min_top_k(self, documents_with_score: List[Document]) -> None:
|
|
sampler = TopPSampler(min_top_k=2, top_p=0.2)
|
|
docs = documents_with_score
|
|
output = sampler.run(documents=docs)
|
|
docs = output["documents"]
|
|
assert len(docs) == 2
|
|
assert docs[0].content == "Sarajevo"
|
|
assert docs[1].content == "Belgrade"
|