mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 02:39:51 +00:00
Pipeline's YAML: syntax validation (#2226)
* Add BasePipeline.validate_config, BasePipeline.validate_yaml, and some new custom exception classes * Make error composition work properly * Clarify typing * Help mypy a bit more * Update Documentation & Code Style * Enable autogenerated docs for Milvus1 and 2 separately * Revert "Enable autogenerated docs for Milvus1 and 2 separately" This reverts commit 282be4a78a6e95862a9b4c924fc3dea5ca71e28d. * Update Documentation & Code Style * Re-enable 'additionalProperties: False' * Add pipeline.type to JSON Schema, was somehow forgotten * Disable additionalProperties on the pipeline properties too * Fix json-schemas for 1.1.0 and 1.2.0 (should not do it again in the future) * Cal super in PipelineValidationError * Improve _read_pipeline_config_from_yaml's error handling * Fix generate_json_schema.py to include document stores * Fix json schemas (retro-fix 1.1.0 again) * Improve custom errors printing, add link to docs * Add function in BaseComponent to list its subclasses in a module * Make some document stores base classes abstract * Add marker 'integration' in pytest flags * Slighly improve validation of pipelines at load * Adding tests for YAML loading and validation * Make custom_query Optional for validation issues * Fix bug in _read_pipeline_config_from_yaml * Improve error handling in BasePipeline and Pipeline and add DAG check * Move json schema generation into haystack/nodes/_json_schema.py (useful for tests) * Simplify errors slightly * Add some YAML validation tests * Remove load_from_config from BasePipeline, it was never used anyway * Improve tests * Include json-schemas in package * Fix conftest imports * Make BasePipeline abstract * Improve mocking by making the test independent from the YAML version * Add exportable_to_yaml decorator to forget about set_config on mock nodes * Fix mypy errors * Comment out one monkeypatch * Fix typing again * Improve error message for validation * Add required properties to pipelines * Fix YAML version for REST API YAMLs to 1.2.0 * Fix load_from_yaml call in load_from_deepset_cloud * fix HaystackError.__getattr__ * Add super().__init__()in most nodes and docstore, comment set_config * Remove type from REST API pipelines * Remove useless init from doc2answers * Call super in Seq3SeqGenerator * Typo in deepsetcloud.py * Fix rest api indexing error mismatch and mock version of JSON schema in all tests * Working on pipeline tests * Improve errors printing slightly * Add back test_pipeline.yaml * _json_schema.py supports different versions with identical schemas * Add type to 0.7 schema for backwards compatibility * Fix small bug in _json_schema.py * Try alternative to generate json schemas on the CI * Update Documentation & Code Style * Make linux CI match autoformat CI * Fix super-init-not-called * Accidentally committed file * Update Documentation & Code Style * fix test_summarizer_translation.py's import * Mock YAML in a few suites, split and simplify test_pipeline_debug_and_validation.py::test_invalid_run_args * Fix json schema for ray tests too * Update Documentation & Code Style * Reintroduce validation * Usa unstable version in tests and rest api * Make unstable support the latest versions * Update Documentation & Code Style * Remove needless fixture * Make type in pipeline optional in the strings validation * Fix schemas * Fix string validation for pipeline type * Improve validate_config_strings * Remove type from test p[ipelines * Update Documentation & Code Style * Fix test_pipeline * Removing more type from pipelines * Temporary CI patc * Fix issue with exportable_to_yaml never invoking the wrapped init * rm stray file * pipeline tests are green again * Linux CI now needs .[all] to generate the schema * Bugfixes, pipeline tests seems to be green * Typo in version after merge * Implement missing methods in Weaviate * Trying to avoid FAISS tests from running in the Milvus1 test suite * Fix some stray test paths and faiss index dumping * Fix pytest markers list * Temporarily disable cache to be able to see tests failures * Fix pyproject.toml syntax * Use only tmp_path * Fix preprocessor signature after merge * Fix faiss bug * Fix Ray test * Fix documentation issue by removing quotes from faiss type * Update Documentation & Code Style * use document properly in preprocessor tests * Update Documentation & Code Style * make preprocessor capable of handling documents * import document * Revert support for documents in preprocessor, do later * Fix bug in _json_schema.py that was breaking validation * re-enable cache * Update Documentation & Code Style * Simplify calling _json_schema.py from the CI * Remove redundant ABC inheritance * Ensure exportable_to_yaml works only on implementations * Rename subclass to class_ in Meta * Make run() and get_config() abstract in BasePipeline * Revert unintended change in preprocessor * Move outgoing_edges_input_node check inside try block * Rename VALID_CODE_GEN_INPUT_REGEX into VALID_INPUT_REGEX * Add check for a RecursionError on validate_config_strings * Address usages of _pipeline_config in data silo and elasticsearch * Rename _pipeline_config into _init_parameters * Fix pytest marker and remove unused imports * Remove most redundant ABCs * Rename _init_parameters into _component_configuration * Remove set_config and type from _component_configuration's dict * Remove last instances of set_config and replace with super().__init__() * Implement __init_subclass__ approach * Simplify checks on the existence of _component_configuration * Fix faiss issue * Dynamic generation of node schemas & weed out old schemas * Add debatable test * Add docstring to debatable test * Positive diff between schemas implemented * Improve diff printing * Rename REST API YAML files to trigger IDE validation * Fix typing issues * Fix more typing * Typo in YAML filename * Remove needless type:ignore * Add tests * Fix tests & validation feedback for accessory classes in custom nodes * Refactor RAGeneratorType out * Fix broken import in conftest * Improve source error handling * Remove unused import in test_eval.py breaking tests * Fix changed error message in tests matches too * Normalize generate_openapi_specs.py and generate_json_schema.py in the actions * Fix path to generate_openapi_specs.py in autoformat.yml * Update Documentation & Code Style * Add test for FAISSDocumentStore-like situations (superclass with init params) * Update Documentation & Code Style * Fix indentation * Remove commented set_config * Store model_name_or_path in FARMReader to use in DistillationDataSilo * Rename _component_configuration into _component_config * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
a1040a17b2
commit
11cf94a965
307
.github/utils/generate_json_schema.py
vendored
307
.github/utils/generate_json_schema.py
vendored
@ -1,307 +1,10 @@
|
||||
import json
|
||||
import sys
|
||||
import logging
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Set, Tuple
|
||||
|
||||
from haystack import __version__
|
||||
import haystack.document_stores
|
||||
import haystack.nodes
|
||||
import pydantic.schema
|
||||
from fastapi.dependencies.utils import get_typed_signature
|
||||
from pydantic import BaseConfig, BaseSettings, Required, SecretStr, create_model
|
||||
from pydantic.fields import ModelField
|
||||
from pydantic.schema import SkipField, TypeModelOrEnum, TypeModelSet, encode_default
|
||||
from pydantic.schema import field_singleton_schema as _field_singleton_schema
|
||||
from pydantic.typing import is_callable_type
|
||||
from pydantic.utils import lenient_issubclass
|
||||
|
||||
schema_version = __version__
|
||||
filename = f"haystack-pipeline-{schema_version}.schema.json"
|
||||
destination_path = Path(__file__).parent.parent.parent / "json-schemas" / filename
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
input_token: SecretStr
|
||||
github_repository: str
|
||||
sys.path.append(".")
|
||||
from haystack.nodes._json_schema import update_json_schema
|
||||
|
||||
|
||||
# Monkey patch Pydantic's field_singleton_schema to convert classes and functions to
|
||||
# strings in JSON Schema
|
||||
def field_singleton_schema(
|
||||
field: ModelField,
|
||||
*,
|
||||
by_alias: bool,
|
||||
model_name_map: Dict[TypeModelOrEnum, str],
|
||||
ref_template: str,
|
||||
schema_overrides: bool = False,
|
||||
ref_prefix: Optional[str] = None,
|
||||
known_models: TypeModelSet,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]:
|
||||
try:
|
||||
return _field_singleton_schema(
|
||||
field,
|
||||
by_alias=by_alias,
|
||||
model_name_map=model_name_map,
|
||||
ref_template=ref_template,
|
||||
schema_overrides=schema_overrides,
|
||||
ref_prefix=ref_prefix,
|
||||
known_models=known_models,
|
||||
)
|
||||
except (ValueError, SkipField):
|
||||
schema: Dict[str, Any] = {"type": "string"}
|
||||
|
||||
if isinstance(field.default, type) or is_callable_type(field.default):
|
||||
default = field.default.__name__
|
||||
else:
|
||||
default = field.default
|
||||
if not field.required:
|
||||
schema["default"] = encode_default(default)
|
||||
return schema, {}, set()
|
||||
|
||||
|
||||
# Monkeypatch Pydantic's field_singleton_schema
|
||||
pydantic.schema.field_singleton_schema = field_singleton_schema
|
||||
|
||||
|
||||
class Config(BaseConfig):
|
||||
extra = "forbid"
|
||||
|
||||
|
||||
def get_json_schema():
|
||||
"""
|
||||
Generate JSON schema for Haystack pipelines.
|
||||
"""
|
||||
schema_definitions = {}
|
||||
additional_definitions = {}
|
||||
|
||||
modules_with_nodes = [haystack.nodes, haystack.document_stores]
|
||||
possible_nodes = []
|
||||
for module in modules_with_nodes:
|
||||
for importable_name in dir(module):
|
||||
imported = getattr(module, importable_name)
|
||||
possible_nodes.append((module, imported))
|
||||
# TODO: decide if there's a better way to not include Base classes other than by
|
||||
# the prefix "Base" in the name. Maybe it could make sense to have a list of
|
||||
# all the valid nodes to include in the main source code and then using that here.
|
||||
for module, node in possible_nodes:
|
||||
if lenient_issubclass(node, haystack.nodes.BaseComponent) and not node.__name__.startswith("Base"):
|
||||
logging.info(f"Processing node: {node.__name__}")
|
||||
init_method = getattr(node, "__init__", None)
|
||||
if init_method:
|
||||
signature = get_typed_signature(init_method)
|
||||
param_fields = [
|
||||
param
|
||||
for param in signature.parameters.values()
|
||||
if param.kind not in {param.VAR_POSITIONAL, param.VAR_KEYWORD}
|
||||
]
|
||||
# Remove self parameter
|
||||
param_fields.pop(0)
|
||||
param_fields_kwargs: Dict[str, Any] = {}
|
||||
for param in param_fields:
|
||||
logging.info(f"--- processing param: {param.name}")
|
||||
annotation = Any
|
||||
if param.annotation != param.empty:
|
||||
annotation = param.annotation
|
||||
default = Required
|
||||
if param.default != param.empty:
|
||||
default = param.default
|
||||
param_fields_kwargs[param.name] = (annotation, default)
|
||||
model = create_model(f"{node.__name__}ComponentParams", __config__=Config, **param_fields_kwargs)
|
||||
model.update_forward_refs(**model.__dict__)
|
||||
params_schema = model.schema()
|
||||
params_schema["title"] = "Parameters"
|
||||
params_schema[
|
||||
"description"
|
||||
] = "Each parameter can reference other components defined in the same YAML file."
|
||||
if "definitions" in params_schema:
|
||||
params_definitions = params_schema.pop("definitions")
|
||||
additional_definitions.update(params_definitions)
|
||||
component_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "Custom name for the component. Helpful for visualization and debugging.",
|
||||
"type": "string",
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"description": "Haystack Class name for the component.",
|
||||
"type": "string",
|
||||
"const": f"{node.__name__}",
|
||||
},
|
||||
"params": params_schema,
|
||||
},
|
||||
"required": ["type", "name"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
schema_definitions[f"{node.__name__}Component"] = component_schema
|
||||
|
||||
all_definitions = {**schema_definitions, **additional_definitions}
|
||||
component_refs = [{"$ref": f"#/definitions/{name}"} for name in schema_definitions]
|
||||
pipeline_schema = {
|
||||
"$schema": "http://json-schema.org/draft-07/schema",
|
||||
"$id": f"https://haystack.deepset.ai/json-schemas/{filename}",
|
||||
"title": "Haystack Pipeline",
|
||||
"description": "Haystack Pipeline YAML file describing the nodes of the pipelines. For more info read the docs at: https://haystack.deepset.ai/components/pipelines#yaml-file-definitions",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"version": {
|
||||
"title": "Version",
|
||||
"description": "Version of the Haystack Pipeline file.",
|
||||
"type": "string",
|
||||
"const": schema_version,
|
||||
},
|
||||
"components": {
|
||||
"title": "Components",
|
||||
"description": "Component nodes and their configurations, to later be used in the pipelines section. Define here all the building blocks for the pipelines.",
|
||||
"type": "array",
|
||||
"items": {"anyOf": component_refs},
|
||||
"required": ["type", "name"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"pipelines": {
|
||||
"title": "Pipelines",
|
||||
"description": "Multiple pipelines can be defined using the components from the same YAML file.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"title": "Name", "description": "Name of the pipeline.", "type": "string"},
|
||||
"nodes": {
|
||||
"title": "Nodes",
|
||||
"description": "Nodes to be used by this particular pipeline",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "The name of this particular node in the pipeline. This should be one of the names from the components defined in the same file.",
|
||||
"type": "string",
|
||||
},
|
||||
"inputs": {
|
||||
"title": "Inputs",
|
||||
"description": "Input parameters for this node.",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"required": ["name", "nodes"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["version", "components", "pipelines"],
|
||||
"additionalProperties": False,
|
||||
"definitions": all_definitions,
|
||||
}
|
||||
return pipeline_schema
|
||||
|
||||
|
||||
def list_indexed_versions(index):
|
||||
"""
|
||||
Given the schema index as a parsed JSON,
|
||||
return a list of all the versions it contains.
|
||||
"""
|
||||
indexed_versions = []
|
||||
for version_entry in index["oneOf"]:
|
||||
for property_entry in version_entry["allOf"]:
|
||||
if "properties" in property_entry.keys():
|
||||
indexed_versions.append(property_entry["properties"]["version"]["const"])
|
||||
return indexed_versions
|
||||
|
||||
|
||||
def cleanup_rc_versions(index):
|
||||
"""
|
||||
Given the schema index as a parsed JSON,
|
||||
removes any existing (unstable) rc version from it.
|
||||
"""
|
||||
new_versions_list = []
|
||||
for version_entry in index["oneOf"]:
|
||||
for property_entry in version_entry["allOf"]:
|
||||
if "properties" in property_entry.keys():
|
||||
if "rc" not in property_entry["properties"]["version"]["const"]:
|
||||
new_versions_list.append(version_entry)
|
||||
break
|
||||
index["oneOf"] = new_versions_list
|
||||
return index
|
||||
|
||||
|
||||
def new_version_entry(version):
|
||||
"""
|
||||
Returns a new entry for the version index JSON schema.
|
||||
"""
|
||||
return {
|
||||
"allOf": [
|
||||
{"properties": {"version": {"const": version}}},
|
||||
{
|
||||
"$ref": "https://raw.githubusercontent.com/deepset-ai/haystack/master/json-schemas/"
|
||||
f"haystack-pipeline-{version}.schema.json"
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def generate_json_schema():
|
||||
# Create new schema file
|
||||
pipeline_schema = get_json_schema()
|
||||
destination_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
destination_path.write_text(json.dumps(pipeline_schema, indent=2))
|
||||
|
||||
# Update schema index
|
||||
index = []
|
||||
index_path = Path(__file__).parent.parent.parent / "json-schemas" / "haystack-pipeline.schema.json"
|
||||
with open(index_path, "r") as index_file:
|
||||
index = json.load(index_file)
|
||||
if index:
|
||||
index = cleanup_rc_versions(index)
|
||||
indexed_versions = list_indexed_versions(index)
|
||||
if not any(version == schema_version for version in indexed_versions):
|
||||
index["oneOf"].append(new_version_entry(schema_version))
|
||||
with open(index_path, "w") as index_file:
|
||||
json.dump(index, index_file, indent=4)
|
||||
|
||||
|
||||
def main():
|
||||
from github import Github
|
||||
|
||||
generate_json_schema()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
settings = Settings()
|
||||
logging.info(f"Using config: {settings.json()}")
|
||||
g = Github(settings.input_token.get_secret_value())
|
||||
repo = g.get_repo(settings.github_repository)
|
||||
|
||||
logging.info("Setting up GitHub Actions git user")
|
||||
subprocess.run(["git", "config", "user.name", "github-actions"], check=True)
|
||||
subprocess.run(["git", "config", "user.email", "github-actions@github.com"], check=True)
|
||||
branch_name = "generate-json-schema"
|
||||
logging.info(f"Creating a new branch {branch_name}")
|
||||
subprocess.run(["git", "checkout", "-b", branch_name], check=True)
|
||||
logging.info("Adding updated file")
|
||||
subprocess.run(["git", "add", str(destination_path)], check=True)
|
||||
logging.info("Committing updated file")
|
||||
message = "⬆ Upgrade JSON Schema file"
|
||||
subprocess.run(["git", "commit", "-m", message], check=True)
|
||||
logging.info("Pushing branch")
|
||||
subprocess.run(["git", "push", "origin", branch_name], check=True)
|
||||
logging.info("Creating PR")
|
||||
pr = repo.create_pull(title=message, body=message, base="master", head=branch_name)
|
||||
logging.info(f"Created PR: {pr.number}")
|
||||
logging.info("Finished")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# If you only want to generate the JSON Schema file without submitting a PR
|
||||
# uncomment this line:
|
||||
generate_json_schema()
|
||||
|
||||
# and comment this line:
|
||||
# main()
|
||||
update_json_schema(update_index=True)
|
||||
|
||||
32
.github/utils/generate_openapi_specs.py
vendored
Normal file
32
.github/utils/generate_openapi_specs.py
vendored
Normal file
@ -0,0 +1,32 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
|
||||
REST_PATH = Path("./rest_api").absolute()
|
||||
PIPELINE_PATH = str(REST_PATH / "pipeline" / "pipeline_empty.haystack-pipeline.yml")
|
||||
APP_PATH = str(REST_PATH / "application.py")
|
||||
DOCS_PATH = Path("./docs") / "_src" / "api" / "openapi"
|
||||
|
||||
os.environ["PIPELINE_YAML_PATH"] = PIPELINE_PATH
|
||||
|
||||
print(f"Loading OpenAPI specs from {APP_PATH} with pipeline at {PIPELINE_PATH}")
|
||||
|
||||
sys.path.append(".")
|
||||
from rest_api.application import get_openapi_specs, haystack_version
|
||||
|
||||
# Generate the openapi specs
|
||||
specs = get_openapi_specs()
|
||||
|
||||
# Dump the specs into a JSON file
|
||||
with open(DOCS_PATH / "openapi.json", "w") as f:
|
||||
json.dump(specs, f, indent=4)
|
||||
|
||||
# Remove rc versions of the specs from the folder
|
||||
for specs_file in os.listdir():
|
||||
if os.path.isfile(specs_file) and "rc" in specs_file and Path(specs_file).suffix == ".json":
|
||||
os.remove(specs_file)
|
||||
|
||||
# Add versioned copy
|
||||
shutil.copy(DOCS_PATH / "openapi.json", DOCS_PATH / f"openapi-{haystack_version}.json")
|
||||
9
.github/workflows/autoformat.yml
vendored
9
.github/workflows/autoformat.yml
vendored
@ -40,7 +40,7 @@ jobs:
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install .[test]
|
||||
pip install .[all]
|
||||
pip install rest_api/
|
||||
pip install ui/
|
||||
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cpu.html
|
||||
@ -69,14 +69,11 @@ jobs:
|
||||
|
||||
# Generates the OpenAPI specs file to be used on the documentation website
|
||||
- name: Generate OpenAPI Specs
|
||||
run: |
|
||||
pip install rest_api/
|
||||
cd docs/_src/api/openapi/
|
||||
python generate_openapi_specs.py
|
||||
run: python .github/utils/generate_openapi_specs.py
|
||||
|
||||
# Generates a new JSON schema for the pipeline YAML validation
|
||||
- name: Generate JSON schema for pipelines
|
||||
run: python ./.github/utils/generate_json_schema.py
|
||||
run: python .github/utils/generate_json_schema.py
|
||||
|
||||
# Commit the files to GitHub
|
||||
- name: Commit files
|
||||
|
||||
9
.github/workflows/linux_ci.yml
vendored
9
.github/workflows/linux_ci.yml
vendored
@ -193,14 +193,11 @@ jobs:
|
||||
|
||||
# Generates the OpenAPI specs file to be used on the documentation website
|
||||
- name: Generate OpenAPI Specs
|
||||
run: |
|
||||
pip install rest_api/
|
||||
cd docs/_src/api/openapi/
|
||||
python generate_openapi_specs.py
|
||||
run: python .github/utils/generate_openapi_specs.py
|
||||
|
||||
# Generates a new JSON schema for the pipeline YAML validation
|
||||
- name: Generate JSON schema for pipelines
|
||||
run: python ./.github/utils/generate_json_schema.py
|
||||
run: python .github/utils/generate_json_schema.py
|
||||
|
||||
# If there is anything to commit, fail
|
||||
# Note: this CI action mirrors autoformat.yml, with the difference that it
|
||||
@ -287,7 +284,7 @@ jobs:
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install .[test]
|
||||
pip install .[all]
|
||||
pip install rest_api/
|
||||
pip install ui/
|
||||
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cpu.html
|
||||
|
||||
@ -2174,7 +2174,7 @@ the vector embeddings are indexed in a FAISS Index.
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
def __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: "Optional[faiss.swigfaiss.Index]" = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = "overwrite", faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, isolation_level: str = None, **kwargs, ,)
|
||||
def __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional[faiss.swigfaiss.Index] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = "overwrite", faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, isolation_level: str = None, **kwargs, ,)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
@ -3565,6 +3565,54 @@ operation.
|
||||
|
||||
None
|
||||
|
||||
<a id="weaviate.WeaviateDocumentStore.delete_labels"></a>
|
||||
|
||||
#### delete\_labels
|
||||
|
||||
```python
|
||||
def delete_labels()
|
||||
```
|
||||
|
||||
Implemented to respect BaseDocumentStore's contract.
|
||||
|
||||
Weaviate does not support labels (yet).
|
||||
|
||||
<a id="weaviate.WeaviateDocumentStore.get_all_labels"></a>
|
||||
|
||||
#### get\_all\_labels
|
||||
|
||||
```python
|
||||
def get_all_labels()
|
||||
```
|
||||
|
||||
Implemented to respect BaseDocumentStore's contract.
|
||||
|
||||
Weaviate does not support labels (yet).
|
||||
|
||||
<a id="weaviate.WeaviateDocumentStore.get_label_count"></a>
|
||||
|
||||
#### get\_label\_count
|
||||
|
||||
```python
|
||||
def get_label_count()
|
||||
```
|
||||
|
||||
Implemented to respect BaseDocumentStore's contract.
|
||||
|
||||
Weaviate does not support labels (yet).
|
||||
|
||||
<a id="weaviate.WeaviateDocumentStore.write_labels"></a>
|
||||
|
||||
#### write\_labels
|
||||
|
||||
```python
|
||||
def write_labels()
|
||||
```
|
||||
|
||||
Implemented to respect BaseDocumentStore's contract.
|
||||
|
||||
Weaviate does not support labels (yet).
|
||||
|
||||
<a id="graphdb"></a>
|
||||
|
||||
# Module graphdb
|
||||
|
||||
@ -90,7 +90,7 @@ i.e. the model can easily adjust to domain documents even after training has fin
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
def __init__(model_name_or_path: str = "facebook/rag-token-nq", model_version: Optional[str] = None, retriever: Optional[DensePassageRetriever] = None, generator_type: RAGeneratorType = RAGeneratorType.TOKEN, top_k: int = 2, max_length: int = 200, min_length: int = 2, num_beams: int = 2, embed_title: bool = True, prefix: Optional[str] = None, use_gpu: bool = True)
|
||||
def __init__(model_name_or_path: str = "facebook/rag-token-nq", model_version: Optional[str] = None, retriever: Optional[DensePassageRetriever] = None, generator_type: str = "token", top_k: int = 2, max_length: int = 200, min_length: int = 2, num_beams: int = 2, embed_title: bool = True, prefix: Optional[str] = None, use_gpu: bool = True)
|
||||
```
|
||||
|
||||
Load a RAG model from Transformers along with passage_embedding_model.
|
||||
@ -104,7 +104,7 @@ See https://huggingface.co/transformers/model_doc/rag.html for more details
|
||||
See https://huggingface.co/models for full list of available models.
|
||||
- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
- `retriever`: `DensePassageRetriever` used to embedded passages for the docs passed to `predict()`. This is optional and is only needed if the docs you pass don't already contain embeddings in `Document.embedding`.
|
||||
- `generator_type`: Which RAG generator implementation to use? RAG-TOKEN or RAG-SEQUENCE
|
||||
- `generator_type`: Which RAG generator implementation to use ("token" or "sequence")
|
||||
- `top_k`: Number of independently generated text to return
|
||||
- `max_length`: Maximum length of generated text
|
||||
- `min_length`: Minimum length of generated text
|
||||
|
||||
@ -17,7 +17,7 @@ RootNode feeds inputs together with corresponding params to a Pipeline.
|
||||
## BasePipeline
|
||||
|
||||
```python
|
||||
class BasePipeline()
|
||||
class BasePipeline(ABC)
|
||||
```
|
||||
|
||||
Base class for pipelines, providing the most basic methods to load and save them in different ways.
|
||||
@ -28,10 +28,11 @@ See also the `Pipeline` class for the actual pipeline logic.
|
||||
#### get\_config
|
||||
|
||||
```python
|
||||
@abstractmethod
|
||||
def get_config(return_defaults: bool = False) -> dict
|
||||
```
|
||||
|
||||
Returns a configuration for the Pipeline that can be used with `BasePipeline.load_from_config()`.
|
||||
Returns a configuration for the Pipeline that can be used with `Pipeline.load_from_config()`.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
@ -81,6 +82,7 @@ Default value is True.
|
||||
|
||||
```python
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load_from_config(cls, pipeline_config: Dict, pipeline_name: Optional[str] = None, overwrite_with_env_variables: bool = True)
|
||||
```
|
||||
|
||||
@ -137,6 +139,7 @@ variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an
|
||||
|
||||
```python
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load_from_yaml(cls, path: Path, pipeline_name: Optional[str] = None, overwrite_with_env_variables: bool = True)
|
||||
```
|
||||
|
||||
@ -519,6 +522,62 @@ Create a Graphviz visualization of the pipeline.
|
||||
|
||||
- `path`: the path to save the image.
|
||||
|
||||
<a id="base.Pipeline.load_from_yaml"></a>
|
||||
|
||||
#### load\_from\_yaml
|
||||
|
||||
```python
|
||||
@classmethod
|
||||
def load_from_yaml(cls, path: Path, pipeline_name: Optional[str] = None, overwrite_with_env_variables: bool = True)
|
||||
```
|
||||
|
||||
Load Pipeline from a YAML file defining the individual components and how they're tied together to form
|
||||
|
||||
a Pipeline. A single YAML can declare multiple Pipelines, in which case an explicit `pipeline_name` must
|
||||
be passed.
|
||||
|
||||
Here's a sample configuration:
|
||||
|
||||
```yaml
|
||||
| version: '1.0'
|
||||
|
|
||||
| components: # define all the building-blocks for Pipeline
|
||||
| - name: MyReader # custom-name for the component; helpful for visualization & debugging
|
||||
| type: FARMReader # Haystack Class name for the component
|
||||
| params:
|
||||
| no_ans_boost: -10
|
||||
| model_name_or_path: deepset/roberta-base-squad2
|
||||
| - name: MyESRetriever
|
||||
| type: ElasticsearchRetriever
|
||||
| params:
|
||||
| document_store: MyDocumentStore # params can reference other components defined in the YAML
|
||||
| custom_query: null
|
||||
| - name: MyDocumentStore
|
||||
| type: ElasticsearchDocumentStore
|
||||
| params:
|
||||
| index: haystack_test
|
||||
|
|
||||
| pipelines: # multiple Pipelines can be defined using the components from above
|
||||
| - name: my_query_pipeline # a simple extractive-qa Pipeline
|
||||
| nodes:
|
||||
| - name: MyESRetriever
|
||||
| inputs: [Query]
|
||||
| - name: MyReader
|
||||
| inputs: [MyESRetriever]
|
||||
```
|
||||
|
||||
Note that, in case of a mismatch in version between Haystack and the YAML, a warning will be printed.
|
||||
If the pipeline loads correctly regardless, save again the pipeline using `Pipeline.save_to_yaml()` to remove the warning.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `path`: path of the YAML file.
|
||||
- `pipeline_name`: if the YAML contains multiple pipelines, the pipeline_name to load must be set.
|
||||
- `overwrite_with_env_variables`: Overwrite the YAML configuration with environment variables. For example,
|
||||
to change index name param for an ElasticsearchDocumentStore, an env
|
||||
variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an
|
||||
`_` sign must be used to specify nested hierarchical properties.
|
||||
|
||||
<a id="base.Pipeline.load_from_config"></a>
|
||||
|
||||
#### load\_from\_config
|
||||
|
||||
@ -15,6 +15,7 @@ class BasePreProcessor(BaseComponent)
|
||||
#### process
|
||||
|
||||
```python
|
||||
@abstractmethod
|
||||
def process(documents: Union[dict, List[dict]], clean_whitespace: Optional[bool] = True, clean_header_footer: Optional[bool] = False, clean_empty_lines: Optional[bool] = True, remove_substrings: List[str] = [], split_by: Optional[str] = "word", split_length: Optional[int] = 1000, split_overlap: Optional[int] = None, split_respect_sentence_boundary: Optional[bool] = True) -> List[dict]
|
||||
```
|
||||
|
||||
|
||||
@ -107,7 +107,7 @@ class ElasticsearchRetriever(BaseRetriever)
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
def __init__(document_store: KeywordDocumentStore, top_k: int = 10, custom_query: str = None)
|
||||
def __init__(document_store: KeywordDocumentStore, top_k: int = 10, custom_query: Optional[str] = None)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
|
||||
@ -1,31 +0,0 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
|
||||
sys.path.append("../../../../")
|
||||
|
||||
rest_path = Path("../../../../rest_api").absolute()
|
||||
pipeline_path = str(rest_path / "pipeline" / "pipeline_empty.yaml")
|
||||
app_path = str(rest_path / "application.py")
|
||||
print(f"Loading OpenAPI specs from {app_path} with pipeline at {pipeline_path}")
|
||||
|
||||
os.environ["PIPELINE_YAML_PATH"] = pipeline_path
|
||||
|
||||
from rest_api.application import get_openapi_specs, haystack_version
|
||||
|
||||
# Generate the openapi specs
|
||||
specs = get_openapi_specs()
|
||||
|
||||
# Dump the specs into a JSON file
|
||||
with open("openapi.json", "w") as f:
|
||||
json.dump(specs, f, indent=4)
|
||||
|
||||
# Remove rc versions of the specs from the folder
|
||||
for specs_file in os.listdir():
|
||||
if os.path.isfile(specs_file) and "rc" in specs_file and Path(specs_file).suffix == ".json":
|
||||
os.remove(specs_file)
|
||||
|
||||
# Add versioned copy
|
||||
shutil.copy("openapi.json", f"openapi-{haystack_version}.json")
|
||||
@ -7,7 +7,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
# Python <= 3.7
|
||||
import importlib_metadata as metadata # type: ignore
|
||||
|
||||
__version__ = metadata.version("farm-haystack")
|
||||
__version__: str = str(metadata.version("farm-haystack"))
|
||||
|
||||
|
||||
# This configuration must be done before any import to apply to all submodules
|
||||
|
||||
@ -65,14 +65,7 @@ class DeepsetCloudDocumentStore(KeywordDocumentStore):
|
||||
f"{indexing_info['pending_file_count']} files are pending to be indexed. Indexing status: {indexing_info['status']}"
|
||||
)
|
||||
|
||||
self.set_config(
|
||||
workspace=workspace,
|
||||
index=index,
|
||||
duplicate_documents=duplicate_documents,
|
||||
api_endpoint=api_endpoint,
|
||||
similarity=similarity,
|
||||
return_embedding=return_embedding,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
def get_all_documents(
|
||||
self,
|
||||
|
||||
@ -140,41 +140,7 @@ class ElasticsearchDocumentStore(KeywordDocumentStore):
|
||||
:param use_system_proxy: Whether to use system proxy.
|
||||
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
host=host,
|
||||
port=port,
|
||||
username=username,
|
||||
password=password,
|
||||
api_key_id=api_key_id,
|
||||
api_key=api_key,
|
||||
aws4auth=aws4auth,
|
||||
index=index,
|
||||
label_index=label_index,
|
||||
search_fields=search_fields,
|
||||
content_field=content_field,
|
||||
name_field=name_field,
|
||||
embedding_field=embedding_field,
|
||||
embedding_dim=embedding_dim,
|
||||
custom_mapping=custom_mapping,
|
||||
excluded_meta_data=excluded_meta_data,
|
||||
analyzer=analyzer,
|
||||
scheme=scheme,
|
||||
ca_certs=ca_certs,
|
||||
verify_certs=verify_certs,
|
||||
create_index=create_index,
|
||||
duplicate_documents=duplicate_documents,
|
||||
refresh_type=refresh_type,
|
||||
similarity=similarity,
|
||||
timeout=timeout,
|
||||
return_embedding=return_embedding,
|
||||
index_type=index_type,
|
||||
scroll=scroll,
|
||||
skip_missing_embeddings=skip_missing_embeddings,
|
||||
synonyms=synonyms,
|
||||
synonym_type=synonym_type,
|
||||
use_system_proxy=use_system_proxy,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.client = self._init_elastic_client(
|
||||
host=host,
|
||||
@ -352,11 +318,12 @@ class ElasticsearchDocumentStore(KeywordDocumentStore):
|
||||
if self.search_fields:
|
||||
for search_field in self.search_fields:
|
||||
if search_field in mapping["properties"] and mapping["properties"][search_field]["type"] != "text":
|
||||
host_data = self.client.transport.hosts[0]
|
||||
raise Exception(
|
||||
f"The search_field '{search_field}' of index '{index_name}' with type '{mapping['properties'][search_field]['type']}' "
|
||||
f"does not have the right type 'text' to be queried in fulltext search. Please use only 'text' type properties as search_fields. "
|
||||
f"This error might occur if you are trying to use haystack 1.0 and above with an existing elasticsearch index created with a previous version of haystack."
|
||||
f"In this case deleting the index with `curl -X DELETE \"{self.pipeline_config['params']['host']}:{self.pipeline_config['params']['port']}/{index_name}\"` will fix your environment. "
|
||||
f"In this case deleting the index with `curl -X DELETE \"{host_data['host']}:{host_data['port']}/{index_name}\"` will fix your environment. "
|
||||
f"Note, that all data stored in the index will be lost!"
|
||||
)
|
||||
if self.embedding_field:
|
||||
@ -1823,11 +1790,12 @@ class OpenSearchDocumentStore(ElasticsearchDocumentStore):
|
||||
search_field in mappings["properties"]
|
||||
and mappings["properties"][search_field]["type"] != "text"
|
||||
):
|
||||
host_data = self.client.transport.hosts[0]
|
||||
raise Exception(
|
||||
f"The search_field '{search_field}' of index '{index_name}' with type '{mappings['properties'][search_field]['type']}' "
|
||||
f"does not have the right type 'text' to be queried in fulltext search. Please use only 'text' type properties as search_fields. "
|
||||
f"This error might occur if you are trying to use haystack 1.0 and above with an existing elasticsearch index created with a previous version of haystack."
|
||||
f"In this case deleting the index with `curl -X DELETE \"{self.pipeline_config['params']['host']}:{self.pipeline_config['params']['port']}/{index_name}\"` will fix your environment. "
|
||||
f"In this case deleting the index with `curl -X DELETE \"{host_data['host']}:{host_data['port']}/{index_name}\"` will fix your environment. "
|
||||
f"Note, that all data stored in the index will be lost!"
|
||||
)
|
||||
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Union, List, Optional, Dict, Generator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from haystack.nodes.retriever import BaseRetriever
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union, List, Optional, Dict, Generator
|
||||
from tqdm.auto import tqdm
|
||||
import warnings
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from tqdm.auto import tqdm
|
||||
from inspect import Signature, signature
|
||||
|
||||
try:
|
||||
@ -22,7 +22,6 @@ except (ImportError, ModuleNotFoundError) as ie:
|
||||
|
||||
_optional_component_not_installed(__name__, "faiss", ie)
|
||||
|
||||
|
||||
from haystack.schema import Document
|
||||
from haystack.document_stores.base import get_batches_from_generator
|
||||
|
||||
@ -47,7 +46,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
vector_dim: int = None,
|
||||
embedding_dim: int = 768,
|
||||
faiss_index_factory_str: str = "Flat",
|
||||
faiss_index: "Optional[faiss.swigfaiss.Index]" = None,
|
||||
faiss_index: Optional[faiss.swigfaiss.Index] = None,
|
||||
return_embedding: bool = False,
|
||||
index: str = "document",
|
||||
similarity: str = "dot_product",
|
||||
@ -112,21 +111,6 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
self.__class__.__init__(self, **init_params) # pylint: disable=non-parent-init-called
|
||||
return
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
sql_url=sql_url,
|
||||
vector_dim=vector_dim,
|
||||
embedding_dim=embedding_dim,
|
||||
faiss_index_factory_str=faiss_index_factory_str,
|
||||
return_embedding=return_embedding,
|
||||
duplicate_documents=duplicate_documents,
|
||||
index=index,
|
||||
similarity=similarity,
|
||||
embedding_field=embedding_field,
|
||||
progress_bar=progress_bar,
|
||||
isolation_level=isolation_level,
|
||||
)
|
||||
|
||||
if similarity in ("dot_product", "cosine"):
|
||||
self.similarity = similarity
|
||||
self.metric_type = faiss.METRIC_INNER_PRODUCT
|
||||
@ -614,8 +598,15 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
config_path = index_path.with_suffix(".json")
|
||||
|
||||
faiss.write_index(self.faiss_indexes[self.index], str(index_path))
|
||||
|
||||
config_to_save = deepcopy(self._component_config["params"])
|
||||
keys_to_remove = ["faiss_index", "faiss_index_path"]
|
||||
for key in keys_to_remove:
|
||||
if key in config_to_save.keys():
|
||||
del config_to_save[key]
|
||||
|
||||
with open(config_path, "w") as ipp:
|
||||
json.dump(self.pipeline_config["params"], ipp)
|
||||
json.dump(config_to_save, ipp, default=str)
|
||||
|
||||
def _load_init_params_from_config(
|
||||
self, index_path: Union[str, Path], config_path: Optional[Union[str, Path]] = None
|
||||
|
||||
@ -38,8 +38,7 @@ class GraphDBKnowledgeGraph(BaseKnowledgeGraph):
|
||||
:param index: name of the index (also called repository) stored in the GraphDB instance
|
||||
:param prefixes: definitions of namespaces with a new line after each namespace, e.g., PREFIX hp: <https://deepset.ai/harry_potter/>
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(host=host, port=port, username=username, password=password, index=index, prefixes=prefixes)
|
||||
super().__init__()
|
||||
|
||||
self.url = f"http://{host}:{port}"
|
||||
self.index = index
|
||||
|
||||
@ -66,17 +66,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
Since the data is originally stored in CPU memory there is little risk of overruning memory
|
||||
when running on CPU.
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
index=index,
|
||||
label_index=label_index,
|
||||
embedding_field=embedding_field,
|
||||
embedding_dim=embedding_dim,
|
||||
return_embedding=return_embedding,
|
||||
similarity=similarity,
|
||||
progress_bar=progress_bar,
|
||||
duplicate_documents=duplicate_documents,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.indexes: Dict[str, Dict] = defaultdict(dict)
|
||||
self.index: str = index
|
||||
|
||||
@ -107,25 +107,7 @@ class Milvus1DocumentStore(SQLDocumentStore):
|
||||
exists.
|
||||
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
sql_url=sql_url,
|
||||
milvus_url=milvus_url,
|
||||
connection_pool=connection_pool,
|
||||
index=index,
|
||||
vector_dim=vector_dim,
|
||||
embedding_dim=embedding_dim,
|
||||
index_file_size=index_file_size,
|
||||
similarity=similarity,
|
||||
index_type=index_type,
|
||||
index_param=index_param,
|
||||
search_param=search_param,
|
||||
duplicate_documents=duplicate_documents,
|
||||
return_embedding=return_embedding,
|
||||
embedding_field=embedding_field,
|
||||
progress_bar=progress_bar,
|
||||
isolation_level=isolation_level,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool)
|
||||
|
||||
|
||||
@ -126,29 +126,8 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
exists.
|
||||
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
sql_url=sql_url,
|
||||
host=host,
|
||||
port=port,
|
||||
connection_pool=connection_pool,
|
||||
index=index,
|
||||
vector_dim=vector_dim,
|
||||
embedding_dim=embedding_dim,
|
||||
index_file_size=index_file_size,
|
||||
similarity=similarity,
|
||||
index_type=index_type,
|
||||
index_param=index_param,
|
||||
search_param=search_param,
|
||||
duplicate_documents=duplicate_documents,
|
||||
id_field=id_field,
|
||||
return_embedding=return_embedding,
|
||||
embedding_field=embedding_field,
|
||||
progress_bar=progress_bar,
|
||||
custom_fields=custom_fields,
|
||||
isolation_level=isolation_level,
|
||||
)
|
||||
connections.add_connection(default={"host": host, "port": port})
|
||||
connections.connect()
|
||||
|
||||
|
||||
@ -134,15 +134,8 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
:param check_same_thread: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior)
|
||||
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
url=url,
|
||||
index=index,
|
||||
label_index=label_index,
|
||||
duplicate_documents=duplicate_documents,
|
||||
check_same_thread=check_same_thread,
|
||||
)
|
||||
create_engine_params = {}
|
||||
if isolation_level:
|
||||
create_engine_params["isolation_level"] = isolation_level
|
||||
|
||||
@ -105,25 +105,8 @@ class WeaviateDocumentStore(BaseDocumentStore):
|
||||
"""
|
||||
if similarity != "cosine":
|
||||
raise ValueError(f"Weaviate only supports cosine similarity, but you provided {similarity}")
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
host=host,
|
||||
port=port,
|
||||
timeout_config=timeout_config,
|
||||
username=username,
|
||||
password=password,
|
||||
index=index,
|
||||
embedding_dim=embedding_dim,
|
||||
content_field=content_field,
|
||||
name_field=name_field,
|
||||
similarity=similarity,
|
||||
index_type=index_type,
|
||||
custom_schema=custom_schema,
|
||||
return_embedding=return_embedding,
|
||||
embedding_field=embedding_field,
|
||||
progress_bar=progress_bar,
|
||||
duplicate_documents=duplicate_documents,
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Connect to Weaviate server using python binding
|
||||
weaviate_url = f"{host}:{port}"
|
||||
@ -1162,3 +1145,35 @@ class WeaviateDocumentStore(BaseDocumentStore):
|
||||
docs_to_delete = [doc for doc in docs_to_delete if doc.id in ids]
|
||||
for doc in docs_to_delete:
|
||||
self.weaviate_client.data_object.delete(doc.id)
|
||||
|
||||
def delete_labels(self):
|
||||
"""
|
||||
Implemented to respect BaseDocumentStore's contract.
|
||||
|
||||
Weaviate does not support labels (yet).
|
||||
"""
|
||||
raise NotImplementedError("Weaviate does not support labels (yet).")
|
||||
|
||||
def get_all_labels(self):
|
||||
"""
|
||||
Implemented to respect BaseDocumentStore's contract.
|
||||
|
||||
Weaviate does not support labels (yet).
|
||||
"""
|
||||
raise NotImplementedError("Weaviate does not support labels (yet).")
|
||||
|
||||
def get_label_count(self):
|
||||
"""
|
||||
Implemented to respect BaseDocumentStore's contract.
|
||||
|
||||
Weaviate does not support labels (yet).
|
||||
"""
|
||||
raise NotImplementedError("Weaviate does not support labels (yet).")
|
||||
|
||||
def write_labels(self):
|
||||
"""
|
||||
Implemented to respect BaseDocumentStore's contract.
|
||||
|
||||
Weaviate does not support labels (yet).
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -1,8 +1,71 @@
|
||||
# coding: utf8
|
||||
"""Custom Errors for Haystack stacks"""
|
||||
"""Custom Errors for Haystack"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class DuplicateDocumentError(ValueError):
|
||||
class HaystackError(Exception):
|
||||
"""
|
||||
Any error generated by Haystack.
|
||||
|
||||
This error wraps its source transparently in such a way that its attributes
|
||||
can be accessed directly: for example, if the original error has a `message` attribute,
|
||||
`HaystackError.message` will exist and have the expected content.
|
||||
"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None, docs_link: Optional[str] = None):
|
||||
super().__init__()
|
||||
if message:
|
||||
self.message = message
|
||||
self.docs_link = None
|
||||
|
||||
def __getattr__(self, attr):
|
||||
# If self.__cause__ is None, it will raise the expected AttributeError
|
||||
getattr(self.__cause__, attr)
|
||||
|
||||
def __str__(self):
|
||||
if self.docs_link:
|
||||
docs_message = f"\n\nCheck out the documentation at {self.docs_link}"
|
||||
return self.message + docs_message
|
||||
return self.message
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class PipelineError(HaystackError):
|
||||
"""Exception for issues raised within a pipeline"""
|
||||
|
||||
def __init__(
|
||||
self, message: Optional[str] = None, docs_link: Optional[str] = "https://haystack.deepset.ai/pipelines"
|
||||
):
|
||||
super().__init__(message=message, docs_link=docs_link)
|
||||
|
||||
|
||||
class PipelineSchemaError(PipelineError):
|
||||
"""Exception for issues arising when reading/building the JSON schema of pipelines"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PipelineConfigError(PipelineError):
|
||||
"""Exception for issues raised within a pipeline's config file"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: Optional[str] = None,
|
||||
docs_link: Optional[str] = "https://haystack.deepset.ai/pipelines#yaml-file-definitions",
|
||||
):
|
||||
super().__init__(message=message, docs_link=docs_link)
|
||||
|
||||
|
||||
class DocumentStoreError(HaystackError):
|
||||
"""Exception for issues that occur in a document store"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DuplicateDocumentError(DocumentStoreError, ValueError):
|
||||
"""Exception for Duplicate document"""
|
||||
|
||||
pass
|
||||
|
||||
@ -884,7 +884,7 @@ class DistillationDataSilo(DataSilo):
|
||||
"max_seq_len": self.processor.max_seq_len,
|
||||
"dev_split": self.processor.dev_split,
|
||||
"tasks": self.processor.tasks,
|
||||
"teacher_name_or_path": self.teacher.pipeline_config["params"]["model_name_or_path"],
|
||||
"teacher_name_or_path": self.teacher.model_name_or_path,
|
||||
"data_silo_type": self.__class__.__name__,
|
||||
}
|
||||
checksum = get_dict_checksum(payload_dict)
|
||||
|
||||
486
haystack/nodes/_json_schema.py
Normal file
486
haystack/nodes/_json_schema.py
Normal file
@ -0,0 +1,486 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from copy import deepcopy
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
import pydantic.schema
|
||||
from pydantic import BaseConfig, BaseSettings, Required, SecretStr, create_model
|
||||
from pydantic.typing import ForwardRef, evaluate_forwardref, is_callable_type
|
||||
from pydantic.fields import ModelField
|
||||
from pydantic.schema import (
|
||||
SkipField,
|
||||
TypeModelOrEnum,
|
||||
TypeModelSet,
|
||||
encode_default,
|
||||
field_singleton_schema as _field_singleton_schema,
|
||||
)
|
||||
|
||||
from haystack import __version__ as haystack_version
|
||||
from haystack.errors import HaystackError, PipelineSchemaError
|
||||
from haystack.nodes.base import BaseComponent
|
||||
|
||||
|
||||
JSON_SCHEMAS_PATH = Path(__file__).parent.parent.parent / "json-schemas"
|
||||
SCHEMA_URL = "https://haystack.deepset.ai/json-schemas/"
|
||||
|
||||
# Allows accessory classes (like enums and helpers) to be registered as valid input for
|
||||
# custom node's init parameters. For now we disable this feature, but flipping this variables
|
||||
# re-enables it. Mind that string validation will still cut out most attempts to load anything
|
||||
# else than enums and class constants: see Pipeline.load_from_config()
|
||||
ALLOW_ACCESSORY_CLASSES = False
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
input_token: SecretStr
|
||||
github_repository: str
|
||||
|
||||
|
||||
# Monkey patch Pydantic's field_singleton_schema to convert classes and functions to
|
||||
# strings in JSON Schema
|
||||
def field_singleton_schema(
|
||||
field: ModelField,
|
||||
*,
|
||||
by_alias: bool,
|
||||
model_name_map: Dict[TypeModelOrEnum, str],
|
||||
ref_template: str,
|
||||
schema_overrides: bool = False,
|
||||
ref_prefix: Optional[str] = None,
|
||||
known_models: TypeModelSet,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]:
|
||||
try:
|
||||
return _field_singleton_schema(
|
||||
field,
|
||||
by_alias=by_alias,
|
||||
model_name_map=model_name_map,
|
||||
ref_template=ref_template,
|
||||
schema_overrides=schema_overrides,
|
||||
ref_prefix=ref_prefix,
|
||||
known_models=known_models,
|
||||
)
|
||||
except (ValueError, SkipField):
|
||||
schema: Dict[str, Any] = {"type": "string"}
|
||||
|
||||
if isinstance(field.default, type) or is_callable_type(field.default):
|
||||
default = field.default.__name__
|
||||
else:
|
||||
default = field.default
|
||||
if not field.required:
|
||||
schema["default"] = encode_default(default)
|
||||
return schema, {}, set()
|
||||
|
||||
|
||||
# Monkeypatch Pydantic's field_singleton_schema
|
||||
pydantic.schema.field_singleton_schema = field_singleton_schema
|
||||
|
||||
|
||||
# From FastAPI's internals
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name, kind=param.kind, default=param.default, annotation=get_typed_annotation(param, globalns)
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
typed_signature = inspect.Signature(typed_params)
|
||||
return typed_signature
|
||||
|
||||
|
||||
# From FastAPI's internals
|
||||
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
|
||||
annotation = param.annotation
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||
return annotation
|
||||
|
||||
|
||||
class Config(BaseConfig):
|
||||
extra = "forbid" # type: ignore
|
||||
|
||||
|
||||
def find_subclasses_in_modules(importable_modules: List[str], include_base_classes: bool = False):
|
||||
"""
|
||||
This function returns a list `(module, class)` of all the classes that can be imported
|
||||
dynamically, for example from a pipeline YAML definition or to generate documentation.
|
||||
|
||||
By default it won't include Base classes, which should be abstract.
|
||||
"""
|
||||
return [
|
||||
(module, clazz)
|
||||
for module in importable_modules
|
||||
for _, clazz in inspect.getmembers(sys.modules[module])
|
||||
if (
|
||||
inspect.isclass(clazz)
|
||||
and not inspect.isabstract(clazz)
|
||||
and issubclass(clazz, BaseComponent)
|
||||
and (include_base_classes or not clazz.__name__.startswith("Base"))
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def create_schema_for_node(node: BaseComponent) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
Create the JSON schema for a single BaseComponent subclass,
|
||||
including all accessory classes.
|
||||
|
||||
:returns: the schema for the node and all accessory classes,
|
||||
and a dict with the reference to the node only.
|
||||
"""
|
||||
if not hasattr(node, "__name__"):
|
||||
raise PipelineSchemaError(f"Node {node} has no __name__ attribute, cannot create a schema for it.")
|
||||
|
||||
node_name = getattr(node, "__name__")
|
||||
|
||||
logger.info(f"Processing node: {node_name}")
|
||||
|
||||
# Read the relevant init parameters from __init__'s signature
|
||||
init_method = getattr(node, "__init__", None)
|
||||
if not init_method:
|
||||
raise PipelineSchemaError(f"Could not read the __init__ method of {node_name} to create its schema.")
|
||||
|
||||
signature = get_typed_signature(init_method)
|
||||
param_fields = [
|
||||
param for param in signature.parameters.values() if param.kind not in {param.VAR_POSITIONAL, param.VAR_KEYWORD}
|
||||
]
|
||||
# Remove self parameter
|
||||
param_fields.pop(0)
|
||||
param_fields_kwargs: Dict[str, Any] = {}
|
||||
|
||||
# Read all the paramteres extracted from the __init__ method with type and default value
|
||||
for param in param_fields:
|
||||
annotation = Any
|
||||
if param.annotation != param.empty:
|
||||
annotation = param.annotation
|
||||
default = Required
|
||||
if param.default != param.empty:
|
||||
default = param.default
|
||||
param_fields_kwargs[param.name] = (annotation, default)
|
||||
|
||||
# Create the model with Pydantic and extract the schema
|
||||
model = create_model(f"{node_name}ComponentParams", __config__=Config, **param_fields_kwargs)
|
||||
model.update_forward_refs(**model.__dict__)
|
||||
params_schema = model.schema()
|
||||
params_schema["title"] = "Parameters"
|
||||
desc = "Each parameter can reference other components defined in the same YAML file."
|
||||
params_schema["description"] = desc
|
||||
|
||||
# Definitions for accessory classes will show up here
|
||||
params_definitions = {}
|
||||
if "definitions" in params_schema:
|
||||
if ALLOW_ACCESSORY_CLASSES:
|
||||
params_definitions = params_schema.pop("definitions")
|
||||
else:
|
||||
raise PipelineSchemaError(
|
||||
f"Node {node_name} takes object instances as parameters "
|
||||
"in its __init__ function. This is currently not allowed: "
|
||||
"please use only Python primitives"
|
||||
)
|
||||
|
||||
# Write out the schema and ref and return them
|
||||
component_name = f"{node_name}Component"
|
||||
component_schema = {
|
||||
component_name: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "Custom name for the component. Helpful for visualization and debugging.",
|
||||
"type": "string",
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"description": "Haystack Class name for the component.",
|
||||
"type": "string",
|
||||
"const": f"{node_name}",
|
||||
},
|
||||
"params": params_schema,
|
||||
},
|
||||
"required": ["type", "name"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
**params_definitions,
|
||||
}
|
||||
return component_schema, {"$ref": f"#/definitions/{component_name}"}
|
||||
|
||||
|
||||
def get_json_schema(
|
||||
filename: str, compatible_versions: List[str], modules: List[str] = ["haystack.document_stores", "haystack.nodes"]
|
||||
):
|
||||
"""
|
||||
Generate JSON schema for Haystack pipelines.
|
||||
"""
|
||||
schema_definitions = {} # All the schemas for the node and accessory classes
|
||||
node_refs = [] # References to the nodes only (accessory classes cannot be listed among the nodes in a config)
|
||||
|
||||
# List all known nodes in the given modules
|
||||
possible_nodes = find_subclasses_in_modules(importable_modules=modules)
|
||||
|
||||
# Build the definitions and refs for the nodes
|
||||
for _, node in possible_nodes:
|
||||
node_definition, node_ref = create_schema_for_node(node)
|
||||
schema_definitions.update(node_definition)
|
||||
node_refs.append(node_ref)
|
||||
|
||||
pipeline_schema = {
|
||||
"$schema": "http://json-schema.org/draft-07/schema",
|
||||
"$id": f"{SCHEMA_URL}{filename}",
|
||||
"title": "Haystack Pipeline",
|
||||
"description": "Haystack Pipeline YAML file describing the nodes of the pipelines. For more info read the docs at: https://haystack.deepset.ai/components/pipelines#yaml-file-definitions",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"version": {
|
||||
"title": "Version",
|
||||
"description": "Version of the Haystack Pipeline file.",
|
||||
"type": "string",
|
||||
"oneOf": [{"const": version} for version in compatible_versions],
|
||||
},
|
||||
"components": {
|
||||
"title": "Components",
|
||||
"description": "Component nodes and their configurations, to later be used in the pipelines section. Define here all the building blocks for the pipelines.",
|
||||
"type": "array",
|
||||
"items": {"anyOf": node_refs},
|
||||
"required": ["type", "name"],
|
||||
"additionalProperties": True, # To allow for custom components in IDEs - will be set to False at validation time.
|
||||
},
|
||||
"pipelines": {
|
||||
"title": "Pipelines",
|
||||
"description": "Multiple pipelines can be defined using the components from the same YAML file.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"title": "Name", "description": "Name of the pipeline.", "type": "string"},
|
||||
"nodes": {
|
||||
"title": "Nodes",
|
||||
"description": "Nodes to be used by this particular pipeline",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "The name of this particular node in the pipeline. This should be one of the names from the components defined in the same file.",
|
||||
"type": "string",
|
||||
},
|
||||
"inputs": {
|
||||
"title": "Inputs",
|
||||
"description": "Input parameters for this node.",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["name", "inputs"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"required": ["name", "nodes"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["version", "components", "pipelines"],
|
||||
"additionalProperties": False,
|
||||
"definitions": schema_definitions,
|
||||
}
|
||||
return pipeline_schema
|
||||
|
||||
|
||||
def inject_definition_in_schema(node: BaseComponent, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a node and a schema in dict form, injects the JSON schema for the new component
|
||||
so that pipelines containing such note can be validated against it.
|
||||
|
||||
:returns: the updated schema
|
||||
"""
|
||||
schema_definition, node_ref = create_schema_for_node(node)
|
||||
schema["definitions"].update(schema_definition)
|
||||
schema["properties"]["components"]["items"]["anyOf"].append(node_ref)
|
||||
logger.info(f"Added definition for {getattr(node, '__name__')}")
|
||||
return schema
|
||||
|
||||
|
||||
def natural_sort(list_to_sort: List[str]) -> List[str]:
|
||||
"""Sorts a list keeping numbers in the correct numerical order"""
|
||||
convert = lambda text: int(text) if text.isdigit() else text.lower()
|
||||
alphanumeric_key = lambda key: [convert(c) for c in re.split("([0-9]+)", key)]
|
||||
return sorted(list_to_sort, key=alphanumeric_key)
|
||||
|
||||
|
||||
def load(path: Path) -> Dict[str, Any]:
|
||||
"""Shorthand for loading a JSON"""
|
||||
with open(path, "r") as json_file:
|
||||
return json.load(json_file)
|
||||
|
||||
|
||||
def dump(data: Dict[str, Any], path: Path) -> None:
|
||||
"""Shorthand for dumping to JSON"""
|
||||
with open(path, "w") as json_file:
|
||||
json.dump(data, json_file, indent=2)
|
||||
|
||||
|
||||
def new_version_entry(version):
|
||||
"""
|
||||
Returns a new entry for the version index JSON schema.
|
||||
"""
|
||||
return {
|
||||
"allOf": [
|
||||
{"properties": {"version": {"oneOf": [{"const": version}]}}},
|
||||
{
|
||||
"$ref": "https://raw.githubusercontent.com/deepset-ai/haystack/master/json-schemas/"
|
||||
f"haystack-pipeline-{version}.schema.json"
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def update_json_schema(
|
||||
update_index: bool,
|
||||
destination_path: Path = JSON_SCHEMAS_PATH,
|
||||
index_path: Path = JSON_SCHEMAS_PATH / "haystack-pipeline.schema.json",
|
||||
):
|
||||
# Locate the latest schema's path
|
||||
latest_schema_path = destination_path / Path(
|
||||
natural_sort(os.listdir(destination_path))[-3]
|
||||
) # -1 is index, -2 is unstable
|
||||
logger.info(f"Latest schema: {latest_schema_path}")
|
||||
latest_schema = load(latest_schema_path)
|
||||
|
||||
# List the versions supported by the last schema
|
||||
supported_versions_block = deepcopy(latest_schema["properties"]["version"]["oneOf"])
|
||||
supported_versions = [entry["const"].replace('"', "") for entry in supported_versions_block]
|
||||
logger.info(f"Versions supported by this schema: {supported_versions}")
|
||||
|
||||
# Create new schema with the same filename and versions embedded, to be identical to the latest one.
|
||||
new_schema = get_json_schema(latest_schema_path.name, supported_versions)
|
||||
|
||||
# Check for backwards compatibility with difflib's SequenceMatcher
|
||||
# (https://docs.python.org/3/library/difflib.html#difflib.SequenceMatcher)
|
||||
# If the opcodes contain only "insert" and "equal", that means the new schema
|
||||
# only added lines and did not remove anything from the previous schema.
|
||||
# We decided that additions only imply backwards compatibility.
|
||||
# Any other opcode ("replace", "delete") imply that something has been removed
|
||||
# in the new schema, which breaks backwards compatibility and means we should
|
||||
# store a new, separate schema.
|
||||
# People wishing to upgrade from the older schema version will have to change
|
||||
# version in their YAML to avoid failing validation.
|
||||
latest_schema_string = json.dumps(latest_schema)
|
||||
new_schema_string = json.dumps(new_schema)
|
||||
matcher = SequenceMatcher(None, latest_schema_string, new_schema_string)
|
||||
schema_diff = matcher.get_opcodes()
|
||||
is_backwards_incompatible = any(opcode[0] not in ["insert", "equal"] for opcode in schema_diff)
|
||||
|
||||
unstable_versions_block = []
|
||||
|
||||
# If the two schemas are incompatible, we need a new file.
|
||||
# Update the schema's filename and supported versions, then save it.
|
||||
if is_backwards_incompatible:
|
||||
|
||||
# Print a quick diff to explain the differences
|
||||
logger.info(f"The schemas are NOT backwards compatible. This is the list of INCOMPATIBLE changes only:")
|
||||
for tag, i1, i2, j1, j2 in schema_diff:
|
||||
if tag not in ["equal", "insert"]:
|
||||
logger.info("{!r:>8} --> {!r}".format(latest_schema_string[i1:i2], new_schema_string[j1:j2]))
|
||||
|
||||
filename = f"haystack-pipeline-{haystack_version}.schema.json"
|
||||
logger.info(f"Adding {filename} to the schema folder.")
|
||||
|
||||
# Let's check if the schema changed without a version change
|
||||
if haystack_version in supported_versions and len(supported_versions) > 1:
|
||||
logger.info(
|
||||
f"Version {haystack_version} was supported by the latest schema"
|
||||
f"(supported versions: {supported_versions}). "
|
||||
f"Removing support for version {haystack_version} from it."
|
||||
)
|
||||
|
||||
supported_versions_block = [
|
||||
entry for entry in supported_versions_block if entry["const"].replace('"', "") != haystack_version
|
||||
]
|
||||
latest_schema["properties"]["version"]["oneOf"] = supported_versions_block
|
||||
dump(latest_schema, latest_schema_path)
|
||||
|
||||
# Update the JSON schema index too
|
||||
if update_index:
|
||||
index = load(index_path)
|
||||
index["oneOf"][-1]["allOf"][0]["properties"]["version"]["oneOf"] = supported_versions_block
|
||||
dump(index, index_path)
|
||||
|
||||
# Dump the new schema file
|
||||
new_schema["$id"] = f"{SCHEMA_URL}{filename}"
|
||||
unstable_versions_block = [{"const": haystack_version}]
|
||||
new_schema["properties"]["version"]["oneOf"] = [{"const": haystack_version}]
|
||||
dump(new_schema, destination_path / filename)
|
||||
|
||||
# Update schema index with a whole new entry
|
||||
if update_index:
|
||||
index = load(index_path)
|
||||
new_entry = new_version_entry(haystack_version)
|
||||
if all(new_entry != entry for entry in index["oneOf"]):
|
||||
index["oneOf"].append(new_version_entry(haystack_version))
|
||||
dump(index, index_path)
|
||||
|
||||
# If the two schemas are compatible, no need to write a new one:
|
||||
# Just add the new version to the list of versions supported by
|
||||
# the latest schema if it's not there yet
|
||||
else:
|
||||
|
||||
# Print a quick diff to explain the differences
|
||||
if not schema_diff or all(tag[0] == "equal" for tag in schema_diff):
|
||||
logger.info("The schemas are identical, won't create a new file.")
|
||||
else:
|
||||
logger.info("The schemas are backwards compatible, overwriting the latest schema.")
|
||||
logger.info("This is the list of changes:")
|
||||
for tag, i1, i2, j1, j2 in schema_diff:
|
||||
if tag not in "equal":
|
||||
logger.info("{!r:>8} --> {!r}".format(latest_schema_string[i1:i2], new_schema_string[j1:j2]))
|
||||
|
||||
# Overwrite the latest schema (safe to do for additions)
|
||||
dump(new_schema, latest_schema_path)
|
||||
|
||||
if haystack_version in supported_versions:
|
||||
unstable_versions_block = supported_versions_block
|
||||
logger.info(
|
||||
f"Version {haystack_version} was already supported " f"(supported versions: {supported_versions})"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"This version ({haystack_version}) was not listed "
|
||||
f"(supported versions: {supported_versions}): "
|
||||
"updating the supported versions list."
|
||||
)
|
||||
|
||||
# Updating the latest schema's list of supported versions
|
||||
supported_versions_block.append({"const": haystack_version})
|
||||
unstable_versions_block = supported_versions_block
|
||||
latest_schema["properties"]["version"]["oneOf"] = supported_versions_block
|
||||
dump(latest_schema, latest_schema_path)
|
||||
|
||||
# Update the JSON schema index too
|
||||
if update_index:
|
||||
index = load(index_path)
|
||||
index["oneOf"][-1]["allOf"][0]["properties"]["version"]["oneOf"] = supported_versions_block
|
||||
dump(index, index_path)
|
||||
|
||||
# Update the unstable schema (for tests and internal use).
|
||||
unstable_filename = "haystack-pipeline-unstable.schema.json"
|
||||
unstable_schema = deepcopy(new_schema)
|
||||
unstable_schema["$id"] = f"{SCHEMA_URL}{unstable_filename}"
|
||||
unstable_schema["properties"]["version"]["oneOf"] = [{"const": "unstable"}] + unstable_versions_block
|
||||
dump(unstable_schema, destination_path / unstable_filename)
|
||||
@ -23,11 +23,6 @@ from haystack.nodes.retriever.dense import DensePassageRetriever
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RAGeneratorType(Enum):
|
||||
TOKEN = (1,)
|
||||
SEQUENCE = 2
|
||||
|
||||
|
||||
class RAGenerator(BaseGenerator):
|
||||
"""
|
||||
Implementation of Facebook's Retrieval-Augmented Generator (https://arxiv.org/abs/2005.11401) based on
|
||||
@ -76,7 +71,7 @@ class RAGenerator(BaseGenerator):
|
||||
model_name_or_path: str = "facebook/rag-token-nq",
|
||||
model_version: Optional[str] = None,
|
||||
retriever: Optional[DensePassageRetriever] = None,
|
||||
generator_type: RAGeneratorType = RAGeneratorType.TOKEN,
|
||||
generator_type: str = "token",
|
||||
top_k: int = 2,
|
||||
max_length: int = 200,
|
||||
min_length: int = 2,
|
||||
@ -94,7 +89,7 @@ class RAGenerator(BaseGenerator):
|
||||
See https://huggingface.co/models for full list of available models.
|
||||
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
:param retriever: `DensePassageRetriever` used to embedded passages for the docs passed to `predict()`. This is optional and is only needed if the docs you pass don't already contain embeddings in `Document.embedding`.
|
||||
:param generator_type: Which RAG generator implementation to use? RAG-TOKEN or RAG-SEQUENCE
|
||||
:param generator_type: Which RAG generator implementation to use ("token" or "sequence")
|
||||
:param top_k: Number of independently generated text to return
|
||||
:param max_length: Maximum length of generated text
|
||||
:param min_length: Minimum length of generated text
|
||||
@ -103,21 +98,7 @@ class RAGenerator(BaseGenerator):
|
||||
:param prefix: The prefix used by the generator's tokenizer.
|
||||
:param use_gpu: Whether to use GPU. Falls back on CPU if no GPU is available.
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
model_name_or_path=model_name_or_path,
|
||||
model_version=model_version,
|
||||
retriever=retriever,
|
||||
generator_type=generator_type,
|
||||
top_k=top_k,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
num_beams=num_beams,
|
||||
embed_title=embed_title,
|
||||
prefix=prefix,
|
||||
use_gpu=use_gpu,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.max_length = max_length
|
||||
@ -138,7 +119,7 @@ class RAGenerator(BaseGenerator):
|
||||
|
||||
self.tokenizer = RagTokenizer.from_pretrained(model_name_or_path)
|
||||
|
||||
if self.generator_type == RAGeneratorType.SEQUENCE:
|
||||
if self.generator_type == "sequence":
|
||||
raise NotImplementedError("RagSequenceForGeneration is not implemented yet")
|
||||
# TODO: Enable when transformers have it. Refer https://github.com/huggingface/transformers/issues/7905
|
||||
# Also refer refer https://github.com/huggingface/transformers/issues/7829
|
||||
@ -361,7 +342,7 @@ class Seq2SeqGenerator(BaseGenerator):
|
||||
:param num_beams: Number of beams for beam search. 1 means no beam search.
|
||||
:param use_gpu: Whether to use GPU or the CPU. Falls back on CPU if no GPU is available.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.max_length = max_length
|
||||
self.min_length = min_length
|
||||
|
||||
@ -1,42 +1,84 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Callable, Optional, Dict, List, Tuple, Optional
|
||||
from typing import Any, Optional, Dict, List, Tuple, Optional
|
||||
|
||||
import io
|
||||
from functools import wraps
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import wraps
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from haystack.schema import Document, MultiLabel
|
||||
from haystack.errors import HaystackError
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseComponent:
|
||||
def exportable_to_yaml(init_func):
|
||||
"""
|
||||
Decorator that saves the init parameters of a node that later can
|
||||
be used with exporting YAML configuration of a Pipeline.
|
||||
"""
|
||||
|
||||
@wraps(init_func)
|
||||
def wrapper_exportable_to_yaml(self, *args, **kwargs):
|
||||
|
||||
# Call the actuall __init__ function with all the arguments
|
||||
init_func(self, *args, **kwargs)
|
||||
|
||||
# Warn for unnamed input params - should be rare
|
||||
if args:
|
||||
logger.warning(
|
||||
"Unnamed __init__ parameters will not be saved to YAML if Pipeline.save_to_yaml() is called!"
|
||||
)
|
||||
# Create the configuration dictionary if it doesn't exist yet
|
||||
if not self._component_config:
|
||||
self._component_config = {"params": {}, "type": type(self).__name__}
|
||||
|
||||
# Make sure it runs only on the __init__of the implementations, not in superclasses
|
||||
if init_func.__qualname__ == f"{self.__class__.__name__}.{init_func.__name__}":
|
||||
|
||||
# Store all the named input parameters in self._component_config
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, BaseComponent):
|
||||
self._component_config["params"][k] = v._component_config
|
||||
elif v is not None:
|
||||
self._component_config["params"][k] = v
|
||||
|
||||
return wrapper_exportable_to_yaml
|
||||
|
||||
|
||||
class BaseComponent(ABC):
|
||||
"""
|
||||
A base class for implementing nodes in a Pipeline.
|
||||
"""
|
||||
|
||||
outgoing_edges: int
|
||||
subclasses: dict = {}
|
||||
pipeline_config: dict = {}
|
||||
name: Optional[str] = None
|
||||
_subclasses: dict = {}
|
||||
_component_config: dict = {}
|
||||
|
||||
# __init_subclass__ is invoked when a subclass of BaseComponent is _imported_
|
||||
# (not instantiated). It works approximately as a metaclass.
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Automatically keeps track of all available subclasses.
|
||||
Enables generic load() for all specific component implementations.
|
||||
"""
|
||||
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls.subclasses[cls.__name__] = cls
|
||||
|
||||
# Automatically registers all the init parameters in
|
||||
# an instance attribute called `_component_config`,
|
||||
# used to save this component to YAML. See exportable_to_yaml()
|
||||
cls.__init__ = exportable_to_yaml(cls.__init__)
|
||||
|
||||
# Keeps track of all available subclasses by name.
|
||||
# Enables generic load() for all specific component implementations.
|
||||
cls._subclasses[cls.__name__] = cls
|
||||
|
||||
@classmethod
|
||||
def get_subclass(cls, component_type: str):
|
||||
if component_type not in cls.subclasses.keys():
|
||||
raise Exception(f"Haystack component with the name '{component_type}' does not exist.")
|
||||
subclass = cls.subclasses[component_type]
|
||||
if component_type not in cls._subclasses.keys():
|
||||
raise HaystackError(f"Haystack component with the name '{component_type}' does not exist.")
|
||||
subclass = cls._subclasses[component_type]
|
||||
return subclass
|
||||
|
||||
@classmethod
|
||||
@ -165,18 +207,3 @@ class BaseComponent:
|
||||
|
||||
output["params"] = params
|
||||
return output, stream
|
||||
|
||||
def set_config(self, **kwargs):
|
||||
"""
|
||||
Save the init parameters of a component that later can be used with exporting
|
||||
YAML configuration of a Pipeline.
|
||||
|
||||
:param kwargs: all parameters passed to the __init__() of the Component.
|
||||
"""
|
||||
if not self.pipeline_config:
|
||||
self.pipeline_config = {"params": {}, "type": type(self).__name__}
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, BaseComponent):
|
||||
self.pipeline_config["params"][k] = v.pipeline_config
|
||||
elif v is not None:
|
||||
self.pipeline_config["params"][k] = v
|
||||
|
||||
@ -58,6 +58,8 @@ class Crawler(BaseComponent):
|
||||
All URLs not matching at least one of the regular expressions will be dropped.
|
||||
:param overwrite_existing_files: Whether to overwrite existing files in output_dir with new content
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
IN_COLAB = "google.colab" in sys.modules
|
||||
|
||||
options = webdriver.chrome.options.Options()
|
||||
|
||||
@ -101,18 +101,8 @@ class TransformersDocumentClassifier(BaseDocumentClassifier):
|
||||
:param batch_size: batch size to be processed at once
|
||||
:param classification_field: Name of Document's meta field to be used for classification. If left unset, Document.content is used by default.
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
model_name_or_path=model_name_or_path,
|
||||
model_version=model_version,
|
||||
tokenizer=tokenizer,
|
||||
use_gpu=use_gpu,
|
||||
return_all_scores=return_all_scores,
|
||||
labels=labels,
|
||||
task=task,
|
||||
batch_size=batch_size,
|
||||
classification_field=classification_field,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
if labels and task == "text-classification":
|
||||
logger.warning(
|
||||
f"Provided labels {labels} will be ignored for task text-classification. Set task to "
|
||||
|
||||
@ -41,6 +41,7 @@ class EvalDocuments(BaseComponent):
|
||||
"EvalDocuments node is deprecated and will be removed in a future version. "
|
||||
"Please use pipeline.eval() instead."
|
||||
)
|
||||
super().__init__()
|
||||
self.init_counts()
|
||||
self.no_answer_warning = False
|
||||
self.debug = debug
|
||||
@ -205,6 +206,7 @@ class EvalAnswers(BaseComponent):
|
||||
"EvalAnswers node is deprecated and will be removed in a future version. "
|
||||
"Please use pipeline.eval() instead."
|
||||
)
|
||||
super().__init__()
|
||||
self.log: List = []
|
||||
self.debug = debug
|
||||
self.skip_incorrect_retrieval = skip_incorrect_retrieval
|
||||
|
||||
@ -21,8 +21,8 @@ class EntityExtractor(BaseComponent):
|
||||
outgoing_edges = 1
|
||||
|
||||
def __init__(self, model_name_or_path: str = "dslim/bert-base-NER", use_gpu: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.set_config(model_name_or_path=model_name_or_path)
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
|
||||
@ -30,7 +30,8 @@ class FileTypeClassifier(BaseComponent):
|
||||
if len(set(supported_types)) != len(supported_types):
|
||||
raise ValueError("supported_types can't contain duplicate values.")
|
||||
|
||||
self.set_config(supported_types=supported_types)
|
||||
super().__init__()
|
||||
|
||||
self.supported_types = supported_types
|
||||
|
||||
def _get_extension(self, file_paths: List[Path]) -> str:
|
||||
|
||||
@ -57,17 +57,7 @@ class AzureConverter(BaseConverter):
|
||||
This parameter lets you choose, whether to merge multiple column header
|
||||
rows to a single row.
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
endpoint=endpoint,
|
||||
credential_key=credential_key,
|
||||
model_id=model_id,
|
||||
valid_languages=valid_languages,
|
||||
save_json=save_json,
|
||||
preceding_context_len=preceding_context_len,
|
||||
following_context_len=following_context_len,
|
||||
merge_multiple_column_headers=merge_multiple_column_headers,
|
||||
)
|
||||
super().__init__(valid_languages=valid_languages)
|
||||
|
||||
self.document_analysis_client = DocumentAnalysisClient(
|
||||
endpoint=endpoint, credential=AzureKeyCredential(credential_key)
|
||||
@ -79,8 +69,6 @@ class AzureConverter(BaseConverter):
|
||||
self.following_context_len = following_context_len
|
||||
self.merge_multiple_column_headers = merge_multiple_column_headers
|
||||
|
||||
super().__init__(valid_languages=valid_languages)
|
||||
|
||||
def convert(
|
||||
self,
|
||||
file_path: Path,
|
||||
|
||||
@ -27,9 +27,7 @@ class BaseConverter(BaseComponent):
|
||||
not one of the valid languages, then it might likely be encoding error resulting
|
||||
in garbled text.
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)
|
||||
super().__init__()
|
||||
|
||||
self.remove_numeric_tables = remove_numeric_tables
|
||||
self.valid_languages = valid_languages
|
||||
|
||||
@ -35,9 +35,7 @@ class ImageToTextConverter(BaseConverter):
|
||||
# List of available languages
|
||||
print(pytesseract.get_languages(config=''))
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)
|
||||
super().__init__(remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)
|
||||
|
||||
verify_installation = subprocess.run(["tesseract -v"], shell=True)
|
||||
if verify_installation.returncode == 127:
|
||||
|
||||
@ -55,18 +55,7 @@ class ParsrConverter(BaseConverter):
|
||||
not one of the valid languages, then it might likely be encoding error resulting
|
||||
in garbled text.
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
parsr_url=parsr_url,
|
||||
extractor=extractor,
|
||||
table_detection_mode=table_detection_mode,
|
||||
preceding_context_len=preceding_context_len,
|
||||
following_context_len=following_context_len,
|
||||
remove_page_headers=remove_page_headers,
|
||||
remove_page_footers=remove_page_footers,
|
||||
remove_table_of_contents=remove_table_of_contents,
|
||||
valid_languages=valid_languages,
|
||||
)
|
||||
super().__init__(valid_languages=valid_languages)
|
||||
|
||||
try:
|
||||
ping = requests.get(parsr_url)
|
||||
|
||||
@ -33,8 +33,7 @@ class PDFToTextConverter(BaseConverter):
|
||||
not one of the valid languages, then it might likely be encoding error resulting
|
||||
in garbled text.
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)
|
||||
super().__init__(remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)
|
||||
|
||||
verify_installation = subprocess.run(["pdftotext -v"], shell=True)
|
||||
if verify_installation.returncode == 127:
|
||||
@ -170,8 +169,6 @@ class PDFToTextOCRConverter(BaseConverter):
|
||||
# init image to text instance
|
||||
self.image_2_text = ImageToTextConverter(remove_numeric_tables, valid_languages)
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)
|
||||
super().__init__(remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)
|
||||
|
||||
def convert(
|
||||
|
||||
@ -59,9 +59,7 @@ class TikaConverter(BaseConverter):
|
||||
not one of the valid languages, then it might likely be encoding error resulting
|
||||
in garbled text.
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(tika_url=tika_url, remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)
|
||||
super().__init__(remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)
|
||||
|
||||
ping = requests.get(tika_url)
|
||||
if ping.status_code != 200:
|
||||
|
||||
@ -13,9 +13,6 @@ class Docs2Answers(BaseComponent):
|
||||
|
||||
outgoing_edges = 1
|
||||
|
||||
def __init__(self):
|
||||
self.set_config()
|
||||
|
||||
def run(self, query: str, documents: List[Document]): # type: ignore
|
||||
# conversion from Document -> Answer
|
||||
answers: List[Answer] = []
|
||||
|
||||
@ -26,8 +26,7 @@ class JoinAnswers(BaseComponent):
|
||||
weights is not None and join_mode == "concatenate"
|
||||
), "Weights are not compatible with 'concatenate' join_mode"
|
||||
|
||||
# Save init parameters to enable export of component config as YAML
|
||||
self.set_config(join_mode=join_mode, weights=weights, top_k_join=top_k_join)
|
||||
super().__init__()
|
||||
|
||||
self.join_mode = join_mode
|
||||
self.weights = [float(i) / sum(weights) for i in weights] if weights else None
|
||||
|
||||
@ -39,8 +39,7 @@ class JoinDocuments(BaseComponent):
|
||||
weights is not None and join_mode == "concatenate"
|
||||
), "Weights are not compatible with 'concatenate' join_mode."
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(join_mode=join_mode, weights=weights, top_k_join=top_k_join)
|
||||
super().__init__()
|
||||
|
||||
self.join_mode = join_mode
|
||||
self.weights = [float(i) / sum(weights) for i in weights] if weights else None
|
||||
|
||||
@ -32,8 +32,7 @@ class RouteDocuments(BaseComponent):
|
||||
"to group the documents to."
|
||||
)
|
||||
|
||||
# Save init parameters to enable export of component config as YAML
|
||||
self.set_config(split_by=split_by, metadata_values=metadata_values)
|
||||
super().__init__()
|
||||
|
||||
self.split_by = split_by
|
||||
self.metadata_values = metadata_values
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
from abc import abstractmethod
|
||||
from haystack.nodes.base import BaseComponent
|
||||
|
||||
|
||||
class BasePreProcessor(BaseComponent):
|
||||
outgoing_edges = 1
|
||||
|
||||
@abstractmethod
|
||||
def process(
|
||||
self,
|
||||
documents: Union[dict, List[dict]],
|
||||
@ -23,6 +25,7 @@ class BasePreProcessor(BaseComponent):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clean(
|
||||
self,
|
||||
document: dict,
|
||||
@ -33,6 +36,7 @@ class BasePreProcessor(BaseComponent):
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def split(
|
||||
self,
|
||||
document: dict,
|
||||
|
||||
@ -72,18 +72,7 @@ class PreProcessor(BasePreProcessor):
|
||||
the number of words will be <= split_length.
|
||||
:param language: The language used by "nltk.tokenize.sent_tokenize" in iso639 format. Available options: "en", "es", "de", "fr" & many more.
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
clean_whitespace=clean_whitespace,
|
||||
clean_header_footer=clean_header_footer,
|
||||
clean_empty_lines=clean_empty_lines,
|
||||
remove_substrings=remove_substrings,
|
||||
split_by=split_by,
|
||||
split_length=split_length,
|
||||
split_overlap=split_overlap,
|
||||
split_respect_sentence_boundary=split_respect_sentence_boundary,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
@ -131,9 +120,9 @@ class PreProcessor(BasePreProcessor):
|
||||
|
||||
ret = []
|
||||
|
||||
if type(documents) == dict:
|
||||
if isinstance(documents, dict):
|
||||
ret = self._process_single(document=documents, **kwargs) # type: ignore
|
||||
elif type(documents) == list:
|
||||
elif isinstance(documents, list):
|
||||
ret = self._process_batch(documents=list(documents), **kwargs)
|
||||
|
||||
else:
|
||||
|
||||
@ -72,8 +72,7 @@ class SklearnQueryClassifier(BaseQueryClassifier):
|
||||
):
|
||||
raise TypeError("model_name_or_path and vectorizer_name_or_path must either be of type Path or str")
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(model_name_or_path=model_name_or_path, vectorizer_name_or_path=vectorizer_name_or_path)
|
||||
super().__init__()
|
||||
|
||||
if isinstance(model_name_or_path, Path):
|
||||
file_url = urllib.request.pathname2url(r"{}".format(model_name_or_path))
|
||||
|
||||
@ -63,8 +63,8 @@ class TransformersQueryClassifier(BaseQueryClassifier):
|
||||
:param model_name_or_path: Transformer based fine tuned mini bert model for query classification
|
||||
:param use_gpu: Whether to use GPU (if available).
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(model_name_or_path=model_name_or_path)
|
||||
super().__init__()
|
||||
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu)
|
||||
device = 0 if self.devices[0].type == "cuda" else -1
|
||||
|
||||
|
||||
@ -47,21 +47,11 @@ class QuestionGenerator(BaseComponent):
|
||||
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
:param use_gpu: Whether to use GPU or the CPU. Falls back on CPU if no GPU is available.
|
||||
"""
|
||||
super().__init__()
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
|
||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
|
||||
self.model.to(str(self.devices[0]))
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
self.set_config(
|
||||
model_name_or_path=model_name_or_path,
|
||||
model_version=model_version,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
length_penalty=length_penalty,
|
||||
early_stopping=early_stopping,
|
||||
split_length=split_length,
|
||||
split_overlap=split_overlap,
|
||||
)
|
||||
self.num_beams = num_beams
|
||||
self.max_length = max_length
|
||||
self.no_repeat_ngram_size = no_repeat_ngram_size
|
||||
|
||||
@ -52,9 +52,7 @@ class SentenceTransformersRanker(BaseRanker):
|
||||
:param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
|
||||
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(model_name_or_path=model_name_or_path, model_version=model_version, top_k=top_k)
|
||||
super().__init__()
|
||||
|
||||
self.top_k = top_k
|
||||
|
||||
|
||||
@ -113,30 +113,8 @@ class FARMReader(BaseReader):
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
model_name_or_path=model_name_or_path,
|
||||
model_version=model_version,
|
||||
context_window_size=context_window_size,
|
||||
batch_size=batch_size,
|
||||
use_gpu=use_gpu,
|
||||
no_ans_boost=no_ans_boost,
|
||||
return_no_answer=return_no_answer,
|
||||
top_k=top_k,
|
||||
top_k_per_candidate=top_k_per_candidate,
|
||||
top_k_per_sample=top_k_per_sample,
|
||||
num_processes=num_processes,
|
||||
max_seq_len=max_seq_len,
|
||||
doc_stride=doc_stride,
|
||||
progress_bar=progress_bar,
|
||||
duplicate_filtering=duplicate_filtering,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
force_download=force_download,
|
||||
use_confidence_scores=use_confidence_scores,
|
||||
**kwargs,
|
||||
)
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
|
||||
|
||||
self.return_no_answers = return_no_answer
|
||||
@ -175,6 +153,7 @@ class FARMReader(BaseReader):
|
||||
self.use_gpu = use_gpu
|
||||
self.progress_bar = progress_bar
|
||||
self.use_confidence_scores = use_confidence_scores
|
||||
self.model_name_or_path = model_name_or_path # Used in distillation, see DistillationDataSilo._get_checksum()
|
||||
|
||||
def _training_procedure(
|
||||
self,
|
||||
|
||||
@ -95,17 +95,7 @@ class TableReader(BaseReader):
|
||||
query + table exceed max_seq_len, the table will be truncated by removing rows until the
|
||||
input size fits the model.
|
||||
"""
|
||||
# Save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
model_name_or_path=model_name_or_path,
|
||||
model_version=model_version,
|
||||
tokenizer=tokenizer,
|
||||
use_gpu=use_gpu,
|
||||
top_k=top_k,
|
||||
top_k_per_candidate=top_k_per_candidate,
|
||||
return_no_answer=return_no_answer,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
|
||||
config = TapasConfig.from_pretrained(model_name_or_path)
|
||||
@ -480,18 +470,7 @@ class RCIReader(BaseReader):
|
||||
query + table exceed max_seq_len, the table will be truncated by removing rows until the
|
||||
input size fits the model.
|
||||
"""
|
||||
# Save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
row_model_name_or_path=row_model_name_or_path,
|
||||
column_model_name_or_path=column_model_name_or_path,
|
||||
row_model_version=row_model_version,
|
||||
column_model_version=column_model_version,
|
||||
row_tokenizer=row_tokenizer,
|
||||
column_tokenizer=column_tokenizer,
|
||||
use_gpu=use_gpu,
|
||||
top_k=top_k,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
|
||||
self.row_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
|
||||
@ -60,19 +60,7 @@ class TransformersReader(BaseReader):
|
||||
:param max_seq_len: max sequence length of one input text for the model
|
||||
:param doc_stride: length of striding window for splitting long texts (used if len(text) > max_seq_len)
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
model_name_or_path=model_name_or_path,
|
||||
model_version=model_version,
|
||||
tokenizer=tokenizer,
|
||||
context_window_size=context_window_size,
|
||||
use_gpu=use_gpu,
|
||||
top_k=top_k,
|
||||
doc_stride=doc_stride,
|
||||
top_k_per_candidate=top_k_per_candidate,
|
||||
return_no_answers=return_no_answers,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
|
||||
device = 0 if self.devices[0].type == "cuda" else -1
|
||||
|
||||
@ -8,6 +8,7 @@ from tqdm import tqdm
|
||||
from copy import deepcopy
|
||||
|
||||
from haystack.schema import Document, MultiLabel
|
||||
from haystack.errors import HaystackError
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.document_stores.base import BaseDocumentStore, BaseKnowledgeGraph
|
||||
|
||||
@ -240,6 +241,10 @@ class BaseRetriever(BaseComponent):
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
if root_node == "Query":
|
||||
if not query:
|
||||
raise HaystackError(
|
||||
"Must provide a 'query' parameter for retrievers in pipelines where Query is the root node."
|
||||
)
|
||||
self.query_count += 1
|
||||
run_query_timed = self.timing(self.run_query, "query_time")
|
||||
output, stream = run_query_timed(query=query, filters=filters, top_k=top_k, index=index, headers=headers)
|
||||
|
||||
@ -108,25 +108,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
document_store=document_store,
|
||||
query_embedding_model=query_embedding_model,
|
||||
passage_embedding_model=passage_embedding_model,
|
||||
model_version=model_version,
|
||||
max_seq_len_query=max_seq_len_query,
|
||||
max_seq_len_passage=max_seq_len_passage,
|
||||
top_k=top_k,
|
||||
use_gpu=use_gpu,
|
||||
batch_size=batch_size,
|
||||
embed_title=embed_title,
|
||||
use_fast_tokenizers=use_fast_tokenizers,
|
||||
infer_tokenizer_classes=infer_tokenizer_classes,
|
||||
similarity_function=similarity_function,
|
||||
progress_bar=progress_bar,
|
||||
devices=devices,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
if devices is not None:
|
||||
self.devices = devices
|
||||
@ -606,27 +588,7 @@ class TableTextRetriever(BaseRetriever):
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
document_store=document_store,
|
||||
query_embedding_model=query_embedding_model,
|
||||
passage_embedding_model=passage_embedding_model,
|
||||
table_embedding_model=table_embedding_model,
|
||||
model_version=model_version,
|
||||
max_seq_len_query=max_seq_len_query,
|
||||
max_seq_len_passage=max_seq_len_passage,
|
||||
max_seq_len_table=max_seq_len_table,
|
||||
top_k=top_k,
|
||||
use_gpu=use_gpu,
|
||||
batch_size=batch_size,
|
||||
embed_meta_fields=embed_meta_fields,
|
||||
use_fast_tokenizers=use_fast_tokenizers,
|
||||
infer_tokenizer_classes=infer_tokenizer_classes,
|
||||
similarity_function=similarity_function,
|
||||
progress_bar=progress_bar,
|
||||
devices=devices,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
if devices is not None:
|
||||
self.devices = devices
|
||||
@ -1145,19 +1107,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
document_store=document_store,
|
||||
embedding_model=embedding_model,
|
||||
model_version=model_version,
|
||||
use_gpu=use_gpu,
|
||||
batch_size=batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
model_format=model_format,
|
||||
pooling_strategy=pooling_strategy,
|
||||
emb_extraction_layer=emb_extraction_layer,
|
||||
top_k=top_k,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
if devices is not None:
|
||||
self.devices = devices
|
||||
|
||||
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticsearchRetriever(BaseRetriever):
|
||||
def __init__(self, document_store: KeywordDocumentStore, top_k: int = 10, custom_query: str = None):
|
||||
def __init__(self, document_store: KeywordDocumentStore, top_k: int = 10, custom_query: Optional[str] = None):
|
||||
"""
|
||||
:param document_store: an instance of an ElasticsearchDocumentStore to retrieve documents from.
|
||||
:param custom_query: query string as per Elasticsearch DSL with a mandatory query placeholder(query).
|
||||
@ -87,8 +87,7 @@ class ElasticsearchRetriever(BaseRetriever):
|
||||
|
||||
:param top_k: How many documents to return per query.
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(document_store=document_store, top_k=top_k, custom_query=custom_query)
|
||||
super().__init__()
|
||||
self.document_store: KeywordDocumentStore = document_store
|
||||
self.top_k = top_k
|
||||
self.custom_query = custom_query
|
||||
@ -176,8 +175,7 @@ class TfidfRetriever(BaseRetriever):
|
||||
:param top_k: How many documents to return per query.
|
||||
:param auto_fit: Whether to automatically update tf-idf matrix by calling fit() after new documents have been added
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(document_store=document_store, top_k=top_k, auto_fit=auto_fit)
|
||||
super().__init__()
|
||||
|
||||
self.vectorizer = TfidfVectorizer(
|
||||
lowercase=True, stop_words=None, token_pattern=r"(?u)\b\w\w+\b", ngram_range=(1, 1)
|
||||
|
||||
@ -23,8 +23,7 @@ class Text2SparqlRetriever(BaseGraphRetriever):
|
||||
:param model_name_or_path: Name of or path to a pre-trained BartForConditionalGeneration model.
|
||||
:param top_k: How many SPARQL queries to generate per text query.
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(knowledge_graph=knowledge_graph, model_name_or_path=model_name_or_path, top_k=top_k)
|
||||
super().__init__()
|
||||
|
||||
self.knowledge_graph = knowledge_graph
|
||||
# TODO We should extend this to any seq2seq models and use the AutoModel class
|
||||
|
||||
@ -82,18 +82,7 @@ class TransformersSummarizer(BaseSummarizer):
|
||||
be summarized.
|
||||
Important: The summary will depend on the order of the supplied documents!
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
model_name_or_path=model_name_or_path,
|
||||
model_version=model_version,
|
||||
tokenizer=tokenizer,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
use_gpu=use_gpu,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
separator_for_single_summary=separator_for_single_summary,
|
||||
generate_single_summary=generate_single_summary,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu)
|
||||
device = 0 if self.devices[0].type == "cuda" else -1
|
||||
|
||||
@ -60,14 +60,7 @@ class TransformersTranslator(BaseTranslator):
|
||||
:param clean_up_tokenization_spaces: Whether or not to clean up the tokenization spaces. (default True)
|
||||
:param use_gpu: Whether to use GPU or the CPU. Falls back on CPU if no GPU is available.
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
model_name_or_path=model_name_or_path,
|
||||
tokenizer_name=tokenizer_name,
|
||||
max_seq_len=max_seq_len,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from os import pipe
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
import copy
|
||||
@ -10,7 +11,12 @@ import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import networkx as nx
|
||||
from abc import ABC, abstractmethod
|
||||
from jsonschema import Draft7Validator
|
||||
from jsonschema.exceptions import ValidationError
|
||||
from jsonschema import _utils as jsonschema_utils
|
||||
from pandas.core.frame import DataFrame
|
||||
from transformers import pipelines
|
||||
import yaml
|
||||
from networkx import DiGraph
|
||||
from networkx.drawing.nx_agraph import to_agraph
|
||||
@ -19,7 +25,14 @@ from haystack.nodes.evaluator.evaluator import (
|
||||
calculate_f1_str_multi,
|
||||
semantic_answer_similarity,
|
||||
)
|
||||
from haystack.pipelines.config import get_component_definitions, get_pipeline_definition, read_pipeline_config_from_yaml
|
||||
from haystack.pipelines.config import (
|
||||
JSON_SCHEMAS_PATH,
|
||||
get_component_definitions,
|
||||
get_pipeline_definition,
|
||||
read_pipeline_config_from_yaml,
|
||||
validate_config_strings,
|
||||
validate_config,
|
||||
)
|
||||
from haystack.pipelines.utils import generate_code, print_eval_report
|
||||
from haystack.utils import DeepsetCloud
|
||||
|
||||
@ -32,6 +45,7 @@ except:
|
||||
|
||||
from haystack import __version__
|
||||
from haystack.schema import EvaluationResult, MultiLabel, Document
|
||||
from haystack.errors import PipelineError, PipelineConfigError
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.nodes.retriever.base import BaseRetriever
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
@ -55,22 +69,24 @@ class RootNode(BaseComponent):
|
||||
return {}, "output_1"
|
||||
|
||||
|
||||
class BasePipeline:
|
||||
class BasePipeline(ABC):
|
||||
"""
|
||||
Base class for pipelines, providing the most basic methods to load and save them in different ways.
|
||||
See also the `Pipeline` class for the actual pipeline logic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def run(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("This is an abstract method. Use Pipeline or RayPipeline instead.")
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self, return_defaults: bool = False) -> dict:
|
||||
"""
|
||||
Returns a configuration for the Pipeline that can be used with `BasePipeline.load_from_config()`.
|
||||
Returns a configuration for the Pipeline that can be used with `Pipeline.load_from_config()`.
|
||||
|
||||
:param return_defaults: whether to output parameters that have the default values.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("This is an abstract method. Use Pipeline or RayPipeline instead.")
|
||||
|
||||
def to_code(
|
||||
self, pipeline_variable_name: str = "pipeline", generate_imports: bool = True, add_comment: bool = False
|
||||
@ -121,6 +137,7 @@ class BasePipeline:
|
||||
logger.error("Could not create notebook cell. Make sure you're running in a notebook environment.")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load_from_config(
|
||||
cls, pipeline_config: Dict, pipeline_name: Optional[str] = None, overwrite_with_env_variables: bool = True
|
||||
):
|
||||
@ -169,26 +186,10 @@ class BasePipeline:
|
||||
variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an
|
||||
`_` sign must be used to specify nested hierarchical properties.
|
||||
"""
|
||||
pipeline_definition = get_pipeline_definition(pipeline_config=pipeline_config, pipeline_name=pipeline_name)
|
||||
if pipeline_definition["type"] == "Pipeline":
|
||||
return Pipeline.load_from_config(
|
||||
pipeline_config=pipeline_config,
|
||||
pipeline_name=pipeline_name,
|
||||
overwrite_with_env_variables=overwrite_with_env_variables,
|
||||
)
|
||||
elif pipeline_definition["type"] == "RayPipeline":
|
||||
return RayPipeline.load_from_config(
|
||||
pipeline_config=pipeline_config,
|
||||
pipeline_name=pipeline_name,
|
||||
overwrite_with_env_variables=overwrite_with_env_variables,
|
||||
)
|
||||
else:
|
||||
raise KeyError(
|
||||
f"Pipeline Type '{pipeline_definition['type']}' is not a valid. The available types are"
|
||||
f"'Pipeline' and 'RayPipeline'."
|
||||
)
|
||||
raise NotImplementedError("This is an abstract method. Use Pipeline or RayPipeline instead.")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load_from_yaml(cls, path: Path, pipeline_name: Optional[str] = None, overwrite_with_env_variables: bool = True):
|
||||
"""
|
||||
Load Pipeline from a YAML file defining the individual components and how they're tied together to form
|
||||
@ -235,21 +236,7 @@ class BasePipeline:
|
||||
variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an
|
||||
`_` sign must be used to specify nested hierarchical properties.
|
||||
"""
|
||||
|
||||
pipeline_config = read_pipeline_config_from_yaml(path)
|
||||
if pipeline_config["version"] != __version__:
|
||||
logger.warning(
|
||||
f"YAML version ({pipeline_config['version']}) does not match with Haystack version ({__version__}). "
|
||||
"Issues may occur during loading. "
|
||||
"To fix this warning, save again this pipeline with the current Haystack version using Pipeline.save_to_yaml(), "
|
||||
"check out our migration guide at https://haystack.deepset.ai/overview/migration "
|
||||
f"or downgrade to haystack version {__version__}."
|
||||
)
|
||||
return cls.load_from_config(
|
||||
pipeline_config=pipeline_config,
|
||||
pipeline_name=pipeline_name,
|
||||
overwrite_with_env_variables=overwrite_with_env_variables,
|
||||
)
|
||||
raise NotImplementedError("This is an abstract method. Use Pipeline or RayPipeline instead.")
|
||||
|
||||
@classmethod
|
||||
def load_from_deepset_cloud(
|
||||
@ -302,6 +289,7 @@ class BasePipeline:
|
||||
)
|
||||
component_config["params"] = params
|
||||
|
||||
del pipeline_config["name"] # Would fail validation otherwise
|
||||
pipeline = cls.load_from_config(
|
||||
pipeline_config=pipeline_config,
|
||||
pipeline_name=pipeline_name,
|
||||
@ -501,43 +489,64 @@ class Pipeline(BasePipeline):
|
||||
In cases when the predecessor node has multiple outputs, e.g., a "QueryClassifier", the output
|
||||
must be specified explicitly as "QueryClassifier.output_2".
|
||||
"""
|
||||
valid_root_nodes = ["Query", "File"]
|
||||
if self.root_node is None:
|
||||
root_node = inputs[0]
|
||||
if root_node in ["Query", "File"]:
|
||||
if root_node in valid_root_nodes:
|
||||
self.root_node = root_node
|
||||
self.graph.add_node(root_node, component=RootNode())
|
||||
else:
|
||||
raise KeyError(f"Root node '{root_node}' is invalid. Available options are 'Query' and 'File'.")
|
||||
raise PipelineConfigError(
|
||||
f"Root node '{root_node}' is invalid. Available options are {valid_root_nodes}."
|
||||
)
|
||||
component.name = name
|
||||
self.graph.add_node(name, component=component, inputs=inputs)
|
||||
|
||||
if len(self.graph.nodes) == 2: # first node added; connect with Root
|
||||
assert len(inputs) == 1 and inputs[0].split(".")[0] == self.root_node, (
|
||||
f"The '{name}' node can only input from {self.root_node}. "
|
||||
f"Set the 'inputs' parameter to ['{self.root_node}']"
|
||||
)
|
||||
if not len(inputs) == 1 and inputs[0].split(".")[0] == self.root_node:
|
||||
raise PipelineConfigError(
|
||||
f"The '{name}' node can only input from {self.root_node}. "
|
||||
f"Set the 'inputs' parameter to ['{self.root_node}']"
|
||||
)
|
||||
self.graph.add_edge(self.root_node, name, label="output_1")
|
||||
return
|
||||
|
||||
for i in inputs:
|
||||
if "." in i:
|
||||
[input_node_name, input_edge_name] = i.split(".")
|
||||
assert "output_" in input_edge_name, f"'{input_edge_name}' is not a valid edge name."
|
||||
for input_node in inputs:
|
||||
if "." in input_node:
|
||||
[input_node_name, input_edge_name] = input_node.split(".")
|
||||
if not "output_" in input_edge_name:
|
||||
raise PipelineConfigError(f"'{input_edge_name}' is not a valid edge name.")
|
||||
|
||||
outgoing_edges_input_node = self.graph.nodes[input_node_name]["component"].outgoing_edges
|
||||
assert int(input_edge_name.split("_")[1]) <= outgoing_edges_input_node, (
|
||||
f"Cannot connect '{input_edge_name}' from '{input_node_name}' as it only has "
|
||||
f"{outgoing_edges_input_node} outgoing edge(s)."
|
||||
)
|
||||
if not int(input_edge_name.split("_")[1]) <= outgoing_edges_input_node:
|
||||
raise PipelineConfigError(
|
||||
f"Cannot connect '{input_edge_name}' from '{input_node_name}' as it only has "
|
||||
f"{outgoing_edges_input_node} outgoing edge(s)."
|
||||
)
|
||||
else:
|
||||
outgoing_edges_input_node = self.graph.nodes[i]["component"].outgoing_edges
|
||||
assert outgoing_edges_input_node == 1, (
|
||||
f"Adding an edge from {i} to {name} is ambiguous as {i} has {outgoing_edges_input_node} edges. "
|
||||
f"Please specify the output explicitly."
|
||||
)
|
||||
input_node_name = i
|
||||
try:
|
||||
outgoing_edges_input_node = self.graph.nodes[input_node]["component"].outgoing_edges
|
||||
if not outgoing_edges_input_node == 1:
|
||||
raise PipelineConfigError(
|
||||
f"Adding an edge from {input_node} to {name} is ambiguous as {input_node} has {outgoing_edges_input_node} edges. "
|
||||
f"Please specify the output explicitly."
|
||||
)
|
||||
|
||||
except KeyError as e:
|
||||
raise PipelineConfigError(
|
||||
f"Cannot find node '{input_node}'. Make sure you're not using more "
|
||||
f"than one root node ({valid_root_nodes}) in the same pipeline and that a node "
|
||||
f"called '{input_node}' is defined."
|
||||
) from e
|
||||
|
||||
input_node_name = input_node
|
||||
input_edge_name = "output_1"
|
||||
self.graph.add_edge(input_node_name, name, label=input_edge_name)
|
||||
|
||||
if not nx.is_directed_acyclic_graph(self.graph):
|
||||
self.graph.remove_node(name)
|
||||
raise PipelineConfigError(f"Cannot add '{name}': it will create a loop in the pipeline.")
|
||||
|
||||
def get_node(self, name: str) -> Optional[BaseComponent]:
|
||||
"""
|
||||
Get a node from the Pipeline.
|
||||
@ -968,6 +977,61 @@ class Pipeline(BasePipeline):
|
||||
graphviz.layout("dot")
|
||||
graphviz.draw(path)
|
||||
|
||||
@classmethod
|
||||
def load_from_yaml(cls, path: Path, pipeline_name: Optional[str] = None, overwrite_with_env_variables: bool = True):
|
||||
"""
|
||||
Load Pipeline from a YAML file defining the individual components and how they're tied together to form
|
||||
a Pipeline. A single YAML can declare multiple Pipelines, in which case an explicit `pipeline_name` must
|
||||
be passed.
|
||||
|
||||
Here's a sample configuration:
|
||||
|
||||
```yaml
|
||||
| version: '1.0'
|
||||
|
|
||||
| components: # define all the building-blocks for Pipeline
|
||||
| - name: MyReader # custom-name for the component; helpful for visualization & debugging
|
||||
| type: FARMReader # Haystack Class name for the component
|
||||
| params:
|
||||
| no_ans_boost: -10
|
||||
| model_name_or_path: deepset/roberta-base-squad2
|
||||
| - name: MyESRetriever
|
||||
| type: ElasticsearchRetriever
|
||||
| params:
|
||||
| document_store: MyDocumentStore # params can reference other components defined in the YAML
|
||||
| custom_query: null
|
||||
| - name: MyDocumentStore
|
||||
| type: ElasticsearchDocumentStore
|
||||
| params:
|
||||
| index: haystack_test
|
||||
|
|
||||
| pipelines: # multiple Pipelines can be defined using the components from above
|
||||
| - name: my_query_pipeline # a simple extractive-qa Pipeline
|
||||
| nodes:
|
||||
| - name: MyESRetriever
|
||||
| inputs: [Query]
|
||||
| - name: MyReader
|
||||
| inputs: [MyESRetriever]
|
||||
```
|
||||
|
||||
Note that, in case of a mismatch in version between Haystack and the YAML, a warning will be printed.
|
||||
If the pipeline loads correctly regardless, save again the pipeline using `Pipeline.save_to_yaml()` to remove the warning.
|
||||
|
||||
:param path: path of the YAML file.
|
||||
:param pipeline_name: if the YAML contains multiple pipelines, the pipeline_name to load must be set.
|
||||
:param overwrite_with_env_variables: Overwrite the YAML configuration with environment variables. For example,
|
||||
to change index name param for an ElasticsearchDocumentStore, an env
|
||||
variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an
|
||||
`_` sign must be used to specify nested hierarchical properties.
|
||||
"""
|
||||
|
||||
pipeline_config = read_pipeline_config_from_yaml(path)
|
||||
return cls.load_from_config(
|
||||
pipeline_config=pipeline_config,
|
||||
pipeline_name=pipeline_name,
|
||||
overwrite_with_env_variables=overwrite_with_env_variables,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_from_config(
|
||||
cls, pipeline_config: Dict, pipeline_name: Optional[str] = None, overwrite_with_env_variables: bool = True
|
||||
@ -1017,6 +1081,8 @@ class Pipeline(BasePipeline):
|
||||
variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an
|
||||
`_` sign must be used to specify nested hierarchical properties.
|
||||
"""
|
||||
validate_config(pipeline_config)
|
||||
|
||||
pipeline_definition = get_pipeline_definition(pipeline_config=pipeline_config, pipeline_name=pipeline_name)
|
||||
component_definitions = get_component_definitions(
|
||||
pipeline_config=pipeline_config, overwrite_with_env_variables=overwrite_with_env_variables
|
||||
@ -1063,8 +1129,16 @@ class Pipeline(BasePipeline):
|
||||
|
||||
instance = BaseComponent.load_from_args(component_type=component_type, **component_params)
|
||||
components[name] = instance
|
||||
|
||||
except KeyError as ke:
|
||||
raise PipelineConfigError(
|
||||
f"Failed loading pipeline component '{name}': "
|
||||
"seems like the component does not exist. Did you spell its name correctly?"
|
||||
) from ke
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed loading pipeline component '{name}': {e}")
|
||||
raise PipelineConfigError(
|
||||
f"Failed loading pipeline component '{name}'. " "See the stacktrace above for more informations."
|
||||
) from e
|
||||
return instance
|
||||
|
||||
def save_to_yaml(self, path: Path, return_defaults: bool = False):
|
||||
@ -1085,15 +1159,16 @@ class Pipeline(BasePipeline):
|
||||
:param return_defaults: whether to output parameters that have the default values.
|
||||
"""
|
||||
pipeline_name = ROOT_NODE_TO_PIPELINE_NAME[self.root_node.lower()]
|
||||
pipelines: dict = {pipeline_name: {"name": pipeline_name, "type": self.__class__.__name__, "nodes": []}}
|
||||
pipelines: dict = {pipeline_name: {"name": pipeline_name, "nodes": []}}
|
||||
|
||||
components = {}
|
||||
for node in self.graph.nodes:
|
||||
if node == self.root_node:
|
||||
continue
|
||||
component_instance = self.graph.nodes.get(node)["component"]
|
||||
component_type = component_instance.pipeline_config["type"]
|
||||
component_params = component_instance.pipeline_config["params"]
|
||||
|
||||
component_type = component_instance._component_config["type"]
|
||||
component_params = component_instance._component_config["params"]
|
||||
components[node] = {"name": node, "type": component_type, "params": {}}
|
||||
|
||||
component_parent_classes = inspect.getmro(type(component_instance))
|
||||
@ -1112,7 +1187,7 @@ class Pipeline(BasePipeline):
|
||||
sub_component = param_value
|
||||
sub_component_type_name = sub_component["type"]
|
||||
sub_component_signature = inspect.signature(
|
||||
BaseComponent.subclasses[sub_component_type_name]
|
||||
BaseComponent._subclasses[sub_component_type_name]
|
||||
).parameters
|
||||
sub_component_params = {
|
||||
k: v
|
||||
@ -1313,14 +1388,6 @@ class RayPipeline(Pipeline):
|
||||
:param address: The IP address for the Ray cluster. If set to None, a local Ray instance is started.
|
||||
"""
|
||||
pipeline_config = read_pipeline_config_from_yaml(path)
|
||||
if pipeline_config["version"] != __version__:
|
||||
logger.warning(
|
||||
f"YAML version ({pipeline_config['version']}) does not match with Haystack version ({__version__}). "
|
||||
"Issues may occur during loading. "
|
||||
"To fix this warning, save again this pipeline with the current Haystack version using Pipeline.save_to_yaml(), "
|
||||
"check out our migration guide at https://haystack.deepset.ai/overview/migration "
|
||||
f"or downgrade to haystack version {__version__}."
|
||||
)
|
||||
return RayPipeline.load_from_config(
|
||||
pipeline_config=pipeline_config,
|
||||
pipeline_name=pipeline_name,
|
||||
|
||||
@ -1,17 +1,26 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import re
|
||||
import os
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
from networkx import DiGraph
|
||||
import yaml
|
||||
import json
|
||||
from jsonschema.validators import Draft7Validator
|
||||
from jsonschema.exceptions import ValidationError
|
||||
|
||||
from haystack import __version__
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.nodes._json_schema import inject_definition_in_schema, JSON_SCHEMAS_PATH
|
||||
from haystack.errors import PipelineConfigError, PipelineSchemaError, HaystackError
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
VALID_CODE_GEN_INPUT_REGEX = re.compile(r"^[-a-zA-Z0-9_/.:]+$")
|
||||
VALID_INPUT_REGEX = re.compile(r"^[-a-zA-Z0-9_/.:]+$")
|
||||
|
||||
|
||||
def get_pipeline_definition(pipeline_config: Dict[str, Any], pipeline_name: Optional[str] = None) -> Dict[str, Any]:
|
||||
@ -26,11 +35,14 @@ def get_pipeline_definition(pipeline_config: Dict[str, Any], pipeline_name: Opti
|
||||
if len(pipeline_config["pipelines"]) == 1:
|
||||
pipeline_definition = pipeline_config["pipelines"][0]
|
||||
else:
|
||||
raise Exception("The YAML contains multiple pipelines. Please specify the pipeline name to load.")
|
||||
raise PipelineConfigError("The YAML contains multiple pipelines. Please specify the pipeline name to load.")
|
||||
else:
|
||||
pipelines_in_definitions = list(filter(lambda p: p["name"] == pipeline_name, pipeline_config["pipelines"]))
|
||||
if not pipelines_in_definitions:
|
||||
raise KeyError(f"Cannot find any pipeline with name '{pipeline_name}' declared in the YAML file.")
|
||||
raise PipelineConfigError(
|
||||
f"Cannot find any pipeline with name '{pipeline_name}' declared in the YAML file. "
|
||||
f"Existing pipelines: {[p['name'] for p in pipeline_config['pipelines']]}"
|
||||
)
|
||||
pipeline_definition = pipelines_in_definitions[0]
|
||||
|
||||
return pipeline_definition
|
||||
@ -62,20 +74,29 @@ def read_pipeline_config_from_yaml(path: Path):
|
||||
return yaml.safe_load(stream)
|
||||
|
||||
|
||||
def validate_config(pipeline_config: Dict[str, Any]):
|
||||
for component in pipeline_config["components"]:
|
||||
_validate_user_input(component["name"])
|
||||
_validate_user_input(component["type"])
|
||||
for k, v in component.get("params", {}).items():
|
||||
_validate_user_input(k)
|
||||
_validate_user_input(v)
|
||||
for pipeline in pipeline_config["pipelines"]:
|
||||
_validate_user_input(pipeline["name"])
|
||||
_validate_user_input(pipeline["type"])
|
||||
for node in pipeline["nodes"]:
|
||||
_validate_user_input(node["name"])
|
||||
for input in node["inputs"]:
|
||||
_validate_user_input(input)
|
||||
def validate_config_strings(pipeline_config: Any):
|
||||
"""
|
||||
Ensures that strings used in the pipelines configuration
|
||||
contain only alphanumeric characters and basic punctuation.
|
||||
"""
|
||||
try:
|
||||
if isinstance(pipeline_config, dict):
|
||||
for key, value in pipeline_config.items():
|
||||
validate_config_strings(key)
|
||||
validate_config_strings(value)
|
||||
|
||||
elif isinstance(pipeline_config, list):
|
||||
for value in pipeline_config:
|
||||
validate_config_strings(value)
|
||||
|
||||
else:
|
||||
if not VALID_INPUT_REGEX.match(str(pipeline_config)):
|
||||
raise PipelineConfigError(
|
||||
f"'{pipeline_config}' is not a valid variable name or value. "
|
||||
"Use alphanumeric characters or dash, underscore and colon only."
|
||||
)
|
||||
except RecursionError as e:
|
||||
raise PipelineConfigError("The given pipeline configuration is recursive, can't validate it.") from e
|
||||
|
||||
|
||||
def build_component_dependency_graph(
|
||||
@ -111,9 +132,96 @@ def build_component_dependency_graph(
|
||||
return graph
|
||||
|
||||
|
||||
def _validate_user_input(input: str):
|
||||
if isinstance(input, str) and not VALID_CODE_GEN_INPUT_REGEX.match(input):
|
||||
raise ValueError(f"'{input}' is not a valid config variable name. Use word characters only.")
|
||||
def validate_yaml(path: Path):
|
||||
"""
|
||||
Validates the given YAML file using the autogenerated JSON schema.
|
||||
|
||||
:param pipeline_config: the configuration to validate
|
||||
:return: None if validation is successful
|
||||
:raise: `PipelineConfigError` in case of issues.
|
||||
"""
|
||||
pipeline_config = read_pipeline_config_from_yaml(path)
|
||||
validate_config(pipeline_config=pipeline_config)
|
||||
logging.debug(f"'{path}' contains valid Haystack pipelines.")
|
||||
|
||||
|
||||
def validate_config(pipeline_config: Dict) -> None:
|
||||
"""
|
||||
Validates the given configuration using the autogenerated JSON schema.
|
||||
|
||||
:param pipeline_config: the configuration to validate
|
||||
:return: None if validation is successful
|
||||
:raise: `PipelineConfigError` in case of issues.
|
||||
"""
|
||||
validate_config_strings(pipeline_config)
|
||||
|
||||
with open(JSON_SCHEMAS_PATH / f"haystack-pipeline-unstable.schema.json", "r") as schema_file:
|
||||
schema = json.load(schema_file)
|
||||
|
||||
compatible_versions = [version["const"].replace('"', "") for version in schema["properties"]["version"]["oneOf"]]
|
||||
loaded_custom_nodes = []
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
Draft7Validator(schema).validate(instance=pipeline_config)
|
||||
|
||||
if pipeline_config["version"] == "unstable":
|
||||
logging.warning(
|
||||
"You seem to be using the 'unstable' version of the schema to validate "
|
||||
"your pipeline configuration.\n"
|
||||
"This is NOT RECOMMENDED in production environments, as pipelines "
|
||||
"might manage to load and then misbehave without warnings.\n"
|
||||
f"Please pin your configurations to '{__version__}' to ensure stability."
|
||||
)
|
||||
|
||||
elif pipeline_config["version"] not in compatible_versions:
|
||||
raise PipelineConfigError(
|
||||
f"Cannot load pipeline configuration of version {pipeline_config['version']} "
|
||||
f"in Haystack version {__version__} "
|
||||
f"(only versions {compatible_versions} are compatible with this Haystack release).\n"
|
||||
"Please check out the release notes (https://github.com/deepset-ai/haystack/releases/latest), "
|
||||
"the documentation (https://haystack.deepset.ai/components/pipelines#yaml-file-definitions) "
|
||||
"and fix your configuration accordingly."
|
||||
)
|
||||
break
|
||||
|
||||
except ValidationError as validation:
|
||||
|
||||
# If the validation comes from an unknown node, try to find it and retry:
|
||||
if list(validation.relative_schema_path) == ["properties", "components", "items", "anyOf"]:
|
||||
if validation.instance["type"] not in loaded_custom_nodes:
|
||||
|
||||
logger.info(
|
||||
f"Missing definition for node of type {validation.instance['type']}. Looking into local classes..."
|
||||
)
|
||||
missing_component = BaseComponent.get_subclass(validation.instance["type"])
|
||||
schema = inject_definition_in_schema(node=missing_component, schema=schema)
|
||||
loaded_custom_nodes.append(validation.instance["type"])
|
||||
continue
|
||||
|
||||
# A node with the given name was imported, but something else is wrong with it.
|
||||
# Probably it references unknown classes in its init parameters.
|
||||
raise PipelineSchemaError(
|
||||
f"Cannot process node of type {validation.instance['type']}. Make sure its __init__ function "
|
||||
"does not reference external classes, but uses only Python primitive types."
|
||||
) from validation
|
||||
|
||||
# Format the error to make it as clear as possible
|
||||
error_path = [
|
||||
i
|
||||
for i in list(validation.relative_schema_path)[:-1]
|
||||
if repr(i) != "'items'" and repr(i) != "'properties'"
|
||||
]
|
||||
error_location = "->".join(repr(index) for index in error_path)
|
||||
if error_location:
|
||||
error_location = f"The error is in {error_location}."
|
||||
|
||||
raise PipelineConfigError(
|
||||
f"Validation failed. {validation.message}. {error_location} " "See the stacktrace for more information."
|
||||
) from validation
|
||||
|
||||
logging.debug(f"Pipeline configuration is valid.")
|
||||
|
||||
|
||||
def _overwrite_with_env_variables(component_definition: Dict[str, Any]):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,65 +1,69 @@
|
||||
{
|
||||
"$schema": "http://json-schema.org/draft-07/schema",
|
||||
"$id": "https://haystack.deepset.ai/json-schemas/haystack-pipeline-1.1.0.schema.json",
|
||||
"title": "Haystack Pipeline",
|
||||
"description": "Haystack Pipeline YAML file describing the nodes of the pipelines. For more info read the docs at: https://haystack.deepset.ai/components/pipelines#yaml-file-definitions",
|
||||
"type": "object",
|
||||
"oneOf": [
|
||||
"$schema": "http://json-schema.org/draft-07/schema",
|
||||
"$id": "https://haystack.deepset.ai/json-schemas/haystack-pipeline-1.1.0.schema.json",
|
||||
"title": "Haystack Pipeline",
|
||||
"description": "Haystack Pipeline YAML file describing the nodes of the pipelines. For more info read the docs at: https://haystack.deepset.ai/components/pipelines#yaml-file-definitions",
|
||||
"type": "object",
|
||||
"oneOf": [
|
||||
{
|
||||
"allOf": [
|
||||
{
|
||||
"allOf": [
|
||||
"properties": {
|
||||
"version": {
|
||||
"oneOf": [
|
||||
{
|
||||
"properties": {
|
||||
"version": {
|
||||
"const": "0.7"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$ref": "https://raw.githubusercontent.com/deepset-ai/haystack/master/json-schemas/haystack-pipeline-0.7.schema.json"
|
||||
"const": "unstable"
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"allOf": [
|
||||
{
|
||||
"properties": {
|
||||
"version": {
|
||||
"const": "1.1.0"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$ref": "https://raw.githubusercontent.com/deepset-ai/haystack/master/json-schemas/haystack-pipeline-1.1.0.schema.json"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"allOf": [
|
||||
{
|
||||
"properties": {
|
||||
"version": {
|
||||
"const": "1.2.0"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$ref": "https://raw.githubusercontent.com/deepset-ai/haystack/master/json-schemas/haystack-pipeline-1.2.0.schema.json"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"allOf": [
|
||||
{
|
||||
"properties": {
|
||||
"version": {
|
||||
"const": "1.2.1rc0"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$ref": "https://raw.githubusercontent.com/deepset-ai/haystack/master/json-schemas/haystack-pipeline-1.2.1rc0.schema.json"
|
||||
}
|
||||
]
|
||||
"$ref": "https://raw.githubusercontent.com/deepset-ai/haystack/master/json-schemas/haystack-pipeline-unstable.schema.json"
|
||||
}
|
||||
]
|
||||
]
|
||||
},
|
||||
{
|
||||
"allOf": [
|
||||
{
|
||||
"properties": {
|
||||
"version": {
|
||||
"oneOf": [
|
||||
{
|
||||
"const": "1.0.0"
|
||||
},
|
||||
{
|
||||
"const": "1.1.0"
|
||||
},
|
||||
{
|
||||
"const": "1.2.0"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$ref": "https://raw.githubusercontent.com/deepset-ai/haystack/master/json-schemas/haystack-pipeline-1.0.0.schema.json"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"allOf": [
|
||||
{
|
||||
"properties": {
|
||||
"version": {
|
||||
"oneOf": [
|
||||
{
|
||||
"const": "1.2.1rc0"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$ref": "https://raw.githubusercontent.com/deepset-ai/haystack/master/json-schemas/haystack-pipeline-1.2.1rc0.schema.json"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -93,14 +93,18 @@ min-similarity-lines=6
|
||||
minversion = "6.0"
|
||||
addopts = "--strict-markers"
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"tika: marks tests which require tika container (deselect with '-m \"not tika\"')",
|
||||
"elasticsearch: marks tests which require elasticsearch container (deselect with '-m \"not elasticsearch\"')",
|
||||
"graphdb: marks tests which require graphdb container (deselect with '-m \"not graphdb\"')",
|
||||
"generator: marks generator tests (deselect with '-m \"not generator\"')",
|
||||
"pipeline: marks tests with pipeline",
|
||||
"summarizer: marks summarizer tests",
|
||||
"weaviate: marks tests that require weaviate container",
|
||||
"embedding_dim: marks usage of document store with non-default embedding dimension (e.g @pytest.mark.embedding_dim(128))",
|
||||
"integration: integration tests (deselect with '-m \"not integration\"')",
|
||||
"slow: slow tests (deselect with '-m \"not slow\"')",
|
||||
"tika: require tika container (deselect with '-m \"not tika\"')",
|
||||
"elasticsearch: require elasticsearch container (deselect with '-m \"not elasticsearch\"')",
|
||||
"graphdb: require graphdb container (deselect with '-m \"not graphdb\"')",
|
||||
"generator: generator tests (deselect with '-m \"not generator\"')",
|
||||
"pipeline: tests with pipelines",
|
||||
"summarizer: summarizer tests",
|
||||
"weaviate: require weaviate container",
|
||||
"faiss: uses FAISS",
|
||||
"milvus: requires a Milvus 2 setup",
|
||||
"milvus1: requires a Milvus 1 container",
|
||||
"embedding_dim: uses a document store with non-default embedding dimension (e.g @pytest.mark.embedding_dim(128))",
|
||||
]
|
||||
log_cli = true
|
||||
@ -3,7 +3,7 @@ from pathlib import Path
|
||||
|
||||
|
||||
PIPELINE_YAML_PATH = os.getenv(
|
||||
"PIPELINE_YAML_PATH", str((Path(__file__).parent / "pipeline" / "pipelines.yaml").absolute())
|
||||
"PIPELINE_YAML_PATH", str((Path(__file__).parent / "pipeline" / "pipelines.haystack-pipeline.yml").absolute())
|
||||
)
|
||||
QUERY_PIPELINE_NAME = os.getenv("QUERY_PIPELINE_NAME", "query")
|
||||
INDEXING_PIPELINE_NAME = os.getenv("INDEXING_PIPELINE_NAME", "indexing")
|
||||
|
||||
@ -11,6 +11,7 @@ from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from haystack.pipelines.base import Pipeline
|
||||
from haystack.errors import PipelineConfigError
|
||||
from haystack.pipelines.config import get_component_definitions, get_pipeline_definition, read_pipeline_config_from_yaml
|
||||
from rest_api.config import PIPELINE_YAML_PATH, FILE_UPLOAD_PATH, INDEXING_PIPELINE_NAME
|
||||
from rest_api.controller.utils import as_form
|
||||
@ -43,10 +44,10 @@ try:
|
||||
INDEXING_PIPELINE = None
|
||||
else:
|
||||
INDEXING_PIPELINE = Pipeline.load_from_yaml(Path(PIPELINE_YAML_PATH), pipeline_name=INDEXING_PIPELINE_NAME)
|
||||
except KeyError:
|
||||
INDEXING_PIPELINE = None
|
||||
logger.warning("Indexing Pipeline not found in the YAML configuration. File Upload API will not be available.")
|
||||
|
||||
except PipelineConfigError as e:
|
||||
INDEXING_PIPELINE = None
|
||||
logger.error(f"{e.message}. File Upload API will not be available.")
|
||||
|
||||
# create directory for uploading files
|
||||
os.makedirs(FILE_UPLOAD_PATH, exist_ok=True)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# Dummy pipeline, used when the CI needs to load the REST API to extract the OpenAPI specs. DO NOT USE.
|
||||
|
||||
version: '1.1.0'
|
||||
version: 'unstable'
|
||||
|
||||
components:
|
||||
- name: FileTypeClassifier
|
||||
@ -8,8 +7,6 @@ components:
|
||||
|
||||
pipelines:
|
||||
- name: query
|
||||
type: Query
|
||||
nodes:
|
||||
- name: FileTypeClassifier
|
||||
inputs: [File]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
version: '1.1.0'
|
||||
version: 'unstable'
|
||||
|
||||
components: # define all the building-blocks for Pipeline
|
||||
- name: DocumentStore
|
||||
@ -30,14 +30,12 @@ components: # define all the building-blocks for Pipeline
|
||||
|
||||
pipelines:
|
||||
- name: query # a sample extractive-qa Pipeline
|
||||
type: Query
|
||||
nodes:
|
||||
- name: Retriever
|
||||
inputs: [Query]
|
||||
- name: Reader
|
||||
inputs: [Retriever]
|
||||
- name: indexing
|
||||
type: Indexing
|
||||
nodes:
|
||||
- name: FileTypeClassifier
|
||||
inputs: [File]
|
||||
@ -1,4 +1,4 @@
|
||||
version: '1.1.0'
|
||||
version: 'unstable'
|
||||
|
||||
components: # define all the building-blocks for Pipeline
|
||||
- name: DocumentStore
|
||||
@ -30,14 +30,12 @@ components: # define all the building-blocks for Pipeline
|
||||
|
||||
pipelines:
|
||||
- name: query # a sample extractive-qa Pipeline
|
||||
type: Query
|
||||
nodes:
|
||||
- name: Retriever
|
||||
inputs: [Query]
|
||||
- name: Reader
|
||||
inputs: [Retriever]
|
||||
- name: indexing
|
||||
type: Indexing
|
||||
nodes:
|
||||
- name: FileTypeClassifier
|
||||
inputs: [File]
|
||||
@ -0,0 +1,47 @@
|
||||
version: '1.1.0'
|
||||
|
||||
components:
|
||||
- name: Reader
|
||||
type: FARMReader
|
||||
params:
|
||||
no_ans_boost: -10
|
||||
model_name_or_path: deepset/roberta-base-squad2
|
||||
num_processes: 0
|
||||
- name: ESRetriever
|
||||
type: ElasticsearchRetriever
|
||||
params:
|
||||
document_store: DocumentStore
|
||||
custom_query: null
|
||||
- name: DocumentStore
|
||||
type: ElasticsearchDocumentStore
|
||||
params:
|
||||
index: haystack_test
|
||||
label_index: haystack_test_label
|
||||
- name: Preprocessor
|
||||
type: PreProcessor
|
||||
params:
|
||||
clean_whitespace: true
|
||||
- name: PDFConverter
|
||||
type: PDFToTextConverter
|
||||
params:
|
||||
remove_numeric_tables: false
|
||||
|
||||
|
||||
pipelines:
|
||||
- name: test-query
|
||||
nodes:
|
||||
- name: ESRetriever
|
||||
inputs: [Query]
|
||||
- name: Reader
|
||||
inputs: [ESRetriever]
|
||||
|
||||
- name: test-indexing
|
||||
nodes:
|
||||
- name: PDFConverter
|
||||
inputs: [File]
|
||||
- name: Preprocessor
|
||||
inputs: [PDFConverter]
|
||||
- name: ESRetriever
|
||||
inputs: [Preprocessor]
|
||||
- name: DocumentStore
|
||||
inputs: [ESRetriever]
|
||||
@ -1,103 +0,0 @@
|
||||
version: '1.1.0'
|
||||
|
||||
components:
|
||||
- name: Reader
|
||||
type: FARMReader
|
||||
params:
|
||||
no_ans_boost: -10
|
||||
model_name_or_path: deepset/roberta-base-squad2
|
||||
num_processes: 0
|
||||
- name: ESRetriever
|
||||
type: ElasticsearchRetriever
|
||||
params:
|
||||
document_store: DocumentStore
|
||||
custom_query: null
|
||||
- name: DocumentStore
|
||||
type: ElasticsearchDocumentStore
|
||||
params:
|
||||
index: haystack_test
|
||||
label_index: haystack_test_label
|
||||
- name: PDFConverter
|
||||
type: PDFToTextConverter
|
||||
params:
|
||||
remove_numeric_tables: false
|
||||
- name: Preprocessor
|
||||
type: PreProcessor
|
||||
params:
|
||||
clean_whitespace: true
|
||||
- name: IndexTimeDocumentClassifier
|
||||
type: TransformersDocumentClassifier
|
||||
params:
|
||||
batch_size: 16
|
||||
use_gpu: -1
|
||||
- name: QueryTimeDocumentClassifier
|
||||
type: TransformersDocumentClassifier
|
||||
params:
|
||||
use_gpu: -1
|
||||
|
||||
|
||||
pipelines:
|
||||
- name: query_pipeline
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: ESRetriever
|
||||
inputs: [Query]
|
||||
- name: Reader
|
||||
inputs: [ESRetriever]
|
||||
|
||||
- name: ray_query_pipeline
|
||||
type: RayPipeline
|
||||
nodes:
|
||||
- name: ESRetriever
|
||||
replicas: 2
|
||||
inputs: [ Query ]
|
||||
- name: Reader
|
||||
inputs: [ ESRetriever ]
|
||||
|
||||
- name: query_pipeline_with_document_classifier
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: ESRetriever
|
||||
inputs: [Query]
|
||||
- name: QueryTimeDocumentClassifier
|
||||
inputs: [ESRetriever]
|
||||
- name: Reader
|
||||
inputs: [QueryTimeDocumentClassifier]
|
||||
|
||||
- name: indexing_pipeline
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: PDFConverter
|
||||
inputs: [File]
|
||||
- name: Preprocessor
|
||||
inputs: [PDFConverter]
|
||||
- name: ESRetriever
|
||||
inputs: [Preprocessor]
|
||||
- name: DocumentStore
|
||||
inputs: [ESRetriever]
|
||||
|
||||
- name: indexing_text_pipeline
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: TextConverter
|
||||
inputs: [File]
|
||||
- name: Preprocessor
|
||||
inputs: [TextConverter]
|
||||
- name: ESRetriever
|
||||
inputs: [Preprocessor]
|
||||
- name: DocumentStore
|
||||
inputs: [ESRetriever]
|
||||
|
||||
- name: indexing_pipeline_with_classifier
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: PDFConverter
|
||||
inputs: [File]
|
||||
- name: Preprocessor
|
||||
inputs: [PDFConverter]
|
||||
- name: IndexTimeDocumentClassifier
|
||||
inputs: [Preprocessor]
|
||||
- name: ESRetriever
|
||||
inputs: [IndexTimeDocumentClassifier]
|
||||
- name: DocumentStore
|
||||
inputs: [ESRetriever]
|
||||
@ -1,31 +0,0 @@
|
||||
version: '1.1.0'
|
||||
|
||||
components:
|
||||
- name: DPRRetriever
|
||||
type: DensePassageRetriever
|
||||
params:
|
||||
document_store: NewFAISSDocumentStore
|
||||
- name: PDFConverter
|
||||
type: PDFToTextConverter
|
||||
params:
|
||||
remove_numeric_tables: false
|
||||
- name: Preprocessor
|
||||
type: PreProcessor
|
||||
params:
|
||||
clean_whitespace: true
|
||||
- name: NewFAISSDocumentStore
|
||||
type: FAISSDocumentStore
|
||||
|
||||
|
||||
pipelines:
|
||||
- name: indexing_pipeline
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: PDFConverter
|
||||
inputs: [File]
|
||||
- name: Preprocessor
|
||||
inputs: [PDFConverter]
|
||||
- name: DPRRetriever
|
||||
inputs: [Preprocessor]
|
||||
- name: NewFAISSDocumentStore
|
||||
inputs: [DPRRetriever]
|
||||
@ -1,19 +0,0 @@
|
||||
version: '1.1.0'
|
||||
|
||||
components:
|
||||
- name: DPRRetriever
|
||||
type: DensePassageRetriever
|
||||
params:
|
||||
document_store: ExistingFAISSDocumentStore
|
||||
- name: ExistingFAISSDocumentStore
|
||||
type: FAISSDocumentStore
|
||||
params:
|
||||
faiss_index_path: 'existing_faiss_document_store'
|
||||
|
||||
|
||||
pipelines:
|
||||
- name: query_pipeline
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: DPRRetriever
|
||||
inputs: [Query]
|
||||
@ -1,25 +0,0 @@
|
||||
version: '1.1.0'
|
||||
|
||||
components:
|
||||
- name: Reader
|
||||
type: FARMReader
|
||||
params:
|
||||
no_ans_boost: -10
|
||||
model_name_or_path: deepset/minilm-uncased-squad2
|
||||
num_processes: 0
|
||||
- name: Retriever
|
||||
type: TfidfRetriever
|
||||
params:
|
||||
document_store: DocumentStore
|
||||
- name: DocumentStore
|
||||
type: InMemoryDocumentStore
|
||||
|
||||
|
||||
pipelines:
|
||||
- name: query_pipeline
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: Retriever
|
||||
inputs: [Query]
|
||||
- name: Reader
|
||||
inputs: [Retriever]
|
||||
@ -45,9 +45,10 @@ def exclude_no_answer(responses):
|
||||
@pytest.fixture()
|
||||
def client() -> TestClient:
|
||||
os.environ["PIPELINE_YAML_PATH"] = str(
|
||||
(Path(__file__).parent / "samples" / "pipeline" / "test_pipeline.yaml").absolute()
|
||||
(Path(__file__).parent / "samples" / "pipeline" / "test_pipeline.haystack-pipeline.yml").absolute()
|
||||
)
|
||||
os.environ["INDEXING_PIPELINE_NAME"] = "indexing_text_pipeline"
|
||||
os.environ["INDEXING_PIPELINE_NAME"] = "test-indexing"
|
||||
os.environ["QUERY_PIPELINE_NAME"] = "test-query"
|
||||
client = TestClient(app)
|
||||
|
||||
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
|
||||
@ -217,15 +218,7 @@ def test_query_with_invalid_filter(populated_client: TestClient):
|
||||
assert len(response_json["answers"]) == 0
|
||||
|
||||
|
||||
def test_query_with_no_documents_and_no_answers():
|
||||
os.environ["PIPELINE_YAML_PATH"] = str(
|
||||
(Path(__file__).parent / "samples" / "pipeline" / "test_pipeline.yaml").absolute()
|
||||
)
|
||||
os.environ["INDEXING_PIPELINE_NAME"] = "indexing_text_pipeline"
|
||||
client = TestClient(app)
|
||||
|
||||
# Clean up to make sure the docstore is empty
|
||||
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
|
||||
def test_query_with_no_documents_and_no_answers(client: TestClient):
|
||||
query = {"query": "Who made the PDF specification?"}
|
||||
response = client.post(url="/query", json=query)
|
||||
assert 200 == response.status_code
|
||||
|
||||
@ -75,6 +75,9 @@ install_requires =
|
||||
# pip unfortunately backtracks into the databind direction ultimately getting lost.
|
||||
azure-core<1.23
|
||||
|
||||
# TEMPORARY!!!
|
||||
azure-core<1.23.0
|
||||
|
||||
# Preprocessing
|
||||
more_itertools # for windowing
|
||||
python-docx
|
||||
@ -106,6 +109,8 @@ exclude =
|
||||
test*
|
||||
tutorials*
|
||||
ui*
|
||||
include =
|
||||
json-schemas
|
||||
|
||||
|
||||
[options.extras_require]
|
||||
|
||||
0
test/__init__.py
Normal file
0
test/__init__.py
Normal file
111
test/conftest.py
111
test/conftest.py
@ -1,8 +1,9 @@
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
|
||||
import subprocess
|
||||
import time
|
||||
from subprocess import run
|
||||
from sys import platform
|
||||
import os
|
||||
import gc
|
||||
import uuid
|
||||
import logging
|
||||
@ -15,6 +16,8 @@ import psutil
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from haystack.nodes.base import BaseComponent, MultiLabel
|
||||
|
||||
try:
|
||||
from milvus import Milvus
|
||||
|
||||
@ -27,7 +30,6 @@ try:
|
||||
from elasticsearch import Elasticsearch
|
||||
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
||||
import weaviate
|
||||
|
||||
from haystack.document_stores.weaviate import WeaviateDocumentStore
|
||||
from haystack.document_stores import MilvusDocumentStore
|
||||
from haystack.document_stores.graphdb import GraphDBKnowledgeGraph
|
||||
@ -39,19 +41,15 @@ except (ImportError, ModuleNotFoundError) as ie:
|
||||
|
||||
_optional_component_not_installed("test", "test", ie)
|
||||
|
||||
from haystack.document_stores import BaseDocumentStore, DeepsetCloudDocumentStore, InMemoryDocumentStore
|
||||
|
||||
from haystack.document_stores import DeepsetCloudDocumentStore, InMemoryDocumentStore
|
||||
|
||||
from haystack.nodes import BaseReader, BaseRetriever
|
||||
from haystack.nodes.answer_generator.transformers import Seq2SeqGenerator
|
||||
|
||||
from haystack.nodes.answer_generator.transformers import RAGenerator, RAGeneratorType
|
||||
from haystack.modeling.infer import Inferencer, QAInferencer
|
||||
from haystack.nodes.answer_generator.transformers import RAGenerator
|
||||
from haystack.nodes.ranker import SentenceTransformersRanker
|
||||
from haystack.nodes.document_classifier.transformers import TransformersDocumentClassifier
|
||||
from haystack.nodes.retriever.sparse import ElasticsearchFilterOnlyRetriever, ElasticsearchRetriever, TfidfRetriever
|
||||
from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever
|
||||
from haystack.schema import Document
|
||||
|
||||
from haystack.nodes.reader.farm import FARMReader
|
||||
from haystack.nodes.reader.transformers import TransformersReader
|
||||
from haystack.nodes.reader.table import TableReader, RCIReader
|
||||
@ -59,6 +57,10 @@ from haystack.nodes.summarizer.transformers import TransformersSummarizer
|
||||
from haystack.nodes.translator import TransformersTranslator
|
||||
from haystack.nodes.question_generator import QuestionGenerator
|
||||
|
||||
from haystack.modeling.infer import Inferencer, QAInferencer
|
||||
|
||||
from haystack.schema import Document
|
||||
|
||||
|
||||
# To manually run the tests with default PostgreSQL instead of SQLite, switch the lines below
|
||||
SQL_TYPE = "sqlite"
|
||||
@ -96,22 +98,28 @@ def pytest_collection_modifyitems(config, items):
|
||||
|
||||
# add pytest markers for tests that are not explicitly marked but include some keywords
|
||||
# in the test name (e.g. test_elasticsearch_client would get the "elasticsearch" marker)
|
||||
# TODO evaluate if we need all of there (the non document store ones seems to be unused)
|
||||
if "generator" in item.nodeid:
|
||||
item.add_marker(pytest.mark.generator)
|
||||
elif "summarizer" in item.nodeid:
|
||||
item.add_marker(pytest.mark.summarizer)
|
||||
elif "tika" in item.nodeid:
|
||||
item.add_marker(pytest.mark.tika)
|
||||
elif "elasticsearch" in item.nodeid:
|
||||
item.add_marker(pytest.mark.elasticsearch)
|
||||
elif "graphdb" in item.nodeid:
|
||||
item.add_marker(pytest.mark.graphdb)
|
||||
elif "pipeline" in item.nodeid:
|
||||
item.add_marker(pytest.mark.pipeline)
|
||||
elif "slow" in item.nodeid:
|
||||
item.add_marker(pytest.mark.slow)
|
||||
elif "elasticsearch" in item.nodeid:
|
||||
item.add_marker(pytest.mark.elasticsearch)
|
||||
elif "graphdb" in item.nodeid:
|
||||
item.add_marker(pytest.mark.graphdb)
|
||||
elif "weaviate" in item.nodeid:
|
||||
item.add_marker(pytest.mark.weaviate)
|
||||
elif "faiss" in item.nodeid:
|
||||
item.add_marker(pytest.mark.faiss)
|
||||
elif "milvus" in item.nodeid:
|
||||
item.add_marker(pytest.mark.milvus)
|
||||
item.add_marker(pytest.mark.milvus1)
|
||||
|
||||
# if the cli argument "--document_store_type" is used, we want to skip all tests that have markers of other docstores
|
||||
# Example: pytest -v test_document_store.py --document_store_type="memory" => skip all tests marked with "elasticsearch"
|
||||
@ -139,6 +147,81 @@ def pytest_collection_modifyitems(config, items):
|
||||
item.add_marker(skip_milvus)
|
||||
|
||||
|
||||
#
|
||||
# Empty mocks, as a base for unit tests.
|
||||
#
|
||||
# Monkeypatch the methods you need with either a mock implementation
|
||||
# or a unittest.mock.MagicMock object (https://docs.python.org/3/library/unittest.mock.html)
|
||||
#
|
||||
|
||||
|
||||
class MockNode(BaseComponent):
|
||||
outgoing_edges = 1
|
||||
|
||||
def run(self, *a, **k):
|
||||
pass
|
||||
|
||||
|
||||
class MockDocumentStore(BaseDocumentStore):
|
||||
outgoing_edges = 1
|
||||
|
||||
def _create_document_field_map(self, *a, **k):
|
||||
pass
|
||||
|
||||
def delete_documents(self, *a, **k):
|
||||
pass
|
||||
|
||||
def delete_labels(self, *a, **k):
|
||||
pass
|
||||
|
||||
def get_all_documents(self, *a, **k):
|
||||
pass
|
||||
|
||||
def get_all_documents_generator(self, *a, **k):
|
||||
pass
|
||||
|
||||
def get_all_labels(self, *a, **k):
|
||||
pass
|
||||
|
||||
def get_document_by_id(self, *a, **k):
|
||||
pass
|
||||
|
||||
def get_document_count(self, *a, **k):
|
||||
pass
|
||||
|
||||
def get_documents_by_id(self, *a, **k):
|
||||
pass
|
||||
|
||||
def get_label_count(self, *a, **k):
|
||||
pass
|
||||
|
||||
def query_by_embedding(self, *a, **k):
|
||||
pass
|
||||
|
||||
def write_documents(self, *a, **k):
|
||||
pass
|
||||
|
||||
def write_labels(self, *a, **k):
|
||||
pass
|
||||
|
||||
|
||||
class MockRetriever(BaseRetriever):
|
||||
outgoing_edges = 1
|
||||
|
||||
def retrieve(self, query: str, top_k: int):
|
||||
pass
|
||||
|
||||
|
||||
class MockReader(BaseReader):
|
||||
outgoing_edges = 1
|
||||
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
pass
|
||||
|
||||
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def gc_cleanup(request):
|
||||
"""
|
||||
@ -295,7 +378,7 @@ def deepset_cloud_document_store(deepset_cloud_fixture):
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def rag_generator():
|
||||
return RAGenerator(model_name_or_path="facebook/rag-token-nq", generator_type=RAGeneratorType.TOKEN, max_length=20)
|
||||
return RAGenerator(model_name_or_path="facebook/rag-token-nq", generator_type="token", max_length=20)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
{
|
||||
"version": "0.9",
|
||||
"version": "unstable",
|
||||
"name": "document_retrieval_1",
|
||||
"components": [
|
||||
{
|
||||
@ -40,7 +40,6 @@
|
||||
"pipelines": [
|
||||
{
|
||||
"name": "query",
|
||||
"type": "Query",
|
||||
"nodes": [
|
||||
{
|
||||
"name": "Retriever",
|
||||
@ -52,7 +51,6 @@
|
||||
},
|
||||
{
|
||||
"name": "indexing",
|
||||
"type": "Indexing",
|
||||
"nodes": [
|
||||
{
|
||||
"name": "TextFileConverter",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
version: '1.1.0'
|
||||
version: 'unstable'
|
||||
|
||||
components:
|
||||
- name: Reader
|
||||
@ -11,7 +11,6 @@ components:
|
||||
type: ElasticsearchRetriever
|
||||
params:
|
||||
document_store: DocumentStore
|
||||
custom_query: null
|
||||
- name: DocumentStore
|
||||
type: ElasticsearchDocumentStore
|
||||
params:
|
||||
@ -29,33 +28,22 @@ components:
|
||||
type: TransformersDocumentClassifier
|
||||
params:
|
||||
batch_size: 16
|
||||
use_gpu: -1
|
||||
use_gpu: false
|
||||
- name: QueryTimeDocumentClassifier
|
||||
type: TransformersDocumentClassifier
|
||||
params:
|
||||
use_gpu: -1
|
||||
use_gpu: false
|
||||
|
||||
|
||||
pipelines:
|
||||
- name: query_pipeline
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: ESRetriever
|
||||
inputs: [Query]
|
||||
- name: Reader
|
||||
inputs: [ESRetriever]
|
||||
|
||||
- name: ray_query_pipeline
|
||||
type: RayPipeline
|
||||
nodes:
|
||||
- name: ESRetriever
|
||||
replicas: 2
|
||||
inputs: [ Query ]
|
||||
- name: Reader
|
||||
inputs: [ ESRetriever ]
|
||||
|
||||
- name: query_pipeline_with_document_classifier
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: ESRetriever
|
||||
inputs: [Query]
|
||||
@ -65,7 +53,6 @@ pipelines:
|
||||
inputs: [QueryTimeDocumentClassifier]
|
||||
|
||||
- name: indexing_pipeline
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: PDFConverter
|
||||
inputs: [File]
|
||||
@ -77,7 +64,6 @@ pipelines:
|
||||
inputs: [ESRetriever]
|
||||
|
||||
- name: indexing_text_pipeline
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: TextConverter
|
||||
inputs: [File]
|
||||
@ -89,7 +75,6 @@ pipelines:
|
||||
inputs: [ESRetriever]
|
||||
|
||||
- name: indexing_pipeline_with_classifier
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: PDFConverter
|
||||
inputs: [File]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
version: '1.1.0'
|
||||
version: 'unstable'
|
||||
|
||||
components:
|
||||
- name: DPRRetriever
|
||||
@ -19,7 +19,6 @@ components:
|
||||
|
||||
pipelines:
|
||||
- name: indexing_pipeline
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: PDFConverter
|
||||
inputs: [File]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
version: '1.1.0'
|
||||
version: 'unstable'
|
||||
|
||||
components:
|
||||
- name: DPRRetriever
|
||||
@ -13,7 +13,6 @@ components:
|
||||
|
||||
pipelines:
|
||||
- name: query_pipeline
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: DPRRetriever
|
||||
inputs: [Query]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
version: '1.1.0'
|
||||
version: 'unstable'
|
||||
|
||||
components:
|
||||
- name: Reader
|
||||
@ -17,7 +17,6 @@ components:
|
||||
|
||||
pipelines:
|
||||
- name: query_pipeline
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: Retriever
|
||||
inputs: [Query]
|
||||
|
||||
46
test/samples/pipeline/test_ray_pipeline.yaml
Normal file
46
test/samples/pipeline/test_ray_pipeline.yaml
Normal file
@ -0,0 +1,46 @@
|
||||
version: 'unstable'
|
||||
|
||||
components:
|
||||
- name: Reader
|
||||
type: FARMReader
|
||||
params:
|
||||
no_ans_boost: -10
|
||||
model_name_or_path: deepset/roberta-base-squad2
|
||||
num_processes: 0
|
||||
- name: ESRetriever
|
||||
type: ElasticsearchRetriever
|
||||
params:
|
||||
document_store: DocumentStore
|
||||
- name: DocumentStore
|
||||
type: ElasticsearchDocumentStore
|
||||
params:
|
||||
index: haystack_test
|
||||
label_index: haystack_test_label
|
||||
- name: PDFConverter
|
||||
type: PDFToTextConverter
|
||||
params:
|
||||
remove_numeric_tables: false
|
||||
- name: Preprocessor
|
||||
type: PreProcessor
|
||||
params:
|
||||
clean_whitespace: true
|
||||
- name: IndexTimeDocumentClassifier
|
||||
type: TransformersDocumentClassifier
|
||||
params:
|
||||
batch_size: 16
|
||||
use_gpu: false
|
||||
- name: QueryTimeDocumentClassifier
|
||||
type: TransformersDocumentClassifier
|
||||
params:
|
||||
use_gpu: false
|
||||
|
||||
|
||||
pipelines:
|
||||
- name: ray_query_pipeline
|
||||
type: RayPipeline
|
||||
nodes:
|
||||
- name: ESRetriever
|
||||
replicas: 2
|
||||
inputs: [ Query ]
|
||||
- name: Reader
|
||||
inputs: [ ESRetriever ]
|
||||
@ -3,7 +3,7 @@ from haystack.nodes import FARMReader
|
||||
from haystack.modeling.data_handler.processor import UnlabeledTextProcessor
|
||||
import torch
|
||||
|
||||
from conftest import SAMPLES_PATH
|
||||
from .conftest import SAMPLES_PATH
|
||||
|
||||
|
||||
def create_checkpoint(model):
|
||||
|
||||
@ -8,7 +8,7 @@ from unittest.mock import Mock
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch.exceptions import RequestError
|
||||
|
||||
from conftest import (
|
||||
from .conftest import (
|
||||
deepset_cloud_fixture,
|
||||
get_document_store,
|
||||
MOCK_DC,
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
import pytest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
from haystack.document_stores.memory import InMemoryDocumentStore
|
||||
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.nodes.answer_generator.transformers import RAGenerator, RAGeneratorType
|
||||
from haystack.nodes.retriever.dense import EmbeddingRetriever
|
||||
from haystack.nodes.preprocessor import PreProcessor
|
||||
from haystack.nodes.evaluator import EvalAnswers, EvalDocuments
|
||||
from haystack.nodes.query_classifier.transformers import TransformersQueryClassifier
|
||||
@ -19,10 +16,9 @@ from haystack.pipelines.standard_pipelines import (
|
||||
RetrieverQuestionGenerationPipeline,
|
||||
TranslationWrapperPipeline,
|
||||
)
|
||||
from haystack.nodes.summarizer.transformers import TransformersSummarizer
|
||||
from haystack.schema import Answer, Document, EvaluationResult, Label, MultiLabel, Span
|
||||
|
||||
from conftest import SAMPLES_PATH
|
||||
from .conftest import SAMPLES_PATH
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Causes OOM on windows github runner")
|
||||
|
||||
@ -13,7 +13,7 @@ from haystack.document_stores.weaviate import WeaviateDocumentStore
|
||||
from haystack.pipelines import Pipeline
|
||||
from haystack.nodes.retriever.dense import EmbeddingRetriever
|
||||
|
||||
from conftest import ensure_ids_are_correct_uuids
|
||||
from .conftest import ensure_ids_are_correct_uuids
|
||||
|
||||
|
||||
DOCUMENTS = [
|
||||
|
||||
@ -13,7 +13,7 @@ from haystack.nodes import (
|
||||
ParsrConverter,
|
||||
)
|
||||
|
||||
from conftest import SAMPLES_PATH
|
||||
from .conftest import SAMPLES_PATH
|
||||
|
||||
|
||||
@pytest.mark.tika
|
||||
|
||||
@ -3,9 +3,9 @@ from pathlib import Path
|
||||
from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT_TYPES
|
||||
|
||||
|
||||
def test_filetype_classifier_single_file(tmpdir):
|
||||
def test_filetype_classifier_single_file(tmp_path):
|
||||
node = FileTypeClassifier()
|
||||
test_files = [tmpdir / f"test.{extension}" for extension in DEFAULT_TYPES]
|
||||
test_files = [tmp_path / f"test.{extension}" for extension in DEFAULT_TYPES]
|
||||
|
||||
for edge_index, test_file in enumerate(test_files):
|
||||
output, edge = node.run(test_file)
|
||||
@ -13,35 +13,35 @@ def test_filetype_classifier_single_file(tmpdir):
|
||||
assert output == {"file_paths": [test_file]}
|
||||
|
||||
|
||||
def test_filetype_classifier_many_files(tmpdir):
|
||||
def test_filetype_classifier_many_files(tmp_path):
|
||||
node = FileTypeClassifier()
|
||||
|
||||
for edge_index, extension in enumerate(DEFAULT_TYPES):
|
||||
test_files = [tmpdir / f"test_{idx}.{extension}" for idx in range(10)]
|
||||
test_files = [tmp_path / f"test_{idx}.{extension}" for idx in range(10)]
|
||||
|
||||
output, edge = node.run(test_files)
|
||||
assert edge == f"output_{edge_index+1}"
|
||||
assert output == {"file_paths": test_files}
|
||||
|
||||
|
||||
def test_filetype_classifier_many_files_mixed_extensions(tmpdir):
|
||||
def test_filetype_classifier_many_files_mixed_extensions(tmp_path):
|
||||
node = FileTypeClassifier()
|
||||
test_files = [tmpdir / f"test.{extension}" for extension in DEFAULT_TYPES]
|
||||
test_files = [tmp_path / f"test.{extension}" for extension in DEFAULT_TYPES]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
node.run(test_files)
|
||||
|
||||
|
||||
def test_filetype_classifier_unsupported_extension(tmpdir):
|
||||
def test_filetype_classifier_unsupported_extension(tmp_path):
|
||||
node = FileTypeClassifier()
|
||||
test_file = tmpdir / f"test.really_weird_extension"
|
||||
test_file = tmp_path / f"test.really_weird_extension"
|
||||
with pytest.raises(ValueError):
|
||||
node.run(test_file)
|
||||
|
||||
|
||||
def test_filetype_classifier_custom_extensions(tmpdir):
|
||||
def test_filetype_classifier_custom_extensions(tmp_path):
|
||||
node = FileTypeClassifier(supported_types=["my_extension"])
|
||||
test_file = tmpdir / f"test.my_extension"
|
||||
test_file = tmp_path / f"test.my_extension"
|
||||
output, edge = node.run(test_file)
|
||||
assert edge == f"output_1"
|
||||
assert output == {"file_paths": [test_file]}
|
||||
|
||||
@ -9,7 +9,7 @@ from haystack.nodes.answer_generator import Seq2SeqGenerator
|
||||
from haystack.pipelines import TranslationWrapperPipeline, GenerativeQAPipeline
|
||||
|
||||
|
||||
from conftest import DOCS_WITH_EMBEDDINGS
|
||||
from .conftest import DOCS_WITH_EMBEDDINGS
|
||||
|
||||
|
||||
# Keeping few (retriever,document_store) combination to reduce test time
|
||||
|
||||
@ -603,7 +603,7 @@ def test_dpr_context_only():
|
||||
assert tensor_names == ["passage_input_ids", "passage_segment_ids", "passage_attention_mask", "label_ids"]
|
||||
|
||||
|
||||
def test_dpr_processor_save_load():
|
||||
def test_dpr_processor_save_load(tmp_path):
|
||||
d = {
|
||||
"query": "big little lies season 2 how many episodes ?",
|
||||
"passages": [
|
||||
@ -646,9 +646,9 @@ def test_dpr_processor_save_load():
|
||||
metric="text_similarity_metric",
|
||||
shuffle_negatives=False,
|
||||
)
|
||||
processor.save(save_dir="testsave/dpr_processor")
|
||||
processor.save(save_dir=f"{tmp_path}/testsave/dpr_processor")
|
||||
dataset, tensor_names, _ = processor.dataset_from_dicts(dicts=[d], return_baskets=False)
|
||||
loadedprocessor = TextSimilarityProcessor.load_from_dir(load_dir="testsave/dpr_processor")
|
||||
loadedprocessor = TextSimilarityProcessor.load_from_dir(load_dir=f"{tmp_path}/testsave/dpr_processor")
|
||||
dataset2, tensor_names, _ = loadedprocessor.dataset_from_dicts(dicts=[d], return_baskets=False)
|
||||
assert np.array_equal(dataset.tensors[0], dataset2.tensors[0])
|
||||
|
||||
@ -667,7 +667,7 @@ def test_dpr_processor_save_load():
|
||||
{"query": "facebook/dpr-question_encoder-single-nq-base", "passage": "facebook/dpr-ctx_encoder-single-nq-base"},
|
||||
],
|
||||
)
|
||||
def test_dpr_processor_save_load_non_bert_tokenizer(query_and_passage_model):
|
||||
def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path, query_and_passage_model):
|
||||
"""
|
||||
This test compares 1) a model that was loaded from model hub with
|
||||
2) a model from model hub that was saved to disk and then loaded from disk and
|
||||
@ -729,7 +729,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(query_and_passage_model):
|
||||
model.connect_heads_with_processor(processor.tasks, require_labels=False)
|
||||
|
||||
# save model that was loaded from model hub to disk
|
||||
save_dir = "testsave/dpr_model"
|
||||
save_dir = f"{tmp_path}/testsave/dpr_model"
|
||||
query_encoder_dir = "query_encoder"
|
||||
passage_encoder_dir = "passage_encoder"
|
||||
model.save(Path(save_dir), lm1_name=query_encoder_dir, lm2_name=passage_encoder_dir)
|
||||
@ -841,7 +841,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(query_and_passage_model):
|
||||
assert np.array_equal(all_embeddings["query"][0], all_embeddings2["query"][0])
|
||||
|
||||
# save the model that was loaded from disk to disk
|
||||
save_dir = "testsave/dpr_model"
|
||||
save_dir = f"{tmp_path}/testsave/dpr_model"
|
||||
query_encoder_dir = "query_encoder"
|
||||
passage_encoder_dir = "passage_encoder"
|
||||
loaded_model.save(Path(save_dir), lm1_name=query_encoder_dir, lm2_name=passage_encoder_dir)
|
||||
|
||||
@ -6,7 +6,7 @@ from transformers import AutoTokenizer
|
||||
from haystack.modeling.data_handler.processor import SquadProcessor
|
||||
from haystack.modeling.model.tokenization import Tokenizer
|
||||
|
||||
from conftest import SAMPLES_PATH
|
||||
from .conftest import SAMPLES_PATH
|
||||
|
||||
|
||||
# during inference (parameter return_baskets = False) we do not convert labels
|
||||
|
||||
@ -6,10 +6,10 @@ from haystack.modeling.model.tokenization import Tokenizer
|
||||
from haystack.modeling.utils import set_all_seeds
|
||||
import torch
|
||||
|
||||
from conftest import SAMPLES_PATH
|
||||
from .conftest import SAMPLES_PATH
|
||||
|
||||
|
||||
def test_processor_saving_loading(caplog):
|
||||
def test_processor_saving_loading(tmp_path, caplog):
|
||||
if caplog is not None:
|
||||
caplog.set_level(logging.CRITICAL)
|
||||
|
||||
@ -31,7 +31,7 @@ def test_processor_saving_loading(caplog):
|
||||
dicts = processor.file_to_dicts(file=SAMPLES_PATH / "qa" / "dev-sample.json")
|
||||
data, tensor_names, _ = processor.dataset_from_dicts(dicts=dicts, indices=[1])
|
||||
|
||||
save_dir = Path("testsave/processor")
|
||||
save_dir = tmp_path / Path("testsave/processor")
|
||||
processor.save(save_dir)
|
||||
|
||||
processor = processor.load_from_dir(save_dir)
|
||||
|
||||
@ -9,6 +9,7 @@ import pandas as pd
|
||||
import pytest
|
||||
from requests import PreparedRequest
|
||||
import responses
|
||||
import logging
|
||||
import yaml
|
||||
|
||||
from haystack import __version__, Document, Answer, JoinAnswers
|
||||
@ -19,21 +20,33 @@ from haystack.nodes.other.join_docs import JoinDocuments
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.nodes.retriever.base import BaseRetriever
|
||||
from haystack.nodes.retriever.sparse import ElasticsearchRetriever
|
||||
from haystack.pipelines import Pipeline, DocumentSearchPipeline, RootNode, ExtractiveQAPipeline
|
||||
from haystack.pipelines.config import _validate_user_input, validate_config
|
||||
from haystack.pipelines import Pipeline, DocumentSearchPipeline, RootNode
|
||||
from haystack.pipelines.config import validate_config_strings
|
||||
from haystack.pipelines.utils import generate_code
|
||||
from haystack.errors import PipelineConfigError
|
||||
from haystack.nodes import DensePassageRetriever, EmbeddingRetriever, RouteDocuments, PreProcessor, TextConverter
|
||||
|
||||
from conftest import MOCK_DC, DC_API_ENDPOINT, DC_API_KEY, DC_TEST_INDEX, SAMPLES_PATH, deepset_cloud_fixture
|
||||
from haystack.utils.deepsetcloud import DeepsetCloudError
|
||||
|
||||
from .conftest import (
|
||||
MOCK_DC,
|
||||
DC_API_ENDPOINT,
|
||||
DC_API_KEY,
|
||||
DC_TEST_INDEX,
|
||||
SAMPLES_PATH,
|
||||
MockDocumentStore,
|
||||
MockRetriever,
|
||||
deepset_cloud_fixture,
|
||||
)
|
||||
|
||||
|
||||
class ParentComponent(BaseComponent):
|
||||
outgoing_edges = 1
|
||||
|
||||
def __init__(self, dependent: BaseComponent) -> None:
|
||||
super().__init__()
|
||||
self.set_config(dependent=dependent)
|
||||
|
||||
def run(*args, **kwargs):
|
||||
logging.info("ParentComponent run() was called")
|
||||
|
||||
|
||||
class ParentComponent2(BaseComponent):
|
||||
@ -41,157 +54,38 @@ class ParentComponent2(BaseComponent):
|
||||
|
||||
def __init__(self, dependent: BaseComponent) -> None:
|
||||
super().__init__()
|
||||
self.set_config(dependent=dependent)
|
||||
|
||||
def run(*args, **kwargs):
|
||||
logging.info("ParentComponent2 run() was called")
|
||||
|
||||
|
||||
class ChildComponent(BaseComponent):
|
||||
def __init__(self, some_key: str = None) -> None:
|
||||
super().__init__()
|
||||
self.set_config(some_key=some_key)
|
||||
|
||||
def run(*args, **kwargs):
|
||||
logging.info("ChildComponent run() was called")
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||
def test_load_and_save_yaml(document_store, tmp_path):
|
||||
# test correct load of indexing pipeline from yaml
|
||||
pipeline = Pipeline.load_from_yaml(
|
||||
SAMPLES_PATH / "pipeline" / "test_pipeline.yaml", pipeline_name="indexing_pipeline"
|
||||
)
|
||||
pipeline.run(file_paths=SAMPLES_PATH / "pdf" / "sample_pdf_1.pdf")
|
||||
# test correct load of query pipeline from yaml
|
||||
pipeline = Pipeline.load_from_yaml(SAMPLES_PATH / "pipeline" / "test_pipeline.yaml", pipeline_name="query_pipeline")
|
||||
prediction = pipeline.run(
|
||||
query="Who made the PDF specification?", params={"ESRetriever": {"top_k": 10}, "Reader": {"top_k": 3}}
|
||||
)
|
||||
assert prediction["query"] == "Who made the PDF specification?"
|
||||
assert prediction["answers"][0].answer == "Adobe Systems"
|
||||
assert "_debug" not in prediction.keys()
|
||||
class DummyRetriever(MockRetriever):
|
||||
def __init__(self, document_store):
|
||||
self.document_store = document_store
|
||||
|
||||
# test invalid pipeline name
|
||||
with pytest.raises(Exception):
|
||||
Pipeline.load_from_yaml(path=SAMPLES_PATH / "pipeline" / "test_pipeline.yaml", pipeline_name="invalid")
|
||||
# test config export
|
||||
pipeline.save_to_yaml(tmp_path / "test.yaml")
|
||||
with open(tmp_path / "test.yaml", "r", encoding="utf-8") as stream:
|
||||
saved_yaml = stream.read()
|
||||
expected_yaml = f"""
|
||||
components:
|
||||
- name: ESRetriever
|
||||
params:
|
||||
document_store: ElasticsearchDocumentStore
|
||||
type: ElasticsearchRetriever
|
||||
- name: ElasticsearchDocumentStore
|
||||
params:
|
||||
index: haystack_test
|
||||
label_index: haystack_test_label
|
||||
type: ElasticsearchDocumentStore
|
||||
- name: Reader
|
||||
params:
|
||||
model_name_or_path: deepset/roberta-base-squad2
|
||||
no_ans_boost: -10
|
||||
num_processes: 0
|
||||
type: FARMReader
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- inputs:
|
||||
- Query
|
||||
name: ESRetriever
|
||||
- inputs:
|
||||
- ESRetriever
|
||||
name: Reader
|
||||
type: Pipeline
|
||||
version: {__version__}
|
||||
"""
|
||||
assert saved_yaml.replace(" ", "").replace("\n", "") == expected_yaml.replace(" ", "").replace("\n", "")
|
||||
def run(self):
|
||||
test = "test"
|
||||
return {"test": test}, "output_1"
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||
def test_load_and_save_yaml_prebuilt_pipelines(document_store, tmp_path):
|
||||
# populating index
|
||||
pipeline = Pipeline.load_from_yaml(
|
||||
SAMPLES_PATH / "pipeline" / "test_pipeline.yaml", pipeline_name="indexing_pipeline"
|
||||
)
|
||||
pipeline.run(file_paths=SAMPLES_PATH / "pdf" / "sample_pdf_1.pdf")
|
||||
# test correct load of query pipeline from yaml
|
||||
pipeline = ExtractiveQAPipeline.load_from_yaml(
|
||||
SAMPLES_PATH / "pipeline" / "test_pipeline.yaml", pipeline_name="query_pipeline"
|
||||
)
|
||||
prediction = pipeline.run(
|
||||
query="Who made the PDF specification?", params={"ESRetriever": {"top_k": 10}, "Reader": {"top_k": 3}}
|
||||
)
|
||||
assert prediction["query"] == "Who made the PDF specification?"
|
||||
assert prediction["answers"][0].answer == "Adobe Systems"
|
||||
assert "_debug" not in prediction.keys()
|
||||
|
||||
# test invalid pipeline name
|
||||
with pytest.raises(Exception):
|
||||
ExtractiveQAPipeline.load_from_yaml(
|
||||
path=SAMPLES_PATH / "pipeline" / "test_pipeline.yaml", pipeline_name="invalid"
|
||||
)
|
||||
# test config export
|
||||
pipeline.save_to_yaml(tmp_path / "test.yaml")
|
||||
with open(tmp_path / "test.yaml", "r", encoding="utf-8") as stream:
|
||||
saved_yaml = stream.read()
|
||||
expected_yaml = f"""
|
||||
components:
|
||||
- name: ESRetriever
|
||||
params:
|
||||
document_store: ElasticsearchDocumentStore
|
||||
type: ElasticsearchRetriever
|
||||
- name: ElasticsearchDocumentStore
|
||||
params:
|
||||
index: haystack_test
|
||||
label_index: haystack_test_label
|
||||
type: ElasticsearchDocumentStore
|
||||
- name: Reader
|
||||
params:
|
||||
model_name_or_path: deepset/roberta-base-squad2
|
||||
no_ans_boost: -10
|
||||
num_processes: 0
|
||||
type: FARMReader
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- inputs:
|
||||
- Query
|
||||
name: ESRetriever
|
||||
- inputs:
|
||||
- ESRetriever
|
||||
name: Reader
|
||||
type: Pipeline
|
||||
version: {__version__}
|
||||
"""
|
||||
assert saved_yaml.replace(" ", "").replace("\n", "") == expected_yaml.replace(" ", "").replace("\n", "")
|
||||
|
||||
|
||||
def test_load_tfidfretriever_yaml(tmp_path):
|
||||
documents = [
|
||||
{
|
||||
"content": "A Doc specifically talking about haystack. Haystack can be used to scale QA models to large document collections."
|
||||
}
|
||||
]
|
||||
pipeline = Pipeline.load_from_yaml(
|
||||
SAMPLES_PATH / "pipeline" / "test_pipeline_tfidfretriever.yaml", pipeline_name="query_pipeline"
|
||||
)
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
pipeline.run(
|
||||
query="What can be used to scale QA models to large document collections?",
|
||||
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}},
|
||||
)
|
||||
exception_raised = str(exc_info.value)
|
||||
assert "Retrieval requires dataframe df and tf-idf matrix" in exception_raised
|
||||
|
||||
pipeline.get_node(name="Retriever").document_store.write_documents(documents=documents)
|
||||
prediction = pipeline.run(
|
||||
query="What can be used to scale QA models to large document collections?",
|
||||
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}},
|
||||
)
|
||||
assert prediction["query"] == "What can be used to scale QA models to large document collections?"
|
||||
assert prediction["answers"][0].answer == "haystack"
|
||||
class JoinNode(RootNode):
|
||||
def run(self, output=None, inputs=None):
|
||||
if inputs:
|
||||
output = ""
|
||||
for input_dict in inputs:
|
||||
output += input_dict["output"]
|
||||
return {"output": output}, "output_1"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.elasticsearch
|
||||
def test_to_code_creates_same_pipelines():
|
||||
index_pipeline = Pipeline.load_from_yaml(
|
||||
@ -202,6 +96,7 @@ def test_to_code_creates_same_pipelines():
|
||||
)
|
||||
query_pipeline_code = query_pipeline.to_code(pipeline_variable_name="query_pipeline_from_code")
|
||||
index_pipeline_code = index_pipeline.to_code(pipeline_variable_name="index_pipeline_from_code")
|
||||
|
||||
exec(query_pipeline_code)
|
||||
exec(index_pipeline_code)
|
||||
assert locals()["query_pipeline_from_code"] is not None
|
||||
@ -216,8 +111,7 @@ def test_get_config_creates_dependent_component():
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=parent, name="parent", inputs=["Query"])
|
||||
|
||||
expected_pipelines = [{"name": "query", "type": "Pipeline", "nodes": [{"name": "parent", "inputs": ["Query"]}]}]
|
||||
|
||||
expected_pipelines = [{"name": "query", "nodes": [{"name": "parent", "inputs": ["Query"]}]}]
|
||||
expected_components = [
|
||||
{"name": "parent", "type": "ParentComponent", "params": {"dependent": "ChildComponent"}},
|
||||
{"name": "ChildComponent", "type": "ChildComponent", "params": {}},
|
||||
@ -249,7 +143,6 @@ def test_get_config_creates_only_one_dependent_component_referenced_by_multiple_
|
||||
expected_pipelines = [
|
||||
{
|
||||
"name": "query",
|
||||
"type": "Pipeline",
|
||||
"nodes": [
|
||||
{"name": "Parent1", "inputs": ["Query"]},
|
||||
{"name": "Parent2", "inputs": ["Query"]},
|
||||
@ -286,7 +179,6 @@ def test_get_config_creates_two_different_dependent_components_of_same_type():
|
||||
expected_pipelines = [
|
||||
{
|
||||
"name": "query",
|
||||
"type": "Pipeline",
|
||||
"nodes": [
|
||||
{"name": "ParentA", "inputs": ["Query"]},
|
||||
{"name": "ParentB", "inputs": ["Query"]},
|
||||
@ -302,8 +194,34 @@ def test_get_config_creates_two_different_dependent_components_of_same_type():
|
||||
assert expected_component in config["components"]
|
||||
|
||||
|
||||
def test_get_config_component_with_superclass_arguments():
|
||||
class CustomBaseDocumentStore(MockDocumentStore):
|
||||
def __init__(self, base_parameter: str):
|
||||
self.base_parameter = base_parameter
|
||||
|
||||
class CustomDocumentStore(CustomBaseDocumentStore):
|
||||
def __init__(self, sub_parameter: int):
|
||||
super().__init__(base_parameter="something")
|
||||
self.sub_parameter = sub_parameter
|
||||
|
||||
class CustomRetriever(MockRetriever):
|
||||
def __init__(self, document_store):
|
||||
super().__init__()
|
||||
self.document_store = document_store
|
||||
|
||||
document_store = CustomDocumentStore(sub_parameter=10)
|
||||
retriever = CustomRetriever(document_store=document_store)
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(retriever, name="Retriever", inputs=["Query"])
|
||||
|
||||
pipeline.get_config()
|
||||
assert pipeline.get_document_store().sub_parameter == 10
|
||||
assert pipeline.get_document_store().base_parameter == "something"
|
||||
|
||||
|
||||
def test_generate_code_simple_pipeline():
|
||||
config = {
|
||||
"version": "unstable",
|
||||
"components": [
|
||||
{
|
||||
"name": "retri",
|
||||
@ -316,7 +234,7 @@ def test_generate_code_simple_pipeline():
|
||||
"params": {"index": "my-index"},
|
||||
},
|
||||
],
|
||||
"pipelines": [{"name": "query", "type": "Pipeline", "nodes": [{"name": "retri", "inputs": ["Query"]}]}],
|
||||
"pipelines": [{"name": "query", "nodes": [{"name": "retri", "inputs": ["Query"]}]}],
|
||||
}
|
||||
|
||||
code = generate_code(pipeline_config=config, pipeline_variable_name="p", generate_imports=False)
|
||||
@ -331,15 +249,15 @@ def test_generate_code_simple_pipeline():
|
||||
|
||||
def test_generate_code_imports():
|
||||
pipeline_config = {
|
||||
"version": "unstable",
|
||||
"components": [
|
||||
{"name": "DocumentStore", "type": "ElasticsearchDocumentStore"},
|
||||
{"name": "retri", "type": "ElasticsearchRetriever", "params": {"document_store": "DocumentStore"}},
|
||||
{"name": "retri2", "type": "EmbeddingRetriever", "params": {"document_store": "DocumentStore"}},
|
||||
{"name": "retri2", "type": "TfidfRetriever", "params": {"document_store": "DocumentStore"}},
|
||||
],
|
||||
"pipelines": [
|
||||
{
|
||||
"name": "Query",
|
||||
"type": "Pipeline",
|
||||
"nodes": [{"name": "retri", "inputs": ["Query"]}, {"name": "retri2", "inputs": ["Query"]}],
|
||||
}
|
||||
],
|
||||
@ -348,12 +266,12 @@ def test_generate_code_imports():
|
||||
code = generate_code(pipeline_config=pipeline_config, pipeline_variable_name="p", generate_imports=True)
|
||||
assert code == (
|
||||
"from haystack.document_stores import ElasticsearchDocumentStore\n"
|
||||
"from haystack.nodes import ElasticsearchRetriever, EmbeddingRetriever\n"
|
||||
"from haystack.nodes import ElasticsearchRetriever, TfidfRetriever\n"
|
||||
"from haystack.pipelines import Pipeline\n"
|
||||
"\n"
|
||||
"document_store = ElasticsearchDocumentStore()\n"
|
||||
"retri = ElasticsearchRetriever(document_store=document_store)\n"
|
||||
"retri_2 = EmbeddingRetriever(document_store=document_store)\n"
|
||||
"retri_2 = TfidfRetriever(document_store=document_store)\n"
|
||||
"\n"
|
||||
"p = Pipeline()\n"
|
||||
'p.add_node(component=retri, name="retri", inputs=["Query"])\n'
|
||||
@ -363,11 +281,12 @@ def test_generate_code_imports():
|
||||
|
||||
def test_generate_code_imports_no_pipeline_cls():
|
||||
pipeline_config = {
|
||||
"version": "unstable",
|
||||
"components": [
|
||||
{"name": "DocumentStore", "type": "ElasticsearchDocumentStore"},
|
||||
{"name": "retri", "type": "ElasticsearchRetriever", "params": {"document_store": "DocumentStore"}},
|
||||
],
|
||||
"pipelines": [{"name": "Query", "type": "Pipeline", "nodes": [{"name": "retri", "inputs": ["Query"]}]}],
|
||||
"pipelines": [{"name": "Query", "nodes": [{"name": "retri", "inputs": ["Query"]}]}],
|
||||
}
|
||||
|
||||
code = generate_code(
|
||||
@ -390,11 +309,12 @@ def test_generate_code_imports_no_pipeline_cls():
|
||||
|
||||
def test_generate_code_comment():
|
||||
pipeline_config = {
|
||||
"version": "unstable",
|
||||
"components": [
|
||||
{"name": "DocumentStore", "type": "ElasticsearchDocumentStore"},
|
||||
{"name": "retri", "type": "ElasticsearchRetriever", "params": {"document_store": "DocumentStore"}},
|
||||
],
|
||||
"pipelines": [{"name": "Query", "type": "Pipeline", "nodes": [{"name": "retri", "inputs": ["Query"]}]}],
|
||||
"pipelines": [{"name": "Query", "nodes": [{"name": "retri", "inputs": ["Query"]}]}],
|
||||
}
|
||||
|
||||
comment = "This is my comment\n...and here is a new line"
|
||||
@ -416,17 +336,17 @@ def test_generate_code_comment():
|
||||
|
||||
def test_generate_code_is_component_order_invariant():
|
||||
pipeline_config = {
|
||||
"version": "unstable",
|
||||
"pipelines": [
|
||||
{
|
||||
"name": "Query",
|
||||
"type": "Pipeline",
|
||||
"nodes": [
|
||||
{"name": "EsRetriever", "inputs": ["Query"]},
|
||||
{"name": "EmbeddingRetriever", "inputs": ["Query"]},
|
||||
{"name": "JoinResults", "inputs": ["EsRetriever", "EmbeddingRetriever"]},
|
||||
],
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
doc_store = {"name": "ElasticsearchDocumentStore", "type": "ElasticsearchDocumentStore"}
|
||||
@ -471,52 +391,45 @@ def test_generate_code_is_component_order_invariant():
|
||||
|
||||
@pytest.mark.parametrize("input", ["\btest", " test", "#test", "+test", "\ttest", "\ntest", "test()"])
|
||||
def test_validate_user_input_invalid(input):
|
||||
with pytest.raises(ValueError, match="is not a valid config variable name"):
|
||||
_validate_user_input(input)
|
||||
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
|
||||
validate_config_strings(input)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input", ["test", "testName", "test_name", "test-name", "test-name1234", "http://localhost:8000/my-path"]
|
||||
)
|
||||
def test_validate_user_input_valid(input):
|
||||
_validate_user_input(input)
|
||||
validate_config_strings(input)
|
||||
|
||||
|
||||
def test_validate_pipeline_config_invalid_component_name():
|
||||
with pytest.raises(ValueError, match="is not a valid config variable name"):
|
||||
validate_config({"components": [{"name": "\btest"}]})
|
||||
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
|
||||
validate_config_strings({"components": [{"name": "\btest"}]})
|
||||
|
||||
|
||||
def test_validate_pipeline_config_invalid_component_type():
|
||||
with pytest.raises(ValueError, match="is not a valid config variable name"):
|
||||
validate_config({"components": [{"name": "test", "type": "\btest"}]})
|
||||
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
|
||||
validate_config_strings({"components": [{"name": "test", "type": "\btest"}]})
|
||||
|
||||
|
||||
def test_validate_pipeline_config_invalid_component_param():
|
||||
with pytest.raises(ValueError, match="is not a valid config variable name"):
|
||||
validate_config({"components": [{"name": "test", "type": "test", "params": {"key": "\btest"}}]})
|
||||
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
|
||||
validate_config_strings({"components": [{"name": "test", "type": "test", "params": {"key": "\btest"}}]})
|
||||
|
||||
|
||||
def test_validate_pipeline_config_invalid_component_param_key():
|
||||
with pytest.raises(ValueError, match="is not a valid config variable name"):
|
||||
validate_config({"components": [{"name": "test", "type": "test", "params": {"\btest": "test"}}]})
|
||||
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
|
||||
validate_config_strings({"components": [{"name": "test", "type": "test", "params": {"\btest": "test"}}]})
|
||||
|
||||
|
||||
def test_validate_pipeline_config_invalid_pipeline_name():
|
||||
with pytest.raises(ValueError, match="is not a valid config variable name"):
|
||||
validate_config({"components": [{"name": "test", "type": "test"}], "pipelines": [{"name": "\btest"}]})
|
||||
|
||||
|
||||
def test_validate_pipeline_config_invalid_pipeline_type():
|
||||
with pytest.raises(ValueError, match="is not a valid config variable name"):
|
||||
validate_config(
|
||||
{"components": [{"name": "test", "type": "test"}], "pipelines": [{"name": "test", "type": "\btest"}]}
|
||||
)
|
||||
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
|
||||
validate_config_strings({"components": [{"name": "test", "type": "test"}], "pipelines": [{"name": "\btest"}]})
|
||||
|
||||
|
||||
def test_validate_pipeline_config_invalid_pipeline_node_name():
|
||||
with pytest.raises(ValueError, match="is not a valid config variable name"):
|
||||
validate_config(
|
||||
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
|
||||
validate_config_strings(
|
||||
{
|
||||
"components": [{"name": "test", "type": "test"}],
|
||||
"pipelines": [{"name": "test", "type": "test", "nodes": [{"name": "\btest"}]}],
|
||||
@ -525,8 +438,8 @@ def test_validate_pipeline_config_invalid_pipeline_node_name():
|
||||
|
||||
|
||||
def test_validate_pipeline_config_invalid_pipeline_node_inputs():
|
||||
with pytest.raises(ValueError, match="is not a valid config variable name"):
|
||||
validate_config(
|
||||
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
|
||||
validate_config_strings(
|
||||
{
|
||||
"components": [{"name": "test", "type": "test"}],
|
||||
"pipelines": [{"name": "test", "type": "test", "nodes": [{"name": "test", "inputs": ["\btest"]}]}],
|
||||
@ -534,6 +447,15 @@ def test_validate_pipeline_config_invalid_pipeline_node_inputs():
|
||||
)
|
||||
|
||||
|
||||
def test_validate_pipeline_config_recursive_config():
|
||||
pipeline_config = {}
|
||||
node = {"config": pipeline_config}
|
||||
pipeline_config["node"] = node
|
||||
|
||||
with pytest.raises(PipelineConfigError, match="recursive"):
|
||||
validate_config_strings(pipeline_config)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(deepset_cloud_fixture.__name__)
|
||||
@responses.activate
|
||||
def test_load_from_deepset_cloud_query():
|
||||
@ -1212,46 +1134,42 @@ def test_undeploy_on_deepset_cloud_timeout():
|
||||
)
|
||||
|
||||
|
||||
# @pytest.mark.slow
|
||||
# @pytest.mark.elasticsearch
|
||||
# @pytest.mark.parametrize(
|
||||
# "retriever_with_docs, document_store_with_docs",
|
||||
# [("elasticsearch", "elasticsearch")],
|
||||
# indirect=True,
|
||||
# )
|
||||
@pytest.mark.parametrize(
|
||||
"retriever_with_docs,document_store_with_docs",
|
||||
[
|
||||
("dpr", "elasticsearch"),
|
||||
("dpr", "faiss"),
|
||||
("dpr", "memory"),
|
||||
("dpr", "milvus1"),
|
||||
("embedding", "elasticsearch"),
|
||||
("embedding", "faiss"),
|
||||
("embedding", "memory"),
|
||||
("embedding", "milvus1"),
|
||||
("elasticsearch", "elasticsearch"),
|
||||
("es_filter_only", "elasticsearch"),
|
||||
("tfidf", "memory"),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_graph_creation(retriever_with_docs, document_store_with_docs):
|
||||
def test_graph_creation_invalid_edge():
|
||||
docstore = MockDocumentStore()
|
||||
retriever = DummyRetriever(document_store=docstore)
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="ES", component=retriever_with_docs, inputs=["Query"])
|
||||
pipeline.add_node(name="DocStore", component=docstore, inputs=["Query"])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["ES.output_2"])
|
||||
with pytest.raises(PipelineConfigError, match="'output_2' from 'DocStore'"):
|
||||
pipeline.add_node(name="Retriever", component=retriever, inputs=["DocStore.output_2"])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["ES.wrong_edge_label"])
|
||||
|
||||
with pytest.raises(Exception):
|
||||
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["InvalidNode"])
|
||||
def test_graph_creation_non_existing_edge():
|
||||
docstore = MockDocumentStore()
|
||||
retriever = DummyRetriever(document_store=docstore)
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="DocStore", component=docstore, inputs=["Query"])
|
||||
|
||||
with pytest.raises(Exception):
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="ES", component=retriever_with_docs, inputs=["InvalidNode"])
|
||||
with pytest.raises(PipelineConfigError, match="'wrong_edge_label' is not a valid edge name"):
|
||||
pipeline.add_node(name="Retriever", component=retriever, inputs=["DocStore.wrong_edge_label"])
|
||||
|
||||
|
||||
def test_graph_creation_invalid_node():
|
||||
docstore = MockDocumentStore()
|
||||
retriever = DummyRetriever(document_store=docstore)
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="DocStore", component=docstore, inputs=["Query"])
|
||||
|
||||
with pytest.raises(PipelineConfigError, match="Cannot find node 'InvalidNode'"):
|
||||
pipeline.add_node(name="Retriever", component=retriever, inputs=["InvalidNode"])
|
||||
|
||||
|
||||
def test_graph_creation_invalid_root_node():
|
||||
docstore = MockDocumentStore()
|
||||
pipeline = Pipeline()
|
||||
|
||||
with pytest.raises(PipelineConfigError, match="Root node 'InvalidNode' is invalid"):
|
||||
pipeline.add_node(name="DocStore", component=docstore, inputs=["InvalidNode"])
|
||||
|
||||
|
||||
def test_parallel_paths_in_pipeline_graph():
|
||||
@ -1414,10 +1332,7 @@ def test_pipeline_components():
|
||||
|
||||
|
||||
def test_pipeline_get_document_store_from_components():
|
||||
class DummyDocumentStore(BaseDocumentStore):
|
||||
pass
|
||||
|
||||
doc_store = DummyDocumentStore()
|
||||
doc_store = MockDocumentStore()
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="A", component=doc_store, inputs=["File"])
|
||||
|
||||
@ -1425,11 +1340,8 @@ def test_pipeline_get_document_store_from_components():
|
||||
|
||||
|
||||
def test_pipeline_get_document_store_from_components_multiple_doc_stores():
|
||||
class DummyDocumentStore(BaseDocumentStore):
|
||||
pass
|
||||
|
||||
doc_store_a = DummyDocumentStore()
|
||||
doc_store_b = DummyDocumentStore()
|
||||
doc_store_a = MockDocumentStore()
|
||||
doc_store_b = MockDocumentStore()
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="A", component=doc_store_a, inputs=["File"])
|
||||
pipeline.add_node(name="B", component=doc_store_b, inputs=["File"])
|
||||
@ -1439,18 +1351,7 @@ def test_pipeline_get_document_store_from_components_multiple_doc_stores():
|
||||
|
||||
|
||||
def test_pipeline_get_document_store_from_retriever():
|
||||
class DummyRetriever(BaseRetriever):
|
||||
def __init__(self, document_store):
|
||||
self.document_store = document_store
|
||||
|
||||
def run(self):
|
||||
test = "test"
|
||||
return {"test": test}, "output_1"
|
||||
|
||||
class DummyDocumentStore(BaseDocumentStore):
|
||||
pass
|
||||
|
||||
doc_store = DummyDocumentStore()
|
||||
doc_store = MockDocumentStore()
|
||||
retriever = DummyRetriever(document_store=doc_store)
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="A", component=retriever, inputs=["Query"])
|
||||
@ -1459,26 +1360,7 @@ def test_pipeline_get_document_store_from_retriever():
|
||||
|
||||
|
||||
def test_pipeline_get_document_store_from_dual_retriever():
|
||||
class DummyRetriever(BaseRetriever):
|
||||
def __init__(self, document_store):
|
||||
self.document_store = document_store
|
||||
|
||||
def run(self):
|
||||
test = "test"
|
||||
return {"test": test}, "output_1"
|
||||
|
||||
class DummyDocumentStore(BaseDocumentStore):
|
||||
pass
|
||||
|
||||
class JoinNode(RootNode):
|
||||
def run(self, output=None, inputs=None):
|
||||
if inputs:
|
||||
output = ""
|
||||
for input_dict in inputs:
|
||||
output += input_dict["output"]
|
||||
return {"output": output}, "output_1"
|
||||
|
||||
doc_store = DummyDocumentStore()
|
||||
doc_store = MockDocumentStore()
|
||||
retriever_a = DummyRetriever(document_store=doc_store)
|
||||
retriever_b = DummyRetriever(document_store=doc_store)
|
||||
pipeline = Pipeline()
|
||||
@ -1490,27 +1372,8 @@ def test_pipeline_get_document_store_from_dual_retriever():
|
||||
|
||||
|
||||
def test_pipeline_get_document_store_multiple_doc_stores_from_dual_retriever():
|
||||
class DummyRetriever(BaseRetriever):
|
||||
def __init__(self, document_store):
|
||||
self.document_store = document_store
|
||||
|
||||
def run(self):
|
||||
test = "test"
|
||||
return {"test": test}, "output_1"
|
||||
|
||||
class DummyDocumentStore(BaseDocumentStore):
|
||||
pass
|
||||
|
||||
class JoinNode(RootNode):
|
||||
def run(self, output=None, inputs=None):
|
||||
if inputs:
|
||||
output = ""
|
||||
for input_dict in inputs:
|
||||
output += input_dict["output"]
|
||||
return {"output": output}, "output_1"
|
||||
|
||||
doc_store_a = DummyDocumentStore()
|
||||
doc_store_b = DummyDocumentStore()
|
||||
doc_store_a = MockDocumentStore()
|
||||
doc_store_b = MockDocumentStore()
|
||||
retriever_a = DummyRetriever(document_store=doc_store_a)
|
||||
retriever_b = DummyRetriever(document_store=doc_store_b)
|
||||
pipeline = Pipeline()
|
||||
|
||||
@ -6,7 +6,19 @@ import pytest
|
||||
from haystack.pipelines import Pipeline, RootNode
|
||||
from haystack.nodes import FARMReader, ElasticsearchRetriever
|
||||
|
||||
from conftest import SAMPLES_PATH
|
||||
from .conftest import SAMPLES_PATH, MockRetriever as BaseMockRetriever, MockReader
|
||||
|
||||
|
||||
class MockRetriever(BaseMockRetriever):
|
||||
def retrieve(self, *args, **kwargs):
|
||||
top_k = None
|
||||
if "top_k" in kwargs.keys():
|
||||
top_k = kwargs["top_k"]
|
||||
elif len(args) > 0:
|
||||
top_k = args[-1]
|
||||
|
||||
if top_k and not isinstance(top_k, int):
|
||||
raise ValueError("TEST ERROR!")
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@ -132,19 +144,34 @@ def test_global_debug_attributes_override_node_ones(document_store_with_docs, tm
|
||||
assert prediction["_debug"]["Reader"]["output"]
|
||||
|
||||
|
||||
def test_invalid_run_args():
|
||||
pipeline = Pipeline.load_from_yaml(SAMPLES_PATH / "pipeline" / "test_pipeline.yaml", pipeline_name="query_pipeline")
|
||||
with pytest.raises(Exception) as exc:
|
||||
pipeline.run(params={"ESRetriever": {"top_k": 10}})
|
||||
assert "run() missing 1 required positional argument: 'query'" in str(exc.value)
|
||||
def test_missing_top_level_arg():
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=MockRetriever(), name="Retriever", inputs=["Query"])
|
||||
pipeline.add_node(component=MockReader(), name="Reader", inputs=["Retriever"])
|
||||
|
||||
with pytest.raises(Exception) as exc:
|
||||
pipeline.run(invalid_query="Who made the PDF specification?", params={"ESRetriever": {"top_k": 10}})
|
||||
pipeline.run(params={"Retriever": {"top_k": 10}})
|
||||
assert "Must provide a 'query' parameter" in str(exc.value)
|
||||
|
||||
|
||||
def test_unexpected_top_level_arg():
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=MockRetriever(), name="Retriever", inputs=["Query"])
|
||||
pipeline.add_node(component=MockReader(), name="Reader", inputs=["Retriever"])
|
||||
|
||||
with pytest.raises(Exception) as exc:
|
||||
pipeline.run(invalid_query="Who made the PDF specification?", params={"Retriever": {"top_k": 10}})
|
||||
assert "run() got an unexpected keyword argument 'invalid_query'" in str(exc.value)
|
||||
|
||||
|
||||
def test_unexpected_node_arg():
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=MockRetriever(), name="Retriever", inputs=["Query"])
|
||||
pipeline.add_node(component=MockReader(), name="Reader", inputs=["Retriever"])
|
||||
|
||||
with pytest.raises(Exception) as exc:
|
||||
pipeline.run(query="Who made the PDF specification?", params={"ESRetriever": {"invalid": 10}})
|
||||
assert "Invalid parameter 'invalid' for the node 'ESRetriever'" in str(exc.value)
|
||||
pipeline.run(query="Who made the PDF specification?", params={"Retriever": {"invalid": 10}})
|
||||
assert "Invalid parameter 'invalid' for the node 'Retriever'" in str(exc.value)
|
||||
|
||||
|
||||
def test_debug_info_propagation():
|
||||
|
||||
670
test/test_pipeline_yaml.py
Normal file
670
test/test_pipeline_yaml.py
Normal file
@ -0,0 +1,670 @@
|
||||
import pytest
|
||||
import json
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
from enum import Enum
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import haystack
|
||||
from haystack import Pipeline
|
||||
from haystack import document_stores
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
from haystack.nodes import _json_schema
|
||||
from haystack.nodes import FileTypeClassifier
|
||||
from haystack.errors import HaystackError, PipelineConfigError, PipelineSchemaError
|
||||
|
||||
from .conftest import SAMPLES_PATH, MockNode, MockDocumentStore, MockReader, MockRetriever
|
||||
from . import conftest
|
||||
|
||||
|
||||
#
|
||||
# Fixtures
|
||||
#
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_json_schema(request, monkeypatch, tmp_path):
|
||||
"""
|
||||
JSON schema with the unstable version and only mocked nodes.
|
||||
"""
|
||||
# Do not patch integration tests
|
||||
if "integration" in request.keywords:
|
||||
return
|
||||
|
||||
# Mock the subclasses list to make it very small, containing only mock nodes
|
||||
monkeypatch.setattr(
|
||||
haystack.nodes._json_schema,
|
||||
"find_subclasses_in_modules",
|
||||
lambda *a, **k: [(conftest, MockDocumentStore), (conftest, MockReader), (conftest, MockRetriever)],
|
||||
)
|
||||
# Point the JSON schema path to tmp_path
|
||||
monkeypatch.setattr(haystack.pipelines.config, "JSON_SCHEMAS_PATH", tmp_path)
|
||||
|
||||
# Generate mock schema in tmp_path
|
||||
filename = f"haystack-pipeline-unstable.schema.json"
|
||||
test_schema = _json_schema.get_json_schema(filename=filename, compatible_versions=["unstable"])
|
||||
|
||||
with open(tmp_path / filename, "w") as schema_file:
|
||||
json.dump(test_schema, schema_file, indent=4)
|
||||
|
||||
|
||||
#
|
||||
# Integration
|
||||
#
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.elasticsearch
|
||||
def test_load_and_save_from_yaml(tmp_path):
|
||||
config_path = SAMPLES_PATH / "pipeline" / "test_pipeline.yaml"
|
||||
|
||||
# Test the indexing pipeline:
|
||||
# Load it
|
||||
indexing_pipeline = Pipeline.load_from_yaml(path=config_path, pipeline_name="indexing_pipeline")
|
||||
|
||||
# Check if it works
|
||||
indexing_pipeline.get_document_store().delete_documents()
|
||||
assert indexing_pipeline.get_document_store().get_document_count() == 0
|
||||
indexing_pipeline.run(file_paths=SAMPLES_PATH / "pdf" / "sample_pdf_1.pdf")
|
||||
assert indexing_pipeline.get_document_store().get_document_count() > 0
|
||||
|
||||
# Save it
|
||||
new_indexing_config = tmp_path / "test_indexing.yaml"
|
||||
indexing_pipeline.save_to_yaml(new_indexing_config)
|
||||
|
||||
# Re-load it and compare the resulting pipelines
|
||||
new_indexing_pipeline = Pipeline.load_from_yaml(path=new_indexing_config)
|
||||
assert nx.is_isomorphic(new_indexing_pipeline.graph, indexing_pipeline.graph)
|
||||
|
||||
# Check that modifying a pipeline modifies the output YAML
|
||||
modified_indexing_pipeline = Pipeline.load_from_yaml(path=new_indexing_config)
|
||||
modified_indexing_pipeline.add_node(FileTypeClassifier(), name="file_classifier", inputs=["File"])
|
||||
assert not nx.is_isomorphic(new_indexing_pipeline.graph, modified_indexing_pipeline.graph)
|
||||
|
||||
# Test the query pipeline:
|
||||
# Load it
|
||||
query_pipeline = Pipeline.load_from_yaml(path=config_path, pipeline_name="query_pipeline")
|
||||
|
||||
# Check if it works
|
||||
prediction = query_pipeline.run(
|
||||
query="Who made the PDF specification?", params={"ESRetriever": {"top_k": 10}, "Reader": {"top_k": 3}}
|
||||
)
|
||||
assert prediction["query"] == "Who made the PDF specification?"
|
||||
assert prediction["answers"][0].answer == "Adobe Systems"
|
||||
assert "_debug" not in prediction.keys()
|
||||
|
||||
# Save it
|
||||
new_query_config = tmp_path / "test_query.yaml"
|
||||
query_pipeline.save_to_yaml(new_query_config)
|
||||
|
||||
# Re-load it and compare the resulting pipelines
|
||||
new_query_pipeline = Pipeline.load_from_yaml(path=new_query_config)
|
||||
assert nx.is_isomorphic(new_query_pipeline.graph, query_pipeline.graph)
|
||||
|
||||
# Check that different pipelines produce different files
|
||||
assert not nx.is_isomorphic(new_query_pipeline.graph, new_indexing_pipeline.graph)
|
||||
|
||||
|
||||
#
|
||||
# Unit
|
||||
#
|
||||
|
||||
|
||||
def test_load_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: retriever
|
||||
type: MockRetriever
|
||||
- name: reader
|
||||
type: MockReader
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: retriever
|
||||
inputs:
|
||||
- Query
|
||||
- name: reader
|
||||
inputs:
|
||||
- retriever
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert len(pipeline.graph.nodes) == 3
|
||||
assert isinstance(pipeline.get_node("retriever"), MockRetriever)
|
||||
assert isinstance(pipeline.get_node("reader"), MockReader)
|
||||
|
||||
|
||||
def test_load_yaml_non_existing_file():
|
||||
with pytest.raises(FileNotFoundError):
|
||||
Pipeline.load_from_yaml(path=SAMPLES_PATH / "pipeline" / "I_dont_exist.yml")
|
||||
|
||||
|
||||
def test_load_yaml_invalid_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write("this is not valid YAML!")
|
||||
with pytest.raises(PipelineConfigError):
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
|
||||
|
||||
def test_load_yaml_missing_version(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
"""
|
||||
components:
|
||||
- name: docstore
|
||||
type: MockDocumentStore
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: docstore
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineConfigError) as e:
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert "version" in str(e)
|
||||
|
||||
|
||||
def test_load_yaml_non_existing_version(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
"""
|
||||
version: random
|
||||
components:
|
||||
- name: docstore
|
||||
type: MockDocumentStore
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: docstore
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineConfigError) as e:
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert "version" in str(e) and "random" in str(e)
|
||||
|
||||
|
||||
def test_load_yaml_incompatible_version(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
"""
|
||||
version: 1.1.0
|
||||
components:
|
||||
- name: docstore
|
||||
type: MockDocumentStore
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: docstore
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineConfigError) as e:
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert "version" in str(e) and "1.1.0" in str(e)
|
||||
|
||||
|
||||
def test_load_yaml_no_components(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineConfigError) as e:
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert "components" in str(e)
|
||||
|
||||
|
||||
def test_load_yaml_wrong_component(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: docstore
|
||||
type: ImaginaryDocumentStore
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: docstore
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
with pytest.raises(HaystackError) as e:
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert "ImaginaryDocumentStore" in str(e)
|
||||
|
||||
|
||||
def test_load_yaml_custom_component(tmp_path):
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, param: int):
|
||||
self.param = param
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: custom_node
|
||||
type: CustomNode
|
||||
params:
|
||||
param: 1
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: custom_node
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
|
||||
|
||||
def test_load_yaml_custom_component_with_helper_class_in_init(tmp_path):
|
||||
"""
|
||||
This test can work from the perspective of YAML schema validation:
|
||||
HelperClass is picked up correctly and everything gets loaded.
|
||||
|
||||
However, for now we decide to disable this feature.
|
||||
See haystack/_json_schema.py for details.
|
||||
"""
|
||||
|
||||
@dataclass # Makes this test class JSON serializable
|
||||
class HelperClass:
|
||||
def __init__(self, another_param: str):
|
||||
self.param = another_param
|
||||
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, some_exotic_parameter: HelperClass = HelperClass(1)):
|
||||
self.some_exotic_parameter = some_exotic_parameter
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: custom_node
|
||||
type: CustomNode
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: custom_node
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineSchemaError, match="takes object instances as parameters in its __init__ function"):
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
|
||||
|
||||
def test_load_yaml_custom_component_with_helper_class_in_yaml(tmp_path):
|
||||
"""
|
||||
This test can work from the perspective of YAML schema validation:
|
||||
HelperClass is picked up correctly and everything gets loaded.
|
||||
|
||||
However, for now we decide to disable this feature.
|
||||
See haystack/_json_schema.py for details.
|
||||
"""
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self, another_param: str):
|
||||
self.param = another_param
|
||||
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, some_exotic_parameter: HelperClass):
|
||||
self.some_exotic_parameter = some_exotic_parameter
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: custom_node
|
||||
type: CustomNode
|
||||
params:
|
||||
some_exotic_parameter: HelperClass("hello")
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: custom_node
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineConfigError, match="not a valid variable name or value"):
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
|
||||
|
||||
def test_load_yaml_custom_component_with_enum_in_init(tmp_path):
|
||||
"""
|
||||
This test can work from the perspective of YAML schema validation:
|
||||
Flags is picked up correctly and everything gets loaded.
|
||||
|
||||
However, for now we decide to disable this feature.
|
||||
See haystack/_json_schema.py for details.
|
||||
"""
|
||||
|
||||
class Flags(Enum):
|
||||
FIRST_VALUE = 1
|
||||
SECOND_VALUE = 2
|
||||
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, some_exotic_parameter: Flags = None):
|
||||
self.some_exotic_parameter = some_exotic_parameter
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: custom_node
|
||||
type: CustomNode
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: custom_node
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineSchemaError, match="takes object instances as parameters in its __init__ function"):
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
|
||||
|
||||
def test_load_yaml_custom_component_with_enum_in_yaml(tmp_path):
|
||||
"""
|
||||
This test can work from the perspective of YAML schema validation:
|
||||
Flags is picked up correctly and everything gets loaded.
|
||||
|
||||
However, for now we decide to disable this feature.
|
||||
See haystack/_json_schema.py for details.
|
||||
"""
|
||||
|
||||
class Flags(Enum):
|
||||
FIRST_VALUE = 1
|
||||
SECOND_VALUE = 2
|
||||
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, some_exotic_parameter: Flags):
|
||||
self.some_exotic_parameter = some_exotic_parameter
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: custom_node
|
||||
type: CustomNode
|
||||
params:
|
||||
some_exotic_parameter: Flags.SECOND_VALUE
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: custom_node
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineSchemaError, match="takes object instances as parameters in its __init__ function"):
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
|
||||
|
||||
def test_load_yaml_custom_component_with_external_constant(tmp_path):
|
||||
"""
|
||||
This is a potential pitfall. The code should work as described here.
|
||||
"""
|
||||
|
||||
class AnotherClass:
|
||||
CLASS_CONSTANT = "str"
|
||||
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, some_exotic_parameter: str):
|
||||
self.some_exotic_parameter = some_exotic_parameter
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: custom_node
|
||||
type: CustomNode
|
||||
params:
|
||||
some_exotic_parameter: AnotherClass.CLASS_CONSTANT # Will *NOT* be resolved
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: custom_node
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
node = pipeline.get_node("custom_node")
|
||||
node.some_exotic_parameter == "AnotherClass.CLASS_CONSTANT"
|
||||
|
||||
|
||||
def test_load_yaml_custom_component_with_superclass(tmp_path):
|
||||
class BaseCustomNode(MockNode):
|
||||
pass
|
||||
|
||||
class CustomNode(BaseCustomNode):
|
||||
def __init__(self, some_exotic_parameter: str):
|
||||
self.some_exotic_parameter = some_exotic_parameter
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: custom_node
|
||||
type: CustomNode
|
||||
params:
|
||||
some_exotic_parameter: value
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: custom_node
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
|
||||
|
||||
def test_load_yaml_no_pipelines(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: docstore
|
||||
type: MockDocumentStore
|
||||
pipelines:
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineConfigError) as e:
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert "pipeline" in str(e)
|
||||
|
||||
|
||||
def test_load_yaml_invalid_pipeline_name(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: docstore
|
||||
type: MockDocumentStore
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: docstore
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineConfigError) as e:
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml", pipeline_name="invalid")
|
||||
assert "invalid" in str(e) and "pipeline" in str(e)
|
||||
|
||||
|
||||
def test_load_yaml_pipeline_with_wrong_nodes(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: docstore
|
||||
type: MockDocumentStore
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: not_existing_node
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineConfigError) as e:
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert "not_existing_node" in str(e)
|
||||
|
||||
|
||||
def test_load_yaml_pipeline_not_acyclic_graph(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: retriever
|
||||
type: MockRetriever
|
||||
- name: reader
|
||||
type: MockRetriever
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: retriever
|
||||
inputs:
|
||||
- reader
|
||||
- name: reader
|
||||
inputs:
|
||||
- retriever
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineConfigError) as e:
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert "reader" in str(e) or "retriever" in str(e)
|
||||
assert "loop" in str(e)
|
||||
|
||||
|
||||
def test_load_yaml_wrong_root(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: retriever
|
||||
type: MockRetriever
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: retriever
|
||||
inputs:
|
||||
- Nothing
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineConfigError) as e:
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert "Nothing" in str(e)
|
||||
assert "root" in str(e).lower()
|
||||
|
||||
|
||||
def test_load_yaml_two_roots(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: retriever
|
||||
type: MockRetriever
|
||||
- name: retriever_2
|
||||
type: MockRetriever
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: retriever
|
||||
inputs:
|
||||
- Query
|
||||
- name: retriever_2
|
||||
inputs:
|
||||
- File
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineConfigError) as e:
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert "File" in str(e) or "Query" in str(e)
|
||||
|
||||
|
||||
def test_load_yaml_disconnected_component(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: docstore
|
||||
type: MockDocumentStore
|
||||
- name: retriever
|
||||
type: MockRetriever
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: docstore
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert len(pipeline.graph.nodes) == 2
|
||||
assert isinstance(pipeline.get_document_store(), MockDocumentStore)
|
||||
assert not pipeline.get_node("retriever")
|
||||
|
||||
|
||||
def test_save_yaml(tmp_path):
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(MockRetriever(), name="retriever", inputs=["Query"])
|
||||
pipeline.save_to_yaml(tmp_path / "saved_pipeline.yml")
|
||||
|
||||
with open(tmp_path / "saved_pipeline.yml", "r") as saved_yaml:
|
||||
content = saved_yaml.read()
|
||||
|
||||
assert content.count("retriever") == 2
|
||||
assert "MockRetriever" in content
|
||||
assert "Query" in content
|
||||
assert f"version: {haystack.__version__}" in content
|
||||
|
||||
|
||||
def test_save_yaml_overwrite(tmp_path):
|
||||
pipeline = Pipeline()
|
||||
retriever = MockRetriever()
|
||||
pipeline.add_node(component=retriever, name="retriever", inputs=["Query"])
|
||||
|
||||
with open(tmp_path / "saved_pipeline.yml", "w") as _:
|
||||
pass
|
||||
|
||||
pipeline.save_to_yaml(tmp_path / "saved_pipeline.yml")
|
||||
|
||||
with open(tmp_path / "saved_pipeline.yml", "r") as saved_yaml:
|
||||
content = saved_yaml.read()
|
||||
assert content != ""
|
||||
@ -4,7 +4,7 @@ from haystack import Document
|
||||
from haystack.nodes.file_converter.pdf import PDFToTextConverter
|
||||
from haystack.nodes.preprocessor.preprocessor import PreProcessor
|
||||
|
||||
from conftest import SAMPLES_PATH
|
||||
from .conftest import SAMPLES_PATH
|
||||
|
||||
TEXT = """
|
||||
This is a sample sentence in paragraph_1. This is a sample sentence in paragraph_1. This is a sample sentence in
|
||||
|
||||
@ -5,23 +5,10 @@ import ray
|
||||
|
||||
from haystack.pipelines import RayPipeline
|
||||
|
||||
from conftest import SAMPLES_PATH
|
||||
from .conftest import SAMPLES_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_load_pipeline(document_store_with_docs):
|
||||
pipeline = RayPipeline.load_from_yaml(
|
||||
SAMPLES_PATH / "pipeline" / "test_pipeline.yaml", pipeline_name="ray_query_pipeline", num_cpus=8
|
||||
)
|
||||
prediction = pipeline.run(query="Who lives in Berlin?", params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}})
|
||||
|
||||
assert ray.serve.get_deployment(name="ESRetriever").num_replicas == 2
|
||||
assert ray.serve.get_deployment(name="Reader").num_replicas == 1
|
||||
assert prediction["query"] == "Who lives in Berlin?"
|
||||
assert prediction["answers"][0].answer == "Carla"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
@pytest.fixture(autouse=True)
|
||||
def shutdown_ray():
|
||||
yield
|
||||
try:
|
||||
@ -30,3 +17,17 @@ def shutdown_ray():
|
||||
ray.shutdown()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_load_pipeline(document_store_with_docs):
|
||||
pipeline = RayPipeline.load_from_yaml(
|
||||
SAMPLES_PATH / "pipeline" / "test_ray_pipeline.yaml", pipeline_name="ray_query_pipeline", num_cpus=8
|
||||
)
|
||||
prediction = pipeline.run(query="Who lives in Berlin?", params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}})
|
||||
|
||||
assert ray.serve.get_deployment(name="ESRetriever").num_replicas == 2
|
||||
assert ray.serve.get_deployment(name="Reader").num_replicas == 1
|
||||
assert prediction["query"] == "Who lives in Berlin?"
|
||||
assert prediction["answers"][0].answer == "Carla"
|
||||
|
||||
@ -15,7 +15,7 @@ from haystack.nodes.retriever.dense import DensePassageRetriever, TableTextRetri
|
||||
from haystack.nodes.retriever.sparse import ElasticsearchRetriever, ElasticsearchFilterOnlyRetriever, TfidfRetriever
|
||||
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
|
||||
|
||||
from conftest import SAMPLES_PATH
|
||||
from .conftest import SAMPLES_PATH
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -244,8 +244,8 @@ def test_table_text_retriever_embedding(document_store, retriever, docs):
|
||||
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
def test_dpr_saving_and_loading(retriever, document_store):
|
||||
retriever.save("test_dpr_save")
|
||||
def test_dpr_saving_and_loading(tmp_path, retriever, document_store):
|
||||
retriever.save(f"{tmp_path}/test_dpr_save")
|
||||
|
||||
def sum_params(model):
|
||||
s = []
|
||||
@ -258,7 +258,7 @@ def test_dpr_saving_and_loading(retriever, document_store):
|
||||
original_sum_passage = sum_params(retriever.passage_encoder)
|
||||
del retriever
|
||||
|
||||
loaded_retriever = DensePassageRetriever.load("test_dpr_save", document_store)
|
||||
loaded_retriever = DensePassageRetriever.load(f"{tmp_path}/test_dpr_save", document_store)
|
||||
|
||||
loaded_sum_query = sum_params(loaded_retriever.query_encoder)
|
||||
loaded_sum_passage = sum_params(loaded_retriever.passage_encoder)
|
||||
@ -292,8 +292,8 @@ def test_dpr_saving_and_loading(retriever, document_store):
|
||||
|
||||
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
||||
@pytest.mark.embedding_dim(512)
|
||||
def test_table_text_retriever_saving_and_loading(retriever, document_store):
|
||||
retriever.save("test_table_text_retriever_save")
|
||||
def test_table_text_retriever_saving_and_loading(tmp_path, retriever, document_store):
|
||||
retriever.save(f"{tmp_path}/test_table_text_retriever_save")
|
||||
|
||||
def sum_params(model):
|
||||
s = []
|
||||
@ -307,7 +307,7 @@ def test_table_text_retriever_saving_and_loading(retriever, document_store):
|
||||
original_sum_table = sum_params(retriever.table_encoder)
|
||||
del retriever
|
||||
|
||||
loaded_retriever = TableTextRetriever.load("test_table_text_retriever_save", document_store)
|
||||
loaded_retriever = TableTextRetriever.load(f"{tmp_path}/test_table_text_retriever_save", document_store)
|
||||
|
||||
loaded_sum_query = sum_params(loaded_retriever.query_encoder)
|
||||
loaded_sum_passage = sum_params(loaded_retriever.passage_encoder)
|
||||
|
||||
@ -16,7 +16,7 @@ from haystack.nodes import (
|
||||
)
|
||||
from haystack.schema import Document
|
||||
|
||||
from conftest import SAMPLES_PATH
|
||||
from .conftest import SAMPLES_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user