mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-29 16:59:47 +00:00
Refactoring the Raypipeline.run method - merging it with the Pipeline.run (#2981)
* Refactoring the `Raypipeline.run` method - merging it with the `Pipeline.run` This is to fix #2968 * Bug: variable `i` was already in use * Removing unused imports * Removing unused import * [EMPTY] Re-trigger CI * Addressing concerns raised pre-review - Removing the attempt to try to make it without the need for `JoinDocuments` - it is okey to fail without `JoinDocuments` for certain pipelines. * Refactoring based on reviews
This commit is contained in:
parent
f4128d3581
commit
aafa017c17
@ -435,6 +435,9 @@ class Pipeline:
|
|||||||
"""
|
"""
|
||||||
self.graph.nodes[name]["component"] = component
|
self.graph.nodes[name]["component"] = component
|
||||||
|
|
||||||
|
def _run_node(self, node_id: str, node_input: Dict[str, Any]) -> Tuple[Dict, str]:
|
||||||
|
return self.graph.nodes[node_id]["component"]._dispatch_run(**node_input)
|
||||||
|
|
||||||
def run( # type: ignore
|
def run( # type: ignore
|
||||||
self,
|
self,
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
@ -506,7 +509,7 @@ class Pipeline:
|
|||||||
if predecessors.isdisjoint(set(queue.keys())): # only execute if predecessor nodes are executed
|
if predecessors.isdisjoint(set(queue.keys())): # only execute if predecessor nodes are executed
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Running node `{node_id}` with input `{node_input}`")
|
logger.debug(f"Running node `{node_id}` with input `{node_input}`")
|
||||||
node_output, stream_id = self.graph.nodes[node_id]["component"]._dispatch_run(**node_input)
|
node_output, stream_id = self._run_node(node_id, node_input)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tb = traceback.format_exc()
|
tb = traceback.format_exc()
|
||||||
raise Exception(
|
raise Exception(
|
||||||
@ -1909,7 +1912,7 @@ class Pipeline:
|
|||||||
not_a_node = set(params.keys()) - set(self.graph.nodes)
|
not_a_node = set(params.keys()) - set(self.graph.nodes)
|
||||||
valid_global_params = set(["debug"]) # Debug will be picked up by _dispatch_run, see its code
|
valid_global_params = set(["debug"]) # Debug will be picked up by _dispatch_run, see its code
|
||||||
for node_id in self.graph.nodes:
|
for node_id in self.graph.nodes:
|
||||||
run_signature_args = inspect.signature(self.graph.nodes[node_id]["component"].run).parameters.keys()
|
run_signature_args = self._get_run_node_signature(node_id)
|
||||||
valid_global_params |= set(run_signature_args)
|
valid_global_params |= set(run_signature_args)
|
||||||
invalid_keys = [key for key in not_a_node if key not in valid_global_params]
|
invalid_keys = [key for key in not_a_node if key not in valid_global_params]
|
||||||
|
|
||||||
@ -1918,6 +1921,9 @@ class Pipeline:
|
|||||||
f"No node(s) or global parameter(s) named {', '.join(invalid_keys)} found in pipeline."
|
f"No node(s) or global parameter(s) named {', '.join(invalid_keys)} found in pipeline."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_run_node_signature(self, node_id: str):
|
||||||
|
return inspect.signature(self.graph.nodes[node_id]["component"].run).parameters.keys()
|
||||||
|
|
||||||
def print_eval_report(
|
def print_eval_report(
|
||||||
self,
|
self,
|
||||||
eval_result: EvaluationResult,
|
eval_result: EvaluationResult,
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Any, Dict, List, Optional
|
import inspect
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import networkx as nx
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from ray import serve
|
from ray import serve
|
||||||
import ray
|
import ray
|
||||||
@ -18,10 +17,8 @@ from haystack.pipelines.config import (
|
|||||||
read_pipeline_config_from_yaml,
|
read_pipeline_config_from_yaml,
|
||||||
validate_config,
|
validate_config,
|
||||||
)
|
)
|
||||||
from haystack.schema import MultiLabel, Document
|
|
||||||
from haystack.nodes.base import BaseComponent, RootNode
|
from haystack.nodes.base import BaseComponent, RootNode
|
||||||
from haystack.pipelines.base import Pipeline
|
from haystack.pipelines.base import Pipeline
|
||||||
from haystack.errors import PipelineError
|
|
||||||
|
|
||||||
|
|
||||||
class RayPipeline(Pipeline):
|
class RayPipeline(Pipeline):
|
||||||
@ -219,61 +216,6 @@ class RayPipeline(Pipeline):
|
|||||||
handle = RayDeployment.get_handle()
|
handle = RayDeployment.get_handle()
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
def run( # type: ignore
|
|
||||||
self,
|
|
||||||
query: Optional[str] = None,
|
|
||||||
file_paths: Optional[List[str]] = None,
|
|
||||||
labels: Optional[MultiLabel] = None,
|
|
||||||
documents: Optional[List[Document]] = None,
|
|
||||||
meta: Optional[dict] = None,
|
|
||||||
params: Optional[dict] = None,
|
|
||||||
):
|
|
||||||
has_next_node = True
|
|
||||||
|
|
||||||
root_node = self.root_node
|
|
||||||
if not root_node:
|
|
||||||
raise PipelineError("Cannot run a pipeline with no nodes.")
|
|
||||||
|
|
||||||
current_node_id: str = root_node
|
|
||||||
|
|
||||||
input_dict: Dict[str, Any] = {"root_node": root_node, "params": params}
|
|
||||||
if query:
|
|
||||||
input_dict["query"] = query
|
|
||||||
if file_paths:
|
|
||||||
input_dict["file_paths"] = file_paths
|
|
||||||
if labels:
|
|
||||||
input_dict["labels"] = labels
|
|
||||||
if documents:
|
|
||||||
input_dict["documents"] = documents
|
|
||||||
if meta:
|
|
||||||
input_dict["meta"] = meta
|
|
||||||
|
|
||||||
output_dict = None
|
|
||||||
|
|
||||||
while has_next_node:
|
|
||||||
output_dict, stream_id = ray.get(self.graph.nodes[current_node_id]["component"].remote(**input_dict))
|
|
||||||
input_dict = output_dict
|
|
||||||
next_nodes = self.get_next_nodes(current_node_id, stream_id)
|
|
||||||
|
|
||||||
if len(next_nodes) > 1:
|
|
||||||
join_node_id = list(nx.neighbors(self.graph, next_nodes[0]))[0]
|
|
||||||
if set(self.graph.predecessors(join_node_id)) != set(next_nodes):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"The current pipeline does not support multiple levels of parallel nodes."
|
|
||||||
)
|
|
||||||
inputs_for_join_node: dict = {"inputs": []}
|
|
||||||
for n_id in next_nodes:
|
|
||||||
output = self.graph.nodes[n_id]["component"].run(**input_dict)
|
|
||||||
inputs_for_join_node["inputs"].append(output)
|
|
||||||
input_dict = inputs_for_join_node
|
|
||||||
current_node_id = join_node_id
|
|
||||||
elif len(next_nodes) == 1:
|
|
||||||
current_node_id = next_nodes[0]
|
|
||||||
else:
|
|
||||||
has_next_node = False
|
|
||||||
|
|
||||||
return output_dict
|
|
||||||
|
|
||||||
def add_node(self, component, name: str, inputs: List[str]):
|
def add_node(self, component, name: str, inputs: List[str]):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"The current implementation of RayPipeline only supports loading Pipelines from a YAML file."
|
"The current implementation of RayPipeline only supports loading Pipelines from a YAML file."
|
||||||
@ -318,6 +260,12 @@ class RayPipeline(Pipeline):
|
|||||||
input_edge_name = "output_1"
|
input_edge_name = "output_1"
|
||||||
self.graph.add_edge(input_node_name, name, label=input_edge_name)
|
self.graph.add_edge(input_node_name, name, label=input_edge_name)
|
||||||
|
|
||||||
|
def _run_node(self, node_id: str, node_input: Dict[str, Any]) -> Tuple[Dict, str]:
|
||||||
|
return ray.get(self.graph.nodes[node_id]["component"].remote(**node_input))
|
||||||
|
|
||||||
|
def _get_run_node_signature(self, node_id: str):
|
||||||
|
return inspect.signature(self.graph.nodes[node_id]["component"].remote).parameters.keys()
|
||||||
|
|
||||||
|
|
||||||
class _RayDeploymentWrapper:
|
class _RayDeploymentWrapper:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1023,8 +1023,9 @@ def test_save_yaml_overwrite(tmp_path):
|
|||||||
assert content != ""
|
assert content != ""
|
||||||
|
|
||||||
|
|
||||||
def test_load_yaml_ray_args_in_pipeline(tmp_path):
|
@pytest.mark.parametrize("pipeline_file", ["ray.simple.haystack-pipeline.yml", "ray.advanced.haystack-pipeline.yml"])
|
||||||
|
def test_load_yaml_ray_args_in_pipeline(tmp_path, pipeline_file):
|
||||||
with pytest.raises(PipelineConfigError) as e:
|
with pytest.raises(PipelineConfigError) as e:
|
||||||
pipeline = Pipeline.load_from_yaml(
|
pipeline = Pipeline.load_from_yaml(
|
||||||
SAMPLES_PATH / "pipeline" / "ray.haystack-pipeline.yml", pipeline_name="ray_query_pipeline"
|
SAMPLES_PATH / "pipeline" / pipeline_file, pipeline_name="ray_query_pipeline"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -25,12 +25,14 @@ def shutdown_ray():
|
|||||||
@pytest.mark.parametrize("serve_detached", [True, False])
|
@pytest.mark.parametrize("serve_detached", [True, False])
|
||||||
def test_load_pipeline(document_store_with_docs, serve_detached):
|
def test_load_pipeline(document_store_with_docs, serve_detached):
|
||||||
pipeline = RayPipeline.load_from_yaml(
|
pipeline = RayPipeline.load_from_yaml(
|
||||||
SAMPLES_PATH / "pipeline" / "ray.haystack-pipeline.yml",
|
SAMPLES_PATH / "pipeline" / "ray.simple.haystack-pipeline.yml",
|
||||||
pipeline_name="ray_query_pipeline",
|
pipeline_name="ray_query_pipeline",
|
||||||
ray_args={"num_cpus": 8},
|
ray_args={"num_cpus": 8},
|
||||||
serve_args={"detached": serve_detached},
|
serve_args={"detached": serve_detached},
|
||||||
)
|
)
|
||||||
prediction = pipeline.run(query="Who lives in Berlin?", params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}})
|
prediction = pipeline.run(
|
||||||
|
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 10}, "Reader": {"top_k": 3}}
|
||||||
|
)
|
||||||
|
|
||||||
assert pipeline._serve_controller_client._detached == serve_detached
|
assert pipeline._serve_controller_client._detached == serve_detached
|
||||||
assert ray.serve.get_deployment(name="ESRetriever").num_replicas == 2
|
assert ray.serve.get_deployment(name="ESRetriever").num_replicas == 2
|
||||||
@ -39,3 +41,30 @@ def test_load_pipeline(document_store_with_docs, serve_detached):
|
|||||||
assert ray.serve.get_deployment(name="ESRetriever").ray_actor_options["num_cpus"] == 0.5
|
assert ray.serve.get_deployment(name="ESRetriever").ray_actor_options["num_cpus"] == 0.5
|
||||||
assert prediction["query"] == "Who lives in Berlin?"
|
assert prediction["query"] == "Who lives in Berlin?"
|
||||||
assert prediction["answers"][0].answer == "Carla"
|
assert prediction["answers"][0].answer == "Carla"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||||
|
def test_load_advanced_pipeline(document_store_with_docs):
|
||||||
|
pipeline = RayPipeline.load_from_yaml(
|
||||||
|
SAMPLES_PATH / "pipeline" / "ray.advanced.haystack-pipeline.yml",
|
||||||
|
pipeline_name="ray_query_pipeline",
|
||||||
|
ray_args={"num_cpus": 8},
|
||||||
|
serve_args={"detached": True},
|
||||||
|
)
|
||||||
|
prediction = pipeline.run(
|
||||||
|
query="Who lives in Berlin?",
|
||||||
|
params={"ESRetriever1": {"top_k": 1}, "ESRetriever2": {"top_k": 2}, "Reader": {"top_k": 3}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert pipeline._serve_controller_client._detached is True
|
||||||
|
assert ray.serve.get_deployment(name="ESRetriever1").num_replicas == 2
|
||||||
|
assert ray.serve.get_deployment(name="ESRetriever2").num_replicas == 2
|
||||||
|
assert ray.serve.get_deployment(name="Reader").num_replicas == 1
|
||||||
|
assert ray.serve.get_deployment(name="ESRetriever1").max_concurrent_queries == 17
|
||||||
|
assert ray.serve.get_deployment(name="ESRetriever2").max_concurrent_queries == 15
|
||||||
|
assert ray.serve.get_deployment(name="ESRetriever1").ray_actor_options["num_cpus"] == 0.25
|
||||||
|
assert ray.serve.get_deployment(name="ESRetriever2").ray_actor_options["num_cpus"] == 0.25
|
||||||
|
assert prediction["query"] == "Who lives in Berlin?"
|
||||||
|
assert prediction["answers"][0].answer == "Carla"
|
||||||
|
assert len(prediction["answers"]) > 1
|
||||||
|
|||||||
73
test/samples/pipeline/ray.advanced.haystack-pipeline.yml
Normal file
73
test/samples/pipeline/ray.advanced.haystack-pipeline.yml
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
version: ignore
|
||||||
|
extras: ray
|
||||||
|
|
||||||
|
components:
|
||||||
|
- name: DocumentStore
|
||||||
|
type: ElasticsearchDocumentStore
|
||||||
|
params:
|
||||||
|
index: haystack_test
|
||||||
|
label_index: haystack_test_label
|
||||||
|
- name: ESRetriever1
|
||||||
|
type: BM25Retriever
|
||||||
|
params:
|
||||||
|
document_store: DocumentStore
|
||||||
|
- name: ESRetriever2
|
||||||
|
# type: TfidfRetriever # can't use TfidfRetriever until https://github.com/deepset-ai/haystack/pull/2984 isn't merged
|
||||||
|
type: BM25Retriever
|
||||||
|
params:
|
||||||
|
document_store: DocumentStore
|
||||||
|
- name: Reader
|
||||||
|
type: FARMReader
|
||||||
|
params:
|
||||||
|
no_ans_boost: -10
|
||||||
|
model_name_or_path: deepset/roberta-base-squad2
|
||||||
|
num_processes: 0
|
||||||
|
- name: PDFConverter
|
||||||
|
type: PDFToTextConverter
|
||||||
|
params:
|
||||||
|
remove_numeric_tables: false
|
||||||
|
- name: Preprocessor
|
||||||
|
type: PreProcessor
|
||||||
|
params:
|
||||||
|
clean_whitespace: true
|
||||||
|
- name: IndexTimeDocumentClassifier
|
||||||
|
type: TransformersDocumentClassifier
|
||||||
|
params:
|
||||||
|
batch_size: 16
|
||||||
|
use_gpu: false
|
||||||
|
- name: QueryTimeDocumentClassifier
|
||||||
|
type: TransformersDocumentClassifier
|
||||||
|
params:
|
||||||
|
use_gpu: false
|
||||||
|
- name: JoinDocuments
|
||||||
|
params: {}
|
||||||
|
type: JoinDocuments
|
||||||
|
|
||||||
|
|
||||||
|
pipelines:
|
||||||
|
- name: ray_query_pipeline
|
||||||
|
nodes:
|
||||||
|
- name: ESRetriever1
|
||||||
|
inputs: [ Query ]
|
||||||
|
serve_deployment_kwargs:
|
||||||
|
num_replicas: 2
|
||||||
|
version: Twenty
|
||||||
|
ray_actor_options:
|
||||||
|
# num_gpus: 0.25 # we have no GPU to test this
|
||||||
|
num_cpus: 0.25
|
||||||
|
max_concurrent_queries: 17
|
||||||
|
- name: ESRetriever2
|
||||||
|
inputs: [ Query ]
|
||||||
|
serve_deployment_kwargs:
|
||||||
|
num_replicas: 2
|
||||||
|
version: Twenty
|
||||||
|
ray_actor_options:
|
||||||
|
# num_gpus: 0.25 # we have no GPU to test this
|
||||||
|
num_cpus: 0.25
|
||||||
|
max_concurrent_queries: 15
|
||||||
|
- name: JoinDocuments
|
||||||
|
inputs:
|
||||||
|
- ESRetriever1
|
||||||
|
- ESRetriever2
|
||||||
|
- name: Reader
|
||||||
|
inputs: [ JoinDocuments ]
|
||||||
Loading…
x
Reference in New Issue
Block a user