mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-05 19:47:45 +00:00
fix: HuggingFaceTGIGenerator gets stuck when model is not supported (#6915)
* HuggingFaceTGIGenerator/HuggingFaceTGIChatGenerator check if model is deployed on free-tier
This commit is contained in:
parent
b875eda4af
commit
9e6a2e3cf9
@ -8,7 +8,7 @@ from haystack.components.generators.utils import serialize_callback_handler, des
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params
|
||||
from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params, list_inference_deployed_models
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
@ -139,11 +139,25 @@ class HuggingFaceTGIChatGenerator:
|
||||
|
||||
def warm_up(self) -> None:
|
||||
"""
|
||||
Load the tokenizer.
|
||||
If the url is not provided, check if the model is deployed on the free tier of the HF inference API.
|
||||
Load the tokenizer
|
||||
"""
|
||||
|
||||
# is this user using HF free tier inference API?
|
||||
if self.model and not self.url:
|
||||
deployed_models = list_inference_deployed_models()
|
||||
# Determine if the specified model is deployed in the free tier.
|
||||
if self.model not in deployed_models:
|
||||
raise ValueError(
|
||||
f"The model {self.model} is not deployed on the free tier of the HF inference API. "
|
||||
"To use free tier models provide the model ID and the token. Valid models are: "
|
||||
f"{deployed_models}"
|
||||
)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model, token=self.token.resolve_value() if self.token else None
|
||||
)
|
||||
|
||||
# mypy can't infer that chat_template attribute exists on the object returned by AutoTokenizer.from_pretrained
|
||||
chat_template = getattr(self.tokenizer, "chat_template", None)
|
||||
if not chat_template and not self.chat_template:
|
||||
|
||||
@ -8,7 +8,7 @@ from haystack.components.generators.utils import serialize_callback_handler, des
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params
|
||||
from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params, list_inference_deployed_models
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
@ -122,8 +122,21 @@ class HuggingFaceTGIGenerator:
|
||||
|
||||
def warm_up(self) -> None:
|
||||
"""
|
||||
If the url is not provided, check if the model is deployed on the free tier of the HF inference API.
|
||||
Load the tokenizer
|
||||
"""
|
||||
|
||||
# is this user using HF free tier inference API?
|
||||
if self.model and not self.url:
|
||||
deployed_models = list_inference_deployed_models()
|
||||
# Determine if the specified model is deployed in the free tier.
|
||||
if self.model not in deployed_models:
|
||||
raise ValueError(
|
||||
f"The model {self.model} is not deployed on the free tier of the HF inference API. "
|
||||
"To use free tier models provide the model ID and the token. Valid models are: "
|
||||
f"{deployed_models}"
|
||||
)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model, token=self.token.resolve_value() if self.token else None
|
||||
)
|
||||
|
||||
@ -4,10 +4,12 @@ import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, List, Union, Callable
|
||||
|
||||
import requests
|
||||
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils.device import ComponentDevice
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.utils.device import ComponentDevice
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_import:
|
||||
import torch
|
||||
@ -92,6 +94,28 @@ def resolve_hf_device_map(device: Optional[ComponentDevice], model_kwargs: Optio
|
||||
return model_kwargs
|
||||
|
||||
|
||||
def list_inference_deployed_models(headers: Optional[Dict] = None) -> List[str]:
|
||||
"""
|
||||
List all currently deployed models on HF TGI free tier
|
||||
|
||||
:param headers: Optional dictionary of headers to include in the request
|
||||
:type headers: Optional[Dict]
|
||||
:return: list of all currently deployed models
|
||||
:raises Exception: If the request to the TGI API fails
|
||||
|
||||
"""
|
||||
resp = requests.get(
|
||||
"https://api-inference.huggingface.co/framework/text-generation-inference", headers=headers, timeout=10
|
||||
)
|
||||
|
||||
payload = resp.json()
|
||||
if resp.status_code != 200:
|
||||
message = payload["error"] if "error" in payload else "Unknown TGI error"
|
||||
error_type = payload["error_type"] if "error_type" in payload else "Unknown TGI error type"
|
||||
raise Exception(f"Failed to fetch TGI deployed models: {message}. Error type: {error_type}")
|
||||
return [model["model_id"] for model in payload]
|
||||
|
||||
|
||||
def check_valid_model(model_id: str, model_type: HFModelType, token: Optional[Secret]) -> None:
|
||||
"""
|
||||
Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
|
||||
|
||||
@ -0,0 +1,7 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Resolves a bug where the HuggingFaceTGIGenerator and HuggingFaceTGIChatGenerator encountered issues if provided
|
||||
with valid models that were not available on the HuggingFace inference API rate-limited tier. The fix, detailed
|
||||
in [GitHub issue #6816](https://github.com/deepset-ai/haystack/issues/6816) and its GitHub PR, ensures these
|
||||
components now correctly handle model availability, eliminating previous limitations.
|
||||
@ -1,4 +1,5 @@
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
import pytest
|
||||
@ -10,6 +11,22 @@ from haystack.components.generators.chat import HuggingFaceTGIChatGenerator
|
||||
from haystack.dataclasses import StreamingChunk, ChatMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_list_inference_deployed_models():
|
||||
with patch(
|
||||
"haystack.components.generators.chat.hugging_face_tgi.list_inference_deployed_models",
|
||||
MagicMock(
|
||||
return_value=[
|
||||
"HuggingFaceH4/zephyr-7b-alpha",
|
||||
"HuggingFaceH4/zephyr-7b-beta",
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
"meta-llama/Llama-2-13b-chat-hf",
|
||||
]
|
||||
),
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_check_valid_model():
|
||||
with patch(
|
||||
@ -37,7 +54,9 @@ def streaming_callback_handler(x):
|
||||
|
||||
|
||||
class TestHuggingFaceTGIChatGenerator:
|
||||
def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_valid_model, mock_auto_tokenizer):
|
||||
def test_initialize_with_valid_model_and_generation_parameters(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models
|
||||
):
|
||||
model = "HuggingFaceH4/zephyr-7b-alpha"
|
||||
generation_kwargs = {"n": 1}
|
||||
stop_words = ["stop"]
|
||||
@ -90,14 +109,16 @@ class TestHuggingFaceTGIChatGenerator:
|
||||
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]}
|
||||
assert generator_2.streaming_callback is streaming_callback_handler
|
||||
|
||||
def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer):
|
||||
def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models):
|
||||
generator = HuggingFaceTGIChatGenerator()
|
||||
generator.warm_up()
|
||||
|
||||
# Assert that the tokenizer is now initialized
|
||||
assert generator.tokenizer is not None
|
||||
|
||||
def test_warm_up_no_chat_template(self, mock_check_valid_model, mock_auto_tokenizer, caplog):
|
||||
def test_warm_up_no_chat_template(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models, caplog
|
||||
):
|
||||
generator = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
|
||||
|
||||
# Set chat_template to None for this specific test
|
||||
@ -108,7 +129,12 @@ class TestHuggingFaceTGIChatGenerator:
|
||||
assert "The model 'meta-llama/Llama-2-13b-chat-hf' doesn't have a default chat_template" in caplog.text
|
||||
|
||||
def test_custom_chat_template(
|
||||
self, chat_messages, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
|
||||
self,
|
||||
chat_messages,
|
||||
mock_check_valid_model,
|
||||
mock_auto_tokenizer,
|
||||
mock_text_generation,
|
||||
mock_list_inference_deployed_models,
|
||||
):
|
||||
custom_chat_template = "Here goes some Jinja template"
|
||||
|
||||
@ -154,7 +180,12 @@ class TestHuggingFaceTGIChatGenerator:
|
||||
HuggingFaceTGIChatGenerator(model="invalid_model_id", url="https://some_chat_model.com")
|
||||
|
||||
def test_generate_text_response_with_valid_prompt_and_generation_parameters(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
|
||||
self,
|
||||
mock_check_valid_model,
|
||||
mock_auto_tokenizer,
|
||||
mock_text_generation,
|
||||
chat_messages,
|
||||
mock_list_inference_deployed_models,
|
||||
):
|
||||
model = "meta-llama/Llama-2-13b-chat-hf"
|
||||
generation_kwargs = {"n": 1}
|
||||
@ -183,7 +214,12 @@ class TestHuggingFaceTGIChatGenerator:
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
|
||||
def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
|
||||
self,
|
||||
mock_check_valid_model,
|
||||
mock_auto_tokenizer,
|
||||
mock_text_generation,
|
||||
chat_messages,
|
||||
mock_list_inference_deployed_models,
|
||||
):
|
||||
model = "meta-llama/Llama-2-13b-chat-hf"
|
||||
token = None
|
||||
@ -214,7 +250,12 @@ class TestHuggingFaceTGIChatGenerator:
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
|
||||
def test_generate_text_with_stop_words(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
|
||||
self,
|
||||
mock_check_valid_model,
|
||||
mock_auto_tokenizer,
|
||||
mock_text_generation,
|
||||
chat_messages,
|
||||
mock_list_inference_deployed_models,
|
||||
):
|
||||
generator = HuggingFaceTGIChatGenerator()
|
||||
generator.warm_up()
|
||||
@ -236,7 +277,12 @@ class TestHuggingFaceTGIChatGenerator:
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
|
||||
def test_generate_text_with_custom_generation_parameters(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
|
||||
self,
|
||||
mock_check_valid_model,
|
||||
mock_auto_tokenizer,
|
||||
mock_text_generation,
|
||||
chat_messages,
|
||||
mock_list_inference_deployed_models,
|
||||
):
|
||||
# Create an instance of HuggingFaceRemoteGenerator with no generation parameters
|
||||
generator = HuggingFaceTGIChatGenerator()
|
||||
@ -258,7 +304,12 @@ class TestHuggingFaceTGIChatGenerator:
|
||||
assert response["replies"][0].content == "I'm fine, thanks."
|
||||
|
||||
def test_generate_text_with_streaming_callback(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
|
||||
self,
|
||||
mock_check_valid_model,
|
||||
mock_auto_tokenizer,
|
||||
mock_text_generation,
|
||||
chat_messages,
|
||||
mock_list_inference_deployed_models,
|
||||
):
|
||||
streaming_call_count = 0
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
import pytest
|
||||
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
|
||||
@ -7,6 +6,18 @@ from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from haystack.components.generators import HuggingFaceTGIGenerator
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_list_inference_deployed_models():
|
||||
with patch(
|
||||
"haystack.components.generators.hugging_face_tgi.list_inference_deployed_models",
|
||||
MagicMock(
|
||||
return_value=["HuggingFaceH4/zephyr-7b-alpha", "HuggingFaceH4/zephyr-7b-alpha", "mistralai/Mistral-7B-v0.1"]
|
||||
),
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -102,7 +113,7 @@ class TestHuggingFaceTGIGenerator:
|
||||
HuggingFaceTGIGenerator(model="invalid_model_id", url="https://some_chat_model.com")
|
||||
|
||||
def test_generate_text_response_with_valid_prompt_and_generation_parameters(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models
|
||||
):
|
||||
model = "mistralai/Mistral-7B-v0.1"
|
||||
|
||||
@ -136,7 +147,7 @@ class TestHuggingFaceTGIGenerator:
|
||||
assert [isinstance(reply, str) for reply in response["replies"]]
|
||||
|
||||
def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models
|
||||
):
|
||||
model = "mistralai/Mistral-7B-v0.1"
|
||||
generation_kwargs = {"n": 3}
|
||||
@ -186,7 +197,9 @@ class TestHuggingFaceTGIGenerator:
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
|
||||
def test_generate_text_with_stop_words(self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation):
|
||||
def test_generate_text_with_stop_words(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models
|
||||
):
|
||||
generator = HuggingFaceTGIGenerator()
|
||||
generator.warm_up()
|
||||
|
||||
@ -210,7 +223,7 @@ class TestHuggingFaceTGIGenerator:
|
||||
assert [isinstance(reply, dict) for reply in response["replies"]]
|
||||
|
||||
def test_generate_text_with_custom_generation_parameters(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models
|
||||
):
|
||||
generator = HuggingFaceTGIGenerator()
|
||||
generator.warm_up()
|
||||
@ -236,7 +249,7 @@ class TestHuggingFaceTGIGenerator:
|
||||
assert [isinstance(reply, str) for reply in response["replies"]]
|
||||
|
||||
def test_generate_text_with_streaming_callback(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models
|
||||
):
|
||||
streaming_call_count = 0
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user