Add Ray integration for Pipelines (#1255)

This commit is contained in:
oryx1729 2021-08-02 14:51:24 +02:00 committed by GitHub
parent 3eaf9dfbca
commit bafa1b46de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 453 additions and 71 deletions

View File

@ -85,8 +85,11 @@ jobs:
- name: Run Apache Tika
run: docker run -d -p 9998:9998 -e "TIKA_CHILD_JAVA_OPTS=-JXms128m" -e "TIKA_CHILD_JAVA_OPTS=-JXmx128m" apache/tika:1.24.1
# - name: Run Ray
# run: RAY_DISABLE_MEMORY_MONITOR=1 ray start --head
- name: Install pdftotext
run: wget --no-check-certificate https://dl.xpdfreader.com/xpdf-tools-linux-4.03.tar.gz && tar -xvf xpdf-tools-linux-4.03.tar.gz && sudo cp xpdf-tools-linux-4.03/bin64/pdftotext /usr/local/bin
- name: Run tests
run: cd test && pytest ${{ matrix.test-path }}
run: cd test && pytest -s ${{ matrix.test-path }}

View File

@ -1,3 +1,4 @@
import copy
import inspect
import logging
import os
@ -10,6 +11,14 @@ import pickle
import urllib
from functools import wraps
try:
from ray import serve
import ray
except:
ray = None
serve = None
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
import networkx as nx
@ -30,7 +39,137 @@ from haystack.graph_retriever.base import BaseGraphRetriever
logger = logging.getLogger(__name__)
class Pipeline:
class BasePipeline:
def run(self, **kwargs):
raise NotImplementedError
@classmethod
def load_from_yaml(cls, path: Path, pipeline_name: Optional[str] = None, overwrite_with_env_variables: bool = True):
"""
Load Pipeline from a YAML file defining the individual components and how they're tied together to form
a Pipeline. A single YAML can declare multiple Pipelines, in which case an explicit `pipeline_name` must
be passed.
Here's a sample configuration:
```yaml
| version: '0.8'
|
| components: # define all the building-blocks for Pipeline
| - name: MyReader # custom-name for the component; helpful for visualization & debugging
| type: FARMReader # Haystack Class name for the component
| params:
| no_ans_boost: -10
| model_name_or_path: deepset/roberta-base-squad2
| - name: MyESRetriever
| type: ElasticsearchRetriever
| params:
| document_store: MyDocumentStore # params can reference other components defined in the YAML
| custom_query: null
| - name: MyDocumentStore
| type: ElasticsearchDocumentStore
| params:
| index: haystack_test
|
| pipelines: # multiple Pipelines can be defined using the components from above
| - name: my_query_pipeline # a simple extractive-qa Pipeline
| nodes:
| - name: MyESRetriever
| inputs: [Query]
| - name: MyReader
| inputs: [MyESRetriever]
```
:param path: path of the YAML file.
:param pipeline_name: if the YAML contains multiple pipelines, the pipeline_name to load must be set.
:param overwrite_with_env_variables: Overwrite the YAML configuration with environment variables. For example,
to change index name param for an ElasticsearchDocumentStore, an env
variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an
`_` sign must be used to specify nested hierarchical properties.
"""
pipeline_config = cls._get_pipeline_config_from_yaml(path=path, pipeline_name=pipeline_name)
if pipeline_config["type"] == "Pipeline":
return Pipeline.load_from_yaml(
path=path, pipeline_name=pipeline_name, overwrite_with_env_variables=overwrite_with_env_variables
)
elif pipeline_config["type"] == "RayPipeline":
return RayPipeline.load_from_yaml(
path=path, pipeline_name=pipeline_name, overwrite_with_env_variables=overwrite_with_env_variables
)
else:
raise KeyError(f"Pipeline Type '{pipeline_config['type']}' is not a valid. The available types are"
f"'Pipeline' and 'RayPipeline'.")
@classmethod
def _get_pipeline_config_from_yaml(cls, path: Path, pipeline_name: Optional[str] = None):
"""
Get the definition of Pipeline from a given YAML. If the YAML contains more than one Pipeline,
then the pipeline_name must be supplied.
:param path: Path of Pipeline YAML file.
:param pipeline_name: name of the Pipeline.
"""
with open(path, "r", encoding='utf-8') as stream:
data = yaml.safe_load(stream)
if pipeline_name is None:
if len(data["pipelines"]) == 1:
pipeline_config = data["pipelines"][0]
else:
raise Exception("The YAML contains multiple pipelines. Please specify the pipeline name to load.")
else:
pipelines_in_yaml = list(filter(lambda p: p["name"] == pipeline_name, data["pipelines"]))
if not pipelines_in_yaml:
raise KeyError(f"Cannot find any pipeline with name '{pipeline_name}' declared in the YAML file.")
pipeline_config = pipelines_in_yaml[0]
return pipeline_config
@classmethod
def _read_yaml(cls, path: Path, pipeline_name: Optional[str], overwrite_with_env_variables: bool):
"""
Parse the YAML and return the full YAML config, pipeline_config, and definitions of all components.
:param path: path of the YAML file.
:param pipeline_name: if the YAML contains multiple pipelines, the pipeline_name to load must be set.
:param overwrite_with_env_variables: Overwrite the YAML configuration with environment variables. For example,
to change index name param for an ElasticsearchDocumentStore, an env
variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an
`_` sign must be used to specify nested hierarchical properties.
"""
with open(path, "r", encoding="utf-8") as stream:
data = yaml.safe_load(stream)
pipeline_config = cls._get_pipeline_config_from_yaml(path=path, pipeline_name=pipeline_name)
definitions = {} # definitions of each component from the YAML.
component_definitions = copy.deepcopy(data["components"])
for definition in component_definitions:
if overwrite_with_env_variables:
cls._overwrite_with_env_variables(definition)
name = definition.pop("name")
definitions[name] = definition
return data, pipeline_config, definitions
@classmethod
def _overwrite_with_env_variables(cls, definition: dict):
"""
Overwrite the YAML configuration with environment variables. For example, to change index name param for an
ElasticsearchDocumentStore, an env variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an
`_` sign must be used to specify nested hierarchical properties.
:param definition: a dictionary containing the YAML definition of a component.
"""
env_prefix = f"{definition['name']}_params_".upper()
for key, value in os.environ.items():
if key.startswith(env_prefix):
param_name = key.replace(env_prefix, "").lower()
definition["params"][param_name] = value
class Pipeline(BasePipeline):
"""
Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components.
@ -39,18 +178,9 @@ class Pipeline:
Reader from multiple Retrievers, or re-ranking of candidate documents.
"""
def __init__(self, pipeline_type: str = "Query"):
def __init__(self):
self.graph = DiGraph()
if pipeline_type == "Query":
self.root_node_id = "Query"
self.graph.add_node("Query", component=RootNode())
elif pipeline_type == "Indexing":
self.root_node_id = "File"
self.graph.add_node("File", component=RootNode())
else:
raise Exception(f"pipeline_type '{pipeline_type}' is not valid. Supported types are 'Query' & 'Indexing'.")
self.pipeline_type = pipeline_type
self.root_node = None
self.components: dict = {}
def add_node(self, component, name: str, inputs: List[str]):
@ -68,13 +198,20 @@ class Pipeline:
In cases when the predecessor node has multiple outputs, e.g., a "QueryClassifier", the output
must be specified explicitly as "QueryClassifier.output_2".
"""
if self.root_node is None:
root_node = inputs[0]
if root_node in ["Query", "File"]:
self.root_node = root_node
self.graph.add_node(root_node, component=RootNode())
else:
raise KeyError(f"Root node '{root_node}' is invalid. Available options are 'Query' and 'File'.")
self.graph.add_node(name, component=component, inputs=inputs)
if len(self.graph.nodes) == 2: # first node added; connect with Root
assert len(inputs) == 1 and inputs[0].split(".")[0] == self.root_node_id, \
f"The '{name}' node can only input from {self.root_node_id}. " \
f"Set the 'inputs' parameter to ['{self.root_node_id}']"
self.graph.add_edge(self.root_node_id, name, label="output_1")
assert len(inputs) == 1 and inputs[0].split(".")[0] == self.root_node, \
f"The '{name}' node can only input from {self.root_node}. " \
f"Set the 'inputs' parameter to ['{self.root_node}']"
self.graph.add_edge(self.root_node, name, label="output_1")
return
for i in inputs:
@ -118,7 +255,7 @@ class Pipeline:
def run(self, **kwargs):
node_output = None
queue = {
self.root_node_id: {"pipeline_type": self.pipeline_type, **kwargs}
self.root_node: {"root_node": self.root_node, **kwargs}
} # ordered dict with "node_id" -> "input" mapping that acts as a FIFO queue
i = 0 # the first item is popped off the queue unless it is a "join" node with unprocessed predecessors
while queue:
@ -221,28 +358,11 @@ class Pipeline:
variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an
`_` sign must be used to specify nested hierarchical properties.
"""
with open(path, "r", encoding='utf-8') as stream:
data = yaml.safe_load(stream)
data, pipeline_config, definitions = cls._read_yaml(
path=path, pipeline_name=pipeline_name, overwrite_with_env_variables=overwrite_with_env_variables
)
if pipeline_name is None:
if len(data["pipelines"]) == 1:
pipeline_config = data["pipelines"][0]
else:
raise Exception("The YAML contains multiple pipelines. Please specify the pipeline name to load.")
else:
pipelines_in_yaml = list(filter(lambda p: p["name"] == pipeline_name, data["pipelines"]))
if not pipelines_in_yaml:
raise KeyError(f"Cannot find any pipeline with name '{pipeline_name}' declared in the YAML file.")
pipeline_config = pipelines_in_yaml[0]
definitions = {} # definitions of each component from the YAML.
for definition in data["components"]:
if overwrite_with_env_variables:
cls._overwrite_with_env_variables(definition)
name = definition.pop("name")
definitions[name] = definition
pipeline = cls(pipeline_type=pipeline_config["type"])
pipeline = cls()
components: dict = {} # instances of component objects.
for node_config in pipeline_config["nodes"]:
@ -283,21 +403,6 @@ class Pipeline:
raise Exception(f"Failed loading pipeline component '{name}': {e}")
return instance
@classmethod
def _overwrite_with_env_variables(cls, definition: dict):
"""
Overwrite the YAML configuration with environment variables. For example, to change index name param for an
ElasticsearchDocumentStore, an env variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an
`_` sign must be used to specify nested hierarchical properties.
:param definition: a dictionary containing the YAML definition of a component.
"""
env_prefix = f"{definition['name']}_params_".upper()
for key, value in os.environ.items():
if key.startswith(env_prefix):
param_name = key.replace(env_prefix, "").lower()
definition["params"][param_name] = value
def save_to_yaml(self, path: Path, return_defaults: bool = False):
"""
Save a YAML configuration for the Pipeline that can be used with `Pipeline.load_from_yaml()`.
@ -307,13 +412,12 @@ class Pipeline:
"""
nodes = self.graph.nodes
pipeline_name = self.pipeline_type.lower()
pipeline_type = self.pipeline_type
pipelines: dict = {pipeline_name: {"name": pipeline_name, "type": pipeline_type, "nodes": []}}
pipeline_name = self.root_node.lower()
pipelines: dict = {pipeline_name: {"name": pipeline_name, "type": "Pipeline", "nodes": []}}
components = {}
for node in nodes:
if node == self.root_node_id:
if node == self.root_node:
continue
component_instance = self.graph.nodes.get(node)["component"]
component_type = component_instance.pipeline_config["type"]
@ -656,7 +760,10 @@ class QuestionAnswerGenerationPipeline(BaseStandardPipeline):
return output
class RootNode:
class RootNode(BaseComponent):
"""
RootNode feeds inputs(`query` or `file`) together with corresponding parameters to a Pipeline.
"""
outgoing_edges = 1
def run(self, **kwargs):
@ -897,3 +1004,231 @@ class JoinDocuments(BaseComponent):
documents = documents[: self.top_k_join]
output = {"query": inputs[0]["query"], "documents": documents, "labels": inputs[0].get("labels", None)}
return output, "output_1"
class RayPipeline(Pipeline):
"""
Ray (https://ray.io) is a framework for distributed computing.
With Ray, the Pipeline nodes can be distributed across a cluster of machine(s).
This allows scaling individual nodes. For instance, in an extractive QA Pipeline, multiple replicas
of the Reader, while keeping a single instance for the Retriever. It also enables efficient resource
utilization as load could be split across GPU vs CPU machines.
In the current implementation, a Ray Pipeline can only be created with a YAML Pipeline config.
>>> from haystack.pipeline import RayPipeline
>>> pipeline = RayPipeline.load_from_yaml(path="my_pipelines.yaml", pipeline_name="my_query_pipeline")
>>> pipeline.run(query="What is the capital of Germany?")
By default, RayPipelines creates an instance of RayServe locally. To connect to an existing Ray instance,
set the `address` parameter when creating RayPipeline instance.
"""
def __init__(self, address: str = None, **kwargs):
"""
:param address: The IP address for the Ray cluster. If set to None, a local Ray instance is started.
:param kwargs: Optional parameters for initializing Ray.
"""
ray.init(address=address, **kwargs)
serve.start()
super().__init__()
@classmethod
def load_from_yaml(
cls,
path: Path, pipeline_name: Optional[str] = None,
overwrite_with_env_variables: bool = True,
address: Optional[str] = None,
**kwargs,
):
"""
Load Pipeline from a YAML file defining the individual components and how they're tied together to form
a Pipeline. A single YAML can declare multiple Pipelines, in which case an explicit `pipeline_name` must
be passed.
Here's a sample configuration:
```yaml
| version: '0.8'
|
| components: # define all the building-blocks for Pipeline
| - name: MyReader # custom-name for the component; helpful for visualization & debugging
| type: FARMReader # Haystack Class name for the component
| params:
| no_ans_boost: -10
| model_name_or_path: deepset/roberta-base-squad2
| - name: MyESRetriever
| type: ElasticsearchRetriever
| params:
| document_store: MyDocumentStore # params can reference other components defined in the YAML
| custom_query: null
| - name: MyDocumentStore
| type: ElasticsearchDocumentStore
| params:
| index: haystack_test
|
| pipelines: # multiple Pipelines can be defined using the components from above
| - name: my_query_pipeline # a simple extractive-qa Pipeline
| nodes:
| - name: MyESRetriever
| inputs: [Query]
| - name: MyReader
| inputs: [MyESRetriever]
```
:param path: path of the YAML file.
:param pipeline_name: if the YAML contains multiple pipelines, the pipeline_name to load must be set.
:param overwrite_with_env_variables: Overwrite the YAML configuration with environment variables. For example,
to change index name param for an ElasticsearchDocumentStore, an env
variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an
`_` sign must be used to specify nested hierarchical properties.
:param address: The IP address for the Ray cluster. If set to None, a local Ray instance is started.
"""
data, pipeline_config, definitions = cls._read_yaml(
path=path, pipeline_name=pipeline_name, overwrite_with_env_variables=overwrite_with_env_variables
)
pipeline = cls(address=address, **kwargs)
for node_config in pipeline_config["nodes"]:
if pipeline.root_node is None:
root_node = node_config["inputs"][0]
if root_node in ["Query", "File"]:
pipeline.root_node = root_node
handle = cls._create_ray_deployment(component_name=root_node, pipeline_config=data)
pipeline._add_ray_deployment_in_graph(handle=handle, name=root_node, outgoing_edges=1, inputs=[])
else:
raise KeyError(f"Root node '{root_node}' is invalid. Available options are 'Query' and 'File'.")
name = node_config["name"]
component_type = definitions[name]["type"]
component_class = BaseComponent.get_subclass(component_type)
replicas = next(comp for comp in data["components"] if comp["name"] == name).get("replicas", 1)
handle = cls._create_ray_deployment(component_name=name, pipeline_config=data, replicas=replicas)
pipeline._add_ray_deployment_in_graph(
handle=handle,
name=name,
outgoing_edges=component_class.outgoing_edges,
inputs=node_config.get("inputs", []),
)
return pipeline
@classmethod
def _create_ray_deployment(cls, component_name: str, pipeline_config: dict, replicas: int = 1):
"""
Create a Ray Deployment for the Component.
:param component_name: Class name of the Haystack Component.
:param pipeline_config: The Pipeline config YAML parsed as a dict.
:param replicas: By default, a single replica of the component is created. It can be
configured by setting `replicas` parameter in the Pipeline YAML.
"""
RayDeployment = serve.deployment(_RayDeploymentWrapper, name=component_name, num_replicas=replicas)
RayDeployment.deploy(pipeline_config, component_name)
handle = RayDeployment.get_handle()
return handle
def run(self, **kwargs):
has_next_node = True
current_node_id = self.root_node
input_dict = {"root_node": self.root_node, **kwargs}
output_dict = None
while has_next_node:
output_dict, stream_id = ray.get(self.graph.nodes[current_node_id]["component"].remote(**input_dict))
input_dict = output_dict
next_nodes = self.get_next_nodes(current_node_id, stream_id)
if len(next_nodes) > 1:
join_node_id = list(nx.neighbors(self.graph, next_nodes[0]))[0]
if set(self.graph.predecessors(join_node_id)) != set(next_nodes):
raise NotImplementedError(
"The current pipeline does not support multiple levels of parallel nodes."
)
inputs_for_join_node = {"inputs": []}
for n_id in next_nodes:
output = self.graph.nodes[n_id]["component"].run(**input_dict)
inputs_for_join_node["inputs"].append(output)
input_dict = inputs_for_join_node
current_node_id = join_node_id
elif len(next_nodes) == 1:
current_node_id = next_nodes[0]
else:
has_next_node = False
return output_dict
def add_node(self, component, name: str, inputs: List[str]):
raise NotImplementedError(
"The current implementation of RayPipeline only supports loading Pipelines from a YAML file."
)
def _add_ray_deployment_in_graph(self, handle, name: str, outgoing_edges: int, inputs: List[str]):
"""
Add the Ray deployment handle in the Pipeline Graph.
:param handle: Ray deployment `handle` to add in the Pipeline Graph. The handle allow calling a Ray deployment
from Python: https://docs.ray.io/en/master/serve/package-ref.html#servehandle-api.
:param name: The name for the node. It must not contain any dots.
:param inputs: A list of inputs to the node. If the predecessor node has a single outgoing edge, just the name
of node is sufficient. For instance, a 'ElasticsearchRetriever' node would always output a single
edge with a list of documents. It can be represented as ["ElasticsearchRetriever"].
In cases when the predecessor node has multiple outputs, e.g., a "QueryClassifier", the output
must be specified explicitly as "QueryClassifier.output_2".
"""
self.graph.add_node(name, component=handle, inputs=inputs, outgoing_edges=outgoing_edges)
if len(self.graph.nodes) == 2: # first node added; connect with Root
self.graph.add_edge(self.root_node, name, label="output_1")
return
for i in inputs:
if "." in i:
[input_node_name, input_edge_name] = i.split(".")
assert "output_" in input_edge_name, f"'{input_edge_name}' is not a valid edge name."
outgoing_edges_input_node = self.graph.nodes[input_node_name]["component"].outgoing_edges
assert int(input_edge_name.split("_")[1]) <= outgoing_edges_input_node, (
f"Cannot connect '{input_edge_name}' from '{input_node_name}' as it only has "
f"{outgoing_edges_input_node} outgoing edge(s)."
)
else:
outgoing_edges_input_node = self.graph.nodes[i]["outgoing_edges"]
assert outgoing_edges_input_node == 1, (
f"Adding an edge from {i} to {name} is ambiguous as {i} has {outgoing_edges_input_node} edges. "
f"Please specify the output explicitly."
)
input_node_name = i
input_edge_name = "output_1"
self.graph.add_edge(input_node_name, name, label=input_edge_name)
class _RayDeploymentWrapper:
"""
Ray Serve supports calling of __init__ methods on the Classes to create "deployment" instances.
In case of Haystack, some Components like Retrievers have complex init methods that needs objects
like Document Stores.
This wrapper class encapsulates the initialization of Components. Given a Component Class
name, it creates an instance using the YAML Pipeline config.
"""
node: BaseComponent
def __init__(self, pipeline_config: dict, component_name: str):
"""
Create an instance of Component.
:param pipeline_config: Pipeline YAML parsed as a dict.
:param component_name: Component Class name.
"""
if component_name in ["Query", "File"]:
self.node = RootNode()
else:
self.node = BaseComponent.load_from_pipeline_config(pipeline_config, component_name)
def __call__(self, *args, **kwargs):
"""
Ray calls this method which is then re-directed to the corresponding component's run().
"""
return self.node.run(*args, **kwargs)

View File

@ -174,17 +174,17 @@ class BaseRetriever(BaseComponent):
else:
return metrics
def run(self, pipeline_type: str, **kwargs): # type: ignore
if pipeline_type == "Query":
def run(self, root_node: str, **kwargs): # type: ignore
if root_node == "Query":
self.query_count += 1
run_query_timed = self.timing(self.run_query, "query_time")
output, stream = run_query_timed(**kwargs)
elif pipeline_type == "Indexing":
elif root_node == "File":
self.index_count += len(kwargs["documents"])
run_indexing = self.timing(self.run_indexing, "index_time")
output, stream = run_indexing(**kwargs)
else:
raise Exception(f"Invalid pipeline_type '{pipeline_type}'.")
raise Exception(f"Invalid root_node '{root_node}'.")
return output, stream
def run_query(

View File

@ -268,6 +268,13 @@ class BaseComponent:
super().__init_subclass__(**kwargs)
cls.subclasses[cls.__name__] = cls
@classmethod
def get_subclass(cls, component_type: str):
if component_type not in cls.subclasses.keys():
raise Exception(f"Haystack component with the name '{component_type}' does not exist.")
subclass = cls.subclasses[component_type]
return subclass
@classmethod
def load_from_args(cls, component_type: str, **kwargs):
"""
@ -276,11 +283,33 @@ class BaseComponent:
:param component_type: name of the component class to load.
:param kwargs: parameters to pass to the __init__() for the component.
"""
if component_type not in cls.subclasses.keys():
raise Exception(f"Haystack component with the name '{component_type}' does not exist.")
instance = cls.subclasses[component_type](**kwargs)
subclass = cls.get_subclass(component_type)
instance = subclass(**kwargs)
return instance
@classmethod
def load_from_pipeline_config(cls, pipeline_config: dict, component_name: str):
"""
Load an individual component from a YAML config for Pipelines.
:param pipeline_config: the Pipelines YAML config parsed as a dict.
:param component_name: the name of the component to load.
"""
if pipeline_config:
all_component_configs = pipeline_config["components"]
all_component_names = [comp["name"] for comp in all_component_configs]
component_config = next(comp for comp in all_component_configs if comp["name"] == component_name)
component_params = component_config["params"]
for key, value in component_params.items():
if value in all_component_names: # check if the param value is a reference to another component
component_params[key] = cls.load_from_pipeline_config(pipeline_config, value)
component_instance = cls.load_from_args(component_config["type"], **component_params)
else:
component_instance = cls.load_from_args(component_name)
return component_instance
@abstractmethod
def run(self, *args: Any, **kwargs: Any):
"""

View File

@ -32,3 +32,4 @@ pymilvus
SPARQLWrapper
mmh3
weaviate-client==2.5.0
ray==1.5.0

View File

@ -14,7 +14,7 @@ components:
- name: DocumentStore
type: ElasticsearchDocumentStore
params:
index: haystack_test_document
index: haystack_test
label_index: haystack_test_label
- name: PDFConverter
type: PDFToTextConverter

View File

@ -18,8 +18,8 @@ from haystack.retriever.dense import DensePassageRetriever
from haystack.retriever.sparse import ElasticsearchRetriever
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_load_and_save_yaml(document_store_with_docs, tmp_path):
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
def test_load_and_save_yaml(document_store, tmp_path):
# test correct load of indexing pipeline from yaml
pipeline = Pipeline.load_from_yaml(
Path("samples/pipeline/test_pipeline.yaml"), pipeline_name="indexing_pipeline"
@ -58,7 +58,7 @@ def test_load_and_save_yaml(document_store_with_docs, tmp_path):
type: ElasticsearchRetriever
- name: ElasticsearchDocumentStore
params:
index: haystack_test_document
index: haystack_test
label_index: haystack_test_label
type: ElasticsearchDocumentStore
- name: Reader
@ -75,7 +75,7 @@ def test_load_and_save_yaml(document_store_with_docs, tmp_path):
- inputs:
- ESRetriever
name: Reader
type: Query
type: Pipeline
version: '0.8'
"""
assert saved_yaml.replace(" ", "").replace("\n", "") == expected_yaml.replace(

14
test/test_ray.py Normal file
View File

@ -0,0 +1,14 @@
from pathlib import Path
import pytest
from haystack.pipeline import RayPipeline
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_load_pipeline(document_store_with_docs):
pipeline = RayPipeline.load_from_yaml(
Path("samples/pipeline/test_pipeline.yaml"), pipeline_name="query_pipeline", num_cpus=8,
)
prediction = pipeline.run(query="Who lives in Berlin?", top_k_retriever=10, top_k_reader=3)
assert prediction["query"] == "Who lives in Berlin?"
assert prediction["answers"][0]["answer"] == "Carla"