mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-24 13:38:53 +00:00
test: review integration tests (#9306)
* AzureOCR: convert integration test to unit test and simplify * clean up HuggingFaceAPITextEmbedder * clean up LinkContentFetcher * simplify HuggingFaceLocalGenerator * clean up OpenAIGenerator * OpenAIChatGenerator * SentenceTransformersDiversityRanker * TransformersSimilarityRanker * ChatMessage: rm outdated tests * fail fast false * typo
This commit is contained in:
parent
f97472329f
commit
38c39a49de
1
.github/workflows/slow.yml
vendored
1
.github/workflows/slow.yml
vendored
@ -101,6 +101,7 @@ jobs:
|
||||
(needs.check-if-changed.outputs.changes == 'true')
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
include:
|
||||
|
||||
@ -82,6 +82,22 @@ def get_sample_pdf_2_text(page_layout: Literal["natural", "single_column"]) -> s
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_poller(test_files_path):
|
||||
"""Fixture that returns a MockPoller class factory that can be used to create mock pollers for different JSON files."""
|
||||
|
||||
class MockPoller:
|
||||
def __init__(self, json_file: str):
|
||||
self.json_file = json_file
|
||||
|
||||
def result(self) -> AnalyzeResult:
|
||||
with open(test_files_path / "json" / self.json_file, encoding="utf-8") as azure_file:
|
||||
result = json.load(azure_file)
|
||||
return AnalyzeResult.from_dict(result)
|
||||
|
||||
return MockPoller
|
||||
|
||||
|
||||
class TestAzureOCRDocumentConverter:
|
||||
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("AZURE_AI_API_KEY", raising=False)
|
||||
@ -109,17 +125,11 @@ class TestAzureOCRDocumentConverter:
|
||||
}
|
||||
|
||||
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
|
||||
def test_azure_converter_with_pdf(self, mock_resolve_value, test_files_path) -> None:
|
||||
def test_azure_converter_with_pdf(self, mock_resolve_value, test_files_path, mock_poller) -> None:
|
||||
mock_resolve_value.return_value = "test_api_key"
|
||||
|
||||
class MockPoller:
|
||||
def result(self) -> AnalyzeResult:
|
||||
with open(test_files_path / "json" / "azure_sample_pdf_2.json", encoding="utf-8") as azure_file:
|
||||
result = json.load(azure_file)
|
||||
return AnalyzeResult.from_dict(result)
|
||||
|
||||
with patch("azure.ai.formrecognizer.DocumentAnalysisClient.begin_analyze_document") as azure_mock:
|
||||
azure_mock.return_value = MockPoller()
|
||||
azure_mock.return_value = mock_poller("azure_sample_pdf_2.json")
|
||||
ocr_node = AzureOCRDocumentConverter(endpoint="")
|
||||
out = ocr_node.run(sources=[test_files_path / "pdf" / "sample_pdf_2.pdf"])
|
||||
assert len(out["documents"]) == 1
|
||||
@ -129,18 +139,12 @@ class TestAzureOCRDocumentConverter:
|
||||
@pytest.mark.parametrize("page_layout", ["natural", "single_column"])
|
||||
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
|
||||
def test_azure_converter_with_table(
|
||||
self, mock_resolve_value, page_layout: Literal["natural", "single_column"], test_files_path
|
||||
self, mock_resolve_value, page_layout: Literal["natural", "single_column"], test_files_path, mock_poller
|
||||
) -> None:
|
||||
mock_resolve_value.return_value = "test_api_key"
|
||||
|
||||
class MockPoller:
|
||||
def result(self) -> AnalyzeResult:
|
||||
with open(test_files_path / "json" / "azure_sample_pdf_1.json", encoding="utf-8") as azure_file:
|
||||
result = json.load(azure_file)
|
||||
return AnalyzeResult.from_dict(result)
|
||||
|
||||
with patch("azure.ai.formrecognizer.DocumentAnalysisClient.begin_analyze_document") as azure_mock:
|
||||
azure_mock.return_value = MockPoller()
|
||||
azure_mock.return_value = mock_poller("azure_sample_pdf_1.json")
|
||||
ocr_node = AzureOCRDocumentConverter(endpoint="", page_layout=page_layout)
|
||||
out = ocr_node.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"])
|
||||
|
||||
@ -177,17 +181,13 @@ D,$54.35,$6345.,
|
||||
assert pages[3] == gold_pages[3]
|
||||
|
||||
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
|
||||
def test_azure_converter_with_table_no_bounding_region(self, mock_resolve_value, test_files_path) -> None:
|
||||
def test_azure_converter_with_table_no_bounding_region(
|
||||
self, mock_resolve_value, test_files_path, mock_poller
|
||||
) -> None:
|
||||
mock_resolve_value.return_value = "test_api_key"
|
||||
|
||||
class MockPoller:
|
||||
def result(self) -> AnalyzeResult:
|
||||
with open(test_files_path / "json" / "azure_sample_pdf_1.json", encoding="utf-8") as azure_file:
|
||||
result = json.load(azure_file)
|
||||
return AnalyzeResult.from_dict(result)
|
||||
|
||||
with patch("azure.ai.formrecognizer.DocumentAnalysisClient.begin_analyze_document") as azure_mock:
|
||||
azure_mock.return_value = MockPoller()
|
||||
azure_mock.return_value = mock_poller("azure_sample_pdf_1.json")
|
||||
ocr_node = AzureOCRDocumentConverter(endpoint="")
|
||||
out = ocr_node.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"])
|
||||
|
||||
@ -211,19 +211,14 @@ D,$54.35,$6345.,
|
||||
assert docs[0].meta["following_context"] == ""
|
||||
|
||||
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
|
||||
def test_azure_converter_with_multicolumn_header_table(self, mock_resolve_value, test_files_path) -> None:
|
||||
def test_azure_converter_with_multicolumn_header_table(
|
||||
self, mock_resolve_value, test_files_path, mock_poller
|
||||
) -> None:
|
||||
mock_resolve_value.return_value = "test_api_key"
|
||||
|
||||
class MockPoller:
|
||||
def result(self) -> AnalyzeResult:
|
||||
with open(test_files_path / "json" / "azure_sample_pdf_3.json", encoding="utf-8") as azure_file:
|
||||
result = json.load(azure_file)
|
||||
return AnalyzeResult.from_dict(result)
|
||||
|
||||
with patch("azure.ai.formrecognizer.DocumentAnalysisClient.begin_analyze_document") as azure_mock:
|
||||
azure_mock.return_value = MockPoller()
|
||||
azure_mock.return_value = mock_poller("azure_sample_pdf_3.json")
|
||||
ocr_node = AzureOCRDocumentConverter(endpoint="")
|
||||
|
||||
out = ocr_node.run(sources=[test_files_path / "pdf" / "sample_pdf_3.pdf"])
|
||||
|
||||
docs = out["documents"]
|
||||
@ -237,17 +232,11 @@ D,$54.35,$6345.,
|
||||
assert docs[0].meta["page"] == 1
|
||||
|
||||
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
|
||||
def test_table_pdf_with_non_empty_meta(self, mock_resolve_value, test_files_path) -> None:
|
||||
def test_table_pdf_with_non_empty_meta(self, mock_resolve_value, test_files_path, mock_poller) -> None:
|
||||
mock_resolve_value.return_value = "test_api_key"
|
||||
|
||||
class MockPoller:
|
||||
def result(self) -> AnalyzeResult:
|
||||
with open(test_files_path / "json" / "azure_sample_pdf_1.json", encoding="utf-8") as azure_file:
|
||||
result = json.load(azure_file)
|
||||
return AnalyzeResult.from_dict(result)
|
||||
|
||||
with patch("azure.ai.formrecognizer.DocumentAnalysisClient.begin_analyze_document") as azure_mock:
|
||||
azure_mock.return_value = MockPoller()
|
||||
azure_mock.return_value = mock_poller("azure_sample_pdf_1.json")
|
||||
ocr_node = AzureOCRDocumentConverter(endpoint="")
|
||||
out = ocr_node.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"], meta=[{"test": "value_1"}])
|
||||
|
||||
@ -299,33 +288,30 @@ D,$54.35,$6345.,
|
||||
assert "Now we are in Page 2" in documents[0].content
|
||||
assert "Page 3 was empty this is page 4" in documents[0].content
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_ENDPOINT", None), reason="Azure endpoint not available")
|
||||
@pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available")
|
||||
def test_run_with_store_full_path_false(self, test_files_path):
|
||||
component = AzureOCRDocumentConverter(
|
||||
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"],
|
||||
api_key=Secret.from_env_var("CORE_AZURE_CS_API_KEY"),
|
||||
store_full_path=False,
|
||||
)
|
||||
output = component.run(sources=[test_files_path / "docx" / "sample_docx.docx"])
|
||||
documents = output["documents"]
|
||||
assert len(documents) == 1
|
||||
assert "Sample Docx File" in documents[0].content
|
||||
assert documents[0].meta["file_path"] == "sample_docx.docx"
|
||||
|
||||
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
|
||||
def test_meta_from_byte_stream(self, mock_resolve_value, test_files_path) -> None:
|
||||
def test_run_with_store_full_path_false(self, mock_resolve_value, test_files_path, mock_poller):
|
||||
mock_resolve_value.return_value = "test_api_key"
|
||||
|
||||
class MockPoller:
|
||||
def result(self) -> AnalyzeResult:
|
||||
with open(test_files_path / "json" / "azure_sample_pdf_1.json", encoding="utf-8") as azure_file:
|
||||
result = json.load(azure_file)
|
||||
return AnalyzeResult.from_dict(result)
|
||||
with patch("azure.ai.formrecognizer.DocumentAnalysisClient.begin_analyze_document") as azure_mock:
|
||||
azure_mock.return_value = mock_poller("azure_sample_pdf_1.json")
|
||||
component = AzureOCRDocumentConverter(
|
||||
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"],
|
||||
api_key=Secret.from_env_var("CORE_AZURE_CS_API_KEY"),
|
||||
store_full_path=False,
|
||||
)
|
||||
output = component.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"])
|
||||
|
||||
documents = output["documents"]
|
||||
assert len(documents) == 2
|
||||
for doc in documents:
|
||||
assert doc.meta["file_path"] == "sample_pdf_1.pdf"
|
||||
|
||||
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
|
||||
def test_meta_from_byte_stream(self, mock_resolve_value, test_files_path, mock_poller) -> None:
|
||||
mock_resolve_value.return_value = "test_api_key"
|
||||
|
||||
with patch("azure.ai.formrecognizer.DocumentAnalysisClient.begin_analyze_document") as azure_mock:
|
||||
azure_mock.return_value = MockPoller()
|
||||
azure_mock.return_value = mock_poller("azure_sample_pdf_1.json")
|
||||
ocr_node = AzureOCRDocumentConverter(endpoint="")
|
||||
bytes_ = (test_files_path / "pdf" / "sample_pdf_1.pdf").read_bytes()
|
||||
byte_stream = ByteStream(data=bytes_, meta={"test_from": "byte_stream"})
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import random
|
||||
@ -230,93 +229,21 @@ class TestHuggingFaceAPITextEmbedder:
|
||||
assert len(result["embedding"]) == 384
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
|
||||
|
||||
class TestHuggingFaceAPITextEmbedderAsync:
|
||||
"""
|
||||
Integration tests for HuggingFaceAPITextEmbedder that verify the async functionality with a real API.
|
||||
These tests require a valid Hugging Face API token.
|
||||
"""
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(os.environ.get("HF_API_TOKEN", "") == "", reason="HF_API_TOKEN is not set")
|
||||
async def test_run_async_with_real_api(self):
|
||||
"""
|
||||
Integration test that verifies the async functionality with a real API.
|
||||
This test requires a valid Hugging Face API token.
|
||||
"""
|
||||
# Use a small, reliable model for testing
|
||||
async def test_live_run_async_serverless(self):
|
||||
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
embedder = HuggingFaceAPITextEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model_name}
|
||||
)
|
||||
|
||||
# Test with a simple text
|
||||
text = "This is a test sentence for embedding."
|
||||
result = await embedder.run_async(text=text)
|
||||
|
||||
# Verify the result
|
||||
assert "embedding" in result
|
||||
assert isinstance(result["embedding"], list)
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
assert len(result["embedding"]) == 384 # MiniLM-L6-v2 has 384 dimensions
|
||||
|
||||
# Test with a longer text
|
||||
long_text = "This is a longer test sentence for embedding. " * 10
|
||||
result = await embedder.run_async(text=long_text)
|
||||
|
||||
# Verify the result
|
||||
assert "embedding" in result
|
||||
assert isinstance(result["embedding"], list)
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
assert len(result["embedding"]) == 384
|
||||
|
||||
# Test with prefix and suffix
|
||||
embedder_with_prefix_suffix = HuggingFaceAPITextEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": model_name},
|
||||
prefix="prefix: ",
|
||||
suffix=" :suffix",
|
||||
)
|
||||
|
||||
result = await embedder_with_prefix_suffix.run_async(text=text)
|
||||
|
||||
# Verify the result
|
||||
assert "embedding" in result
|
||||
assert isinstance(result["embedding"], list)
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
assert len(result["embedding"]) == 384
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(os.environ.get("HF_API_TOKEN", "") == "", reason="HF_API_TOKEN is not set")
|
||||
async def test_run_async_concurrent_requests(self):
|
||||
"""
|
||||
Integration test that verifies the async functionality with concurrent requests.
|
||||
This test requires a valid Hugging Face API token.
|
||||
"""
|
||||
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
embedder = HuggingFaceAPITextEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model_name}
|
||||
)
|
||||
|
||||
texts = [
|
||||
"This is the first test sentence.",
|
||||
"This is the second test sentence.",
|
||||
"This is the third test sentence.",
|
||||
"This is the fourth test sentence.",
|
||||
"This is the fifth test sentence.",
|
||||
]
|
||||
|
||||
# Run concurrent requests
|
||||
tasks = [embedder.run_async(text=text) for text in texts]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
assert "embedding" in result
|
||||
assert isinstance(result["embedding"], list)
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
assert len(result["embedding"]) == 384 # MiniLM-L6-v2 has 384 dimensions
|
||||
|
||||
@ -120,7 +120,7 @@ class TestLinkContentFetcher:
|
||||
assert first_stream.meta["content_type"] == "application/pdf"
|
||||
assert first_stream.mime_type == "application/pdf"
|
||||
|
||||
def test_run_bad_status_code(self):
|
||||
def test_run_bad_request_no_exception(self):
|
||||
"""Test behavior when a request results in an error status code"""
|
||||
empty_byte_stream = b""
|
||||
fetcher = LinkContentFetcher(raise_on_failure=False, retry_attempts=0)
|
||||
@ -140,6 +140,23 @@ class TestLinkContentFetcher:
|
||||
assert first_stream.meta["content_type"] == "text/html"
|
||||
assert first_stream.mime_type == "text/html"
|
||||
|
||||
def test_bad_request_exception_raised(self):
|
||||
"""
|
||||
This test is to ensure that the fetcher raises an exception when a single bad request is made and it is configured to
|
||||
do so.
|
||||
"""
|
||||
fetcher = LinkContentFetcher(raise_on_failure=True, retry_attempts=0)
|
||||
|
||||
mock_response = Mock(status_code=403)
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"403 Client Error", request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
with patch("haystack.components.fetchers.link_content.httpx.Client.get") as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
fetcher.run(["https://non_existent_website_dot.com/"])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_link_content_fetcher_html(self):
|
||||
"""
|
||||
@ -166,19 +183,6 @@ class TestLinkContentFetcher:
|
||||
assert "url" in first_stream.meta and first_stream.meta["url"] == TEXT_URL
|
||||
assert first_stream.mime_type == "text/plain"
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_link_content_fetcher_pdf(self):
|
||||
"""
|
||||
Test fetching PDF content from a real URL.
|
||||
"""
|
||||
fetcher = LinkContentFetcher()
|
||||
streams = fetcher.run([PDF_URL])["streams"]
|
||||
assert len(streams) == 1
|
||||
first_stream = streams[0]
|
||||
assert first_stream.meta["content_type"] in ("application/octet-stream", "application/pdf")
|
||||
assert "url" in first_stream.meta and first_stream.meta["url"] == PDF_URL
|
||||
assert first_stream.mime_type in ("application/octet-stream", "application/pdf")
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_link_content_fetcher_multiple_different_content_types(self):
|
||||
"""
|
||||
@ -222,35 +226,13 @@ class TestLinkContentFetcher:
|
||||
In such a case, the fetcher should return the content of the URLs that were successfully fetched and not raise
|
||||
an exception.
|
||||
"""
|
||||
fetcher = LinkContentFetcher()
|
||||
fetcher = LinkContentFetcher(retry_attempts=0)
|
||||
result = fetcher.run(["https://non_existent_website_dot.com/", "https://www.google.com/"])
|
||||
assert len(result["streams"]) == 1
|
||||
first_stream = result["streams"][0]
|
||||
assert first_stream.meta["content_type"] == "text/html"
|
||||
assert first_stream.mime_type == "text/html"
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_bad_request_exception_raised(self):
|
||||
"""
|
||||
This test is to ensure that the fetcher raises an exception when a single bad request is made and it is configured to
|
||||
do so.
|
||||
"""
|
||||
fetcher = LinkContentFetcher()
|
||||
with pytest.raises((httpx.ConnectError, httpx.ConnectTimeout)):
|
||||
fetcher.run(["https://non_existent_website_dot.com/"])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_link_content_fetcher_audio(self):
|
||||
"""
|
||||
Test fetching audio content from a real URL.
|
||||
"""
|
||||
fetcher = LinkContentFetcher()
|
||||
streams = fetcher.run(["https://download.samplelib.com/mp3/sample-3s.mp3"])["streams"]
|
||||
first_stream = streams[0]
|
||||
assert first_stream.meta["content_type"] == "audio/mpeg"
|
||||
assert first_stream.mime_type == "audio/mpeg"
|
||||
assert len(first_stream.data) > 0
|
||||
|
||||
|
||||
class TestLinkContentFetcherAsync:
|
||||
@pytest.mark.asyncio
|
||||
@ -337,17 +319,6 @@ class TestLinkContentFetcherAsync:
|
||||
assert len(streams) == 1
|
||||
assert streams[0].data == b"Success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_run_async_integration(self):
|
||||
"""Test async fetching with real HTTP requests"""
|
||||
fetcher = LinkContentFetcher()
|
||||
streams = (await fetcher.run_async([HTML_URL]))["streams"]
|
||||
first_stream = streams[0]
|
||||
assert "Haystack" in first_stream.data.decode("utf-8")
|
||||
assert first_stream.meta["content_type"] == "text/html"
|
||||
assert first_stream.mime_type == "text/html"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_run_async_multiple_integration(self):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
|
||||
|
||||
@ -894,15 +894,16 @@ class TestOpenAIChatGenerator:
|
||||
assert message.meta["finish_reason"] == "stop"
|
||||
assert message.meta["usage"]["prompt_tokens"] > 0
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run_wrong_model(self, chat_messages):
|
||||
component = OpenAIChatGenerator(model="something-obviously-wrong")
|
||||
async def test_run_with_wrong_model(self):
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.side_effect = OpenAIError("Invalid model name")
|
||||
|
||||
generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), model="something-obviously-wrong")
|
||||
|
||||
generator.client = mock_client
|
||||
|
||||
with pytest.raises(OpenAIError):
|
||||
component.run(chat_messages)
|
||||
generator.run([ChatMessage.from_user("irrelevant")])
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
@ -944,27 +945,6 @@ class TestOpenAIChatGenerator:
|
||||
assert message.meta["usage"]["completion_tokens"] > 0
|
||||
assert message.meta["usage"]["total_tokens"] > 0
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run_with_tools(self, tools):
|
||||
chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
|
||||
component = OpenAIChatGenerator(tools=tools)
|
||||
results = component.run(chat_messages)
|
||||
assert len(results["replies"]) == 1
|
||||
message = results["replies"][0]
|
||||
|
||||
assert not message.texts
|
||||
assert not message.text
|
||||
assert message.tool_calls
|
||||
tool_call = message.tool_call
|
||||
assert isinstance(tool_call, ToolCall)
|
||||
assert tool_call.tool_name == "weather"
|
||||
assert tool_call.arguments == {"city": "Paris"}
|
||||
assert message.meta["finish_reason"] == "tool_calls"
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
|
||||
@ -293,16 +293,17 @@ class TestOpenAIChatGeneratorAsync:
|
||||
assert "gpt-4o" in message.meta["model"]
|
||||
assert message.meta["finish_reason"] == "stop"
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_run_wrong_model_async(self, chat_messages):
|
||||
component = OpenAIChatGenerator(model="something-obviously-wrong")
|
||||
async def test_run_with_wrong_model_async(self):
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.side_effect = OpenAIError("Invalid model name")
|
||||
|
||||
generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), model="something-obviously-wrong")
|
||||
|
||||
generator.client = mock_client
|
||||
|
||||
with pytest.raises(OpenAIError):
|
||||
await component.run_async(chat_messages)
|
||||
await generator.run_async([ChatMessage.from_user("irrelevant")])
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
|
||||
@ -439,7 +439,6 @@ class TestHuggingFaceLocalGenerator:
|
||||
Test that StopWordsCriteria catches stop word tokens in a continuous and sequential order in the input_ids
|
||||
using a real Huggingface tokenizer.
|
||||
"""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
model_name = "google/flan-t5-small"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
@ -462,8 +461,6 @@ class TestHuggingFaceLocalGenerator:
|
||||
model="google/flan-t5-small", task="text2text-generation", stop_words=["unambiguously"]
|
||||
)
|
||||
generator.warm_up()
|
||||
results = generator.run(prompt="something that triggers something")
|
||||
assert results["replies"] != []
|
||||
assert generator.stopping_criteria_list is not None
|
||||
|
||||
@pytest.mark.integration
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
from openai import OpenAIError
|
||||
from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
|
||||
@ -255,54 +254,16 @@ class TestOpenAIGenerator:
|
||||
assert "completion_tokens" in metadata["usage"] and metadata["usage"]["completion_tokens"] > 0
|
||||
assert "total_tokens" in metadata["usage"] and metadata["usage"]["total_tokens"] > 0
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run_wrong_model(self):
|
||||
component = OpenAIGenerator(model="something-obviously-wrong")
|
||||
def test_run_with_wrong_model(self):
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.side_effect = OpenAIError("Invalid model name")
|
||||
|
||||
generator = OpenAIGenerator(api_key=Secret.from_token("test-api-key"), model="something-obviously-wrong")
|
||||
|
||||
generator.client = mock_client
|
||||
|
||||
with pytest.raises(OpenAIError):
|
||||
component.run("Whatever")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run_streaming(self):
|
||||
class Callback:
|
||||
def __init__(self):
|
||||
self.responses = ""
|
||||
self.counter = 0
|
||||
|
||||
def __call__(self, chunk: StreamingChunk) -> None:
|
||||
self.counter += 1
|
||||
self.responses += chunk.content if chunk.content else ""
|
||||
|
||||
callback = Callback()
|
||||
component = OpenAIGenerator(streaming_callback=callback)
|
||||
results = component.run("What's the capital of France?")
|
||||
|
||||
assert len(results["replies"]) == 1
|
||||
assert len(results["meta"]) == 1
|
||||
response: str = results["replies"][0]
|
||||
assert "Paris" in response
|
||||
|
||||
metadata = results["meta"][0]
|
||||
|
||||
assert "gpt-4o-mini" in metadata["model"]
|
||||
assert metadata["finish_reason"] == "stop"
|
||||
|
||||
assert "completion_start_time" in metadata
|
||||
assert datetime.fromisoformat(metadata["completion_start_time"]) <= datetime.now()
|
||||
|
||||
# unfortunately, the usage is not available for streaming calls
|
||||
# we keep the key in the metadata for compatibility
|
||||
assert "usage" in metadata and len(metadata["usage"]) == 0
|
||||
|
||||
assert callback.counter > 1
|
||||
assert "Paris" in callback.responses
|
||||
generator.run("Whatever")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
|
||||
@ -576,35 +576,6 @@ class TestSentenceTransformersDiversityRanker:
|
||||
pipe_serialized = pipe.dumps()
|
||||
assert Pipeline.loads(pipe_serialized) == pipe
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run(self, similarity, monkeypatch):
|
||||
"""
|
||||
Tests that run method returns documents in the correct order
|
||||
"""
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
|
||||
)
|
||||
ranker.warm_up()
|
||||
query = "city"
|
||||
documents = [
|
||||
Document(content="France"),
|
||||
Document(content="Germany"),
|
||||
Document(content="Eiffel Tower"),
|
||||
Document(content="Berlin"),
|
||||
Document(content="Bananas"),
|
||||
Document(content="Silicon Valley"),
|
||||
Document(content="Brandenburg Gate"),
|
||||
]
|
||||
result = ranker.run(query=query, documents=documents)
|
||||
ranked_docs = result["documents"]
|
||||
ranked_order = ", ".join([doc.content for doc in ranked_docs])
|
||||
expected_order = "Berlin, Bananas, Eiffel Tower, Silicon Valley, France, Brandenburg Gate, Germany"
|
||||
|
||||
assert ranked_order == expected_order
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
|
||||
@ -361,32 +361,6 @@ class TestSimilarityRanker:
|
||||
assert docs_after[1].score == pytest.approx(sorted_scores[1], abs=1e-6)
|
||||
assert docs_after[2].score == pytest.approx(sorted_scores[2], abs=1e-6)
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.slow
|
||||
def test_run_small_batch_size(self):
|
||||
"""
|
||||
Test if the component ranks documents correctly.
|
||||
"""
|
||||
ranker = TransformersSimilarityRanker(model="cross-encoder/ms-marco-MiniLM-L-6-v2", batch_size=2)
|
||||
ranker.warm_up()
|
||||
|
||||
query = "City in Bosnia and Herzegovina"
|
||||
docs_before_texts = ["Berlin", "Belgrade", "Sarajevo"]
|
||||
expected_first_text = "Sarajevo"
|
||||
expected_scores = [2.2864143829792738e-05, 0.00012495707778725773, 0.009869757108390331]
|
||||
|
||||
docs_before = [Document(content=text) for text in docs_before_texts]
|
||||
output = ranker.run(query=query, documents=docs_before)
|
||||
docs_after = output["documents"]
|
||||
|
||||
assert len(docs_after) == 3
|
||||
assert docs_after[0].content == expected_first_text
|
||||
|
||||
sorted_scores = sorted(expected_scores, reverse=True)
|
||||
assert docs_after[0].score == pytest.approx(sorted_scores[0], abs=1e-6)
|
||||
assert docs_after[1].score == pytest.approx(sorted_scores[1], abs=1e-6)
|
||||
assert docs_after[2].score == pytest.approx(sorted_scores[2], abs=1e-6)
|
||||
|
||||
def test_returns_empty_list_if_no_documents_are_provided(self):
|
||||
sampler = TransformersSimilarityRanker()
|
||||
sampler.model = MagicMock()
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
import json
|
||||
|
||||
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ToolCall, ToolCallResult, TextContent
|
||||
@ -430,35 +429,3 @@ def test_from_openai_dict_format_unsupported_role():
|
||||
def test_from_openai_dict_format_assistant_missing_content_and_tool_calls():
|
||||
with pytest.raises(ValueError):
|
||||
ChatMessage.from_openai_dict_format({"role": "assistant", "irrelevant": "irrelevant"})
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_apply_chat_templating_on_chat_message():
|
||||
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
||||
formatted_messages = [m.to_openai_dict_format() for m in messages]
|
||||
tokenized_messages = tokenizer.apply_chat_template(formatted_messages, tokenize=False)
|
||||
assert tokenized_messages == "<|system|>\nYou are good assistant</s>\n<|user|>\nI have a question</s>\n"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_apply_custom_chat_templating_on_chat_message():
|
||||
anthropic_template = (
|
||||
"{%- for message in messages %}"
|
||||
"{%- if message.role == 'user' %}\n\nHuman: {{ message.content.strip() }}"
|
||||
"{%- elif message.role == 'assistant' %}\n\nAssistant: {{ message.content.strip() }}"
|
||||
"{%- elif message.role == 'function' %}{{ raise('anthropic does not support function calls.') }}"
|
||||
"{%- elif message.role == 'system' and loop.index == 1 %}{{ message.content }}"
|
||||
"{%- else %}{{ raise('Invalid message role: ' + message.role) }}"
|
||||
"{%- endif %}"
|
||||
"{%- endfor %}"
|
||||
"\n\nAssistant:"
|
||||
)
|
||||
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
|
||||
# could be any tokenizer, let's use the one we already likely have in cache
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
||||
formatted_messages = [m.to_openai_dict_format() for m in messages]
|
||||
tokenized_messages = tokenizer.apply_chat_template(
|
||||
formatted_messages, chat_template=anthropic_template, tokenize=False
|
||||
)
|
||||
assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user