mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-26 23:15:59 +00:00
feat: HuggingFaceAPIChatGenerator add token usage data (#8375)
* Ensure HuggingFaceAPIChatGenerator has token usage data * Add reno note * Fix release note
This commit is contained in:
parent
0b7f1fd114
commit
09b95746a2
@ -248,7 +248,14 @@ class HuggingFaceAPIChatGenerator:
|
||||
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
|
||||
|
||||
message = ChatMessage.from_assistant(generated_text)
|
||||
message.meta.update({"model": self._client.model, "finish_reason": finish_reason, "index": 0})
|
||||
message.meta.update(
|
||||
{
|
||||
"model": self._client.model,
|
||||
"finish_reason": finish_reason,
|
||||
"index": 0,
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming
|
||||
}
|
||||
)
|
||||
return {"replies": [message]}
|
||||
|
||||
def _run_non_streaming(
|
||||
@ -257,11 +264,16 @@ class HuggingFaceAPIChatGenerator:
|
||||
chat_messages: List[ChatMessage] = []
|
||||
|
||||
api_chat_output: ChatCompletionOutput = self._client.chat_completion(messages, **generation_kwargs)
|
||||
|
||||
for choice in api_chat_output.choices:
|
||||
message = ChatMessage.from_assistant(choice.message.content)
|
||||
message.meta.update(
|
||||
{"model": self._client.model, "finish_reason": choice.finish_reason, "index": choice.index}
|
||||
{
|
||||
"model": self._client.model,
|
||||
"finish_reason": choice.finish_reason,
|
||||
"index": choice.index,
|
||||
"usage": api_chat_output.usage or {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
}
|
||||
)
|
||||
chat_messages.append(message)
|
||||
|
||||
return {"replies": chat_messages}
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Adds 'usage' meta field with 'prompt_tokens' and 'completion_tokens' keys to HuggingFaceAPIChatGenerator.
|
||||
@ -287,3 +287,31 @@ class TestHuggingFaceAPIGenerator:
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) > 0
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
assert "usage" in response["replies"][0].meta
|
||||
assert "prompt_tokens" in response["replies"][0].meta["usage"]
|
||||
assert "completion_tokens" in response["replies"][0].meta["usage"]
|
||||
|
||||
@pytest.mark.flaky(reruns=5, reruns_delay=5)
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("HF_API_TOKEN", None),
|
||||
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
||||
)
|
||||
def test_run_serverless_streaming(self):
|
||||
generator = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
|
||||
generation_kwargs={"max_tokens": 20},
|
||||
streaming_callback=streaming_callback_handler,
|
||||
)
|
||||
|
||||
messages = [ChatMessage.from_user("What is the capital of France?")]
|
||||
response = generator.run(messages=messages)
|
||||
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) > 0
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
assert "usage" in response["replies"][0].meta
|
||||
assert "prompt_tokens" in response["replies"][0].meta["usage"]
|
||||
assert "completion_tokens" in response["replies"][0].meta["usage"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user