feat: add async PromptNode run (#5890)

* add async promptnode

* Remove unecessary calls to dict.keys()

---------

Co-authored-by: Silvano Cerza <silvanocerza@gmail.com>
Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
Massimiliano Pippi 2023-09-29 08:40:01 +02:00 committed by GitHub
parent 578f2b4bbf
commit 0947f59545
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 222 additions and 21 deletions

View File

@ -111,6 +111,16 @@ class PromptModel(BaseComponent):
output = self.model_invocation_layer.invoke(prompt=prompt, **kwargs) output = self.model_invocation_layer.invoke(prompt=prompt, **kwargs)
return output return output
async def ainvoke(self, prompt: Union[str, List[str], List[Dict[str, str]]], **kwargs) -> List[str]:
"""
Drop-in replacement asyncio version of the `invoke` method, see there for documentation.
"""
try:
return await self.model_invocation_layer.invoke(prompt=prompt, **kwargs)
except TypeError:
# The `invoke` method of the underlying invocation layer doesn't support asyncio
return self.model_invocation_layer.invoke(prompt=prompt, **kwargs)
@overload @overload
def _ensure_token_limit(self, prompt: str) -> str: def _ensure_token_limit(self, prompt: str) -> str:
... ...

View File

@ -232,6 +232,37 @@ class PromptNode(BaseComponent):
return list(template.prompt_params) return list(template.prompt_params)
return [] return []
def _prepare( # type: ignore
self, query, file_paths, labels, documents, meta, invocation_context, prompt_template, generation_kwargs
) -> Dict:
"""
Prepare prompt invocation.
"""
invocation_context = invocation_context or {}
if query and "query" not in invocation_context:
invocation_context["query"] = query
if file_paths and "file_paths" not in invocation_context:
invocation_context["file_paths"] = file_paths
if labels and "labels" not in invocation_context:
invocation_context["labels"] = labels
if documents and "documents" not in invocation_context:
invocation_context["documents"] = documents
if meta and "meta" not in invocation_context:
invocation_context["meta"] = meta
if "prompt_template" not in invocation_context:
invocation_context["prompt_template"] = self.get_prompt_template(prompt_template)
if generation_kwargs:
invocation_context.update(generation_kwargs)
return invocation_context
def run( def run(
self, self,
query: Optional[str] = None, query: Optional[str] = None,
@ -272,29 +303,86 @@ class PromptNode(BaseComponent):
# so that they can be returned by `run()` as part of the pipeline's debug output. # so that they can be returned by `run()` as part of the pipeline's debug output.
prompt_collector: List[str] = [] prompt_collector: List[str] = []
invocation_context = invocation_context or {} invocation_context = self._prepare(
if query and "query" not in invocation_context.keys(): query, file_paths, labels, documents, meta, invocation_context, prompt_template, generation_kwargs
invocation_context["query"] = query )
if file_paths and "file_paths" not in invocation_context.keys(): results = self(**invocation_context, prompt_collector=prompt_collector)
invocation_context["file_paths"] = file_paths
if labels and "labels" not in invocation_context.keys(): prompt_template_resolved: PromptTemplate = invocation_context.pop("prompt_template")
invocation_context["labels"] = labels
if documents and "documents" not in invocation_context.keys(): try:
invocation_context["documents"] = documents output_variable = self.output_variable or prompt_template_resolved.output_variable or "results"
except:
output_variable = "results"
if meta and "meta" not in invocation_context.keys(): invocation_context[output_variable] = results
invocation_context["meta"] = meta invocation_context["prompts"] = prompt_collector
final_result: Dict[str, Any] = {output_variable: results, "invocation_context": invocation_context}
if "prompt_template" not in invocation_context.keys(): if self.debug:
invocation_context["prompt_template"] = self.get_prompt_template(prompt_template) final_result["_debug"] = {"prompts_used": prompt_collector}
if generation_kwargs: return final_result, "output_1"
invocation_context.update(generation_kwargs)
results = self(prompt_collector=prompt_collector, **invocation_context) async def _aprompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs):
"""
Async version of the actual prompt invocation.
"""
results = []
# we pop the prompt_collector kwarg to avoid passing it to the model
prompt_collector: List[Union[str, List[Dict[str, str]]]] = kwargs.pop("prompt_collector", [])
# kwargs override model kwargs
kwargs = {**self._prepare_model_kwargs(), **kwargs}
template_to_fill = self.get_prompt_template(prompt_template)
if template_to_fill:
# prompt template used, yield prompts from inputs args
for prompt in template_to_fill.fill(*args, **kwargs):
kwargs_copy = template_to_fill.remove_template_params(copy.copy(kwargs))
# and pass the prepared prompt and kwargs copy to the model
prompt = self.prompt_model._ensure_token_limit(prompt)
prompt_collector.append(prompt)
logger.debug("Prompt being sent to LLM with prompt %s and kwargs %s", prompt, kwargs_copy)
output = await self.prompt_model.ainvoke(prompt, **kwargs_copy)
results.extend(output)
kwargs["prompts"] = prompt_collector
results = template_to_fill.post_process(results, **kwargs)
else:
# straightforward prompt, no templates used
for prompt in list(args):
kwargs_copy = copy.copy(kwargs)
prompt = self.prompt_model._ensure_token_limit(prompt)
prompt_collector.append(prompt)
logger.debug("Prompt being sent to LLM with prompt %s and kwargs %s ", prompt, kwargs_copy)
output = await self.prompt_model.ainvoke(prompt, **kwargs_copy)
results.extend(output)
return results
async def arun(
self,
query: Optional[str] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[dict] = None,
invocation_context: Optional[Dict[str, Any]] = None,
prompt_template: Optional[Union[str, PromptTemplate]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict, str]:
"""
Drop-in replacement asyncio version of the `run` method, see there for documentation.
"""
prompt_collector: List[str] = []
invocation_context = self._prepare(
query, file_paths, labels, documents, meta, invocation_context, prompt_template, generation_kwargs
)
# Let's skip the call to __call__, because all it does is injecting a prompt template
# if there isn't any, while we know for sure it'll be in `invocation_context`.
results = await self._aprompt(prompt_collector=prompt_collector, **invocation_context)
prompt_template_resolved: PromptTemplate = invocation_context.pop("prompt_template") prompt_template_resolved: PromptTemplate = invocation_context.pop("prompt_template")

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
PromptNode can now be run asynchronously by calling the `arun` method.

View File

@ -1,4 +1,5 @@
from unittest.mock import patch, Mock import asyncio
from unittest.mock import patch, MagicMock
import pytest import pytest
@ -36,3 +37,24 @@ def test_construtor_with_custom_model():
def test_constructor_with_no_supported_model(): def test_constructor_with_no_supported_model():
with pytest.raises(ValueError, match="Model some-random-model is not supported"): with pytest.raises(ValueError, match="Model some-random-model is not supported"):
PromptModel("some-random-model") PromptModel("some-random-model")
@pytest.mark.asyncio
async def test_ainvoke():
def async_return(result):
f = asyncio.Future()
f.set_result(result)
return f
mock_layer = MagicMock() # no async-defined methods, await will fail and fall back to regular `invoke`
mock_layer.return_value.invoke.return_value = async_return("Async Bar!")
model = PromptModel(invocation_layer_class=mock_layer)
assert await model.ainvoke("Foo") == "Async Bar!"
@pytest.mark.asyncio
async def test_ainvoke_falls_back_to_sync():
mock_layer = MagicMock() # no async-defined methods, await will fail and fall back to regular `invoke`
mock_layer.return_value.invoke.return_value = "Bar!"
model = PromptModel(invocation_layer_class=mock_layer)
assert await model.ainvoke("Foo") == "Bar!"

View File

@ -1,18 +1,15 @@
import os import os
import logging import logging
from typing import Optional, Union, List, Dict, Any, Tuple from typing import Optional, Union, List, Dict, Any, Tuple
from unittest.mock import patch, Mock, MagicMock from unittest.mock import patch, Mock, MagicMock, AsyncMock
import pytest import pytest
from prompthub import Prompt from prompthub import Prompt
from transformers import GenerationConfig, TextStreamer
from haystack import Document, Pipeline, BaseComponent, MultiLabel from haystack import Document, Pipeline, BaseComponent, MultiLabel
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
from haystack.nodes.prompt.prompt_template import LEGACY_DEFAULT_TEMPLATES from haystack.nodes.prompt.prompt_template import LEGACY_DEFAULT_TEMPLATES
from haystack.nodes.prompt.invocation_layer import ( from haystack.nodes.prompt.invocation_layer import (
HFLocalInvocationLayer,
DefaultTokenStreamingHandler,
AzureChatGPTInvocationLayer, AzureChatGPTInvocationLayer,
AzureOpenAIInvocationLayer, AzureOpenAIInvocationLayer,
OpenAIInvocationLayer, OpenAIInvocationLayer,
@ -1098,3 +1095,83 @@ def test_prompt_node_warns_about_missing_documents(mock_model, caplog):
"Expected prompt parameter 'documents' to be provided but it is missing. " "Expected prompt parameter 'documents' to be provided but it is missing. "
"Continuing with an empty list of documents." in caplog.text "Continuing with an empty list of documents." in caplog.text
) )
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test__prepare_invocation_context_is_empty(mock_model):
node = PromptNode()
node.get_prompt_template = MagicMock(return_value="Test Template")
kwargs = {
"query": "query",
"file_paths": ["foo", "bar"],
"labels": ["label", "another"],
"documents": ["A", "B"],
"meta": {"meta_key": "meta_value"},
"prompt_template": "my-test-prompt",
"invocation_context": None,
"generation_kwargs": {"gen_key": "gen_value"},
}
invocation_context = node._prepare(**kwargs)
node.get_prompt_template.assert_called_once_with("my-test-prompt")
assert invocation_context == {
"query": "query",
"file_paths": ["foo", "bar"],
"labels": ["label", "another"],
"documents": ["A", "B"],
"meta": {"meta_key": "meta_value"},
"prompt_template": "Test Template",
"gen_key": "gen_value",
}
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test__prepare_invocation_context_was_passed(mock_model):
node = PromptNode()
# Test invocation_context is left untouched
invocation_context = {
"query": "query",
"file_paths": ["foo", "bar"],
"labels": ["label", "another"],
"documents": ["A", "B"],
"meta": {"meta_key": "meta_value"},
"prompt_template": "my-test-prompt",
"invocation_context": None,
}
kwargs = {
"query": None,
"file_paths": None,
"labels": None,
"documents": None,
"meta": None,
"prompt_template": None,
"invocation_context": invocation_context,
"generation_kwargs": None,
}
assert node._prepare(**kwargs) == invocation_context
@pytest.mark.unit
@pytest.mark.asyncio
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
async def test_arun(mock_model):
node = PromptNode()
node._aprompt = AsyncMock()
await node.arun("a query")
node._aprompt.assert_awaited_once_with(prompt_collector=[], query="a query", prompt_template=None)
@pytest.mark.unit
@pytest.mark.asyncio
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
async def test_aprompt(mock_model):
node = PromptNode()
mock_model.return_value.ainvoke = AsyncMock()
await node._aprompt(PromptTemplate("test template"))
mock_model.return_value.ainvoke.assert_awaited_once()