From 2616d4d55b3fe672a2fc46f75276d1d9fe6d8c12 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 29 May 2025 09:49:11 +0200 Subject: [PATCH 1/2] test: speed up some tests + minor refactorings (#9451) * this is an integration test * more improvements * rm redundant comments --- .../components/embedders/types/protocol.py | 4 +- ...to-component-rankers-0e3f54e523e42141.yaml | 3 +- test/components/agents/test_agent.py | 3 +- .../embedders/types/test_protocol.py | 64 --------- .../fetchers/test_link_content_fetcher.py | 10 +- test/utils/test_requests_utils.py | 136 ++++++++++-------- 6 files changed, 91 insertions(+), 129 deletions(-) delete mode 100644 test/components/embedders/types/test_protocol.py diff --git a/haystack/components/embedders/types/protocol.py b/haystack/components/embedders/types/protocol.py index 1e27548f6..de49b06c2 100644 --- a/haystack/components/embedders/types/protocol.py +++ b/haystack/components/embedders/types/protocol.py @@ -2,9 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, Protocol, TypeVar - -T = TypeVar("T", bound="TextEmbedder") +from typing import Any, Dict, Protocol # See https://github.com/pylint-dev/pylint/issues/9319. # pylint: disable=unnecessary-ellipsis diff --git a/releasenotes/notes/add-huggingface-api-text-embeddings-inference-to-component-rankers-0e3f54e523e42141.yaml b/releasenotes/notes/add-huggingface-api-text-embeddings-inference-to-component-rankers-0e3f54e523e42141.yaml index aff1d2798..c14bbd5b9 100644 --- a/releasenotes/notes/add-huggingface-api-text-embeddings-inference-to-component-rankers-0e3f54e523e42141.yaml +++ b/releasenotes/notes/add-huggingface-api-text-embeddings-inference-to-component-rankers-0e3f54e523e42141.yaml @@ -2,5 +2,4 @@ features: - | Added new `HuggingFaceTEIRanker` component to enable reranking with Text Embeddings Inference (TEI) API. - This component supports both both self-hosted Text Embeddings Inference services and Hugging Face Inference - Endpoints. + This component supports both self-hosted Text Embeddings Inference services and Hugging Face Inference Endpoints. diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index 2665e2cc7..9e53c2255 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -804,8 +804,9 @@ class TestAgent: assert isinstance(result["last_message"], ChatMessage) assert result["messages"][-1] == result["last_message"] + @pytest.mark.integration @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") - def test_agent_streaming_with_tool_call(self, monkeypatch, weather_tool): + def test_agent_streaming_with_tool_call(self, weather_tool): chat_generator = OpenAIChatGenerator() agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) agent.warm_up() diff --git a/test/components/embedders/types/test_protocol.py b/test/components/embedders/types/test_protocol.py deleted file mode 100644 index 8aa93f1d7..000000000 --- a/test/components/embedders/types/test_protocol.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -import inspect -from typing import Any, Dict - -import pytest - -from haystack import component -from haystack.components.embedders.types.protocol import TextEmbedder - - -@component -class MockTextEmbedder: - def run(self, text: str, param_a: str = "default", param_b: str = "another_default") -> Dict[str, Any]: - return {"embedding": [0.1, 0.2, 0.3], "metadata": {"text": text, "param_a": param_a, "param_b": param_b}} - - -@component -class MockInvalidTextEmbedder: - def run(self, something_else: float) -> dict[str, bool]: - return {"result": True} - - -def test_protocol_implementation(): - embedder: TextEmbedder = MockTextEmbedder() # should not raise any type errors - - # check if the run method has the correct signature - run_signature = inspect.signature(MockTextEmbedder.run) - assert "text" in run_signature.parameters - assert run_signature.parameters["text"].annotation == str - assert run_signature.return_annotation == Dict[str, Any] - - result = embedder.run("test text") - assert isinstance(result, dict) - assert "embedding" in result - assert isinstance(result["embedding"], list) - assert all(isinstance(x, float) for x in result["embedding"]) - assert isinstance(result["metadata"], dict) - - -def test_protocol_optional_parameters(): - embedder = MockTextEmbedder() - - # default parameters - result1 = embedder.run("test text") - - # with custom parameters - result2 = embedder.run("test text", param_a="custom_a", param_b="custom_b") - - assert result1["metadata"]["param_a"] == "default" - assert result2["metadata"]["param_a"] == "custom_a" - assert result2["metadata"]["param_b"] == "custom_b" - - -def test_protocol_invalid_implementation(): - run_signature = inspect.signature(MockInvalidTextEmbedder.run) - - with pytest.raises(AssertionError): - assert "text" in run_signature.parameters and run_signature.parameters["text"].annotation == str - - with pytest.raises(AssertionError): - assert run_signature.return_annotation == Dict[str, Any] diff --git a/test/components/fetchers/test_link_content_fetcher.py b/test/components/fetchers/test_link_content_fetcher.py index ea2cb14a4..5c0e99d68 100644 --- a/test/components/fetchers/test_link_content_fetcher.py +++ b/test/components/fetchers/test_link_content_fetcher.py @@ -299,7 +299,13 @@ class TestLinkContentFetcherAsync: @pytest.mark.asyncio async def test_run_async_user_agent_rotation(self): """Test user agent rotation in async fetching""" - with patch("haystack.components.fetchers.link_content.httpx.AsyncClient.get") as mock_get: + with ( + patch("haystack.components.fetchers.link_content.httpx.AsyncClient.get") as mock_get, + patch("asyncio.sleep") as mock_sleep, + ): + # Mock asyncio.sleep used by tenacity to keep this test fast + mock_sleep.return_value = None + # First call raises an error to trigger user agent rotation first_response = Mock(status_code=403) first_response.raise_for_status.side_effect = httpx.HTTPStatusError( @@ -320,6 +326,8 @@ class TestLinkContentFetcherAsync: assert len(streams) == 1 assert streams[0].data == b"Success" + mock_sleep.assert_called_once() + @pytest.mark.asyncio @pytest.mark.integration async def test_run_async_multiple_integration(self): diff --git a/test/utils/test_requests_utils.py b/test/utils/test_requests_utils.py index 7b25e4788..6674ff6bd 100644 --- a/test/utils/test_requests_utils.py +++ b/test/utils/test_requests_utils.py @@ -79,46 +79,56 @@ class TestRequestWithRetry: def test_request_with_retry_retries_on_error(self): """Test that request_with_retry retries on HTTP errors""" - error_response = requests.Response() - error_response.status_code = 503 + with patch("time.sleep") as mock_sleep: + # Mock time.sleep used by tenacity to keep this test fast + mock_sleep.return_value = None - success_response = requests.Response() - success_response.status_code = 200 + error_response = requests.Response() + error_response.status_code = 503 - with patch("requests.request") as mock_request: - # First call raises an error, second call succeeds - mock_request.side_effect = [requests.exceptions.HTTPError("Server error"), success_response] + success_response = requests.Response() + success_response.status_code = 200 - response = request_with_retry(method="GET", url="https://example.com", attempts=2) + with patch("requests.request") as mock_request: + # First call raises an error, second call succeeds + mock_request.side_effect = [requests.exceptions.HTTPError("Server error"), success_response] - assert response == success_response - assert mock_request.call_count == 2 + response = request_with_retry(method="GET", url="https://example.com", attempts=2) + + assert response == success_response + assert mock_request.call_count == 2 + mock_sleep.assert_called() def test_request_with_retry_retries_on_status_code(self): """Test that request_with_retry retries on specified status codes""" - error_response = requests.Response() - error_response.status_code = 503 + with patch("time.sleep") as mock_sleep: + # Mock time.sleep used by tenacity to keep this test fast + mock_sleep.return_value = None - def raise_for_status(): - if error_response.status_code in [503]: - raise requests.exceptions.HTTPError("Service Unavailable") + error_response = requests.Response() + error_response.status_code = 503 - error_response.raise_for_status = raise_for_status + def raise_for_status(): + if error_response.status_code in [503]: + raise requests.exceptions.HTTPError("Service Unavailable") - success_response = requests.Response() - success_response.status_code = 200 - success_response.raise_for_status = lambda: None + error_response.raise_for_status = raise_for_status - with patch("requests.request") as mock_request: - # First call returns error status code, second call succeeds - mock_request.side_effect = [error_response, success_response] + success_response = requests.Response() + success_response.status_code = 200 + success_response.raise_for_status = lambda: None - response = request_with_retry( - method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503] - ) + with patch("requests.request") as mock_request: + # First call returns error status code, second call succeeds + mock_request.side_effect = [error_response, success_response] - assert response == success_response - assert mock_request.call_count == 2 + response = request_with_retry( + method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503] + ) + + assert response == success_response + assert mock_request.call_count == 2 + mock_sleep.assert_called() class TestAsyncRequestWithRetry: @@ -183,44 +193,54 @@ class TestAsyncRequestWithRetry: @pytest.mark.asyncio async def test_async_request_with_retry_retries_on_error(self): """Test that async_request_with_retry retries on HTTP errors""" - error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com")) - success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com")) + with patch("asyncio.sleep") as mock_sleep: + # Mock asyncio.sleep used by tenacity to keep this test fast + mock_sleep.return_value = None - with patch("httpx.AsyncClient.request") as mock_request: - # First call raises an error, second call succeeds - mock_request.side_effect = [ - httpx.RequestError("Server error", request=httpx.Request("GET", "https://example.com")), - success_response, - ] + error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com")) + success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com")) - response = await async_request_with_retry(method="GET", url="https://example.com", attempts=2) + with patch("httpx.AsyncClient.request") as mock_request: + # First call raises an error, second call succeeds + mock_request.side_effect = [ + httpx.RequestError("Server error", request=httpx.Request("GET", "https://example.com")), + success_response, + ] - assert response == success_response - assert mock_request.call_count == 2 + response = await async_request_with_retry(method="GET", url="https://example.com", attempts=2) + + assert response == success_response + assert mock_request.call_count == 2 + mock_sleep.assert_called() @pytest.mark.asyncio async def test_async_request_with_retry_retries_on_status_code(self): """Test that async_request_with_retry retries on specified status codes""" - error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com")) + with patch("asyncio.sleep") as mock_sleep: + # Mock asyncio.sleep used by tenacity to keep this test fast + mock_sleep.return_value = None - def raise_for_status(): - if error_response.status_code in [503]: - raise httpx.HTTPStatusError( - "Service Unavailable", request=error_response.request, response=error_response + error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com")) + + def raise_for_status(): + if error_response.status_code in [503]: + raise httpx.HTTPStatusError( + "Service Unavailable", request=error_response.request, response=error_response + ) + + error_response.raise_for_status = raise_for_status + + success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com")) + success_response.raise_for_status = lambda: None + + with patch("httpx.AsyncClient.request") as mock_request: + # First call returns error status code, second call succeeds + mock_request.side_effect = [error_response, success_response] + + response = await async_request_with_retry( + method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503] ) - error_response.raise_for_status = raise_for_status - - success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com")) - success_response.raise_for_status = lambda: None - - with patch("httpx.AsyncClient.request") as mock_request: - # First call returns error status code, second call succeeds - mock_request.side_effect = [error_response, success_response] - - response = await async_request_with_retry( - method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503] - ) - - assert response == success_response - assert mock_request.call_count == 2 + assert response == success_response + assert mock_request.call_count == 2 + mock_sleep.assert_called() From 25c8d7ef9a487ef543c747ca5cd8ee7c4f73d201 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> Date: Fri, 30 May 2025 10:07:37 +0200 Subject: [PATCH 2/2] fix: In State schema validation use `!=` instead of `is not` for checking the type of `messages` (#9454) * Use != instead of is not * Add reno * Use more == instead of is * Fix mypy --- haystack/components/agents/state/state.py | 2 +- haystack/components/agents/state/state_utils.py | 4 ++-- haystack/components/generators/chat/hugging_face_local.py | 2 +- .../notes/fix-state-validate-schema-5ae41ce9c82de61a.yaml | 5 +++++ 4 files changed, 9 insertions(+), 4 deletions(-) create mode 100644 releasenotes/notes/fix-state-validate-schema-5ae41ce9c82de61a.yaml diff --git a/haystack/components/agents/state/state.py b/haystack/components/agents/state/state.py index 1224f0b6a..18f43562f 100644 --- a/haystack/components/agents/state/state.py +++ b/haystack/components/agents/state/state.py @@ -69,7 +69,7 @@ def _validate_schema(schema: Dict[str, Any]) -> None: raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}") if definition.get("handler") is not None and not callable(definition["handler"]): raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None") - if param == "messages" and definition["type"] is not List[ChatMessage]: + if param == "messages" and definition["type"] != List[ChatMessage]: raise ValueError(f"StateSchema: 'messages' must be of type List[ChatMessage], got {definition['type']}") diff --git a/haystack/components/agents/state/state_utils.py b/haystack/components/agents/state/state_utils.py index 2b392d812..8b8caec7d 100644 --- a/haystack/components/agents/state/state_utils.py +++ b/haystack/components/agents/state/state_utils.py @@ -31,7 +31,7 @@ def _is_valid_type(obj: Any) -> bool: False """ # Handle Union types (including Optional) - if hasattr(obj, "__origin__") and obj.__origin__ is Union: + if hasattr(obj, "__origin__") and obj.__origin__ == Union: return True # Handle normal classes and generic types @@ -45,7 +45,7 @@ def _is_list_type(type_hint: Any) -> bool: :param type_hint: The type hint to check :return: True if the type hint represents a list, False otherwise """ - return type_hint is list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) is list) + return type_hint == list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) == list) def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]: diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 52a990d79..a174bfb93 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -422,7 +422,7 @@ class HuggingFaceLocalChatGenerator: replies = [o.get("generated_text", "") for o in output] # Remove stop words from replies if present - for stop_word in stop_words: + for stop_word in stop_words or []: replies = [reply.replace(stop_word, "").rstrip() for reply in replies] chat_messages = [ diff --git a/releasenotes/notes/fix-state-validate-schema-5ae41ce9c82de61a.yaml b/releasenotes/notes/fix-state-validate-schema-5ae41ce9c82de61a.yaml new file mode 100644 index 000000000..bfb93f8d8 --- /dev/null +++ b/releasenotes/notes/fix-state-validate-schema-5ae41ce9c82de61a.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Fix type comparison in schema validation by replacing `is not` with `!=` when checking the type `List[ChatMessage]`. + This prevents false mismatches due to Python's `is` operator comparing object identity instead of equality.