unpin ruff and update code (#9040)

This commit is contained in:
Stefano Fiorucci 2025-03-14 15:53:25 +01:00 committed by GitHub
parent 6366f6577e
commit c5cde40d3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 33 additions and 33 deletions

View File

@ -17,7 +17,7 @@ repos:
args: [--markdown-linebreak-ext=md]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.2
rev: v0.11.0
hooks:
- id: ruff
- id: ruff-format

View File

@ -95,7 +95,7 @@ class OutputAdapter:
input_types.update(route_input_names)
# the env is not needed, discarded automatically
component.set_input_types(self, **{var: Any for var in input_types})
component.set_input_types(self, **dict.fromkeys(input_types, Any))
component.set_output_types(self, **{"output": output_type})
self.output_type = output_type

View File

@ -127,8 +127,8 @@ class AzureOpenAIDocumentEmbedder:
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.default_headers = default_headers or {}
self.azure_ad_token_provider = azure_ad_token_provider

View File

@ -107,8 +107,8 @@ class AzureOpenAITextEmbedder:
self.azure_deployment = azure_deployment
self.dimensions = dimensions
self.organization = organization
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.prefix = prefix
self.suffix = suffix
self.default_headers = default_headers or {}

View File

@ -108,9 +108,9 @@ class OpenAIDocumentEmbedder:
self.embedding_separator = embedding_separator
if timeout is None:
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.client = OpenAI(
api_key=api_key.resolve_value(),

View File

@ -90,9 +90,9 @@ class OpenAITextEmbedder:
self.api_key = api_key
if timeout is None:
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.client = OpenAI(
api_key=api_key.resolve_value(),

View File

@ -141,8 +141,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
self.azure_deployment = azure_deployment
self.organization = organization
self.model: str = azure_deployment or "gpt-4o-mini"
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.default_headers = default_headers or {}
self.azure_ad_token_provider = azure_ad_token_provider

View File

@ -149,8 +149,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
self.azure_deployment = azure_deployment
self.organization = organization
self.model = azure_deployment or "gpt-4o-mini"
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.default_headers = default_headers or {}
self.azure_ad_token_provider = azure_ad_token_provider

View File

@ -146,9 +146,9 @@ class OpenAIChatGenerator:
_check_duplicate_tool_names(tools)
if timeout is None:
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
client_args: Dict[str, Any] = {
"api_key": api_key.resolve_value(),

View File

@ -112,9 +112,9 @@ class OpenAIGenerator:
self.organization = organization
if timeout is None:
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.client = OpenAI(
api_key=api_key.resolve_value(),

View File

@ -71,8 +71,8 @@ class DALLEImageGenerator:
self.api_base_url = api_base_url
self.organization = organization
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.client: Optional[OpenAI] = None

View File

@ -225,7 +225,7 @@ class ConditionalRouter:
logger.warning(msg)
# add mandatory input types
component.set_input_types(self, **{var: Any for var in mandatory_input_types})
component.set_input_types(self, **dict.fromkeys(mandatory_input_types, Any))
# now add optional input types
for optional_var_name in self.optional_variables:

View File

@ -96,7 +96,7 @@ class FileTypeRouter:
component.set_output_types(
self,
unclassified=List[Union[str, Path, ByteStream]],
**{mime_type: List[Union[str, Path, ByteStream]] for mime_type in mime_types},
**dict.fromkeys(mime_types, List[Union[str, Path, ByteStream]]),
)
self.mime_types = mime_types
self._additional_mimetypes = additional_mimetypes

View File

@ -81,7 +81,7 @@ class MetadataRouter:
raise ValueError(
"Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details."
)
component.set_output_types(self, unmatched=List[Document], **{edge: List[Document] for edge in rules})
component.set_output_types(self, unmatched=List[Document], **dict.fromkeys(rules, List[Document]))
def run(self, documents: List[Document]):
"""

View File

@ -59,7 +59,7 @@ class TextLanguageRouter:
if not languages:
languages = ["en"]
self.languages = languages
component.set_output_types(self, unmatched=str, **{language: str for language in languages})
component.set_output_types(self, unmatched=str, **dict.fromkeys(languages, str))
def run(self, text: str) -> Dict[str, str]:
"""

View File

@ -116,7 +116,7 @@ class TransformersTextRouter:
self.labels = list(config.label2id.keys())
else:
self.labels = labels
component.set_output_types(self, **{label: str for label in self.labels})
component.set_output_types(self, **dict.fromkeys(self.labels, str))
self.pipeline = None

View File

@ -128,7 +128,7 @@ class TransformersZeroShotTextRouter:
self.token = token
self.labels = labels
self.multi_label = multi_label
component.set_output_types(self, **{label: str for label in labels})
component.set_output_types(self, **dict.fromkeys(labels, str))
huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},

View File

@ -143,7 +143,7 @@ class AsyncPipeline(PipelineBase):
# For quick lookup of downstream receivers
ordered_names = sorted(self.graph.nodes.keys())
cached_receivers = {n: self._find_receivers_from(n) for n in ordered_names}
component_visits = {component_name: 0 for component_name in ordered_names}
component_visits = dict.fromkeys(ordered_names, 0)
cached_topological_sort = None
# We fill the queue once and raise if all components are BLOCKED

View File

@ -202,7 +202,7 @@ class Pipeline(PipelineBase):
ordered_component_names = sorted(self.graph.nodes.keys())
# We track component visits to decide if a component can run.
component_visits = {component_name: 0 for component_name in ordered_component_names}
component_visits = dict.fromkeys(ordered_component_names, 0)
# We need to access a component's receivers multiple times during a pipeline run.
# We store them here for easy access.

View File

@ -214,7 +214,7 @@ def component_class( # pylint: disable=too-many-positional-arguments
def run(self, **kwargs): # pylint: disable=unused-argument
if output is not None:
return output
return {name: None for name in output_types.keys()}
return dict.fromkeys(output_types.keys())
def to_dict(self):
return default_to_dict(self)

View File

@ -18,7 +18,7 @@ class FString:
self.variables = variables or []
if "template" in self.variables:
raise ValueError("The variable name 'template' is reserved and cannot be used.")
component.set_input_types(self, **{variable: Any for variable in self.variables})
component.set_input_types(self, **dict.fromkeys(self.variables, Any))
@component.output_types(string=str)
def run(self, template: Optional[str] = None, **kwargs):

View File

@ -11,10 +11,10 @@ from haystack.core.component import component
class Repeat:
def __init__(self, outputs: List[str]):
self._outputs = outputs
component.set_output_types(self, **{k: int for k in outputs})
component.set_output_types(self, **dict.fromkeys(outputs, int))
def run(self, value: int):
"""
:param value: the value to repeat.
"""
return {val: value for val in self._outputs}
return dict.fromkeys(self._outputs, value)

View File

@ -64,7 +64,7 @@ dependencies = [
installer = "uv"
dependencies = [
"pre-commit",
"ruff<0.10.0",
"ruff",
"toml",
"reno",
# dulwich is a reno dependency, they pin it at >=0.15.0 so pip takes ton of time to resolve the dependency tree.