test: review integration tests (#9306)

* AzureOCR: convert integration test to unit test and simplify

* clean up HuggingFaceAPITextEmbedder

* clean up LinkContentFetcher

* simplify HuggingFaceLocalGenerator

* clean up OpenAIGenerator

* OpenAIChatGenerator

* SentenceTransformersDiversityRanker

* TransformersSimilarityRanker

* ChatMessage: rm outdated tests

* fail fast false

* typo
This commit is contained in:
Stefano Fiorucci 2025-04-25 09:07:57 +02:00 committed by GitHub
parent f97472329f
commit 38c39a49de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 99 additions and 363 deletions

View File

@ -101,6 +101,7 @@ jobs:
(needs.check-if-changed.outputs.changes == 'true')
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
include:

View File

@ -82,6 +82,22 @@ def get_sample_pdf_2_text(page_layout: Literal["natural", "single_column"]) -> s
)
@pytest.fixture
def mock_poller(test_files_path):
"""Fixture that returns a MockPoller class factory that can be used to create mock pollers for different JSON files."""
class MockPoller:
def __init__(self, json_file: str):
self.json_file = json_file
def result(self) -> AnalyzeResult:
with open(test_files_path / "json" / self.json_file, encoding="utf-8") as azure_file:
result = json.load(azure_file)
return AnalyzeResult.from_dict(result)
return MockPoller
class TestAzureOCRDocumentConverter:
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("AZURE_AI_API_KEY", raising=False)
@ -109,17 +125,11 @@ class TestAzureOCRDocumentConverter:
}
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
def test_azure_converter_with_pdf(self, mock_resolve_value, test_files_path) -> None:
def test_azure_converter_with_pdf(self, mock_resolve_value, test_files_path, mock_poller) -> None:
mock_resolve_value.return_value = "test_api_key"
class MockPoller:
def result(self) -> AnalyzeResult:
with open(test_files_path / "json" / "azure_sample_pdf_2.json", encoding="utf-8") as azure_file:
result = json.load(azure_file)
return AnalyzeResult.from_dict(result)
with patch("azure.ai.formrecognizer.DocumentAnalysisClient.begin_analyze_document") as azure_mock:
azure_mock.return_value = MockPoller()
azure_mock.return_value = mock_poller("azure_sample_pdf_2.json")
ocr_node = AzureOCRDocumentConverter(endpoint="")
out = ocr_node.run(sources=[test_files_path / "pdf" / "sample_pdf_2.pdf"])
assert len(out["documents"]) == 1
@ -129,18 +139,12 @@ class TestAzureOCRDocumentConverter:
@pytest.mark.parametrize("page_layout", ["natural", "single_column"])
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
def test_azure_converter_with_table(
self, mock_resolve_value, page_layout: Literal["natural", "single_column"], test_files_path
self, mock_resolve_value, page_layout: Literal["natural", "single_column"], test_files_path, mock_poller
) -> None:
mock_resolve_value.return_value = "test_api_key"
class MockPoller:
def result(self) -> AnalyzeResult:
with open(test_files_path / "json" / "azure_sample_pdf_1.json", encoding="utf-8") as azure_file:
result = json.load(azure_file)
return AnalyzeResult.from_dict(result)
with patch("azure.ai.formrecognizer.DocumentAnalysisClient.begin_analyze_document") as azure_mock:
azure_mock.return_value = MockPoller()
azure_mock.return_value = mock_poller("azure_sample_pdf_1.json")
ocr_node = AzureOCRDocumentConverter(endpoint="", page_layout=page_layout)
out = ocr_node.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"])
@ -177,17 +181,13 @@ D,$54.35,$6345.,
assert pages[3] == gold_pages[3]
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
def test_azure_converter_with_table_no_bounding_region(self, mock_resolve_value, test_files_path) -> None:
def test_azure_converter_with_table_no_bounding_region(
self, mock_resolve_value, test_files_path, mock_poller
) -> None:
mock_resolve_value.return_value = "test_api_key"
class MockPoller:
def result(self) -> AnalyzeResult:
with open(test_files_path / "json" / "azure_sample_pdf_1.json", encoding="utf-8") as azure_file:
result = json.load(azure_file)
return AnalyzeResult.from_dict(result)
with patch("azure.ai.formrecognizer.DocumentAnalysisClient.begin_analyze_document") as azure_mock:
azure_mock.return_value = MockPoller()
azure_mock.return_value = mock_poller("azure_sample_pdf_1.json")
ocr_node = AzureOCRDocumentConverter(endpoint="")
out = ocr_node.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"])
@ -211,19 +211,14 @@ D,$54.35,$6345.,
assert docs[0].meta["following_context"] == ""
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
def test_azure_converter_with_multicolumn_header_table(self, mock_resolve_value, test_files_path) -> None:
def test_azure_converter_with_multicolumn_header_table(
self, mock_resolve_value, test_files_path, mock_poller
) -> None:
mock_resolve_value.return_value = "test_api_key"
class MockPoller:
def result(self) -> AnalyzeResult:
with open(test_files_path / "json" / "azure_sample_pdf_3.json", encoding="utf-8") as azure_file:
result = json.load(azure_file)
return AnalyzeResult.from_dict(result)
with patch("azure.ai.formrecognizer.DocumentAnalysisClient.begin_analyze_document") as azure_mock:
azure_mock.return_value = MockPoller()
azure_mock.return_value = mock_poller("azure_sample_pdf_3.json")
ocr_node = AzureOCRDocumentConverter(endpoint="")
out = ocr_node.run(sources=[test_files_path / "pdf" / "sample_pdf_3.pdf"])
docs = out["documents"]
@ -237,17 +232,11 @@ D,$54.35,$6345.,
assert docs[0].meta["page"] == 1
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
def test_table_pdf_with_non_empty_meta(self, mock_resolve_value, test_files_path) -> None:
def test_table_pdf_with_non_empty_meta(self, mock_resolve_value, test_files_path, mock_poller) -> None:
mock_resolve_value.return_value = "test_api_key"
class MockPoller:
def result(self) -> AnalyzeResult:
with open(test_files_path / "json" / "azure_sample_pdf_1.json", encoding="utf-8") as azure_file:
result = json.load(azure_file)
return AnalyzeResult.from_dict(result)
with patch("azure.ai.formrecognizer.DocumentAnalysisClient.begin_analyze_document") as azure_mock:
azure_mock.return_value = MockPoller()
azure_mock.return_value = mock_poller("azure_sample_pdf_1.json")
ocr_node = AzureOCRDocumentConverter(endpoint="")
out = ocr_node.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"], meta=[{"test": "value_1"}])
@ -299,33 +288,30 @@ D,$54.35,$6345.,
assert "Now we are in Page 2" in documents[0].content
assert "Page 3 was empty this is page 4" in documents[0].content
@pytest.mark.integration
@pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_ENDPOINT", None), reason="Azure endpoint not available")
@pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available")
def test_run_with_store_full_path_false(self, test_files_path):
component = AzureOCRDocumentConverter(
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"],
api_key=Secret.from_env_var("CORE_AZURE_CS_API_KEY"),
store_full_path=False,
)
output = component.run(sources=[test_files_path / "docx" / "sample_docx.docx"])
documents = output["documents"]
assert len(documents) == 1
assert "Sample Docx File" in documents[0].content
assert documents[0].meta["file_path"] == "sample_docx.docx"
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
def test_meta_from_byte_stream(self, mock_resolve_value, test_files_path) -> None:
def test_run_with_store_full_path_false(self, mock_resolve_value, test_files_path, mock_poller):
mock_resolve_value.return_value = "test_api_key"
class MockPoller:
def result(self) -> AnalyzeResult:
with open(test_files_path / "json" / "azure_sample_pdf_1.json", encoding="utf-8") as azure_file:
result = json.load(azure_file)
return AnalyzeResult.from_dict(result)
with patch("azure.ai.formrecognizer.DocumentAnalysisClient.begin_analyze_document") as azure_mock:
azure_mock.return_value = mock_poller("azure_sample_pdf_1.json")
component = AzureOCRDocumentConverter(
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"],
api_key=Secret.from_env_var("CORE_AZURE_CS_API_KEY"),
store_full_path=False,
)
output = component.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"])
documents = output["documents"]
assert len(documents) == 2
for doc in documents:
assert doc.meta["file_path"] == "sample_pdf_1.pdf"
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
def test_meta_from_byte_stream(self, mock_resolve_value, test_files_path, mock_poller) -> None:
mock_resolve_value.return_value = "test_api_key"
with patch("azure.ai.formrecognizer.DocumentAnalysisClient.begin_analyze_document") as azure_mock:
azure_mock.return_value = MockPoller()
azure_mock.return_value = mock_poller("azure_sample_pdf_1.json")
ocr_node = AzureOCRDocumentConverter(endpoint="")
bytes_ = (test_files_path / "pdf" / "sample_pdf_1.pdf").read_bytes()
byte_stream = ByteStream(data=bytes_, meta={"test_from": "byte_stream"})

