feat: Add TTFT support in OpenAI chat generator (#8444)

* feat: Add TTFT support in OpenAI generators

* pylint fixes

* correct disable

---------

Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
This commit is contained in:
Bohan Qu 2024-10-31 23:56:17 +08:00 committed by GitHub
parent 294a67e426
commit 2595e68050
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 3 deletions

View File

@ -5,6 +5,7 @@
import copy import copy
import json import json
import os import os
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
from openai import OpenAI, Stream from openai import OpenAI, Stream
@ -63,7 +64,7 @@ class OpenAIChatGenerator:
``` ```
""" """
def __init__( def __init__( # pylint: disable=too-many-positional-arguments
self, self,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"), api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "gpt-4o-mini", model: str = "gpt-4o-mini",
@ -222,11 +223,15 @@ class OpenAIChatGenerator:
raise ValueError("Cannot stream multiple responses, please set n=1.") raise ValueError("Cannot stream multiple responses, please set n=1.")
chunks: List[StreamingChunk] = [] chunks: List[StreamingChunk] = []
chunk = None chunk = None
_first_token = True
# pylint: disable=not-an-iterable # pylint: disable=not-an-iterable
for chunk in chat_completion: for chunk in chat_completion:
if chunk.choices and streaming_callback: if chunk.choices and streaming_callback:
chunk_delta: StreamingChunk = self._build_chunk(chunk) chunk_delta: StreamingChunk = self._build_chunk(chunk)
if _first_token:
_first_token = False
chunk_delta.meta["completion_start_time"] = datetime.now().isoformat()
chunks.append(chunk_delta) chunks.append(chunk_delta)
streaming_callback(chunk_delta) # invoke callback with the chunk_delta streaming_callback(chunk_delta) # invoke callback with the chunk_delta
completions = [self._connect_chunks(chunk, chunks)] completions = [self._connect_chunks(chunk, chunks)]
@ -280,7 +285,12 @@ class OpenAIChatGenerator:
payload["function"]["arguments"] += delta.arguments or "" payload["function"]["arguments"] += delta.arguments or ""
complete_response = ChatMessage.from_assistant(json.dumps(payloads)) complete_response = ChatMessage.from_assistant(json.dumps(payloads))
else: else:
complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in chunks])) total_content = ""
total_meta = {}
for streaming_chunk in chunks:
total_content += streaming_chunk.content
total_meta.update(streaming_chunk.meta)
complete_response = ChatMessage.from_assistant(total_content, meta=total_meta)
complete_response.meta.update( complete_response.meta.update(
{ {
"model": chunk.model, "model": chunk.model,

View File

@ -0,0 +1,6 @@
---
features:
- |
Add TTFT (Time-to-First-Token) support for OpenAI generators. This
captures the time taken to generate the first token from the model and
can be used to analyze the latency of the application.

View File

@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
import os import os
from unittest.mock import patch
import pytest import pytest
from openai import OpenAIError from openai import OpenAIError
@ -219,7 +220,8 @@ class TestOpenAIChatGenerator:
assert [isinstance(reply, ChatMessage) for reply in response["replies"]] assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
def test_run_with_streaming_callback_in_run_method(self, chat_messages, mock_chat_completion_chunk): @patch("haystack.components.generators.chat.openai.datetime")
def test_run_with_streaming_callback_in_run_method(self, mock_datetime, chat_messages, mock_chat_completion_chunk):
streaming_callback_called = False streaming_callback_called = False
def streaming_callback(chunk: StreamingChunk) -> None: def streaming_callback(chunk: StreamingChunk) -> None:
@ -240,6 +242,13 @@ class TestOpenAIChatGenerator:
assert [isinstance(reply, ChatMessage) for reply in response["replies"]] assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
assert hasattr(response["replies"][0], "meta")
assert isinstance(response["replies"][0].meta, dict)
assert (
response["replies"][0].meta["completion_start_time"]
== mock_datetime.now.return_value.isoformat.return_value
)
def test_check_abnormal_completions(self, caplog): def test_check_abnormal_completions(self, caplog):
caplog.set_level(logging.INFO) caplog.set_level(logging.INFO)
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))