mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 18:59:28 +00:00
chore: cleanup unused code (#6804)
* remove validation module * remove unused code * adjust imports * sort imports
This commit is contained in:
parent
f44f123b3f
commit
df2a23dfa5
@ -1,7 +1,6 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import datetime
|
||||
import importlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
@ -13,16 +12,9 @@ import networkx # type:ignore
|
||||
|
||||
from haystack.core.component import Component, InputSocket, OutputSocket, component
|
||||
from haystack.core.component.connection import Connection, parse_connect_string
|
||||
from haystack.core.errors import (
|
||||
PipelineConnectError,
|
||||
PipelineError,
|
||||
PipelineMaxLoops,
|
||||
PipelineRuntimeError,
|
||||
PipelineValidationError,
|
||||
)
|
||||
from haystack.core.pipeline.descriptions import find_pipeline_outputs
|
||||
from haystack.core.errors import PipelineConnectError, PipelineError, PipelineRuntimeError, PipelineValidationError
|
||||
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
|
||||
from haystack.core.pipeline.draw.draw import RenderingEngines, _draw
|
||||
from haystack.core.pipeline.validation import find_pipeline_inputs
|
||||
from haystack.core.serialization import component_from_dict, component_to_dict
|
||||
from haystack.core.type_utils import _type_name
|
||||
|
||||
@ -679,152 +671,3 @@ class Pipeline:
|
||||
to_run.append((name, comp))
|
||||
|
||||
return final_outputs
|
||||
|
||||
def _record_pipeline_step(
|
||||
self, step, components_queue, mandatory_values_buffer, optional_values_buffer, pipeline_output
|
||||
):
|
||||
"""
|
||||
Stores a snapshot of this step into the self.debug dictionary of the pipeline.
|
||||
"""
|
||||
self._debug[step] = {
|
||||
"time": datetime.datetime.now(),
|
||||
"components_queue": components_queue,
|
||||
"mandatory_values_buffer": mandatory_values_buffer,
|
||||
"optional_values_buffer": optional_values_buffer,
|
||||
"pipeline_output": pipeline_output,
|
||||
}
|
||||
|
||||
def _clear_visits_count(self):
|
||||
"""
|
||||
Make sure all nodes's visits count is zero.
|
||||
"""
|
||||
for node in self.graph.nodes:
|
||||
self.graph.nodes[node]["visits"] = 0
|
||||
|
||||
def _check_max_loops(self, component_name: str):
|
||||
"""
|
||||
Verify whether this component run too many times.
|
||||
"""
|
||||
if self.graph.nodes[component_name]["visits"] > self.max_loops_allowed:
|
||||
raise PipelineMaxLoops(
|
||||
f"Maximum loops count ({self.max_loops_allowed}) exceeded for component '{component_name}'."
|
||||
)
|
||||
|
||||
def _add_value_to_buffers(
|
||||
self,
|
||||
value: Any,
|
||||
connection: Connection,
|
||||
components_queue: List[str],
|
||||
mandatory_values_buffer: Dict[Connection, Any],
|
||||
optional_values_buffer: Dict[Connection, Any],
|
||||
):
|
||||
"""
|
||||
Given a value and the connection it is being sent on, it updates the buffers and the components queue.
|
||||
"""
|
||||
if connection.is_mandatory:
|
||||
mandatory_values_buffer[connection] = value
|
||||
if connection.receiver and connection.receiver not in components_queue:
|
||||
components_queue.append(connection.receiver)
|
||||
else:
|
||||
optional_values_buffer[connection] = value
|
||||
|
||||
def _ready_to_run(
|
||||
self, component_name: str, mandatory_values_buffer: Dict[Connection, Any], components_queue: List[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if a component is ready to run, False otherwise.
|
||||
"""
|
||||
connections_with_value = {conn for conn in mandatory_values_buffer.keys() if conn.receiver == component_name}
|
||||
expected_connections = set(self._mandatory_connections[component_name])
|
||||
if expected_connections.issubset(connections_with_value):
|
||||
logger.debug("Component '%s' is ready to run. All mandatory values were received.", component_name)
|
||||
return True
|
||||
|
||||
# Collect the missing values still being computed we need to wait for
|
||||
missing_connections: Set[Connection] = expected_connections - connections_with_value
|
||||
connections_to_wait = []
|
||||
for missing_conn in missing_connections:
|
||||
if any(
|
||||
networkx.has_path(self.graph, component_to_run, missing_conn.sender)
|
||||
for component_to_run in components_queue
|
||||
):
|
||||
connections_to_wait.append(missing_conn)
|
||||
|
||||
if not connections_to_wait:
|
||||
# None of the missing values are needed to visit this part of the graph: we can run the component
|
||||
logger.debug(
|
||||
"Component '%s' is ready to run. A variadic input parameter received all the expected values.",
|
||||
component_name,
|
||||
)
|
||||
return True
|
||||
|
||||
# Component can't run, waiting for the values needed by `connections_to_wait`
|
||||
logger.debug(
|
||||
"Component '%s' is not ready to run, some values are still missing: %s", component_name, connections_to_wait
|
||||
)
|
||||
# Put the component back in the queue
|
||||
components_queue.append(component_name)
|
||||
return False
|
||||
|
||||
def _extract_inputs_from_buffer(self, component_name: str, buffer: Dict[Connection, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract a component's input values from one of the value buffers.
|
||||
"""
|
||||
inputs = defaultdict(list)
|
||||
connections: List[Connection] = []
|
||||
|
||||
for connection in buffer.keys():
|
||||
if connection.receiver == component_name:
|
||||
connections.append(connection)
|
||||
|
||||
for key in connections:
|
||||
value = buffer.pop(key)
|
||||
if key.receiver_socket:
|
||||
if key.receiver_socket.is_variadic:
|
||||
inputs[key.receiver_socket.name].append(value)
|
||||
else:
|
||||
inputs[key.receiver_socket.name] = value
|
||||
return inputs
|
||||
|
||||
def _run_component(self, name: str, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Once we're confident this component is ready to run, run it and collect the output.
|
||||
"""
|
||||
self.graph.nodes[name]["visits"] += 1
|
||||
instance = self.graph.nodes[name]["instance"]
|
||||
try:
|
||||
logger.info("* Running %s", name)
|
||||
logger.debug(" '%s' inputs: %s", name, inputs)
|
||||
|
||||
outputs = instance.run(**inputs)
|
||||
|
||||
# Unwrap the output
|
||||
logger.debug(" '%s' outputs: %s\n", name, outputs)
|
||||
|
||||
# Make sure the component returned a dictionary
|
||||
if not isinstance(outputs, dict):
|
||||
raise PipelineRuntimeError(
|
||||
f"Component '{name}' returned a value of type '{_type_name(type(outputs))}' instead of a dict. "
|
||||
"Components must always return dictionaries: check the the documentation."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise PipelineRuntimeError(
|
||||
f"{name} raised '{e.__class__.__name__}: {e}' \nInputs: {inputs}\n\n"
|
||||
"See the stacktrace above for more information."
|
||||
) from e
|
||||
|
||||
return outputs
|
||||
|
||||
def _collect_targets(self, component_name: str, socket_name: str) -> List[Connection]:
|
||||
"""
|
||||
Given a component and an output socket name, return a list of Connections
|
||||
for which they represent the sender. Used to route output.
|
||||
"""
|
||||
return [
|
||||
connection
|
||||
for connection in self._connections
|
||||
if connection.sender == component_name
|
||||
and connection.sender_socket
|
||||
and connection.sender_socket.name == socket_name
|
||||
]
|
||||
|
||||
@ -1,78 +0,0 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Dict, Any
|
||||
import logging
|
||||
|
||||
import networkx # type:ignore
|
||||
|
||||
from haystack.core.errors import PipelineValidationError
|
||||
from haystack.core.component.sockets import InputSocket
|
||||
from haystack.core.pipeline.descriptions import find_pipeline_inputs, describe_pipeline_inputs_as_string
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_pipeline_input(graph: networkx.MultiDiGraph, input_values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Make sure the pipeline is properly built and that the input received makes sense.
|
||||
Returns the input values, validated and updated at need.
|
||||
"""
|
||||
if not any(sockets for sockets in find_pipeline_inputs(graph).values()):
|
||||
raise PipelineValidationError("This pipeline has no inputs.")
|
||||
|
||||
# Make sure the input keys are all nodes of the pipeline
|
||||
unknown_components = [key for key in input_values.keys() if not key in graph.nodes]
|
||||
if unknown_components:
|
||||
all_inputs = describe_pipeline_inputs_as_string(graph)
|
||||
raise ValueError(
|
||||
f"Pipeline received data for unknown component(s): {', '.join(unknown_components)}\n\n{all_inputs}"
|
||||
)
|
||||
|
||||
# Make sure all necessary sockets are connected
|
||||
_validate_input_sockets_are_connected(graph, input_values)
|
||||
|
||||
# Make sure that the pipeline input is only sent to nodes that won't receive data from other nodes
|
||||
_validate_nodes_receive_only_expected_input(graph, input_values)
|
||||
|
||||
return input_values
|
||||
|
||||
|
||||
def _validate_input_sockets_are_connected(graph: networkx.MultiDiGraph, input_values: Dict[str, Any]):
|
||||
"""
|
||||
Make sure all the inputs nodes are receiving all the values they need, either from the Pipeline's input or from
|
||||
other nodes.
|
||||
"""
|
||||
valid_inputs = find_pipeline_inputs(graph)
|
||||
for node, sockets in valid_inputs.items():
|
||||
for socket in sockets:
|
||||
inputs_for_node = input_values.get(node, {})
|
||||
missing_input_value = (
|
||||
inputs_for_node is None
|
||||
or not socket.name in inputs_for_node.keys()
|
||||
or inputs_for_node.get(socket.name, None) is None
|
||||
)
|
||||
if missing_input_value and socket.is_mandatory and not socket.is_variadic:
|
||||
all_inputs = describe_pipeline_inputs_as_string(graph)
|
||||
raise ValueError(f"Missing input: {node}.{socket.name}\n\n{all_inputs}")
|
||||
|
||||
|
||||
def _validate_nodes_receive_only_expected_input(graph: networkx.MultiDiGraph, input_values: Dict[str, Any]):
|
||||
"""
|
||||
Make sure that every input node is only receiving input values from EITHER the pipeline's input or another node,
|
||||
but never from both.
|
||||
"""
|
||||
for node, input_data in input_values.items():
|
||||
for socket_name in input_data.keys():
|
||||
if input_data.get(socket_name, None) is None:
|
||||
continue
|
||||
if not socket_name in graph.nodes[node]["input_sockets"].keys():
|
||||
all_inputs = describe_pipeline_inputs_as_string(graph)
|
||||
raise ValueError(
|
||||
f"Component {node} is not expecting any input value called {socket_name}.\n\n{all_inputs}"
|
||||
)
|
||||
|
||||
input_socket: InputSocket = graph.nodes[node]["input_sockets"][socket_name]
|
||||
if input_socket.senders and not input_socket.is_variadic:
|
||||
raise ValueError(f"The input {socket_name} of {node} is already sent by: {input_socket.senders}")
|
||||
Loading…
x
Reference in New Issue
Block a user