View File

@ -2,7 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0
import os
import asyncio
from unittest.mock import MagicMock, patch
import random
@ -230,93 +229,21 @@ class TestHuggingFaceAPITextEmbedder:
assert len(result["embedding"]) == 384
assert all(isinstance(x, float) for x in result["embedding"])
class TestHuggingFaceAPITextEmbedderAsync:
"""
Integration tests for HuggingFaceAPITextEmbedder that verify the async functionality with a real API.
These tests require a valid Hugging Face API token.
"""
@pytest.mark.integration
@pytest.mark.asyncio
@pytest.mark.slow
@pytest.mark.skipif(os.environ.get("HF_API_TOKEN", "") == "", reason="HF_API_TOKEN is not set")
async def test_run_async_with_real_api(self):
"""
Integration test that verifies the async functionality with a real API.
This test requires a valid Hugging Face API token.
"""
# Use a small, reliable model for testing
async def test_live_run_async_serverless(self):
model_name = "sentence-transformers/all-MiniLM-L6-v2"
embedder = HuggingFaceAPITextEmbedder(
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model_name}
)
# Test with a simple text
text = "This is a test sentence for embedding."
result = await embedder.run_async(text=text)
# Verify the result
assert "embedding" in result
assert isinstance(result["embedding"], list)
assert all(isinstance(x, float) for x in result["embedding"])
assert len(result["embedding"]) == 384 # MiniLM-L6-v2 has 384 dimensions
# Test with a longer text
long_text = "This is a longer test sentence for embedding. " * 10
result = await embedder.run_async(text=long_text)
# Verify the result
assert "embedding" in result
assert isinstance(result["embedding"], list)
assert all(isinstance(x, float) for x in result["embedding"])
assert len(result["embedding"]) == 384
# Test with prefix and suffix
embedder_with_prefix_suffix = HuggingFaceAPITextEmbedder(
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": model_name},
prefix="prefix: ",
suffix=" :suffix",
)
result = await embedder_with_prefix_suffix.run_async(text=text)
# Verify the result
assert "embedding" in result
assert isinstance(result["embedding"], list)
assert all(isinstance(x, float) for x in result["embedding"])
assert len(result["embedding"]) == 384
@pytest.mark.integration
@pytest.mark.asyncio
@pytest.mark.slow
@pytest.mark.skipif(os.environ.get("HF_API_TOKEN", "") == "", reason="HF_API_TOKEN is not set")
async def test_run_async_concurrent_requests(self):
"""
Integration test that verifies the async functionality with concurrent requests.
This test requires a valid Hugging Face API token.
"""
model_name = "sentence-transformers/all-MiniLM-L6-v2"
embedder = HuggingFaceAPITextEmbedder(
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model_name}
)
texts = [
"This is the first test sentence.",
"This is the second test sentence.",
"This is the third test sentence.",
"This is the fourth test sentence.",
"This is the fifth test sentence.",
]
# Run concurrent requests
tasks = [embedder.run_async(text=text) for text in texts]
results = await asyncio.gather(*tasks)
for i, result in enumerate(results):
assert "embedding" in result
assert isinstance(result["embedding"], list)
assert all(isinstance(x, float) for x in result["embedding"])
assert len(result["embedding"]) == 384 # MiniLM-L6-v2 has 384 dimensions

