mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 02:39:51 +00:00
fix: Improve robustness of get_task HF pipeline invocations (#5284)
* replace get_task method and change invocation layer order * add test for invocation layer order * add test documentation * make invocation layer test more robust * fix type annotation * change hf timeout * simplify timeout mock and add get_task exception cause --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
This commit is contained in:
parent
ac412193cc
commit
08f1865ddd
@ -3,10 +3,10 @@ from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLay
|
||||
from haystack.nodes.prompt.invocation_layer.chatgpt import ChatGPTInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.azure_chatgpt import AzureChatGPTInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
|
||||
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.anthropic_claude import AnthropicClaudeInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.cohere import CohereInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_infer import SageMakerHFInferenceInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_text_gen import SageMakerHFTextGenerationInvocationLayer
|
||||
|
||||
@ -21,7 +21,7 @@ with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_a
|
||||
GenerationConfig,
|
||||
Pipeline,
|
||||
)
|
||||
from transformers.pipelines import get_task
|
||||
from huggingface_hub import model_info
|
||||
from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler
|
||||
|
||||
@ -43,6 +43,15 @@ with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_a
|
||||
stop_result = torch.isin(self.stop_words["input_ids"], input_ids[-1])
|
||||
return any(all(stop_word) for stop_word in stop_result)
|
||||
|
||||
def get_task(model: str, use_auth_token: Optional[Union[str, bool]] = None, timeout: float = 3.0) -> Optional[str]:
|
||||
"""
|
||||
Simplified version of transformers.pipelines.get_task with support for timeouts
|
||||
"""
|
||||
try:
|
||||
return model_info(model, token=use_auth_token, timeout=timeout).pipeline_tag
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"The task of {model} could not be checked because of the following error: {e}") from e
|
||||
|
||||
|
||||
class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
"""
|
||||
|
||||
@ -368,7 +368,9 @@ def test_supports(tmp_path):
|
||||
assert HFLocalInvocationLayer.supports("google/flan-t5-base")
|
||||
assert HFLocalInvocationLayer.supports("mosaicml/mpt-7b")
|
||||
assert HFLocalInvocationLayer.supports("CarperAI/stable-vicuna-13b-delta")
|
||||
assert mock_get_task.call_count == 3
|
||||
mock_get_task.side_effect = RuntimeError
|
||||
assert not HFLocalInvocationLayer.supports("google/flan-t5-base")
|
||||
assert mock_get_task.call_count == 4
|
||||
|
||||
# some HF local model directory, let's use the one from test/prompt/invocation_layer
|
||||
assert HFLocalInvocationLayer.supports(str(tmp_path))
|
||||
|
||||
14
test/prompt/invocation_layer/test_invocation_layers.py
Normal file
14
test/prompt/invocation_layer/test_invocation_layers.py
Normal file
@ -0,0 +1,14 @@
|
||||
import pytest
|
||||
|
||||
from haystack.nodes.prompt.prompt_model import PromptModelInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer, HFInferenceEndpointInvocationLayer
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invocation_layer_order():
|
||||
"""
|
||||
Checks that the huggingface invocation layer is checked late because it can timeout/be slow to respond.
|
||||
"""
|
||||
last_invocation_layers = set(PromptModelInvocationLayer.invocation_layer_providers[-5:])
|
||||
assert HFLocalInvocationLayer in last_invocation_layers
|
||||
assert HFInferenceEndpointInvocationLayer in last_invocation_layers
|
||||
Loading…
x
Reference in New Issue
Block a user