diff --git a/haystack/nodes/prompt/invocation_layer/handlers.py b/haystack/nodes/prompt/invocation_layer/handlers.py index e764100d7..446ddedf0 100644 --- a/haystack/nodes/prompt/invocation_layer/handlers.py +++ b/haystack/nodes/prompt/invocation_layer/handlers.py @@ -47,7 +47,7 @@ class HFTokenStreamingHandler(TextStreamer): # pylint: disable=useless-object-i stream_handler: "TokenStreamingHandler", ): transformers_import.check() - super().__init__(tokenizer=tokenizer) # type: ignore + super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore self.token_handler = stream_handler def on_finalized_text(self, token: str, stream_end: bool = False): diff --git a/releasenotes/notes/add-skip-prompt-for-hf-model-agent-89aef2838edb907c.yaml b/releasenotes/notes/add-skip-prompt-for-hf-model-agent-89aef2838edb907c.yaml new file mode 100644 index 000000000..51760da92 --- /dev/null +++ b/releasenotes/notes/add-skip-prompt-for-hf-model-agent-89aef2838edb907c.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fix the bug that the responses of Agents using local HF models contain the prompt text. diff --git a/test/prompt/invocation_layer/test_hugging_face.py b/test/prompt/invocation_layer/test_hugging_face.py index e2e721a7c..7c3c43a10 100644 --- a/test/prompt/invocation_layer/test_hugging_face.py +++ b/test/prompt/invocation_layer/test_hugging_face.py @@ -632,3 +632,17 @@ def test_tokenizer_loading_unsupported_model_with_tokenizer_class_in_config( invocation_layer = HFLocalInvocationLayer(model_name_or_path="unsupported_model", trust_remote_code=True) assert not mock_tokenizer.called assert not caplog.text + + +@pytest.mark.unit +def test_skip_prompt_is_set_in_hf_text_streamer(mock_pipeline, mock_get_task): + """ + Test that skip_prompt is set in HFTextStreamingHandler. Otherwise, we will output prompt text. + """ + layer = HFLocalInvocationLayer(stream=True) + + layer.invoke(prompt="Tell me hello") + + _, kwargs = layer.pipe.call_args + assert "streamer" in kwargs and isinstance(kwargs["streamer"], HFTokenStreamingHandler) + assert kwargs["streamer"].skip_prompt