View File

@ -120,7 +120,7 @@ class TestLinkContentFetcher:
assert first_stream.meta["content_type"] == "application/pdf"
assert first_stream.mime_type == "application/pdf"
def test_run_bad_status_code(self):
def test_run_bad_request_no_exception(self):
"""Test behavior when a request results in an error status code"""
empty_byte_stream = b""
fetcher = LinkContentFetcher(raise_on_failure=False, retry_attempts=0)
@ -140,6 +140,23 @@ class TestLinkContentFetcher:
assert first_stream.meta["content_type"] == "text/html"
assert first_stream.mime_type == "text/html"
def test_bad_request_exception_raised(self):
"""
This test is to ensure that the fetcher raises an exception when a single bad request is made and it is configured to
do so.
"""
fetcher = LinkContentFetcher(raise_on_failure=True, retry_attempts=0)
mock_response = Mock(status_code=403)
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"403 Client Error", request=Mock(), response=mock_response
)
with patch("haystack.components.fetchers.link_content.httpx.Client.get") as mock_get:
mock_get.return_value = mock_response
with pytest.raises(httpx.HTTPStatusError):
fetcher.run(["https://non_existent_website_dot.com/"])
@pytest.mark.integration
def test_link_content_fetcher_html(self):
"""
@ -166,19 +183,6 @@ class TestLinkContentFetcher:
assert "url" in first_stream.meta and first_stream.meta["url"] == TEXT_URL
assert first_stream.mime_type == "text/plain"
@pytest.mark.integration
def test_link_content_fetcher_pdf(self):
"""
Test fetching PDF content from a real URL.
"""
fetcher = LinkContentFetcher()
streams = fetcher.run([PDF_URL])["streams"]
assert len(streams) == 1
first_stream = streams[0]
assert first_stream.meta["content_type"] in ("application/octet-stream", "application/pdf")
assert "url" in first_stream.meta and first_stream.meta["url"] == PDF_URL
assert first_stream.mime_type in ("application/octet-stream", "application/pdf")
@pytest.mark.integration
def test_link_content_fetcher_multiple_different_content_types(self):
"""
@ -222,35 +226,13 @@ class TestLinkContentFetcher:
In such a case, the fetcher should return the content of the URLs that were successfully fetched and not raise
an exception.
"""
fetcher = LinkContentFetcher()
fetcher = LinkContentFetcher(retry_attempts=0)
result = fetcher.run(["https://non_existent_website_dot.com/", "https://www.google.com/"])
assert len(result["streams"]) == 1
first_stream = result["streams"][0]
assert first_stream.meta["content_type"] == "text/html"
assert first_stream.mime_type == "text/html"
@pytest.mark.integration
def test_bad_request_exception_raised(self):
"""
This test is to ensure that the fetcher raises an exception when a single bad request is made and it is configured to
do so.
"""
fetcher = LinkContentFetcher()
with pytest.raises((httpx.ConnectError, httpx.ConnectTimeout)):
fetcher.run(["https://non_existent_website_dot.com/"])
@pytest.mark.integration
def test_link_content_fetcher_audio(self):
"""
Test fetching audio content from a real URL.
"""
fetcher = LinkContentFetcher()
streams = fetcher.run(["https://download.samplelib.com/mp3/sample-3s.mp3"])["streams"]
first_stream = streams[0]
assert first_stream.meta["content_type"] == "audio/mpeg"
assert first_stream.mime_type == "audio/mpeg"
assert len(first_stream.data) > 0
class TestLinkContentFetcherAsync:
@pytest.mark.asyncio
@ -337,17 +319,6 @@ class TestLinkContentFetcherAsync:
assert len(streams) == 1
assert streams[0].data == b"Success"
@pytest.mark.asyncio
@pytest.mark.integration
async def test_run_async_integration(self):
"""Test async fetching with real HTTP requests"""
fetcher = LinkContentFetcher()
streams = (await fetcher.run_async([HTML_URL]))["streams"]
first_stream = streams[0]
assert "Haystack" in first_stream.data.decode("utf-8")
assert first_stream.meta["content_type"] == "text/html"
assert first_stream.mime_type == "text/html"
@pytest.mark.asyncio
@pytest.mark.integration
async def test_run_async_multiple_integration(self):

View File

@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import patch, MagicMock, AsyncMock
from unittest.mock import patch, MagicMock
import pytest
@ -894,15 +894,16 @@ class TestOpenAIChatGenerator:
assert message.meta["finish_reason"] == "stop"
assert message.meta["usage"]["prompt_tokens"] > 0
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_wrong_model(self, chat_messages):
component = OpenAIChatGenerator(model="something-obviously-wrong")
async def test_run_with_wrong_model(self):
mock_client = MagicMock()
mock_client.chat.completions.create.side_effect = OpenAIError("Invalid model name")
generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), model="something-obviously-wrong")
generator.client = mock_client
with pytest.raises(OpenAIError):
component.run(chat_messages)
generator.run([ChatMessage.from_user("irrelevant")])
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
@ -944,27 +945,6 @@ class TestOpenAIChatGenerator:
assert message.meta["usage"]["completion_tokens"] > 0
assert message.meta["usage"]["total_tokens"] > 0
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_with_tools(self, tools):
chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
component = OpenAIChatGenerator(tools=tools)
results = component.run(chat_messages)
assert len(results["replies"]) == 1
message = results["replies"][0]
assert not message.texts
assert not message.text
assert message.tool_calls
tool_call = message.tool_call
assert isinstance(tool_call, ToolCall)
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Paris"}
assert message.meta["finish_reason"] == "tool_calls"
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",

View File

@ -293,16 +293,17 @@ class TestOpenAIChatGeneratorAsync:
assert "gpt-4o" in message.meta["model"]
assert message.meta["finish_reason"] == "stop"
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
@pytest.mark.asyncio
async def test_live_run_wrong_model_async(self, chat_messages):
component = OpenAIChatGenerator(model="something-obviously-wrong")
async def test_run_with_wrong_model_async(self):
mock_client = MagicMock()
mock_client.chat.completions.create.side_effect = OpenAIError("Invalid model name")
generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), model="something-obviously-wrong")
generator.client = mock_client
with pytest.raises(OpenAIError):
await component.run_async(chat_messages)
await generator.run_async([ChatMessage.from_user("irrelevant")])
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),

View File

@ -439,7 +439,6 @@ class TestHuggingFaceLocalGenerator:
Test that StopWordsCriteria catches stop word tokens in a continuous and sequential order in the input_ids
using a real Huggingface tokenizer.
"""
from transformers import AutoTokenizer
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -462,8 +461,6 @@ class TestHuggingFaceLocalGenerator:
model="google/flan-t5-small", task="text2text-generation", stop_words=["unambiguously"]
)
generator.warm_up()
results = generator.run(prompt="something that triggers something")
assert results["replies"] != []
assert generator.stopping_criteria_list is not None
@pytest.mark.integration

