mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-27 18:06:17 +00:00
feat: enable passing use_fast
to the underlying transformers' pipeline (#5655)
* copy instead of deepcopy * fix pylint * add use_fast * add release note * remove unrelevant changes * black fix * fix bug * black * bug fix
This commit is contained in:
parent
b1daa7c647
commit
40d9f34e68
@ -208,6 +208,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
"torch_dtype": torch_dtype,
|
||||
"model_kwargs": model_kwargs,
|
||||
"pipeline_class": kwargs.get("pipeline_class", None),
|
||||
"use_fast": kwargs.get("use_fast", True),
|
||||
**hub_kwargs,
|
||||
}
|
||||
return pipeline_kwargs
|
||||
|
@ -0,0 +1,3 @@
|
||||
---
|
||||
enhancements:
|
||||
- enable passing use_fast to the underlying transformers' pipeline
|
@ -59,8 +59,8 @@ def test_constructor_with_model_name_only(mock_pipeline, mock_get_task):
|
||||
assert kwargs["task"] == "text2text-generation"
|
||||
assert kwargs["model"] == "google/flan-t5-base"
|
||||
|
||||
# no matter what kwargs we pass or don't pass, there are always 13 predefined kwargs passed to the pipeline
|
||||
assert len(kwargs) == 13
|
||||
# no matter what kwargs we pass or don't pass, there are always 14 predefined kwargs passed to the pipeline
|
||||
assert len(kwargs) == 14
|
||||
|
||||
# and these kwargs are passed to the pipeline
|
||||
assert list(kwargs.keys()) == [
|
||||
@ -74,6 +74,7 @@ def test_constructor_with_model_name_only(mock_pipeline, mock_get_task):
|
||||
"torch_dtype",
|
||||
"model_kwargs",
|
||||
"pipeline_class",
|
||||
"use_fast",
|
||||
"revision",
|
||||
"use_auth_token",
|
||||
"trust_remote_code",
|
||||
@ -248,8 +249,8 @@ def test_constructor_with_invalid_kwargs(mock_pipeline, mock_get_task):
|
||||
# invalid kwargs are ignored and not passed to the pipeline
|
||||
assert "some_invalid_kwarg" not in kwargs
|
||||
|
||||
# still our 13 kwargs passed to the pipeline
|
||||
assert len(kwargs) == 13
|
||||
# still our 14 kwargs passed to the pipeline
|
||||
assert len(kwargs) == 14
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@ -287,8 +288,8 @@ def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task):
|
||||
assert kwargs["device_map"] and kwargs["device_map"] == "auto"
|
||||
assert kwargs["revision"] == "1.1"
|
||||
|
||||
# still on 13 kwargs passed to the pipeline
|
||||
assert len(kwargs) == 13
|
||||
# still on 14 kwargs passed to the pipeline
|
||||
assert len(kwargs) == 14
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
Loading…
x
Reference in New Issue
Block a user