added protocol attr to bound message senders/receivers

This commit is contained in:
Eric Froemling 2021-09-21 16:00:04 -05:00
parent 0f77fe19ce
commit f8bd041588
No known key found for this signature in database
GPG Key ID: 89C93F0F8D6D5A98
2 changed files with 40 additions and 14 deletions

View File

@ -123,6 +123,11 @@ class _BoundTestMessageSender:
"""Send a message asynchronously."""
return await self._sender.send_async(self._obj, message)
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this sender."""
return self._sender.protocol
# SEND_CODE_TEST_END
# RCVS_CODE_TEST_BEGIN
@ -183,6 +188,11 @@ class _BoundTestSyncMessageReceiver:
"""Synchronously handle a raw incoming message."""
return self._receiver.handle_raw_message(self._obj, message)
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this receiver."""
return self._receiver.protocol
# RCVS_CODE_TEST_END
# RCVA_CODE_TEST_BEGIN
@ -244,6 +254,11 @@ class _BoundTestAsyncMessageReceiver:
return await self._receiver.handle_raw_message_async(
self._obj, message)
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this receiver."""
return self._receiver.protocol
# RCVA_CODE_TEST_END

View File

@ -386,6 +386,11 @@ class MessageProtocol:
f' """Send a message {how}."""\n'
f' return {awt}self._sender.'
f'send{sfx}(self._obj, message)\n')
out += ('\n'
' @property\n'
' def protocol(self) -> MessageProtocol:\n'
' """Protocol associated with this sender."""\n'
' return self._sender.protocol\n')
return out
@ -501,6 +506,12 @@ class MessageProtocol:
' return self._receiver.handle_raw_message'
'(self._obj, message)\n')
out += ('\n'
' @property\n'
' def protocol(self) -> MessageProtocol:\n'
' """Protocol associated with this receiver."""\n'
' return self._receiver.protocol\n')
return out
@ -525,7 +536,7 @@ class MessageSender:
"""
def __init__(self, protocol: MessageProtocol) -> None:
self._protocol = protocol
self.protocol = protocol
self._send_raw_message_call: Optional[Callable[[Any, str], str]] = None
self._send_async_raw_message_call: Optional[Callable[
[Any, str], Awaitable[str]]] = None
@ -554,9 +565,9 @@ class MessageSender:
if self._send_raw_message_call is None:
raise RuntimeError('send() is unimplemented for this type.')
msg_encoded = self._protocol.encode_message(message)
msg_encoded = self.protocol.encode_message(message)
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
response = self._protocol.decode_response(response_encoded)
response = self.protocol.decode_response(response_encoded)
assert isinstance(response, (Response, type(None)))
assert (response is None
or type(response) in type(message).get_response_types())
@ -572,10 +583,10 @@ class MessageSender:
if self._send_async_raw_message_call is None:
raise RuntimeError('send_async() is unimplemented for this type.')
msg_encoded = self._protocol.encode_message(message)
msg_encoded = self.protocol.encode_message(message)
response_encoded = await self._send_async_raw_message_call(
bound_obj, msg_encoded)
response = self._protocol.decode_response(response_encoded)
response = self.protocol.decode_response(response_encoded)
assert isinstance(response, (Response, type(None)))
assert (response is None
or type(response) in type(message).get_response_types())
@ -610,7 +621,7 @@ class MessageReceiver:
is_async = False
def __init__(self, protocol: MessageProtocol) -> None:
self._protocol = protocol
self.protocol = protocol
self._handlers: Dict[Type[Message], Callable] = {}
# noinspection PyProtectedMember
@ -677,7 +688,7 @@ class MessageReceiver:
# return types exactly match. (Technically we could return a subset
# of the supported types; can allow this in the future if it makes
# sense).
registered_types = self._protocol.message_ids_by_type.keys()
registered_types = self.protocol.message_ids_by_type.keys()
if msgtype not in registered_types:
raise TypeError(f'Message type {msgtype} is not registered'
@ -699,7 +710,7 @@ class MessageReceiver:
def validate(self, warn_only: bool = False) -> None:
"""Check for handler completeness, valid types, etc."""
for msgtype in self._protocol.message_ids_by_type.keys():
for msgtype in self.protocol.message_ids_by_type.keys():
if issubclass(msgtype, Response):
continue
if msgtype not in self._handlers:
@ -712,7 +723,7 @@ class MessageReceiver:
def _decode_incoming_message(self,
msg: str) -> Tuple[Message, Type[Message]]:
# Decode the incoming message.
msg_decoded = self._protocol.decode_message(msg)
msg_decoded = self.protocol.decode_message(msg)
msgtype = type(msg_decoded)
assert issubclass(msgtype, Message)
return msg_decoded, msgtype
@ -729,24 +740,24 @@ class MessageReceiver:
# (user should never explicitly return these)
assert not isinstance(response, ErrorResponse)
assert type(response) in msgtype.get_response_types()
return self._protocol.encode_response(response)
return self.protocol.encode_response(response)
def _handle_raw_message_error(self, exc: Exception) -> str:
if self._protocol.log_remote_exceptions:
if self.protocol.log_remote_exceptions:
logging.exception('Error handling message.')
# If anything goes wrong, return a ErrorResponse instead.
if (isinstance(exc, CleanError)
and self._protocol.preserve_clean_errors):
and self.protocol.preserve_clean_errors):
err_response = ErrorResponse(error_message=str(exc),
error_type=ErrorType.CLEAN)
else:
err_response = ErrorResponse(
error_message=(traceback.format_exc()
if self._protocol.trusted_sender else
if self.protocol.trusted_sender else
'An unknown error has occurred.'),
error_type=ErrorType.OTHER)
return self._protocol.encode_response(err_response)
return self.protocol.encode_response(err_response)
def handle_raw_message(self, bound_obj: Any, msg: str) -> str:
"""Decode, handle, and return an response for a message."""