Merge branch 'main' into feature/chinese-document-splitter

This commit is contained in:
MaChi 2025-06-01 07:02:25 +08:00 committed by GitHub
commit 78d10d3cfb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 100 additions and 133 deletions

View File

@ -69,7 +69,7 @@ def _validate_schema(schema: Dict[str, Any]) -> None:
raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}")
if definition.get("handler") is not None and not callable(definition["handler"]):
raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None")
if param == "messages" and definition["type"] is not List[ChatMessage]:
if param == "messages" and definition["type"] != List[ChatMessage]:
raise ValueError(f"StateSchema: 'messages' must be of type List[ChatMessage], got {definition['type']}")

View File

@ -31,7 +31,7 @@ def _is_valid_type(obj: Any) -> bool:
False
"""
# Handle Union types (including Optional)
if hasattr(obj, "__origin__") and obj.__origin__ is Union:
if hasattr(obj, "__origin__") and obj.__origin__ == Union:
return True
# Handle normal classes and generic types
@ -45,7 +45,7 @@ def _is_list_type(type_hint: Any) -> bool:
:param type_hint: The type hint to check
:return: True if the type hint represents a list, False otherwise
"""
return type_hint is list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) is list)
return type_hint == list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) == list)
def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]:

View File

@ -2,9 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, Protocol, TypeVar
T = TypeVar("T", bound="TextEmbedder")
from typing import Any, Dict, Protocol
# See https://github.com/pylint-dev/pylint/issues/9319.
# pylint: disable=unnecessary-ellipsis

View File

