refactor: default for max_new_tokens to 512 in Hugging Face generators (#7370)

* set default for max_new_tokens to 512 in Hugging Face generators

* add release notes

* fix tests

* remove issues from release note

---------

Co-authored-by: christopherkeibel <christopher.keibel@karakun.com>
Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
Christopher Keibel 2024-03-19 08:47:53 +01:00 committed by GitHub
parent 280719339c
commit f69c3e5cd2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 39 additions and 26 deletions

View File

@ -123,6 +123,7 @@ class HuggingFaceTGIChatGenerator:
check_generation_params(generation_kwargs, ["n"]) check_generation_params(generation_kwargs, ["n"])
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", []) generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
generation_kwargs["stop_sequences"].extend(stop_words or []) generation_kwargs["stop_sequences"].extend(stop_words or [])
generation_kwargs.setdefault("max_new_tokens", 512)
self.model = model self.model = model
self.url = url self.url = url

View File

@ -106,6 +106,7 @@ class HuggingFaceLocalGenerator:
"Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. " "Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. "
"Please specify only one of them." "Please specify only one of them."
) )
generation_kwargs.setdefault("max_new_tokens", 512)
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
self.generation_kwargs = generation_kwargs self.generation_kwargs = generation_kwargs

View File

@ -111,6 +111,7 @@ class HuggingFaceTGIGenerator:
check_generation_params(generation_kwargs, ["n"]) check_generation_params(generation_kwargs, ["n"])
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", []) generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
generation_kwargs["stop_sequences"].extend(stop_words or []) generation_kwargs["stop_sequences"].extend(stop_words or [])
generation_kwargs.setdefault("max_new_tokens", 512)
self.model = model self.model = model
self.url = url self.url = url

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Set max_new_tokens default to 512 in Hugging Face generators.

View File

@ -1,14 +1,12 @@
from unittest.mock import patch, MagicMock, Mock from unittest.mock import MagicMock, Mock, patch
from haystack.utils.auth import Secret
import pytest import pytest
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason from huggingface_hub.inference._text_generation import FinishReason, StreamDetails, TextGenerationStreamResponse, Token
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
from haystack.components.generators.chat import HuggingFaceTGIChatGenerator from haystack.components.generators.chat import HuggingFaceTGIChatGenerator
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import StreamingChunk, ChatMessage from haystack.utils.auth import Secret
@pytest.fixture @pytest.fixture
@ -70,7 +68,11 @@ class TestHuggingFaceTGIChatGenerator:
) )
generator.warm_up() generator.warm_up()
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}} assert generator.generation_kwargs == {
**generation_kwargs,
**{"stop_sequences": ["stop"]},
**{"max_new_tokens": 512},
}
assert generator.tokenizer is not None assert generator.tokenizer is not None
assert generator.client is not None assert generator.client is not None
assert generator.streaming_callback == streaming_callback assert generator.streaming_callback == streaming_callback
@ -92,7 +94,7 @@ class TestHuggingFaceTGIChatGenerator:
# Assert that the init_params dictionary contains the expected keys and values # Assert that the init_params dictionary contains the expected keys and values
assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf" assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf"
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"} assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]} assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
def test_from_dict(self, mock_check_valid_model): def test_from_dict(self, mock_check_valid_model):
generator = HuggingFaceTGIChatGenerator( generator = HuggingFaceTGIChatGenerator(
@ -106,7 +108,7 @@ class TestHuggingFaceTGIChatGenerator:
generator_2 = HuggingFaceTGIChatGenerator.from_dict(result) generator_2 = HuggingFaceTGIChatGenerator.from_dict(result)
assert generator_2.model == "NousResearch/Llama-2-7b-chat-hf" assert generator_2.model == "NousResearch/Llama-2-7b-chat-hf"
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]} assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
assert generator_2.streaming_callback is streaming_callback_handler assert generator_2.streaming_callback is streaming_callback_handler
def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models): def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models):
@ -205,7 +207,7 @@ class TestHuggingFaceTGIChatGenerator:
# check kwargs passed to text_generation # check kwargs passed to text_generation
# note how n because it is not text generation parameter was not passed to text_generation # note how n because it is not text generation parameter was not passed to text_generation
_, kwargs = mock_text_generation.call_args _, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]} assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}
assert isinstance(response, dict) assert isinstance(response, dict)
assert "replies" in response assert "replies" in response
@ -240,7 +242,7 @@ class TestHuggingFaceTGIChatGenerator:
# check kwargs passed to text_generation # check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args _, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]} assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}
# note how n caused n replies to be generated # note how n caused n replies to be generated
assert isinstance(response, dict) assert isinstance(response, dict)
@ -268,7 +270,7 @@ class TestHuggingFaceTGIChatGenerator:
# check kwargs passed to text_generation # check kwargs passed to text_generation
# we translate stop_words to stop_sequences # we translate stop_words to stop_sequences
_, kwargs = mock_text_generation.call_args _, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]} assert kwargs == {"details": True, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
# Assert that the response contains the generated replies # Assert that the response contains the generated replies
assert "replies" in response assert "replies" in response
@ -343,7 +345,7 @@ class TestHuggingFaceTGIChatGenerator:
# check kwargs passed to text_generation # check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args _, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": [], "stream": True} assert kwargs == {"details": True, "stop_sequences": [], "stream": True, "max_new_tokens": 512}
# Assert that the streaming callback was called twice # Assert that the streaming callback was called twice
assert streaming_call_count == 2 assert streaming_call_count == 2

View File

@ -4,10 +4,10 @@ from unittest.mock import Mock, patch
import pytest import pytest
import torch import torch
from transformers import PreTrainedTokenizerFast from transformers import PreTrainedTokenizerFast
from haystack.utils.auth import Secret
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria
from haystack.utils import ComponentDevice from haystack.utils import ComponentDevice
from haystack.utils.auth import Secret
class TestHuggingFaceLocalGenerator: class TestHuggingFaceLocalGenerator:
@ -23,7 +23,7 @@ class TestHuggingFaceLocalGenerator:
"token": None, "token": None,
"device": ComponentDevice.resolve_device(None).to_hf(), "device": ComponentDevice.resolve_device(None).to_hf(),
} }
assert generator.generation_kwargs == {} assert generator.generation_kwargs == {"max_new_tokens": 512}
assert generator.pipeline is None assert generator.pipeline is None
def test_init_custom_token(self): def test_init_custom_token(self):
@ -124,7 +124,7 @@ class TestHuggingFaceLocalGenerator:
""" """
generator = HuggingFaceLocalGenerator(task="text-generation") generator = HuggingFaceLocalGenerator(task="text-generation")
assert generator.generation_kwargs == {"return_full_text": False} assert generator.generation_kwargs == {"max_new_tokens": 512, "return_full_text": False}
def test_init_fails_with_both_stopwords_and_stoppingcriteria(self): def test_init_fails_with_both_stopwords_and_stoppingcriteria(self):
with pytest.raises( with pytest.raises(
@ -153,7 +153,7 @@ class TestHuggingFaceLocalGenerator:
"task": "text2text-generation", "task": "text2text-generation",
"device": ComponentDevice.resolve_device(None).to_hf(), "device": ComponentDevice.resolve_device(None).to_hf(),
}, },
"generation_kwargs": {}, "generation_kwargs": {"max_new_tokens": 512},
"stop_words": None, "stop_words": None,
}, },
} }

