add LLMStreamStartEvent and LLMStreamEndEvent (#5890)

These changes are needed because there is currently no way to get
logging information about Streaming LLM requests/responses.

I decided to put the StreamStart event AFTER the first chunk so there
aren't false positives about connections/auth.

Closes #5730
---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Eitan Yarmush 2025-03-11 18:02:46 -04:00 committed by GitHub
parent 2cc8c73d3b
commit 817f728d04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 279 additions and 54 deletions

View File

@ -29,11 +29,14 @@ class LLMCallEvent:
.. code-block:: python
import logging
from autogen_core import EVENT_LOGGER_NAME
from autogen_core.logging import LLMCallEvent
response = {"content": "Hello, world!"}
messages = [{"role": "user", "content": "Hello, world!"}]
logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.info(LLMCallEvent(prompt_tokens=10, completion_tokens=20))
logger.info(LLMCallEvent(prompt_tokens=10, completion_tokens=20, response=response, messages=messages))
"""
self.kwargs = kwargs
@ -61,6 +64,99 @@ class LLMCallEvent:
return json.dumps(self.kwargs)
class LLMStreamStartEvent:
"""To be used by model clients to log the start of a stream.
Args:
messages (List[Dict[str, Any]]): The messages used in the call. Must be json serializable.
Example:
.. code-block:: python
import logging
from autogen_core import EVENT_LOGGER_NAME
from autogen_core.logging import LLMStreamStartEvent
messages = [{"role": "user", "content": "Hello, world!"}]
logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.info(LLMStreamStartEvent(messages=messages))
"""
def __init__(
self,
*,
messages: List[Dict[str, Any]],
**kwargs: Any,
) -> None:
self.kwargs = kwargs
self.kwargs["type"] = "LLMStreamStart"
self.kwargs["messages"] = messages
try:
agent_id = MessageHandlerContext.agent_id()
except RuntimeError:
agent_id = None
self.kwargs["agent_id"] = None if agent_id is None else str(agent_id)
# This must output the event in a json serializable format
def __str__(self) -> str:
return json.dumps(self.kwargs)
class LLMStreamEndEvent:
def __init__(
self,
*,
response: Dict[str, Any],
prompt_tokens: int,
completion_tokens: int,
**kwargs: Any,
) -> None:
"""To be used by model to log the call to the LLM.
Args:
response (Dict[str, Any]): The response of the call. Must be json serializable.
prompt_tokens (int): Number of tokens used in the prompt.
completion_tokens (int): Number of tokens used in the completion.
Example:
.. code-block:: python
import logging
from autogen_core import EVENT_LOGGER_NAME
from autogen_core.logging import LLMStreamEndEvent
response = {"content": "Hello, world!"}
logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.info(LLMStreamEndEvent(prompt_tokens=10, completion_tokens=20, response=response))
"""
self.kwargs = kwargs
self.kwargs["type"] = "LLMStreamEnd"
self.kwargs["response"] = response
self.kwargs["prompt_tokens"] = prompt_tokens
self.kwargs["completion_tokens"] = completion_tokens
try:
agent_id = MessageHandlerContext.agent_id()
except RuntimeError:
agent_id = None
self.kwargs["agent_id"] = None if agent_id is None else str(agent_id)
@property
def prompt_tokens(self) -> int:
return cast(int, self.kwargs["prompt_tokens"])
@property
def completion_tokens(self) -> int:
return cast(int, self.kwargs["completion_tokens"])
# This must output the event in a json serializable format
def __str__(self) -> str:
return json.dumps(self.kwargs)
class ToolCallEvent:
def __init__(
self,

View File

@ -45,7 +45,7 @@ from autogen_core import (
FunctionCall,
Image,
)
from autogen_core.logging import LLMCallEvent
from autogen_core.logging import LLMCallEvent, LLMStreamEndEvent, LLMStreamStartEvent
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
@ -665,8 +665,18 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
output_tokens: int = 0
stop_reason: Optional[str] = None
first_chunk = True
# Process the stream
async for chunk in stream:
if first_chunk:
first_chunk = False
# Emit the start event.
logger.info(
LLMStreamStartEvent(
messages=cast(List[Dict[str, Any]], anthropic_messages),
)
)
# Handle different event types
if chunk.type == "content_block_start":
if chunk.content_block.type == "tool_use":
@ -761,6 +771,15 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
thought=thought,
)
# Emit the end event.
logger.info(
LLMStreamEndEvent(
response=result.model_dump(),
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
)
)
# Update usage statistics
self._total_usage = _add_usage(self._total_usage, usage)
self._actual_usage = _add_usage(self._actual_usage, usage)

View File

@ -6,7 +6,7 @@ from inspect import getfullargspec
from typing import Any, Dict, List, Mapping, Optional, Sequence, cast
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, FunctionCall, Image
from autogen_core.logging import LLMCallEvent
from autogen_core.logging import LLMCallEvent, LLMStreamEndEvent, LLMStreamStartEvent
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
@ -430,7 +430,16 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
completion_tokens = 0
chunk: Optional[StreamingChatCompletionsUpdate] = None
choice: Optional[StreamingChatChoiceUpdate] = None
first_chunk = True
async for chunk in await task: # type: ignore
if first_chunk:
first_chunk = False
# Emit the start event.
logger.info(
LLMStreamStartEvent(
messages=[m.as_dict() for m in azure_messages],
)
)
assert isinstance(chunk, StreamingChatCompletionsUpdate)
choice = chunk.choices[0] if len(chunk.choices) > 0 else None
if choice and choice.finish_reason is not None:
@ -499,6 +508,15 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
thought=thought,
)
# Log the end of the stream.
logger.info(
LLMStreamEndEvent(
response=result.model_dump(),
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
)
)
self.add_usage(usage)
yield result

View File

@ -27,7 +27,7 @@ from autogen_core import (
FunctionCall,
Image,
)
from autogen_core.logging import LLMCallEvent
from autogen_core.logging import LLMCallEvent, LLMStreamEndEvent, LLMStreamStartEvent
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
@ -676,7 +676,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
content_chunks: List[str] = []
full_tool_calls: List[FunctionCall] = []
completion_tokens = 0
first_chunk = True
while True:
try:
chunk_future = asyncio.ensure_future(anext(stream))
@ -684,6 +684,14 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
cancellation_token.link_future(chunk_future)
chunk = await chunk_future
if first_chunk:
first_chunk = False
# Emit the start event.
logger.info(
LLMStreamStartEvent(
messages=cast(List[Dict[str, Any]], ollama_messages),
)
)
# set the stop_reason for the usage chunk to the prior stop_reason
stop_reason = chunk.done_reason if chunk.done and stop_reason is None else stop_reason
# First try get content
@ -759,6 +767,15 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
logprobs=None,
)
# Emit the end event.
logger.info(
LLMStreamEndEvent(
response=result.model_dump(),
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
)
)
self._total_usage = _add_usage(self._total_usage, usage)
self._actual_usage = _add_usage(self._actual_usage, usage)

View File

@ -30,7 +30,7 @@ from autogen_core import (
FunctionCall,
Image,
)
from autogen_core.logging import LLMCallEvent
from autogen_core.logging import LLMCallEvent, LLMStreamEndEvent, LLMStreamStartEvent
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
@ -748,8 +748,18 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
empty_chunk_warning_threshold: int = 10
empty_chunk_count = 0
first_chunk = True
# Process the stream of chunks.
async for chunk in chunks:
if first_chunk:
first_chunk = False
# Emit the start event.
logger.info(
LLMStreamStartEvent(
messages=cast(List[Dict[str, Any]], oai_messages),
)
)
# Empty chunks has been observed when the endpoint is under heavy load.
# https://github.com/microsoft/autogen/issues/4213
if len(chunk.choices) == 0:
@ -869,6 +879,15 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
thought=thought,
)
# Log the end of the stream.
logger.info(
LLMStreamEndEvent(
response=result.model_dump(),
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
)
)
# Update the total usage.
self._total_usage = _add_usage(self._total_usage, usage)
self._actual_usage = _add_usage(self._actual_usage, usage)

View File

@ -5,7 +5,7 @@ from typing import Any, Literal, Mapping, Optional, Sequence
from autogen_core import EVENT_LOGGER_NAME, FunctionCall
from autogen_core._cancellation_token import CancellationToken
from autogen_core.logging import LLMCallEvent
from autogen_core.logging import LLMCallEvent, LLMStreamEndEvent, LLMStreamStartEvent
from autogen_core.models import (
ChatCompletionClient,
CreateResult,
@ -577,9 +577,19 @@ class SKChatCompletionAdapter(ChatCompletionClient):
# accumulating chunk arguments for that call if new items have id=None
last_function_call_id: Optional[str] = None
first_chunk = True
async for streaming_messages in self._sk_client.get_streaming_chat_message_contents(
chat_history, settings=settings, kernel=kernel
):
if first_chunk:
first_chunk = False
# Emit the start event.
logger.info(
LLMStreamStartEvent(
messages=[msg.model_dump() for msg in chat_history],
)
)
for msg in streaming_messages:
# Track token usage
if msg.metadata and "usage" in msg.metadata:
@ -659,7 +669,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
if isinstance(accumulated_text, str) and self._model_info["family"] == ModelFamily.R1:
thought, accumulated_text = parse_r1_content(accumulated_text)
yield CreateResult(
result = CreateResult(
content=accumulated_text,
finish_reason="stop",
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
@ -667,6 +677,17 @@ class SKChatCompletionAdapter(ChatCompletionClient):
thought=thought,
)
# Emit the end event.
logger.info(
LLMStreamEndEvent(
response=result.model_dump(),
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
)
yield result
async def close(self) -> None:
pass # No explicit close method in SK client?

View File

@ -73,7 +73,7 @@ async def test_anthropic_basic_completion(caplog: pytest.LogCaptureFixture) -> N
@pytest.mark.asyncio
async def test_anthropic_streaming() -> None:
async def test_anthropic_streaming(caplog: pytest.LogCaptureFixture) -> None:
"""Test streaming capabilities with Claude."""
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
@ -86,20 +86,28 @@ async def test_anthropic_streaming() -> None:
# Test streaming completion
chunks: List[str | CreateResult] = []
async for chunk in client.create_stream(
messages=[
UserMessage(content="Count from 1 to 5. Each number on its own line.", source="user"),
]
):
chunks.append(chunk)
prompt = "Count from 1 to 5. Each number on its own line."
with caplog.at_level(logging.INFO):
async for chunk in client.create_stream(
messages=[
UserMessage(content=prompt, source="user"),
]
):
chunks.append(chunk)
# Verify we got multiple chunks
assert len(chunks) > 1
# Verify we got multiple chunks
assert len(chunks) > 1
# Check final result
final_result = chunks[-1]
assert isinstance(final_result, CreateResult)
assert final_result.finish_reason == "stop"
# Check final result
final_result = chunks[-1]
assert isinstance(final_result, CreateResult)
assert final_result.finish_reason == "stop"
assert "LLMStreamStart" in caplog.text
assert "LLMStreamEnd" in caplog.text
assert isinstance(final_result.content, str)
for i in range(1, 6):
assert str(i) in caplog.text
assert prompt in caplog.text
# Check content contains numbers 1-5
assert isinstance(final_result.content, str)

View File

@ -181,10 +181,21 @@ async def test_azure_ai_chat_completion_client_create(
@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client_create_stream(azure_client: AzureAIChatCompletionClient) -> None:
chunks: List[str | CreateResult] = []
async for chunk in azure_client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
chunks.append(chunk)
async def test_azure_ai_chat_completion_client_create_stream(
azure_client: AzureAIChatCompletionClient, caplog: pytest.LogCaptureFixture
) -> None:
with caplog.at_level(logging.INFO):
chunks: List[str | CreateResult] = []
async for chunk in azure_client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
chunks.append(chunk)
assert "LLMStreamStart" in caplog.text
assert "LLMStreamEnd" in caplog.text
final_result: str | CreateResult = chunks[-1]
assert isinstance(final_result, CreateResult)
assert isinstance(final_result.content, str)
assert final_result.content in caplog.text
assert chunks[0] == "Hello"
assert chunks[1] == " Another Hello"

View File

@ -235,22 +235,31 @@ async def test_openai_chat_completion_client_create(
@pytest.mark.asyncio
async def test_openai_chat_completion_client_create_stream_with_usage(monkeypatch: pytest.MonkeyPatch) -> None:
async def test_openai_chat_completion_client_create_stream_with_usage(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
chunks: List[str | CreateResult] = []
async for chunk in client.create_stream(
messages=[UserMessage(content="Hello", source="user")],
# include_usage not the default of the OPENAI API and must be explicitly set
extra_create_args={"stream_options": {"include_usage": True}},
):
chunks.append(chunk)
assert chunks[0] == "Hello"
assert chunks[1] == " Another Hello"
assert chunks[2] == " Yet Another Hello"
assert isinstance(chunks[-1], CreateResult)
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
assert chunks[-1].usage == RequestUsage(prompt_tokens=3, completion_tokens=3)
with caplog.at_level(logging.INFO):
async for chunk in client.create_stream(
messages=[UserMessage(content="Hello", source="user")],
# include_usage not the default of the OPENAI API and must be explicitly set
extra_create_args={"stream_options": {"include_usage": True}},
):
chunks.append(chunk)
assert "LLMStreamStart" in caplog.text
assert "LLMStreamEnd" in caplog.text
assert chunks[0] == "Hello"
assert chunks[1] == " Another Hello"
assert chunks[2] == " Yet Another Hello"
assert isinstance(chunks[-1], CreateResult)
assert isinstance(chunks[-1].content, str)
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
assert chunks[-1].content in caplog.text
assert chunks[-1].usage == RequestUsage(prompt_tokens=3, completion_tokens=3)
@pytest.mark.asyncio

View File

@ -396,7 +396,9 @@ async def test_sk_chat_completion_stream_with_tools(sk_client: AzureChatCompleti
@pytest.mark.asyncio
async def test_sk_chat_completion_stream_without_tools(sk_client: AzureChatCompletion) -> None:
async def test_sk_chat_completion_stream_without_tools(
sk_client: AzureChatCompletion, caplog: pytest.LogCaptureFixture
) -> None:
# Create adapter and kernel
adapter = SKChatCompletionAdapter(sk_client)
kernel = Kernel(memory=NullMemory())
@ -409,23 +411,28 @@ async def test_sk_chat_completion_stream_without_tools(sk_client: AzureChatCompl
# Call create_stream without tools
response_chunks: list[CreateResult | str] = []
async for chunk in adapter.create_stream(messages=messages, extra_create_args={"kernel": kernel}):
response_chunks.append(chunk)
with caplog.at_level(logging.INFO):
async for chunk in adapter.create_stream(messages=messages, extra_create_args={"kernel": kernel}):
response_chunks.append(chunk)
# Verify response
assert len(response_chunks) > 0
# All chunks except last should be strings
for chunk in response_chunks[:-1]:
assert isinstance(chunk, str)
assert "LLMStreamStart" in caplog.text
assert "LLMStreamEnd" in caplog.text
# Final chunk should be CreateResult
final_chunk = response_chunks[-1]
assert isinstance(final_chunk, CreateResult)
assert isinstance(final_chunk.content, str)
assert final_chunk.finish_reason == "stop"
assert final_chunk.usage.prompt_tokens >= 0
assert final_chunk.usage.completion_tokens >= 0
assert not final_chunk.cached
# Verify response
assert len(response_chunks) > 0
# All chunks except last should be strings
for chunk in response_chunks[:-1]:
assert isinstance(chunk, str)
# Final chunk should be CreateResult
final_chunk = response_chunks[-1]
assert isinstance(final_chunk, CreateResult)
assert isinstance(final_chunk.content, str)
assert final_chunk.finish_reason == "stop"
assert final_chunk.usage.prompt_tokens >= 0
assert final_chunk.usage.completion_tokens >= 0
assert not final_chunk.cached
assert final_chunk.content in caplog.text
@pytest.mark.asyncio