fix: hf agent outputs the prompt text while the openai agent not (#5461)

* add skil prompt

* fix formatting

* add release note

* add release note

* Update releasenotes/notes/add-skip-prompt-for-hf-model-agent-89aef2838edb907c.yaml

Co-authored-by: Daria Fokina <daria.f93@gmail.com>

* Update haystack/nodes/prompt/invocation_layer/handlers.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update haystack/nodes/prompt/invocation_layer/handlers.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update haystack/nodes/prompt/invocation_layer/hugging_face.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* add a unit test

* add a unit test2

* add skil prompt

* Revert "add skil prompt"

This reverts commit b1ba938c94b67a4fd636d321945990aabd2c5b2a.

* add unit test

---------

Co-authored-by: Daria Fokina <daria.f93@gmail.com>
Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
Fanli Lin 2023-08-02 22:34:33 +08:00 committed by GitHub
parent 73fa796735
commit 8d04f28e11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 1 deletions

View File

@ -47,7 +47,7 @@ class HFTokenStreamingHandler(TextStreamer): # pylint: disable=useless-object-i
stream_handler: "TokenStreamingHandler",
):
transformers_import.check()
super().__init__(tokenizer=tokenizer) # type: ignore
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
self.token_handler = stream_handler
def on_finalized_text(self, token: str, stream_end: bool = False):

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Fix the bug that the responses of Agents using local HF models contain the prompt text.

View File

@ -632,3 +632,17 @@ def test_tokenizer_loading_unsupported_model_with_tokenizer_class_in_config(
invocation_layer = HFLocalInvocationLayer(model_name_or_path="unsupported_model", trust_remote_code=True)
assert not mock_tokenizer.called
assert not caplog.text
@pytest.mark.unit
def test_skip_prompt_is_set_in_hf_text_streamer(mock_pipeline, mock_get_task):
"""
Test that skip_prompt is set in HFTextStreamingHandler. Otherwise, we will output prompt text.
"""
layer = HFLocalInvocationLayer(stream=True)
layer.invoke(prompt="Tell me hello")
_, kwargs = layer.pipe.call_args
assert "streamer" in kwargs and isinstance(kwargs["streamer"], HFTokenStreamingHandler)
assert kwargs["streamer"].skip_prompt