import base64
import json
import pickle
import sys
from datetime import UTC, datetime
from enum import StrEnum
from typing import Annotated as A
from typing import Any, Literal
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Discriminator,
Field,
Tag,
TypeAdapter,
WrapSerializer,
)
from pydantic_core.core_schema import SerializerFunctionWrapHandler
from noob.const import META_SIGNAL
from noob.event import Event, MetaSignal
from noob.types import Picklable
if sys.version_info < (3, 12):
from typing_extensions import TypedDict
else:
from typing import TypedDict
[docs]
class MessageType(StrEnum):
announce = "announce"
identify = "identify"
process = "process"
init = "init"
deinit = "deinit"
ping = "ping"
start = "start"
status = "status"
stop = "stop"
event = "event"
error = "error"
[docs]
class NodeStatus(StrEnum):
stopped = "stopped"
"""Node is deinitialized - does not have an instantiated node, etc., but is responsive."""
waiting = "waiting"
"""Node is waiting for its dependency nodes to be ready"""
ready = "ready"
"""Node is ready to process events"""
running = "running"
"""
Node is running in free-run mode.
Note that we do not update status for every process call at the moment,
as that level of granularity is not relevant to the command node when sending commands
"""
closed = "closed"
"""Node is permanently gone, should not be expected to respond to further messages."""
[docs]
class Message(BaseModel):
type_: MessageType = Field(..., alias="type")
node_id: str
timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC))
value: Any = None
model_config = ConfigDict(use_enum_values=True, validate_by_alias=True, serialize_by_alias=True)
[docs]
@classmethod
def from_bytes(cls, msg: list[bytes]) -> "Message":
return MessageAdapter.validate_json(msg[-1].decode("utf-8"))
[docs]
def to_bytes(self) -> bytes:
return self.model_dump_json().encode("utf-8")
[docs]
class IdentifyValue(TypedDict):
node_id: str
outbox: str
status: NodeStatus
signals: list[str] | None
slots: list[str] | None
[docs]
class AnnounceValue(TypedDict):
inbox: str
nodes: dict[str, IdentifyValue]
[docs]
class ErrorValue(TypedDict):
err_type: type[Exception]
err_args: tuple
traceback: str
[docs]
class ProcessValue(TypedDict):
epoch: int
input: dict | None
[docs]
class AnnounceMsg(Message):
"""Command node 'announces' identities of other peers and the events they emit"""
type_: Literal[MessageType.announce] = Field(MessageType.announce, alias="type")
value: AnnounceValue
[docs]
class IdentifyMsg(Message):
"""A node sends its configuration to the command node on initialization"""
type_: Literal[MessageType.identify] = Field(MessageType.identify, alias="type")
value: IdentifyValue
[docs]
class PingMsg(Message):
"""Request other nodes to identify themselves and report their status"""
type_: Literal[MessageType.ping] = Field(MessageType.ping, alias="type")
value: None = None
[docs]
class ProcessMsg(Message):
"""Process a single iteration of the graph"""
type_: Literal[MessageType.process] = Field(MessageType.process, alias="type")
value: ProcessValue
"""Any process-scoped input passed to the `process` call"""
[docs]
class InitMsg(Message):
"""Initialize nodes within node runners"""
type_: Literal[MessageType.init] = Field(MessageType.init, alias="type")
value: None = None
[docs]
class DeinitMsg(Message):
"""Deinitializes nodes within node runners"""
type_: Literal[MessageType.deinit] = Field(MessageType.deinit, alias="type")
value: None = None
[docs]
class StartMsg(Message):
"""Start free-running nodes"""
type_: Literal[MessageType.start] = Field(MessageType.start, alias="type")
value: int | None = None
[docs]
class StatusMsg(Message):
"""Node updating its current status"""
type_: Literal[MessageType.status] = Field(MessageType.status, alias="type")
value: NodeStatus
[docs]
class StopMsg(Message):
"""Stop processing"""
type_: Literal[MessageType.stop] = Field(MessageType.stop, alias="type")
value: None = None
[docs]
class ErrorMsg(Message):
"""An error occurred in one of the processing nodes"""
type_: Literal[MessageType.error] = Field(MessageType.error, alias="type")
value: Picklable[ErrorValue]
model_config = ConfigDict(arbitrary_types_allowed=True)
[docs]
def to_exception(self) -> Exception:
err = self.value["err_type"](*self.value["err_args"])
tb_message = "\nError re-raised from node runner process\n\n"
tb_message += "Original traceback:\n"
tb_message += "-" * 20 + "\n"
tb_message += self.value["traceback"]
err.add_note(tb_message)
return err
def _to_json(val: Event, handler: SerializerFunctionWrapHandler) -> Any:
if val["signal"] == META_SIGNAL and val["value"] is MetaSignal.NoEvent:
val["value"] = MetaSignal.NoEvent.value
try:
return handler(val)
except TypeError:
# pickle and b64encode
return "pck__" + base64.b64encode(pickle.dumps(val)).decode("utf-8")
def _from_json(val: Any) -> Event:
if isinstance(val, str):
if val.startswith("pck__"):
evt = pickle.loads(base64.b64decode(val[5:]))
else:
evt = Event(**json.loads(val)) # type: ignore[typeddict-item]
if evt["signal"] == META_SIGNAL and evt["value"] == MetaSignal.NoEvent.value:
evt["value"] = MetaSignal.NoEvent
return evt
else:
return val
SerializableEvent = A[
Event, WrapSerializer(_to_json, when_used="json"), BeforeValidator(_from_json)
]
[docs]
class EventMsg(Message):
type_: Literal[MessageType.event] = Field(MessageType.event, alias="type")
value: list[SerializableEvent]
def _type_discriminator(v: dict | Message) -> str:
typ = v.get("type", "any") if isinstance(v, dict) else v.type_
if typ in MessageType.__members__:
return typ
else:
return "any"
MessageUnion = A[
A[AnnounceMsg, Tag("announce")]
| A[IdentifyMsg, Tag("identify")]
| A[ProcessMsg, Tag("process")]
| A[InitMsg, Tag("init")]
| A[DeinitMsg, Tag("deinit")]
| A[PingMsg, Tag("ping")]
| A[StartMsg, Tag("start")]
| A[StatusMsg, Tag("status")]
| A[StopMsg, Tag("stop")]
| A[EventMsg, Tag("event")]
| A[ErrorMsg, Tag("error")]
| A[Message, Tag("any")],
Discriminator(_type_discriminator),
]
MessageAdapter = TypeAdapter[MessageUnion](MessageUnion)