Refactoring the Raypipeline.run method - merging it with the Pipeline.run (#2981)

* Refactoring the `Raypipeline.run` method - merging it with the `Pipeline.run`

This is to fix #2968

* Bug: variable `i` was already in use

* Removing unused imports

* Removing unused import

* [EMPTY] Re-trigger CI

* Addressing concerns raised pre-review

- Removing the attempt to try to make it without the need for `JoinDocuments` - it is okey to fail without `JoinDocuments` for certain pipelines.

* Refactoring based on reviews
This commit is contained in:
Zoltan Fedor 2022-08-11 03:50:14 -04:00 committed by GitHub
parent f4128d3581
commit aafa017c17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 123 additions and 66 deletions

View File

@ -435,6 +435,9 @@ class Pipeline:
"""
self.graph.nodes[name]["component"] = component
def _run_node(self, node_id: str, node_input: Dict[str, Any]) -> Tuple[Dict, str]:
return self.graph.nodes[node_id]["component"]._dispatch_run(**node_input)
def run( # type: ignore
self,
query: Optional[str] = None,
@ -506,7 +509,7 @@ class Pipeline:
if predecessors.isdisjoint(set(queue.keys())): # only execute if predecessor nodes are executed
try:
logger.debug(f"Running node `{node_id}` with input `{node_input}`")
node_output, stream_id = self.graph.nodes[node_id]["component"]._dispatch_run(**node_input)
node_output, stream_id = self._run_node(node_id, node_input)
except Exception as e:
tb = traceback.format_exc()
raise Exception(
@ -1909,7 +1912,7 @@ class Pipeline:
not_a_node = set(params.keys()) - set(self.graph.nodes)
valid_global_params = set(["debug"]) # Debug will be picked up by _dispatch_run, see its code
for node_id in self.graph.nodes:
run_signature_args = inspect.signature(self.graph.nodes[node_id]["component"].run).parameters.keys()
run_signature_args = self._get_run_node_signature(node_id)
valid_global_params |= set(run_signature_args)
invalid_keys = [key for key in not_a_node if key not in valid_global_params]
@ -1918,6 +1921,9 @@ class Pipeline:
f"No node(s) or global parameter(s) named {', '.join(invalid_keys)} found in pipeline."
)
def _get_run_node_signature(self, node_id: str):
return inspect.signature(self.graph.nodes[node_id]["component"].run).parameters.keys()
def print_eval_report(
self,
eval_result: EvaluationResult,

View File

@ -1,10 +1,9 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
import inspect
from typing import Any, Dict, List, Optional, Tuple
from pathlib import Path
import networkx as nx
try:
from ray import serve
import ray
@ -18,10 +17,8 @@ from haystack.pipelines.config import (
read_pipeline_config_from_yaml,
validate_config,
)
from haystack.schema import MultiLabel, Document
from haystack.nodes.base import BaseComponent, RootNode
from haystack.pipelines.base import Pipeline
from haystack.errors import PipelineError
class RayPipeline(Pipeline):
@ -219,61 +216,6 @@ class RayPipeline(Pipeline):
handle = RayDeployment.get_handle()
return handle
def run( # type: ignore
self,
query: Optional[str] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[dict] = None,
params: Optional[dict] = None,
):
has_next_node = True
root_node = self.root_node
if not root_node:
raise PipelineError("Cannot run a pipeline with no nodes.")
current_node_id: str = root_node
input_dict: Dict[str, Any] = {"root_node": root_node, "params": params}
if query:
input_dict["query"] = query
if file_paths:
input_dict["file_paths"] = file_paths
if labels:
input_dict["labels"] = labels
if documents:
input_dict["documents"] = documents
if meta:
input_dict["meta"] = meta
output_dict = None
while has_next_node:
output_dict, stream_id = ray.get(self.graph.nodes[current_node_id]["component"].remote(**input_dict))
input_dict = output_dict
next_nodes = self.get_next_nodes(current_node_id, stream_id)
if len(next_nodes) > 1:
join_node_id = list(nx.neighbors(self.graph, next_nodes[0]))[0]
if set(self.graph.predecessors(join_node_id)) != set(next_nodes):
raise NotImplementedError(
"The current pipeline does not support multiple levels of parallel nodes."
)
inputs_for_join_node: dict = {"inputs": []}
for n_id in next_nodes:
output = self.graph.nodes[n_id]["component"].run(**input_dict)
inputs_for_join_node["inputs"].append(output)
input_dict = inputs_for_join_node
current_node_id = join_node_id
elif len(next_nodes) == 1:
current_node_id = next_nodes[0]
else:
has_next_node = False
return output_dict
def add_node(self, component, name: str, inputs: List[str]):
raise NotImplementedError(
"The current implementation of RayPipeline only supports loading Pipelines from a YAML file."
@ -318,6 +260,12 @@ class RayPipeline(Pipeline):
input_edge_name = "output_1"
self.graph.add_edge(input_node_name, name, label=input_edge_name)
def _run_node(self, node_id: str, node_input: Dict[str, Any]) -> Tuple[Dict, str]:
return ray.get(self.graph.nodes[node_id]["component"].remote(**node_input))
def _get_run_node_signature(self, node_id: str):
return inspect.signature(self.graph.nodes[node_id]["component"].remote).parameters.keys()
class _RayDeploymentWrapper:
"""

View File

@ -1023,8 +1023,9 @@ def test_save_yaml_overwrite(tmp_path):
assert content != ""
def test_load_yaml_ray_args_in_pipeline(tmp_path):
@pytest.mark.parametrize("pipeline_file", ["ray.simple.haystack-pipeline.yml", "ray.advanced.haystack-pipeline.yml"])
def test_load_yaml_ray_args_in_pipeline(tmp_path, pipeline_file):
with pytest.raises(PipelineConfigError) as e:
pipeline = Pipeline.load_from_yaml(
SAMPLES_PATH / "pipeline" / "ray.haystack-pipeline.yml", pipeline_name="ray_query_pipeline"
SAMPLES_PATH / "pipeline" / pipeline_file, pipeline_name="ray_query_pipeline"
)

View File

@ -25,12 +25,14 @@ def shutdown_ray():
@pytest.mark.parametrize("serve_detached", [True, False])
def test_load_pipeline(document_store_with_docs, serve_detached):
pipeline = RayPipeline.load_from_yaml(
SAMPLES_PATH / "pipeline" / "ray.haystack-pipeline.yml",
SAMPLES_PATH / "pipeline" / "ray.simple.haystack-pipeline.yml",
pipeline_name="ray_query_pipeline",
ray_args={"num_cpus": 8},
serve_args={"detached": serve_detached},
)
prediction = pipeline.run(query="Who lives in Berlin?", params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}})
prediction = pipeline.run(
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 10}, "Reader": {"top_k": 3}}
)
assert pipeline._serve_controller_client._detached == serve_detached
assert ray.serve.get_deployment(name="ESRetriever").num_replicas == 2
@ -39,3 +41,30 @@ def test_load_pipeline(document_store_with_docs, serve_detached):
assert ray.serve.get_deployment(name="ESRetriever").ray_actor_options["num_cpus"] == 0.5
assert prediction["query"] == "Who lives in Berlin?"
assert prediction["answers"][0].answer == "Carla"
@pytest.mark.integration
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_load_advanced_pipeline(document_store_with_docs):
pipeline = RayPipeline.load_from_yaml(
SAMPLES_PATH / "pipeline" / "ray.advanced.haystack-pipeline.yml",
pipeline_name="ray_query_pipeline",
ray_args={"num_cpus": 8},
serve_args={"detached": True},
)
prediction = pipeline.run(
query="Who lives in Berlin?",
params={"ESRetriever1": {"top_k": 1}, "ESRetriever2": {"top_k": 2}, "Reader": {"top_k": 3}},
)
assert pipeline._serve_controller_client._detached is True
assert ray.serve.get_deployment(name="ESRetriever1").num_replicas == 2
assert ray.serve.get_deployment(name="ESRetriever2").num_replicas == 2
assert ray.serve.get_deployment(name="Reader").num_replicas == 1
assert ray.serve.get_deployment(name="ESRetriever1").max_concurrent_queries == 17
assert ray.serve.get_deployment(name="ESRetriever2").max_concurrent_queries == 15
assert ray.serve.get_deployment(name="ESRetriever1").ray_actor_options["num_cpus"] == 0.25
assert ray.serve.get_deployment(name="ESRetriever2").ray_actor_options["num_cpus"] == 0.25
assert prediction["query"] == "Who lives in Berlin?"
assert prediction["answers"][0].answer == "Carla"
assert len(prediction["answers"]) > 1

