mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-17 18:43:58 +00:00
This reverts commit 514f93a6eb575d376b21d22e32080fac62cf785f.
This commit is contained in:
parent
2bc7fe1a08
commit
5bb0a1f57a
@ -195,6 +195,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
|||||||
"""
|
"""
|
||||||
output: List[Dict[str, str]] = []
|
output: List[Dict[str, str]] = []
|
||||||
stop_words = kwargs.pop("stop_words", None)
|
stop_words = kwargs.pop("stop_words", None)
|
||||||
|
top_k = kwargs.pop("top_k", None)
|
||||||
# either stream is True (will use default handler) or stream_handler is provided for custom handler
|
# either stream is True (will use default handler) or stream_handler is provided for custom handler
|
||||||
stream = kwargs.get("stream", self.stream)
|
stream = kwargs.get("stream", self.stream)
|
||||||
stream_handler = kwargs.get("stream_handler", self.stream_handler)
|
stream_handler = kwargs.get("stream_handler", self.stream_handler)
|
||||||
@ -238,21 +239,12 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
|||||||
if stop_words:
|
if stop_words:
|
||||||
sw = StopWordsCriteria(tokenizer=self.pipe.tokenizer, stop_words=stop_words, device=self.pipe.device)
|
sw = StopWordsCriteria(tokenizer=self.pipe.tokenizer, stop_words=stop_words, device=self.pipe.device)
|
||||||
model_input_kwargs["stopping_criteria"] = StoppingCriteriaList([sw])
|
model_input_kwargs["stopping_criteria"] = StoppingCriteriaList([sw])
|
||||||
|
if top_k:
|
||||||
|
model_input_kwargs["num_return_sequences"] = top_k
|
||||||
|
if "num_beams" not in model_input_kwargs or model_input_kwargs["num_beams"] < top_k:
|
||||||
if "num_beams" in model_input_kwargs:
|
if "num_beams" in model_input_kwargs:
|
||||||
num_beams = model_input_kwargs["num_beams"]
|
logger.warning("num_beams should not be less than top_k, hence setting it to %s", top_k)
|
||||||
if (
|
model_input_kwargs["num_beams"] = top_k
|
||||||
"num_return_sequences" in model_input_kwargs
|
|
||||||
and model_input_kwargs["num_return_sequences"] > num_beams
|
|
||||||
):
|
|
||||||
num_return_sequences = model_input_kwargs["num_return_sequences"]
|
|
||||||
logger.warning(
|
|
||||||
"num_return_sequences %s should not be larger than num_beams %s, hence setting it equal to num_beams",
|
|
||||||
num_return_sequences,
|
|
||||||
num_beams,
|
|
||||||
)
|
|
||||||
model_input_kwargs["num_return_sequences"] = num_beams
|
|
||||||
|
|
||||||
# max_new_tokens is used for text-generation and max_length for text2text-generation
|
# max_new_tokens is used for text-generation and max_length for text2text-generation
|
||||||
if is_text_generation:
|
if is_text_generation:
|
||||||
model_input_kwargs["max_new_tokens"] = model_input_kwargs.pop("max_length", self.max_length)
|
model_input_kwargs["max_new_tokens"] = model_input_kwargs.pop("max_length", self.max_length)
|
||||||
|
|||||||
@ -205,48 +205,6 @@ def test_ensure_token_limit_negative(caplog):
|
|||||||
assert caplog.records[0].message == expected_message
|
assert caplog.records[0].message == expected_message
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
def test_num_return_sequences_no_larger_than_num_beams(mock_pipeline, mock_get_task, caplog):
|
|
||||||
"""
|
|
||||||
Test that num_return_sequences cannot be larger than num_beams and that a warning is logged
|
|
||||||
"""
|
|
||||||
|
|
||||||
layer = HFLocalInvocationLayer("google/flan-t5-base")
|
|
||||||
|
|
||||||
with patch.object(layer.pipe, "run_single", MagicMock()):
|
|
||||||
layer.invoke(prompt="What does 42 mean?", generation_kwargs={"num_beams": 5, "num_return_sequences": 8})
|
|
||||||
|
|
||||||
expected_message = (
|
|
||||||
"num_return_sequences 8 should not be larger than num_beams 5, hence setting it equal to num_beams"
|
|
||||||
)
|
|
||||||
# check that the warning is logged
|
|
||||||
assert caplog.records[0].message == expected_message
|
|
||||||
|
|
||||||
# check that num_return_sequences is set to num_beams
|
|
||||||
_, kwargs = layer.pipe.call_args
|
|
||||||
assert kwargs["num_beams"] == 5
|
|
||||||
assert kwargs["num_return_sequences"] == 5
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
def test_num_beams_larger_than_num_return_sequences(mock_pipeline, mock_get_task, caplog):
|
|
||||||
"""
|
|
||||||
Test that num_beams can be larger than num_return_sequences and that no warning is logged
|
|
||||||
"""
|
|
||||||
layer = HFLocalInvocationLayer("google/flan-t5-base")
|
|
||||||
|
|
||||||
with patch.object(layer.pipe, "run_single", MagicMock()):
|
|
||||||
layer.invoke(prompt="What does 42 mean?", generation_kwargs={"num_beams": 8, "num_return_sequences": 5})
|
|
||||||
|
|
||||||
# check that no warning is logged
|
|
||||||
assert not caplog.records
|
|
||||||
|
|
||||||
# check that num_return_sequences remains unchanged
|
|
||||||
_, kwargs = layer.pipe.call_args
|
|
||||||
assert kwargs["num_beams"] == 8
|
|
||||||
assert kwargs["num_return_sequences"] == 5
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task):
|
def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user