From 9e6a2e3cf98a35eb030b2eba7206a85d1c7c5e07 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 6 Feb 2024 16:55:06 +0100 Subject: [PATCH] fix: HuggingFaceTGIGenerator gets stuck when model is not supported (#6915) * HuggingFaceTGIGenerator/HuggingFaceTGIChatGenerator check if model is deployed on free-tier --- .../generators/chat/hugging_face_tgi.py | 18 ++++- .../components/generators/hugging_face_tgi.py | 15 +++- haystack/utils/hf.py | 26 ++++++- ...-hf-free-tier-checks-99384060139d5d30.yaml | 7 ++ .../generators/chat/test_hugging_face_tgi.py | 69 ++++++++++++++++--- .../generators/test_hugging_face_tgi.py | 25 +++++-- 6 files changed, 141 insertions(+), 19 deletions(-) create mode 100644 releasenotes/notes/add-hf-free-tier-checks-99384060139d5d30.yaml diff --git a/haystack/components/generators/chat/hugging_face_tgi.py b/haystack/components/generators/chat/hugging_face_tgi.py index 83cea5d40..5fb1b7acf 100644 --- a/haystack/components/generators/chat/hugging_face_tgi.py +++ b/haystack/components/generators/chat/hugging_face_tgi.py @@ -8,7 +8,7 @@ from haystack.components.generators.utils import serialize_callback_handler, des from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.lazy_imports import LazyImport from haystack.utils import Secret, deserialize_secrets_inplace -from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params +from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params, list_inference_deployed_models with LazyImport(message="Run 'pip install transformers'") as transformers_import: from huggingface_hub import InferenceClient @@ -139,11 +139,25 @@ class HuggingFaceTGIChatGenerator: def warm_up(self) -> None: """ - Load the tokenizer. + If the url is not provided, check if the model is deployed on the free tier of the HF inference API. + Load the tokenizer """ + + # is this user using HF free tier inference API? + if self.model and not self.url: + deployed_models = list_inference_deployed_models() + # Determine if the specified model is deployed in the free tier. + if self.model not in deployed_models: + raise ValueError( + f"The model {self.model} is not deployed on the free tier of the HF inference API. " + "To use free tier models provide the model ID and the token. Valid models are: " + f"{deployed_models}" + ) + self.tokenizer = AutoTokenizer.from_pretrained( self.model, token=self.token.resolve_value() if self.token else None ) + # mypy can't infer that chat_template attribute exists on the object returned by AutoTokenizer.from_pretrained chat_template = getattr(self.tokenizer, "chat_template", None) if not chat_template and not self.chat_template: diff --git a/haystack/components/generators/hugging_face_tgi.py b/haystack/components/generators/hugging_face_tgi.py index d7499d586..bd328820e 100644 --- a/haystack/components/generators/hugging_face_tgi.py +++ b/haystack/components/generators/hugging_face_tgi.py @@ -8,7 +8,7 @@ from haystack.components.generators.utils import serialize_callback_handler, des from haystack.dataclasses import StreamingChunk from haystack.lazy_imports import LazyImport from haystack.utils import Secret, deserialize_secrets_inplace -from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params +from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params, list_inference_deployed_models with LazyImport(message="Run 'pip install transformers'") as transformers_import: from huggingface_hub import InferenceClient @@ -122,8 +122,21 @@ class HuggingFaceTGIGenerator: def warm_up(self) -> None: """ + If the url is not provided, check if the model is deployed on the free tier of the HF inference API. Load the tokenizer """ + + # is this user using HF free tier inference API? + if self.model and not self.url: + deployed_models = list_inference_deployed_models() + # Determine if the specified model is deployed in the free tier. + if self.model not in deployed_models: + raise ValueError( + f"The model {self.model} is not deployed on the free tier of the HF inference API. " + "To use free tier models provide the model ID and the token. Valid models are: " + f"{deployed_models}" + ) + self.tokenizer = AutoTokenizer.from_pretrained( self.model, token=self.token.resolve_value() if self.token else None ) diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index 32d9be1bb..bc8a5e721 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -4,10 +4,12 @@ import logging from enum import Enum from typing import Any, Dict, Optional, List, Union, Callable +import requests + from haystack.dataclasses import StreamingChunk from haystack.lazy_imports import LazyImport -from haystack.utils.device import ComponentDevice from haystack.utils.auth import Secret +from haystack.utils.device import ComponentDevice with LazyImport(message="Run 'pip install transformers[torch]'") as torch_import: import torch @@ -92,6 +94,28 @@ def resolve_hf_device_map(device: Optional[ComponentDevice], model_kwargs: Optio return model_kwargs +def list_inference_deployed_models(headers: Optional[Dict] = None) -> List[str]: + """ + List all currently deployed models on HF TGI free tier + + :param headers: Optional dictionary of headers to include in the request + :type headers: Optional[Dict] + :return: list of all currently deployed models + :raises Exception: If the request to the TGI API fails + + """ + resp = requests.get( + "https://api-inference.huggingface.co/framework/text-generation-inference", headers=headers, timeout=10 + ) + + payload = resp.json() + if resp.status_code != 200: + message = payload["error"] if "error" in payload else "Unknown TGI error" + error_type = payload["error_type"] if "error_type" in payload else "Unknown TGI error type" + raise Exception(f"Failed to fetch TGI deployed models: {message}. Error type: {error_type}") + return [model["model_id"] for model in payload] + + def check_valid_model(model_id: str, model_type: HFModelType, token: Optional[Secret]) -> None: """ Check if the provided model ID corresponds to a valid model on HuggingFace Hub. diff --git a/releasenotes/notes/add-hf-free-tier-checks-99384060139d5d30.yaml b/releasenotes/notes/add-hf-free-tier-checks-99384060139d5d30.yaml new file mode 100644 index 000000000..a56e4dd60 --- /dev/null +++ b/releasenotes/notes/add-hf-free-tier-checks-99384060139d5d30.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + Resolves a bug where the HuggingFaceTGIGenerator and HuggingFaceTGIChatGenerator encountered issues if provided + with valid models that were not available on the HuggingFace inference API rate-limited tier. The fix, detailed + in [GitHub issue #6816](https://github.com/deepset-ai/haystack/issues/6816) and its GitHub PR, ensures these + components now correctly handle model availability, eliminating previous limitations. diff --git a/test/components/generators/chat/test_hugging_face_tgi.py b/test/components/generators/chat/test_hugging_face_tgi.py index 3029b8cbe..74e04a6da 100644 --- a/test/components/generators/chat/test_hugging_face_tgi.py +++ b/test/components/generators/chat/test_hugging_face_tgi.py @@ -1,4 +1,5 @@ from unittest.mock import patch, MagicMock, Mock + from haystack.utils.auth import Secret import pytest @@ -10,6 +11,22 @@ from haystack.components.generators.chat import HuggingFaceTGIChatGenerator from haystack.dataclasses import StreamingChunk, ChatMessage +@pytest.fixture +def mock_list_inference_deployed_models(): + with patch( + "haystack.components.generators.chat.hugging_face_tgi.list_inference_deployed_models", + MagicMock( + return_value=[ + "HuggingFaceH4/zephyr-7b-alpha", + "HuggingFaceH4/zephyr-7b-beta", + "mistralai/Mistral-7B-v0.1", + "meta-llama/Llama-2-13b-chat-hf", + ] + ), + ) as mock: + yield mock + + @pytest.fixture def mock_check_valid_model(): with patch( @@ -37,7 +54,9 @@ def streaming_callback_handler(x): class TestHuggingFaceTGIChatGenerator: - def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_valid_model, mock_auto_tokenizer): + def test_initialize_with_valid_model_and_generation_parameters( + self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models + ): model = "HuggingFaceH4/zephyr-7b-alpha" generation_kwargs = {"n": 1} stop_words = ["stop"] @@ -90,14 +109,16 @@ class TestHuggingFaceTGIChatGenerator: assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]} assert generator_2.streaming_callback is streaming_callback_handler - def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer): + def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models): generator = HuggingFaceTGIChatGenerator() generator.warm_up() # Assert that the tokenizer is now initialized assert generator.tokenizer is not None - def test_warm_up_no_chat_template(self, mock_check_valid_model, mock_auto_tokenizer, caplog): + def test_warm_up_no_chat_template( + self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models, caplog + ): generator = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-13b-chat-hf") # Set chat_template to None for this specific test @@ -108,7 +129,12 @@ class TestHuggingFaceTGIChatGenerator: assert "The model 'meta-llama/Llama-2-13b-chat-hf' doesn't have a default chat_template" in caplog.text def test_custom_chat_template( - self, chat_messages, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation + self, + chat_messages, + mock_check_valid_model, + mock_auto_tokenizer, + mock_text_generation, + mock_list_inference_deployed_models, ): custom_chat_template = "Here goes some Jinja template" @@ -154,7 +180,12 @@ class TestHuggingFaceTGIChatGenerator: HuggingFaceTGIChatGenerator(model="invalid_model_id", url="https://some_chat_model.com") def test_generate_text_response_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages + self, + mock_check_valid_model, + mock_auto_tokenizer, + mock_text_generation, + chat_messages, + mock_list_inference_deployed_models, ): model = "meta-llama/Llama-2-13b-chat-hf" generation_kwargs = {"n": 1} @@ -183,7 +214,12 @@ class TestHuggingFaceTGIChatGenerator: assert [isinstance(reply, ChatMessage) for reply in response["replies"]] def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages + self, + mock_check_valid_model, + mock_auto_tokenizer, + mock_text_generation, + chat_messages, + mock_list_inference_deployed_models, ): model = "meta-llama/Llama-2-13b-chat-hf" token = None @@ -214,7 +250,12 @@ class TestHuggingFaceTGIChatGenerator: assert [isinstance(reply, ChatMessage) for reply in response["replies"]] def test_generate_text_with_stop_words( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages + self, + mock_check_valid_model, + mock_auto_tokenizer, + mock_text_generation, + chat_messages, + mock_list_inference_deployed_models, ): generator = HuggingFaceTGIChatGenerator() generator.warm_up() @@ -236,7 +277,12 @@ class TestHuggingFaceTGIChatGenerator: assert [isinstance(reply, ChatMessage) for reply in response["replies"]] def test_generate_text_with_custom_generation_parameters( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages + self, + mock_check_valid_model, + mock_auto_tokenizer, + mock_text_generation, + chat_messages, + mock_list_inference_deployed_models, ): # Create an instance of HuggingFaceRemoteGenerator with no generation parameters generator = HuggingFaceTGIChatGenerator() @@ -258,7 +304,12 @@ class TestHuggingFaceTGIChatGenerator: assert response["replies"][0].content == "I'm fine, thanks." def test_generate_text_with_streaming_callback( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages + self, + mock_check_valid_model, + mock_auto_tokenizer, + mock_text_generation, + chat_messages, + mock_list_inference_deployed_models, ): streaming_call_count = 0 diff --git a/test/components/generators/test_hugging_face_tgi.py b/test/components/generators/test_hugging_face_tgi.py index 2a5543cc1..752dcd439 100644 --- a/test/components/generators/test_hugging_face_tgi.py +++ b/test/components/generators/test_hugging_face_tgi.py @@ -1,5 +1,4 @@ from unittest.mock import patch, MagicMock, Mock -from haystack.utils.auth import Secret import pytest from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason @@ -7,6 +6,18 @@ from huggingface_hub.utils import RepositoryNotFoundError from haystack.components.generators import HuggingFaceTGIGenerator from haystack.dataclasses import StreamingChunk +from haystack.utils.auth import Secret + + +@pytest.fixture +def mock_list_inference_deployed_models(): + with patch( + "haystack.components.generators.hugging_face_tgi.list_inference_deployed_models", + MagicMock( + return_value=["HuggingFaceH4/zephyr-7b-alpha", "HuggingFaceH4/zephyr-7b-alpha", "mistralai/Mistral-7B-v0.1"] + ), + ) as mock: + yield mock @pytest.fixture @@ -102,7 +113,7 @@ class TestHuggingFaceTGIGenerator: HuggingFaceTGIGenerator(model="invalid_model_id", url="https://some_chat_model.com") def test_generate_text_response_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation + self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models ): model = "mistralai/Mistral-7B-v0.1" @@ -136,7 +147,7 @@ class TestHuggingFaceTGIGenerator: assert [isinstance(reply, str) for reply in response["replies"]] def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation + self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models ): model = "mistralai/Mistral-7B-v0.1" generation_kwargs = {"n": 3} @@ -186,7 +197,9 @@ class TestHuggingFaceTGIGenerator: streaming_callback=streaming_callback, ) - def test_generate_text_with_stop_words(self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation): + def test_generate_text_with_stop_words( + self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models + ): generator = HuggingFaceTGIGenerator() generator.warm_up() @@ -210,7 +223,7 @@ class TestHuggingFaceTGIGenerator: assert [isinstance(reply, dict) for reply in response["replies"]] def test_generate_text_with_custom_generation_parameters( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation + self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models ): generator = HuggingFaceTGIGenerator() generator.warm_up() @@ -236,7 +249,7 @@ class TestHuggingFaceTGIGenerator: assert [isinstance(reply, str) for reply in response["replies"]] def test_generate_text_with_streaming_callback( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation + self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models ): streaming_call_count = 0