chore: make mypy run with --check-untyped-defs; fix some errors (#9447)

* chore: make mypy run with --check-untyped-defs; fix some errors

* small fixes

* use HfPipeline

* fix license error
This commit is contained in:
Stefano Fiorucci 2025-05-27 09:35:25 +02:00 committed by GitHub
parent da60156174
commit d8487c4d8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 59 additions and 49 deletions

View File

@ -41,10 +41,11 @@ jobs:
requirements: ${{ env.REQUIREMENTS_FILE }}
fail: "Copyleft,Other,Error"
# Exclusions in the vanilla distribution must be explicitly motivated
# - jsonschema is MIT but, starting from version 4.24.0, pip-license-checker does not recognize it
# - tqdm is MLP but there are no better alternatives
# - typing_extensions>=4.13.0 has a Python Software Foundation License 2.0 but pip-license-checker does not recognize it
# (https://github.com/pilosus/pip-license-checker/issues/143)
# - tqdm is MLP but there are no better alternatives
exclude: "(?i)^(tqdm|typing_extensions).*"
exclude: "(?i)^(jsonschema|tqdm|typing_extensions).*"
# We keep the license inventory on FOSSA
- name: Send license report to Fossa

View File

@ -10,7 +10,8 @@ from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs
with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
from transformers import Pipeline, pipeline
from transformers import Pipeline as HfPipeline
from transformers import pipeline
@component
@ -129,7 +130,7 @@ class TransformersZeroShotDocumentClassifier:
)
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
self.pipeline: Optional[Pipeline] = None
self.pipeline: Optional[HfPipeline] = None
def _get_telemetry_data(self) -> Dict[str, Any]:
"""

View File

@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_and_transformers_import:
from huggingface_hub import model_info
from transformers import Pipeline as HfPipeline
from transformers import StoppingCriteriaList, pipeline
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
@ -235,7 +236,7 @@ class HuggingFaceLocalChatGenerator:
self.generation_kwargs = generation_kwargs
self.chat_template = chat_template
self.streaming_callback = streaming_callback
self.pipeline = None
self.pipeline: Optional[HfPipeline] = None
self.tools = tools
self._owns_executor = async_executor is None
@ -352,9 +353,11 @@ class HuggingFaceLocalChatGenerator:
tools = tools or self.tools
if tools and streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)
_check_duplicate_tool_names(list(tools or []))
tokenizer = self.pipeline.tokenizer
# initialized text-generation/text2text-generation pipelines always have a non-None tokenizer
assert tokenizer is not None
# Check and update generation parameters
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
@ -398,6 +401,9 @@ class HuggingFaceLocalChatGenerator:
tools=[tc.tool_spec for tc in tools] if tools else None,
)
# prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
assert isinstance(prepared_prompt, str)
# Avoid some unnecessary warnings in the generation pipeline call
generation_kwargs["pad_token_id"] = (
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
@ -515,9 +521,11 @@ class HuggingFaceLocalChatGenerator:
tools = tools or self.tools
if tools and streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)
_check_duplicate_tool_names(list(tools or []))
tokenizer = self.pipeline.tokenizer
# initialized text-generation/text2text-generation pipelines always have a non-None tokenizer
assert tokenizer is not None
# Check and update generation parameters
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
@ -557,8 +565,8 @@ class HuggingFaceLocalChatGenerator:
hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
)
# prepared_prompt is a string, but transformers has some type issues
prepared_prompt = cast(str, prepared_prompt)
# prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
assert isinstance(prepared_prompt, str)
# Avoid some unnecessary warnings in the generation pipeline call
generation_kwargs["pad_token_id"] = (

View File

@ -21,7 +21,8 @@ logger = logging.getLogger(__name__)
SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import:
from transformers import Pipeline, StoppingCriteriaList, pipeline
from transformers import Pipeline as HfPipeline
from transformers import StoppingCriteriaList, pipeline
from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
HFTokenStreamingHandler,
@ -126,7 +127,7 @@ class HuggingFaceLocalGenerator:
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
self.generation_kwargs = generation_kwargs
self.stop_words = stop_words
self.pipeline: Optional[Pipeline] = None
self.pipeline: Optional[HfPipeline] = None
self.stopping_criteria_list: Optional[StoppingCriteriaList] = None
self.streaming_callback = streaming_callback
@ -152,7 +153,7 @@ class HuggingFaceLocalGenerator:
return
if self.pipeline is None:
self.pipeline = cast(Pipeline, pipeline(**self.huggingface_pipeline_kwargs))
self.pipeline = cast(HfPipeline, pipeline(**self.huggingface_pipeline_kwargs))
if self.stop_words:
# text-generation and text2text-generation pipelines always have a non-None tokenizer

View File

@ -9,7 +9,8 @@ from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
from transformers import Pipeline, pipeline
from transformers import Pipeline as HfPipeline
from transformers import pipeline
from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
deserialize_hf_model_kwargs,
@ -136,7 +137,7 @@ class TransformersZeroShotTextRouter:
token=token,
)
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
self.pipeline: Optional["Pipeline"] = None
self.pipeline: Optional[HfPipeline] = None
def _get_telemetry_data(self) -> Dict[str, Any]:
"""

