Add export of Pipeline YAML config (#1003)

This commit is contained in:
oryx1729 2021-04-30 12:23:29 +02:00 committed by GitHub
parent a00703256f
commit 99990e7249
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 245 additions and 27 deletions

View File

@ -94,6 +94,16 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
:param return_embedding: To return document embedding
"""
# save init parameters to enable export of component config as YAML
self.set_config(
host=host, port=port, username=username, password=password, api_key_id=api_key_id, api_key=api_key,
aws4auth=aws4auth, index=index, label_index=label_index, search_fields=search_fields, text_field=text_field,
name_field=name_field, embedding_field=embedding_field, embedding_dim=embedding_dim,
custom_mapping=custom_mapping, excluded_meta_data=excluded_meta_data, analyzer=analyzer, scheme=scheme,
ca_certs=ca_certs, verify_certs=verify_certs, create_index=create_index,
update_existing_documents=update_existing_documents, refresh_type=refresh_type, similarity=similarity,
timeout=timeout, return_embedding=return_embedding,
)
self.client = self._init_elastic_client(host=host, port=port, username=username, password=password,
api_key=api_key, api_key_id=api_key_id, aws4auth=aws4auth, scheme=scheme,

View File

@ -77,6 +77,15 @@ class FAISSDocumentStore(SQLDocumentStore):
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
"""
# save init parameters to enable export of component config as YAML
self.set_config(
sql_url=sql_url, vector_dim=vector_dim, faiss_index_factory_str=faiss_index_factory_str,
faiss_index=faiss_index, return_embedding=return_embedding,
update_existing_documents=update_existing_documents, index=index, similarity=similarity,
embedding_field=embedding_field, progress_bar=progress_bar
)
self.vector_dim = vector_dim
self.faiss_index_factory_str = faiss_index_factory_str
self.faiss_indexes: Dict[str, faiss.swigfaiss.Index] = {}

View File

@ -44,6 +44,13 @@ class InMemoryDocumentStore(BaseDocumentStore):
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
"""
# save init parameters to enable export of component config as YAML
self.set_config(
index=index, label_index=label_index, embedding_field=embedding_field, embedding_dim=embedding_dim,
return_embedding=return_embedding, similarity=similarity, progress_bar=progress_bar,
)
self.indexes: Dict[str, Dict] = defaultdict(dict)
self.index: str = index
self.label_index: str = label_index

View File

@ -94,6 +94,15 @@ class MilvusDocumentStore(SQLDocumentStore):
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
"""
# save init parameters to enable export of component config as YAML
self.set_config(
sql_url=sql_url, milvus_url=milvus_url, connection_pool=connection_pool, index=index, vector_dim=vector_dim,
index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
search_param=search_param, update_existing_documents=update_existing_documents,
return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar,
)
self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool)
self.vector_dim = vector_dim
self.index_file_size = index_file_size

View File

@ -88,6 +88,12 @@ class SQLDocumentStore(BaseDocumentStore):
added already exists. Using this parameter could cause performance degradation
for document insertion.
"""
# save init parameters to enable export of component config as YAML
self.set_config(
url=url, index=index, label_index=label_index, update_existing_documents=update_existing_documents
)
engine = create_engine(url)
ORMBase.metadata.create_all(engine)
Session = sessionmaker(bind=engine)

View File

@ -27,6 +27,10 @@ class BaseConverter(BaseComponent):
not one of the valid languages, then it might likely be encoding error resulting
in garbled text.
"""
# save init parameters to enable export of component config as YAML
self.set_config(remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)
self.remove_numeric_tables = remove_numeric_tables
self.valid_languages = valid_languages

View File

