mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-16 17:48:19 +00:00
feat: Add TTFT support in OpenAI chat generator (#8444)
* feat: Add TTFT support in OpenAI generators * pylint fixes * correct disable --------- Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
This commit is contained in:
parent
294a67e426
commit
2595e68050
@ -5,6 +5,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
from openai import OpenAI, Stream
|
from openai import OpenAI, Stream
|
||||||
@ -63,7 +64,7 @@ class OpenAIChatGenerator:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__( # pylint: disable=too-many-positional-arguments
|
||||||
self,
|
self,
|
||||||
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
|
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
|
||||||
model: str = "gpt-4o-mini",
|
model: str = "gpt-4o-mini",
|
||||||
@ -222,11 +223,15 @@ class OpenAIChatGenerator:
|
|||||||
raise ValueError("Cannot stream multiple responses, please set n=1.")
|
raise ValueError("Cannot stream multiple responses, please set n=1.")
|
||||||
chunks: List[StreamingChunk] = []
|
chunks: List[StreamingChunk] = []
|
||||||
chunk = None
|
chunk = None
|
||||||
|
_first_token = True
|
||||||
|
|
||||||
# pylint: disable=not-an-iterable
|
# pylint: disable=not-an-iterable
|
||||||
for chunk in chat_completion:
|
for chunk in chat_completion:
|
||||||
if chunk.choices and streaming_callback:
|
if chunk.choices and streaming_callback:
|
||||||
chunk_delta: StreamingChunk = self._build_chunk(chunk)
|
chunk_delta: StreamingChunk = self._build_chunk(chunk)
|
||||||
|
if _first_token:
|
||||||
|
_first_token = False
|
||||||
|
chunk_delta.meta["completion_start_time"] = datetime.now().isoformat()
|
||||||
chunks.append(chunk_delta)
|
chunks.append(chunk_delta)
|
||||||
streaming_callback(chunk_delta) # invoke callback with the chunk_delta
|
streaming_callback(chunk_delta) # invoke callback with the chunk_delta
|
||||||
completions = [self._connect_chunks(chunk, chunks)]
|
completions = [self._connect_chunks(chunk, chunks)]
|
||||||
@ -280,7 +285,12 @@ class OpenAIChatGenerator:
|
|||||||
payload["function"]["arguments"] += delta.arguments or ""
|
payload["function"]["arguments"] += delta.arguments or ""
|
||||||
complete_response = ChatMessage.from_assistant(json.dumps(payloads))
|
complete_response = ChatMessage.from_assistant(json.dumps(payloads))
|
||||||
else:
|
else:
|
||||||
complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in chunks]))
|
total_content = ""
|
||||||
|
total_meta = {}
|
||||||
|
for streaming_chunk in chunks:
|
||||||
|
total_content += streaming_chunk.content
|
||||||
|
total_meta.update(streaming_chunk.meta)
|
||||||
|
complete_response = ChatMessage.from_assistant(total_content, meta=total_meta)
|
||||||
complete_response.meta.update(
|
complete_response.meta.update(
|
||||||
{
|
{
|
||||||
"model": chunk.model,
|
"model": chunk.model,
|
||||||
|
|||||||
6
releasenotes/notes/openai-ttft-42b1ad551b542930.yaml
Normal file
6
releasenotes/notes/openai-ttft-42b1ad551b542930.yaml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Add TTFT (Time-to-First-Token) support for OpenAI generators. This
|
||||||
|
captures the time taken to generate the first token from the model and
|
||||||
|
can be used to analyze the latency of the application.
|
||||||
@ -3,6 +3,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from openai import OpenAIError
|
from openai import OpenAIError
|
||||||
@ -219,7 +220,8 @@ class TestOpenAIChatGenerator:
|
|||||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||||
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
|
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
|
||||||
|
|
||||||
def test_run_with_streaming_callback_in_run_method(self, chat_messages, mock_chat_completion_chunk):
|
@patch("haystack.components.generators.chat.openai.datetime")
|
||||||
|
def test_run_with_streaming_callback_in_run_method(self, mock_datetime, chat_messages, mock_chat_completion_chunk):
|
||||||
streaming_callback_called = False
|
streaming_callback_called = False
|
||||||
|
|
||||||
def streaming_callback(chunk: StreamingChunk) -> None:
|
def streaming_callback(chunk: StreamingChunk) -> None:
|
||||||
@ -240,6 +242,13 @@ class TestOpenAIChatGenerator:
|
|||||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||||
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
|
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
|
||||||
|
|
||||||
|
assert hasattr(response["replies"][0], "meta")
|
||||||
|
assert isinstance(response["replies"][0].meta, dict)
|
||||||
|
assert (
|
||||||
|
response["replies"][0].meta["completion_start_time"]
|
||||||
|
== mock_datetime.now.return_value.isoformat.return_value
|
||||||
|
)
|
||||||
|
|
||||||
def test_check_abnormal_completions(self, caplog):
|
def test_check_abnormal_completions(self, caplog):
|
||||||
caplog.set_level(logging.INFO)
|
caplog.set_level(logging.INFO)
|
||||||
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
|
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user