mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
unpin ruff and update code (#9040)
This commit is contained in:
parent
6366f6577e
commit
c5cde40d3a
@ -17,7 +17,7 @@ repos:
|
|||||||
args: [--markdown-linebreak-ext=md]
|
args: [--markdown-linebreak-ext=md]
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.9.2
|
rev: v0.11.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
@ -95,7 +95,7 @@ class OutputAdapter:
|
|||||||
input_types.update(route_input_names)
|
input_types.update(route_input_names)
|
||||||
|
|
||||||
# the env is not needed, discarded automatically
|
# the env is not needed, discarded automatically
|
||||||
component.set_input_types(self, **{var: Any for var in input_types})
|
component.set_input_types(self, **dict.fromkeys(input_types, Any))
|
||||||
component.set_output_types(self, **{"output": output_type})
|
component.set_output_types(self, **{"output": output_type})
|
||||||
self.output_type = output_type
|
self.output_type = output_type
|
||||||
|
|
||||||
|
@ -127,8 +127,8 @@ class AzureOpenAIDocumentEmbedder:
|
|||||||
self.progress_bar = progress_bar
|
self.progress_bar = progress_bar
|
||||||
self.meta_fields_to_embed = meta_fields_to_embed or []
|
self.meta_fields_to_embed = meta_fields_to_embed or []
|
||||||
self.embedding_separator = embedding_separator
|
self.embedding_separator = embedding_separator
|
||||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||||
self.default_headers = default_headers or {}
|
self.default_headers = default_headers or {}
|
||||||
self.azure_ad_token_provider = azure_ad_token_provider
|
self.azure_ad_token_provider = azure_ad_token_provider
|
||||||
|
|
||||||
|
@ -107,8 +107,8 @@ class AzureOpenAITextEmbedder:
|
|||||||
self.azure_deployment = azure_deployment
|
self.azure_deployment = azure_deployment
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.organization = organization
|
self.organization = organization
|
||||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.suffix = suffix
|
self.suffix = suffix
|
||||||
self.default_headers = default_headers or {}
|
self.default_headers = default_headers or {}
|
||||||
|
@ -108,9 +108,9 @@ class OpenAIDocumentEmbedder:
|
|||||||
self.embedding_separator = embedding_separator
|
self.embedding_separator = embedding_separator
|
||||||
|
|
||||||
if timeout is None:
|
if timeout is None:
|
||||||
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||||
if max_retries is None:
|
if max_retries is None:
|
||||||
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||||
|
|
||||||
self.client = OpenAI(
|
self.client = OpenAI(
|
||||||
api_key=api_key.resolve_value(),
|
api_key=api_key.resolve_value(),
|
||||||
|
@ -90,9 +90,9 @@ class OpenAITextEmbedder:
|
|||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|
||||||
if timeout is None:
|
if timeout is None:
|
||||||
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||||
if max_retries is None:
|
if max_retries is None:
|
||||||
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||||
|
|
||||||
self.client = OpenAI(
|
self.client = OpenAI(
|
||||||
api_key=api_key.resolve_value(),
|
api_key=api_key.resolve_value(),
|
||||||
|
@ -141,8 +141,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
|||||||
self.azure_deployment = azure_deployment
|
self.azure_deployment = azure_deployment
|
||||||
self.organization = organization
|
self.organization = organization
|
||||||
self.model: str = azure_deployment or "gpt-4o-mini"
|
self.model: str = azure_deployment or "gpt-4o-mini"
|
||||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||||
self.default_headers = default_headers or {}
|
self.default_headers = default_headers or {}
|
||||||
self.azure_ad_token_provider = azure_ad_token_provider
|
self.azure_ad_token_provider = azure_ad_token_provider
|
||||||
|
|
||||||
|
@ -149,8 +149,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
|||||||
self.azure_deployment = azure_deployment
|
self.azure_deployment = azure_deployment
|
||||||
self.organization = organization
|
self.organization = organization
|
||||||
self.model = azure_deployment or "gpt-4o-mini"
|
self.model = azure_deployment or "gpt-4o-mini"
|
||||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||||
self.default_headers = default_headers or {}
|
self.default_headers = default_headers or {}
|
||||||
self.azure_ad_token_provider = azure_ad_token_provider
|
self.azure_ad_token_provider = azure_ad_token_provider
|
||||||
|
|
||||||
|
@ -146,9 +146,9 @@ class OpenAIChatGenerator:
|
|||||||
_check_duplicate_tool_names(tools)
|
_check_duplicate_tool_names(tools)
|
||||||
|
|
||||||
if timeout is None:
|
if timeout is None:
|
||||||
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||||
if max_retries is None:
|
if max_retries is None:
|
||||||
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||||
|
|
||||||
client_args: Dict[str, Any] = {
|
client_args: Dict[str, Any] = {
|
||||||
"api_key": api_key.resolve_value(),
|
"api_key": api_key.resolve_value(),
|
||||||
|
@ -112,9 +112,9 @@ class OpenAIGenerator:
|
|||||||
self.organization = organization
|
self.organization = organization
|
||||||
|
|
||||||
if timeout is None:
|
if timeout is None:
|
||||||
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||||
if max_retries is None:
|
if max_retries is None:
|
||||||
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||||
|
|
||||||
self.client = OpenAI(
|
self.client = OpenAI(
|
||||||
api_key=api_key.resolve_value(),
|
api_key=api_key.resolve_value(),
|
||||||
|
@ -71,8 +71,8 @@ class DALLEImageGenerator:
|
|||||||
self.api_base_url = api_base_url
|
self.api_base_url = api_base_url
|
||||||
self.organization = organization
|
self.organization = organization
|
||||||
|
|
||||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||||
|
|
||||||
self.client: Optional[OpenAI] = None
|
self.client: Optional[OpenAI] = None
|
||||||
|
|
||||||
|
@ -225,7 +225,7 @@ class ConditionalRouter:
|
|||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
|
|
||||||
# add mandatory input types
|
# add mandatory input types
|
||||||
component.set_input_types(self, **{var: Any for var in mandatory_input_types})
|
component.set_input_types(self, **dict.fromkeys(mandatory_input_types, Any))
|
||||||
|
|
||||||
# now add optional input types
|
# now add optional input types
|
||||||
for optional_var_name in self.optional_variables:
|
for optional_var_name in self.optional_variables:
|
||||||
|
@ -96,7 +96,7 @@ class FileTypeRouter:
|
|||||||
component.set_output_types(
|
component.set_output_types(
|
||||||
self,
|
self,
|
||||||
unclassified=List[Union[str, Path, ByteStream]],
|
unclassified=List[Union[str, Path, ByteStream]],
|
||||||
**{mime_type: List[Union[str, Path, ByteStream]] for mime_type in mime_types},
|
**dict.fromkeys(mime_types, List[Union[str, Path, ByteStream]]),
|
||||||
)
|
)
|
||||||
self.mime_types = mime_types
|
self.mime_types = mime_types
|
||||||
self._additional_mimetypes = additional_mimetypes
|
self._additional_mimetypes = additional_mimetypes
|
||||||
|
@ -81,7 +81,7 @@ class MetadataRouter:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details."
|
"Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details."
|
||||||
)
|
)
|
||||||
component.set_output_types(self, unmatched=List[Document], **{edge: List[Document] for edge in rules})
|
component.set_output_types(self, unmatched=List[Document], **dict.fromkeys(rules, List[Document]))
|
||||||
|
|
||||||
def run(self, documents: List[Document]):
|
def run(self, documents: List[Document]):
|
||||||
"""
|
"""
|
||||||
|
@ -59,7 +59,7 @@ class TextLanguageRouter:
|
|||||||
if not languages:
|
if not languages:
|
||||||
languages = ["en"]
|
languages = ["en"]
|
||||||
self.languages = languages
|
self.languages = languages
|
||||||
component.set_output_types(self, unmatched=str, **{language: str for language in languages})
|
component.set_output_types(self, unmatched=str, **dict.fromkeys(languages, str))
|
||||||
|
|
||||||
def run(self, text: str) -> Dict[str, str]:
|
def run(self, text: str) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
|
@ -116,7 +116,7 @@ class TransformersTextRouter:
|
|||||||
self.labels = list(config.label2id.keys())
|
self.labels = list(config.label2id.keys())
|
||||||
else:
|
else:
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
component.set_output_types(self, **{label: str for label in self.labels})
|
component.set_output_types(self, **dict.fromkeys(self.labels, str))
|
||||||
|
|
||||||
self.pipeline = None
|
self.pipeline = None
|
||||||
|
|
||||||
|
@ -128,7 +128,7 @@ class TransformersZeroShotTextRouter:
|
|||||||
self.token = token
|
self.token = token
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
self.multi_label = multi_label
|
self.multi_label = multi_label
|
||||||
component.set_output_types(self, **{label: str for label in labels})
|
component.set_output_types(self, **dict.fromkeys(labels, str))
|
||||||
|
|
||||||
huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
|
huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
|
||||||
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
|
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
|
||||||
|
@ -143,7 +143,7 @@ class AsyncPipeline(PipelineBase):
|
|||||||
# For quick lookup of downstream receivers
|
# For quick lookup of downstream receivers
|
||||||
ordered_names = sorted(self.graph.nodes.keys())
|
ordered_names = sorted(self.graph.nodes.keys())
|
||||||
cached_receivers = {n: self._find_receivers_from(n) for n in ordered_names}
|
cached_receivers = {n: self._find_receivers_from(n) for n in ordered_names}
|
||||||
component_visits = {component_name: 0 for component_name in ordered_names}
|
component_visits = dict.fromkeys(ordered_names, 0)
|
||||||
cached_topological_sort = None
|
cached_topological_sort = None
|
||||||
|
|
||||||
# We fill the queue once and raise if all components are BLOCKED
|
# We fill the queue once and raise if all components are BLOCKED
|
||||||
|
@ -202,7 +202,7 @@ class Pipeline(PipelineBase):
|
|||||||
ordered_component_names = sorted(self.graph.nodes.keys())
|
ordered_component_names = sorted(self.graph.nodes.keys())
|
||||||
|
|
||||||
# We track component visits to decide if a component can run.
|
# We track component visits to decide if a component can run.
|
||||||
component_visits = {component_name: 0 for component_name in ordered_component_names}
|
component_visits = dict.fromkeys(ordered_component_names, 0)
|
||||||
|
|
||||||
# We need to access a component's receivers multiple times during a pipeline run.
|
# We need to access a component's receivers multiple times during a pipeline run.
|
||||||
# We store them here for easy access.
|
# We store them here for easy access.
|
||||||
|
@ -214,7 +214,7 @@ def component_class( # pylint: disable=too-many-positional-arguments
|
|||||||
def run(self, **kwargs): # pylint: disable=unused-argument
|
def run(self, **kwargs): # pylint: disable=unused-argument
|
||||||
if output is not None:
|
if output is not None:
|
||||||
return output
|
return output
|
||||||
return {name: None for name in output_types.keys()}
|
return dict.fromkeys(output_types.keys())
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return default_to_dict(self)
|
return default_to_dict(self)
|
||||||
|
@ -18,7 +18,7 @@ class FString:
|
|||||||
self.variables = variables or []
|
self.variables = variables or []
|
||||||
if "template" in self.variables:
|
if "template" in self.variables:
|
||||||
raise ValueError("The variable name 'template' is reserved and cannot be used.")
|
raise ValueError("The variable name 'template' is reserved and cannot be used.")
|
||||||
component.set_input_types(self, **{variable: Any for variable in self.variables})
|
component.set_input_types(self, **dict.fromkeys(self.variables, Any))
|
||||||
|
|
||||||
@component.output_types(string=str)
|
@component.output_types(string=str)
|
||||||
def run(self, template: Optional[str] = None, **kwargs):
|
def run(self, template: Optional[str] = None, **kwargs):
|
||||||
|
@ -11,10 +11,10 @@ from haystack.core.component import component
|
|||||||
class Repeat:
|
class Repeat:
|
||||||
def __init__(self, outputs: List[str]):
|
def __init__(self, outputs: List[str]):
|
||||||
self._outputs = outputs
|
self._outputs = outputs
|
||||||
component.set_output_types(self, **{k: int for k in outputs})
|
component.set_output_types(self, **dict.fromkeys(outputs, int))
|
||||||
|
|
||||||
def run(self, value: int):
|
def run(self, value: int):
|
||||||
"""
|
"""
|
||||||
:param value: the value to repeat.
|
:param value: the value to repeat.
|
||||||
"""
|
"""
|
||||||
return {val: value for val in self._outputs}
|
return dict.fromkeys(self._outputs, value)
|
||||||
|
@ -64,7 +64,7 @@ dependencies = [
|
|||||||
installer = "uv"
|
installer = "uv"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"pre-commit",
|
"pre-commit",
|
||||||
"ruff<0.10.0",
|
"ruff",
|
||||||
"toml",
|
"toml",
|
||||||
"reno",
|
"reno",
|
||||||
# dulwich is a reno dependency, they pin it at >=0.15.0 so pip takes ton of time to resolve the dependency tree.
|
# dulwich is a reno dependency, they pin it at >=0.15.0 so pip takes ton of time to resolve the dependency tree.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user