mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-26 06:28:33 +00:00
test: speed up some tests + minor refactorings (#9451)
* this is an integration test * more improvements * rm redundant comments
This commit is contained in:
parent
81c0cefa41
commit
2616d4d55b
@ -2,9 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, Protocol, TypeVar
|
||||
|
||||
T = TypeVar("T", bound="TextEmbedder")
|
||||
from typing import Any, Dict, Protocol
|
||||
|
||||
# See https://github.com/pylint-dev/pylint/issues/9319.
|
||||
# pylint: disable=unnecessary-ellipsis
|
||||
|
||||
@ -2,5 +2,4 @@
|
||||
features:
|
||||
- |
|
||||
Added new `HuggingFaceTEIRanker` component to enable reranking with Text Embeddings Inference (TEI) API.
|
||||
This component supports both both self-hosted Text Embeddings Inference services and Hugging Face Inference
|
||||
Endpoints.
|
||||
This component supports both self-hosted Text Embeddings Inference services and Hugging Face Inference Endpoints.
|
||||
|
||||
@ -804,8 +804,9 @@ class TestAgent:
|
||||
assert isinstance(result["last_message"], ChatMessage)
|
||||
assert result["messages"][-1] == result["last_message"]
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
|
||||
def test_agent_streaming_with_tool_call(self, monkeypatch, weather_tool):
|
||||
def test_agent_streaming_with_tool_call(self, weather_tool):
|
||||
chat_generator = OpenAIChatGenerator()
|
||||
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
|
||||
agent.warm_up()
|
||||
|
||||
@ -1,64 +0,0 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import component
|
||||
from haystack.components.embedders.types.protocol import TextEmbedder
|
||||
|
||||
|
||||
@component
|
||||
class MockTextEmbedder:
|
||||
def run(self, text: str, param_a: str = "default", param_b: str = "another_default") -> Dict[str, Any]:
|
||||
return {"embedding": [0.1, 0.2, 0.3], "metadata": {"text": text, "param_a": param_a, "param_b": param_b}}
|
||||
|
||||
|
||||
@component
|
||||
class MockInvalidTextEmbedder:
|
||||
def run(self, something_else: float) -> dict[str, bool]:
|
||||
return {"result": True}
|
||||
|
||||
|
||||
def test_protocol_implementation():
|
||||
embedder: TextEmbedder = MockTextEmbedder() # should not raise any type errors
|
||||
|
||||
# check if the run method has the correct signature
|
||||
run_signature = inspect.signature(MockTextEmbedder.run)
|
||||
assert "text" in run_signature.parameters
|
||||
assert run_signature.parameters["text"].annotation == str
|
||||
assert run_signature.return_annotation == Dict[str, Any]
|
||||
|
||||
result = embedder.run("test text")
|
||||
assert isinstance(result, dict)
|
||||
assert "embedding" in result
|
||||
assert isinstance(result["embedding"], list)
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
assert isinstance(result["metadata"], dict)
|
||||
|
||||
|
||||
def test_protocol_optional_parameters():
|
||||
embedder = MockTextEmbedder()
|
||||
|
||||
# default parameters
|
||||
result1 = embedder.run("test text")
|
||||
|
||||
# with custom parameters
|
||||
result2 = embedder.run("test text", param_a="custom_a", param_b="custom_b")
|
||||
|
||||
assert result1["metadata"]["param_a"] == "default"
|
||||
assert result2["metadata"]["param_a"] == "custom_a"
|
||||
assert result2["metadata"]["param_b"] == "custom_b"
|
||||
|
||||
|
||||
def test_protocol_invalid_implementation():
|
||||
run_signature = inspect.signature(MockInvalidTextEmbedder.run)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert "text" in run_signature.parameters and run_signature.parameters["text"].annotation == str
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert run_signature.return_annotation == Dict[str, Any]
|
||||
@ -299,7 +299,13 @@ class TestLinkContentFetcherAsync:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_user_agent_rotation(self):
|
||||
"""Test user agent rotation in async fetching"""
|
||||
with patch("haystack.components.fetchers.link_content.httpx.AsyncClient.get") as mock_get:
|
||||
with (
|
||||
patch("haystack.components.fetchers.link_content.httpx.AsyncClient.get") as mock_get,
|
||||
patch("asyncio.sleep") as mock_sleep,
|
||||
):
|
||||
# Mock asyncio.sleep used by tenacity to keep this test fast
|
||||
mock_sleep.return_value = None
|
||||
|
||||
# First call raises an error to trigger user agent rotation
|
||||
first_response = Mock(status_code=403)
|
||||
first_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
@ -320,6 +326,8 @@ class TestLinkContentFetcherAsync:
|
||||
assert len(streams) == 1
|
||||
assert streams[0].data == b"Success"
|
||||
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_run_async_multiple_integration(self):
|
||||
|
||||
@ -79,46 +79,56 @@ class TestRequestWithRetry:
|
||||
|
||||
def test_request_with_retry_retries_on_error(self):
|
||||
"""Test that request_with_retry retries on HTTP errors"""
|
||||
error_response = requests.Response()
|
||||
error_response.status_code = 503
|
||||
with patch("time.sleep") as mock_sleep:
|
||||
# Mock time.sleep used by tenacity to keep this test fast
|
||||
mock_sleep.return_value = None
|
||||
|
||||
success_response = requests.Response()
|
||||
success_response.status_code = 200
|
||||
error_response = requests.Response()
|
||||
error_response.status_code = 503
|
||||
|
||||
with patch("requests.request") as mock_request:
|
||||
# First call raises an error, second call succeeds
|
||||
mock_request.side_effect = [requests.exceptions.HTTPError("Server error"), success_response]
|
||||
success_response = requests.Response()
|
||||
success_response.status_code = 200
|
||||
|
||||
response = request_with_retry(method="GET", url="https://example.com", attempts=2)
|
||||
with patch("requests.request") as mock_request:
|
||||
# First call raises an error, second call succeeds
|
||||
mock_request.side_effect = [requests.exceptions.HTTPError("Server error"), success_response]
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
response = request_with_retry(method="GET", url="https://example.com", attempts=2)
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
mock_sleep.assert_called()
|
||||
|
||||
def test_request_with_retry_retries_on_status_code(self):
|
||||
"""Test that request_with_retry retries on specified status codes"""
|
||||
error_response = requests.Response()
|
||||
error_response.status_code = 503
|
||||
with patch("time.sleep") as mock_sleep:
|
||||
# Mock time.sleep used by tenacity to keep this test fast
|
||||
mock_sleep.return_value = None
|
||||
|
||||
def raise_for_status():
|
||||
if error_response.status_code in [503]:
|
||||
raise requests.exceptions.HTTPError("Service Unavailable")
|
||||
error_response = requests.Response()
|
||||
error_response.status_code = 503
|
||||
|
||||
error_response.raise_for_status = raise_for_status
|
||||
def raise_for_status():
|
||||
if error_response.status_code in [503]:
|
||||
raise requests.exceptions.HTTPError("Service Unavailable")
|
||||
|
||||
success_response = requests.Response()
|
||||
success_response.status_code = 200
|
||||
success_response.raise_for_status = lambda: None
|
||||
error_response.raise_for_status = raise_for_status
|
||||
|
||||
with patch("requests.request") as mock_request:
|
||||
# First call returns error status code, second call succeeds
|
||||
mock_request.side_effect = [error_response, success_response]
|
||||
success_response = requests.Response()
|
||||
success_response.status_code = 200
|
||||
success_response.raise_for_status = lambda: None
|
||||
|
||||
response = request_with_retry(
|
||||
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
|
||||
)
|
||||
with patch("requests.request") as mock_request:
|
||||
# First call returns error status code, second call succeeds
|
||||
mock_request.side_effect = [error_response, success_response]
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
response = request_with_retry(
|
||||
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
|
||||
)
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
mock_sleep.assert_called()
|
||||
|
||||
|
||||
class TestAsyncRequestWithRetry:
|
||||
@ -183,44 +193,54 @@ class TestAsyncRequestWithRetry:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_with_retry_retries_on_error(self):
|
||||
"""Test that async_request_with_retry retries on HTTP errors"""
|
||||
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
|
||||
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
|
||||
with patch("asyncio.sleep") as mock_sleep:
|
||||
# Mock asyncio.sleep used by tenacity to keep this test fast
|
||||
mock_sleep.return_value = None
|
||||
|
||||
with patch("httpx.AsyncClient.request") as mock_request:
|
||||
# First call raises an error, second call succeeds
|
||||
mock_request.side_effect = [
|
||||
httpx.RequestError("Server error", request=httpx.Request("GET", "https://example.com")),
|
||||
success_response,
|
||||
]
|
||||
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
|
||||
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
|
||||
|
||||
response = await async_request_with_retry(method="GET", url="https://example.com", attempts=2)
|
||||
with patch("httpx.AsyncClient.request") as mock_request:
|
||||
# First call raises an error, second call succeeds
|
||||
mock_request.side_effect = [
|
||||
httpx.RequestError("Server error", request=httpx.Request("GET", "https://example.com")),
|
||||
success_response,
|
||||
]
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
response = await async_request_with_retry(method="GET", url="https://example.com", attempts=2)
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
mock_sleep.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_with_retry_retries_on_status_code(self):
|
||||
"""Test that async_request_with_retry retries on specified status codes"""
|
||||
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
|
||||
with patch("asyncio.sleep") as mock_sleep:
|
||||
# Mock asyncio.sleep used by tenacity to keep this test fast
|
||||
mock_sleep.return_value = None
|
||||
|
||||
def raise_for_status():
|
||||
if error_response.status_code in [503]:
|
||||
raise httpx.HTTPStatusError(
|
||||
"Service Unavailable", request=error_response.request, response=error_response
|
||||
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
|
||||
|
||||
def raise_for_status():
|
||||
if error_response.status_code in [503]:
|
||||
raise httpx.HTTPStatusError(
|
||||
"Service Unavailable", request=error_response.request, response=error_response
|
||||
)
|
||||
|
||||
error_response.raise_for_status = raise_for_status
|
||||
|
||||
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
|
||||
success_response.raise_for_status = lambda: None
|
||||
|
||||
with patch("httpx.AsyncClient.request") as mock_request:
|
||||
# First call returns error status code, second call succeeds
|
||||
mock_request.side_effect = [error_response, success_response]
|
||||
|
||||
response = await async_request_with_retry(
|
||||
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
|
||||
)
|
||||
|
||||
error_response.raise_for_status = raise_for_status
|
||||
|
||||
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
|
||||
success_response.raise_for_status = lambda: None
|
||||
|
||||
with patch("httpx.AsyncClient.request") as mock_request:
|
||||
# First call returns error status code, second call succeeds
|
||||
mock_request.side_effect = [error_response, success_response]
|
||||
|
||||
response = await async_request_with_retry(
|
||||
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
|
||||
)
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
mock_sleep.assert_called()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user