feat: enable passing use_fast to the underlying transformers' pipeline (#5655)

* copy instead of deepcopy

* fix pylint

* add use_fast

* add release note

* remove unrelevant changes

* black fix

* fix bug

* black

* bug fix
This commit is contained in:
Fanli Lin 2023-08-30 16:25:18 +08:00 committed by GitHub
parent b1daa7c647
commit 40d9f34e68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 6 deletions

View File

@ -208,6 +208,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
"torch_dtype": torch_dtype,
"model_kwargs": model_kwargs,
"pipeline_class": kwargs.get("pipeline_class", None),
"use_fast": kwargs.get("use_fast", True),
**hub_kwargs,
}
return pipeline_kwargs

View File

@ -0,0 +1,3 @@
---
enhancements:
- enable passing use_fast to the underlying transformers' pipeline

View File

@ -59,8 +59,8 @@ def test_constructor_with_model_name_only(mock_pipeline, mock_get_task):
assert kwargs["task"] == "text2text-generation"
assert kwargs["model"] == "google/flan-t5-base"
# no matter what kwargs we pass or don't pass, there are always 13 predefined kwargs passed to the pipeline
assert len(kwargs) == 13
# no matter what kwargs we pass or don't pass, there are always 14 predefined kwargs passed to the pipeline
assert len(kwargs) == 14
# and these kwargs are passed to the pipeline
assert list(kwargs.keys()) == [
@ -74,6 +74,7 @@ def test_constructor_with_model_name_only(mock_pipeline, mock_get_task):
"torch_dtype",
"model_kwargs",
"pipeline_class",
"use_fast",
"revision",
"use_auth_token",
"trust_remote_code",
@ -248,8 +249,8 @@ def test_constructor_with_invalid_kwargs(mock_pipeline, mock_get_task):
# invalid kwargs are ignored and not passed to the pipeline
assert "some_invalid_kwarg" not in kwargs
# still our 13 kwargs passed to the pipeline
assert len(kwargs) == 13
# still our 14 kwargs passed to the pipeline
assert len(kwargs) == 14
@pytest.mark.unit
@ -287,8 +288,8 @@ def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task):
assert kwargs["device_map"] and kwargs["device_map"] == "auto"
assert kwargs["revision"] == "1.1"
# still on 13 kwargs passed to the pipeline
assert len(kwargs) == 13
# still on 14 kwargs passed to the pipeline
assert len(kwargs) == 14
@pytest.mark.integration