diff --git a/haystack/nodes/prompt/prompt_model.py b/haystack/nodes/prompt/prompt_model.py index f8a406004..c4c903089 100644 --- a/haystack/nodes/prompt/prompt_model.py +++ b/haystack/nodes/prompt/prompt_model.py @@ -111,6 +111,16 @@ class PromptModel(BaseComponent): output = self.model_invocation_layer.invoke(prompt=prompt, **kwargs) return output + async def ainvoke(self, prompt: Union[str, List[str], List[Dict[str, str]]], **kwargs) -> List[str]: + """ + Drop-in replacement asyncio version of the `invoke` method, see there for documentation. + """ + try: + return await self.model_invocation_layer.invoke(prompt=prompt, **kwargs) + except TypeError: + # The `invoke` method of the underlying invocation layer doesn't support asyncio + return self.model_invocation_layer.invoke(prompt=prompt, **kwargs) + @overload def _ensure_token_limit(self, prompt: str) -> str: ... diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index 060bec69c..2360d5b44 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -232,6 +232,37 @@ class PromptNode(BaseComponent): return list(template.prompt_params) return [] + def _prepare( # type: ignore + self, query, file_paths, labels, documents, meta, invocation_context, prompt_template, generation_kwargs + ) -> Dict: + """ + Prepare prompt invocation. + """ + invocation_context = invocation_context or {} + + if query and "query" not in invocation_context: + invocation_context["query"] = query + + if file_paths and "file_paths" not in invocation_context: + invocation_context["file_paths"] = file_paths + + if labels and "labels" not in invocation_context: + invocation_context["labels"] = labels + + if documents and "documents" not in invocation_context: + invocation_context["documents"] = documents + + if meta and "meta" not in invocation_context: + invocation_context["meta"] = meta + + if "prompt_template" not in invocation_context: + invocation_context["prompt_template"] = self.get_prompt_template(prompt_template) + + if generation_kwargs: + invocation_context.update(generation_kwargs) + + return invocation_context + def run( self, query: Optional[str] = None, @@ -272,29 +303,86 @@ class PromptNode(BaseComponent): # so that they can be returned by `run()` as part of the pipeline's debug output. prompt_collector: List[str] = [] - invocation_context = invocation_context or {} - if query and "query" not in invocation_context.keys(): - invocation_context["query"] = query + invocation_context = self._prepare( + query, file_paths, labels, documents, meta, invocation_context, prompt_template, generation_kwargs + ) - if file_paths and "file_paths" not in invocation_context.keys(): - invocation_context["file_paths"] = file_paths + results = self(**invocation_context, prompt_collector=prompt_collector) - if labels and "labels" not in invocation_context.keys(): - invocation_context["labels"] = labels + prompt_template_resolved: PromptTemplate = invocation_context.pop("prompt_template") - if documents and "documents" not in invocation_context.keys(): - invocation_context["documents"] = documents + try: + output_variable = self.output_variable or prompt_template_resolved.output_variable or "results" + except: + output_variable = "results" - if meta and "meta" not in invocation_context.keys(): - invocation_context["meta"] = meta + invocation_context[output_variable] = results + invocation_context["prompts"] = prompt_collector + final_result: Dict[str, Any] = {output_variable: results, "invocation_context": invocation_context} - if "prompt_template" not in invocation_context.keys(): - invocation_context["prompt_template"] = self.get_prompt_template(prompt_template) + if self.debug: + final_result["_debug"] = {"prompts_used": prompt_collector} - if generation_kwargs: - invocation_context.update(generation_kwargs) + return final_result, "output_1" - results = self(prompt_collector=prompt_collector, **invocation_context) + async def _aprompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs): + """ + Async version of the actual prompt invocation. + """ + results = [] + # we pop the prompt_collector kwarg to avoid passing it to the model + prompt_collector: List[Union[str, List[Dict[str, str]]]] = kwargs.pop("prompt_collector", []) + + # kwargs override model kwargs + kwargs = {**self._prepare_model_kwargs(), **kwargs} + template_to_fill = self.get_prompt_template(prompt_template) + if template_to_fill: + # prompt template used, yield prompts from inputs args + for prompt in template_to_fill.fill(*args, **kwargs): + kwargs_copy = template_to_fill.remove_template_params(copy.copy(kwargs)) + # and pass the prepared prompt and kwargs copy to the model + prompt = self.prompt_model._ensure_token_limit(prompt) + prompt_collector.append(prompt) + logger.debug("Prompt being sent to LLM with prompt %s and kwargs %s", prompt, kwargs_copy) + output = await self.prompt_model.ainvoke(prompt, **kwargs_copy) + results.extend(output) + + kwargs["prompts"] = prompt_collector + results = template_to_fill.post_process(results, **kwargs) + else: + # straightforward prompt, no templates used + for prompt in list(args): + kwargs_copy = copy.copy(kwargs) + prompt = self.prompt_model._ensure_token_limit(prompt) + prompt_collector.append(prompt) + logger.debug("Prompt being sent to LLM with prompt %s and kwargs %s ", prompt, kwargs_copy) + output = await self.prompt_model.ainvoke(prompt, **kwargs_copy) + results.extend(output) + return results + + async def arun( + self, + query: Optional[str] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[MultiLabel] = None, + documents: Optional[List[Document]] = None, + meta: Optional[dict] = None, + invocation_context: Optional[Dict[str, Any]] = None, + prompt_template: Optional[Union[str, PromptTemplate]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[Dict, str]: + """ + Drop-in replacement asyncio version of the `run` method, see there for documentation. + """ + prompt_collector: List[str] = [] + + invocation_context = self._prepare( + query, file_paths, labels, documents, meta, invocation_context, prompt_template, generation_kwargs + ) + + # Let's skip the call to __call__, because all it does is injecting a prompt template + # if there isn't any, while we know for sure it'll be in `invocation_context`. + results = await self._aprompt(prompt_collector=prompt_collector, **invocation_context) prompt_template_resolved: PromptTemplate = invocation_context.pop("prompt_template") diff --git a/releasenotes/notes/add-promptnode-arun-bc4c2bcc9c653015.yaml b/releasenotes/notes/add-promptnode-arun-bc4c2bcc9c653015.yaml new file mode 100644 index 000000000..e9e3305ea --- /dev/null +++ b/releasenotes/notes/add-promptnode-arun-bc4c2bcc9c653015.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + PromptNode can now be run asynchronously by calling the `arun` method. diff --git a/test/prompt/test_prompt_model.py b/test/prompt/test_prompt_model.py index 1ac1702a1..9e7c2b0f5 100644 --- a/test/prompt/test_prompt_model.py +++ b/test/prompt/test_prompt_model.py @@ -1,4 +1,5 @@ -from unittest.mock import patch, Mock +import asyncio +from unittest.mock import patch, MagicMock import pytest @@ -36,3 +37,24 @@ def test_construtor_with_custom_model(): def test_constructor_with_no_supported_model(): with pytest.raises(ValueError, match="Model some-random-model is not supported"): PromptModel("some-random-model") + + +@pytest.mark.asyncio +async def test_ainvoke(): + def async_return(result): + f = asyncio.Future() + f.set_result(result) + return f + + mock_layer = MagicMock() # no async-defined methods, await will fail and fall back to regular `invoke` + mock_layer.return_value.invoke.return_value = async_return("Async Bar!") + model = PromptModel(invocation_layer_class=mock_layer) + assert await model.ainvoke("Foo") == "Async Bar!" + + +@pytest.mark.asyncio +async def test_ainvoke_falls_back_to_sync(): + mock_layer = MagicMock() # no async-defined methods, await will fail and fall back to regular `invoke` + mock_layer.return_value.invoke.return_value = "Bar!" + model = PromptModel(invocation_layer_class=mock_layer) + assert await model.ainvoke("Foo") == "Bar!" diff --git a/test/prompt/test_prompt_node.py b/test/prompt/test_prompt_node.py index fe9d034f1..c0ae8480e 100644 --- a/test/prompt/test_prompt_node.py +++ b/test/prompt/test_prompt_node.py @@ -1,18 +1,15 @@ import os import logging from typing import Optional, Union, List, Dict, Any, Tuple -from unittest.mock import patch, Mock, MagicMock +from unittest.mock import patch, Mock, MagicMock, AsyncMock import pytest from prompthub import Prompt -from transformers import GenerationConfig, TextStreamer from haystack import Document, Pipeline, BaseComponent, MultiLabel from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel from haystack.nodes.prompt.prompt_template import LEGACY_DEFAULT_TEMPLATES from haystack.nodes.prompt.invocation_layer import ( - HFLocalInvocationLayer, - DefaultTokenStreamingHandler, AzureChatGPTInvocationLayer, AzureOpenAIInvocationLayer, OpenAIInvocationLayer, @@ -1098,3 +1095,83 @@ def test_prompt_node_warns_about_missing_documents(mock_model, caplog): "Expected prompt parameter 'documents' to be provided but it is missing. " "Continuing with an empty list of documents." in caplog.text ) + + +@pytest.mark.unit +@patch("haystack.nodes.prompt.prompt_node.PromptModel") +def test__prepare_invocation_context_is_empty(mock_model): + node = PromptNode() + node.get_prompt_template = MagicMock(return_value="Test Template") + + kwargs = { + "query": "query", + "file_paths": ["foo", "bar"], + "labels": ["label", "another"], + "documents": ["A", "B"], + "meta": {"meta_key": "meta_value"}, + "prompt_template": "my-test-prompt", + "invocation_context": None, + "generation_kwargs": {"gen_key": "gen_value"}, + } + + invocation_context = node._prepare(**kwargs) + + node.get_prompt_template.assert_called_once_with("my-test-prompt") + assert invocation_context == { + "query": "query", + "file_paths": ["foo", "bar"], + "labels": ["label", "another"], + "documents": ["A", "B"], + "meta": {"meta_key": "meta_value"}, + "prompt_template": "Test Template", + "gen_key": "gen_value", + } + + +@pytest.mark.unit +@patch("haystack.nodes.prompt.prompt_node.PromptModel") +def test__prepare_invocation_context_was_passed(mock_model): + node = PromptNode() + + # Test invocation_context is left untouched + invocation_context = { + "query": "query", + "file_paths": ["foo", "bar"], + "labels": ["label", "another"], + "documents": ["A", "B"], + "meta": {"meta_key": "meta_value"}, + "prompt_template": "my-test-prompt", + "invocation_context": None, + } + kwargs = { + "query": None, + "file_paths": None, + "labels": None, + "documents": None, + "meta": None, + "prompt_template": None, + "invocation_context": invocation_context, + "generation_kwargs": None, + } + + assert node._prepare(**kwargs) == invocation_context + + +@pytest.mark.unit +@pytest.mark.asyncio +@patch("haystack.nodes.prompt.prompt_node.PromptModel") +async def test_arun(mock_model): + node = PromptNode() + node._aprompt = AsyncMock() + await node.arun("a query") + node._aprompt.assert_awaited_once_with(prompt_collector=[], query="a query", prompt_template=None) + + +@pytest.mark.unit +@pytest.mark.asyncio +@patch("haystack.nodes.prompt.prompt_node.PromptModel") +async def test_aprompt(mock_model): + node = PromptNode() + mock_model.return_value.ainvoke = AsyncMock() + await node._aprompt(PromptTemplate("test template")) + mock_model.return_value.ainvoke.assert_awaited_once()