haystack/test/nodes/test_summarizer.py
ZanSara ae04ce3c6a
test: mock all Summarizer tests and move a few into e2e (#4299)
* stub e2e folders

* simplify pipeline test

* mocking

* unit tests fixed

* clean up e2e

* pipeline tests work

* pylint

* leftover

* small fix from #2994 and additional tests

* review feedback

* change summaries

* black

* revert models and summaries
2023-03-01 17:30:55 +01:00

86 lines
2.9 KiB
Python

import pytest
import haystack
from haystack.utils.torch_utils import ListDataset
from haystack.schema import Document
from haystack.nodes import TransformersSummarizer
DOCS = [Document(content=doc) for doc in ["First test doc", "Second test doc"]]
EXPECTED_SUMMARIES = ["First summary", "Second summary"]
SUMMARIZED_DOCS = [
Document(content=doc.content, meta={"summary": summary}) for doc, summary in zip(DOCS, EXPECTED_SUMMARIES)
]
class MockHFPipeline:
def __init__(self, *a, **k):
pass
def __call__(self, docs, *a, **k):
summaries = [{"summary_text": summary} for summary in EXPECTED_SUMMARIES]
if isinstance(docs, ListDataset):
return [summaries for _ in docs]
return summaries
def tokenizer(self, *a, **k):
return {"input_ids": []}
@pytest.fixture
def mock_models(monkeypatch):
monkeypatch.setattr(haystack.nodes.summarizer.transformers, "pipeline", MockHFPipeline)
@pytest.fixture
def summarizer(mock_models) -> TransformersSummarizer:
return TransformersSummarizer(model_name_or_path="irrelevant/anyway", use_gpu=False)
@pytest.mark.unit
def test_summarization_no_docs(summarizer):
with pytest.raises(ValueError, match="at least one document"):
summarizer.predict(documents=[])
with pytest.raises(ValueError, match="at least one document"):
summarizer.predict_batch(documents=[])
@pytest.mark.unit
def test_summarization_no_docs(summarizer):
summarizer.min_length = 10
summarizer.max_length = 1
with pytest.raises(ValueError, match="min_length cannot be greater than max_length"):
summarizer.predict(documents=DOCS)
@pytest.mark.unit
def test_summarization_one_doc(summarizer):
summarized_docs = summarizer.predict(documents=[DOCS[0]])
assert len(summarized_docs) == 1
assert EXPECTED_SUMMARIES[0] == summarized_docs[0].meta["summary"]
@pytest.mark.unit
def test_summarization_more_docs(summarizer):
summarized_docs = summarizer.predict(documents=DOCS)
assert len(summarized_docs) == len(DOCS)
for expected_summary, summary in zip(EXPECTED_SUMMARIES, summarized_docs):
assert expected_summary == summary.meta["summary"]
@pytest.mark.unit
def test_summarization_batch_single_doc_list(summarizer):
summarized_docs = summarizer.predict_batch(documents=DOCS)
assert len(summarized_docs) == len(DOCS)
for expected_summary, summary in zip(EXPECTED_SUMMARIES, summarized_docs):
assert expected_summary == summary.meta["summary"]
@pytest.mark.unit
def test_summarization_batch_multiple_doc_lists(summarizer):
summarized_docs = summarizer.predict_batch(documents=[DOCS, DOCS])
assert len(summarized_docs) == 2 # Number of document lists
assert len(summarized_docs[0]) == len(DOCS)
for expected_summary, summary in zip(EXPECTED_SUMMARIES, summarized_docs[0]):
assert expected_summary == summary.meta["summary"]