From 6e2dbdc3204f15d71dd05e10b2ac0f9d5d646b49 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 2 Nov 2023 19:35:16 +0100 Subject: [PATCH] feat: Add `HuggingFaceTGIGenerator` Haystack 2.x component (#6205) * Add HuggingFaceTGIGenerator * PR review * PR feedback from Stefano --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> --- .../preview/components/generators/__init__.py | 3 +- .../preview/components/generators/hf_utils.py | 2 +- .../components/generators/hugging_face_tgi.py | 232 ++++++++++++++ ...ngface-tgi-generator-9d7eed86f5246ea9.yaml | 5 + .../generators/test_hugging_face_tgi.py | 295 ++++++++++++++++++ 5 files changed, 535 insertions(+), 2 deletions(-) create mode 100644 haystack/preview/components/generators/hugging_face_tgi.py create mode 100644 releasenotes/notes/add-huggingface-tgi-generator-9d7eed86f5246ea9.yaml create mode 100644 test/preview/components/generators/test_hugging_face_tgi.py diff --git a/haystack/preview/components/generators/__init__.py b/haystack/preview/components/generators/__init__.py index d6bf5c3f2..331f49a0a 100644 --- a/haystack/preview/components/generators/__init__.py +++ b/haystack/preview/components/generators/__init__.py @@ -1,4 +1,5 @@ from haystack.preview.components.generators.openai.gpt import GPTGenerator from haystack.preview.components.generators.hugging_face.hugging_face_local import HuggingFaceLocalGenerator +from haystack.preview.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator -__all__ = ["GPTGenerator", "HuggingFaceLocalGenerator"] +__all__ = ["GPTGenerator", "HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator"] diff --git a/haystack/preview/components/generators/hf_utils.py b/haystack/preview/components/generators/hf_utils.py index 7a19d7abe..9107b4152 100644 --- a/haystack/preview/components/generators/hf_utils.py +++ b/haystack/preview/components/generators/hf_utils.py @@ -5,7 +5,7 @@ from huggingface_hub import InferenceClient, HfApi from huggingface_hub.utils import RepositoryNotFoundError -def check_generation_params(kwargs: Dict[str, Any], additional_accepted_params: Optional[List[str]] = None): +def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepted_params: Optional[List[str]] = None): """ Check the provided generation parameters for validity. diff --git a/haystack/preview/components/generators/hugging_face_tgi.py b/haystack/preview/components/generators/hugging_face_tgi.py new file mode 100644 index 000000000..ed0d230ea --- /dev/null +++ b/haystack/preview/components/generators/hugging_face_tgi.py @@ -0,0 +1,232 @@ +import logging +from dataclasses import asdict +from typing import Any, Dict, List, Optional, Iterable, Callable +from urllib.parse import urlparse + +from huggingface_hub import InferenceClient +from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token +from transformers import AutoTokenizer + +from haystack.preview import component, default_to_dict, default_from_dict +from haystack.preview.components.generators.hf_utils import check_generation_params, check_valid_model +from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler +from haystack.preview.dataclasses import StreamingChunk + +logger = logging.getLogger(__name__) + + +@component +class HuggingFaceTGIGenerator: + """ + Enables text generation using HuggingFace Hub hosted non-chat LLMs. This component is designed to seamlessly + inference models deployed on the Text Generation Inference (TGI) backend. + + You can use this component for LLMs hosted on Hugging Face inference endpoints, the rate-limited + Inference API tier: + + ```python + from haystack.preview.components.generators import HuggingFaceTGIGenerator + client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", token="") + client.warm_up() + response = client.run("What's Natural Language Processing?", max_new_tokens=120) + print(response) + ``` + + Or for LLMs hosted on paid https://huggingface.co/inference-endpoints endpoint, and/or your own custom TGI endpoint. + In these two cases, you'll need to provide the URL of the endpoint as well as a valid token: + + ```python + from haystack.preview.components.generators import HuggingFaceTGIGenerator + client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", + url="", + token="") + client.warm_up() + response = client.run("What's Natural Language Processing?", max_new_tokens=120) + print(response) + ``` + + + Key Features and Compatibility: + - **Primary Compatibility**: Designed to work seamlessly with any non-chat model deployed using the TGI + framework. For more information on TGI, visit https://github.com/huggingface/text-generation-inference. + - **Hugging Face Inference Endpoints**: Supports inference of TGI chat LLMs deployed on Hugging Face + inference endpoints. For more details refer to https://huggingface.co/inference-endpoints. + - **Inference API Support**: Supports inference of TGI LLMs hosted on the rate-limited Inference + API tier. Learn more about the Inference API at: https://huggingface.co/inference-api + Discover available LLMs using the following command: + ``` + wget -qO- https://api-inference.huggingface.co/framework/text-generation-inference + ``` + And simply use the model ID as the model parameter for this component. You'll also need to provide a valid + Hugging Face API token as the token parameter. + - **Custom TGI Endpoints**: Supports inference of LLMs deployed on custom TGI endpoints. Anyone can + deploy their own TGI endpoint using the TGI framework. For more details refer + to https://huggingface.co/inference-endpoints. + Input and Output Format: + - **String Format**: This component uses the str format for structuring both input and output, + ensuring coherent and contextually relevant responses in text generation scenarios. + """ + + def __init__( + self, + model: str = "mistralai/Mistral-7B-v0.1", + url: Optional[str] = None, + token: Optional[str] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + stop_words: Optional[List[str]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): + """ + Initialize the HuggingFaceTGIGenerator instance. + + :param model: A string representing the model id on HF Hub. Default is "mistralai/Mistral-7B-v0.1". + :param url: An optional string representing the URL of the TGI endpoint. + :param token: The HuggingFace token to use as HTTP bearer authorization + You can find your HF token at https://huggingface.co/settings/tokens + :param generation_kwargs: A dictionary containing keyword arguments to customize text generation. + Some examples: `max_new_tokens`, `temperature`, `top_k`, `top_p`,... + See Hugging Face's documentation for more information at: + https://huggingface.co/docs/huggingface_hub/v0.18.0.rc0/en/package_reference/inference_client#huggingface_hub.inference._text_generation.TextGenerationParameters + :param stop_words: An optional list of strings representing the stop words. + :param streaming_callback: An optional callable for handling streaming responses. + """ + if url: + r = urlparse(url) + is_valid_url = all([r.scheme in ["http", "https"], r.netloc]) + if not is_valid_url: + raise ValueError(f"Invalid TGI endpoint URL provided: {url}") + + check_valid_model(model, token) + + # handle generation kwargs setup + generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} + check_generation_params(generation_kwargs, ["n"]) + generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", []) + generation_kwargs["stop_sequences"].extend(stop_words or []) + + self.model = model + self.url = url + self.token = token + self.generation_kwargs = generation_kwargs + self.client = InferenceClient(url or model, token=token) + self.streaming_callback = streaming_callback + self.tokenizer = None + + def warm_up(self) -> None: + """ + Load the tokenizer + """ + self.tokenizer = AutoTokenizer.from_pretrained(self.model, token=self.token) + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :return: A dictionary containing the serialized component. + """ + callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + model=self.model, + url=self.url, + token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens + generation_kwargs=self.generation_kwargs, + streaming_callback=callback_name, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTGIGenerator": + """ + Deserialize this component from a dictionary. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) + return default_from_dict(cls, data) + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + # Don't send URL as it is sensitive information + return {"model": self.model} + + @component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) + def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Invoke the text generation inference for the given prompt and generation parameters. + + :param prompt: A string representing the prompt. + :param generation_kwargs: Additional keyword arguments for text generation. + :return: A dictionary containing the generated replies and metadata. Both are lists of length n. + Replies are strings and metadata are dictionaries. + """ + # check generation kwargs given as parameters to override the default ones + additional_params = ["n", "stop_words"] + check_generation_params(generation_kwargs, additional_params) + + # update generation kwargs by merging with the default ones + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + num_responses = generation_kwargs.pop("n", 1) + generation_kwargs.setdefault("stop_sequences", []).extend(generation_kwargs.pop("stop_words", [])) + + if self.tokenizer is None: + raise RuntimeError("Please call warm_up() before running LLM inference.") + + prompt_token_count = len(self.tokenizer.encode(prompt, add_special_tokens=False)) + + if self.streaming_callback: + if num_responses > 1: + raise ValueError("Cannot stream multiple responses, please set n=1.") + + return self._run_streaming(prompt, prompt_token_count, generation_kwargs) + + return self._run_non_streaming(prompt, prompt_token_count, num_responses, generation_kwargs) + + def _run_streaming(self, prompt: str, prompt_token_count: int, generation_kwargs: Dict[str, Any]): + res_chunk: Iterable[TextGenerationStreamResponse] = self.client.text_generation( + prompt, details=True, stream=True, **generation_kwargs + ) + chunks: List[StreamingChunk] = [] + # pylint: disable=not-an-iterable + for chunk in res_chunk: + token: Token = chunk.token + if token.special: + continue + chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})} + stream_chunk = StreamingChunk(token.text, chunk_metadata) + chunks.append(stream_chunk) + self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) + metadata = { + "finish_reason": chunks[-1].metadata.get("finish_reason", None), + "model": self.client.model, + "usage": { + "completion_tokens": chunks[-1].metadata.get("generated_tokens", 0), + "prompt_tokens": prompt_token_count, + "total_tokens": prompt_token_count + chunks[-1].metadata.get("generated_tokens", 0), + }, + } + return {"replies": ["".join([chunk.content for chunk in chunks])], "metadata": [metadata]} + + def _run_non_streaming( + self, prompt: str, prompt_token_count: int, num_responses: int, generation_kwargs: Dict[str, Any] + ): + responses: List[str] = [] + all_metadata: List[Dict[str, Any]] = [] + for _i in range(num_responses): + tgr: TextGenerationResponse = self.client.text_generation(prompt, details=True, **generation_kwargs) + all_metadata.append( + { + "model": self.client.model, + "index": _i, + "finish_reason": tgr.details.finish_reason.value, + "usage": { + "completion_tokens": len(tgr.details.tokens), + "prompt_tokens": prompt_token_count, + "total_tokens": prompt_token_count + len(tgr.details.tokens), + }, + } + ) + responses.append(tgr.generated_text) + return {"replies": responses, "metadata": all_metadata} diff --git a/releasenotes/notes/add-huggingface-tgi-generator-9d7eed86f5246ea9.yaml b/releasenotes/notes/add-huggingface-tgi-generator-9d7eed86f5246ea9.yaml new file mode 100644 index 000000000..838351608 --- /dev/null +++ b/releasenotes/notes/add-huggingface-tgi-generator-9d7eed86f5246ea9.yaml @@ -0,0 +1,5 @@ +--- +preview: + - | + Adds `HuggingFaceTGIGenerator` for text generation. This components support remote inferencing for + Hugging Face LLMs via text-generation-inference (TGI) protocol. diff --git a/test/preview/components/generators/test_hugging_face_tgi.py b/test/preview/components/generators/test_hugging_face_tgi.py new file mode 100644 index 000000000..5fcbb304b --- /dev/null +++ b/test/preview/components/generators/test_hugging_face_tgi.py @@ -0,0 +1,295 @@ +from unittest.mock import patch, MagicMock, Mock + +import pytest +from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason +from huggingface_hub.utils import RepositoryNotFoundError + +from haystack.preview.components.generators import HuggingFaceTGIGenerator +from haystack.preview.dataclasses import StreamingChunk + + +@pytest.fixture +def mock_check_valid_model(): + with patch( + "haystack.preview.components.generators.hugging_face_tgi.check_valid_model", MagicMock(return_value=None) + ) as mock: + yield mock + + +@pytest.fixture +def mock_text_generation(): + with patch("huggingface_hub.InferenceClient.text_generation", autospec=True) as mock_text_generation: + mock_response = Mock() + mock_response.generated_text = "I'm fine, thanks." + details = Mock() + details.finish_reason = MagicMock(field1="value") + details.tokens = [1, 2, 3] + mock_response.details = details + mock_text_generation.return_value = mock_response + yield mock_text_generation + + +# used to test serialization of streaming_callback +def streaming_callback_handler(x): + return x + + +class TestHuggingFaceTGIGenerator: + @pytest.mark.unit + def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_valid_model): + model = "HuggingFaceH4/zephyr-7b-alpha" + generation_kwargs = {"n": 1} + stop_words = ["stop"] + streaming_callback = None + + generator = HuggingFaceTGIGenerator( + model=model, + url=None, + token=None, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + ) + + assert generator.model == model + assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}} + assert generator.tokenizer is None + assert generator.client is not None + assert generator.streaming_callback == streaming_callback + + @pytest.mark.unit + def test_to_dict(self, mock_check_valid_model): + # Initialize the HuggingFaceRemoteGenerator object with valid parameters + generator = HuggingFaceTGIGenerator( + token="token", generation_kwargs={"n": 5}, stop_words=["stop", "words"], streaming_callback=lambda x: x + ) + + # Call the to_dict method + result = generator.to_dict() + init_params = result["init_parameters"] + + # Assert that the init_params dictionary contains the expected keys and values + assert init_params["model"] == "mistralai/Mistral-7B-v0.1" + assert not init_params["token"] + assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]} + + @pytest.mark.unit + def test_from_dict(self, mock_check_valid_model): + generator = HuggingFaceTGIGenerator( + model="mistralai/Mistral-7B-v0.1", + generation_kwargs={"n": 5}, + stop_words=["stop", "words"], + streaming_callback=streaming_callback_handler, + ) + # Call the to_dict method + result = generator.to_dict() + + # now deserialize, call from_dict + generator_2 = HuggingFaceTGIGenerator.from_dict(result) + assert generator_2.model == "mistralai/Mistral-7B-v0.1" + assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]} + assert generator_2.streaming_callback is streaming_callback_handler + + @pytest.mark.unit + def test_initialize_with_invalid_url(self, mock_check_valid_model): + with pytest.raises(ValueError): + HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", url="invalid_url") + + @pytest.mark.unit + def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model): + # When custom TGI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id + mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") + with pytest.raises(RepositoryNotFoundError): + HuggingFaceTGIGenerator(model="invalid_model_id", url="https://some_chat_model.com") + + @pytest.mark.unit + def test_generate_text_response_with_valid_prompt_and_generation_parameters( + self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation + ): + model = "mistralai/Mistral-7B-v0.1" + + generation_kwargs = {"n": 1} + stop_words = ["stop"] + streaming_callback = None + + generator = HuggingFaceTGIGenerator( + model=model, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + ) + generator.warm_up() + + prompt = "Hello, how are you?" + response = generator.run(prompt) + + # check kwargs passed to text_generation + # note how n was not passed to text_generation + _, kwargs = mock_text_generation.call_args + assert kwargs == {"details": True, "stop_sequences": ["stop"]} + + assert isinstance(response, dict) + assert "replies" in response + assert "metadata" in response + assert isinstance(response["replies"], list) + assert isinstance(response["metadata"], list) + assert len(response["replies"]) == 1 + assert len(response["metadata"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + @pytest.mark.unit + def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters( + self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation + ): + model = "mistralai/Mistral-7B-v0.1" + generation_kwargs = {"n": 3} + stop_words = ["stop"] + streaming_callback = None + + generator = HuggingFaceTGIGenerator( + model=model, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + ) + generator.warm_up() + + prompt = "Hello, how are you?" + response = generator.run(prompt) + + # check kwargs passed to text_generation + # note how n was not passed to text_generation + _, kwargs = mock_text_generation.call_args + assert kwargs == {"details": True, "stop_sequences": ["stop"]} + + assert isinstance(response, dict) + assert "replies" in response + assert "metadata" in response + assert isinstance(response["replies"], list) + assert [isinstance(reply, str) for reply in response["replies"]] + + assert isinstance(response["metadata"], list) + assert len(response["replies"]) == 3 + assert len(response["metadata"]) == 3 + assert [isinstance(reply, dict) for reply in response["metadata"]] + + @pytest.mark.unit + def test_initialize_with_invalid_model(self, mock_check_valid_model): + model = "invalid_model" + generation_kwargs = {"n": 1} + stop_words = ["stop"] + streaming_callback = None + + mock_check_valid_model.side_effect = ValueError("Invalid model path or url") + + with pytest.raises(ValueError): + HuggingFaceTGIGenerator( + model=model, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + ) + + @pytest.mark.unit + def test_generate_text_with_stop_words(self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation): + generator = HuggingFaceTGIGenerator() + generator.warm_up() + + # Generate text response with stop words + response = generator.run("How are you?", generation_kwargs={"stop_words": ["stop", "words"]}) + + # check kwargs passed to text_generation + _, kwargs = mock_text_generation.call_args + assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]} + + # Assert that the response contains the generated replies + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Assert that the response contains the metadata + assert "metadata" in response + assert isinstance(response["metadata"], list) + assert len(response["metadata"]) > 0 + assert [isinstance(reply, dict) for reply in response["replies"]] + + @pytest.mark.unit + def test_generate_text_with_custom_generation_parameters( + self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation + ): + generator = HuggingFaceTGIGenerator() + generator.warm_up() + + generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100} + response = generator.run("How are you?", generation_kwargs=generation_kwargs) + + # check kwargs passed to text_generation + _, kwargs = mock_text_generation.call_args + assert kwargs == {"details": True, "max_new_tokens": 100, "stop_sequences": [], "temperature": 0.8} + + # Assert that the response contains the generated replies and the right response + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, str) for reply in response["replies"]] + assert response["replies"][0] == "I'm fine, thanks." + + # Assert that the response contains the metadata + assert "metadata" in response + assert isinstance(response["metadata"], list) + assert len(response["metadata"]) > 0 + assert [isinstance(reply, str) for reply in response["replies"]] + + @pytest.mark.unit + def test_generate_text_with_streaming_callback( + self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation + ): + streaming_call_count = 0 + + # Define the streaming callback function + def streaming_callback_fn(chunk: StreamingChunk): + nonlocal streaming_call_count + streaming_call_count += 1 + assert isinstance(chunk, StreamingChunk) + + # Create an instance of HuggingFaceRemoteGenerator + generator = HuggingFaceTGIGenerator(streaming_callback=streaming_callback_fn) + generator.warm_up() + + # Create a fake streamed response + # Don't remove self + def mock_iter(self): + yield TextGenerationStreamResponse( + generated_text=None, token=Token(id=1, text="I'm fine, thanks.", logprob=0.0, special=False) + ) + yield TextGenerationStreamResponse( + generated_text=None, + token=Token(id=1, text="Ok bye", logprob=0.0, special=False), + details=StreamDetails(finish_reason=FinishReason.Length, generated_tokens=5), + ) + + mock_response = Mock(**{"__iter__": mock_iter}) + mock_text_generation.return_value = mock_response + + # Generate text response with streaming callback + response = generator.run("prompt") + + # check kwargs passed to text_generation + _, kwargs = mock_text_generation.call_args + assert kwargs == {"details": True, "stop_sequences": [], "stream": True} + + # Assert that the streaming callback was called twice + assert streaming_call_count == 2 + + # Assert that the response contains the generated replies + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Assert that the response contains the metadata + assert "metadata" in response + assert isinstance(response["metadata"], list) + assert len(response["metadata"]) > 0 + assert [isinstance(reply, dict) for reply in response["replies"]]