mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-17 13:07:42 +00:00
feat: Add HuggingFaceTGIGenerator
Haystack 2.x component (#6205)
* Add HuggingFaceTGIGenerator * PR review * PR feedback from Stefano --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
This commit is contained in:
parent
8511b8cd79
commit
6e2dbdc320
@ -1,4 +1,5 @@
|
||||
from haystack.preview.components.generators.openai.gpt import GPTGenerator
|
||||
from haystack.preview.components.generators.hugging_face.hugging_face_local import HuggingFaceLocalGenerator
|
||||
from haystack.preview.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator
|
||||
|
||||
__all__ = ["GPTGenerator", "HuggingFaceLocalGenerator"]
|
||||
__all__ = ["GPTGenerator", "HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator"]
|
||||
|
@ -5,7 +5,7 @@ from huggingface_hub import InferenceClient, HfApi
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
|
||||
def check_generation_params(kwargs: Dict[str, Any], additional_accepted_params: Optional[List[str]] = None):
|
||||
def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepted_params: Optional[List[str]] = None):
|
||||
"""
|
||||
Check the provided generation parameters for validity.
|
||||
|
||||
|
232
haystack/preview/components/generators/hugging_face_tgi.py
Normal file
232
haystack/preview/components/generators/hugging_face_tgi.py
Normal file
@ -0,0 +1,232 @@
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Optional, Iterable, Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from haystack.preview import component, default_to_dict, default_from_dict
|
||||
from haystack.preview.components.generators.hf_utils import check_generation_params, check_valid_model
|
||||
from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||
from haystack.preview.dataclasses import StreamingChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class HuggingFaceTGIGenerator:
|
||||
"""
|
||||
Enables text generation using HuggingFace Hub hosted non-chat LLMs. This component is designed to seamlessly
|
||||
inference models deployed on the Text Generation Inference (TGI) backend.
|
||||
|
||||
You can use this component for LLMs hosted on Hugging Face inference endpoints, the rate-limited
|
||||
Inference API tier:
|
||||
|
||||
```python
|
||||
from haystack.preview.components.generators import HuggingFaceTGIGenerator
|
||||
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", token="<your-token>")
|
||||
client.warm_up()
|
||||
response = client.run("What's Natural Language Processing?", max_new_tokens=120)
|
||||
print(response)
|
||||
```
|
||||
|
||||
Or for LLMs hosted on paid https://huggingface.co/inference-endpoints endpoint, and/or your own custom TGI endpoint.
|
||||
In these two cases, you'll need to provide the URL of the endpoint as well as a valid token:
|
||||
|
||||
```python
|
||||
from haystack.preview.components.generators import HuggingFaceTGIGenerator
|
||||
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1",
|
||||
url="<your-tgi-endpoint-url>",
|
||||
token="<your-token>")
|
||||
client.warm_up()
|
||||
response = client.run("What's Natural Language Processing?", max_new_tokens=120)
|
||||
print(response)
|
||||
```
|
||||
|
||||
|
||||
Key Features and Compatibility:
|
||||
- **Primary Compatibility**: Designed to work seamlessly with any non-chat model deployed using the TGI
|
||||
framework. For more information on TGI, visit https://github.com/huggingface/text-generation-inference.
|
||||
- **Hugging Face Inference Endpoints**: Supports inference of TGI chat LLMs deployed on Hugging Face
|
||||
inference endpoints. For more details refer to https://huggingface.co/inference-endpoints.
|
||||
- **Inference API Support**: Supports inference of TGI LLMs hosted on the rate-limited Inference
|
||||
API tier. Learn more about the Inference API at: https://huggingface.co/inference-api
|
||||
Discover available LLMs using the following command:
|
||||
```
|
||||
wget -qO- https://api-inference.huggingface.co/framework/text-generation-inference
|
||||
```
|
||||
And simply use the model ID as the model parameter for this component. You'll also need to provide a valid
|
||||
Hugging Face API token as the token parameter.
|
||||
- **Custom TGI Endpoints**: Supports inference of LLMs deployed on custom TGI endpoints. Anyone can
|
||||
deploy their own TGI endpoint using the TGI framework. For more details refer
|
||||
to https://huggingface.co/inference-endpoints.
|
||||
Input and Output Format:
|
||||
- **String Format**: This component uses the str format for structuring both input and output,
|
||||
ensuring coherent and contextually relevant responses in text generation scenarios.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "mistralai/Mistral-7B-v0.1",
|
||||
url: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stop_words: Optional[List[str]] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the HuggingFaceTGIGenerator instance.
|
||||
|
||||
:param model: A string representing the model id on HF Hub. Default is "mistralai/Mistral-7B-v0.1".
|
||||
:param url: An optional string representing the URL of the TGI endpoint.
|
||||
:param token: The HuggingFace token to use as HTTP bearer authorization
|
||||
You can find your HF token at https://huggingface.co/settings/tokens
|
||||
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
|
||||
Some examples: `max_new_tokens`, `temperature`, `top_k`, `top_p`,...
|
||||
See Hugging Face's documentation for more information at:
|
||||
https://huggingface.co/docs/huggingface_hub/v0.18.0.rc0/en/package_reference/inference_client#huggingface_hub.inference._text_generation.TextGenerationParameters
|
||||
:param stop_words: An optional list of strings representing the stop words.
|
||||
:param streaming_callback: An optional callable for handling streaming responses.
|
||||
"""
|
||||
if url:
|
||||
r = urlparse(url)
|
||||
is_valid_url = all([r.scheme in ["http", "https"], r.netloc])
|
||||
if not is_valid_url:
|
||||
raise ValueError(f"Invalid TGI endpoint URL provided: {url}")
|
||||
|
||||
check_valid_model(model, token)
|
||||
|
||||
# handle generation kwargs setup
|
||||
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
|
||||
check_generation_params(generation_kwargs, ["n"])
|
||||
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
|
||||
generation_kwargs["stop_sequences"].extend(stop_words or [])
|
||||
|
||||
self.model = model
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.client = InferenceClient(url or model, token=token)
|
||||
self.streaming_callback = streaming_callback
|
||||
self.tokenizer = None
|
||||
|
||||
def warm_up(self) -> None:
|
||||
"""
|
||||
Load the tokenizer
|
||||
"""
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model, token=self.token)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
|
||||
:return: A dictionary containing the serialized component.
|
||||
"""
|
||||
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
|
||||
return default_to_dict(
|
||||
self,
|
||||
model=self.model,
|
||||
url=self.url,
|
||||
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
|
||||
generation_kwargs=self.generation_kwargs,
|
||||
streaming_callback=callback_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTGIGenerator":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
init_params = data.get("init_parameters", {})
|
||||
serialized_callback_handler = init_params.get("streaming_callback")
|
||||
if serialized_callback_handler:
|
||||
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Data that is sent to Posthog for usage analytics.
|
||||
"""
|
||||
# Don't send URL as it is sensitive information
|
||||
return {"model": self.model}
|
||||
|
||||
@component.output_types(replies=List[str], metadata=List[Dict[str, Any]])
|
||||
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Invoke the text generation inference for the given prompt and generation parameters.
|
||||
|
||||
:param prompt: A string representing the prompt.
|
||||
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||
:return: A dictionary containing the generated replies and metadata. Both are lists of length n.
|
||||
Replies are strings and metadata are dictionaries.
|
||||
"""
|
||||
# check generation kwargs given as parameters to override the default ones
|
||||
additional_params = ["n", "stop_words"]
|
||||
check_generation_params(generation_kwargs, additional_params)
|
||||
|
||||
# update generation kwargs by merging with the default ones
|
||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
num_responses = generation_kwargs.pop("n", 1)
|
||||
generation_kwargs.setdefault("stop_sequences", []).extend(generation_kwargs.pop("stop_words", []))
|
||||
|
||||
if self.tokenizer is None:
|
||||
raise RuntimeError("Please call warm_up() before running LLM inference.")
|
||||
|
||||
prompt_token_count = len(self.tokenizer.encode(prompt, add_special_tokens=False))
|
||||
|
||||
if self.streaming_callback:
|
||||
if num_responses > 1:
|
||||
raise ValueError("Cannot stream multiple responses, please set n=1.")
|
||||
|
||||
return self._run_streaming(prompt, prompt_token_count, generation_kwargs)
|
||||
|
||||
return self._run_non_streaming(prompt, prompt_token_count, num_responses, generation_kwargs)
|
||||
|
||||
def _run_streaming(self, prompt: str, prompt_token_count: int, generation_kwargs: Dict[str, Any]):
|
||||
res_chunk: Iterable[TextGenerationStreamResponse] = self.client.text_generation(
|
||||
prompt, details=True, stream=True, **generation_kwargs
|
||||
)
|
||||
chunks: List[StreamingChunk] = []
|
||||
# pylint: disable=not-an-iterable
|
||||
for chunk in res_chunk:
|
||||
token: Token = chunk.token
|
||||
if token.special:
|
||||
continue
|
||||
chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
|
||||
stream_chunk = StreamingChunk(token.text, chunk_metadata)
|
||||
chunks.append(stream_chunk)
|
||||
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
|
||||
metadata = {
|
||||
"finish_reason": chunks[-1].metadata.get("finish_reason", None),
|
||||
"model": self.client.model,
|
||||
"usage": {
|
||||
"completion_tokens": chunks[-1].metadata.get("generated_tokens", 0),
|
||||
"prompt_tokens": prompt_token_count,
|
||||
"total_tokens": prompt_token_count + chunks[-1].metadata.get("generated_tokens", 0),
|
||||
},
|
||||
}
|
||||
return {"replies": ["".join([chunk.content for chunk in chunks])], "metadata": [metadata]}
|
||||
|
||||
def _run_non_streaming(
|
||||
self, prompt: str, prompt_token_count: int, num_responses: int, generation_kwargs: Dict[str, Any]
|
||||
):
|
||||
responses: List[str] = []
|
||||
all_metadata: List[Dict[str, Any]] = []
|
||||
for _i in range(num_responses):
|
||||
tgr: TextGenerationResponse = self.client.text_generation(prompt, details=True, **generation_kwargs)
|
||||
all_metadata.append(
|
||||
{
|
||||
"model": self.client.model,
|
||||
"index": _i,
|
||||
"finish_reason": tgr.details.finish_reason.value,
|
||||
"usage": {
|
||||
"completion_tokens": len(tgr.details.tokens),
|
||||
"prompt_tokens": prompt_token_count,
|
||||
"total_tokens": prompt_token_count + len(tgr.details.tokens),
|
||||
},
|
||||
}
|
||||
)
|
||||
responses.append(tgr.generated_text)
|
||||
return {"replies": responses, "metadata": all_metadata}
|
@ -0,0 +1,5 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Adds `HuggingFaceTGIGenerator` for text generation. This components support remote inferencing for
|
||||
Hugging Face LLMs via text-generation-inference (TGI) protocol.
|
295
test/preview/components/generators/test_hugging_face_tgi.py
Normal file
295
test/preview/components/generators/test_hugging_face_tgi.py
Normal file
@ -0,0 +1,295 @@
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from haystack.preview.components.generators import HuggingFaceTGIGenerator
|
||||
from haystack.preview.dataclasses import StreamingChunk
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_check_valid_model():
|
||||
with patch(
|
||||
"haystack.preview.components.generators.hugging_face_tgi.check_valid_model", MagicMock(return_value=None)
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_text_generation():
|
||||
with patch("huggingface_hub.InferenceClient.text_generation", autospec=True) as mock_text_generation:
|
||||
mock_response = Mock()
|
||||
mock_response.generated_text = "I'm fine, thanks."
|
||||
details = Mock()
|
||||
details.finish_reason = MagicMock(field1="value")
|
||||
details.tokens = [1, 2, 3]
|
||||
mock_response.details = details
|
||||
mock_text_generation.return_value = mock_response
|
||||
yield mock_text_generation
|
||||
|
||||
|
||||
# used to test serialization of streaming_callback
|
||||
def streaming_callback_handler(x):
|
||||
return x
|
||||
|
||||
|
||||
class TestHuggingFaceTGIGenerator:
|
||||
@pytest.mark.unit
|
||||
def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_valid_model):
|
||||
model = "HuggingFaceH4/zephyr-7b-alpha"
|
||||
generation_kwargs = {"n": 1}
|
||||
stop_words = ["stop"]
|
||||
streaming_callback = None
|
||||
|
||||
generator = HuggingFaceTGIGenerator(
|
||||
model=model,
|
||||
url=None,
|
||||
token=None,
|
||||
generation_kwargs=generation_kwargs,
|
||||
stop_words=stop_words,
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
|
||||
assert generator.model == model
|
||||
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
|
||||
assert generator.tokenizer is None
|
||||
assert generator.client is not None
|
||||
assert generator.streaming_callback == streaming_callback
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self, mock_check_valid_model):
|
||||
# Initialize the HuggingFaceRemoteGenerator object with valid parameters
|
||||
generator = HuggingFaceTGIGenerator(
|
||||
token="token", generation_kwargs={"n": 5}, stop_words=["stop", "words"], streaming_callback=lambda x: x
|
||||
)
|
||||
|
||||
# Call the to_dict method
|
||||
result = generator.to_dict()
|
||||
init_params = result["init_parameters"]
|
||||
|
||||
# Assert that the init_params dictionary contains the expected keys and values
|
||||
assert init_params["model"] == "mistralai/Mistral-7B-v0.1"
|
||||
assert not init_params["token"]
|
||||
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self, mock_check_valid_model):
|
||||
generator = HuggingFaceTGIGenerator(
|
||||
model="mistralai/Mistral-7B-v0.1",
|
||||
generation_kwargs={"n": 5},
|
||||
stop_words=["stop", "words"],
|
||||
streaming_callback=streaming_callback_handler,
|
||||
)
|
||||
# Call the to_dict method
|
||||
result = generator.to_dict()
|
||||
|
||||
# now deserialize, call from_dict
|
||||
generator_2 = HuggingFaceTGIGenerator.from_dict(result)
|
||||
assert generator_2.model == "mistralai/Mistral-7B-v0.1"
|
||||
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]}
|
||||
assert generator_2.streaming_callback is streaming_callback_handler
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_initialize_with_invalid_url(self, mock_check_valid_model):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", url="invalid_url")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model):
|
||||
# When custom TGI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id
|
||||
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
|
||||
with pytest.raises(RepositoryNotFoundError):
|
||||
HuggingFaceTGIGenerator(model="invalid_model_id", url="https://some_chat_model.com")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generate_text_response_with_valid_prompt_and_generation_parameters(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
|
||||
):
|
||||
model = "mistralai/Mistral-7B-v0.1"
|
||||
|
||||
generation_kwargs = {"n": 1}
|
||||
stop_words = ["stop"]
|
||||
streaming_callback = None
|
||||
|
||||
generator = HuggingFaceTGIGenerator(
|
||||
model=model,
|
||||
generation_kwargs=generation_kwargs,
|
||||
stop_words=stop_words,
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
generator.warm_up()
|
||||
|
||||
prompt = "Hello, how are you?"
|
||||
response = generator.run(prompt)
|
||||
|
||||
# check kwargs passed to text_generation
|
||||
# note how n was not passed to text_generation
|
||||
_, kwargs = mock_text_generation.call_args
|
||||
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
|
||||
|
||||
assert isinstance(response, dict)
|
||||
assert "replies" in response
|
||||
assert "metadata" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert isinstance(response["metadata"], list)
|
||||
assert len(response["replies"]) == 1
|
||||
assert len(response["metadata"]) == 1
|
||||
assert [isinstance(reply, str) for reply in response["replies"]]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
|
||||
):
|
||||
model = "mistralai/Mistral-7B-v0.1"
|
||||
generation_kwargs = {"n": 3}
|
||||
stop_words = ["stop"]
|
||||
streaming_callback = None
|
||||
|
||||
generator = HuggingFaceTGIGenerator(
|
||||
model=model,
|
||||
generation_kwargs=generation_kwargs,
|
||||
stop_words=stop_words,
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
generator.warm_up()
|
||||
|
||||
prompt = "Hello, how are you?"
|
||||
response = generator.run(prompt)
|
||||
|
||||
# check kwargs passed to text_generation
|
||||
# note how n was not passed to text_generation
|
||||
_, kwargs = mock_text_generation.call_args
|
||||
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
|
||||
|
||||
assert isinstance(response, dict)
|
||||
assert "replies" in response
|
||||
assert "metadata" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert [isinstance(reply, str) for reply in response["replies"]]
|
||||
|
||||
assert isinstance(response["metadata"], list)
|
||||
assert len(response["replies"]) == 3
|
||||
assert len(response["metadata"]) == 3
|
||||
assert [isinstance(reply, dict) for reply in response["metadata"]]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_initialize_with_invalid_model(self, mock_check_valid_model):
|
||||
model = "invalid_model"
|
||||
generation_kwargs = {"n": 1}
|
||||
stop_words = ["stop"]
|
||||
streaming_callback = None
|
||||
|
||||
mock_check_valid_model.side_effect = ValueError("Invalid model path or url")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceTGIGenerator(
|
||||
model=model,
|
||||
generation_kwargs=generation_kwargs,
|
||||
stop_words=stop_words,
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generate_text_with_stop_words(self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation):
|
||||
generator = HuggingFaceTGIGenerator()
|
||||
generator.warm_up()
|
||||
|
||||
# Generate text response with stop words
|
||||
response = generator.run("How are you?", generation_kwargs={"stop_words": ["stop", "words"]})
|
||||
|
||||
# check kwargs passed to text_generation
|
||||
_, kwargs = mock_text_generation.call_args
|
||||
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]}
|
||||
|
||||
# Assert that the response contains the generated replies
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) > 0
|
||||
assert [isinstance(reply, str) for reply in response["replies"]]
|
||||
|
||||
# Assert that the response contains the metadata
|
||||
assert "metadata" in response
|
||||
assert isinstance(response["metadata"], list)
|
||||
assert len(response["metadata"]) > 0
|
||||
assert [isinstance(reply, dict) for reply in response["replies"]]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generate_text_with_custom_generation_parameters(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
|
||||
):
|
||||
generator = HuggingFaceTGIGenerator()
|
||||
generator.warm_up()
|
||||
|
||||
generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100}
|
||||
response = generator.run("How are you?", generation_kwargs=generation_kwargs)
|
||||
|
||||
# check kwargs passed to text_generation
|
||||
_, kwargs = mock_text_generation.call_args
|
||||
assert kwargs == {"details": True, "max_new_tokens": 100, "stop_sequences": [], "temperature": 0.8}
|
||||
|
||||
# Assert that the response contains the generated replies and the right response
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) > 0
|
||||
assert [isinstance(reply, str) for reply in response["replies"]]
|
||||
assert response["replies"][0] == "I'm fine, thanks."
|
||||
|
||||
# Assert that the response contains the metadata
|
||||
assert "metadata" in response
|
||||
assert isinstance(response["metadata"], list)
|
||||
assert len(response["metadata"]) > 0
|
||||
assert [isinstance(reply, str) for reply in response["replies"]]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generate_text_with_streaming_callback(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
|
||||
):
|
||||
streaming_call_count = 0
|
||||
|
||||
# Define the streaming callback function
|
||||
def streaming_callback_fn(chunk: StreamingChunk):
|
||||
nonlocal streaming_call_count
|
||||
streaming_call_count += 1
|
||||
assert isinstance(chunk, StreamingChunk)
|
||||
|
||||
# Create an instance of HuggingFaceRemoteGenerator
|
||||
generator = HuggingFaceTGIGenerator(streaming_callback=streaming_callback_fn)
|
||||
generator.warm_up()
|
||||
|
||||
# Create a fake streamed response
|
||||
# Don't remove self
|
||||
def mock_iter(self):
|
||||
yield TextGenerationStreamResponse(
|
||||
generated_text=None, token=Token(id=1, text="I'm fine, thanks.", logprob=0.0, special=False)
|
||||
)
|
||||
yield TextGenerationStreamResponse(
|
||||
generated_text=None,
|
||||
token=Token(id=1, text="Ok bye", logprob=0.0, special=False),
|
||||
details=StreamDetails(finish_reason=FinishReason.Length, generated_tokens=5),
|
||||
)
|
||||
|
||||
mock_response = Mock(**{"__iter__": mock_iter})
|
||||
mock_text_generation.return_value = mock_response
|
||||
|
||||
# Generate text response with streaming callback
|
||||
response = generator.run("prompt")
|
||||
|
||||
# check kwargs passed to text_generation
|
||||
_, kwargs = mock_text_generation.call_args
|
||||
assert kwargs == {"details": True, "stop_sequences": [], "stream": True}
|
||||
|
||||
# Assert that the streaming callback was called twice
|
||||
assert streaming_call_count == 2
|
||||
|
||||
# Assert that the response contains the generated replies
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) > 0
|
||||
assert [isinstance(reply, str) for reply in response["replies"]]
|
||||
|
||||
# Assert that the response contains the metadata
|
||||
assert "metadata" in response
|
||||
assert isinstance(response["metadata"], list)
|
||||
assert len(response["metadata"]) > 0
|
||||
assert [isinstance(reply, dict) for reply in response["replies"]]
|
Loading…
x
Reference in New Issue
Block a user