mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-29 16:08:38 +00:00
feat: HuggingFaceAPIChatGenerator (#7480)
* draft * docstrings and more tests * deprecation; reno * pydoc config * better error messages * wip * add test * better docstrings * deprecation; reno * pylint * typo * rm unneeded else * rm unneeded else * fixes from feedback * docstring showing the enum * improve docstring * make params mandatory * Apply suggestions from code review Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * document enum * Update haystack/utils/hf.py Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * mandatory params * fix test * fix test --------- Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
parent
1d083861ff
commit
0dbb98c0a0
@ -11,6 +11,7 @@ loaders:
|
||||
"chat/azure",
|
||||
"chat/hugging_face_local",
|
||||
"chat/hugging_face_tgi",
|
||||
"chat/hugging_face_api",
|
||||
"chat/openai",
|
||||
]
|
||||
ignore_when_discovered: ["__init__"]
|
||||
|
||||
@ -4,10 +4,12 @@ from haystack.components.generators.chat.openai import ( # noqa: I001 (otherwis
|
||||
from haystack.components.generators.chat.azure import AzureOpenAIChatGenerator
|
||||
from haystack.components.generators.chat.hugging_face_local import HuggingFaceLocalChatGenerator
|
||||
from haystack.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator
|
||||
from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceLocalChatGenerator",
|
||||
"HuggingFaceTGIChatGenerator",
|
||||
"HuggingFaceAPIChatGenerator",
|
||||
"OpenAIChatGenerator",
|
||||
"AzureOpenAIChatGenerator",
|
||||
]
|
||||
|
||||
236
haystack/components/generators/chat/hugging_face_api.py
Normal file
236
haystack/components/generators/chat/hugging_face_api.py
Normal file
@ -0,0 +1,236 @@
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
|
||||
from haystack.utils.url_validation import is_valid_http_url
|
||||
|
||||
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.22.0\"'") as huggingface_hub_import:
|
||||
from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput, InferenceClient
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class HuggingFaceAPIChatGenerator:
|
||||
"""
|
||||
This component can be used to generate text using different Hugging Face APIs with the ChatMessage format:
|
||||
- [Free Serverless Inference API](https://huggingface.co/inference-api)
|
||||
- [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
|
||||
- [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
|
||||
|
||||
Input and Output Format:
|
||||
- ChatMessage Format: This component uses the ChatMessage format to structure both input and output,
|
||||
ensuring coherent and contextually relevant responses in chat-based text generation scenarios. Details on the
|
||||
ChatMessage format can be found [here](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage).
|
||||
|
||||
|
||||
Example usage with the free Serverless Inference API:
|
||||
```python
|
||||
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.utils import Secret
|
||||
from haystack.utils.hf import HFGenerationAPIType
|
||||
|
||||
messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
|
||||
ChatMessage.from_user("What's Natural Language Processing?")]
|
||||
|
||||
# the api_type can be expressed using the HFGenerationAPIType enum or as a string
|
||||
api_type = HFGenerationAPIType.SERVERLESS_INFERENCE_API
|
||||
api_type = "serverless_inference_api" # this is equivalent to the above
|
||||
|
||||
generator = HuggingFaceAPIChatGenerator(api_type=api_type,
|
||||
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
|
||||
token=Secret.from_token("<your-api-key>"))
|
||||
|
||||
result = generator.run(messages)
|
||||
print(result)
|
||||
```
|
||||
|
||||
Example usage with paid Inference Endpoints:
|
||||
```python
|
||||
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.utils import Secret
|
||||
|
||||
messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
|
||||
ChatMessage.from_user("What's Natural Language Processing?")]
|
||||
|
||||
generator = HuggingFaceAPIChatGenerator(api_type="inference_endpoints",
|
||||
api_params={"url": "<your-inference-endpoint-url>"},
|
||||
token=Secret.from_token("<your-api-key>"))
|
||||
|
||||
result = generator.run(messages)
|
||||
print(result)
|
||||
|
||||
Example usage with self-hosted Text Generation Inference:
|
||||
```python
|
||||
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
|
||||
from haystack.dataclasses import ChatMessage
|
||||
|
||||
messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
|
||||
ChatMessage.from_user("What's Natural Language Processing?")]
|
||||
|
||||
generator = HuggingFaceAPIChatGenerator(api_type="text_generation_inference",
|
||||
api_params={"url": "http://localhost:8080"})
|
||||
|
||||
result = generator.run(messages)
|
||||
print(result)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_type: Union[HFGenerationAPIType, str],
|
||||
api_params: Dict[str, str],
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stop_words: Optional[List[str]] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the HuggingFaceAPIChatGenerator instance.
|
||||
|
||||
:param api_type:
|
||||
The type of Hugging Face API to use.
|
||||
:param api_params:
|
||||
A dictionary containing the following keys:
|
||||
- `model`: model ID on the Hugging Face Hub. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
|
||||
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or `TEXT_GENERATION_INFERENCE`.
|
||||
:param token: The HuggingFace token to use as HTTP bearer authorization
|
||||
You can find your HF token in your [account settings](https://huggingface.co/settings/tokens)
|
||||
:param generation_kwargs:
|
||||
A dictionary containing keyword arguments to customize text generation.
|
||||
Some examples: `max_tokens`, `temperature`, `top_p`...
|
||||
See Hugging Face's documentation for more information at: [chat_completion](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
|
||||
:param stop_words: An optional list of strings representing the stop words.
|
||||
:param streaming_callback: An optional callable for handling streaming responses.
|
||||
"""
|
||||
|
||||
huggingface_hub_import.check()
|
||||
|
||||
if isinstance(api_type, str):
|
||||
api_type = HFGenerationAPIType.from_str(api_type)
|
||||
|
||||
if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
|
||||
model = api_params.get("model")
|
||||
if model is None:
|
||||
raise ValueError(
|
||||
"To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
|
||||
)
|
||||
check_valid_model(model, HFModelType.GENERATION, token)
|
||||
model_or_url = model
|
||||
elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]:
|
||||
url = api_params.get("url")
|
||||
if url is None:
|
||||
raise ValueError(
|
||||
"To use Text Generation Inference or Inference Endpoints, you need to specify the `url` parameter in `api_params`."
|
||||
)
|
||||
if not is_valid_http_url(url):
|
||||
raise ValueError(f"Invalid URL: {url}")
|
||||
model_or_url = url
|
||||
|
||||
# handle generation kwargs setup
|
||||
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
|
||||
generation_kwargs["stop"] = generation_kwargs.get("stop", [])
|
||||
generation_kwargs["stop"].extend(stop_words or [])
|
||||
generation_kwargs.setdefault("max_tokens", 512)
|
||||
|
||||
self.api_type = api_type
|
||||
self.api_params = api_params
|
||||
self.token = token
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.streaming_callback = streaming_callback
|
||||
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
|
||||
:returns:
|
||||
A dictionary containing the serialized component.
|
||||
"""
|
||||
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
|
||||
return default_to_dict(
|
||||
self,
|
||||
api_type=self.api_type,
|
||||
api_params=self.api_params,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
generation_kwargs=self.generation_kwargs,
|
||||
streaming_callback=callback_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
init_params = data.get("init_parameters", {})
|
||||
serialized_callback_handler = init_params.get("streaming_callback")
|
||||
if serialized_callback_handler:
|
||||
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(replies=List[ChatMessage])
|
||||
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Invoke the text generation inference based on the provided messages and generation parameters.
|
||||
|
||||
:param messages: A list of ChatMessage instances representing the input messages.
|
||||
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||
:returns: A dictionary with the following keys:
|
||||
- `replies`: A list containing the generated responses as ChatMessage instances.
|
||||
"""
|
||||
|
||||
# update generation kwargs by merging with the default ones
|
||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
|
||||
formatted_messages = [m.to_openai_format() for m in messages]
|
||||
|
||||
if self.streaming_callback:
|
||||
return self._run_streaming(formatted_messages, generation_kwargs)
|
||||
|
||||
return self._run_non_streaming(formatted_messages, generation_kwargs)
|
||||
|
||||
def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]):
|
||||
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
|
||||
messages, stream=True, **generation_kwargs
|
||||
)
|
||||
|
||||
generated_text = ""
|
||||
|
||||
for chunk in api_output: # pylint: disable=not-an-iterable
|
||||
text = chunk.choices[0].delta.content
|
||||
if text:
|
||||
generated_text += text
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
|
||||
meta = {}
|
||||
if finish_reason:
|
||||
meta["finish_reason"] = finish_reason
|
||||
|
||||
stream_chunk = StreamingChunk(text, meta)
|
||||
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
|
||||
|
||||
message = ChatMessage.from_assistant(generated_text)
|
||||
message.meta.update({"model": self._client.model, "finish_reason": finish_reason, "index": 0})
|
||||
return {"replies": [message]}
|
||||
|
||||
def _run_non_streaming(
|
||||
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]
|
||||
) -> Dict[str, List[ChatMessage]]:
|
||||
chat_messages: List[ChatMessage] = []
|
||||
|
||||
api_chat_output: ChatCompletionOutput = self._client.chat_completion(messages, **generation_kwargs)
|
||||
|
||||
for choice in api_chat_output.choices:
|
||||
message = ChatMessage.from_assistant(choice.message.content)
|
||||
message.meta.update(
|
||||
{"model": self._client.model, "finish_reason": choice.finish_reason, "index": choice.index}
|
||||
)
|
||||
chat_messages.append(message)
|
||||
return {"replies": chat_messages}
|
||||
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
@ -113,6 +114,11 @@ class HuggingFaceTGIChatGenerator:
|
||||
:param stop_words: An optional list of strings representing the stop words.
|
||||
:param streaming_callback: An optional callable for handling streaming responses.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`HuggingFaceTGIChatGenerator` is deprecated and will be removed in Haystack 2.3.0."
|
||||
"Use `HuggingFaceAPIChatGenerator` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
transformers_import.check()
|
||||
|
||||
if url:
|
||||
|
||||
14
releasenotes/notes/hfapichatgenerator-51772e1f0d679b1c.yaml
Normal file
14
releasenotes/notes/hfapichatgenerator-51772e1f0d679b1c.yaml
Normal file
@ -0,0 +1,14 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Introduce `HuggingFaceAPIChatGenerator`.
|
||||
This text-generation component uses the ChatMessage format and supports different Hugging Face APIs:
|
||||
- free Serverless Inference API
|
||||
- paid Inference Endpoints
|
||||
- self-hosted Text Generation Inference.
|
||||
|
||||
This generator will replace the `HuggingFaceTGIChatGenerator` in the future.
|
||||
deprecations:
|
||||
- |
|
||||
Deprecate `HuggingFaceTGIChatGenerator`. This component will be removed in Haystack 2.3.0.
|
||||
Use `HuggingFaceAPIChatGenerator` instead.
|
||||
256
test/components/generators/chat/test_hugging_face_api.py
Normal file
256
test/components/generators/chat/test_hugging_face_api.py
Normal file
@ -0,0 +1,256 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import (
|
||||
ChatCompletionOutput,
|
||||
ChatCompletionOutputChoice,
|
||||
ChatCompletionOutputChoiceMessage,
|
||||
ChatCompletionStreamOutput,
|
||||
ChatCompletionStreamOutputChoice,
|
||||
ChatCompletionStreamOutputDelta,
|
||||
)
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.utils.hf import HFGenerationAPIType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_check_valid_model():
|
||||
with patch(
|
||||
"haystack.components.generators.chat.hugging_face_api.check_valid_model", MagicMock(return_value=None)
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chat_completion():
|
||||
# https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.example
|
||||
|
||||
with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion:
|
||||
completion = ChatCompletionOutput(
|
||||
choices=[
|
||||
ChatCompletionOutputChoice(
|
||||
finish_reason="eos_token",
|
||||
index=0,
|
||||
message=ChatCompletionOutputChoiceMessage(
|
||||
content="The capital of France is Paris.", role="assistant"
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1710498360,
|
||||
)
|
||||
|
||||
mock_chat_completion.return_value = completion
|
||||
yield mock_chat_completion
|
||||
|
||||
|
||||
# used to test serialization of streaming_callback
|
||||
def streaming_callback_handler(x):
|
||||
return x
|
||||
|
||||
|
||||
class TestHuggingFaceAPIGenerator:
|
||||
def test_init_invalid_api_type(self):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceAPIChatGenerator(api_type="invalid_api_type", api_params={})
|
||||
|
||||
def test_init_serverless(self, mock_check_valid_model):
|
||||
model = "HuggingFaceH4/zephyr-7b-alpha"
|
||||
generation_kwargs = {"temperature": 0.6}
|
||||
stop_words = ["stop"]
|
||||
streaming_callback = None
|
||||
|
||||
generator = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": model},
|
||||
token=None,
|
||||
generation_kwargs=generation_kwargs,
|
||||
stop_words=stop_words,
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
|
||||
assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API
|
||||
assert generator.api_params == {"model": model}
|
||||
assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}}
|
||||
assert generator.streaming_callback == streaming_callback
|
||||
|
||||
def test_init_serverless_invalid_model(self, mock_check_valid_model):
|
||||
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
|
||||
with pytest.raises(RepositoryNotFoundError):
|
||||
HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"}
|
||||
)
|
||||
|
||||
def test_init_serverless_no_model(self):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"}
|
||||
)
|
||||
|
||||
def test_init_tgi(self):
|
||||
url = "https://some_model.com"
|
||||
generation_kwargs = {"temperature": 0.6}
|
||||
stop_words = ["stop"]
|
||||
streaming_callback = None
|
||||
|
||||
generator = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE,
|
||||
api_params={"url": url},
|
||||
token=None,
|
||||
generation_kwargs=generation_kwargs,
|
||||
stop_words=stop_words,
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
|
||||
assert generator.api_type == HFGenerationAPIType.TEXT_GENERATION_INFERENCE
|
||||
assert generator.api_params == {"url": url}
|
||||
assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}}
|
||||
assert generator.streaming_callback == streaming_callback
|
||||
|
||||
def test_init_tgi_invalid_url(self):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"url": "invalid_url"}
|
||||
)
|
||||
|
||||
def test_init_tgi_no_url(self):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"param": "irrelevant"}
|
||||
)
|
||||
|
||||
def test_to_dict(self, mock_check_valid_model):
|
||||
generator = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "mistralai/Mistral-7B-v0.1"},
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
generation_kwargs={"temperature": 0.6},
|
||||
stop_words=["stop", "words"],
|
||||
)
|
||||
|
||||
result = generator.to_dict()
|
||||
init_params = result["init_parameters"]
|
||||
|
||||
assert init_params["api_type"] == HFGenerationAPIType.SERVERLESS_INFERENCE_API
|
||||
assert init_params["api_params"] == {"model": "mistralai/Mistral-7B-v0.1"}
|
||||
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
|
||||
assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}
|
||||
|
||||
def test_from_dict(self, mock_check_valid_model):
|
||||
generator = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "mistralai/Mistral-7B-v0.1"},
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
generation_kwargs={"temperature": 0.6},
|
||||
stop_words=["stop", "words"],
|
||||
streaming_callback=streaming_callback_handler,
|
||||
)
|
||||
result = generator.to_dict()
|
||||
|
||||
# now deserialize, call from_dict
|
||||
generator_2 = HuggingFaceAPIChatGenerator.from_dict(result)
|
||||
assert generator_2.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API
|
||||
assert generator_2.api_params == {"model": "mistralai/Mistral-7B-v0.1"}
|
||||
assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False)
|
||||
assert generator_2.generation_kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}
|
||||
assert generator_2.streaming_callback is streaming_callback_handler
|
||||
|
||||
def test_generate_text_response_with_valid_prompt_and_generation_parameters(
|
||||
self, mock_check_valid_model, mock_chat_completion, chat_messages
|
||||
):
|
||||
generator = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
|
||||
generation_kwargs={"temperature": 0.6},
|
||||
stop_words=["stop", "words"],
|
||||
streaming_callback=None,
|
||||
)
|
||||
|
||||
response = generator.run(messages=chat_messages)
|
||||
|
||||
# check kwargs passed to text_generation
|
||||
_, kwargs = mock_chat_completion.call_args
|
||||
assert kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}
|
||||
|
||||
assert isinstance(response, dict)
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) == 1
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
|
||||
def test_generate_text_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages):
|
||||
streaming_call_count = 0
|
||||
|
||||
# Define the streaming callback function
|
||||
def streaming_callback_fn(chunk: StreamingChunk):
|
||||
nonlocal streaming_call_count
|
||||
streaming_call_count += 1
|
||||
assert isinstance(chunk, StreamingChunk)
|
||||
|
||||
generator = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
|
||||
streaming_callback=streaming_callback_fn,
|
||||
)
|
||||
|
||||
# Create a fake streamed response
|
||||
# self needed here, don't remove
|
||||
def mock_iter(self):
|
||||
yield ChatCompletionStreamOutput(
|
||||
choices=[
|
||||
ChatCompletionStreamOutputChoice(
|
||||
delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"),
|
||||
index=0,
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
created=1710498504,
|
||||
)
|
||||
|
||||
yield ChatCompletionStreamOutput(
|
||||
choices=[
|
||||
ChatCompletionStreamOutputChoice(
|
||||
delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length"
|
||||
)
|
||||
],
|
||||
created=1710498504,
|
||||
)
|
||||
|
||||
mock_response = Mock(**{"__iter__": mock_iter})
|
||||
mock_chat_completion.return_value = mock_response
|
||||
|
||||
# Generate text response with streaming callback
|
||||
response = generator.run(chat_messages)
|
||||
print(response)
|
||||
|
||||
# check kwargs passed to text_generation
|
||||
_, kwargs = mock_chat_completion.call_args
|
||||
assert kwargs == {"stop": [], "stream": True, "max_tokens": 512}
|
||||
|
||||
# Assert that the streaming callback was called twice
|
||||
assert streaming_call_count == 2
|
||||
|
||||
# Assert that the response contains the generated replies
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) > 0
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_serverless(self):
|
||||
generator = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
|
||||
generation_kwargs={"max_tokens": 20},
|
||||
)
|
||||
|
||||
messages = [ChatMessage.from_user("What is the capital of France?")]
|
||||
response = generator.run(messages=messages)
|
||||
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) > 0
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
Loading…
x
Reference in New Issue
Block a user