mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
Merge branch 'main' into feature/chinese-document-splitter
This commit is contained in:
commit
78d10d3cfb
@ -69,7 +69,7 @@ def _validate_schema(schema: Dict[str, Any]) -> None:
|
||||
raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}")
|
||||
if definition.get("handler") is not None and not callable(definition["handler"]):
|
||||
raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None")
|
||||
if param == "messages" and definition["type"] is not List[ChatMessage]:
|
||||
if param == "messages" and definition["type"] != List[ChatMessage]:
|
||||
raise ValueError(f"StateSchema: 'messages' must be of type List[ChatMessage], got {definition['type']}")
|
||||
|
||||
|
||||
|
@ -31,7 +31,7 @@ def _is_valid_type(obj: Any) -> bool:
|
||||
False
|
||||
"""
|
||||
# Handle Union types (including Optional)
|
||||
if hasattr(obj, "__origin__") and obj.__origin__ is Union:
|
||||
if hasattr(obj, "__origin__") and obj.__origin__ == Union:
|
||||
return True
|
||||
|
||||
# Handle normal classes and generic types
|
||||
@ -45,7 +45,7 @@ def _is_list_type(type_hint: Any) -> bool:
|
||||
:param type_hint: The type hint to check
|
||||
:return: True if the type hint represents a list, False otherwise
|
||||
"""
|
||||
return type_hint is list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) is list)
|
||||
return type_hint == list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) == list)
|
||||
|
||||
|
||||
def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]:
|
||||
|
@ -2,9 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, Protocol, TypeVar
|
||||
|
||||
T = TypeVar("T", bound="TextEmbedder")
|
||||
from typing import Any, Dict, Protocol
|
||||
|
||||
# See https://github.com/pylint-dev/pylint/issues/9319.
|
||||
# pylint: disable=unnecessary-ellipsis
|
||||
|
@ -422,7 +422,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
replies = [o.get("generated_text", "") for o in output]
|
||||
|
||||
# Remove stop words from replies if present
|
||||
for stop_word in stop_words:
|
||||
for stop_word in stop_words or []:
|
||||
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
|
||||
|
||||
chat_messages = [
|
||||
|
@ -2,5 +2,4 @@
|
||||
features:
|
||||
- |
|
||||
Added new `HuggingFaceTEIRanker` component to enable reranking with Text Embeddings Inference (TEI) API.
|
||||
This component supports both both self-hosted Text Embeddings Inference services and Hugging Face Inference
|
||||
Endpoints.
|
||||
This component supports both self-hosted Text Embeddings Inference services and Hugging Face Inference Endpoints.
|
||||
|
@ -0,0 +1,5 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Fix type comparison in schema validation by replacing `is not` with `!=` when checking the type `List[ChatMessage]`.
|
||||
This prevents false mismatches due to Python's `is` operator comparing object identity instead of equality.
|
@ -804,8 +804,9 @@ class TestAgent:
|
||||
assert isinstance(result["last_message"], ChatMessage)
|
||||
assert result["messages"][-1] == result["last_message"]
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
|
||||
def test_agent_streaming_with_tool_call(self, monkeypatch, weather_tool):
|
||||
def test_agent_streaming_with_tool_call(self, weather_tool):
|
||||
chat_generator = OpenAIChatGenerator()
|
||||
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
|
||||
agent.warm_up()
|
||||
|
@ -1,64 +0,0 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import component
|
||||
from haystack.components.embedders.types.protocol import TextEmbedder
|
||||
|
||||
|
||||
@component
|
||||
class MockTextEmbedder:
|
||||
def run(self, text: str, param_a: str = "default", param_b: str = "another_default") -> Dict[str, Any]:
|
||||
return {"embedding": [0.1, 0.2, 0.3], "metadata": {"text": text, "param_a": param_a, "param_b": param_b}}
|
||||
|
||||
|
||||
@component
|
||||
class MockInvalidTextEmbedder:
|
||||
def run(self, something_else: float) -> dict[str, bool]:
|
||||
return {"result": True}
|
||||
|
||||
|
||||
def test_protocol_implementation():
|
||||
embedder: TextEmbedder = MockTextEmbedder() # should not raise any type errors
|
||||
|
||||
# check if the run method has the correct signature
|
||||
run_signature = inspect.signature(MockTextEmbedder.run)
|
||||
assert "text" in run_signature.parameters
|
||||
assert run_signature.parameters["text"].annotation == str
|
||||
assert run_signature.return_annotation == Dict[str, Any]
|
||||
|
||||
result = embedder.run("test text")
|
||||
assert isinstance(result, dict)
|
||||
assert "embedding" in result
|
||||
assert isinstance(result["embedding"], list)
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
assert isinstance(result["metadata"], dict)
|
||||
|
||||
|
||||
def test_protocol_optional_parameters():
|
||||
embedder = MockTextEmbedder()
|
||||
|
||||
# default parameters
|
||||
result1 = embedder.run("test text")
|
||||
|
||||
# with custom parameters
|
||||
result2 = embedder.run("test text", param_a="custom_a", param_b="custom_b")
|
||||
|
||||
assert result1["metadata"]["param_a"] == "default"
|
||||
assert result2["metadata"]["param_a"] == "custom_a"
|
||||
assert result2["metadata"]["param_b"] == "custom_b"
|
||||
|
||||
|
||||
def test_protocol_invalid_implementation():
|
||||
run_signature = inspect.signature(MockInvalidTextEmbedder.run)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert "text" in run_signature.parameters and run_signature.parameters["text"].annotation == str
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert run_signature.return_annotation == Dict[str, Any]
|
@ -299,7 +299,13 @@ class TestLinkContentFetcherAsync:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_user_agent_rotation(self):
|
||||
"""Test user agent rotation in async fetching"""
|
||||
with patch("haystack.components.fetchers.link_content.httpx.AsyncClient.get") as mock_get:
|
||||
with (
|
||||
patch("haystack.components.fetchers.link_content.httpx.AsyncClient.get") as mock_get,
|
||||
patch("asyncio.sleep") as mock_sleep,
|
||||
):
|
||||
# Mock asyncio.sleep used by tenacity to keep this test fast
|
||||
mock_sleep.return_value = None
|
||||
|
||||
# First call raises an error to trigger user agent rotation
|
||||
first_response = Mock(status_code=403)
|
||||
first_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
@ -320,6 +326,8 @@ class TestLinkContentFetcherAsync:
|
||||
assert len(streams) == 1
|
||||
assert streams[0].data == b"Success"
|
||||
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_run_async_multiple_integration(self):
|
||||
|
@ -79,46 +79,56 @@ class TestRequestWithRetry:
|
||||
|
||||
def test_request_with_retry_retries_on_error(self):
|
||||
"""Test that request_with_retry retries on HTTP errors"""
|
||||
error_response = requests.Response()
|
||||
error_response.status_code = 503
|
||||
with patch("time.sleep") as mock_sleep:
|
||||
# Mock time.sleep used by tenacity to keep this test fast
|
||||
mock_sleep.return_value = None
|
||||
|
||||
success_response = requests.Response()
|
||||
success_response.status_code = 200
|
||||
error_response = requests.Response()
|
||||
error_response.status_code = 503
|
||||
|
||||
with patch("requests.request") as mock_request:
|
||||
# First call raises an error, second call succeeds
|
||||
mock_request.side_effect = [requests.exceptions.HTTPError("Server error"), success_response]
|
||||
success_response = requests.Response()
|
||||
success_response.status_code = 200
|
||||
|
||||
response = request_with_retry(method="GET", url="https://example.com", attempts=2)
|
||||
with patch("requests.request") as mock_request:
|
||||
# First call raises an error, second call succeeds
|
||||
mock_request.side_effect = [requests.exceptions.HTTPError("Server error"), success_response]
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
response = request_with_retry(method="GET", url="https://example.com", attempts=2)
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
mock_sleep.assert_called()
|
||||
|
||||
def test_request_with_retry_retries_on_status_code(self):
|
||||
"""Test that request_with_retry retries on specified status codes"""
|
||||
error_response = requests.Response()
|
||||
error_response.status_code = 503
|
||||
with patch("time.sleep") as mock_sleep:
|
||||
# Mock time.sleep used by tenacity to keep this test fast
|
||||
mock_sleep.return_value = None
|
||||
|
||||
def raise_for_status():
|
||||
if error_response.status_code in [503]:
|
||||
raise requests.exceptions.HTTPError("Service Unavailable")
|
||||
error_response = requests.Response()
|
||||
error_response.status_code = 503
|
||||
|
||||
error_response.raise_for_status = raise_for_status
|
||||
def raise_for_status():
|
||||
if error_response.status_code in [503]:
|
||||
raise requests.exceptions.HTTPError("Service Unavailable")
|
||||
|
||||
success_response = requests.Response()
|
||||
success_response.status_code = 200
|
||||
success_response.raise_for_status = lambda: None
|
||||
error_response.raise_for_status = raise_for_status
|
||||
|
||||
with patch("requests.request") as mock_request:
|
||||
# First call returns error status code, second call succeeds
|
||||
mock_request.side_effect = [error_response, success_response]
|
||||
success_response = requests.Response()
|
||||
success_response.status_code = 200
|
||||
success_response.raise_for_status = lambda: None
|
||||
|
||||
response = request_with_retry(
|
||||
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
|
||||
)
|
||||
with patch("requests.request") as mock_request:
|
||||
# First call returns error status code, second call succeeds
|
||||
mock_request.side_effect = [error_response, success_response]
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
response = request_with_retry(
|
||||
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
|
||||
)
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
mock_sleep.assert_called()
|
||||
|
||||
|
||||
class TestAsyncRequestWithRetry:
|
||||
@ -183,44 +193,54 @@ class TestAsyncRequestWithRetry:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_with_retry_retries_on_error(self):
|
||||
"""Test that async_request_with_retry retries on HTTP errors"""
|
||||
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
|
||||
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
|
||||
with patch("asyncio.sleep") as mock_sleep:
|
||||
# Mock asyncio.sleep used by tenacity to keep this test fast
|
||||
mock_sleep.return_value = None
|
||||
|
||||
with patch("httpx.AsyncClient.request") as mock_request:
|
||||
# First call raises an error, second call succeeds
|
||||
mock_request.side_effect = [
|
||||
httpx.RequestError("Server error", request=httpx.Request("GET", "https://example.com")),
|
||||
success_response,
|
||||
]
|
||||
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
|
||||
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
|
||||
|
||||
response = await async_request_with_retry(method="GET", url="https://example.com", attempts=2)
|
||||
with patch("httpx.AsyncClient.request") as mock_request:
|
||||
# First call raises an error, second call succeeds
|
||||
mock_request.side_effect = [
|
||||
httpx.RequestError("Server error", request=httpx.Request("GET", "https://example.com")),
|
||||
success_response,
|
||||
]
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
response = await async_request_with_retry(method="GET", url="https://example.com", attempts=2)
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
mock_sleep.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_with_retry_retries_on_status_code(self):
|
||||
"""Test that async_request_with_retry retries on specified status codes"""
|
||||
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
|
||||
with patch("asyncio.sleep") as mock_sleep:
|
||||
# Mock asyncio.sleep used by tenacity to keep this test fast
|
||||
mock_sleep.return_value = None
|
||||
|
||||
def raise_for_status():
|
||||
if error_response.status_code in [503]:
|
||||
raise httpx.HTTPStatusError(
|
||||
"Service Unavailable", request=error_response.request, response=error_response
|
||||
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
|
||||
|
||||
def raise_for_status():
|
||||
if error_response.status_code in [503]:
|
||||
raise httpx.HTTPStatusError(
|
||||
"Service Unavailable", request=error_response.request, response=error_response
|
||||
)
|
||||
|
||||
error_response.raise_for_status = raise_for_status
|
||||
|
||||
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
|
||||
success_response.raise_for_status = lambda: None
|
||||
|
||||
with patch("httpx.AsyncClient.request") as mock_request:
|
||||
# First call returns error status code, second call succeeds
|
||||
mock_request.side_effect = [error_response, success_response]
|
||||
|
||||
response = await async_request_with_retry(
|
||||
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
|
||||
)
|
||||
|
||||
error_response.raise_for_status = raise_for_status
|
||||
|
||||
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
|
||||
success_response.raise_for_status = lambda: None
|
||||
|
||||
with patch("httpx.AsyncClient.request") as mock_request:
|
||||
# First call returns error status code, second call succeeds
|
||||
mock_request.side_effect = [error_response, success_response]
|
||||
|
||||
response = await async_request_with_retry(
|
||||
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
|
||||
)
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
mock_sleep.assert_called()
|
||||
|
Loading…
x
Reference in New Issue
Block a user