diff --git a/.idea/dictionaries/ericf.xml b/.idea/dictionaries/ericf.xml index 2f3299c2..ab1e2242 100644 --- a/.idea/dictionaries/ericf.xml +++ b/.idea/dictionaries/ericf.xml @@ -1992,6 +1992,7 @@ selindex selwidget selwidgets + sendable senze seqtype seqtypestr diff --git a/ballisticacore-cmake/.idea/dictionaries/ericf.xml b/ballisticacore-cmake/.idea/dictionaries/ericf.xml index 471e39a2..a21794a7 100644 --- a/ballisticacore-cmake/.idea/dictionaries/ericf.xml +++ b/ballisticacore-cmake/.idea/dictionaries/ericf.xml @@ -918,6 +918,7 @@ selindex selwidget selwidgets + sendable seqlen seqtype seqtypestr diff --git a/tests/test_efro/test_message.py b/tests/test_efro/test_message.py index d4801bd4..6a9451d4 100644 --- a/tests/test_efro/test_message.py +++ b/tests/test_efro/test_message.py @@ -10,11 +10,11 @@ from dataclasses import dataclass import pytest +from efrotools.statictest import static_type_equals from efro.error import CleanError, RemoteError from efro.dataclassio import ioprepped from efro.message import (Message, Response, MessageProtocol, MessageSender, MessageReceiver) -from efrotools.statictest import static_type_equals if TYPE_CHECKING: from typing import List, Type, Any, Callable, Union @@ -151,10 +151,12 @@ class _BoundTestMessageReceiver: TEST_PROTOCOL = MessageProtocol( message_types={ - 1: _TMessage1, - 2: _TMessage2, - 3: _TResponse1, - 4: _TResponse2, + 0: _TMessage1, + 1: _TMessage2, + }, + response_types={ + 0: _TResponse1, + 1: _TResponse2, }, trusted_client=True, log_remote_exceptions=False, @@ -167,10 +169,16 @@ def test_protocol_creation() -> None: # This should fail because _TMessage1 can return _TResponse1 which # is not given an id here. with pytest.raises(ValueError): - _protocol = MessageProtocol(message_types={1: _TMessage1}) + _protocol = MessageProtocol( + message_types={0: _TMessage1}, + response_types={0: _TResponse2}, + ) # Now it should work. - _protocol = MessageProtocol(message_types={1: _TMessage1, 2: _TResponse1}) + _protocol = MessageProtocol( + message_types={0: _TMessage1}, + response_types={0: _TResponse1}, + ) def test_sender_module_creation() -> None: diff --git a/tools/efro/message.py b/tools/efro/message.py index b9c8f090..3f9439a6 100644 --- a/tools/efro/message.py +++ b/tools/efro/message.py @@ -50,14 +50,23 @@ class Response: """Base class for responses to messages.""" +# Some standard response types: + + @ioprepped @dataclass -class RemoteErrorMessage(Response): +class RemoteErrorResponse(Response): """Message saying some error has occurred on the other end.""" error_message: Annotated[str, IOAttrs('m')] error_type: Annotated[RemoteErrorType, IOAttrs('t')] +@ioprepped +@dataclass +class EmptyResponse(Response): + """The response equivalent of None.""" + + class MessageProtocol: """Wrangles a set of message types, formats, and response types. Both endpoints must be using a compatible Protocol for communication @@ -67,8 +76,8 @@ class MessageProtocol: """ def __init__(self, - message_types: Dict[int, Union[Type[Message], - Type[Response]]], + message_types: Dict[int, Type[Message]], + response_types: Dict[int, Type[Response]], type_key: Optional[str] = None, preserve_clean_errors: bool = True, log_remote_exceptions: bool = True, @@ -88,10 +97,10 @@ class MessageProtocol: be included in the RemoteError. This should only be enabled in cases where the client is trusted. """ - self.message_types_by_id: Dict[int, Union[Type[Message], - Type[Response]]] = {} - self.message_ids_by_type: Dict[Union[Type[Message], Type[Response]], - int] = {} + self.message_types_by_id: Dict[int, Type[Message]] = {} + self.message_ids_by_type: Dict[Type[Message], int] = {} + self.response_types_by_id: Dict[int, Type[Response]] = {} + self.response_ids_by_type: Dict[Type[Response], int] = {} for m_id, m_type in message_types.items(): # Make sure only valid message types were passed and each @@ -99,30 +108,38 @@ class MessageProtocol: assert isinstance(m_id, int) assert m_id >= 0 assert (is_ioprepped_dataclass(m_type) - and issubclass(m_type, (Message, Response))) + and issubclass(m_type, Message)) assert self.message_types_by_id.get(m_id) is None - self.message_types_by_id[m_id] = m_type self.message_ids_by_type[m_type] = m_id + for r_id, r_type in response_types.items(): + assert isinstance(r_id, int) + assert r_id >= 0 + assert (is_ioprepped_dataclass(r_type) + and issubclass(r_type, Response)) + assert self.response_types_by_id.get(r_id) is None + self.response_types_by_id[r_id] = r_type + self.response_ids_by_type[r_type] = r_id + # Some extra-thorough validation in debug mode. if __debug__: - # Make sure all return types are valid and have been assigned - # an ID as well. + # Make sure all Message types' return types are valid + # and have been assigned an ID as well. all_response_types: Set[Type[Response]] = set() for m_id, m_type in message_types.items(): - if issubclass(m_type, Message): - m_rtypes = m_type.get_response_types() - assert isinstance(m_rtypes, list) - assert m_rtypes # make sure not empty - assert len(set(m_rtypes)) == len(m_rtypes) # check dups - all_response_types.update(m_rtypes) + m_rtypes = m_type.get_response_types() + assert isinstance(m_rtypes, list) + assert m_rtypes, ( + f'Message type {m_type} specifies no return types.') + assert len(set(m_rtypes)) == len(m_rtypes) # check dups + all_response_types.update(m_rtypes) for cls in all_response_types: assert is_ioprepped_dataclass(cls) and issubclass( - cls, (Message, Response)) - if cls not in self.message_ids_by_type: + cls, Response) + if cls not in self.response_ids_by_type: raise ValueError(f'Possible response type {cls}' - f' was not included in message_types.') + f' was not included in response_types.') # Make sure all registered types have unique base names. # We can take advantage of this to generate cleaner looking @@ -138,23 +155,35 @@ class MessageProtocol: self.log_remote_exceptions = log_remote_exceptions self.trusted_client = trusted_client - def message_encode(self, - message: Union[Message, Response], + def encode_message(self, + message: Message, is_error: bool = False) -> bytes: """Encode a message to bytes for transport.""" + return self._encode(message, is_error, self.message_ids_by_type, + 'message') + + def encode_response(self, + response: Response, + is_error: bool = False) -> bytes: + """Encode a response to bytes for transport.""" + return self._encode(response, is_error, self.response_ids_by_type, + 'response') + + def _encode(self, message: Any, is_error: bool, + ids_by_type: Dict[Type, int], opname: str) -> bytes: + """Encode a message to bytes for transport.""" m_id: Optional[int] if is_error: m_id = -1 else: - m_id = self.message_ids_by_type.get(type(message)) + m_id = ids_by_type.get(type(message)) if m_id is None: - raise TypeError( - f'Message/Response type is not registered in Protocol:' - f' {type(message)}') + raise TypeError(f'{opname} type is not registered in protocol:' + f' {type(message)}') msgdict = dataclass_to_dict(message) - # Encode type as part of the message dict if desired + # Encode type as part of the message/response dict if desired # (for legacy compatibility). if self._type_key is not None: if self._type_key in msgdict: @@ -166,12 +195,21 @@ class MessageProtocol: out = {'m': msgdict, 't': m_id} return json.dumps(out, separators=(',', ':')).encode() - def message_decode(self, data: bytes) -> Union[Message, Response]: - """Decode a message from bytes. + def decode_message(self, data: bytes) -> Message: + """Decode a message from bytes.""" + out = self._decode(data, self.message_types_by_id, 'message') + assert isinstance(out, Message) + return out - If the message represents a remote error, an Exception will - be raised. - """ + def decode_response(self, data: bytes) -> Response: + """Decode a response from bytes.""" + out = self._decode(data, self.response_types_by_id, 'response') + assert isinstance(out, Response) + return out + + def _decode(self, data: bytes, types_by_id: Dict[int, Type], + opname: str) -> Any: + """Decode a message from bytes.""" msgfull = json.loads(data.decode()) assert isinstance(msgfull, dict) msgdict: Optional[dict] @@ -187,21 +225,18 @@ class MessageProtocol: # Special case: a remote error occurred. Raise a local Exception. if m_id == -1: - err = dataclass_from_dict(RemoteErrorMessage, msgdict) + assert opname == 'response' + err = dataclass_from_dict(RemoteErrorResponse, msgdict) if (self.preserve_clean_errors and err.error_type is RemoteErrorType.CLEAN): raise CleanError(err.error_message) raise RemoteError(err.error_message) - # Decode this particular type and make sure its valid. - msgtype = self.message_types_by_id.get(m_id) + # Decode this particular type. + msgtype = types_by_id.get(m_id) if msgtype is None: - raise TypeError( - f'Got unregistered message/response type id of {m_id}.') - - out = dataclass_from_dict(msgtype, msgdict) - assert isinstance(out, (Message, Response)) - return out + raise TypeError(f'Got unregistered {opname} type id of {m_id}.') + return dataclass_from_dict(msgtype, msgdict) def _get_module_header(self, part: str) -> str: """Return common parts of generated modules.""" @@ -287,8 +322,8 @@ class MessageProtocol: # Ew; @overload requires at least 2 different signatures so # we need to simply write a single function if we have < 2. if len(msgtypes) == 1: - raise RuntimeError('FIXME: currently require at least 2' - ' message types.') + raise RuntimeError('FIXME: currently we require at least 2' + ' registered message types; found 1.') if len(msgtypes) > 1: for msgtype in msgtypes: msgtypevar = msgtype.__name__ @@ -345,7 +380,7 @@ class MessageProtocol: # we need to simply write a single function if we have < 2. if len(msgtypes) == 1: raise RuntimeError('FIXME: currently require at least 2' - ' message types.') + ' registered message types; found 1.') if len(msgtypes) > 1: for msgtype in msgtypes: msgtypevar = msgtype.__name__ @@ -433,9 +468,9 @@ class MessageSender: if self._send_raw_message_call is None: raise RuntimeError('send() is unimplemented for this type.') - msg_encoded = self._protocol.message_encode(message) + msg_encoded = self._protocol.encode_message(message) response_encoded = self._send_raw_message_call(bound_obj, msg_encoded) - response = self._protocol.message_decode(response_encoded) + response = self._protocol.decode_response(response_encoded) assert isinstance(response, Response) assert type(response) in type(message).get_response_types() return response @@ -572,7 +607,7 @@ class MessageReceiver: """Decode, handle, and return encoded response for a message.""" try: # Decode the incoming message. - msg_decoded = self._protocol.message_decode(msg) + msg_decoded = self._protocol.decode_message(msg) msgtype = type(msg_decoded) assert issubclass(msgtype, Message) @@ -585,26 +620,27 @@ class MessageReceiver: # Re-encode the response. assert isinstance(response, Response) assert type(response) in msgtype.get_response_types() - return self._protocol.message_encode(response) + return self._protocol.encode_response(response) except Exception as exc: if self._protocol.log_remote_exceptions: logging.exception('Error handling message.') - # If anything goes wrong, return a RemoteErrorMessage instead. + # If anything goes wrong, return a RemoteErrorResponse instead. if (isinstance(exc, CleanError) and self._protocol.preserve_clean_errors): - response = RemoteErrorMessage(error_message=str(exc), - error_type=RemoteErrorType.CLEAN) + err_response = RemoteErrorResponse( + error_message=str(exc), error_type=RemoteErrorType.CLEAN) + else: - response = RemoteErrorMessage( + err_response = RemoteErrorResponse( error_message=(traceback.format_exc() if self._protocol.trusted_client else 'An unknown error has occurred.'), error_type=RemoteErrorType.OTHER) - return self._protocol.message_encode(response, is_error=True) + return self._protocol.encode_response(err_response, is_error=True) async def handle_raw_message_async(self, msg: bytes) -> bytes: """Should be called when the receiver gets a message.