From d033a086d08a7e1a6f05b2be2b86316c7c594cfd Mon Sep 17 00:00:00 2001 From: s_teja Date: Wed, 26 Apr 2023 12:10:02 +0100 Subject: [PATCH] fix: loads local HF Models in PromptNode pipeline (#4670) * bug: fix load local HF Models in PromptNode pipeline * Update hugging_face.py remove duplicate validator * update: black formatted * update: update doc string, replace pop with get * test HFLocalInvocationLayer with local model --- .../prompt/invocation_layer/hugging_face.py | 11 ++++++-- test/prompt/test_prompt_model.py | 27 ++++++++++++++++++- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face.py b/haystack/nodes/prompt/invocation_layer/hugging_face.py index 7a85becdc..9ac140d02 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face.py @@ -47,7 +47,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): :param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of all PromptModelInvocationLayer instances, this instance of HFLocalInvocationLayer might receive some unrelated kwargs. Only kwargs relevant to the HFLocalInvocationLayer are considered. The list of supported kwargs - includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map. + includes: task_name, trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map. For more details about pipeline kwargs in general, see Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline). @@ -119,8 +119,15 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): if len(model_input_kwargs) > 0: logger.info("Using model input kwargs %s in %s", model_input_kwargs, self.__class__.__name__) - self.task_name = get_task(model_name_or_path, use_auth_token=use_auth_token) + + # If task_name is not provided, get the task name from the model name or path (uses HFApi) + if "task_name" in kwargs: + self.task_name = kwargs.get("task_name") + else: + self.task_name = get_task(model_name_or_path, use_auth_token=use_auth_token) + self.pipe = pipeline( + task=self.task_name, # task_name is used to determine the pipeline type model=model_name_or_path, device=self.devices[0] if "device_map" not in model_input_kwargs else None, use_auth_token=self.use_auth_token, diff --git a/test/prompt/test_prompt_model.py b/test/prompt/test_prompt_model.py index 366efe30e..109b284dc 100644 --- a/test/prompt/test_prompt_model.py +++ b/test/prompt/test_prompt_model.py @@ -3,7 +3,7 @@ from unittest.mock import patch, Mock import pytest from haystack.nodes.prompt.prompt_model import PromptModel -from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer +from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer, HFLocalInvocationLayer from .conftest import create_mock_layer_that_supports @@ -36,3 +36,28 @@ def test_construtor_with_custom_model(): def test_constructor_with_no_supported_model(): with pytest.raises(ValueError, match="Model some-random-model is not supported"): PromptModel("some-random-model") + + +def create_mock_pipeline(model_name_or_path=None, max_length=100): + return Mock( + **{"model_name_or_path": model_name_or_path}, + return_value=Mock(**{"model_name_or_path": model_name_or_path, "tokenizer.model_max_length": max_length}), + ) + + +@pytest.mark.unit +def test_hf_local_invocation_layer_with_task_name(): + mock_pipeline = create_mock_pipeline() + mock_get_task = Mock(return_value="dummy_task") + + with patch("haystack.nodes.prompt.invocation_layer.hugging_face.get_task", mock_get_task): + with patch("haystack.nodes.prompt.invocation_layer.hugging_face.pipeline", mock_pipeline): + PromptModel( + model_name_or_path="local_model", + max_length=100, + model_kwargs={"task_name": "dummy_task"}, + invocation_layer_class=HFLocalInvocationLayer, + ) + # checking if get_task is called when task_name is passed to HFLocalInvocationLayer constructor + mock_get_task.assert_not_called() + mock_pipeline.assert_called_once()