mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 02:39:51 +00:00
fix: loads local HF Models in PromptNode pipeline (#4670)
* bug: fix load local HF Models in PromptNode pipeline * Update hugging_face.py remove duplicate validator * update: black formatted * update: update doc string, replace pop with get * test HFLocalInvocationLayer with local model
This commit is contained in:
parent
0d6fba14fd
commit
d033a086d0
@ -47,7 +47,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
:param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of
|
||||
all PromptModelInvocationLayer instances, this instance of HFLocalInvocationLayer might receive some unrelated
|
||||
kwargs. Only kwargs relevant to the HFLocalInvocationLayer are considered. The list of supported kwargs
|
||||
includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
|
||||
includes: task_name, trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
|
||||
For more details about pipeline kwargs in general, see
|
||||
Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
|
||||
|
||||
@ -119,8 +119,15 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
|
||||
if len(model_input_kwargs) > 0:
|
||||
logger.info("Using model input kwargs %s in %s", model_input_kwargs, self.__class__.__name__)
|
||||
self.task_name = get_task(model_name_or_path, use_auth_token=use_auth_token)
|
||||
|
||||
# If task_name is not provided, get the task name from the model name or path (uses HFApi)
|
||||
if "task_name" in kwargs:
|
||||
self.task_name = kwargs.get("task_name")
|
||||
else:
|
||||
self.task_name = get_task(model_name_or_path, use_auth_token=use_auth_token)
|
||||
|
||||
self.pipe = pipeline(
|
||||
task=self.task_name, # task_name is used to determine the pipeline type
|
||||
model=model_name_or_path,
|
||||
device=self.devices[0] if "device_map" not in model_input_kwargs else None,
|
||||
use_auth_token=self.use_auth_token,
|
||||
|
||||
@ -3,7 +3,7 @@ from unittest.mock import patch, Mock
|
||||
import pytest
|
||||
|
||||
from haystack.nodes.prompt.prompt_model import PromptModel
|
||||
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer, HFLocalInvocationLayer
|
||||
|
||||
from .conftest import create_mock_layer_that_supports
|
||||
|
||||
@ -36,3 +36,28 @@ def test_construtor_with_custom_model():
|
||||
def test_constructor_with_no_supported_model():
|
||||
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
|
||||
PromptModel("some-random-model")
|
||||
|
||||
|
||||
def create_mock_pipeline(model_name_or_path=None, max_length=100):
|
||||
return Mock(
|
||||
**{"model_name_or_path": model_name_or_path},
|
||||
return_value=Mock(**{"model_name_or_path": model_name_or_path, "tokenizer.model_max_length": max_length}),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_hf_local_invocation_layer_with_task_name():
|
||||
mock_pipeline = create_mock_pipeline()
|
||||
mock_get_task = Mock(return_value="dummy_task")
|
||||
|
||||
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.get_task", mock_get_task):
|
||||
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.pipeline", mock_pipeline):
|
||||
PromptModel(
|
||||
model_name_or_path="local_model",
|
||||
max_length=100,
|
||||
model_kwargs={"task_name": "dummy_task"},
|
||||
invocation_layer_class=HFLocalInvocationLayer,
|
||||
)
|
||||
# checking if get_task is called when task_name is passed to HFLocalInvocationLayer constructor
|
||||
mock_get_task.assert_not_called()
|
||||
mock_pipeline.assert_called_once()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user