View File

@ -1,11 +1,10 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
import logging
import os
from typing import List
from datetime import datetime
import pytest
from openai import OpenAIError
from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
@ -255,54 +254,16 @@ class TestOpenAIGenerator:
assert "completion_tokens" in metadata["usage"] and metadata["usage"]["completion_tokens"] > 0
assert "total_tokens" in metadata["usage"] and metadata["usage"]["total_tokens"] > 0
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_wrong_model(self):
component = OpenAIGenerator(model="something-obviously-wrong")
def test_run_with_wrong_model(self):
mock_client = MagicMock()
mock_client.chat.completions.create.side_effect = OpenAIError("Invalid model name")
generator = OpenAIGenerator(api_key=Secret.from_token("test-api-key"), model="something-obviously-wrong")
generator.client = mock_client
with pytest.raises(OpenAIError):
component.run("Whatever")
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_streaming(self):
class Callback:
def __init__(self):
self.responses = ""
self.counter = 0
def __call__(self, chunk: StreamingChunk) -> None:
self.counter += 1
self.responses += chunk.content if chunk.content else ""
callback = Callback()
component = OpenAIGenerator(streaming_callback=callback)
results = component.run("What's the capital of France?")
assert len(results["replies"]) == 1
assert len(results["meta"]) == 1
response: str = results["replies"][0]
assert "Paris" in response
metadata = results["meta"][0]
assert "gpt-4o-mini" in metadata["model"]
assert metadata["finish_reason"] == "stop"
assert "completion_start_time" in metadata
assert datetime.fromisoformat(metadata["completion_start_time"]) <= datetime.now()
# unfortunately, the usage is not available for streaming calls
# we keep the key in the metadata for compatibility
assert "usage" in metadata and len(metadata["usage"]) == 0
assert callback.counter > 1
assert "Paris" in callback.responses
generator.run("Whatever")
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),