@ -422,7 +422,7 @@ class HuggingFaceLocalChatGenerator:
replies = [o.get("generated_text", "") for o in output]
# Remove stop words from replies if present
for stop_word in stop_words:
for stop_word in stop_words or []:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
chat_messages = [

View File

@ -2,5 +2,4 @@
features:
- |
Added new `HuggingFaceTEIRanker` component to enable reranking with Text Embeddings Inference (TEI) API.
This component supports both both self-hosted Text Embeddings Inference services and Hugging Face Inference
Endpoints.
This component supports both self-hosted Text Embeddings Inference services and Hugging Face Inference Endpoints.

View File

@ -0,0 +1,5 @@
---
fixes:
- |
Fix type comparison in schema validation by replacing `is not` with `!=` when checking the type `List[ChatMessage]`.
This prevents false mismatches due to Python's `is` operator comparing object identity instead of equality.

View File

@ -804,8 +804,9 @@ class TestAgent:
assert isinstance(result["last_message"], ChatMessage)
assert result["messages"][-1] == result["last_message"]
@pytest.mark.integration
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
def test_agent_streaming_with_tool_call(self, monkeypatch, weather_tool):
def test_agent_streaming_with_tool_call(self, weather_tool):
chat_generator = OpenAIChatGenerator()
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
agent.warm_up()

View File

@ -1,64 +0,0 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import inspect
from typing import Any, Dict
import pytest
from haystack import component
from haystack.components.embedders.types.protocol import TextEmbedder
@component
class MockTextEmbedder:
def run(self, text: str, param_a: str = "default", param_b: str = "another_default") -> Dict[str, Any]:
return {"embedding": [0.1, 0.2, 0.3], "metadata": {"text": text, "param_a": param_a, "param_b": param_b}}
@component
class MockInvalidTextEmbedder:
def run(self, something_else: float) -> dict[str, bool]:
return {"result": True}
def test_protocol_implementation():
embedder: TextEmbedder = MockTextEmbedder() # should not raise any type errors
# check if the run method has the correct signature
run_signature = inspect.signature(MockTextEmbedder.run)
assert "text" in run_signature.parameters
assert run_signature.parameters["text"].annotation == str
assert run_signature.return_annotation == Dict[str, Any]
result = embedder.run("test text")
assert isinstance(result, dict)
assert "embedding" in result
assert isinstance(result["embedding"], list)
assert all(isinstance(x, float) for x in result["embedding"])
assert isinstance(result["metadata"], dict)
def test_protocol_optional_parameters():
embedder = MockTextEmbedder()
# default parameters
result1 = embedder.run("test text")
# with custom parameters
result2 = embedder.run("test text", param_a="custom_a", param_b="custom_b")
assert result1["metadata"]["param_a"] == "default"
assert result2["metadata"]["param_a"] == "custom_a"
assert result2["metadata"]["param_b"] == "custom_b"
def test_protocol_invalid_implementation():
run_signature = inspect.signature(MockInvalidTextEmbedder.run)
with pytest.raises(AssertionError):
assert "text" in run_signature.parameters and run_signature.parameters["text"].annotation == str
with pytest.raises(AssertionError):
assert run_signature.return_annotation == Dict[str, Any]

View File

@ -299,7 +299,13 @@ class TestLinkContentFetcherAsync:
@pytest.mark.asyncio
async def test_run_async_user_agent_rotation(self):
"""Test user agent rotation in async fetching"""
with patch("haystack.components.fetchers.link_content.httpx.AsyncClient.get") as mock_get:
with (
patch("haystack.components.fetchers.link_content.httpx.AsyncClient.get") as mock_get,
patch("asyncio.sleep") as mock_sleep,
):
# Mock asyncio.sleep used by tenacity to keep this test fast
mock_sleep.return_value = None
# First call raises an error to trigger user agent rotation
first_response = Mock(status_code=403)
first_response.raise_for_status.side_effect = httpx.HTTPStatusError(
@ -320,6 +326,8 @@ class TestLinkContentFetcherAsync:
assert len(streams) == 1
assert streams[0].data == b"Success"
mock_sleep.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.integration
async def test_run_async_multiple_integration(self):

View File

@ -79,46 +79,56 @@ class TestRequestWithRetry:
def test_request_with_retry_retries_on_error(self):
"""Test that request_with_retry retries on HTTP errors"""
error_response = requests.Response()
error_response.status_code = 503
with patch("time.sleep") as mock_sleep:
# Mock time.sleep used by tenacity to keep this test fast
mock_sleep.return_value = None
success_response = requests.Response()
success_response.status_code = 200
error_response = requests.Response()
error_response.status_code = 503
with patch("requests.request") as mock_request:
# First call raises an error, second call succeeds
mock_request.side_effect = [requests.exceptions.HTTPError("Server error"), success_response]
success_response = requests.Response()
success_response.status_code = 200
response = request_with_retry(method="GET", url="https://example.com", attempts=2)
with patch("requests.request") as mock_request:
# First call raises an error, second call succeeds
mock_request.side_effect = [requests.exceptions.HTTPError("Server error"), success_response]
assert response == success_response
assert mock_request.call_count == 2
response = request_with_retry(method="GET", url="https://example.com", attempts=2)
assert response == success_response
assert mock_request.call_count == 2
mock_sleep.assert_called()
def test_request_with_retry_retries_on_status_code(self):
"""Test that request_with_retry retries on specified status codes"""
error_response = requests.Response()
error_response.status_code = 503
with patch("time.sleep") as mock_sleep:
# Mock time.sleep used by tenacity to keep this test fast
mock_sleep.return_value = None
def raise_for_status():
if error_response.status_code in [503]:
raise requests.exceptions.HTTPError("Service Unavailable")
error_response = requests.Response()
error_response.status_code = 503
error_response.raise_for_status = raise_for_status
def raise_for_status():
if error_response.status_code in [503]:
raise requests.exceptions.HTTPError("Service Unavailable")
success_response = requests.Response()
success_response.status_code = 200
success_response.raise_for_status = lambda: None
error_response.raise_for_status = raise_for_status
with patch("requests.request") as mock_request:
# First call returns error status code, second call succeeds
mock_request.side_effect = [error_response, success_response]
success_response = requests.Response()
success_response.status_code = 200
success_response.raise_for_status = lambda: None
response = request_with_retry(
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
)
with patch("requests.request") as mock_request:
# First call returns error status code, second call succeeds
mock_request.side_effect = [error_response, success_response]
assert response == success_response
assert mock_request.call_count == 2
response = request_with_retry(
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
)
assert response == success_response
assert mock_request.call_count == 2
mock_sleep.assert_called()
class TestAsyncRequestWithRetry:
@ -183,44 +193,54 @@ class TestAsyncRequestWithRetry:
@pytest.mark.asyncio
async def test_async_request_with_retry_retries_on_error(self):
"""Test that async_request_with_retry retries on HTTP errors"""
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
with patch("asyncio.sleep") as mock_sleep:
# Mock asyncio.sleep used by tenacity to keep this test fast
mock_sleep.return_value = None
with patch("httpx.AsyncClient.request") as mock_request:
# First call raises an error, second call succeeds
mock_request.side_effect = [
httpx.RequestError("Server error", request=httpx.Request("GET", "https://example.com")),
success_response,
]
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
response = await async_request_with_retry(method="GET", url="https://example.com", attempts=2)
with patch("httpx.AsyncClient.request") as mock_request:
# First call raises an error, second call succeeds
mock_request.side_effect = [
httpx.RequestError("Server error", request=httpx.Request("GET", "https://example.com")),
success_response,
]
assert response == success_response
assert mock_request.call_count == 2
response = await async_request_with_retry(method="GET", url="https://example.com", attempts=2)
assert response == success_response
assert mock_request.call_count == 2
mock_sleep.assert_called()
@pytest.mark.asyncio
async def test_async_request_with_retry_retries_on_status_code(self):
"""Test that async_request_with_retry retries on specified status codes"""
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
with patch("asyncio.sleep") as mock_sleep:
# Mock asyncio.sleep used by tenacity to keep this test fast
mock_sleep.return_value = None
def raise_for_status():
if error_response.status_code in [503]:
raise httpx.HTTPStatusError(
"Service Unavailable", request=error_response.request, response=error_response
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
def raise_for_status():
if error_response.status_code in [503]:
raise httpx.HTTPStatusError(
"Service Unavailable", request=error_response.request, response=error_response
)
error_response.raise_for_status = raise_for_status
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
success_response.raise_for_status = lambda: None
with patch("httpx.AsyncClient.request") as mock_request:
# First call returns error status code, second call succeeds
mock_request.side_effect = [error_response, success_response]
response = await async_request_with_retry(
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
)
error_response.raise_for_status = raise_for_status
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
success_response.raise_for_status = lambda: None
with patch("httpx.AsyncClient.request") as mock_request:
# First call returns error status code, second call succeeds
mock_request.side_effect = [error_response, success_response]
response = await async_request_with_retry(
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
)
assert response == success_response
assert mock_request.call_count == 2
assert response == success_response
assert mock_request.call_count == 2
mock_sleep.assert_called()