diff --git a/tests/test_efro/test_message.py b/tests/test_efro/test_message.py index 2a24a008..ceddb6d4 100644 --- a/tests/test_efro/test_message.py +++ b/tests/test_efro/test_message.py @@ -123,6 +123,11 @@ class _BoundTestMessageSender: """Send a message asynchronously.""" return await self._sender.send_async(self._obj, message) + @property + def protocol(self) -> MessageProtocol: + """Protocol associated with this sender.""" + return self._sender.protocol + # SEND_CODE_TEST_END # RCVS_CODE_TEST_BEGIN @@ -183,6 +188,11 @@ class _BoundTestSyncMessageReceiver: """Synchronously handle a raw incoming message.""" return self._receiver.handle_raw_message(self._obj, message) + @property + def protocol(self) -> MessageProtocol: + """Protocol associated with this receiver.""" + return self._receiver.protocol + # RCVS_CODE_TEST_END # RCVA_CODE_TEST_BEGIN @@ -244,6 +254,11 @@ class _BoundTestAsyncMessageReceiver: return await self._receiver.handle_raw_message_async( self._obj, message) + @property + def protocol(self) -> MessageProtocol: + """Protocol associated with this receiver.""" + return self._receiver.protocol + # RCVA_CODE_TEST_END diff --git a/tools/efro/message.py b/tools/efro/message.py index 848f74ec..d5f47d21 100644 --- a/tools/efro/message.py +++ b/tools/efro/message.py @@ -386,6 +386,11 @@ class MessageProtocol: f' """Send a message {how}."""\n' f' return {awt}self._sender.' f'send{sfx}(self._obj, message)\n') + out += ('\n' + ' @property\n' + ' def protocol(self) -> MessageProtocol:\n' + ' """Protocol associated with this sender."""\n' + ' return self._sender.protocol\n') return out @@ -501,6 +506,12 @@ class MessageProtocol: ' return self._receiver.handle_raw_message' '(self._obj, message)\n') + out += ('\n' + ' @property\n' + ' def protocol(self) -> MessageProtocol:\n' + ' """Protocol associated with this receiver."""\n' + ' return self._receiver.protocol\n') + return out @@ -525,7 +536,7 @@ class MessageSender: """ def __init__(self, protocol: MessageProtocol) -> None: - self._protocol = protocol + self.protocol = protocol self._send_raw_message_call: Optional[Callable[[Any, str], str]] = None self._send_async_raw_message_call: Optional[Callable[ [Any, str], Awaitable[str]]] = None @@ -554,9 +565,9 @@ class MessageSender: if self._send_raw_message_call is None: raise RuntimeError('send() is unimplemented for this type.') - msg_encoded = self._protocol.encode_message(message) + msg_encoded = self.protocol.encode_message(message) response_encoded = self._send_raw_message_call(bound_obj, msg_encoded) - response = self._protocol.decode_response(response_encoded) + response = self.protocol.decode_response(response_encoded) assert isinstance(response, (Response, type(None))) assert (response is None or type(response) in type(message).get_response_types()) @@ -572,10 +583,10 @@ class MessageSender: if self._send_async_raw_message_call is None: raise RuntimeError('send_async() is unimplemented for this type.') - msg_encoded = self._protocol.encode_message(message) + msg_encoded = self.protocol.encode_message(message) response_encoded = await self._send_async_raw_message_call( bound_obj, msg_encoded) - response = self._protocol.decode_response(response_encoded) + response = self.protocol.decode_response(response_encoded) assert isinstance(response, (Response, type(None))) assert (response is None or type(response) in type(message).get_response_types()) @@ -610,7 +621,7 @@ class MessageReceiver: is_async = False def __init__(self, protocol: MessageProtocol) -> None: - self._protocol = protocol + self.protocol = protocol self._handlers: Dict[Type[Message], Callable] = {} # noinspection PyProtectedMember @@ -677,7 +688,7 @@ class MessageReceiver: # return types exactly match. (Technically we could return a subset # of the supported types; can allow this in the future if it makes # sense). - registered_types = self._protocol.message_ids_by_type.keys() + registered_types = self.protocol.message_ids_by_type.keys() if msgtype not in registered_types: raise TypeError(f'Message type {msgtype} is not registered' @@ -699,7 +710,7 @@ class MessageReceiver: def validate(self, warn_only: bool = False) -> None: """Check for handler completeness, valid types, etc.""" - for msgtype in self._protocol.message_ids_by_type.keys(): + for msgtype in self.protocol.message_ids_by_type.keys(): if issubclass(msgtype, Response): continue if msgtype not in self._handlers: @@ -712,7 +723,7 @@ class MessageReceiver: def _decode_incoming_message(self, msg: str) -> Tuple[Message, Type[Message]]: # Decode the incoming message. - msg_decoded = self._protocol.decode_message(msg) + msg_decoded = self.protocol.decode_message(msg) msgtype = type(msg_decoded) assert issubclass(msgtype, Message) return msg_decoded, msgtype @@ -729,24 +740,24 @@ class MessageReceiver: # (user should never explicitly return these) assert not isinstance(response, ErrorResponse) assert type(response) in msgtype.get_response_types() - return self._protocol.encode_response(response) + return self.protocol.encode_response(response) def _handle_raw_message_error(self, exc: Exception) -> str: - if self._protocol.log_remote_exceptions: + if self.protocol.log_remote_exceptions: logging.exception('Error handling message.') # If anything goes wrong, return a ErrorResponse instead. if (isinstance(exc, CleanError) - and self._protocol.preserve_clean_errors): + and self.protocol.preserve_clean_errors): err_response = ErrorResponse(error_message=str(exc), error_type=ErrorType.CLEAN) else: err_response = ErrorResponse( error_message=(traceback.format_exc() - if self._protocol.trusted_sender else + if self.protocol.trusted_sender else 'An unknown error has occurred.'), error_type=ErrorType.OTHER) - return self._protocol.encode_response(err_response) + return self.protocol.encode_response(err_response) def handle_raw_message(self, bound_obj: Any, msg: str) -> str: """Decode, handle, and return an response for a message."""