mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-14 08:37:54 +00:00
Fix iostream on new thread (#2181)
* fixed get_stream in new thread by introducing a global default * fixed get_stream in new thread by introducing a global default --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
f467f21ec9
commit
21a7eb3115
@ -3,6 +3,7 @@ from .console import IOConsole
|
|||||||
from .websockets import IOWebsockets
|
from .websockets import IOWebsockets
|
||||||
|
|
||||||
# Set the default input/output stream to the console
|
# Set the default input/output stream to the console
|
||||||
IOStream._default_io_stream.set(IOConsole())
|
IOStream.set_global_default(IOConsole())
|
||||||
|
IOStream.set_default(IOConsole())
|
||||||
|
|
||||||
__all__ = ("IOConsole", "IOStream", "InputStream", "OutputStream", "IOWebsockets")
|
__all__ = ("IOConsole", "IOStream", "InputStream", "OutputStream", "IOWebsockets")
|
||||||
|
|||||||
@ -1,9 +1,12 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
import logging
|
||||||
from typing import Any, Iterator, Optional, Protocol, runtime_checkable
|
from typing import Any, Iterator, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
__all__ = ("OutputStream", "InputStream", "IOStream")
|
__all__ = ("OutputStream", "InputStream", "IOStream")
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class OutputStream(Protocol):
|
class OutputStream(Protocol):
|
||||||
@ -39,6 +42,31 @@ class InputStream(Protocol):
|
|||||||
class IOStream(InputStream, OutputStream, Protocol):
|
class IOStream(InputStream, OutputStream, Protocol):
|
||||||
"""A protocol for input/output streams."""
|
"""A protocol for input/output streams."""
|
||||||
|
|
||||||
|
# ContextVar must be used in multithreaded or async environments
|
||||||
|
_default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream", default=None)
|
||||||
|
_default_io_stream.set(None)
|
||||||
|
_global_default: Optional["IOStream"] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_global_default(stream: "IOStream") -> None:
|
||||||
|
"""Set the default input/output stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream (IOStream): The input/output stream to set as the default.
|
||||||
|
"""
|
||||||
|
IOStream._global_default = stream
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_global_default() -> "IOStream":
|
||||||
|
"""Get the default input/output stream.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IOStream: The default input/output stream.
|
||||||
|
"""
|
||||||
|
if IOStream._global_default is None:
|
||||||
|
raise RuntimeError("No global default IOStream has been set")
|
||||||
|
return IOStream._global_default
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_default() -> "IOStream":
|
def get_default() -> "IOStream":
|
||||||
"""Get the default input/output stream.
|
"""Get the default input/output stream.
|
||||||
@ -48,13 +76,10 @@ class IOStream(InputStream, OutputStream, Protocol):
|
|||||||
"""
|
"""
|
||||||
iostream = IOStream._default_io_stream.get()
|
iostream = IOStream._default_io_stream.get()
|
||||||
if iostream is None:
|
if iostream is None:
|
||||||
raise RuntimeError("No default IOStream has been set")
|
logger.warning("No default IOStream has been set, defaulting to IOConsole.")
|
||||||
|
return IOStream.get_global_default()
|
||||||
return iostream
|
return iostream
|
||||||
|
|
||||||
# ContextVar must be used in multithreaded or async environments
|
|
||||||
_default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream")
|
|
||||||
_default_io_stream.set(None)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def set_default(stream: Optional["IOStream"]) -> Iterator[None]:
|
def set_default(stream: Optional["IOStream"]) -> Iterator[None]:
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from typing import Any
|
from threading import Thread
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
from autogen.io import IOConsole, IOStream, IOWebsockets
|
from autogen.io import IOConsole, IOStream, IOWebsockets
|
||||||
|
|
||||||
@ -26,3 +27,23 @@ class TestIOStream:
|
|||||||
assert isinstance(IOStream.get_default(), MyIOStream)
|
assert isinstance(IOStream.get_default(), MyIOStream)
|
||||||
|
|
||||||
assert isinstance(IOStream.get_default(), IOConsole)
|
assert isinstance(IOStream.get_default(), IOConsole)
|
||||||
|
|
||||||
|
def test_get_default_on_new_thread(self) -> None:
|
||||||
|
exceptions: List[Exception] = []
|
||||||
|
|
||||||
|
def on_new_thread(exceptions: List[Exception] = exceptions) -> None:
|
||||||
|
try:
|
||||||
|
assert isinstance(IOStream.get_default(), IOConsole)
|
||||||
|
except Exception as e:
|
||||||
|
exceptions.append(e)
|
||||||
|
|
||||||
|
# create a new thread and run the function
|
||||||
|
thread = Thread(target=on_new_thread)
|
||||||
|
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# get exception from the thread
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
if exceptions:
|
||||||
|
raise exceptions[0]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user