View File

@ -576,35 +576,6 @@ class TestSentenceTransformersDiversityRanker:
pipe_serialized = pipe.dumps()
assert Pipeline.loads(pipe_serialized) == pipe
@pytest.mark.integration
@pytest.mark.slow
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run(self, similarity, monkeypatch):
"""
Tests that run method returns documents in the correct order
"""
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
)
ranker.warm_up()
query = "city"
documents = [
Document(content="France"),
Document(content="Germany"),
Document(content="Eiffel Tower"),
Document(content="Berlin"),
Document(content="Bananas"),
Document(content="Silicon Valley"),
Document(content="Brandenburg Gate"),
]
result = ranker.run(query=query, documents=documents)
ranked_docs = result["documents"]
ranked_order = ", ".join([doc.content for doc in ranked_docs])
expected_order = "Berlin, Bananas, Eiffel Tower, Silicon Valley, France, Brandenburg Gate, Germany"
assert ranked_order == expected_order
@pytest.mark.integration
@pytest.mark.slow
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])

View File

@ -361,32 +361,6 @@ class TestSimilarityRanker:
assert docs_after[1].score == pytest.approx(sorted_scores[1], abs=1e-6)
assert docs_after[2].score == pytest.approx(sorted_scores[2], abs=1e-6)
@pytest.mark.integration
@pytest.mark.slow
def test_run_small_batch_size(self):
"""
Test if the component ranks documents correctly.
"""
ranker = TransformersSimilarityRanker(model="cross-encoder/ms-marco-MiniLM-L-6-v2", batch_size=2)
ranker.warm_up()
query = "City in Bosnia and Herzegovina"
docs_before_texts = ["Berlin", "Belgrade", "Sarajevo"]
expected_first_text = "Sarajevo"
expected_scores = [2.2864143829792738e-05, 0.00012495707778725773, 0.009869757108390331]
docs_before = [Document(content=text) for text in docs_before_texts]
output = ranker.run(query=query, documents=docs_before)
docs_after = output["documents"]
assert len(docs_after) == 3
assert docs_after[0].content == expected_first_text
sorted_scores = sorted(expected_scores, reverse=True)
assert docs_after[0].score == pytest.approx(sorted_scores[0], abs=1e-6)
assert docs_after[1].score == pytest.approx(sorted_scores[1], abs=1e-6)
assert docs_after[2].score == pytest.approx(sorted_scores[2], abs=1e-6)
def test_returns_empty_list_if_no_documents_are_provided(self):
sampler = TransformersSimilarityRanker()
sampler.model = MagicMock()

