mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 10:49:30 +00:00
feat: add Shaper (#3880)
* Shaper initial version * Inital pydoc * Add more unit tests * Fix pydoc, expand Shaper pydoc with YAML example * Minor fix * Improve pydoc * More unit tests with prompt node * Describe Shaper functions in pydoc * More pydoc * Use pytest.raises instead of catching errors * Improve test_function_invocation_order unit test * pylint fixes * Improve run_batch handling * simpler version, initial stub * stubbing tests * promptnode compatibility * add tests * simplify * fix promptnode tests * pylint * mypy * fix corner case & mypy * mypy * review feedback * tests * Add lg updates * add rename * pylint * Add complex unit test with two PNs and ICMs in between (#3921) Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com> * docstring * fix tests * add join_lists * add documents_to_strings * fix tests * allow lists of input values * doc review feedback * do not use locals() * Update with minor lg changes * fix corner case in ICM * fix merge * review feedback * answers conversions * mypy * add tests * generative answers * forgot to commit --------- Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com> Co-authored-by: agnieszka-m <amarzec13@gmail.com>
This commit is contained in:
parent
e8ff48094b
commit
9009a9ae58
@ -23,7 +23,7 @@ from haystack.nodes.file_converter import (
|
||||
)
|
||||
from haystack.nodes.image_to_text import TransformersImageToText
|
||||
from haystack.nodes.label_generator import PseudoLabelGenerator
|
||||
from haystack.nodes.other import Docs2Answers, JoinDocuments, RouteDocuments, JoinAnswers, DocumentMerger
|
||||
from haystack.nodes.other import Docs2Answers, JoinDocuments, RouteDocuments, JoinAnswers, DocumentMerger, Shaper
|
||||
from haystack.nodes.preprocessor import BasePreProcessor, PreProcessor
|
||||
from haystack.nodes.prompt import PromptNode, PromptTemplate, PromptModel
|
||||
from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier
|
||||
|
||||
@ -4,3 +4,4 @@ from haystack.nodes.other.route_documents import RouteDocuments
|
||||
from haystack.nodes.other.join_answers import JoinAnswers
|
||||
from haystack.nodes.other.join import JoinNode
|
||||
from haystack.nodes.other.document_merger import DocumentMerger
|
||||
from haystack.nodes.other.shaper import Shaper
|
||||
|
||||
408
haystack/nodes/other/shaper.py
Normal file
408
haystack/nodes/other/shaper.py
Normal file
@ -0,0 +1,408 @@
|
||||
from typing import Optional, List, Dict, Any, Tuple, Union, Callable
|
||||
|
||||
import logging
|
||||
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.schema import Document, Answer, MultiLabel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def rename(value: Any) -> Tuple[Any]:
|
||||
"""
|
||||
Identity function. Can be used to rename values in the invocation context without changing them.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert rename(1) == (1, )
|
||||
```
|
||||
"""
|
||||
return (value,)
|
||||
|
||||
|
||||
def value_to_list(value: Any, target_list: List[Any]) -> Tuple[List[Any]]:
|
||||
"""
|
||||
Transforms a value into a list containing this value as many times as the length of the target list.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert value_to_list(value=1, target_list=list(range(5))) == ([1, 1, 1, 1, 1], )
|
||||
```
|
||||
"""
|
||||
return ([value] * len(target_list),)
|
||||
|
||||
|
||||
def join_lists(lists: List[List[Any]]) -> Tuple[List[Any]]:
|
||||
"""
|
||||
Joins the passed lists into a single one.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert join_lists(lists=[[1, 2, 3], [4, 5]]) == ([1, 2, 3, 4, 5], )
|
||||
```
|
||||
"""
|
||||
merged_list = []
|
||||
for inner_list in lists:
|
||||
merged_list += inner_list
|
||||
return (merged_list,)
|
||||
|
||||
|
||||
def join_strings(strings: List[str], delimiter: str = " ") -> Tuple[str]:
|
||||
"""
|
||||
Transforms a list of strings into a single string. The content of this string
|
||||
is the content of all original strings separated by the delimiter you specify.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert join_strings(strings=["first", "second", "third"], delimiter=" - ") == ("first - second - third", )
|
||||
```
|
||||
"""
|
||||
return (delimiter.join(strings),)
|
||||
|
||||
|
||||
def join_documents(documents: List[Document], delimiter: str = " ") -> Tuple[List[Document]]:
|
||||
"""
|
||||
Transforms a list of documents into a list containing a single Document. The content of this list
|
||||
is the content of all original documents separated by the delimiter you specify.
|
||||
|
||||
All metadata is dropped. (TODO: fix)
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert join_documents(
|
||||
documents=[
|
||||
Document(content="first"),
|
||||
Document(content="second"),
|
||||
Document(content="third")
|
||||
],
|
||||
delimiter=" - "
|
||||
) == ([Document(content="first - second - third")], )
|
||||
```
|
||||
"""
|
||||
return ([Document(content=delimiter.join([d.content for d in documents]))],)
|
||||
|
||||
|
||||
def strings_to_answers(strings: List[str]) -> Tuple[List[Answer]]:
|
||||
"""
|
||||
Transforms a list of strings into a list of Answers.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert strings_to_answers(strings=["first", "second", "third"]) == ([
|
||||
Answer(answer="first"),
|
||||
Answer(answer="second"),
|
||||
Answer(answer="third"),
|
||||
], )
|
||||
```
|
||||
"""
|
||||
return ([Answer(answer=string, type="generative") for string in strings],)
|
||||
|
||||
|
||||
def answers_to_strings(answers: List[Answer]) -> Tuple[List[str]]:
|
||||
"""
|
||||
Extracts the content field of Documents and returns a list of strings.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert answers_to_strings(
|
||||
answers=[
|
||||
Answer(answer="first"),
|
||||
Answer(answer="second"),
|
||||
Answer(answer="third")
|
||||
]
|
||||
) == (["first", "second", "third"],)
|
||||
```
|
||||
"""
|
||||
return ([answer.answer for answer in answers],)
|
||||
|
||||
|
||||
def strings_to_documents(
|
||||
strings: List[str],
|
||||
meta: Union[List[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]] = None,
|
||||
id_hash_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[List[Document]]:
|
||||
"""
|
||||
Transforms a list of strings into a list of Documents. If you pass the metadata in a single
|
||||
dictionary, all Documents get the same metadata. If you pass the metadata as a list, the length of this list
|
||||
must be the same as the length of the list of strings, and each Document gets its own metadata.
|
||||
You can specify `id_hash_keys` only once and it gets assigned to all Documents.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert strings_to_documents(
|
||||
strings=["first", "second", "third"],
|
||||
meta=[{"position": i} for i in range(3)],
|
||||
id_hash_keys=['content', 'meta]
|
||||
) == ([
|
||||
Document(content="first", metadata={"position": 1}, id_hash_keys=['content', 'meta])]),
|
||||
Document(content="second", metadata={"position": 2}, id_hash_keys=['content', 'meta]),
|
||||
Document(content="third", metadata={"position": 3}, id_hash_keys=['content', 'meta])
|
||||
], )
|
||||
```
|
||||
"""
|
||||
all_metadata: List[Optional[Dict[str, Any]]]
|
||||
if isinstance(meta, dict):
|
||||
all_metadata = [meta] * len(strings)
|
||||
elif isinstance(meta, list):
|
||||
if len(meta) != len(strings):
|
||||
raise ValueError(
|
||||
f"Not enough metadata dictionaries. strings_to_documents received {len(strings)} and {len(meta)} metadata dictionaries."
|
||||
)
|
||||
all_metadata = meta
|
||||
else:
|
||||
all_metadata = [None] * len(strings)
|
||||
|
||||
return ([Document(content=string, meta=m, id_hash_keys=id_hash_keys) for string, m in zip(strings, all_metadata)],)
|
||||
|
||||
|
||||
def documents_to_strings(documents: List[Document]) -> Tuple[List[str]]:
|
||||
"""
|
||||
Extracts the content field of Documents and returns a list of strings.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert documents_to_strings(
|
||||
documents=[
|
||||
Document(content="first"),
|
||||
Document(content="second"),
|
||||
Document(content="third")
|
||||
]
|
||||
) == (["first", "second", "third"],)
|
||||
```
|
||||
"""
|
||||
return ([doc.content for doc in documents],)
|
||||
|
||||
|
||||
REGISTERED_FUNCTIONS: Dict[str, Callable[..., Tuple[Any]]] = {
|
||||
"rename": rename,
|
||||
"value_to_list": value_to_list,
|
||||
"join_lists": join_lists,
|
||||
"join_strings": join_strings,
|
||||
"join_documents": join_documents,
|
||||
"strings_to_answers": strings_to_answers,
|
||||
"answers_to_strings": answers_to_strings,
|
||||
"strings_to_documents": strings_to_documents,
|
||||
"documents_to_strings": documents_to_strings,
|
||||
}
|
||||
|
||||
|
||||
class Shaper(BaseComponent):
|
||||
|
||||
"""
|
||||
Shaper is a component that can invoke arbitrary, registered functions on the invocation context
|
||||
(query, documents, and so on) of a pipeline. It then passes the new or modified variables further down the pipeline.
|
||||
|
||||
Using YAML configuration, the Shaper component is initialized with functions to invoke on pipeline invocation
|
||||
context.
|
||||
|
||||
For example, in the YAML snippet below:
|
||||
```yaml
|
||||
components:
|
||||
- name: shaper
|
||||
type: Shaper
|
||||
params:
|
||||
func: value_to_list
|
||||
inputs:
|
||||
value: query
|
||||
target_list: documents
|
||||
output: [questions]
|
||||
```
|
||||
Shaper component is initialized with a directive to invoke function expand on the variable query and to store
|
||||
the result in the invocation context variable questions. All other invocation context variables are passed down
|
||||
the pipeline as they are.
|
||||
|
||||
Shaper is especially useful for pipelines with PromptNodes, where we need to modify the invocation
|
||||
context to match the templates of PromptNodes.
|
||||
|
||||
You can use multiple Shaper components in a pipeline to modify the invocation context as needed.
|
||||
|
||||
`Shaper` supports the current functions:
|
||||
|
||||
- `value_to_list`
|
||||
- `join_strings`
|
||||
- `join_documents`
|
||||
- `join_lists`
|
||||
- `strings_to_documents`
|
||||
- `documents_to_strings`
|
||||
|
||||
See their descriptions in the code for details about their inputs, outputs, and other parameters.
|
||||
"""
|
||||
|
||||
outgoing_edges = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: str,
|
||||
outputs: List[str],
|
||||
inputs: Optional[Dict[str, Union[List[str], str]]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the Shaper component.
|
||||
|
||||
Some examples:
|
||||
|
||||
```yaml
|
||||
- name: shaper
|
||||
type: Shaper
|
||||
params:
|
||||
func: value_to_list
|
||||
inputs:
|
||||
value: query
|
||||
target_list: documents
|
||||
outputs:
|
||||
- questions
|
||||
```
|
||||
This node takes the content of `query` and creates a list that contains the value of `query` `len(documents)` times.
|
||||
This list is stored in the invocation context under the key `questions`.
|
||||
|
||||
```yaml
|
||||
- name: shaper
|
||||
type: Shaper
|
||||
params:
|
||||
func: join_documents
|
||||
inputs:
|
||||
value: documents
|
||||
params:
|
||||
delimiter: ' - '
|
||||
outputs:
|
||||
- documents
|
||||
```
|
||||
This node overwrites the content of `documents` in the invocation context with a list containing a single Document
|
||||
whose content is the concatenation of all the original Documents. So if `documents` contained
|
||||
`[Document("A"), Document("B"), Document("C")]`, this shaper overwrites it with `[Document("A - B - C")]`
|
||||
|
||||
```yaml
|
||||
- name: shaper
|
||||
type: Shaper
|
||||
params:
|
||||
func: join_strings
|
||||
params:
|
||||
strings: ['a', 'b', 'c']
|
||||
delimiter: ' . '
|
||||
outputs:
|
||||
- single_string
|
||||
|
||||
- name: shaper
|
||||
type: Shaper
|
||||
params:
|
||||
func: strings_to_documents
|
||||
inputs:
|
||||
strings: single_string
|
||||
metadata:
|
||||
name: 'my_file.txt'
|
||||
outputs:
|
||||
- single_document
|
||||
```
|
||||
These two nodes, executed one after the other, first add a key in the invocation context called `single_string`
|
||||
that contains `a . b . c`, and then create another key called `single_document` that contains instead
|
||||
`[Document(content="a . b . c", metadata={'name': 'my_file.txt'})]`.
|
||||
|
||||
:param func: The function to apply.
|
||||
:param inputs: Maps the function's input kwargs to the key-value pairs in the invocation context.
|
||||
For example, `value_to_list` expects the `value` and `target_list` parameters, so `inputs` might contain:
|
||||
`{'value': 'query', 'target_list': 'documents'}`. It doesn't need to contain all keyword args, see `params`.
|
||||
:param params: Maps the function's input kwargs to some fixed values. For example, `value_to_list` expects
|
||||
`value` and `target_list` parameters, so `params` might contain
|
||||
`{'value': 'A', 'target_list': [1, 1, 1, 1]}` and the node's output is `["A", "A", "A", "A"]`.
|
||||
It doesn't need to contain all keyword args, see `inputs`.
|
||||
You can use params to provide fallback values for arguments of `run` that you're not sure exist.
|
||||
So if you need `query` to exist, you can provide a fallback value in the params, which will be used only if `query`
|
||||
is not passed to this node by the pipeline.
|
||||
:param outputs: THe key to store the outputs in the invocation context. The length of the outputs must match
|
||||
the number of outputs produced by the function invoked.
|
||||
"""
|
||||
super().__init__()
|
||||
self.function = REGISTERED_FUNCTIONS[func]
|
||||
self.outputs = outputs
|
||||
self.inputs = inputs or {}
|
||||
self.params = params or {}
|
||||
|
||||
def run( # type: ignore
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[MultiLabel] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
meta: Optional[dict] = None,
|
||||
invocation_context: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Dict, str]:
|
||||
invocation_context = invocation_context or {}
|
||||
if query and "query" not in invocation_context.keys():
|
||||
invocation_context["query"] = query
|
||||
|
||||
if file_paths and "file_paths" not in invocation_context.keys():
|
||||
invocation_context["file_paths"] = file_paths
|
||||
|
||||
if labels and "labels" not in invocation_context.keys():
|
||||
invocation_context["labels"] = labels
|
||||
|
||||
if documents and "documents" not in invocation_context.keys():
|
||||
invocation_context["documents"] = documents
|
||||
|
||||
if meta and "meta" not in invocation_context.keys():
|
||||
invocation_context["meta"] = meta
|
||||
|
||||
input_values: Dict[str, Any] = {}
|
||||
for key, value in self.inputs.items():
|
||||
if isinstance(value, list):
|
||||
input_values[key] = []
|
||||
for v in value:
|
||||
if v in invocation_context.keys() and v is not None:
|
||||
input_values[key].append(invocation_context[v])
|
||||
else:
|
||||
if value in invocation_context.keys() and value is not None:
|
||||
input_values[key] = invocation_context[value]
|
||||
|
||||
input_values = {**self.params, **input_values}
|
||||
try:
|
||||
logger.debug(
|
||||
"Shaper is invoking this function: %s(%s)",
|
||||
self.function.__name__,
|
||||
", ".join([f"{key}={value}" for key, value in input_values.items()]),
|
||||
)
|
||||
output_values = self.function(**input_values)
|
||||
except TypeError as e:
|
||||
raise ValueError(
|
||||
"Shaper couldn't apply the function to your inputs and parameters. "
|
||||
"Check the above stacktrace and make sure you provided all the correct inputs, parameters, "
|
||||
"and parameter types."
|
||||
) from e
|
||||
|
||||
for output_key, output_value in zip(self.outputs, output_values):
|
||||
invocation_context[output_key] = output_value
|
||||
|
||||
results = {"invocation_context": invocation_context}
|
||||
if output_key in ["query", "file_paths", "labels", "documents", "meta"]:
|
||||
results[output_key] = output_value
|
||||
|
||||
return results, "output_1"
|
||||
|
||||
def run_batch( # type: ignore
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[MultiLabel] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
meta: Optional[dict] = None,
|
||||
invocation_context: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Dict, str]:
|
||||
|
||||
return self.run(
|
||||
query=query,
|
||||
file_paths=file_paths,
|
||||
labels=labels,
|
||||
documents=documents,
|
||||
meta=meta,
|
||||
invocation_context=invocation_context,
|
||||
)
|
||||
@ -911,29 +911,42 @@ class PromptNode(BaseComponent):
|
||||
:param meta: The meta to be used for the prompt. Usually not used.
|
||||
:param invocation_context: The invocation context to be used for the prompt.
|
||||
"""
|
||||
|
||||
# invocation_context is a dictionary that is passed from a pipeline node to a pipeline node and can be used
|
||||
# to pass results from a pipeline node to any other downstream pipeline node.
|
||||
invocation_context = invocation_context or {}
|
||||
|
||||
# prompt_collector is an empty list, it's passed to the PromptNode that will fill it with the rendered prompts,
|
||||
# so that they can be returned by `run()` as part of the pipeline's debug output.
|
||||
prompt_collector: List[str] = []
|
||||
|
||||
results = self(
|
||||
query=query,
|
||||
labels=labels,
|
||||
documents=[doc.content for doc in documents if isinstance(doc.content, str)] if documents else [],
|
||||
prompt_collector=prompt_collector,
|
||||
**invocation_context,
|
||||
)
|
||||
invocation_context = invocation_context or {}
|
||||
if query and "query" not in invocation_context.keys():
|
||||
invocation_context["query"] = query
|
||||
|
||||
if file_paths and "file_paths" not in invocation_context.keys():
|
||||
invocation_context["file_paths"] = file_paths
|
||||
|
||||
if labels and "labels" not in invocation_context.keys():
|
||||
invocation_context["labels"] = labels
|
||||
|
||||
if documents and "documents" not in invocation_context.keys():
|
||||
invocation_context["documents"] = documents
|
||||
|
||||
if meta and "meta" not in invocation_context.keys():
|
||||
invocation_context["meta"] = meta
|
||||
|
||||
if "documents" in invocation_context.keys():
|
||||
for doc in invocation_context.get("documents", []):
|
||||
if not isinstance(doc, str) and not isinstance(doc.content, str):
|
||||
raise ValueError("PromptNode only accepts text documents.")
|
||||
invocation_context["documents"] = [
|
||||
doc.content if isinstance(doc, Document) else doc for doc in invocation_context.get("documents", [])
|
||||
]
|
||||
|
||||
results = self(prompt_collector=prompt_collector, **invocation_context)
|
||||
|
||||
final_result: Dict[str, Any] = {}
|
||||
if self.output_variable:
|
||||
invocation_context[self.output_variable] = results
|
||||
final_result[self.output_variable] = results
|
||||
output_variable = self.output_variable or "results"
|
||||
if output_variable:
|
||||
invocation_context[output_variable] = results
|
||||
final_result[output_variable] = results
|
||||
|
||||
final_result["results"] = results
|
||||
final_result["invocation_context"] = invocation_context
|
||||
final_result["_debug"] = {"prompts_used": prompt_collector}
|
||||
return final_result, "output_1"
|
||||
|
||||
1099
test/nodes/test_shaper.py
Normal file
1099
test/nodes/test_shaper.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user