From 53308a62943095ca0d50e53e73575dfdcdab44f1 Mon Sep 17 00:00:00 2001 From: Mohammed Abdul Razak Wahab <60781022+mdrazak2001@users.noreply.github.com> Date: Sat, 26 Apr 2025 15:34:02 +0530 Subject: [PATCH] feat: Add sanitization for Meta field during serialization (#9272) * feat: Add sanitization for Meta field during serialization * Revert "feat: Add sanitization for Meta field during serialization" This reverts commit c529f7c25b69aed626bb2072c8bf171815b591cc. * feat: add nested serialization in openai usage object * add reno * add nested serialization in OpenAiChatGenerator * Update releasenotes/notes/nested-serialization-openai-usage-object-3817b07342999edf.yaml Co-authored-by: Amna Mubashar * merge tests * Adjust the test --------- Co-authored-by: Amna Mubashar Co-authored-by: Sebastian Husch Lee --- haystack/components/generators/chat/openai.py | 17 +++++++++-- haystack/components/generators/openai.py | 23 +++++++++++---- ...-openai-usage-object-3817b07342999edf.yaml | 5 ++++ .../components/generators/chat/test_openai.py | 29 ++++++++++++------- test/components/generators/test_openai.py | 17 +++++++---- 5 files changed, 67 insertions(+), 24 deletions(-) create mode 100644 releasenotes/notes/nested-serialization-openai-usage-object-3817b07342999edf.yaml diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index baab37f85..b5df7caaa 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -515,7 +515,7 @@ class OpenAIChatGenerator: "index": 0, "finish_reason": finish_reason, "completion_start_time": chunks[0].meta.get("received_at"), # first chunk received - "usage": dict(last_chunk.usage or {}), # last chunk has the final usage data if available + "usage": self._serialize_usage(last_chunk.usage), # last chunk has the final usage data if available } return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta) @@ -553,7 +553,7 @@ class OpenAIChatGenerator: "model": completion.model, "index": choice.index, "finish_reason": choice.finish_reason, - "usage": dict(completion.usage or {}), + "usage": self._serialize_usage(completion.usage), } ) return chat_message @@ -587,3 +587,16 @@ class OpenAIChatGenerator: } ) return chunk_message + + def _serialize_usage(self, usage): + """Convert OpenAI usage object to serializable dict recursively""" + if hasattr(usage, "model_dump"): + return usage.model_dump() + elif hasattr(usage, "__dict__"): + return {k: self._serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")} + elif isinstance(usage, dict): + return {k: self._serialize_usage(v) for k, v in usage.items()} + elif isinstance(usage, list): + return [self._serialize_usage(item) for item in usage] + else: + return usage diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index 0e74ea208..72455a1b9 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -252,9 +252,21 @@ class OpenAIGenerator: return {"replies": [message.text for message in completions], "meta": [message.meta for message in completions]} - @staticmethod + def _serialize_usage(self, usage): + """Convert OpenAI usage object to serializable dict recursively""" + if hasattr(usage, "model_dump"): + return usage.model_dump() + elif hasattr(usage, "__dict__"): + return {k: self._serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")} + elif isinstance(usage, dict): + return {k: self._serialize_usage(v) for k, v in usage.items()} + elif isinstance(usage, list): + return [self._serialize_usage(item) for item in usage] + else: + return usage + def _create_message_from_chunks( - completion_chunk: ChatCompletionChunk, streamed_chunks: List[StreamingChunk] + self, completion_chunk: ChatCompletionChunk, streamed_chunks: List[StreamingChunk] ) -> ChatMessage: """ Creates a single ChatMessage from the streamed chunks. Some data is retrieved from the completion chunk. @@ -267,13 +279,12 @@ class OpenAIGenerator: "index": 0, "finish_reason": finish_reason, "completion_start_time": streamed_chunks[0].meta.get("received_at"), # first chunk received - "usage": dict(completion_chunk.usage or {}), + "usage": self._serialize_usage(completion_chunk.usage), } ) return complete_response - @staticmethod - def _build_message(completion: Any, choice: Any) -> ChatMessage: + def _build_message(self, completion: Any, choice: Any) -> ChatMessage: """ Converts the response from the OpenAI API to a ChatMessage. @@ -292,7 +303,7 @@ class OpenAIGenerator: "model": completion.model, "index": choice.index, "finish_reason": choice.finish_reason, - "usage": dict(completion.usage), + "usage": self._serialize_usage(completion.usage), } ) return chat_message diff --git a/releasenotes/notes/nested-serialization-openai-usage-object-3817b07342999edf.yaml b/releasenotes/notes/nested-serialization-openai-usage-object-3817b07342999edf.yaml new file mode 100644 index 000000000..6cec387cc --- /dev/null +++ b/releasenotes/notes/nested-serialization-openai-usage-object-3817b07342999edf.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Implement JSON-safe serialization for OpenAI usage data by converting + token counts and details (like CompletionTokensDetails and PromptTokensDetails) into plain dictionaries. diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 794a96887..59ec57214 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -926,25 +926,32 @@ class TestOpenAIChatGenerator: ) results = component.run([ChatMessage.from_user("What's the capital of France?")]) + # Basic response checks + assert "replies" in results assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] assert "Paris" in message.text + assert isinstance(message.meta, dict) - assert "gpt-4o" in message.meta["model"] - assert message.meta["finish_reason"] == "stop" + # Metadata checks + metadata = message.meta + assert "gpt-4o" in metadata["model"] + assert metadata["finish_reason"] == "stop" + # Usage information checks + assert isinstance(metadata.get("usage"), dict), "meta.usage not a dict" + usage = metadata["usage"] + assert "prompt_tokens" in usage and usage["prompt_tokens"] > 0 + assert "completion_tokens" in usage and usage["completion_tokens"] > 0 + + # Detailed token information checks + assert isinstance(usage.get("completion_tokens_details"), dict), "usage.completion_tokens_details not a dict" + assert isinstance(usage.get("prompt_tokens_details"), dict), "usage.prompt_tokens_details not a dict" + + # Streaming callback verification assert callback.counter > 1 assert "Paris" in callback.responses - # check that the completion_start_time is set and valid ISO format - assert "completion_start_time" in message.meta - assert datetime.fromisoformat(message.meta["completion_start_time"]) <= datetime.now() - - assert isinstance(message.meta["usage"], dict) - assert message.meta["usage"]["prompt_tokens"] > 0 - assert message.meta["usage"]["completion_tokens"] > 0 - assert message.meta["usage"]["total_tokens"] > 0 - @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", diff --git a/test/components/generators/test_openai.py b/test/components/generators/test_openai.py index 4f3419f55..f0fe989ef 100644 --- a/test/components/generators/test_openai.py +++ b/test/components/generators/test_openai.py @@ -296,21 +296,28 @@ class TestOpenAIGenerator: ) results = component.run("What's the capital of France?") + # Basic response validation assert len(results["replies"]) == 1 assert len(results["meta"]) == 1 response: str = results["replies"][0] assert "Paris" in response + # Metadata validation metadata = results["meta"][0] - assert "gpt-4o-mini" in metadata["model"] assert metadata["finish_reason"] == "stop" - assert "usage" in metadata - assert "prompt_tokens" in metadata["usage"] and metadata["usage"]["prompt_tokens"] > 0 - assert "completion_tokens" in metadata["usage"] and metadata["usage"]["completion_tokens"] > 0 - assert "total_tokens" in metadata["usage"] and metadata["usage"]["total_tokens"] > 0 + # Basic usage validation + assert isinstance(metadata.get("usage"), dict), "meta.usage not a dict" + usage = metadata["usage"] + assert "prompt_tokens" in usage and usage["prompt_tokens"] > 0 + assert "completion_tokens" in usage and usage["completion_tokens"] > 0 + # Detailed token information validation + assert isinstance(usage.get("completion_tokens_details"), dict), "usage.completion_tokens_details not a dict" + assert isinstance(usage.get("prompt_tokens_details"), dict), "usage.prompt_tokens_details not a dict" + + # Streaming callback validation assert callback.counter > 1 assert "Paris" in callback.responses