mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
unpin ruff and update code (#9040)
This commit is contained in:
parent
6366f6577e
commit
c5cde40d3a
@ -17,7 +17,7 @@ repos:
|
||||
args: [--markdown-linebreak-ext=md]
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.9.2
|
||||
rev: v0.11.0
|
||||
hooks:
|
||||
- id: ruff
|
||||
- id: ruff-format
|
||||
|
@ -95,7 +95,7 @@ class OutputAdapter:
|
||||
input_types.update(route_input_names)
|
||||
|
||||
# the env is not needed, discarded automatically
|
||||
component.set_input_types(self, **{var: Any for var in input_types})
|
||||
component.set_input_types(self, **dict.fromkeys(input_types, Any))
|
||||
component.set_output_types(self, **{"output": output_type})
|
||||
self.output_type = output_type
|
||||
|
||||
|
@ -127,8 +127,8 @@ class AzureOpenAIDocumentEmbedder:
|
||||
self.progress_bar = progress_bar
|
||||
self.meta_fields_to_embed = meta_fields_to_embed or []
|
||||
self.embedding_separator = embedding_separator
|
||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||
self.default_headers = default_headers or {}
|
||||
self.azure_ad_token_provider = azure_ad_token_provider
|
||||
|
||||
|
@ -107,8 +107,8 @@ class AzureOpenAITextEmbedder:
|
||||
self.azure_deployment = azure_deployment
|
||||
self.dimensions = dimensions
|
||||
self.organization = organization
|
||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.default_headers = default_headers or {}
|
||||
|
@ -108,9 +108,9 @@ class OpenAIDocumentEmbedder:
|
||||
self.embedding_separator = embedding_separator
|
||||
|
||||
if timeout is None:
|
||||
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
||||
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||
if max_retries is None:
|
||||
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
||||
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=api_key.resolve_value(),
|
||||
|
@ -90,9 +90,9 @@ class OpenAITextEmbedder:
|
||||
self.api_key = api_key
|
||||
|
||||
if timeout is None:
|
||||
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
||||
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||
if max_retries is None:
|
||||
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
||||
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=api_key.resolve_value(),
|
||||
|
@ -141,8 +141,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
self.azure_deployment = azure_deployment
|
||||
self.organization = organization
|
||||
self.model: str = azure_deployment or "gpt-4o-mini"
|
||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||
self.default_headers = default_headers or {}
|
||||
self.azure_ad_token_provider = azure_ad_token_provider
|
||||
|
||||
|
@ -149,8 +149,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
self.azure_deployment = azure_deployment
|
||||
self.organization = organization
|
||||
self.model = azure_deployment or "gpt-4o-mini"
|
||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||
self.default_headers = default_headers or {}
|
||||
self.azure_ad_token_provider = azure_ad_token_provider
|
||||
|
||||
|
@ -146,9 +146,9 @@ class OpenAIChatGenerator:
|
||||
_check_duplicate_tool_names(tools)
|
||||
|
||||
if timeout is None:
|
||||
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
||||
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||
if max_retries is None:
|
||||
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
||||
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||
|
||||
client_args: Dict[str, Any] = {
|
||||
"api_key": api_key.resolve_value(),
|
||||
|
@ -112,9 +112,9 @@ class OpenAIGenerator:
|
||||
self.organization = organization
|
||||
|
||||
if timeout is None:
|
||||
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
||||
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||
if max_retries is None:
|
||||
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
||||
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=api_key.resolve_value(),
|
||||
|
@ -71,8 +71,8 @@ class DALLEImageGenerator:
|
||||
self.api_base_url = api_base_url
|
||||
self.organization = organization
|
||||
|
||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||
|
||||
self.client: Optional[OpenAI] = None
|
||||
|
||||
|
@ -225,7 +225,7 @@ class ConditionalRouter:
|
||||
logger.warning(msg)
|
||||
|
||||
# add mandatory input types
|
||||
component.set_input_types(self, **{var: Any for var in mandatory_input_types})
|
||||
component.set_input_types(self, **dict.fromkeys(mandatory_input_types, Any))
|
||||
|
||||
# now add optional input types
|
||||
for optional_var_name in self.optional_variables:
|
||||
|
@ -96,7 +96,7 @@ class FileTypeRouter:
|
||||
component.set_output_types(
|
||||
self,
|
||||
unclassified=List[Union[str, Path, ByteStream]],
|
||||
**{mime_type: List[Union[str, Path, ByteStream]] for mime_type in mime_types},
|
||||
**dict.fromkeys(mime_types, List[Union[str, Path, ByteStream]]),
|
||||
)
|
||||
self.mime_types = mime_types
|
||||
self._additional_mimetypes = additional_mimetypes
|
||||
|
@ -81,7 +81,7 @@ class MetadataRouter:
|
||||
raise ValueError(
|
||||
"Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details."
|
||||
)
|
||||
component.set_output_types(self, unmatched=List[Document], **{edge: List[Document] for edge in rules})
|
||||
component.set_output_types(self, unmatched=List[Document], **dict.fromkeys(rules, List[Document]))
|
||||
|
||||
def run(self, documents: List[Document]):
|
||||
"""
|
||||
|
@ -59,7 +59,7 @@ class TextLanguageRouter:
|
||||
if not languages:
|
||||
languages = ["en"]
|
||||
self.languages = languages
|
||||
component.set_output_types(self, unmatched=str, **{language: str for language in languages})
|
||||
component.set_output_types(self, unmatched=str, **dict.fromkeys(languages, str))
|
||||
|
||||
def run(self, text: str) -> Dict[str, str]:
|
||||
"""
|
||||
|
@ -116,7 +116,7 @@ class TransformersTextRouter:
|
||||
self.labels = list(config.label2id.keys())
|
||||
else:
|
||||
self.labels = labels
|
||||
component.set_output_types(self, **{label: str for label in self.labels})
|
||||
component.set_output_types(self, **dict.fromkeys(self.labels, str))
|
||||
|
||||
self.pipeline = None
|
||||
|
||||
|
@ -128,7 +128,7 @@ class TransformersZeroShotTextRouter:
|
||||
self.token = token
|
||||
self.labels = labels
|
||||
self.multi_label = multi_label
|
||||
component.set_output_types(self, **{label: str for label in labels})
|
||||
component.set_output_types(self, **dict.fromkeys(labels, str))
|
||||
|
||||
huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
|
||||
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
|
||||
|
@ -143,7 +143,7 @@ class AsyncPipeline(PipelineBase):
|
||||
# For quick lookup of downstream receivers
|
||||
ordered_names = sorted(self.graph.nodes.keys())
|
||||
cached_receivers = {n: self._find_receivers_from(n) for n in ordered_names}
|
||||
component_visits = {component_name: 0 for component_name in ordered_names}
|
||||
component_visits = dict.fromkeys(ordered_names, 0)
|
||||
cached_topological_sort = None
|
||||
|
||||
# We fill the queue once and raise if all components are BLOCKED
|
||||
|
@ -202,7 +202,7 @@ class Pipeline(PipelineBase):
|
||||
ordered_component_names = sorted(self.graph.nodes.keys())
|
||||
|
||||
# We track component visits to decide if a component can run.
|
||||
component_visits = {component_name: 0 for component_name in ordered_component_names}
|
||||
component_visits = dict.fromkeys(ordered_component_names, 0)
|
||||
|
||||
# We need to access a component's receivers multiple times during a pipeline run.
|
||||
# We store them here for easy access.
|
||||
|
@ -214,7 +214,7 @@ def component_class( # pylint: disable=too-many-positional-arguments
|
||||
def run(self, **kwargs): # pylint: disable=unused-argument
|
||||
if output is not None:
|
||||
return output
|
||||
return {name: None for name in output_types.keys()}
|
||||
return dict.fromkeys(output_types.keys())
|
||||
|
||||
def to_dict(self):
|
||||
return default_to_dict(self)
|
||||
|
@ -18,7 +18,7 @@ class FString:
|
||||
self.variables = variables or []
|
||||
if "template" in self.variables:
|
||||
raise ValueError("The variable name 'template' is reserved and cannot be used.")
|
||||
component.set_input_types(self, **{variable: Any for variable in self.variables})
|
||||
component.set_input_types(self, **dict.fromkeys(self.variables, Any))
|
||||
|
||||
@component.output_types(string=str)
|
||||
def run(self, template: Optional[str] = None, **kwargs):
|
||||
|
@ -11,10 +11,10 @@ from haystack.core.component import component
|
||||
class Repeat:
|
||||
def __init__(self, outputs: List[str]):
|
||||
self._outputs = outputs
|
||||
component.set_output_types(self, **{k: int for k in outputs})
|
||||
component.set_output_types(self, **dict.fromkeys(outputs, int))
|
||||
|
||||
def run(self, value: int):
|
||||
"""
|
||||
:param value: the value to repeat.
|
||||
"""
|
||||
return {val: value for val in self._outputs}
|
||||
return dict.fromkeys(self._outputs, value)
|
||||
|
@ -64,7 +64,7 @@ dependencies = [
|
||||
installer = "uv"
|
||||
dependencies = [
|
||||
"pre-commit",
|
||||
"ruff<0.10.0",
|
||||
"ruff",
|
||||
"toml",
|
||||
"reno",
|
||||
# dulwich is a reno dependency, they pin it at >=0.15.0 so pip takes ton of time to resolve the dependency tree.
|
||||
|
Loading…
x
Reference in New Issue
Block a user