View File

@ -0,0 +1,73 @@
version: ignore
extras: ray
components:
- name: DocumentStore
type: ElasticsearchDocumentStore
params:
index: haystack_test
label_index: haystack_test_label
- name: ESRetriever1
type: BM25Retriever
params:
document_store: DocumentStore
- name: ESRetriever2
# type: TfidfRetriever # can't use TfidfRetriever until https://github.com/deepset-ai/haystack/pull/2984 isn't merged
type: BM25Retriever
params:
document_store: DocumentStore
- name: Reader
type: FARMReader
params:
no_ans_boost: -10
model_name_or_path: deepset/roberta-base-squad2
num_processes: 0
- name: PDFConverter
type: PDFToTextConverter
params:
remove_numeric_tables: false
- name: Preprocessor
type: PreProcessor
params:
clean_whitespace: true
- name: IndexTimeDocumentClassifier
type: TransformersDocumentClassifier
params:
batch_size: 16
use_gpu: false
- name: QueryTimeDocumentClassifier
type: TransformersDocumentClassifier
params:
use_gpu: false
- name: JoinDocuments
params: {}
type: JoinDocuments
pipelines:
- name: ray_query_pipeline
nodes:
- name: ESRetriever1
inputs: [ Query ]
serve_deployment_kwargs:
num_replicas: 2
version: Twenty
ray_actor_options:
# num_gpus: 0.25 # we have no GPU to test this
num_cpus: 0.25
max_concurrent_queries: 17
- name: ESRetriever2
inputs: [ Query ]
serve_deployment_kwargs:
num_replicas: 2
version: Twenty
ray_actor_options:
# num_gpus: 0.25 # we have no GPU to test this
num_cpus: 0.25
max_concurrent_queries: 15
- name: JoinDocuments
inputs:
- ESRetriever1
- ESRetriever2
- name: Reader
inputs: [ JoinDocuments ]