mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 12:37:27 +00:00
feat: Add usage when using HuggingFaceAPIChatGenerator with streaming (#9371)
* Small fix and update tests * Add usage support to streaming for HuggingFaceAPIChatGenerator * Add reno * try using provider='auto' * Undo provider * Fix unit tests * Update releasenotes/notes/add-usage-hf-api-chat-streaming-91fd04705f45d5b3.yaml Co-authored-by: Julian Risch <julian.risch@deepset.ai> --------- Co-authored-by: anakin87 <stefanofiorucci@gmail.com> Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
parent
9ae76e1653
commit
af073852d0
@ -27,6 +27,7 @@ with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"
|
||||
from huggingface_hub import (
|
||||
AsyncInferenceClient,
|
||||
ChatCompletionInputFunctionDefinition,
|
||||
ChatCompletionInputStreamOptions,
|
||||
ChatCompletionInputTool,
|
||||
ChatCompletionOutput,
|
||||
ChatCompletionOutputToolCall,
|
||||
@ -396,37 +397,52 @@ class HuggingFaceAPIChatGenerator:
|
||||
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
|
||||
):
|
||||
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
|
||||
messages, stream=True, **generation_kwargs
|
||||
messages,
|
||||
stream=True,
|
||||
stream_options=ChatCompletionInputStreamOptions(include_usage=True),
|
||||
**generation_kwargs,
|
||||
)
|
||||
|
||||
generated_text = ""
|
||||
first_chunk_time = None
|
||||
finish_reason = None
|
||||
usage = None
|
||||
meta: Dict[str, Any] = {}
|
||||
|
||||
for chunk in api_output:
|
||||
# n is unused, so the API always returns only one choice
|
||||
# the argument is probably allowed for compatibility with OpenAI
|
||||
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
|
||||
choice = chunk.choices[0]
|
||||
# The chunk with usage returns an empty array for choices
|
||||
if len(chunk.choices) > 0:
|
||||
# n is unused, so the API always returns only one choice
|
||||
# the argument is probably allowed for compatibility with OpenAI
|
||||
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
|
||||
choice = chunk.choices[0]
|
||||
|
||||
text = choice.delta.content or ""
|
||||
generated_text += text
|
||||
text = choice.delta.content or ""
|
||||
generated_text += text
|
||||
|
||||
finish_reason = choice.finish_reason
|
||||
if finish_reason:
|
||||
meta["finish_reason"] = finish_reason
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
stream_chunk = StreamingChunk(text, meta)
|
||||
streaming_callback(stream_chunk)
|
||||
|
||||
if chunk.usage:
|
||||
usage = chunk.usage
|
||||
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = datetime.now().isoformat()
|
||||
|
||||
stream_chunk = StreamingChunk(text, meta)
|
||||
streaming_callback(stream_chunk)
|
||||
if usage:
|
||||
usage_dict = {"prompt_tokens": usage.prompt_tokens, "completion_tokens": usage.completion_tokens}
|
||||
else:
|
||||
usage_dict = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
|
||||
meta.update(
|
||||
{
|
||||
"model": self._client.model,
|
||||
"index": 0,
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming
|
||||
"finish_reason": finish_reason,
|
||||
"usage": usage_dict,
|
||||
"completion_start_time": first_chunk_time,
|
||||
}
|
||||
)
|
||||
@ -477,34 +493,52 @@ class HuggingFaceAPIChatGenerator:
|
||||
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
|
||||
):
|
||||
api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
|
||||
messages, stream=True, **generation_kwargs
|
||||
messages,
|
||||
stream=True,
|
||||
stream_options=ChatCompletionInputStreamOptions(include_usage=True),
|
||||
**generation_kwargs,
|
||||
)
|
||||
|
||||
generated_text = ""
|
||||
first_chunk_time = None
|
||||
finish_reason = None
|
||||
usage = None
|
||||
meta: Dict[str, Any] = {}
|
||||
|
||||
async for chunk in api_output:
|
||||
choice = chunk.choices[0]
|
||||
# The chunk with usage returns an empty array for choices
|
||||
if len(chunk.choices) > 0:
|
||||
# n is unused, so the API always returns only one choice
|
||||
# the argument is probably allowed for compatibility with OpenAI
|
||||
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
|
||||
choice = chunk.choices[0]
|
||||
|
||||
text = choice.delta.content or ""
|
||||
generated_text += text
|
||||
text = choice.delta.content or ""
|
||||
generated_text += text
|
||||
|
||||
finish_reason = choice.finish_reason
|
||||
if finish_reason:
|
||||
meta["finish_reason"] = finish_reason
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
stream_chunk = StreamingChunk(text, meta)
|
||||
await streaming_callback(stream_chunk) # type: ignore
|
||||
|
||||
if chunk.usage:
|
||||
usage = chunk.usage
|
||||
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = datetime.now().isoformat()
|
||||
|
||||
stream_chunk = StreamingChunk(text, meta)
|
||||
await streaming_callback(stream_chunk) # type: ignore
|
||||
if usage:
|
||||
usage_dict = {"prompt_tokens": usage.prompt_tokens, "completion_tokens": usage.completion_tokens}
|
||||
else:
|
||||
usage_dict = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
|
||||
meta.update(
|
||||
{
|
||||
"model": self._async_client.model,
|
||||
"index": 0,
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
"finish_reason": finish_reason,
|
||||
"usage": usage_dict,
|
||||
"completion_start_time": first_chunk_time,
|
||||
}
|
||||
)
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
When using HuggingFaceAPIChatGenerator with streaming, the returned ChatMessage now contains the number of prompt tokens and completion tokens in its meta data.
|
||||
Internally, the HuggingFaceAPIChatGenerator requests an additional streaming chunk that contains usage data.
|
||||
It then processes the usage streaming chunk to add usage meta data to the returned ChatMessage.
|
||||
@ -22,6 +22,7 @@ from huggingface_hub import (
|
||||
ChatCompletionStreamOutput,
|
||||
ChatCompletionStreamOutputChoice,
|
||||
ChatCompletionStreamOutputDelta,
|
||||
ChatCompletionInputStreamOptions,
|
||||
)
|
||||
from huggingface_hub.errors import RepositoryNotFoundError
|
||||
|
||||
@ -441,7 +442,12 @@ class TestHuggingFaceAPIChatGenerator:
|
||||
|
||||
# check kwargs passed to text_generation
|
||||
_, kwargs = mock_chat_completion.call_args
|
||||
assert kwargs == {"stop": [], "stream": True, "max_tokens": 512}
|
||||
assert kwargs == {
|
||||
"stop": [],
|
||||
"stream": True,
|
||||
"max_tokens": 512,
|
||||
"stream_options": ChatCompletionInputStreamOptions(include_usage=True),
|
||||
}
|
||||
|
||||
# Assert that the streaming callback was called twice
|
||||
assert streaming_call_count == 2
|
||||
@ -505,7 +511,12 @@ class TestHuggingFaceAPIChatGenerator:
|
||||
|
||||
# check kwargs passed to text_generation
|
||||
_, kwargs = mock_chat_completion.call_args
|
||||
assert kwargs == {"stop": [], "stream": True, "max_tokens": 512}
|
||||
assert kwargs == {
|
||||
"stop": [],
|
||||
"stream": True,
|
||||
"max_tokens": 512,
|
||||
"stream_options": ChatCompletionInputStreamOptions(include_usage=True),
|
||||
}
|
||||
|
||||
# Assert that the streaming callback was called twice
|
||||
assert streaming_call_count == 2
|
||||
@ -717,9 +728,9 @@ class TestHuggingFaceAPIChatGenerator:
|
||||
assert datetime.fromisoformat(response_meta["completion_start_time"]) <= datetime.now()
|
||||
assert "usage" in response_meta
|
||||
assert "prompt_tokens" in response_meta["usage"]
|
||||
assert response_meta["usage"]["prompt_tokens"] == 0
|
||||
assert response_meta["usage"]["prompt_tokens"] > 0
|
||||
assert "completion_tokens" in response_meta["usage"]
|
||||
assert response_meta["usage"]["completion_tokens"] == 0
|
||||
assert response_meta["usage"]["completion_tokens"] > 0
|
||||
assert response_meta["model"] == "microsoft/Phi-3.5-mini-instruct"
|
||||
assert response_meta["finish_reason"] is not None
|
||||
|
||||
@ -848,7 +859,12 @@ class TestHuggingFaceAPIChatGenerator:
|
||||
|
||||
# check kwargs passed to chat_completion
|
||||
_, kwargs = mock_chat_completion_async.call_args
|
||||
assert kwargs == {"stop": [], "stream": True, "max_tokens": 512}
|
||||
assert kwargs == {
|
||||
"stop": [],
|
||||
"stream": True,
|
||||
"max_tokens": 512,
|
||||
"stream_options": ChatCompletionInputStreamOptions(include_usage=True),
|
||||
}
|
||||
|
||||
# Assert that the streaming callback was called twice
|
||||
assert streaming_call_count == 2
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user