feat: HuggingFaceAPIChatGenerator add token usage data (#8375)

* Ensure HuggingFaceAPIChatGenerator has token usage data

* Add reno note

* Fix release note
This commit is contained in:
Vladimir Blagojevic 2024-09-23 15:40:50 +02:00 committed by GitHub
parent 0b7f1fd114
commit 09b95746a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 3 deletions

View File

@ -248,7 +248,14 @@ class HuggingFaceAPIChatGenerator:
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
message = ChatMessage.from_assistant(generated_text)
message.meta.update({"model": self._client.model, "finish_reason": finish_reason, "index": 0})
message.meta.update(
{
"model": self._client.model,
"finish_reason": finish_reason,
"index": 0,
"usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming
}
)
return {"replies": [message]}
def _run_non_streaming(
@ -257,11 +264,16 @@ class HuggingFaceAPIChatGenerator:
chat_messages: List[ChatMessage] = []
api_chat_output: ChatCompletionOutput = self._client.chat_completion(messages, **generation_kwargs)
for choice in api_chat_output.choices:
message = ChatMessage.from_assistant(choice.message.content)
message.meta.update(
{"model": self._client.model, "finish_reason": choice.finish_reason, "index": choice.index}
{
"model": self._client.model,
"finish_reason": choice.finish_reason,
"index": choice.index,
"usage": api_chat_output.usage or {"prompt_tokens": 0, "completion_tokens": 0},
}
)
chat_messages.append(message)
return {"replies": chat_messages}

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Adds 'usage' meta field with 'prompt_tokens' and 'completion_tokens' keys to HuggingFaceAPIChatGenerator.

View File

@ -287,3 +287,31 @@ class TestHuggingFaceAPIGenerator:
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "usage" in response["replies"][0].meta
assert "prompt_tokens" in response["replies"][0].meta["usage"]
assert "completion_tokens" in response["replies"][0].meta["usage"]
@pytest.mark.flaky(reruns=5, reruns_delay=5)
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
)
def test_run_serverless_streaming(self):
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
generation_kwargs={"max_tokens": 20},
streaming_callback=streaming_callback_handler,
)
messages = [ChatMessage.from_user("What is the capital of France?")]
response = generator.run(messages=messages)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "usage" in response["replies"][0].meta
assert "prompt_tokens" in response["replies"][0].meta["usage"]
assert "completion_tokens" in response["replies"][0].meta["usage"]