ballistica/tools/efro/message/_receiver.py
Eric Froemling 02399815ed
v1.6.10
2022-03-18 22:02:10 -05:00

240 lines
9.1 KiB
Python

# Released under the MIT License. See LICENSE for details.
#
"""Functionality for sending and responding to messages.
Supports static typing for message types and possible return types.
"""
from __future__ import annotations
import inspect
import logging
from typing import TYPE_CHECKING
from efro.message._message import (Message, Response, EmptyResponse,
ErrorResponse, UnregisteredMessageIDError)
if TYPE_CHECKING:
from typing import Any, Callable, Optional, Awaitable, Union
from efro.message._protocol import MessageProtocol
class MessageReceiver:
"""Facilitates receiving & responding to messages from a remote source.
This is instantiated at the class level with unbound methods registered
as handlers for different message types in the protocol.
Example:
class MyClass:
receiver = MyMessageReceiver()
# MyMessageReceiver fills out handler() overloads to ensure all
# registered handlers have valid types/return-types.
@receiver.handler
def handle_some_message_type(self, message: SomeMsg) -> SomeResponse:
# Deal with this message type here.
# This will trigger the registered handler being called.
obj = MyClass()
obj.receiver.handle_raw_message(some_raw_data)
Any unhandled Exception occurring during message handling will result in
an Exception being raised on the sending end.
"""
is_async = False
def __init__(self, protocol: MessageProtocol) -> None:
self.protocol = protocol
self._handlers: dict[type[Message], Callable] = {}
# noinspection PyProtectedMember
def register_handler(
self, call: Callable[[Any, Message], Optional[Response]]) -> None:
"""Register a handler call.
The message type handled by the call is determined by its
type annotation.
"""
# TODO: can use types.GenericAlias in 3.9.
from typing import _GenericAlias # type: ignore
from typing import get_type_hints, get_args
sig = inspect.getfullargspec(call)
# The provided callable should be a method taking one 'msg' arg.
expectedsig = ['self', 'msg']
if sig.args != expectedsig:
raise ValueError(f'Expected callable signature of {expectedsig};'
f' got {sig.args}')
# Make sure we are only given async methods if we are an async handler
# and sync ones otherwise.
is_async = inspect.iscoroutinefunction(call)
if self.is_async != is_async:
msg = ('Expected a sync method; found an async one.' if is_async
else 'Expected an async method; found a sync one.')
raise ValueError(msg)
# Check annotation types to determine what message types we handle.
# Return-type annotation can be a Union, but we probably don't
# have it available at runtime. Explicitly pull it in.
# UPDATE: we've updated our pylint filter to where we should
# have all annotations available.
# anns = get_type_hints(call, localns={'Union': Union})
anns = get_type_hints(call)
msgtype = anns.get('msg')
if not isinstance(msgtype, type):
raise TypeError(
f'expected a type for "msg" annotation; got {type(msgtype)}.')
assert issubclass(msgtype, Message)
ret = anns.get('return')
responsetypes: tuple[Union[type[Any], type[None]], ...]
# Return types can be a single type or a union of types.
if isinstance(ret, _GenericAlias):
targs = get_args(ret)
if not all(isinstance(a, type) for a in targs):
raise TypeError(f'expected only types for "return" annotation;'
f' got {targs}.')
responsetypes = targs
else:
if not isinstance(ret, type):
raise TypeError(f'expected one or more types for'
f' "return" annotation; got a {type(ret)}.')
responsetypes = (ret, )
# Return type of None translates to EmptyResponse.
responsetypes = tuple(EmptyResponse if r is type(None) else r
for r in responsetypes) # noqa
# Make sure our protocol has this message type registered and our
# return types exactly match. (Technically we could return a subset
# of the supported types; can allow this in the future if it makes
# sense).
registered_types = self.protocol.message_ids_by_type.keys()
if msgtype not in registered_types:
raise TypeError(f'Message type {msgtype} is not registered'
f' in this Protocol.')
if msgtype in self._handlers:
raise TypeError(f'Message type {msgtype} already has a registered'
f' handler.')
# Make sure the responses exactly matches what the message expects.
if set(responsetypes) != set(msgtype.get_response_types()):
raise TypeError(
f'Provided response types {responsetypes} do not'
f' match the set expected by message type {msgtype}: '
f'({msgtype.get_response_types()})')
# Ok; we're good!
self._handlers[msgtype] = call
def validate(self, log_only: bool = False) -> None:
"""Check for handler completeness, valid types, etc."""
for msgtype in self.protocol.message_ids_by_type.keys():
if issubclass(msgtype, Response):
continue
if msgtype not in self._handlers:
msg = (f'Protocol message type {msgtype} is not handled'
f' by receiver type {type(self)}.')
if log_only:
logging.error(msg)
else:
raise TypeError(msg)
def _decode_incoming_message(self,
msg: str) -> tuple[Message, type[Message]]:
# Decode the incoming message.
msg_decoded = self.protocol.decode_message(msg)
msgtype = type(msg_decoded)
assert issubclass(msgtype, Message)
return msg_decoded, msgtype
def _encode_response(self, response: Optional[Response],
msgtype: type[Message]) -> str:
# A return value of None equals EmptyResponse.
if response is None:
response = EmptyResponse()
assert isinstance(response, Response)
# (user should never explicitly return error-responses)
assert not isinstance(response, ErrorResponse)
assert type(response) in msgtype.get_response_types()
return self.protocol.encode_response(response)
def handle_raw_message(self,
bound_obj: Any,
msg: str,
raise_unregistered: bool = False) -> str:
"""Decode, handle, and return an response for a message.
if 'raise_unregistered' is True, will raise an
efro.message.UnregisteredMessageIDError for messages not handled by
the protocol. In all other cases local errors will translate to
error responses returned to the sender.
"""
assert not self.is_async, "can't call sync handler on async receiver"
try:
msg_decoded, msgtype = self._decode_incoming_message(msg)
handler = self._handlers.get(msgtype)
if handler is None:
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
result = handler(bound_obj, msg_decoded)
return self._encode_response(result, msgtype)
except Exception as exc:
if (raise_unregistered
and isinstance(exc, UnregisteredMessageIDError)):
raise
return self.protocol.encode_error_response(exc)
async def handle_raw_message_async(
self,
bound_obj: Any,
msg: str,
raise_unregistered: bool = False) -> str:
"""Should be called when the receiver gets a message.
The return value is the raw response to the message.
"""
assert self.is_async, "can't call async handler on sync receiver"
try:
msg_decoded, msgtype = self._decode_incoming_message(msg)
handler = self._handlers.get(msgtype)
if handler is None:
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
result = await handler(bound_obj, msg_decoded)
return self._encode_response(result, msgtype)
except Exception as exc:
if (raise_unregistered
and isinstance(exc, UnregisteredMessageIDError)):
raise
return self.protocol.encode_error_response(exc)
class BoundMessageReceiver:
"""Base bound receiver class."""
def __init__(
self,
obj: Any,
receiver: MessageReceiver,
) -> None:
assert obj is not None
self._obj = obj
self._receiver = receiver
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this receiver."""
return self._receiver.protocol