feat: Add HuggingFaceTGIGenerator Haystack 2.x component (#6205)

* Add HuggingFaceTGIGenerator

* PR review

* PR feedback from Stefano

---------

Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
This commit is contained in:
Vladimir Blagojevic 2023-11-02 19:35:16 +01:00 committed by GitHub
parent 8511b8cd79
commit 6e2dbdc320
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 535 additions and 2 deletions

View File

@ -1,4 +1,5 @@
from haystack.preview.components.generators.openai.gpt import GPTGenerator
from haystack.preview.components.generators.hugging_face.hugging_face_local import HuggingFaceLocalGenerator
from haystack.preview.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator
__all__ = ["GPTGenerator", "HuggingFaceLocalGenerator"]
__all__ = ["GPTGenerator", "HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator"]

View File

@ -5,7 +5,7 @@ from huggingface_hub import InferenceClient, HfApi
from huggingface_hub.utils import RepositoryNotFoundError
def check_generation_params(kwargs: Dict[str, Any], additional_accepted_params: Optional[List[str]] = None):
def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepted_params: Optional[List[str]] = None):
"""
Check the provided generation parameters for validity.

View File

@ -0,0 +1,232 @@
import logging
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Iterable, Callable
from urllib.parse import urlparse
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token
from transformers import AutoTokenizer
from haystack.preview import component, default_to_dict, default_from_dict
from haystack.preview.components.generators.hf_utils import check_generation_params, check_valid_model
from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.preview.dataclasses import StreamingChunk
logger = logging.getLogger(__name__)
@component
class HuggingFaceTGIGenerator:
"""
Enables text generation using HuggingFace Hub hosted non-chat LLMs. This component is designed to seamlessly
inference models deployed on the Text Generation Inference (TGI) backend.
You can use this component for LLMs hosted on Hugging Face inference endpoints, the rate-limited
Inference API tier:
```python
from haystack.preview.components.generators import HuggingFaceTGIGenerator
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", token="<your-token>")
client.warm_up()
response = client.run("What's Natural Language Processing?", max_new_tokens=120)
print(response)
```
Or for LLMs hosted on paid https://huggingface.co/inference-endpoints endpoint, and/or your own custom TGI endpoint.
In these two cases, you'll need to provide the URL of the endpoint as well as a valid token:
```python
from haystack.preview.components.generators import HuggingFaceTGIGenerator
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1",
url="<your-tgi-endpoint-url>",
token="<your-token>")
client.warm_up()
response = client.run("What's Natural Language Processing?", max_new_tokens=120)
print(response)
```
Key Features and Compatibility:
- **Primary Compatibility**: Designed to work seamlessly with any non-chat model deployed using the TGI
framework. For more information on TGI, visit https://github.com/huggingface/text-generation-inference.
- **Hugging Face Inference Endpoints**: Supports inference of TGI chat LLMs deployed on Hugging Face
inference endpoints. For more details refer to https://huggingface.co/inference-endpoints.
- **Inference API Support**: Supports inference of TGI LLMs hosted on the rate-limited Inference
API tier. Learn more about the Inference API at: https://huggingface.co/inference-api
Discover available LLMs using the following command:
```
wget -qO- https://api-inference.huggingface.co/framework/text-generation-inference
```
And simply use the model ID as the model parameter for this component. You'll also need to provide a valid
Hugging Face API token as the token parameter.
- **Custom TGI Endpoints**: Supports inference of LLMs deployed on custom TGI endpoints. Anyone can
deploy their own TGI endpoint using the TGI framework. For more details refer
to https://huggingface.co/inference-endpoints.
Input and Output Format:
- **String Format**: This component uses the str format for structuring both input and output,
ensuring coherent and contextually relevant responses in text generation scenarios.
"""
def __init__(
self,
model: str = "mistralai/Mistral-7B-v0.1",
url: Optional[str] = None,
token: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Initialize the HuggingFaceTGIGenerator instance.
:param model: A string representing the model id on HF Hub. Default is "mistralai/Mistral-7B-v0.1".
:param url: An optional string representing the URL of the TGI endpoint.
:param token: The HuggingFace token to use as HTTP bearer authorization
You can find your HF token at https://huggingface.co/settings/tokens
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
Some examples: `max_new_tokens`, `temperature`, `top_k`, `top_p`,...
See Hugging Face's documentation for more information at:
https://huggingface.co/docs/huggingface_hub/v0.18.0.rc0/en/package_reference/inference_client#huggingface_hub.inference._text_generation.TextGenerationParameters
:param stop_words: An optional list of strings representing the stop words.
:param streaming_callback: An optional callable for handling streaming responses.
"""
if url:
r = urlparse(url)
is_valid_url = all([r.scheme in ["http", "https"], r.netloc])
if not is_valid_url:
raise ValueError(f"Invalid TGI endpoint URL provided: {url}")
check_valid_model(model, token)
# handle generation kwargs setup
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
check_generation_params(generation_kwargs, ["n"])
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
generation_kwargs["stop_sequences"].extend(stop_words or [])
self.model = model
self.url = url
self.token = token
self.generation_kwargs = generation_kwargs
self.client = InferenceClient(url or model, token=token)
self.streaming_callback = streaming_callback
self.tokenizer = None
def warm_up(self) -> None:
"""
Load the tokenizer
"""
self.tokenizer = AutoTokenizer.from_pretrained(self.model, token=self.token)
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:return: A dictionary containing the serialized component.
"""
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
model=self.model,
url=self.url,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTGIGenerator":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
return default_from_dict(cls, data)
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
# Don't send URL as it is sensitive information
return {"model": self.model}
@component.output_types(replies=List[str], metadata=List[Dict[str, Any]])
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Invoke the text generation inference for the given prompt and generation parameters.
:param prompt: A string representing the prompt.
:param generation_kwargs: Additional keyword arguments for text generation.
:return: A dictionary containing the generated replies and metadata. Both are lists of length n.
Replies are strings and metadata are dictionaries.
"""
# check generation kwargs given as parameters to override the default ones
additional_params = ["n", "stop_words"]
check_generation_params(generation_kwargs, additional_params)
# update generation kwargs by merging with the default ones
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
num_responses = generation_kwargs.pop("n", 1)
generation_kwargs.setdefault("stop_sequences", []).extend(generation_kwargs.pop("stop_words", []))
if self.tokenizer is None:
raise RuntimeError("Please call warm_up() before running LLM inference.")
prompt_token_count = len(self.tokenizer.encode(prompt, add_special_tokens=False))
if self.streaming_callback:
if num_responses > 1:
raise ValueError("Cannot stream multiple responses, please set n=1.")
return self._run_streaming(prompt, prompt_token_count, generation_kwargs)
return self._run_non_streaming(prompt, prompt_token_count, num_responses, generation_kwargs)
def _run_streaming(self, prompt: str, prompt_token_count: int, generation_kwargs: Dict[str, Any]):
res_chunk: Iterable[TextGenerationStreamResponse] = self.client.text_generation(
prompt, details=True, stream=True, **generation_kwargs
)
chunks: List[StreamingChunk] = []
# pylint: disable=not-an-iterable
for chunk in res_chunk:
token: Token = chunk.token
if token.special:
continue
chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
stream_chunk = StreamingChunk(token.text, chunk_metadata)
chunks.append(stream_chunk)
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
metadata = {
"finish_reason": chunks[-1].metadata.get("finish_reason", None),
"model": self.client.model,
"usage": {
"completion_tokens": chunks[-1].metadata.get("generated_tokens", 0),
"prompt_tokens": prompt_token_count,
"total_tokens": prompt_token_count + chunks[-1].metadata.get("generated_tokens", 0),
},
}
return {"replies": ["".join([chunk.content for chunk in chunks])], "metadata": [metadata]}
def _run_non_streaming(
self, prompt: str, prompt_token_count: int, num_responses: int, generation_kwargs: Dict[str, Any]
):
responses: List[str] = []
all_metadata: List[Dict[str, Any]] = []
for _i in range(num_responses):
tgr: TextGenerationResponse = self.client.text_generation(prompt, details=True, **generation_kwargs)
all_metadata.append(
{
"model": self.client.model,
"index": _i,
"finish_reason": tgr.details.finish_reason.value,
"usage": {
"completion_tokens": len(tgr.details.tokens),
"prompt_tokens": prompt_token_count,
"total_tokens": prompt_token_count + len(tgr.details.tokens),
},
}
)
responses.append(tgr.generated_text)
return {"replies": responses, "metadata": all_metadata}

