mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-25 16:16:37 +00:00
Fix iostream on new thread (#2181)
* fixed get_stream in new thread by introducing a global default * fixed get_stream in new thread by introducing a global default --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
f467f21ec9
commit
21a7eb3115
@ -3,6 +3,7 @@ from .console import IOConsole
|
||||
from .websockets import IOWebsockets
|
||||
|
||||
# Set the default input/output stream to the console
|
||||
IOStream._default_io_stream.set(IOConsole())
|
||||
IOStream.set_global_default(IOConsole())
|
||||
IOStream.set_default(IOConsole())
|
||||
|
||||
__all__ = ("IOConsole", "IOStream", "InputStream", "OutputStream", "IOWebsockets")
|
||||
|
@ -1,9 +1,12 @@
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
import logging
|
||||
from typing import Any, Iterator, Optional, Protocol, runtime_checkable
|
||||
|
||||
__all__ = ("OutputStream", "InputStream", "IOStream")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OutputStream(Protocol):
|
||||
@ -39,6 +42,31 @@ class InputStream(Protocol):
|
||||
class IOStream(InputStream, OutputStream, Protocol):
|
||||
"""A protocol for input/output streams."""
|
||||
|
||||
# ContextVar must be used in multithreaded or async environments
|
||||
_default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream", default=None)
|
||||
_default_io_stream.set(None)
|
||||
_global_default: Optional["IOStream"] = None
|
||||
|
||||
@staticmethod
|
||||
def set_global_default(stream: "IOStream") -> None:
|
||||
"""Set the default input/output stream.
|
||||
|
||||
Args:
|
||||
stream (IOStream): The input/output stream to set as the default.
|
||||
"""
|
||||
IOStream._global_default = stream
|
||||
|
||||
@staticmethod
|
||||
def get_global_default() -> "IOStream":
|
||||
"""Get the default input/output stream.
|
||||
|
||||
Returns:
|
||||
IOStream: The default input/output stream.
|
||||
"""
|
||||
if IOStream._global_default is None:
|
||||
raise RuntimeError("No global default IOStream has been set")
|
||||
return IOStream._global_default
|
||||
|
||||
@staticmethod
|
||||
def get_default() -> "IOStream":
|
||||
"""Get the default input/output stream.
|
||||
@ -48,13 +76,10 @@ class IOStream(InputStream, OutputStream, Protocol):
|
||||
"""
|
||||
iostream = IOStream._default_io_stream.get()
|
||||
if iostream is None:
|
||||
raise RuntimeError("No default IOStream has been set")
|
||||
logger.warning("No default IOStream has been set, defaulting to IOConsole.")
|
||||
return IOStream.get_global_default()
|
||||
return iostream
|
||||
|
||||
# ContextVar must be used in multithreaded or async environments
|
||||
_default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream")
|
||||
_default_io_stream.set(None)
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def set_default(stream: Optional["IOStream"]) -> Iterator[None]:
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
from threading import Thread
|
||||
from typing import Any, List
|
||||
|
||||
from autogen.io import IOConsole, IOStream, IOWebsockets
|
||||
|
||||
@ -26,3 +27,23 @@ class TestIOStream:
|
||||
assert isinstance(IOStream.get_default(), MyIOStream)
|
||||
|
||||
assert isinstance(IOStream.get_default(), IOConsole)
|
||||
|
||||
def test_get_default_on_new_thread(self) -> None:
|
||||
exceptions: List[Exception] = []
|
||||
|
||||
def on_new_thread(exceptions: List[Exception] = exceptions) -> None:
|
||||
try:
|
||||
assert isinstance(IOStream.get_default(), IOConsole)
|
||||
except Exception as e:
|
||||
exceptions.append(e)
|
||||
|
||||
# create a new thread and run the function
|
||||
thread = Thread(target=on_new_thread)
|
||||
|
||||
thread.start()
|
||||
|
||||
# get exception from the thread
|
||||
thread.join()
|
||||
|
||||
if exceptions:
|
||||
raise exceptions[0]
|
||||
|
Loading…
x
Reference in New Issue
Block a user