Refactoring the Raypipeline.run method - merging it with the Pipeline.run (#2981)

* Refactoring the `Raypipeline.run` method - merging it with the `Pipeline.run`

This is to fix #2968

* Bug: variable `i` was already in use

* Removing unused imports

* Removing unused import

* [EMPTY] Re-trigger CI

* Addressing concerns raised pre-review

- Removing the attempt to try to make it without the need for `JoinDocuments` - it is okey to fail without `JoinDocuments` for certain pipelines.

* Refactoring based on reviews
This commit is contained in:
Zoltan Fedor 2022-08-11 03:50:14 -04:00 committed by GitHub
parent f4128d3581
commit aafa017c17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 123 additions and 66 deletions

View File

@ -435,6 +435,9 @@ class Pipeline:
""" """
self.graph.nodes[name]["component"] = component self.graph.nodes[name]["component"] = component
def _run_node(self, node_id: str, node_input: Dict[str, Any]) -> Tuple[Dict, str]:
return self.graph.nodes[node_id]["component"]._dispatch_run(**node_input)
def run( # type: ignore def run( # type: ignore
self, self,
query: Optional[str] = None, query: Optional[str] = None,
@ -506,7 +509,7 @@ class Pipeline:
if predecessors.isdisjoint(set(queue.keys())): # only execute if predecessor nodes are executed if predecessors.isdisjoint(set(queue.keys())): # only execute if predecessor nodes are executed
try: try:
logger.debug(f"Running node `{node_id}` with input `{node_input}`") logger.debug(f"Running node `{node_id}` with input `{node_input}`")
node_output, stream_id = self.graph.nodes[node_id]["component"]._dispatch_run(**node_input) node_output, stream_id = self._run_node(node_id, node_input)
except Exception as e: except Exception as e:
tb = traceback.format_exc() tb = traceback.format_exc()
raise Exception( raise Exception(
@ -1909,7 +1912,7 @@ class Pipeline:
not_a_node = set(params.keys()) - set(self.graph.nodes) not_a_node = set(params.keys()) - set(self.graph.nodes)
valid_global_params = set(["debug"]) # Debug will be picked up by _dispatch_run, see its code valid_global_params = set(["debug"]) # Debug will be picked up by _dispatch_run, see its code
for node_id in self.graph.nodes: for node_id in self.graph.nodes:
run_signature_args = inspect.signature(self.graph.nodes[node_id]["component"].run).parameters.keys() run_signature_args = self._get_run_node_signature(node_id)
valid_global_params |= set(run_signature_args) valid_global_params |= set(run_signature_args)
invalid_keys = [key for key in not_a_node if key not in valid_global_params] invalid_keys = [key for key in not_a_node if key not in valid_global_params]
@ -1918,6 +1921,9 @@ class Pipeline:
f"No node(s) or global parameter(s) named {', '.join(invalid_keys)} found in pipeline." f"No node(s) or global parameter(s) named {', '.join(invalid_keys)} found in pipeline."
) )
def _get_run_node_signature(self, node_id: str):
return inspect.signature(self.graph.nodes[node_id]["component"].run).parameters.keys()
def print_eval_report( def print_eval_report(
self, self,
eval_result: EvaluationResult, eval_result: EvaluationResult,

View File

@ -1,10 +1,9 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional import inspect
from typing import Any, Dict, List, Optional, Tuple
from pathlib import Path from pathlib import Path
import networkx as nx
try: try:
from ray import serve from ray import serve
import ray import ray
@ -18,10 +17,8 @@ from haystack.pipelines.config import (
read_pipeline_config_from_yaml, read_pipeline_config_from_yaml,
validate_config, validate_config,
) )
from haystack.schema import MultiLabel, Document
from haystack.nodes.base import BaseComponent, RootNode from haystack.nodes.base import BaseComponent, RootNode
from haystack.pipelines.base import Pipeline from haystack.pipelines.base import Pipeline
from haystack.errors import PipelineError
class RayPipeline(Pipeline): class RayPipeline(Pipeline):
@ -219,61 +216,6 @@ class RayPipeline(Pipeline):
handle = RayDeployment.get_handle() handle = RayDeployment.get_handle()
return handle return handle
def run( # type: ignore
self,
query: Optional[str] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[dict] = None,
params: Optional[dict] = None,
):
has_next_node = True
root_node = self.root_node
if not root_node:
raise PipelineError("Cannot run a pipeline with no nodes.")
current_node_id: str = root_node
input_dict: Dict[str, Any] = {"root_node": root_node, "params": params}
if query:
input_dict["query"] = query
if file_paths:
input_dict["file_paths"] = file_paths
if labels:
input_dict["labels"] = labels
if documents:
input_dict["documents"] = documents
if meta:
input_dict["meta"] = meta
output_dict = None
while has_next_node:
output_dict, stream_id = ray.get(self.graph.nodes[current_node_id]["component"].remote(**input_dict))
input_dict = output_dict
next_nodes = self.get_next_nodes(current_node_id, stream_id)
if len(next_nodes) > 1:
join_node_id = list(nx.neighbors(self.graph, next_nodes[0]))[0]
if set(self.graph.predecessors(join_node_id)) != set(next_nodes):
raise NotImplementedError(
"The current pipeline does not support multiple levels of parallel nodes."
)
inputs_for_join_node: dict = {"inputs": []}
for n_id in next_nodes:
output = self.graph.nodes[n_id]["component"].run(**input_dict)
inputs_for_join_node["inputs"].append(output)
input_dict = inputs_for_join_node
current_node_id = join_node_id
elif len(next_nodes) == 1:
current_node_id = next_nodes[0]
else:
has_next_node = False
return output_dict
def add_node(self, component, name: str, inputs: List[str]): def add_node(self, component, name: str, inputs: List[str]):
raise NotImplementedError( raise NotImplementedError(
"The current implementation of RayPipeline only supports loading Pipelines from a YAML file." "The current implementation of RayPipeline only supports loading Pipelines from a YAML file."
@ -318,6 +260,12 @@ class RayPipeline(Pipeline):
input_edge_name = "output_1" input_edge_name = "output_1"
self.graph.add_edge(input_node_name, name, label=input_edge_name) self.graph.add_edge(input_node_name, name, label=input_edge_name)
def _run_node(self, node_id: str, node_input: Dict[str, Any]) -> Tuple[Dict, str]:
return ray.get(self.graph.nodes[node_id]["component"].remote(**node_input))
def _get_run_node_signature(self, node_id: str):
return inspect.signature(self.graph.nodes[node_id]["component"].remote).parameters.keys()
class _RayDeploymentWrapper: class _RayDeploymentWrapper:
""" """

View File

@ -1023,8 +1023,9 @@ def test_save_yaml_overwrite(tmp_path):
assert content != "" assert content != ""
def test_load_yaml_ray_args_in_pipeline(tmp_path): @pytest.mark.parametrize("pipeline_file", ["ray.simple.haystack-pipeline.yml", "ray.advanced.haystack-pipeline.yml"])
def test_load_yaml_ray_args_in_pipeline(tmp_path, pipeline_file):
with pytest.raises(PipelineConfigError) as e: with pytest.raises(PipelineConfigError) as e:
pipeline = Pipeline.load_from_yaml( pipeline = Pipeline.load_from_yaml(
SAMPLES_PATH / "pipeline" / "ray.haystack-pipeline.yml", pipeline_name="ray_query_pipeline" SAMPLES_PATH / "pipeline" / pipeline_file, pipeline_name="ray_query_pipeline"
) )

View File

@ -25,12 +25,14 @@ def shutdown_ray():
@pytest.mark.parametrize("serve_detached", [True, False]) @pytest.mark.parametrize("serve_detached", [True, False])
def test_load_pipeline(document_store_with_docs, serve_detached): def test_load_pipeline(document_store_with_docs, serve_detached):
pipeline = RayPipeline.load_from_yaml( pipeline = RayPipeline.load_from_yaml(
SAMPLES_PATH / "pipeline" / "ray.haystack-pipeline.yml", SAMPLES_PATH / "pipeline" / "ray.simple.haystack-pipeline.yml",
pipeline_name="ray_query_pipeline", pipeline_name="ray_query_pipeline",
ray_args={"num_cpus": 8}, ray_args={"num_cpus": 8},
serve_args={"detached": serve_detached}, serve_args={"detached": serve_detached},
) )
prediction = pipeline.run(query="Who lives in Berlin?", params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}}) prediction = pipeline.run(
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 10}, "Reader": {"top_k": 3}}
)
assert pipeline._serve_controller_client._detached == serve_detached assert pipeline._serve_controller_client._detached == serve_detached
assert ray.serve.get_deployment(name="ESRetriever").num_replicas == 2 assert ray.serve.get_deployment(name="ESRetriever").num_replicas == 2
@ -39,3 +41,30 @@ def test_load_pipeline(document_store_with_docs, serve_detached):
assert ray.serve.get_deployment(name="ESRetriever").ray_actor_options["num_cpus"] == 0.5 assert ray.serve.get_deployment(name="ESRetriever").ray_actor_options["num_cpus"] == 0.5
assert prediction["query"] == "Who lives in Berlin?" assert prediction["query"] == "Who lives in Berlin?"
assert prediction["answers"][0].answer == "Carla" assert prediction["answers"][0].answer == "Carla"
@pytest.mark.integration
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_load_advanced_pipeline(document_store_with_docs):
pipeline = RayPipeline.load_from_yaml(
SAMPLES_PATH / "pipeline" / "ray.advanced.haystack-pipeline.yml",
pipeline_name="ray_query_pipeline",
ray_args={"num_cpus": 8},
serve_args={"detached": True},
)
prediction = pipeline.run(
query="Who lives in Berlin?",
params={"ESRetriever1": {"top_k": 1}, "ESRetriever2": {"top_k": 2}, "Reader": {"top_k": 3}},
)
assert pipeline._serve_controller_client._detached is True
assert ray.serve.get_deployment(name="ESRetriever1").num_replicas == 2
assert ray.serve.get_deployment(name="ESRetriever2").num_replicas == 2
assert ray.serve.get_deployment(name="Reader").num_replicas == 1
assert ray.serve.get_deployment(name="ESRetriever1").max_concurrent_queries == 17
assert ray.serve.get_deployment(name="ESRetriever2").max_concurrent_queries == 15
assert ray.serve.get_deployment(name="ESRetriever1").ray_actor_options["num_cpus"] == 0.25
assert ray.serve.get_deployment(name="ESRetriever2").ray_actor_options["num_cpus"] == 0.25
assert prediction["query"] == "Who lives in Berlin?"
assert prediction["answers"][0].answer == "Carla"
assert len(prediction["answers"]) > 1

View File

@ -0,0 +1,73 @@
version: ignore
extras: ray
components:
- name: DocumentStore
type: ElasticsearchDocumentStore
params:
index: haystack_test
label_index: haystack_test_label
- name: ESRetriever1
type: BM25Retriever
params:
document_store: DocumentStore
- name: ESRetriever2
# type: TfidfRetriever # can't use TfidfRetriever until https://github.com/deepset-ai/haystack/pull/2984 isn't merged
type: BM25Retriever
params:
document_store: DocumentStore
- name: Reader
type: FARMReader
params:
no_ans_boost: -10
model_name_or_path: deepset/roberta-base-squad2
num_processes: 0
- name: PDFConverter
type: PDFToTextConverter
params:
remove_numeric_tables: false
- name: Preprocessor
type: PreProcessor
params:
clean_whitespace: true
- name: IndexTimeDocumentClassifier
type: TransformersDocumentClassifier
params:
batch_size: 16
use_gpu: false
- name: QueryTimeDocumentClassifier
type: TransformersDocumentClassifier
params:
use_gpu: false
- name: JoinDocuments
params: {}
type: JoinDocuments
pipelines:
- name: ray_query_pipeline
nodes:
- name: ESRetriever1
inputs: [ Query ]
serve_deployment_kwargs:
num_replicas: 2
version: Twenty
ray_actor_options:
# num_gpus: 0.25 # we have no GPU to test this
num_cpus: 0.25
max_concurrent_queries: 17
- name: ESRetriever2
inputs: [ Query ]
serve_deployment_kwargs:
num_replicas: 2
version: Twenty
ray_actor_options:
# num_gpus: 0.25 # we have no GPU to test this
num_cpus: 0.25
max_concurrent_queries: 15
- name: JoinDocuments
inputs:
- ESRetriever1
- ESRetriever2
- name: Reader
inputs: [ JoinDocuments ]