feat: add Shaper (#3880)

* Shaper initial version

* Inital pydoc

* Add more unit tests

* Fix pydoc, expand Shaper pydoc with YAML example

* Minor fix

* Improve pydoc

* More unit tests with prompt node

* Describe Shaper functions in pydoc

* More pydoc

* Use pytest.raises instead of catching errors

* Improve test_function_invocation_order unit test

* pylint fixes

* Improve run_batch handling

* simpler version, initial stub

* stubbing tests

* promptnode compatibility

* add tests

* simplify

* fix promptnode tests

* pylint

* mypy

* fix corner case & mypy

* mypy

* review feedback

* tests

* Add lg updates

* add rename

* pylint

* Add complex unit test with two PNs and ICMs in between (#3921)

Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>

* docstring

* fix tests

* add join_lists

* add documents_to_strings

* fix tests

* allow lists of input values

* doc review feedback

* do not use locals()

* Update with minor lg changes

* fix corner case in ICM

* fix merge

* review feedback

* answers conversions

* mypy

* add tests

* generative answers

* forgot to commit

---------

Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
Co-authored-by: agnieszka-m <amarzec13@gmail.com>
This commit is contained in:
ZanSara 2023-02-01 18:36:13 +01:00 committed by GitHub
parent e8ff48094b
commit 9009a9ae58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 1538 additions and 17 deletions

View File

@ -23,7 +23,7 @@ from haystack.nodes.file_converter import (
)
from haystack.nodes.image_to_text import TransformersImageToText
from haystack.nodes.label_generator import PseudoLabelGenerator
from haystack.nodes.other import Docs2Answers, JoinDocuments, RouteDocuments, JoinAnswers, DocumentMerger
from haystack.nodes.other import Docs2Answers, JoinDocuments, RouteDocuments, JoinAnswers, DocumentMerger, Shaper
from haystack.nodes.preprocessor import BasePreProcessor, PreProcessor
from haystack.nodes.prompt import PromptNode, PromptTemplate, PromptModel
from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier

View File

@ -4,3 +4,4 @@ from haystack.nodes.other.route_documents import RouteDocuments
from haystack.nodes.other.join_answers import JoinAnswers
from haystack.nodes.other.join import JoinNode
from haystack.nodes.other.document_merger import DocumentMerger
from haystack.nodes.other.shaper import Shaper

View File

@ -0,0 +1,408 @@
from typing import Optional, List, Dict, Any, Tuple, Union, Callable
import logging
from haystack.nodes.base import BaseComponent
from haystack.schema import Document, Answer, MultiLabel
logger = logging.getLogger(__name__)
def rename(value: Any) -> Tuple[Any]:
"""
Identity function. Can be used to rename values in the invocation context without changing them.
Example:
```python
assert rename(1) == (1, )
```
"""
return (value,)
def value_to_list(value: Any, target_list: List[Any]) -> Tuple[List[Any]]:
"""
Transforms a value into a list containing this value as many times as the length of the target list.
Example:
```python
assert value_to_list(value=1, target_list=list(range(5))) == ([1, 1, 1, 1, 1], )
```
"""
return ([value] * len(target_list),)
def join_lists(lists: List[List[Any]]) -> Tuple[List[Any]]:
"""
Joins the passed lists into a single one.
Example:
```python
assert join_lists(lists=[[1, 2, 3], [4, 5]]) == ([1, 2, 3, 4, 5], )
```
"""
merged_list = []
for inner_list in lists:
merged_list += inner_list
return (merged_list,)
def join_strings(strings: List[str], delimiter: str = " ") -> Tuple[str]:
"""
Transforms a list of strings into a single string. The content of this string
is the content of all original strings separated by the delimiter you specify.
Example:
```python
assert join_strings(strings=["first", "second", "third"], delimiter=" - ") == ("first - second - third", )
```
"""
return (delimiter.join(strings),)
def join_documents(documents: List[Document], delimiter: str = " ") -> Tuple[List[Document]]:
"""
Transforms a list of documents into a list containing a single Document. The content of this list
is the content of all original documents separated by the delimiter you specify.
All metadata is dropped. (TODO: fix)
Example:
```python
assert join_documents(
documents=[
Document(content="first"),
Document(content="second"),
Document(content="third")
],
delimiter=" - "
) == ([Document(content="first - second - third")], )
```
"""
return ([Document(content=delimiter.join([d.content for d in documents]))],)
def strings_to_answers(strings: List[str]) -> Tuple[List[Answer]]:
"""
Transforms a list of strings into a list of Answers.
Example:
```python
assert strings_to_answers(strings=["first", "second", "third"]) == ([
Answer(answer="first"),
Answer(answer="second"),
Answer(answer="third"),
], )
```
"""
return ([Answer(answer=string, type="generative") for string in strings],)
def answers_to_strings(answers: List[Answer]) -> Tuple[List[str]]:
"""
Extracts the content field of Documents and returns a list of strings.
Example:
```python
assert answers_to_strings(
answers=[
Answer(answer="first"),
Answer(answer="second"),
Answer(answer="third")
]
) == (["first", "second", "third"],)
```
"""
return ([answer.answer for answer in answers],)
def strings_to_documents(
strings: List[str],
meta: Union[List[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]] = None,
id_hash_keys: Optional[List[str]] = None,
) -> Tuple[List[Document]]:
"""
Transforms a list of strings into a list of Documents. If you pass the metadata in a single
dictionary, all Documents get the same metadata. If you pass the metadata as a list, the length of this list
must be the same as the length of the list of strings, and each Document gets its own metadata.
You can specify `id_hash_keys` only once and it gets assigned to all Documents.
Example:
```python
assert strings_to_documents(
strings=["first", "second", "third"],
meta=[{"position": i} for i in range(3)],
id_hash_keys=['content', 'meta]
) == ([
Document(content="first", metadata={"position": 1}, id_hash_keys=['content', 'meta])]),
Document(content="second", metadata={"position": 2}, id_hash_keys=['content', 'meta]),
Document(content="third", metadata={"position": 3}, id_hash_keys=['content', 'meta])
], )
```
"""
all_metadata: List[Optional[Dict[str, Any]]]
if isinstance(meta, dict):
all_metadata = [meta] * len(strings)
elif isinstance(meta, list):
if len(meta) != len(strings):
raise ValueError(
f"Not enough metadata dictionaries. strings_to_documents received {len(strings)} and {len(meta)} metadata dictionaries."
)
all_metadata = meta
else:
all_metadata = [None] * len(strings)
return ([Document(content=string, meta=m, id_hash_keys=id_hash_keys) for string, m in zip(strings, all_metadata)],)
def documents_to_strings(documents: List[Document]) -> Tuple[List[str]]:
"""
Extracts the content field of Documents and returns a list of strings.
Example:
```python
assert documents_to_strings(
documents=[
Document(content="first"),
Document(content="second"),
Document(content="third")
]
) == (["first", "second", "third"],)
```
"""
return ([doc.content for doc in documents],)
REGISTERED_FUNCTIONS: Dict[str, Callable[..., Tuple[Any]]] = {
"rename": rename,
"value_to_list": value_to_list,
"join_lists": join_lists,
"join_strings": join_strings,
"join_documents": join_documents,
"strings_to_answers": strings_to_answers,
"answers_to_strings": answers_to_strings,
"strings_to_documents": strings_to_documents,
"documents_to_strings": documents_to_strings,
}
class Shaper(BaseComponent):
"""
Shaper is a component that can invoke arbitrary, registered functions on the invocation context
(query, documents, and so on) of a pipeline. It then passes the new or modified variables further down the pipeline.
Using YAML configuration, the Shaper component is initialized with functions to invoke on pipeline invocation
context.
For example, in the YAML snippet below:
```yaml
components:
- name: shaper
type: Shaper
params:
func: value_to_list
inputs:
value: query
target_list: documents
output: [questions]
```
Shaper component is initialized with a directive to invoke function expand on the variable query and to store
the result in the invocation context variable questions. All other invocation context variables are passed down
the pipeline as they are.
Shaper is especially useful for pipelines with PromptNodes, where we need to modify the invocation
context to match the templates of PromptNodes.
You can use multiple Shaper components in a pipeline to modify the invocation context as needed.
`Shaper` supports the current functions:
- `value_to_list`
- `join_strings`
- `join_documents`
- `join_lists`
- `strings_to_documents`
- `documents_to_strings`
See their descriptions in the code for details about their inputs, outputs, and other parameters.
"""
outgoing_edges = 1
def __init__(
self,
func: str,
outputs: List[str],
inputs: Optional[Dict[str, Union[List[str], str]]] = None,
params: Optional[Dict[str, Any]] = None,
):
"""
Initializes the Shaper component.
Some examples:
```yaml
- name: shaper
type: Shaper
params:
func: value_to_list
inputs:
value: query
target_list: documents
outputs:
- questions
```
This node takes the content of `query` and creates a list that contains the value of `query` `len(documents)` times.
This list is stored in the invocation context under the key `questions`.
```yaml
- name: shaper
type: Shaper
params:
func: join_documents
inputs:
value: documents
params:
delimiter: ' - '
outputs:
- documents
```
This node overwrites the content of `documents` in the invocation context with a list containing a single Document
whose content is the concatenation of all the original Documents. So if `documents` contained
`[Document("A"), Document("B"), Document("C")]`, this shaper overwrites it with `[Document("A - B - C")]`
```yaml
- name: shaper
type: Shaper
params:
func: join_strings
params:
strings: ['a', 'b', 'c']
delimiter: ' . '
outputs:
- single_string
- name: shaper
type: Shaper
params:
func: strings_to_documents
inputs:
strings: single_string
metadata:
name: 'my_file.txt'
outputs:
- single_document
```
These two nodes, executed one after the other, first add a key in the invocation context called `single_string`
that contains `a . b . c`, and then create another key called `single_document` that contains instead
`[Document(content="a . b . c", metadata={'name': 'my_file.txt'})]`.
:param func: The function to apply.
:param inputs: Maps the function's input kwargs to the key-value pairs in the invocation context.
For example, `value_to_list` expects the `value` and `target_list` parameters, so `inputs` might contain:
`{'value': 'query', 'target_list': 'documents'}`. It doesn't need to contain all keyword args, see `params`.
:param params: Maps the function's input kwargs to some fixed values. For example, `value_to_list` expects
`value` and `target_list` parameters, so `params` might contain
`{'value': 'A', 'target_list': [1, 1, 1, 1]}` and the node's output is `["A", "A", "A", "A"]`.
It doesn't need to contain all keyword args, see `inputs`.
You can use params to provide fallback values for arguments of `run` that you're not sure exist.
So if you need `query` to exist, you can provide a fallback value in the params, which will be used only if `query`
is not passed to this node by the pipeline.
:param outputs: THe key to store the outputs in the invocation context. The length of the outputs must match
the number of outputs produced by the function invoked.
"""
super().__init__()
self.function = REGISTERED_FUNCTIONS[func]
self.outputs = outputs
self.inputs = inputs or {}
self.params = params or {}
def run( # type: ignore
self,
query: Optional[str] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[dict] = None,
invocation_context: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict, str]:
invocation_context = invocation_context or {}
if query and "query" not in invocation_context.keys():
invocation_context["query"] = query
if file_paths and "file_paths" not in invocation_context.keys():
invocation_context["file_paths"] = file_paths
if labels and "labels" not in invocation_context.keys():
invocation_context["labels"] = labels
if documents and "documents" not in invocation_context.keys():
invocation_context["documents"] = documents
if meta and "meta" not in invocation_context.keys():
invocation_context["meta"] = meta
input_values: Dict[str, Any] = {}
for key, value in self.inputs.items():
if isinstance(value, list):
input_values[key] = []
for v in value:
if v in invocation_context.keys() and v is not None:
input_values[key].append(invocation_context[v])
else:
if value in invocation_context.keys() and value is not None:
input_values[key] = invocation_context[value]
input_values = {**self.params, **input_values}
try:
logger.debug(
"Shaper is invoking this function: %s(%s)",
self.function.__name__,
", ".join([f"{key}={value}" for key, value in input_values.items()]),
)
output_values = self.function(**input_values)
except TypeError as e:
raise ValueError(
"Shaper couldn't apply the function to your inputs and parameters. "
"Check the above stacktrace and make sure you provided all the correct inputs, parameters, "
"and parameter types."
) from e
for output_key, output_value in zip(self.outputs, output_values):
invocation_context[output_key] = output_value
results = {"invocation_context": invocation_context}
if output_key in ["query", "file_paths", "labels", "documents", "meta"]:
results[output_key] = output_value
return results, "output_1"
def run_batch( # type: ignore
self,
query: Optional[str] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[dict] = None,
invocation_context: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict, str]:
return self.run(
query=query,
file_paths=file_paths,
labels=labels,
documents=documents,
meta=meta,
invocation_context=invocation_context,
)

View File

@ -911,29 +911,42 @@ class PromptNode(BaseComponent):
:param meta: The meta to be used for the prompt. Usually not used.
:param invocation_context: The invocation context to be used for the prompt.
"""
# invocation_context is a dictionary that is passed from a pipeline node to a pipeline node and can be used
# to pass results from a pipeline node to any other downstream pipeline node.
invocation_context = invocation_context or {}
# prompt_collector is an empty list, it's passed to the PromptNode that will fill it with the rendered prompts,
# so that they can be returned by `run()` as part of the pipeline's debug output.
prompt_collector: List[str] = []
results = self(
query=query,
labels=labels,
documents=[doc.content for doc in documents if isinstance(doc.content, str)] if documents else [],
prompt_collector=prompt_collector,
**invocation_context,
)
invocation_context = invocation_context or {}
if query and "query" not in invocation_context.keys():
invocation_context["query"] = query
if file_paths and "file_paths" not in invocation_context.keys():
invocation_context["file_paths"] = file_paths
if labels and "labels" not in invocation_context.keys():
invocation_context["labels"] = labels
if documents and "documents" not in invocation_context.keys():
invocation_context["documents"] = documents
if meta and "meta" not in invocation_context.keys():
invocation_context["meta"] = meta
if "documents" in invocation_context.keys():
for doc in invocation_context.get("documents", []):
if not isinstance(doc, str) and not isinstance(doc.content, str):
raise ValueError("PromptNode only accepts text documents.")
invocation_context["documents"] = [
doc.content if isinstance(doc, Document) else doc for doc in invocation_context.get("documents", [])
]
results = self(prompt_collector=prompt_collector, **invocation_context)
final_result: Dict[str, Any] = {}
if self.output_variable:
invocation_context[self.output_variable] = results
final_result[self.output_variable] = results
output_variable = self.output_variable or "results"
if output_variable:
invocation_context[output_variable] = results
final_result[output_variable] = results
final_result["results"] = results
final_result["invocation_context"] = invocation_context
final_result["_debug"] = {"prompts_used": prompt_collector}
return final_result, "output_1"

1099
test/nodes/test_shaper.py Normal file

File diff suppressed because it is too large Load Diff