View File

@ -0,0 +1,5 @@
---
preview:
- |
Adds `HuggingFaceTGIGenerator` for text generation. This components support remote inferencing for
Hugging Face LLMs via text-generation-inference (TGI) protocol.

View File

@ -0,0 +1,295 @@
from unittest.mock import patch, MagicMock, Mock
import pytest
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
from huggingface_hub.utils import RepositoryNotFoundError
from haystack.preview.components.generators import HuggingFaceTGIGenerator
from haystack.preview.dataclasses import StreamingChunk
@pytest.fixture
def mock_check_valid_model():
with patch(
"haystack.preview.components.generators.hugging_face_tgi.check_valid_model", MagicMock(return_value=None)
) as mock:
yield mock
@pytest.fixture
def mock_text_generation():
with patch("huggingface_hub.InferenceClient.text_generation", autospec=True) as mock_text_generation:
mock_response = Mock()
mock_response.generated_text = "I'm fine, thanks."
details = Mock()
details.finish_reason = MagicMock(field1="value")
details.tokens = [1, 2, 3]
mock_response.details = details
mock_text_generation.return_value = mock_response
yield mock_text_generation
# used to test serialization of streaming_callback
def streaming_callback_handler(x):
return x
class TestHuggingFaceTGIGenerator:
@pytest.mark.unit
def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_valid_model):
model = "HuggingFaceH4/zephyr-7b-alpha"
generation_kwargs = {"n": 1}
stop_words = ["stop"]
streaming_callback = None
generator = HuggingFaceTGIGenerator(
model=model,
url=None,
token=None,
generation_kwargs=generation_kwargs,
stop_words=stop_words,
streaming_callback=streaming_callback,
)
assert generator.model == model
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
assert generator.tokenizer is None
assert generator.client is not None
assert generator.streaming_callback == streaming_callback
@pytest.mark.unit
def test_to_dict(self, mock_check_valid_model):
# Initialize the HuggingFaceRemoteGenerator object with valid parameters
generator = HuggingFaceTGIGenerator(
token="token", generation_kwargs={"n": 5}, stop_words=["stop", "words"], streaming_callback=lambda x: x
)
# Call the to_dict method
result = generator.to_dict()
init_params = result["init_parameters"]
# Assert that the init_params dictionary contains the expected keys and values
assert init_params["model"] == "mistralai/Mistral-7B-v0.1"
assert not init_params["token"]
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
@pytest.mark.unit
def test_from_dict(self, mock_check_valid_model):
generator = HuggingFaceTGIGenerator(
model="mistralai/Mistral-7B-v0.1",
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=streaming_callback_handler,
)
# Call the to_dict method
result = generator.to_dict()
# now deserialize, call from_dict
generator_2 = HuggingFaceTGIGenerator.from_dict(result)
assert generator_2.model == "mistralai/Mistral-7B-v0.1"
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]}
assert generator_2.streaming_callback is streaming_callback_handler
@pytest.mark.unit
def test_initialize_with_invalid_url(self, mock_check_valid_model):
with pytest.raises(ValueError):
HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", url="invalid_url")
@pytest.mark.unit
def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model):
# When custom TGI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
with pytest.raises(RepositoryNotFoundError):
HuggingFaceTGIGenerator(model="invalid_model_id", url="https://some_chat_model.com")
@pytest.mark.unit
def test_generate_text_response_with_valid_prompt_and_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
):
model = "mistralai/Mistral-7B-v0.1"
generation_kwargs = {"n": 1}
stop_words = ["stop"]
streaming_callback = None
generator = HuggingFaceTGIGenerator(
model=model,
generation_kwargs=generation_kwargs,
stop_words=stop_words,
streaming_callback=streaming_callback,
)
generator.warm_up()
prompt = "Hello, how are you?"
response = generator.run(prompt)
# check kwargs passed to text_generation
# note how n was not passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
assert isinstance(response, dict)
assert "replies" in response
assert "metadata" in response
assert isinstance(response["replies"], list)
assert isinstance(response["metadata"], list)
assert len(response["replies"]) == 1
assert len(response["metadata"]) == 1
assert [isinstance(reply, str) for reply in response["replies"]]
@pytest.mark.unit
def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
):
model = "mistralai/Mistral-7B-v0.1"
generation_kwargs = {"n": 3}
stop_words = ["stop"]
streaming_callback = None
generator = HuggingFaceTGIGenerator(
model=model,
generation_kwargs=generation_kwargs,
stop_words=stop_words,
streaming_callback=streaming_callback,
)
generator.warm_up()
prompt = "Hello, how are you?"
response = generator.run(prompt)
# check kwargs passed to text_generation
# note how n was not passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
assert isinstance(response, dict)
assert "replies" in response
assert "metadata" in response
assert isinstance(response["replies"], list)
assert [isinstance(reply, str) for reply in response["replies"]]
assert isinstance(response["metadata"], list)
assert len(response["replies"]) == 3
assert len(response["metadata"]) == 3
assert [isinstance(reply, dict) for reply in response["metadata"]]
@pytest.mark.unit
def test_initialize_with_invalid_model(self, mock_check_valid_model):
model = "invalid_model"
generation_kwargs = {"n": 1}
stop_words = ["stop"]
streaming_callback = None
mock_check_valid_model.side_effect = ValueError("Invalid model path or url")
with pytest.raises(ValueError):
HuggingFaceTGIGenerator(
model=model,
generation_kwargs=generation_kwargs,
stop_words=stop_words,
streaming_callback=streaming_callback,
)
@pytest.mark.unit
def test_generate_text_with_stop_words(self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation):
generator = HuggingFaceTGIGenerator()
generator.warm_up()
# Generate text response with stop words
response = generator.run("How are you?", generation_kwargs={"stop_words": ["stop", "words"]})
# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]}
# Assert that the response contains the generated replies
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, str) for reply in response["replies"]]
# Assert that the response contains the metadata
assert "metadata" in response
assert isinstance(response["metadata"], list)
assert len(response["metadata"]) > 0
assert [isinstance(reply, dict) for reply in response["replies"]]
@pytest.mark.unit
def test_generate_text_with_custom_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
):
generator = HuggingFaceTGIGenerator()
generator.warm_up()
generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100}
response = generator.run("How are you?", generation_kwargs=generation_kwargs)
# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "max_new_tokens": 100, "stop_sequences": [], "temperature": 0.8}
# Assert that the response contains the generated replies and the right response
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, str) for reply in response["replies"]]
assert response["replies"][0] == "I'm fine, thanks."
# Assert that the response contains the metadata
assert "metadata" in response
assert isinstance(response["metadata"], list)
assert len(response["metadata"]) > 0
assert [isinstance(reply, str) for reply in response["replies"]]
@pytest.mark.unit
def test_generate_text_with_streaming_callback(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
):
streaming_call_count = 0
# Define the streaming callback function
def streaming_callback_fn(chunk: StreamingChunk):
nonlocal streaming_call_count
streaming_call_count += 1
assert isinstance(chunk, StreamingChunk)
# Create an instance of HuggingFaceRemoteGenerator
generator = HuggingFaceTGIGenerator(streaming_callback=streaming_callback_fn)
generator.warm_up()
# Create a fake streamed response
# Don't remove self
def mock_iter(self):
yield TextGenerationStreamResponse(
generated_text=None, token=Token(id=1, text="I'm fine, thanks.", logprob=0.0, special=False)
)
yield TextGenerationStreamResponse(
generated_text=None,
token=Token(id=1, text="Ok bye", logprob=0.0, special=False),
details=StreamDetails(finish_reason=FinishReason.Length, generated_tokens=5),
)
mock_response = Mock(**{"__iter__": mock_iter})
mock_text_generation.return_value = mock_response
# Generate text response with streaming callback
response = generator.run("prompt")
# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": [], "stream": True}
# Assert that the streaming callback was called twice
assert streaming_call_count == 2
# Assert that the response contains the generated replies
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, str) for reply in response["replies"]]
# Assert that the response contains the metadata
assert "metadata" in response
assert isinstance(response["metadata"], list)
assert len(response["metadata"]) > 0
assert [isinstance(reply, dict) for reply in response["replies"]]