feat: Add sanitization for Meta field during serialization (#9272)

* feat: Add sanitization for Meta field during serialization

* Revert "feat: Add sanitization for Meta field during serialization"

This reverts commit c529f7c25b69aed626bb2072c8bf171815b591cc.

* feat: add nested serialization in openai usage object

* add reno

* add nested serialization in OpenAiChatGenerator

* Update releasenotes/notes/nested-serialization-openai-usage-object-3817b07342999edf.yaml

Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>

* merge tests

* Adjust the test

---------

Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>
Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>
This commit is contained in:
Mohammed Abdul Razak Wahab 2025-04-26 15:34:02 +05:30 committed by GitHub
parent 0fdb88424b
commit 53308a6294
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 67 additions and 24 deletions

View File

@ -515,7 +515,7 @@ class OpenAIChatGenerator:
"index": 0,
"finish_reason": finish_reason,
"completion_start_time": chunks[0].meta.get("received_at"), # first chunk received
"usage": dict(last_chunk.usage or {}), # last chunk has the final usage data if available
"usage": self._serialize_usage(last_chunk.usage), # last chunk has the final usage data if available
}
return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta)
@ -553,7 +553,7 @@ class OpenAIChatGenerator:
"model": completion.model,
"index": choice.index,
"finish_reason": choice.finish_reason,
"usage": dict(completion.usage or {}),
"usage": self._serialize_usage(completion.usage),
}
)
return chat_message
@ -587,3 +587,16 @@ class OpenAIChatGenerator:
}
)
return chunk_message
def _serialize_usage(self, usage):
"""Convert OpenAI usage object to serializable dict recursively"""
if hasattr(usage, "model_dump"):
return usage.model_dump()
elif hasattr(usage, "__dict__"):
return {k: self._serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")}
elif isinstance(usage, dict):
return {k: self._serialize_usage(v) for k, v in usage.items()}
elif isinstance(usage, list):
return [self._serialize_usage(item) for item in usage]
else:
return usage

View File

@ -252,9 +252,21 @@ class OpenAIGenerator:
return {"replies": [message.text for message in completions], "meta": [message.meta for message in completions]}
@staticmethod
def _serialize_usage(self, usage):
"""Convert OpenAI usage object to serializable dict recursively"""
if hasattr(usage, "model_dump"):
return usage.model_dump()
elif hasattr(usage, "__dict__"):
return {k: self._serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")}
elif isinstance(usage, dict):
return {k: self._serialize_usage(v) for k, v in usage.items()}
elif isinstance(usage, list):
return [self._serialize_usage(item) for item in usage]
else:
return usage
def _create_message_from_chunks(
completion_chunk: ChatCompletionChunk, streamed_chunks: List[StreamingChunk]
self, completion_chunk: ChatCompletionChunk, streamed_chunks: List[StreamingChunk]
) -> ChatMessage:
"""
Creates a single ChatMessage from the streamed chunks. Some data is retrieved from the completion chunk.
@ -267,13 +279,12 @@ class OpenAIGenerator:
"index": 0,
"finish_reason": finish_reason,
"completion_start_time": streamed_chunks[0].meta.get("received_at"), # first chunk received
"usage": dict(completion_chunk.usage or {}),
"usage": self._serialize_usage(completion_chunk.usage),
}
)
return complete_response
@staticmethod
def _build_message(completion: Any, choice: Any) -> ChatMessage:
def _build_message(self, completion: Any, choice: Any) -> ChatMessage:
"""
Converts the response from the OpenAI API to a ChatMessage.
@ -292,7 +303,7 @@ class OpenAIGenerator:
"model": completion.model,
"index": choice.index,
"finish_reason": choice.finish_reason,
"usage": dict(completion.usage),
"usage": self._serialize_usage(completion.usage),
}
)
return chat_message

View File

@ -0,0 +1,5 @@
---
features:
- |
Implement JSON-safe serialization for OpenAI usage data by converting
token counts and details (like CompletionTokensDetails and PromptTokensDetails) into plain dictionaries.

View File

@ -926,25 +926,32 @@ class TestOpenAIChatGenerator:
)
results = component.run([ChatMessage.from_user("What's the capital of France?")])
# Basic response checks
assert "replies" in results
assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
assert "Paris" in message.text
assert isinstance(message.meta, dict)
assert "gpt-4o" in message.meta["model"]
assert message.meta["finish_reason"] == "stop"
# Metadata checks
metadata = message.meta
assert "gpt-4o" in metadata["model"]
assert metadata["finish_reason"] == "stop"
# Usage information checks
assert isinstance(metadata.get("usage"), dict), "meta.usage not a dict"
usage = metadata["usage"]
assert "prompt_tokens" in usage and usage["prompt_tokens"] > 0
assert "completion_tokens" in usage and usage["completion_tokens"] > 0
# Detailed token information checks
assert isinstance(usage.get("completion_tokens_details"), dict), "usage.completion_tokens_details not a dict"
assert isinstance(usage.get("prompt_tokens_details"), dict), "usage.prompt_tokens_details not a dict"
# Streaming callback verification
assert callback.counter > 1
assert "Paris" in callback.responses
# check that the completion_start_time is set and valid ISO format
assert "completion_start_time" in message.meta
assert datetime.fromisoformat(message.meta["completion_start_time"]) <= datetime.now()
assert isinstance(message.meta["usage"], dict)
assert message.meta["usage"]["prompt_tokens"] > 0
assert message.meta["usage"]["completion_tokens"] > 0
assert message.meta["usage"]["total_tokens"] > 0
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",

View File

@ -296,21 +296,28 @@ class TestOpenAIGenerator:
)
results = component.run("What's the capital of France?")
# Basic response validation
assert len(results["replies"]) == 1
assert len(results["meta"]) == 1
response: str = results["replies"][0]
assert "Paris" in response
# Metadata validation
metadata = results["meta"][0]
assert "gpt-4o-mini" in metadata["model"]
assert metadata["finish_reason"] == "stop"
assert "usage" in metadata
assert "prompt_tokens" in metadata["usage"] and metadata["usage"]["prompt_tokens"] > 0
assert "completion_tokens" in metadata["usage"] and metadata["usage"]["completion_tokens"] > 0
assert "total_tokens" in metadata["usage"] and metadata["usage"]["total_tokens"] > 0
# Basic usage validation
assert isinstance(metadata.get("usage"), dict), "meta.usage not a dict"
usage = metadata["usage"]
assert "prompt_tokens" in usage and usage["prompt_tokens"] > 0
assert "completion_tokens" in usage and usage["completion_tokens"] > 0
# Detailed token information validation
assert isinstance(usage.get("completion_tokens_details"), dict), "usage.completion_tokens_details not a dict"
assert isinstance(usage.get("prompt_tokens_details"), dict), "usage.prompt_tokens_details not a dict"
# Streaming callback validation
assert callback.counter > 1
assert "Paris" in callback.responses