Source code for noob.network.loop

import asyncio
import sys
from collections import defaultdict
from collections.abc import Callable, Coroutine
from typing import Any

try:
    from zmq.asyncio import Context, Socket
except ImportError as e:
    raise ImportError(
        "Attempted to import zmq runner, but zmq deps are not installed. install with `noob[zmq]`",
    ) from e

if sys.version_info < (3, 12):
    from typing_extensions import TypedDict
else:
    from typing import TypedDict

from noob.logging import init_logger
from noob.network.message import Message
from noob.utils import iscoroutinefunction_partial


class _CallbackDict(TypedDict):
    sync: list[Callable[[Message], Any]]
    asyncio: list[Callable[[Message], Coroutine]]


[docs] class EventloopMixin: """ Mixin to provide common asyncio zmq scaffolding to networked classes. Inheriting classes should, in order * call the ``_init_loop`` method to create the eventloop, context, and poller * populate the private ``_sockets`` and ``_receivers`` dicts * await the ``_poll_sockets`` method, which polls indefinitely. Inheriting classes **must** ensure that ``_init_loop`` is called in the thread it is intended to run in, and that thread must already have a running eventloop. asyncio eventloops (and most of asyncio) are **not** thread safe. To help avoid cross-threading issues, the :meth:`.context` and :meth:`.loop` properties do *not* automatically create the objects, raising a :class:`.RuntimeError` if they are accessed before ``_init_loop`` is called. """ def __init__(self): self._context = None self._loop = None self._quitting: asyncio.Event = None # type: ignore[assignment] self._sockets: dict[str, Socket] = {} """ All sockets, mapped from some common name to the socket. The same key used here should be shared between _receivers and _callbacks """ self._receivers: dict[str, Socket] = {} """Sockets that should be polled for incoming messages""" self._callbacks: dict[str, _CallbackDict] = defaultdict( lambda: _CallbackDict(sync=[], asyncio=[]) ) """Callbacks for each receiver socket""" if not hasattr(self, "logger"): self.logger = init_logger("eventloop") @property def context(self) -> Context: if self._context is None: raise RuntimeError("Loop has not been initialized with _init_loop!") return self._context @property def loop(self) -> asyncio.AbstractEventLoop: if self._loop is None: raise RuntimeError("Loop has not been initialized with _init_loop!") return self._loop @property def sockets(self) -> dict[str, Socket]: return self._sockets
[docs] def register_socket(self, name: str, socket: Socket, receiver: bool = False) -> None: """Register a socket, optionally declaring it as a receiver socket to poll""" if name in self._sockets: raise KeyError(f"Socket {name} already declared!") self._sockets[name] = socket if receiver: self._receivers[name] = socket
[docs] def add_callback( self, socket: str, callback: Callable[[Message], Any] | Callable[[Message], Coroutine] ) -> None: """ Add a callback to be called when the socket receives a message. Callbacks are called in the order in which they are added. """ if socket not in self._receivers: raise KeyError(f"Socket {socket} does not exist or is not a receiving socket") if iscoroutinefunction_partial(callback): self._callbacks[socket]["asyncio"].append(callback) else: self._callbacks[socket]["sync"].append(callback)
[docs] def clear_callbacks(self) -> None: self._callbacks = defaultdict(lambda: _CallbackDict(sync=[], asyncio=[]))
def _init_loop(self) -> None: self._loop = asyncio.get_running_loop() self._context = Context.instance() self._quitting = asyncio.Event() def _stop_loop(self) -> None: if self._quitting is None: return self._quitting.set() async def _poll_receivers(self) -> None: """ Rather than using the zmq.asyncio.Poller which wastes a ton of time, it turns out doing it this way is roughly 4x as fast: just manually poll the sockets, and if you have multiple sockets, gather multiple coroutines where you're polling the sockets. """ if len(self._receivers) == 1: await self._poll_receiver(next(iter(self._receivers.keys()))) else: await asyncio.gather(*[self._poll_receiver(name) for name in self._receivers]) async def _poll_receiver(self, name: str) -> None: socket = self._receivers[name] while not self._quitting.is_set(): msg_bytes = await socket.recv_multipart() try: msg = Message.from_bytes(msg_bytes) except Exception as e: self.logger.exception( "Exception decoding message for socket %s: %s, %s", name, msg_bytes, e ) continue # purposely don't catch errors here because we want them to bubble up into the caller for acb in self._callbacks[name]["asyncio"]: await acb(msg) for cb in self._callbacks[name]["sync"]: self.loop.run_in_executor(None, cb, msg)