fix: Improve robustness of get_task HF pipeline invocations (#5284)

* replace get_task method and change invocation layer order

* add test for invocation layer order

* add test documentation

* make invocation layer test more robust

* fix type annotation

* change hf timeout

* simplify timeout mock and add get_task exception cause

---------

Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
This commit is contained in:
MichelBartels 2023-07-06 16:33:44 +02:00 committed by GitHub
parent ac412193cc
commit 08f1865ddd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 29 additions and 4 deletions

View File

@ -3,10 +3,10 @@ from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLay
from haystack.nodes.prompt.invocation_layer.chatgpt import ChatGPTInvocationLayer
from haystack.nodes.prompt.invocation_layer.azure_chatgpt import AzureChatGPTInvocationLayer
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
from haystack.nodes.prompt.invocation_layer.anthropic_claude import AnthropicClaudeInvocationLayer
from haystack.nodes.prompt.invocation_layer.cohere import CohereInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_infer import SageMakerHFInferenceInvocationLayer
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_text_gen import SageMakerHFTextGenerationInvocationLayer

View File

@ -21,7 +21,7 @@ with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_a
GenerationConfig,
Pipeline,
)
from transformers.pipelines import get_task
from huggingface_hub import model_info
from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports
from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler
@ -43,6 +43,15 @@ with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_a
stop_result = torch.isin(self.stop_words["input_ids"], input_ids[-1])
return any(all(stop_word) for stop_word in stop_result)
def get_task(model: str, use_auth_token: Optional[Union[str, bool]] = None, timeout: float = 3.0) -> Optional[str]:
"""
Simplified version of transformers.pipelines.get_task with support for timeouts
"""
try:
return model_info(model, token=use_auth_token, timeout=timeout).pipeline_tag
except Exception as e:
raise RuntimeError(f"The task of {model} could not be checked because of the following error: {e}") from e
class HFLocalInvocationLayer(PromptModelInvocationLayer):
"""

View File

@ -368,7 +368,9 @@ def test_supports(tmp_path):
assert HFLocalInvocationLayer.supports("google/flan-t5-base")
assert HFLocalInvocationLayer.supports("mosaicml/mpt-7b")
assert HFLocalInvocationLayer.supports("CarperAI/stable-vicuna-13b-delta")
assert mock_get_task.call_count == 3
mock_get_task.side_effect = RuntimeError
assert not HFLocalInvocationLayer.supports("google/flan-t5-base")
assert mock_get_task.call_count == 4
# some HF local model directory, let's use the one from test/prompt/invocation_layer
assert HFLocalInvocationLayer.supports(str(tmp_path))

View File

@ -0,0 +1,14 @@
import pytest
from haystack.nodes.prompt.prompt_model import PromptModelInvocationLayer
from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer, HFInferenceEndpointInvocationLayer
@pytest.mark.unit
def test_invocation_layer_order():
"""
Checks that the huggingface invocation layer is checked late because it can timeout/be slow to respond.
"""
last_invocation_layers = set(PromptModelInvocationLayer.invocation_layer_providers[-5:])
assert HFLocalInvocationLayer in last_invocation_layers
assert HFInferenceEndpointInvocationLayer in last_invocation_layers