mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-18 04:33:34 +00:00
refactor: default for max_new_tokens to 512 in Hugging Face generators (#7370)
* set default for max_new_tokens to 512 in Hugging Face generators * add release notes * fix tests * remove issues from release note --------- Co-authored-by: christopherkeibel <christopher.keibel@karakun.com> Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
parent
280719339c
commit
f69c3e5cd2
@ -123,6 +123,7 @@ class HuggingFaceTGIChatGenerator:
|
|||||||
check_generation_params(generation_kwargs, ["n"])
|
check_generation_params(generation_kwargs, ["n"])
|
||||||
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
|
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
|
||||||
generation_kwargs["stop_sequences"].extend(stop_words or [])
|
generation_kwargs["stop_sequences"].extend(stop_words or [])
|
||||||
|
generation_kwargs.setdefault("max_new_tokens", 512)
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.url = url
|
self.url = url
|
||||||
|
@ -106,6 +106,7 @@ class HuggingFaceLocalGenerator:
|
|||||||
"Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. "
|
"Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. "
|
||||||
"Please specify only one of them."
|
"Please specify only one of them."
|
||||||
)
|
)
|
||||||
|
generation_kwargs.setdefault("max_new_tokens", 512)
|
||||||
|
|
||||||
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
|
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
|
||||||
self.generation_kwargs = generation_kwargs
|
self.generation_kwargs = generation_kwargs
|
||||||
|
@ -111,6 +111,7 @@ class HuggingFaceTGIGenerator:
|
|||||||
check_generation_params(generation_kwargs, ["n"])
|
check_generation_params(generation_kwargs, ["n"])
|
||||||
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
|
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
|
||||||
generation_kwargs["stop_sequences"].extend(stop_words or [])
|
generation_kwargs["stop_sequences"].extend(stop_words or [])
|
||||||
|
generation_kwargs.setdefault("max_new_tokens", 512)
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.url = url
|
self.url = url
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
enhancements:
|
||||||
|
- |
|
||||||
|
Set max_new_tokens default to 512 in Hugging Face generators.
|
@ -1,14 +1,12 @@
|
|||||||
from unittest.mock import patch, MagicMock, Mock
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
from haystack.utils.auth import Secret
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
|
from huggingface_hub.inference._text_generation import FinishReason, StreamDetails, TextGenerationStreamResponse, Token
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
|
|
||||||
from haystack.components.generators.chat import HuggingFaceTGIChatGenerator
|
from haystack.components.generators.chat import HuggingFaceTGIChatGenerator
|
||||||
|
from haystack.dataclasses import ChatMessage, StreamingChunk
|
||||||
from haystack.dataclasses import StreamingChunk, ChatMessage
|
from haystack.utils.auth import Secret
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -70,7 +68,11 @@ class TestHuggingFaceTGIChatGenerator:
|
|||||||
)
|
)
|
||||||
generator.warm_up()
|
generator.warm_up()
|
||||||
|
|
||||||
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
|
assert generator.generation_kwargs == {
|
||||||
|
**generation_kwargs,
|
||||||
|
**{"stop_sequences": ["stop"]},
|
||||||
|
**{"max_new_tokens": 512},
|
||||||
|
}
|
||||||
assert generator.tokenizer is not None
|
assert generator.tokenizer is not None
|
||||||
assert generator.client is not None
|
assert generator.client is not None
|
||||||
assert generator.streaming_callback == streaming_callback
|
assert generator.streaming_callback == streaming_callback
|
||||||
@ -92,7 +94,7 @@ class TestHuggingFaceTGIChatGenerator:
|
|||||||
# Assert that the init_params dictionary contains the expected keys and values
|
# Assert that the init_params dictionary contains the expected keys and values
|
||||||
assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf"
|
assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf"
|
||||||
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
|
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
|
||||||
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
|
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
|
||||||
|
|
||||||
def test_from_dict(self, mock_check_valid_model):
|
def test_from_dict(self, mock_check_valid_model):
|
||||||
generator = HuggingFaceTGIChatGenerator(
|
generator = HuggingFaceTGIChatGenerator(
|
||||||
@ -106,7 +108,7 @@ class TestHuggingFaceTGIChatGenerator:
|
|||||||
|
|
||||||
generator_2 = HuggingFaceTGIChatGenerator.from_dict(result)
|
generator_2 = HuggingFaceTGIChatGenerator.from_dict(result)
|
||||||
assert generator_2.model == "NousResearch/Llama-2-7b-chat-hf"
|
assert generator_2.model == "NousResearch/Llama-2-7b-chat-hf"
|
||||||
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]}
|
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
|
||||||
assert generator_2.streaming_callback is streaming_callback_handler
|
assert generator_2.streaming_callback is streaming_callback_handler
|
||||||
|
|
||||||
def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models):
|
def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models):
|
||||||
@ -205,7 +207,7 @@ class TestHuggingFaceTGIChatGenerator:
|
|||||||
# check kwargs passed to text_generation
|
# check kwargs passed to text_generation
|
||||||
# note how n because it is not text generation parameter was not passed to text_generation
|
# note how n because it is not text generation parameter was not passed to text_generation
|
||||||
_, kwargs = mock_text_generation.call_args
|
_, kwargs = mock_text_generation.call_args
|
||||||
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
|
assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}
|
||||||
|
|
||||||
assert isinstance(response, dict)
|
assert isinstance(response, dict)
|
||||||
assert "replies" in response
|
assert "replies" in response
|
||||||
@ -240,7 +242,7 @@ class TestHuggingFaceTGIChatGenerator:
|
|||||||
|
|
||||||
# check kwargs passed to text_generation
|
# check kwargs passed to text_generation
|
||||||
_, kwargs = mock_text_generation.call_args
|
_, kwargs = mock_text_generation.call_args
|
||||||
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
|
assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}
|
||||||
|
|
||||||
# note how n caused n replies to be generated
|
# note how n caused n replies to be generated
|
||||||
assert isinstance(response, dict)
|
assert isinstance(response, dict)
|
||||||
@ -268,7 +270,7 @@ class TestHuggingFaceTGIChatGenerator:
|
|||||||
# check kwargs passed to text_generation
|
# check kwargs passed to text_generation
|
||||||
# we translate stop_words to stop_sequences
|
# we translate stop_words to stop_sequences
|
||||||
_, kwargs = mock_text_generation.call_args
|
_, kwargs = mock_text_generation.call_args
|
||||||
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]}
|
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
|
||||||
|
|
||||||
# Assert that the response contains the generated replies
|
# Assert that the response contains the generated replies
|
||||||
assert "replies" in response
|
assert "replies" in response
|
||||||
@ -343,7 +345,7 @@ class TestHuggingFaceTGIChatGenerator:
|
|||||||
|
|
||||||
# check kwargs passed to text_generation
|
# check kwargs passed to text_generation
|
||||||
_, kwargs = mock_text_generation.call_args
|
_, kwargs = mock_text_generation.call_args
|
||||||
assert kwargs == {"details": True, "stop_sequences": [], "stream": True}
|
assert kwargs == {"details": True, "stop_sequences": [], "stream": True, "max_new_tokens": 512}
|
||||||
|
|
||||||
# Assert that the streaming callback was called twice
|
# Assert that the streaming callback was called twice
|
||||||
assert streaming_call_count == 2
|
assert streaming_call_count == 2
|
||||||
|
@ -4,10 +4,10 @@ from unittest.mock import Mock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers import PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizerFast
|
||||||
from haystack.utils.auth import Secret
|
|
||||||
|
|
||||||
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria
|
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria
|
||||||
from haystack.utils import ComponentDevice
|
from haystack.utils import ComponentDevice
|
||||||
|
from haystack.utils.auth import Secret
|
||||||
|
|
||||||
|
|
||||||
class TestHuggingFaceLocalGenerator:
|
class TestHuggingFaceLocalGenerator:
|
||||||
@ -23,7 +23,7 @@ class TestHuggingFaceLocalGenerator:
|
|||||||
"token": None,
|
"token": None,
|
||||||
"device": ComponentDevice.resolve_device(None).to_hf(),
|
"device": ComponentDevice.resolve_device(None).to_hf(),
|
||||||
}
|
}
|
||||||
assert generator.generation_kwargs == {}
|
assert generator.generation_kwargs == {"max_new_tokens": 512}
|
||||||
assert generator.pipeline is None
|
assert generator.pipeline is None
|
||||||
|
|
||||||
def test_init_custom_token(self):
|
def test_init_custom_token(self):
|
||||||
@ -124,7 +124,7 @@ class TestHuggingFaceLocalGenerator:
|
|||||||
"""
|
"""
|
||||||
generator = HuggingFaceLocalGenerator(task="text-generation")
|
generator = HuggingFaceLocalGenerator(task="text-generation")
|
||||||
|
|
||||||
assert generator.generation_kwargs == {"return_full_text": False}
|
assert generator.generation_kwargs == {"max_new_tokens": 512, "return_full_text": False}
|
||||||
|
|
||||||
def test_init_fails_with_both_stopwords_and_stoppingcriteria(self):
|
def test_init_fails_with_both_stopwords_and_stoppingcriteria(self):
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
@ -153,7 +153,7 @@ class TestHuggingFaceLocalGenerator:
|
|||||||
"task": "text2text-generation",
|
"task": "text2text-generation",
|
||||||
"device": ComponentDevice.resolve_device(None).to_hf(),
|
"device": ComponentDevice.resolve_device(None).to_hf(),
|
||||||
},
|
},
|
||||||
"generation_kwargs": {},
|
"generation_kwargs": {"max_new_tokens": 512},
|
||||||
"stop_words": None,
|
"stop_words": None,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from unittest.mock import patch, MagicMock, Mock
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
|
from huggingface_hub.inference._text_generation import FinishReason, StreamDetails, TextGenerationStreamResponse, Token
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
|
|
||||||
from haystack.components.generators import HuggingFaceTGIGenerator
|
from haystack.components.generators import HuggingFaceTGIGenerator
|
||||||
@ -63,7 +63,11 @@ class TestHuggingFaceTGIGenerator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert generator.model == model
|
assert generator.model == model
|
||||||
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
|
assert generator.generation_kwargs == {
|
||||||
|
**generation_kwargs,
|
||||||
|
**{"stop_sequences": ["stop"]},
|
||||||
|
**{"max_new_tokens": 512},
|
||||||
|
}
|
||||||
assert generator.tokenizer is None
|
assert generator.tokenizer is None
|
||||||
assert generator.client is not None
|
assert generator.client is not None
|
||||||
assert generator.streaming_callback == streaming_callback
|
assert generator.streaming_callback == streaming_callback
|
||||||
@ -84,7 +88,7 @@ class TestHuggingFaceTGIGenerator:
|
|||||||
# Assert that the init_params dictionary contains the expected keys and values
|
# Assert that the init_params dictionary contains the expected keys and values
|
||||||
assert init_params["model"] == "mistralai/Mistral-7B-v0.1"
|
assert init_params["model"] == "mistralai/Mistral-7B-v0.1"
|
||||||
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
|
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
|
||||||
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
|
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
|
||||||
|
|
||||||
def test_from_dict(self, mock_check_valid_model):
|
def test_from_dict(self, mock_check_valid_model):
|
||||||
generator = HuggingFaceTGIGenerator(
|
generator = HuggingFaceTGIGenerator(
|
||||||
@ -99,7 +103,7 @@ class TestHuggingFaceTGIGenerator:
|
|||||||
# now deserialize, call from_dict
|
# now deserialize, call from_dict
|
||||||
generator_2 = HuggingFaceTGIGenerator.from_dict(result)
|
generator_2 = HuggingFaceTGIGenerator.from_dict(result)
|
||||||
assert generator_2.model == "mistralai/Mistral-7B-v0.1"
|
assert generator_2.model == "mistralai/Mistral-7B-v0.1"
|
||||||
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]}
|
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
|
||||||
assert generator_2.streaming_callback is streaming_callback_handler
|
assert generator_2.streaming_callback is streaming_callback_handler
|
||||||
|
|
||||||
def test_initialize_with_invalid_url(self, mock_check_valid_model):
|
def test_initialize_with_invalid_url(self, mock_check_valid_model):
|
||||||
@ -135,7 +139,7 @@ class TestHuggingFaceTGIGenerator:
|
|||||||
# check kwargs passed to text_generation
|
# check kwargs passed to text_generation
|
||||||
# note how n was not passed to text_generation
|
# note how n was not passed to text_generation
|
||||||
_, kwargs = mock_text_generation.call_args
|
_, kwargs = mock_text_generation.call_args
|
||||||
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
|
assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}
|
||||||
|
|
||||||
assert isinstance(response, dict)
|
assert isinstance(response, dict)
|
||||||
assert "replies" in response
|
assert "replies" in response
|
||||||
@ -168,7 +172,7 @@ class TestHuggingFaceTGIGenerator:
|
|||||||
# check kwargs passed to text_generation
|
# check kwargs passed to text_generation
|
||||||
# note how n was not passed to text_generation
|
# note how n was not passed to text_generation
|
||||||
_, kwargs = mock_text_generation.call_args
|
_, kwargs = mock_text_generation.call_args
|
||||||
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
|
assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}
|
||||||
|
|
||||||
assert isinstance(response, dict)
|
assert isinstance(response, dict)
|
||||||
assert "replies" in response
|
assert "replies" in response
|
||||||
@ -208,7 +212,7 @@ class TestHuggingFaceTGIGenerator:
|
|||||||
|
|
||||||
# check kwargs passed to text_generation
|
# check kwargs passed to text_generation
|
||||||
_, kwargs = mock_text_generation.call_args
|
_, kwargs = mock_text_generation.call_args
|
||||||
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]}
|
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
|
||||||
|
|
||||||
# Assert that the response contains the generated replies
|
# Assert that the response contains the generated replies
|
||||||
assert "replies" in response
|
assert "replies" in response
|
||||||
@ -283,7 +287,7 @@ class TestHuggingFaceTGIGenerator:
|
|||||||
|
|
||||||
# check kwargs passed to text_generation
|
# check kwargs passed to text_generation
|
||||||
_, kwargs = mock_text_generation.call_args
|
_, kwargs = mock_text_generation.call_args
|
||||||
assert kwargs == {"details": True, "stop_sequences": [], "stream": True}
|
assert kwargs == {"details": True, "stop_sequences": [], "stream": True, "max_new_tokens": 512}
|
||||||
|
|
||||||
# Assert that the streaming callback was called twice
|
# Assert that the streaming callback was called twice
|
||||||
assert streaming_call_count == 2
|
assert streaming_call_count == 2
|
||||||
|
Loading…
x
Reference in New Issue
Block a user