@ -22,6 +22,10 @@ class PDFToTextConverter(BaseConverter):
not one of the valid languages, then it might likely be encoding error resulting
in garbled text.
"""
# save init parameters to enable export of component config as YAML
self.set_config(remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)
verify_installation = subprocess.run(["pdftotext -v"], shell=True)
if verify_installation.returncode == 127:
raise Exception(

View File

@ -58,6 +58,12 @@ class TikaConverter(BaseConverter):
not one of the valid languages, then it might likely be encoding error resulting
in garbled text.
"""
# save init parameters to enable export of component config as YAML
self.set_config(
tika_url=tika_url, remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages
)
ping = requests.get(tika_url)
if ping.status_code != 200:
raise Exception(f"Apache Tika server is not reachable at the URL '{tika_url}'. To run it locally"

View File

@ -8,22 +8,6 @@ logger = logging.getLogger(__name__)
class TextConverter(BaseConverter):
def __init__(self, remove_numeric_tables: bool = False, valid_languages: Optional[List[str]] = None):
"""
:param remove_numeric_tables: This option uses heuristics to remove numeric rows from the tables.
The tabular structures in documents might be noise for the reader model if it
does not have table parsing capability for finding answers. However, tables
may also have long strings that could possible candidate for searching answers.
The rows containing strings are thus retained in this option.
:param valid_languages: validate languages from a list of languages specified in the ISO 639-1
(https://en.wikipedia.org/wiki/ISO_639-1) format.
This option can be used to add test for encoding errors. If the extracted text is
not one of the valid languages, then it might likely be encoding error resulting
in garbled text.
"""
super().__init__(remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)
def convert(
self,
file_path: Path,

View File

@ -94,6 +94,13 @@ class RAGenerator(BaseGenerator):
:param use_gpu: Whether to use GPU (if available)
"""
# save init parameters to enable export of component config as YAML
self.set_config(
model_name_or_path=model_name_or_path, model_version=model_version, retriever=retriever,
generator_type=generator_type, top_k=top_k, max_length=max_length, min_length=min_length,
num_beams=num_beams, embed_title=embed_title, prefix=prefix, use_gpu=use_gpu,
)
self.model_name_or_path = model_name_or_path
self.max_length = max_length
self.min_length = min_length

View File

@ -22,7 +22,10 @@ class Text2SparqlRetriever(BaseGraphRetriever):
:param model_name_or_path: Name of or path to a pre-trained BartForConditionalGeneration model.
:param top_k: How many SPARQL queries to generate per text query.
"""
# save init parameters to enable export of component config as YAML
self.set_config(knowledge_graph=knowledge_graph, model_name_or_path=model_name_or_path, top_k=top_k)
self.knowledge_graph = knowledge_graph
# TODO We should extend this to any seq2seq models and use the AutoModel class
self.model = BartForConditionalGeneration.from_pretrained(model_name_or_path, force_bos_token_to_be_generated=True)

View File

@ -33,7 +33,12 @@ class GraphDBKnowledgeGraph(BaseKnowledgeGraph):
:param prefixes: definitions of namespaces with a new line after each namespace, e.g., PREFIX hp: <https://deepset.ai/harry_potter/>
"""
# save init parameters to enable export of component config as YAML
self.set_config(
host=host, port=port, username=username, password=password, index=index, prefixes=prefixes
)
self.url = f"http://{host}:{port}"
self.index = index
self.username = username

View File

@ -1,3 +1,4 @@
import inspect
import logging
import os
import traceback
@ -180,7 +181,7 @@ class Pipeline:
Here's a sample configuration:
```yaml
| version: '0.7'
| version: '0.8'
|
| components: # define all the building-blocks for Pipeline
| - name: MyReader # custom-name for the component; helpful for visualization & debugging
@ -291,6 +292,55 @@ class Pipeline:
param_name = key.replace(env_prefix, "").lower()
definition["params"][param_name] = value
def save_to_yaml(self, path: Path, return_defaults: bool = False):
"""
Save a YAML configuration for the Pipeline that can be used with `Pipeline.load_from_yaml()`.
:param path: path of the output YAML file.
:param return_defaults: whether to output parameters that have the default values.
"""
nodes = self.graph.nodes
pipeline_name = self.pipeline_type.lower()
pipeline_type = self.pipeline_type
pipelines: dict = {pipeline_name: {"name": pipeline_name, "type": pipeline_type, "nodes": []}}
components = {}
for node in nodes:
if node == self.root_node_id:
continue
component_instance = self.graph.nodes.get(node)["component"]
component_type = component_instance.pipeline_config["type"]
component_params = component_instance.pipeline_config["params"]
components[node] = {"name": node, "type": component_type, "params": {}}
component_signature = inspect.signature(type(component_instance)).parameters
for key, value in component_params.items():
# A parameter for a Component could be another Component. For instance, a Retriever has
# the DocumentStore as a parameter.
# Component configs must be a dict with a "type" key. The "type" keys distinguishes between
# other parameters like "custom_mapping" that are dicts.
# This currently only checks for the case single-level nesting case, wherein, "a Component has another
# Component as a parameter". For deeper nesting cases, this function should be made recursive.
if isinstance(value, dict) and "type" in value.keys(): # the parameter is a Component
components[node]["params"][key] = value["type"]
sub_component_signature = inspect.signature(BaseComponent.subclasses[value["type"]]).parameters
params = {
k: v for k, v in value["params"].items()
if sub_component_signature[k].default != v or return_defaults is True
}
components[value["type"]] = {"name": value["type"], "type": value["type"], "params": params}
else:
if component_signature[key].default != value or return_defaults is True:
components[node]["params"][key] = value
# create the Pipeline definition with how the Component are connected
pipelines[pipeline_name]["nodes"].append({"name": node, "inputs": list(self.graph.predecessors(node))})
config = {"components": list(components.values()), "pipelines": list(pipelines.values()), "version": "0.8"}
with open(path, 'w') as outfile:
yaml.dump(config, outfile, default_flow_style=False)
class BaseStandardPipeline(ABC):
pipeline: Pipeline

View File

@ -44,6 +44,14 @@ class PreProcessor(BasePreProcessor):
to True, the individual split will always have complete sentences &
the number of words will be <= split_length.
"""
# save init parameters to enable export of component config as YAML
self.set_config(
clean_whitespace=clean_whitespace, clean_header_footer=clean_header_footer,
clean_empty_lines=clean_empty_lines, split_by=split_by, split_length=split_length,
split_overlap=split_overlap, split_respect_sentence_boundary=split_respect_sentence_boundary,
)
try:
nltk.data.find('tokenizers/punkt')
except LookupError:

View File

@ -93,6 +93,14 @@ class FARMReader(BaseReader):
Can be helpful to disable in production deployments to keep the logs clean.
"""
# save init parameters to enable export of component config as YAML
self.set_config(
model_name_or_path=model_name_or_path, model_version=model_version, context_window_size=context_window_size,
batch_size=batch_size, use_gpu=use_gpu, no_ans_boost=no_ans_boost, return_no_answer=return_no_answer,
top_k=top_k, top_k_per_candidate=top_k_per_candidate, top_k_per_sample=top_k_per_sample,
num_processes=num_processes, max_seq_len=max_seq_len, doc_stride=doc_stride, progress_bar=progress_bar,
)
self.return_no_answers = return_no_answer
self.top_k = top_k
self.top_k_per_candidate = top_k_per_candidate

View File

@ -57,6 +57,14 @@ class TransformersReader(BaseReader):
:param doc_stride: length of striding window for splitting long texts (used if len(text) > max_seq_len)
"""
# save init parameters to enable export of component config as YAML
self.set_config(
model_name_or_path=model_name_or_path, model_version=model_version, tokenizer=tokenizer,
context_window_size=context_window_size, use_gpu=use_gpu, top_k=top_k, doc_stride=doc_stride,
top_k_per_candidate=top_k_per_candidate, return_no_answers=return_no_answers, max_seq_len=max_seq_len,
)
self.model = pipeline('question-answering', model=model_name_or_path, tokenizer=tokenizer, device=use_gpu, revision=model_version)
self.context_window_size = context_window_size
self.top_k = top_k

View File

@ -98,6 +98,16 @@ class DensePassageRetriever(BaseRetriever):
Can be helpful to disable in production deployments to keep the logs clean.
"""
# save init parameters to enable export of component config as YAML
self.set_config(
document_store=document_store, query_embedding_model=query_embedding_model,
passage_embedding_model=passage_embedding_model, single_model_path=single_model_path,
model_version=model_version, max_seq_len_query=max_seq_len_query, max_seq_len_passage=max_seq_len_passage,
top_k=top_k, use_gpu=use_gpu, batch_size=batch_size, embed_title=embed_title,
use_fast_tokenizers=use_fast_tokenizers, infer_tokenizer_classes=infer_tokenizer_classes,
similarity_function=similarity_function, progress_bar=progress_bar,
)
self.document_store = document_store
self.batch_size = batch_size
self.progress_bar = progress_bar
@ -461,6 +471,14 @@ class EmbeddingRetriever(BaseRetriever):
Default: -1 (very last layer).
:param top_k: How many documents to return per query.
"""
# save init parameters to enable export of component config as YAML
self.set_config(
document_store=document_store, embedding_model=embedding_model, model_version=model_version,
use_gpu=use_gpu, model_format=model_format, pooling_strategy=pooling_strategy,
emb_extraction_layer=emb_extraction_layer, top_k=top_k,
)
self.document_store = document_store
self.model_format = model_format
self.pooling_strategy = pooling_strategy

View File

@ -52,6 +52,10 @@ class ElasticsearchRetriever(BaseRetriever):
```
:param top_k: How many documents to return per query.
"""
# save init parameters to enable export of component config as YAML
self.set_config(document_store=document_store, top_k=top_k, custom_query=custom_query)
self.document_store: ElasticsearchDocumentStore = document_store
self.top_k = top_k
self.custom_query = custom_query
@ -118,6 +122,10 @@ class TfidfRetriever(BaseRetriever):
:param document_store: an instance of a DocumentStore to retrieve documents from.
:param top_k: How many documents to return per query.
"""
# save init parameters to enable export of component config as YAML
self.set_config(document_store=document_store, top_k=top_k)
self.vectorizer = TfidfVectorizer(
lowercase=True,
stop_words=None,

View File

@ -3,6 +3,7 @@ from uuid import uuid4
import numpy as np
from abc import abstractmethod
class Document:
def __init__(self, text: str,
id: Optional[str] = None,
@ -227,6 +228,7 @@ class BaseComponent:
outgoing_edges: int
subclasses: dict = {}
pipeline_config: dict = {}
def __init_subclass__(cls, **kwargs):
""" This automatically keeps track of all available subclasses.
@ -258,4 +260,19 @@ class BaseComponent:
:param kwargs:
:return:
"""
pass
pass
def set_config(self, **kwargs):
"""
Save the init parameters of a component that later can be used with exporting
YAML configuration of a Pipeline.
:param kwargs: all parameters passed to the __init__() of the Component.
"""
if not self.pipeline_config:
self.pipeline_config = {"params": {}, "type": type(self).__name__}
for k, v in kwargs.items():
if isinstance(v, BaseComponent):
self.pipeline_config["params"][k] = v.pipeline_config
elif v is not None:
self.pipeline_config["params"][k] = v

View File

@ -81,6 +81,14 @@ class TransformersSummarizer(BaseSummarizer):
Important: The summary will depend on the order of the supplied documents!
"""
# save init parameters to enable export of component config as YAML
self.set_config(
model_name_or_path=model_name_or_path, model_version=model_version, tokenizer=tokenizer,
max_length=max_length, min_length=min_length, use_gpu=use_gpu,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
separator_for_single_summary=separator_for_single_summary, generate_single_summary=generate_single_summary,
)
# TODO AutoModelForSeq2SeqLM is only necessary with transformers==4.1.1, with newer versions use the pipeline directly
if tokenizer is None:
tokenizer = model_name_or_path

View File

@ -1,10 +1,10 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Union
from haystack import Document
from haystack import Document, BaseComponent
class BaseTranslator(ABC):
class BaseTranslator(BaseComponent):
"""
Abstract class for a Translator component that translates either a query or a doc from language A to language B.
"""
@ -24,7 +24,7 @@ class BaseTranslator(ABC):
"""
pass
def run(
def run( # type: ignore
self,
query: Optional[str] = None,
documents: Optional[Union[List[Document], List[str], List[Dict[str, Any]]]] = None,

View File

@ -56,6 +56,12 @@ class TransformersTranslator(BaseTranslator):
:param clean_up_tokenization_spaces: Whether or not to clean up the tokenization spaces. (default True)
"""
# save init parameters to enable export of component config as YAML
self.set_config(
model_name_or_path=model_name_or_path, tokenizer_name=tokenizer_name, max_seq_len=max_seq_len,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
self.max_seq_len = max_seq_len
self.clean_up_tokenization_spaces = clean_up_tokenization_spaces
tokenizer_name = tokenizer_name or model_name_or_path

View File

@ -10,10 +10,9 @@ from haystack.retriever.sparse import ElasticsearchRetriever
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_load_yaml(document_store_with_docs):
def test_load_and_save_yaml(document_store_with_docs, tmp_path):
# test correct load of indexing pipeline from yaml
pipeline = Pipeline.load_from_yaml(Path("samples/pipeline/test_pipeline.yaml"),
pipeline_name="indexing_pipeline")
pipeline = Pipeline.load_from_yaml(Path("samples/pipeline/test_pipeline.yaml"), pipeline_name="indexing_pipeline")
pipeline.run(file_path=Path("samples/pdf/sample_pdf_1.pdf"), top_k_retriever=10, top_k_reader=3)
# test correct load of query pipeline from yaml
@ -26,6 +25,40 @@ def test_load_yaml(document_store_with_docs):
with pytest.raises(Exception):
Pipeline.load_from_yaml(path=Path("samples/pipeline/test_pipeline.yaml"), pipeline_name="invalid")
# test config export
pipeline.save_to_yaml(tmp_path / "test.yaml")
with open(tmp_path/"test.yaml", "r", encoding='utf-8') as stream:
saved_yaml = stream.read()
expected_yaml = '''
components:
- name: ESRetriever
params:
document_store: ElasticsearchDocumentStore
type: ElasticsearchRetriever
- name: ElasticsearchDocumentStore
params:
index: haystack_test_document
label_index: haystack_test_label
type: ElasticsearchDocumentStore
- name: Reader
params:
model_name_or_path: deepset/roberta-base-squad2
no_ans_boost: -10
type: FARMReader
pipelines:
- name: query
nodes:
- inputs:
- Query
name: ESRetriever
- inputs:
- ESRetriever
name: Reader
type: Query
version: '0.8'
'''
assert saved_yaml.replace(" ", "").replace("\n", "") == expected_yaml.replace(" ", "").replace("\n", "")
@pytest.mark.slow
@pytest.mark.elasticsearch