fix: loads local HF Models in PromptNode pipeline (#4670)

* bug: fix load local HF Models in PromptNode pipeline

* Update hugging_face.py

remove duplicate validator

* update: black formatted

* update: update doc string, replace pop with get

* test HFLocalInvocationLayer with local model
This commit is contained in:
s_teja 2023-04-26 12:10:02 +01:00 committed by GitHub
parent 0d6fba14fd
commit d033a086d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 3 deletions

View File

@ -47,7 +47,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
:param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of
all PromptModelInvocationLayer instances, this instance of HFLocalInvocationLayer might receive some unrelated
kwargs. Only kwargs relevant to the HFLocalInvocationLayer are considered. The list of supported kwargs
includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
includes: task_name, trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
For more details about pipeline kwargs in general, see
Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
@ -119,8 +119,15 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
if len(model_input_kwargs) > 0:
logger.info("Using model input kwargs %s in %s", model_input_kwargs, self.__class__.__name__)
self.task_name = get_task(model_name_or_path, use_auth_token=use_auth_token)
# If task_name is not provided, get the task name from the model name or path (uses HFApi)
if "task_name" in kwargs:
self.task_name = kwargs.get("task_name")
else:
self.task_name = get_task(model_name_or_path, use_auth_token=use_auth_token)
self.pipe = pipeline(
task=self.task_name, # task_name is used to determine the pipeline type
model=model_name_or_path,
device=self.devices[0] if "device_map" not in model_input_kwargs else None,
use_auth_token=self.use_auth_token,

View File

@ -3,7 +3,7 @@ from unittest.mock import patch, Mock
import pytest
from haystack.nodes.prompt.prompt_model import PromptModel
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer, HFLocalInvocationLayer
from .conftest import create_mock_layer_that_supports
@ -36,3 +36,28 @@ def test_construtor_with_custom_model():
def test_constructor_with_no_supported_model():
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
PromptModel("some-random-model")
def create_mock_pipeline(model_name_or_path=None, max_length=100):
return Mock(
**{"model_name_or_path": model_name_or_path},
return_value=Mock(**{"model_name_or_path": model_name_or_path, "tokenizer.model_max_length": max_length}),
)
@pytest.mark.unit
def test_hf_local_invocation_layer_with_task_name():
mock_pipeline = create_mock_pipeline()
mock_get_task = Mock(return_value="dummy_task")
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.get_task", mock_get_task):
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.pipeline", mock_pipeline):
PromptModel(
model_name_or_path="local_model",
max_length=100,
model_kwargs={"task_name": "dummy_task"},
invocation_layer_class=HFLocalInvocationLayer,
)
# checking if get_task is called when task_name is passed to HFLocalInvocationLayer constructor
mock_get_task.assert_not_called()
mock_pipeline.assert_called_once()