View File

@ -1,7 +1,7 @@
from unittest.mock import patch, MagicMock, Mock from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason from huggingface_hub.inference._text_generation import FinishReason, StreamDetails, TextGenerationStreamResponse, Token
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
from haystack.components.generators import HuggingFaceTGIGenerator from haystack.components.generators import HuggingFaceTGIGenerator
@ -63,7 +63,11 @@ class TestHuggingFaceTGIGenerator:
) )
assert generator.model == model assert generator.model == model
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}} assert generator.generation_kwargs == {
**generation_kwargs,
**{"stop_sequences": ["stop"]},
**{"max_new_tokens": 512},
}
assert generator.tokenizer is None assert generator.tokenizer is None
assert generator.client is not None assert generator.client is not None
assert generator.streaming_callback == streaming_callback assert generator.streaming_callback == streaming_callback
@ -84,7 +88,7 @@ class TestHuggingFaceTGIGenerator:
# Assert that the init_params dictionary contains the expected keys and values # Assert that the init_params dictionary contains the expected keys and values
assert init_params["model"] == "mistralai/Mistral-7B-v0.1" assert init_params["model"] == "mistralai/Mistral-7B-v0.1"
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"} assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]} assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
def test_from_dict(self, mock_check_valid_model): def test_from_dict(self, mock_check_valid_model):
generator = HuggingFaceTGIGenerator( generator = HuggingFaceTGIGenerator(
@ -99,7 +103,7 @@ class TestHuggingFaceTGIGenerator:
# now deserialize, call from_dict # now deserialize, call from_dict
generator_2 = HuggingFaceTGIGenerator.from_dict(result) generator_2 = HuggingFaceTGIGenerator.from_dict(result)
assert generator_2.model == "mistralai/Mistral-7B-v0.1" assert generator_2.model == "mistralai/Mistral-7B-v0.1"
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]} assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
assert generator_2.streaming_callback is streaming_callback_handler assert generator_2.streaming_callback is streaming_callback_handler
def test_initialize_with_invalid_url(self, mock_check_valid_model): def test_initialize_with_invalid_url(self, mock_check_valid_model):
@ -135,7 +139,7 @@ class TestHuggingFaceTGIGenerator:
# check kwargs passed to text_generation # check kwargs passed to text_generation
# note how n was not passed to text_generation # note how n was not passed to text_generation
_, kwargs = mock_text_generation.call_args _, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]} assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}
assert isinstance(response, dict) assert isinstance(response, dict)
assert "replies" in response assert "replies" in response
@ -168,7 +172,7 @@ class TestHuggingFaceTGIGenerator:
# check kwargs passed to text_generation # check kwargs passed to text_generation
# note how n was not passed to text_generation # note how n was not passed to text_generation
_, kwargs = mock_text_generation.call_args _, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]} assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}
assert isinstance(response, dict) assert isinstance(response, dict)
assert "replies" in response assert "replies" in response
@ -208,7 +212,7 @@ class TestHuggingFaceTGIGenerator:
# check kwargs passed to text_generation # check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args _, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]} assert kwargs == {"details": True, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
# Assert that the response contains the generated replies # Assert that the response contains the generated replies
assert "replies" in response assert "replies" in response
@ -283,7 +287,7 @@ class TestHuggingFaceTGIGenerator:
# check kwargs passed to text_generation # check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args _, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": [], "stream": True} assert kwargs == {"details": True, "stop_sequences": [], "stream": True, "max_new_tokens": 512}
# Assert that the streaming callback was called twice # Assert that the streaming callback was called twice
assert streaming_call_count == 2 assert streaming_call_count == 2