From f69c3e5cd26046b826927a39cad02af93b2ccbbf Mon Sep 17 00:00:00 2001 From: Christopher Keibel <55911084+CKeibel@users.noreply.github.com> Date: Tue, 19 Mar 2024 08:47:53 +0100 Subject: [PATCH] refactor: default for max_new_tokens to 512 in Hugging Face generators (#7370) * set default for max_new_tokens to 512 in Hugging Face generators * add release notes * fix tests * remove issues from release note --------- Co-authored-by: christopherkeibel Co-authored-by: Julian Risch --- .../generators/chat/hugging_face_tgi.py | 1 + .../generators/hugging_face_local.py | 1 + .../components/generators/hugging_face_tgi.py | 1 + ...ggingface-generators-76d9aba116b65e70.yaml | 4 +++ .../generators/chat/test_hugging_face_tgi.py | 28 ++++++++++--------- .../test_hugging_face_local_generator.py | 8 +++--- .../generators/test_hugging_face_tgi.py | 22 +++++++++------ 7 files changed, 39 insertions(+), 26 deletions(-) create mode 100644 releasenotes/notes/adjust-max-new-tokens-to-512-in-huggingface-generators-76d9aba116b65e70.yaml diff --git a/haystack/components/generators/chat/hugging_face_tgi.py b/haystack/components/generators/chat/hugging_face_tgi.py index 4fbed43eb..3e388a008 100644 --- a/haystack/components/generators/chat/hugging_face_tgi.py +++ b/haystack/components/generators/chat/hugging_face_tgi.py @@ -123,6 +123,7 @@ class HuggingFaceTGIChatGenerator: check_generation_params(generation_kwargs, ["n"]) generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", []) generation_kwargs["stop_sequences"].extend(stop_words or []) + generation_kwargs.setdefault("max_new_tokens", 512) self.model = model self.url = url diff --git a/haystack/components/generators/hugging_face_local.py b/haystack/components/generators/hugging_face_local.py index 90d0928cd..cbf9eabe2 100644 --- a/haystack/components/generators/hugging_face_local.py +++ b/haystack/components/generators/hugging_face_local.py @@ -106,6 +106,7 @@ class HuggingFaceLocalGenerator: "Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. " "Please specify only one of them." ) + generation_kwargs.setdefault("max_new_tokens", 512) self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs self.generation_kwargs = generation_kwargs diff --git a/haystack/components/generators/hugging_face_tgi.py b/haystack/components/generators/hugging_face_tgi.py index be66f0d68..06b065be4 100644 --- a/haystack/components/generators/hugging_face_tgi.py +++ b/haystack/components/generators/hugging_face_tgi.py @@ -111,6 +111,7 @@ class HuggingFaceTGIGenerator: check_generation_params(generation_kwargs, ["n"]) generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", []) generation_kwargs["stop_sequences"].extend(stop_words or []) + generation_kwargs.setdefault("max_new_tokens", 512) self.model = model self.url = url diff --git a/releasenotes/notes/adjust-max-new-tokens-to-512-in-huggingface-generators-76d9aba116b65e70.yaml b/releasenotes/notes/adjust-max-new-tokens-to-512-in-huggingface-generators-76d9aba116b65e70.yaml new file mode 100644 index 000000000..1103cfb61 --- /dev/null +++ b/releasenotes/notes/adjust-max-new-tokens-to-512-in-huggingface-generators-76d9aba116b65e70.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Set max_new_tokens default to 512 in Hugging Face generators. diff --git a/test/components/generators/chat/test_hugging_face_tgi.py b/test/components/generators/chat/test_hugging_face_tgi.py index 74e04a6da..75abeee2d 100644 --- a/test/components/generators/chat/test_hugging_face_tgi.py +++ b/test/components/generators/chat/test_hugging_face_tgi.py @@ -1,14 +1,12 @@ -from unittest.mock import patch, MagicMock, Mock - -from haystack.utils.auth import Secret +from unittest.mock import MagicMock, Mock, patch import pytest -from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason +from huggingface_hub.inference._text_generation import FinishReason, StreamDetails, TextGenerationStreamResponse, Token from huggingface_hub.utils import RepositoryNotFoundError from haystack.components.generators.chat import HuggingFaceTGIChatGenerator - -from haystack.dataclasses import StreamingChunk, ChatMessage +from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.utils.auth import Secret @pytest.fixture @@ -70,7 +68,11 @@ class TestHuggingFaceTGIChatGenerator: ) generator.warm_up() - assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}} + assert generator.generation_kwargs == { + **generation_kwargs, + **{"stop_sequences": ["stop"]}, + **{"max_new_tokens": 512}, + } assert generator.tokenizer is not None assert generator.client is not None assert generator.streaming_callback == streaming_callback @@ -92,7 +94,7 @@ class TestHuggingFaceTGIChatGenerator: # Assert that the init_params dictionary contains the expected keys and values assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf" assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"} - assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]} + assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512} def test_from_dict(self, mock_check_valid_model): generator = HuggingFaceTGIChatGenerator( @@ -106,7 +108,7 @@ class TestHuggingFaceTGIChatGenerator: generator_2 = HuggingFaceTGIChatGenerator.from_dict(result) assert generator_2.model == "NousResearch/Llama-2-7b-chat-hf" - assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]} + assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512} assert generator_2.streaming_callback is streaming_callback_handler def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models): @@ -205,7 +207,7 @@ class TestHuggingFaceTGIChatGenerator: # check kwargs passed to text_generation # note how n because it is not text generation parameter was not passed to text_generation _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": ["stop"]} + assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512} assert isinstance(response, dict) assert "replies" in response @@ -240,7 +242,7 @@ class TestHuggingFaceTGIChatGenerator: # check kwargs passed to text_generation _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": ["stop"]} + assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512} # note how n caused n replies to be generated assert isinstance(response, dict) @@ -268,7 +270,7 @@ class TestHuggingFaceTGIChatGenerator: # check kwargs passed to text_generation # we translate stop_words to stop_sequences _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]} + assert kwargs == {"details": True, "stop_sequences": ["stop", "words"], "max_new_tokens": 512} # Assert that the response contains the generated replies assert "replies" in response @@ -343,7 +345,7 @@ class TestHuggingFaceTGIChatGenerator: # check kwargs passed to text_generation _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": [], "stream": True} + assert kwargs == {"details": True, "stop_sequences": [], "stream": True, "max_new_tokens": 512} # Assert that the streaming callback was called twice assert streaming_call_count == 2 diff --git a/test/components/generators/test_hugging_face_local_generator.py b/test/components/generators/test_hugging_face_local_generator.py index efd77934a..5f4450cee 100644 --- a/test/components/generators/test_hugging_face_local_generator.py +++ b/test/components/generators/test_hugging_face_local_generator.py @@ -4,10 +4,10 @@ from unittest.mock import Mock, patch import pytest import torch from transformers import PreTrainedTokenizerFast -from haystack.utils.auth import Secret from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria from haystack.utils import ComponentDevice +from haystack.utils.auth import Secret class TestHuggingFaceLocalGenerator: @@ -23,7 +23,7 @@ class TestHuggingFaceLocalGenerator: "token": None, "device": ComponentDevice.resolve_device(None).to_hf(), } - assert generator.generation_kwargs == {} + assert generator.generation_kwargs == {"max_new_tokens": 512} assert generator.pipeline is None def test_init_custom_token(self): @@ -124,7 +124,7 @@ class TestHuggingFaceLocalGenerator: """ generator = HuggingFaceLocalGenerator(task="text-generation") - assert generator.generation_kwargs == {"return_full_text": False} + assert generator.generation_kwargs == {"max_new_tokens": 512, "return_full_text": False} def test_init_fails_with_both_stopwords_and_stoppingcriteria(self): with pytest.raises( @@ -153,7 +153,7 @@ class TestHuggingFaceLocalGenerator: "task": "text2text-generation", "device": ComponentDevice.resolve_device(None).to_hf(), }, - "generation_kwargs": {}, + "generation_kwargs": {"max_new_tokens": 512}, "stop_words": None, }, } diff --git a/test/components/generators/test_hugging_face_tgi.py b/test/components/generators/test_hugging_face_tgi.py index 752dcd439..19588e938 100644 --- a/test/components/generators/test_hugging_face_tgi.py +++ b/test/components/generators/test_hugging_face_tgi.py @@ -1,7 +1,7 @@ -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import MagicMock, Mock, patch import pytest -from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason +from huggingface_hub.inference._text_generation import FinishReason, StreamDetails, TextGenerationStreamResponse, Token from huggingface_hub.utils import RepositoryNotFoundError from haystack.components.generators import HuggingFaceTGIGenerator @@ -63,7 +63,11 @@ class TestHuggingFaceTGIGenerator: ) assert generator.model == model - assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}} + assert generator.generation_kwargs == { + **generation_kwargs, + **{"stop_sequences": ["stop"]}, + **{"max_new_tokens": 512}, + } assert generator.tokenizer is None assert generator.client is not None assert generator.streaming_callback == streaming_callback @@ -84,7 +88,7 @@ class TestHuggingFaceTGIGenerator: # Assert that the init_params dictionary contains the expected keys and values assert init_params["model"] == "mistralai/Mistral-7B-v0.1" assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"} - assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]} + assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512} def test_from_dict(self, mock_check_valid_model): generator = HuggingFaceTGIGenerator( @@ -99,7 +103,7 @@ class TestHuggingFaceTGIGenerator: # now deserialize, call from_dict generator_2 = HuggingFaceTGIGenerator.from_dict(result) assert generator_2.model == "mistralai/Mistral-7B-v0.1" - assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]} + assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512} assert generator_2.streaming_callback is streaming_callback_handler def test_initialize_with_invalid_url(self, mock_check_valid_model): @@ -135,7 +139,7 @@ class TestHuggingFaceTGIGenerator: # check kwargs passed to text_generation # note how n was not passed to text_generation _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": ["stop"]} + assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512} assert isinstance(response, dict) assert "replies" in response @@ -168,7 +172,7 @@ class TestHuggingFaceTGIGenerator: # check kwargs passed to text_generation # note how n was not passed to text_generation _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": ["stop"]} + assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512} assert isinstance(response, dict) assert "replies" in response @@ -208,7 +212,7 @@ class TestHuggingFaceTGIGenerator: # check kwargs passed to text_generation _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]} + assert kwargs == {"details": True, "stop_sequences": ["stop", "words"], "max_new_tokens": 512} # Assert that the response contains the generated replies assert "replies" in response @@ -283,7 +287,7 @@ class TestHuggingFaceTGIGenerator: # check kwargs passed to text_generation _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": [], "stream": True} + assert kwargs == {"details": True, "stop_sequences": [], "stream": True, "max_new_tokens": 512} # Assert that the streaming callback was called twice assert streaming_call_count == 2