diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face.py b/haystack/nodes/prompt/invocation_layer/hugging_face.py index 57ef37ef6..ce46eb9f8 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face.py @@ -208,6 +208,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): "torch_dtype": torch_dtype, "model_kwargs": model_kwargs, "pipeline_class": kwargs.get("pipeline_class", None), + "use_fast": kwargs.get("use_fast", True), **hub_kwargs, } return pipeline_kwargs diff --git a/releasenotes/notes/enable_pass_use_fast_to_transformers-b5fdf14d69aa58ec.yaml b/releasenotes/notes/enable_pass_use_fast_to_transformers-b5fdf14d69aa58ec.yaml new file mode 100644 index 000000000..33aacb398 --- /dev/null +++ b/releasenotes/notes/enable_pass_use_fast_to_transformers-b5fdf14d69aa58ec.yaml @@ -0,0 +1,3 @@ +--- +enhancements: + - enable passing use_fast to the underlying transformers' pipeline diff --git a/test/prompt/invocation_layer/test_hugging_face.py b/test/prompt/invocation_layer/test_hugging_face.py index 9aeef2508..e377bb752 100644 --- a/test/prompt/invocation_layer/test_hugging_face.py +++ b/test/prompt/invocation_layer/test_hugging_face.py @@ -59,8 +59,8 @@ def test_constructor_with_model_name_only(mock_pipeline, mock_get_task): assert kwargs["task"] == "text2text-generation" assert kwargs["model"] == "google/flan-t5-base" - # no matter what kwargs we pass or don't pass, there are always 13 predefined kwargs passed to the pipeline - assert len(kwargs) == 13 + # no matter what kwargs we pass or don't pass, there are always 14 predefined kwargs passed to the pipeline + assert len(kwargs) == 14 # and these kwargs are passed to the pipeline assert list(kwargs.keys()) == [ @@ -74,6 +74,7 @@ def test_constructor_with_model_name_only(mock_pipeline, mock_get_task): "torch_dtype", "model_kwargs", "pipeline_class", + "use_fast", "revision", "use_auth_token", "trust_remote_code", @@ -248,8 +249,8 @@ def test_constructor_with_invalid_kwargs(mock_pipeline, mock_get_task): # invalid kwargs are ignored and not passed to the pipeline assert "some_invalid_kwarg" not in kwargs - # still our 13 kwargs passed to the pipeline - assert len(kwargs) == 13 + # still our 14 kwargs passed to the pipeline + assert len(kwargs) == 14 @pytest.mark.unit @@ -287,8 +288,8 @@ def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task): assert kwargs["device_map"] and kwargs["device_map"] == "auto" assert kwargs["revision"] == "1.1" - # still on 13 kwargs passed to the pipeline - assert len(kwargs) == 13 + # still on 14 kwargs passed to the pipeline + assert len(kwargs) == 14 @pytest.mark.integration