mirror of
https://github.com/langgenius/dify.git
synced 2025-12-05 23:36:09 +00:00
This PR introduces a `BroadcastChannel` abstraction with broadcasting and at-most once delivery semantics, serving as the communication component between celery worker and API server. It also includes a reference implementation backed by Redis PubSub. Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
201 lines
6.4 KiB
Python
201 lines
6.4 KiB
Python
import logging
|
|
import queue
|
|
import threading
|
|
import types
|
|
from collections.abc import Generator, Iterator
|
|
from typing import Self
|
|
|
|
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
|
from libs.broadcast_channel.exc import SubscriptionClosedError
|
|
from redis import Redis
|
|
from redis.client import PubSub
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BroadcastChannel:
|
|
"""
|
|
Redis Pub/Sub based broadcast channel implementation.
|
|
|
|
Provides "at most once" delivery semantics for messages published to channels.
|
|
Uses Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
|
|
|
|
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
redis_client: Redis,
|
|
):
|
|
self._client = redis_client
|
|
|
|
def topic(self, topic: str) -> "Topic":
|
|
return Topic(self._client, topic)
|
|
|
|
|
|
class Topic:
|
|
def __init__(self, redis_client: Redis, topic: str):
|
|
self._client = redis_client
|
|
self._topic = topic
|
|
|
|
def as_producer(self) -> Producer:
|
|
return self
|
|
|
|
def publish(self, payload: bytes) -> None:
|
|
self._client.publish(self._topic, payload)
|
|
|
|
def as_subscriber(self) -> Subscriber:
|
|
return self
|
|
|
|
def subscribe(self) -> Subscription:
|
|
return _RedisSubscription(
|
|
pubsub=self._client.pubsub(),
|
|
topic=self._topic,
|
|
)
|
|
|
|
|
|
class _RedisSubscription(Subscription):
|
|
def __init__(
|
|
self,
|
|
pubsub: PubSub,
|
|
topic: str,
|
|
):
|
|
# The _pubsub is None only if the subscription is closed.
|
|
self._pubsub: PubSub | None = pubsub
|
|
self._topic = topic
|
|
self._closed = threading.Event()
|
|
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
|
|
self._dropped_count = 0
|
|
self._listener_thread: threading.Thread | None = None
|
|
self._start_lock = threading.Lock()
|
|
self._started = False
|
|
|
|
def _start_if_needed(self) -> None:
|
|
with self._start_lock:
|
|
if self._started:
|
|
return
|
|
if self._closed.is_set():
|
|
raise SubscriptionClosedError("The Redis subscription is closed")
|
|
if self._pubsub is None:
|
|
raise SubscriptionClosedError("The Redis subscription has been cleaned up")
|
|
|
|
self._pubsub.subscribe(self._topic)
|
|
_logger.debug("Subscribed to channel %s", self._topic)
|
|
|
|
self._listener_thread = threading.Thread(
|
|
target=self._listen,
|
|
name=f"redis-broadcast-{self._topic}",
|
|
daemon=True,
|
|
)
|
|
self._listener_thread.start()
|
|
self._started = True
|
|
|
|
def _listen(self) -> None:
|
|
pubsub = self._pubsub
|
|
assert pubsub is not None, "PubSub should not be None while starting listening."
|
|
while not self._closed.is_set():
|
|
raw_message = pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
|
|
|
|
if raw_message is None:
|
|
continue
|
|
|
|
if raw_message.get("type") != "message":
|
|
continue
|
|
|
|
channel_field = raw_message.get("channel")
|
|
if isinstance(channel_field, bytes):
|
|
channel_name = channel_field.decode("utf-8")
|
|
elif isinstance(channel_field, str):
|
|
channel_name = channel_field
|
|
else:
|
|
channel_name = str(channel_field)
|
|
|
|
if channel_name != self._topic:
|
|
_logger.warning("Ignoring message from unexpected channel %s", channel_name)
|
|
continue
|
|
|
|
payload_bytes: bytes | None = raw_message.get("data")
|
|
if not isinstance(payload_bytes, bytes):
|
|
_logger.error("Received invalid data from channel %s, type=%s", self._topic, type(payload_bytes))
|
|
continue
|
|
|
|
self._enqueue_message(payload_bytes)
|
|
|
|
_logger.debug("Listener thread stopped for channel %s", self._topic)
|
|
pubsub.unsubscribe(self._topic)
|
|
pubsub.close()
|
|
_logger.debug("PubSub closed for topic %s", self._topic)
|
|
self._pubsub = None
|
|
|
|
def _enqueue_message(self, payload: bytes) -> None:
|
|
while not self._closed.is_set():
|
|
try:
|
|
self._queue.put_nowait(payload)
|
|
return
|
|
except queue.Full:
|
|
try:
|
|
self._queue.get_nowait()
|
|
self._dropped_count += 1
|
|
_logger.debug(
|
|
"Dropped message from Redis subscription, topic=%s, total_dropped=%d",
|
|
self._topic,
|
|
self._dropped_count,
|
|
)
|
|
except queue.Empty:
|
|
continue
|
|
return
|
|
|
|
def _message_iterator(self) -> Generator[bytes, None, None]:
|
|
while not self._closed.is_set():
|
|
try:
|
|
item = self._queue.get(timeout=0.1)
|
|
except queue.Empty:
|
|
continue
|
|
|
|
yield item
|
|
|
|
def __iter__(self) -> Iterator[bytes]:
|
|
if self._closed.is_set():
|
|
raise SubscriptionClosedError("The Redis subscription is closed")
|
|
self._start_if_needed()
|
|
return iter(self._message_iterator())
|
|
|
|
def receive(self, timeout: float | None = None) -> bytes | None:
|
|
if self._closed.is_set():
|
|
raise SubscriptionClosedError("The Redis subscription is closed")
|
|
self._start_if_needed()
|
|
|
|
try:
|
|
item = self._queue.get(timeout=timeout)
|
|
except queue.Empty:
|
|
return None
|
|
|
|
return item
|
|
|
|
def __enter__(self) -> Self:
|
|
self._start_if_needed()
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_value: BaseException | None,
|
|
traceback: types.TracebackType | None,
|
|
) -> bool | None:
|
|
self.close()
|
|
return None
|
|
|
|
def close(self) -> None:
|
|
if self._closed.is_set():
|
|
return
|
|
|
|
self._closed.set()
|
|
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the `PubSub.get_message`
|
|
# method should NOT be called concurrently.
|
|
#
|
|
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
|
|
listener = self._listener_thread
|
|
if listener is not None:
|
|
listener.join(timeout=1.0)
|
|
self._listener_thread = None
|