fix: HuggingFaceTGIGenerator gets stuck when model is not supported (#6915)

* HuggingFaceTGIGenerator/HuggingFaceTGIChatGenerator check if model is deployed on free-tier
This commit is contained in:
Vladimir Blagojevic 2024-02-06 16:55:06 +01:00 committed by GitHub
parent b875eda4af
commit 9e6a2e3cf9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 141 additions and 19 deletions

View File

@ -8,7 +8,7 @@ from haystack.components.generators.utils import serialize_callback_handler, des
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params
from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params, list_inference_deployed_models
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient
@ -139,11 +139,25 @@ class HuggingFaceTGIChatGenerator:
def warm_up(self) -> None:
"""
Load the tokenizer.
If the url is not provided, check if the model is deployed on the free tier of the HF inference API.
Load the tokenizer
"""
# is this user using HF free tier inference API?
if self.model and not self.url:
deployed_models = list_inference_deployed_models()
# Determine if the specified model is deployed in the free tier.
if self.model not in deployed_models:
raise ValueError(
f"The model {self.model} is not deployed on the free tier of the HF inference API. "
"To use free tier models provide the model ID and the token. Valid models are: "
f"{deployed_models}"
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model, token=self.token.resolve_value() if self.token else None
)
# mypy can't infer that chat_template attribute exists on the object returned by AutoTokenizer.from_pretrained
chat_template = getattr(self.tokenizer, "chat_template", None)
if not chat_template and not self.chat_template:

View File

@ -8,7 +8,7 @@ from haystack.components.generators.utils import serialize_callback_handler, des
from haystack.dataclasses import StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params
from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params, list_inference_deployed_models
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient
@ -122,8 +122,21 @@ class HuggingFaceTGIGenerator:
def warm_up(self) -> None:
"""
If the url is not provided, check if the model is deployed on the free tier of the HF inference API.
Load the tokenizer
"""
# is this user using HF free tier inference API?
if self.model and not self.url:
deployed_models = list_inference_deployed_models()
# Determine if the specified model is deployed in the free tier.
if self.model not in deployed_models:
raise ValueError(
f"The model {self.model} is not deployed on the free tier of the HF inference API. "
"To use free tier models provide the model ID and the token. Valid models are: "
f"{deployed_models}"
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model, token=self.token.resolve_value() if self.token else None
)

View File

@ -4,10 +4,12 @@ import logging
from enum import Enum
from typing import Any, Dict, Optional, List, Union, Callable
import requests
from haystack.dataclasses import StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils.device import ComponentDevice
from haystack.utils.auth import Secret
from haystack.utils.device import ComponentDevice
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_import:
import torch
@ -92,6 +94,28 @@ def resolve_hf_device_map(device: Optional[ComponentDevice], model_kwargs: Optio
return model_kwargs
def list_inference_deployed_models(headers: Optional[Dict] = None) -> List[str]:
"""
List all currently deployed models on HF TGI free tier
:param headers: Optional dictionary of headers to include in the request
:type headers: Optional[Dict]
:return: list of all currently deployed models
:raises Exception: If the request to the TGI API fails
"""
resp = requests.get(
"https://api-inference.huggingface.co/framework/text-generation-inference", headers=headers, timeout=10
)
payload = resp.json()
if resp.status_code != 200:
message = payload["error"] if "error" in payload else "Unknown TGI error"
error_type = payload["error_type"] if "error_type" in payload else "Unknown TGI error type"
raise Exception(f"Failed to fetch TGI deployed models: {message}. Error type: {error_type}")
return [model["model_id"] for model in payload]
def check_valid_model(model_id: str, model_type: HFModelType, token: Optional[Secret]) -> None:
"""
Check if the provided model ID corresponds to a valid model on HuggingFace Hub.

View File

@ -0,0 +1,7 @@
---
fixes:
- |
Resolves a bug where the HuggingFaceTGIGenerator and HuggingFaceTGIChatGenerator encountered issues if provided
with valid models that were not available on the HuggingFace inference API rate-limited tier. The fix, detailed
in [GitHub issue #6816](https://github.com/deepset-ai/haystack/issues/6816) and its GitHub PR, ensures these
components now correctly handle model availability, eliminating previous limitations.

View File

@ -1,4 +1,5 @@
from unittest.mock import patch, MagicMock, Mock
from haystack.utils.auth import Secret
import pytest
@ -10,6 +11,22 @@ from haystack.components.generators.chat import HuggingFaceTGIChatGenerator
from haystack.dataclasses import StreamingChunk, ChatMessage
@pytest.fixture
def mock_list_inference_deployed_models():
with patch(
"haystack.components.generators.chat.hugging_face_tgi.list_inference_deployed_models",
MagicMock(
return_value=[
"HuggingFaceH4/zephyr-7b-alpha",
"HuggingFaceH4/zephyr-7b-beta",
"mistralai/Mistral-7B-v0.1",
"meta-llama/Llama-2-13b-chat-hf",
]
),
) as mock:
yield mock
@pytest.fixture
def mock_check_valid_model():
with patch(
@ -37,7 +54,9 @@ def streaming_callback_handler(x):
class TestHuggingFaceTGIChatGenerator:
def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_valid_model, mock_auto_tokenizer):
def test_initialize_with_valid_model_and_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models
):
model = "HuggingFaceH4/zephyr-7b-alpha"
generation_kwargs = {"n": 1}
stop_words = ["stop"]
@ -90,14 +109,16 @@ class TestHuggingFaceTGIChatGenerator:
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]}
assert generator_2.streaming_callback is streaming_callback_handler
def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer):
def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models):
generator = HuggingFaceTGIChatGenerator()
generator.warm_up()
# Assert that the tokenizer is now initialized
assert generator.tokenizer is not None
def test_warm_up_no_chat_template(self, mock_check_valid_model, mock_auto_tokenizer, caplog):
def test_warm_up_no_chat_template(
self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models, caplog
):
generator = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
# Set chat_template to None for this specific test
@ -108,7 +129,12 @@ class TestHuggingFaceTGIChatGenerator:
assert "The model 'meta-llama/Llama-2-13b-chat-hf' doesn't have a default chat_template" in caplog.text
def test_custom_chat_template(
self, chat_messages, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
self,
chat_messages,
mock_check_valid_model,
mock_auto_tokenizer,
mock_text_generation,
mock_list_inference_deployed_models,
):
custom_chat_template = "Here goes some Jinja template"
@ -154,7 +180,12 @@ class TestHuggingFaceTGIChatGenerator:
HuggingFaceTGIChatGenerator(model="invalid_model_id", url="https://some_chat_model.com")
def test_generate_text_response_with_valid_prompt_and_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
self,
mock_check_valid_model,
mock_auto_tokenizer,
mock_text_generation,
chat_messages,
mock_list_inference_deployed_models,
):
model = "meta-llama/Llama-2-13b-chat-hf"
generation_kwargs = {"n": 1}
@ -183,7 +214,12 @@ class TestHuggingFaceTGIChatGenerator:
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
self,
mock_check_valid_model,
mock_auto_tokenizer,
mock_text_generation,
chat_messages,
mock_list_inference_deployed_models,
):
model = "meta-llama/Llama-2-13b-chat-hf"
token = None
@ -214,7 +250,12 @@ class TestHuggingFaceTGIChatGenerator:
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
def test_generate_text_with_stop_words(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
self,
mock_check_valid_model,
mock_auto_tokenizer,
mock_text_generation,
chat_messages,
mock_list_inference_deployed_models,
):
generator = HuggingFaceTGIChatGenerator()
generator.warm_up()
@ -236,7 +277,12 @@ class TestHuggingFaceTGIChatGenerator:
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
def test_generate_text_with_custom_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
self,
mock_check_valid_model,
mock_auto_tokenizer,
mock_text_generation,
chat_messages,
mock_list_inference_deployed_models,
):
# Create an instance of HuggingFaceRemoteGenerator with no generation parameters
generator = HuggingFaceTGIChatGenerator()
@ -258,7 +304,12 @@ class TestHuggingFaceTGIChatGenerator:
assert response["replies"][0].content == "I'm fine, thanks."
def test_generate_text_with_streaming_callback(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
self,
mock_check_valid_model,
mock_auto_tokenizer,
mock_text_generation,
chat_messages,
mock_list_inference_deployed_models,
):
streaming_call_count = 0

View File

@ -1,5 +1,4 @@
from unittest.mock import patch, MagicMock, Mock
from haystack.utils.auth import Secret
import pytest
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
@ -7,6 +6,18 @@ from huggingface_hub.utils import RepositoryNotFoundError
from haystack.components.generators import HuggingFaceTGIGenerator
from haystack.dataclasses import StreamingChunk
from haystack.utils.auth import Secret
@pytest.fixture
def mock_list_inference_deployed_models():
with patch(
"haystack.components.generators.hugging_face_tgi.list_inference_deployed_models",
MagicMock(
return_value=["HuggingFaceH4/zephyr-7b-alpha", "HuggingFaceH4/zephyr-7b-alpha", "mistralai/Mistral-7B-v0.1"]
),
) as mock:
yield mock
@pytest.fixture
@ -102,7 +113,7 @@ class TestHuggingFaceTGIGenerator:
HuggingFaceTGIGenerator(model="invalid_model_id", url="https://some_chat_model.com")
def test_generate_text_response_with_valid_prompt_and_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models
):
model = "mistralai/Mistral-7B-v0.1"
@ -136,7 +147,7 @@ class TestHuggingFaceTGIGenerator:
assert [isinstance(reply, str) for reply in response["replies"]]
def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models
):
model = "mistralai/Mistral-7B-v0.1"
generation_kwargs = {"n": 3}
@ -186,7 +197,9 @@ class TestHuggingFaceTGIGenerator:
streaming_callback=streaming_callback,
)
def test_generate_text_with_stop_words(self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation):
def test_generate_text_with_stop_words(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models
):
generator = HuggingFaceTGIGenerator()
generator.warm_up()
@ -210,7 +223,7 @@ class TestHuggingFaceTGIGenerator:
assert [isinstance(reply, dict) for reply in response["replies"]]
def test_generate_text_with_custom_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models
):
generator = HuggingFaceTGIGenerator()
generator.warm_up()
@ -236,7 +249,7 @@ class TestHuggingFaceTGIGenerator:
assert [isinstance(reply, str) for reply in response["replies"]]
def test_generate_text_with_streaming_callback(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models
):
streaming_call_count = 0