mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
feat: Add sanitization for Meta field during serialization (#9272)
* feat: Add sanitization for Meta field during serialization * Revert "feat: Add sanitization for Meta field during serialization" This reverts commit c529f7c25b69aed626bb2072c8bf171815b591cc. * feat: add nested serialization in openai usage object * add reno * add nested serialization in OpenAiChatGenerator * Update releasenotes/notes/nested-serialization-openai-usage-object-3817b07342999edf.yaml Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> * merge tests * Adjust the test --------- Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>
This commit is contained in:
parent
0fdb88424b
commit
53308a6294
@ -515,7 +515,7 @@ class OpenAIChatGenerator:
|
||||
"index": 0,
|
||||
"finish_reason": finish_reason,
|
||||
"completion_start_time": chunks[0].meta.get("received_at"), # first chunk received
|
||||
"usage": dict(last_chunk.usage or {}), # last chunk has the final usage data if available
|
||||
"usage": self._serialize_usage(last_chunk.usage), # last chunk has the final usage data if available
|
||||
}
|
||||
|
||||
return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta)
|
||||
@ -553,7 +553,7 @@ class OpenAIChatGenerator:
|
||||
"model": completion.model,
|
||||
"index": choice.index,
|
||||
"finish_reason": choice.finish_reason,
|
||||
"usage": dict(completion.usage or {}),
|
||||
"usage": self._serialize_usage(completion.usage),
|
||||
}
|
||||
)
|
||||
return chat_message
|
||||
@ -587,3 +587,16 @@ class OpenAIChatGenerator:
|
||||
}
|
||||
)
|
||||
return chunk_message
|
||||
|
||||
def _serialize_usage(self, usage):
|
||||
"""Convert OpenAI usage object to serializable dict recursively"""
|
||||
if hasattr(usage, "model_dump"):
|
||||
return usage.model_dump()
|
||||
elif hasattr(usage, "__dict__"):
|
||||
return {k: self._serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")}
|
||||
elif isinstance(usage, dict):
|
||||
return {k: self._serialize_usage(v) for k, v in usage.items()}
|
||||
elif isinstance(usage, list):
|
||||
return [self._serialize_usage(item) for item in usage]
|
||||
else:
|
||||
return usage
|
||||
|
||||
@ -252,9 +252,21 @@ class OpenAIGenerator:
|
||||
|
||||
return {"replies": [message.text for message in completions], "meta": [message.meta for message in completions]}
|
||||
|
||||
@staticmethod
|
||||
def _serialize_usage(self, usage):
|
||||
"""Convert OpenAI usage object to serializable dict recursively"""
|
||||
if hasattr(usage, "model_dump"):
|
||||
return usage.model_dump()
|
||||
elif hasattr(usage, "__dict__"):
|
||||
return {k: self._serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")}
|
||||
elif isinstance(usage, dict):
|
||||
return {k: self._serialize_usage(v) for k, v in usage.items()}
|
||||
elif isinstance(usage, list):
|
||||
return [self._serialize_usage(item) for item in usage]
|
||||
else:
|
||||
return usage
|
||||
|
||||
def _create_message_from_chunks(
|
||||
completion_chunk: ChatCompletionChunk, streamed_chunks: List[StreamingChunk]
|
||||
self, completion_chunk: ChatCompletionChunk, streamed_chunks: List[StreamingChunk]
|
||||
) -> ChatMessage:
|
||||
"""
|
||||
Creates a single ChatMessage from the streamed chunks. Some data is retrieved from the completion chunk.
|
||||
@ -267,13 +279,12 @@ class OpenAIGenerator:
|
||||
"index": 0,
|
||||
"finish_reason": finish_reason,
|
||||
"completion_start_time": streamed_chunks[0].meta.get("received_at"), # first chunk received
|
||||
"usage": dict(completion_chunk.usage or {}),
|
||||
"usage": self._serialize_usage(completion_chunk.usage),
|
||||
}
|
||||
)
|
||||
return complete_response
|
||||
|
||||
@staticmethod
|
||||
def _build_message(completion: Any, choice: Any) -> ChatMessage:
|
||||
def _build_message(self, completion: Any, choice: Any) -> ChatMessage:
|
||||
"""
|
||||
Converts the response from the OpenAI API to a ChatMessage.
|
||||
|
||||
@ -292,7 +303,7 @@ class OpenAIGenerator:
|
||||
"model": completion.model,
|
||||
"index": choice.index,
|
||||
"finish_reason": choice.finish_reason,
|
||||
"usage": dict(completion.usage),
|
||||
"usage": self._serialize_usage(completion.usage),
|
||||
}
|
||||
)
|
||||
return chat_message
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Implement JSON-safe serialization for OpenAI usage data by converting
|
||||
token counts and details (like CompletionTokensDetails and PromptTokensDetails) into plain dictionaries.
|
||||
@ -926,25 +926,32 @@ class TestOpenAIChatGenerator:
|
||||
)
|
||||
results = component.run([ChatMessage.from_user("What's the capital of France?")])
|
||||
|
||||
# Basic response checks
|
||||
assert "replies" in results
|
||||
assert len(results["replies"]) == 1
|
||||
message: ChatMessage = results["replies"][0]
|
||||
assert "Paris" in message.text
|
||||
assert isinstance(message.meta, dict)
|
||||
|
||||
assert "gpt-4o" in message.meta["model"]
|
||||
assert message.meta["finish_reason"] == "stop"
|
||||
# Metadata checks
|
||||
metadata = message.meta
|
||||
assert "gpt-4o" in metadata["model"]
|
||||
assert metadata["finish_reason"] == "stop"
|
||||
|
||||
# Usage information checks
|
||||
assert isinstance(metadata.get("usage"), dict), "meta.usage not a dict"
|
||||
usage = metadata["usage"]
|
||||
assert "prompt_tokens" in usage and usage["prompt_tokens"] > 0
|
||||
assert "completion_tokens" in usage and usage["completion_tokens"] > 0
|
||||
|
||||
# Detailed token information checks
|
||||
assert isinstance(usage.get("completion_tokens_details"), dict), "usage.completion_tokens_details not a dict"
|
||||
assert isinstance(usage.get("prompt_tokens_details"), dict), "usage.prompt_tokens_details not a dict"
|
||||
|
||||
# Streaming callback verification
|
||||
assert callback.counter > 1
|
||||
assert "Paris" in callback.responses
|
||||
|
||||
# check that the completion_start_time is set and valid ISO format
|
||||
assert "completion_start_time" in message.meta
|
||||
assert datetime.fromisoformat(message.meta["completion_start_time"]) <= datetime.now()
|
||||
|
||||
assert isinstance(message.meta["usage"], dict)
|
||||
assert message.meta["usage"]["prompt_tokens"] > 0
|
||||
assert message.meta["usage"]["completion_tokens"] > 0
|
||||
assert message.meta["usage"]["total_tokens"] > 0
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
|
||||
@ -296,21 +296,28 @@ class TestOpenAIGenerator:
|
||||
)
|
||||
results = component.run("What's the capital of France?")
|
||||
|
||||
# Basic response validation
|
||||
assert len(results["replies"]) == 1
|
||||
assert len(results["meta"]) == 1
|
||||
response: str = results["replies"][0]
|
||||
assert "Paris" in response
|
||||
|
||||
# Metadata validation
|
||||
metadata = results["meta"][0]
|
||||
|
||||
assert "gpt-4o-mini" in metadata["model"]
|
||||
assert metadata["finish_reason"] == "stop"
|
||||
|
||||
assert "usage" in metadata
|
||||
assert "prompt_tokens" in metadata["usage"] and metadata["usage"]["prompt_tokens"] > 0
|
||||
assert "completion_tokens" in metadata["usage"] and metadata["usage"]["completion_tokens"] > 0
|
||||
assert "total_tokens" in metadata["usage"] and metadata["usage"]["total_tokens"] > 0
|
||||
# Basic usage validation
|
||||
assert isinstance(metadata.get("usage"), dict), "meta.usage not a dict"
|
||||
usage = metadata["usage"]
|
||||
assert "prompt_tokens" in usage and usage["prompt_tokens"] > 0
|
||||
assert "completion_tokens" in usage and usage["completion_tokens"] > 0
|
||||
|
||||
# Detailed token information validation
|
||||
assert isinstance(usage.get("completion_tokens_details"), dict), "usage.completion_tokens_details not a dict"
|
||||
assert isinstance(usage.get("prompt_tokens_details"), dict), "usage.prompt_tokens_details not a dict"
|
||||
|
||||
# Streaming callback validation
|
||||
assert callback.counter > 1
|
||||
assert "Paris" in callback.responses
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user