feat: Add TTFT support in OpenAI chat generator (#8444)

* feat: Add TTFT support in OpenAI generators

* pylint fixes

* correct disable

---------

Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
This commit is contained in:
Bohan Qu 2024-10-31 23:56:17 +08:00 committed by GitHub
parent 294a67e426
commit 2595e68050
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 3 deletions

View File

@ -5,6 +5,7 @@
import copy
import json
import os
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union
from openai import OpenAI, Stream
@ -63,7 +64,7 @@ class OpenAIChatGenerator:
```
"""
def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "gpt-4o-mini",
@ -222,11 +223,15 @@ class OpenAIChatGenerator:
raise ValueError("Cannot stream multiple responses, please set n=1.")
chunks: List[StreamingChunk] = []
chunk = None
_first_token = True
# pylint: disable=not-an-iterable
for chunk in chat_completion:
if chunk.choices and streaming_callback:
chunk_delta: StreamingChunk = self._build_chunk(chunk)
if _first_token:
_first_token = False
chunk_delta.meta["completion_start_time"] = datetime.now().isoformat()
chunks.append(chunk_delta)
streaming_callback(chunk_delta) # invoke callback with the chunk_delta
completions = [self._connect_chunks(chunk, chunks)]
@ -280,7 +285,12 @@ class OpenAIChatGenerator:
payload["function"]["arguments"] += delta.arguments or ""
complete_response = ChatMessage.from_assistant(json.dumps(payloads))
else:
complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in chunks]))
total_content = ""
total_meta = {}
for streaming_chunk in chunks:
total_content += streaming_chunk.content
total_meta.update(streaming_chunk.meta)
complete_response = ChatMessage.from_assistant(total_content, meta=total_meta)
complete_response.meta.update(
{
"model": chunk.model,

View File

@ -0,0 +1,6 @@
---
features:
- |
Add TTFT (Time-to-First-Token) support for OpenAI generators. This
captures the time taken to generate the first token from the model and
can be used to analyze the latency of the application.

View File

@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import os
from unittest.mock import patch
import pytest
from openai import OpenAIError
@ -219,7 +220,8 @@ class TestOpenAIChatGenerator:
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
def test_run_with_streaming_callback_in_run_method(self, chat_messages, mock_chat_completion_chunk):
@patch("haystack.components.generators.chat.openai.datetime")
def test_run_with_streaming_callback_in_run_method(self, mock_datetime, chat_messages, mock_chat_completion_chunk):
streaming_callback_called = False
def streaming_callback(chunk: StreamingChunk) -> None:
@ -240,6 +242,13 @@ class TestOpenAIChatGenerator:
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
assert hasattr(response["replies"][0], "meta")
assert isinstance(response["replies"][0].meta, dict)
assert (
response["replies"][0].meta["completion_start_time"]
== mock_datetime.now.return_value.isoformat.return_value
)
def test_check_abnormal_completions(self, caplog):
caplog.set_level(logging.INFO)
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))