Fix iostream on new thread (#2181)

* fixed get_stream in new thread by introducing a global default

* fixed get_stream in new thread by introducing a global default

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Davor Runje 2024-03-28 15:26:01 +01:00 committed by GitHub
parent f467f21ec9
commit 21a7eb3115
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 54 additions and 7 deletions

View File

@ -3,6 +3,7 @@ from .console import IOConsole
from .websockets import IOWebsockets
# Set the default input/output stream to the console
IOStream._default_io_stream.set(IOConsole())
IOStream.set_global_default(IOConsole())
IOStream.set_default(IOConsole())
__all__ = ("IOConsole", "IOStream", "InputStream", "OutputStream", "IOWebsockets")

View File

@ -1,9 +1,12 @@
from contextlib import contextmanager
from contextvars import ContextVar
import logging
from typing import Any, Iterator, Optional, Protocol, runtime_checkable
__all__ = ("OutputStream", "InputStream", "IOStream")
logger = logging.getLogger(__name__)
@runtime_checkable
class OutputStream(Protocol):
@ -39,6 +42,31 @@ class InputStream(Protocol):
class IOStream(InputStream, OutputStream, Protocol):
"""A protocol for input/output streams."""
# ContextVar must be used in multithreaded or async environments
_default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream", default=None)
_default_io_stream.set(None)
_global_default: Optional["IOStream"] = None
@staticmethod
def set_global_default(stream: "IOStream") -> None:
"""Set the default input/output stream.
Args:
stream (IOStream): The input/output stream to set as the default.
"""
IOStream._global_default = stream
@staticmethod
def get_global_default() -> "IOStream":
"""Get the default input/output stream.
Returns:
IOStream: The default input/output stream.
"""
if IOStream._global_default is None:
raise RuntimeError("No global default IOStream has been set")
return IOStream._global_default
@staticmethod
def get_default() -> "IOStream":
"""Get the default input/output stream.
@ -48,13 +76,10 @@ class IOStream(InputStream, OutputStream, Protocol):
"""
iostream = IOStream._default_io_stream.get()
if iostream is None:
raise RuntimeError("No default IOStream has been set")
logger.warning("No default IOStream has been set, defaulting to IOConsole.")
return IOStream.get_global_default()
return iostream
# ContextVar must be used in multithreaded or async environments
_default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream")
_default_io_stream.set(None)
@staticmethod
@contextmanager
def set_default(stream: Optional["IOStream"]) -> Iterator[None]:

View File

@ -1,4 +1,5 @@
from typing import Any
from threading import Thread
from typing import Any, List
from autogen.io import IOConsole, IOStream, IOWebsockets
@ -26,3 +27,23 @@ class TestIOStream:
assert isinstance(IOStream.get_default(), MyIOStream)
assert isinstance(IOStream.get_default(), IOConsole)
def test_get_default_on_new_thread(self) -> None:
exceptions: List[Exception] = []
def on_new_thread(exceptions: List[Exception] = exceptions) -> None:
try:
assert isinstance(IOStream.get_default(), IOConsole)
except Exception as e:
exceptions.append(e)
# create a new thread and run the function
thread = Thread(target=on_new_thread)
thread.start()
# get exception from the thread
thread.join()
if exceptions:
raise exceptions[0]