View File

@ -2,7 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from transformers import AutoTokenizer
import json
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ToolCall, ToolCallResult, TextContent
@ -430,35 +429,3 @@ def test_from_openai_dict_format_unsupported_role():
def test_from_openai_dict_format_assistant_missing_content_and_tool_calls():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"role": "assistant", "irrelevant": "irrelevant"})
@pytest.mark.integration
def test_apply_chat_templating_on_chat_message():
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
formatted_messages = [m.to_openai_dict_format() for m in messages]
tokenized_messages = tokenizer.apply_chat_template(formatted_messages, tokenize=False)
assert tokenized_messages == "<|system|>\nYou are good assistant</s>\n<|user|>\nI have a question</s>\n"
@pytest.mark.integration
def test_apply_custom_chat_templating_on_chat_message():
anthropic_template = (
"{%- for message in messages %}"
"{%- if message.role == 'user' %}\n\nHuman: {{ message.content.strip() }}"
"{%- elif message.role == 'assistant' %}\n\nAssistant: {{ message.content.strip() }}"
"{%- elif message.role == 'function' %}{{ raise('anthropic does not support function calls.') }}"
"{%- elif message.role == 'system' and loop.index == 1 %}{{ message.content }}"
"{%- else %}{{ raise('Invalid message role: ' + message.role) }}"
"{%- endif %}"
"{%- endfor %}"
"\n\nAssistant:"
)
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
# could be any tokenizer, let's use the one we already likely have in cache
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
formatted_messages = [m.to_openai_dict_format() for m in messages]
tokenized_messages = tokenizer.apply_chat_template(
formatted_messages, chat_template=anthropic_template, tokenize=False
)
assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:"