View File

@ -80,10 +80,10 @@ from typing import (
Any,
Dict,
Iterator,
List,
Mapping,
Optional,
Protocol,
Tuple,
Type,
TypeVar,
Union,
@ -196,7 +196,7 @@ class Component(Protocol):
class ComponentMeta(type):
@staticmethod
def _positional_to_kwargs(cls_type: Type, args: List) -> Dict[str, Any]:
def _positional_to_kwargs(cls_type: Type, args: Tuple[Any, ...]) -> Dict[str, Any]:
"""
Convert positional arguments to keyword arguments based on the signature of the `__init__` method.
"""

View File

@ -290,14 +290,11 @@ disallow_incomplete_defs = true
warn_return_any = false
warn_unused_configs = true
ignore_missing_imports = true
check_untyped_defs = true
[[tool.mypy.overrides]]
# TODO: Fix component typings
module = "haystack.components.*"
disallow_incomplete_defs = false
[[tool.mypy.overrides]]
module = "haystack.testing.*"
module = ["haystack.components.*", "haystack.testing.*"]
disallow_incomplete_defs = false
[tool.ruff]

View File

@ -47,13 +47,14 @@ def model_info_mock():
@pytest.fixture
def mock_pipeline_tokenizer():
def mock_pipeline_with_tokenizer():
# Mocking the pipeline
mock_pipeline = Mock(return_value=[{"generated_text": "Berlin is cool"}])
# Mocking the tokenizer
mock_tokenizer = Mock(spec=PreTrainedTokenizer)
mock_tokenizer.encode.return_value = ["Berlin", "is", "cool"]
mock_tokenizer.apply_chat_template.return_value = "Berlin is cool"
mock_tokenizer.pad_token_id = 100
mock_pipeline.tokenizer = mock_tokenizer
@ -249,11 +250,11 @@ class TestHuggingFaceLocalChatGenerator:
model="mistralai/Mistral-7B-Instruct-v0.2", task="text2text-generation", token=None, device="cpu"
)
def test_run(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
def test_run(self, model_info_mock, mock_pipeline_with_tokenizer, chat_messages):
generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
# Use the mocked pipeline from the fixture and simulate warm_up
generator.pipeline = mock_pipeline_tokenizer
generator.pipeline = mock_pipeline_with_tokenizer
results = generator.run(messages=chat_messages)
@ -263,16 +264,16 @@ class TestHuggingFaceLocalChatGenerator:
assert chat_message.is_from(ChatRole.ASSISTANT)
assert chat_message.text == "Berlin is cool"
def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipeline_with_tokenizer, chat_messages):
generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
# Use the mocked pipeline from the fixture and simulate warm_up
generator.pipeline = mock_pipeline_tokenizer
generator.pipeline = mock_pipeline_with_tokenizer
generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100}
# Use the mocked pipeline from the fixture and simulate warm_up
generator.pipeline = mock_pipeline_tokenizer
generator.pipeline = mock_pipeline_with_tokenizer
results = generator.run(messages=chat_messages, generation_kwargs=generation_kwargs)
# check kwargs passed pipeline
@ -287,7 +288,7 @@ class TestHuggingFaceLocalChatGenerator:
assert chat_message.is_from(ChatRole.ASSISTANT)
assert chat_message.text == "Berlin is cool"
def test_run_with_streaming_callback(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
def test_run_with_streaming_callback(self, model_info_mock, mock_pipeline_with_tokenizer, chat_messages):
# Define the streaming callback function
def streaming_callback_fn(chunk: StreamingChunk): ...
@ -296,7 +297,7 @@ class TestHuggingFaceLocalChatGenerator:
)
# Use the mocked pipeline from the fixture and simulate warm_up
generator.pipeline = mock_pipeline_tokenizer
generator.pipeline = mock_pipeline_with_tokenizer
results = generator.run(messages=chat_messages)
@ -308,14 +309,16 @@ class TestHuggingFaceLocalChatGenerator:
generator.pipeline.assert_called_once()
generator.pipeline.call_args[1]["streamer"].token_handler == streaming_callback_fn
def test_run_with_streaming_callback_in_run_method(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
def test_run_with_streaming_callback_in_run_method(
self, model_info_mock, mock_pipeline_with_tokenizer, chat_messages
):
# Define the streaming callback function
def streaming_callback_fn(chunk: StreamingChunk): ...
generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
# Use the mocked pipeline from the fixture and simulate warm_up
generator.pipeline = mock_pipeline_tokenizer
generator.pipeline = mock_pipeline_with_tokenizer
results = generator.run(messages=chat_messages, streaming_callback=streaming_callback_fn)
@ -447,15 +450,12 @@ class TestHuggingFaceLocalChatGenerator:
assert "22°C" in message.text
assert message.meta["finish_reason"] == "stop"
def test_run_with_custom_tool_parser(self, model_info_mock, tools):
def test_run_with_custom_tool_parser(self, model_info_mock, mock_pipeline_with_tokenizer, tools):
"""Test that a custom tool parsing function works correctly."""
generator = HuggingFaceLocalChatGenerator(
model="meta-llama/Llama-2-13b-chat-hf", tools=tools, tool_parsing_function=custom_tool_parser
)
generator.pipeline = Mock(return_value=[{"mocked_response": "Mocked response, we don't use it"}])
generator.pipeline.tokenizer = Mock()
generator.pipeline.tokenizer.encode.return_value = [1, 2, 3]
generator.pipeline.tokenizer.pad_token_id = 1
generator.pipeline = mock_pipeline_with_tokenizer
messages = [ChatMessage.from_user("What's the weather like in Berlin?")]
results = generator.run(messages=messages)
@ -474,6 +474,7 @@ class TestHuggingFaceLocalChatGenerator:
generator.pipeline.tokenizer = Mock()
generator.pipeline.tokenizer.encode.return_value = [1, 2, 3]
generator.pipeline.tokenizer.pad_token_id = 1
generator.pipeline.tokenizer.apply_chat_template.return_value = "Irrelevant"
messages = [ChatMessage.from_user("What's the weather like in Berlin?")]
results = generator.run(messages=messages)
@ -485,10 +486,10 @@ class TestHuggingFaceLocalChatGenerator:
# Async tests
async def test_run_async(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
async def test_run_async(self, model_info_mock, mock_pipeline_with_tokenizer, chat_messages):
"""Test basic async functionality"""
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
generator.pipeline = mock_pipeline_tokenizer
generator.pipeline = mock_pipeline_with_tokenizer
results = await generator.run_async(messages=chat_messages)
@ -498,10 +499,10 @@ class TestHuggingFaceLocalChatGenerator:
assert chat_message.is_from(ChatRole.ASSISTANT)
assert chat_message.text == "Berlin is cool"
async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_tokenizer, tools):
async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_with_tokenizer, tools):
"""Test async functionality with tools"""
generator = HuggingFaceLocalChatGenerator(model="mocked-model", tools=tools)
generator.pipeline = mock_pipeline_tokenizer
generator.pipeline = mock_pipeline_with_tokenizer
# Mock the pipeline to return a tool call format
generator.pipeline.return_value = [{"generated_text": '{"name": "weather", "arguments": {"city": "Berlin"}}'}]
@ -516,10 +517,10 @@ class TestHuggingFaceLocalChatGenerator:
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Berlin"}
async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_with_tokenizer, chat_messages):
"""Test handling of multiple concurrent async requests"""
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
generator.pipeline = mock_pipeline_tokenizer
generator.pipeline = mock_pipeline_with_tokenizer
# Create multiple concurrent requests
tasks = [generator.run_async(messages=chat_messages) for _ in range(5)]
@ -530,7 +531,7 @@ class TestHuggingFaceLocalChatGenerator:
assert isinstance(result["replies"][0], ChatMessage)
assert result["replies"][0].text == "Berlin is cool"
async def test_async_error_handling(self, model_info_mock, mock_pipeline_tokenizer):
async def test_async_error_handling(self, model_info_mock, mock_pipeline_with_tokenizer):
"""Test error handling in async context"""
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
@ -539,7 +540,7 @@ class TestHuggingFaceLocalChatGenerator:
await generator.run_async(messages=[ChatMessage.from_user("test")])
# Test with invalid streaming callback
generator.pipeline = mock_pipeline_tokenizer
generator.pipeline = mock_pipeline_with_tokenizer
with pytest.raises(ValueError, match="Using tools and streaming at the same time is not supported"):
await generator.run_async(
messages=[ChatMessage.from_user("test")],
@ -547,7 +548,7 @@ class TestHuggingFaceLocalChatGenerator:
tools=[Tool(name="test", description="test", parameters={}, function=lambda: None)],
)
def test_executor_shutdown(self, model_info_mock, mock_pipeline_tokenizer):
def test_executor_shutdown(self, model_info_mock, mock_pipeline_with_tokenizer):
with patch("haystack.components.generators.chat.hugging_face_local.pipeline") as mock_pipeline:
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
executor = generator.executor
@ -557,12 +558,12 @@ class TestHuggingFaceLocalChatGenerator:
mock_shutdown.assert_called_once_with(wait=True)
def test_hugging_face_local_generator_with_toolset_initialization(
self, model_info_mock, mock_pipeline_tokenizer, tools
self, model_info_mock, mock_pipeline_with_tokenizer, tools
):
"""Test that the HuggingFaceLocalChatGenerator can be initialized with a Toolset."""
toolset = Toolset(tools)
generator = HuggingFaceLocalChatGenerator(model="irrelevant", tools=toolset)
generator.pipeline = mock_pipeline_tokenizer
generator.pipeline = mock_pipeline_with_tokenizer
assert generator.tools == toolset
def test_from_dict_with_toolset(self, model_info_mock, tools):
@ -577,11 +578,11 @@ class TestHuggingFaceLocalChatGenerator:
assert len(deserialized_component.tools) == len(tools)
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_tokenizer, tools):
def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_with_tokenizer, tools):
"""Test that the HuggingFaceLocalChatGenerator can be serialized to a dictionary with a Toolset."""
toolset = Toolset(tools)
generator = HuggingFaceLocalChatGenerator(huggingface_pipeline_kwargs={"model": "irrelevant"}, tools=toolset)
generator.pipeline = mock_pipeline_tokenizer
generator.pipeline = mock_pipeline_with_tokenizer
data = generator.to_dict()
expected_tools_data = {