From 08ea73714411538b7f234cebe1d50e8c7bb20e1a Mon Sep 17 00:00:00 2001 From: Eric Froemling Date: Tue, 21 Sep 2021 16:50:29 -0500 Subject: [PATCH] moved some bound-sender/receiver functionality to base classes instead of generated --- tests/test_efro/test_message.py | 47 ++--------------- tools/efro/message.py | 91 ++++++++++++++++++++------------- 2 files changed, 60 insertions(+), 78 deletions(-) diff --git a/tests/test_efro/test_message.py b/tests/test_efro/test_message.py index ceddb6d4..1790b1f1 100644 --- a/tests/test_efro/test_message.py +++ b/tests/test_efro/test_message.py @@ -15,7 +15,8 @@ from efrotools.statictest import static_type_equals from efro.error import CleanError, RemoteError from efro.dataclassio import ioprepped from efro.message import (Message, Response, MessageProtocol, MessageSender, - MessageReceiver) + BoundMessageSender, MessageReceiver, + BoundMessageReceiver) if TYPE_CHECKING: from typing import List, Type, Any, Callable, Union, Optional, Awaitable @@ -83,14 +84,9 @@ class _TestMessageSender(MessageSender): return _BoundTestMessageSender(obj, self) -class _BoundTestMessageSender: +class _BoundTestMessageSender(BoundMessageSender): """Protocol-specific bound sender.""" - def __init__(self, obj: Any, sender: _TestMessageSender) -> None: - assert obj is not None - self._obj = obj - self._sender = sender - @overload def send(self, message: _TMsg1) -> _TResp1: ... @@ -123,11 +119,6 @@ class _BoundTestMessageSender: """Send a message asynchronously.""" return await self._sender.send_async(self._obj, message) - @property - def protocol(self) -> MessageProtocol: - """Protocol associated with this sender.""" - return self._sender.protocol - # SEND_CODE_TEST_END # RCVS_CODE_TEST_BEGIN @@ -172,27 +163,13 @@ class _TestSyncMessageReceiver(MessageReceiver): return call -class _BoundTestSyncMessageReceiver: +class _BoundTestSyncMessageReceiver(BoundMessageReceiver): """Protocol-specific bound receiver.""" - def __init__( - self, - obj: Any, - receiver: _TestSyncMessageReceiver, - ) -> None: - assert obj is not None - self._obj = obj - self._receiver = receiver - def handle_raw_message(self, message: str) -> str: """Synchronously handle a raw incoming message.""" return self._receiver.handle_raw_message(self._obj, message) - @property - def protocol(self) -> MessageProtocol: - """Protocol associated with this receiver.""" - return self._receiver.protocol - # RCVS_CODE_TEST_END # RCVA_CODE_TEST_BEGIN @@ -237,28 +214,14 @@ class _TestAsyncMessageReceiver(MessageReceiver): return call -class _BoundTestAsyncMessageReceiver: +class _BoundTestAsyncMessageReceiver(BoundMessageReceiver): """Protocol-specific bound receiver.""" - def __init__( - self, - obj: Any, - receiver: _TestAsyncMessageReceiver, - ) -> None: - assert obj is not None - self._obj = obj - self._receiver = receiver - async def handle_raw_message(self, message: str) -> str: """Asynchronously handle a raw incoming message.""" return await self._receiver.handle_raw_message_async( self._obj, message) - @property - def protocol(self) -> MessageProtocol: - """Protocol associated with this receiver.""" - return self._receiver.protocol - # RCVA_CODE_TEST_END diff --git a/tools/efro/message.py b/tools/efro/message.py index d5f47d21..f7b1a57f 100644 --- a/tools/efro/message.py +++ b/tools/efro/message.py @@ -23,7 +23,6 @@ from efro.dataclassio import (ioprepped, is_ioprepped_dataclass, IOAttrs, if TYPE_CHECKING: from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set, Sequence, Union, Awaitable) - from efro.error import CommunicationError TM = TypeVar('TM', bound='MessageSender') @@ -278,10 +277,12 @@ class MessageProtocol: importlines2 += f' {line}\n' if part == 'sender': - importlines1 = 'from efro.message import MessageSender' + importlines1 = ( + 'from efro.message import MessageSender, BoundMessageSender') tpimportex = '' else: - importlines1 = 'from efro.message import MessageReceiver' + importlines1 = ('from efro.message import MessageReceiver,' + ' BoundMessageReceiver') tpimportex = ', Awaitable' out = ('# Released under the MIT License. See LICENSE for details.\n' @@ -335,14 +336,8 @@ class MessageProtocol: f'(obj, self)\n' f'\n' f'\n' - f'class {ppre}Bound{basename}:\n' - f' """Protocol-specific bound sender."""\n' - f'\n' - f' def __init__(self, obj: Any,' - f' sender: {ppre}{basename}) -> None:\n' - f' assert obj is not None\n' - f' self._obj = obj\n' - f' self._sender = sender\n') + f'class {ppre}Bound{basename}(BoundMessageSender):\n' + f' """Protocol-specific bound sender."""\n') # Define handler() overloads for all registered message types. msgtypes = [ @@ -386,11 +381,6 @@ class MessageProtocol: f' """Send a message {how}."""\n' f' return {awt}self._sender.' f'send{sfx}(self._obj, message)\n') - out += ('\n' - ' @property\n' - ' def protocol(self) -> MessageProtocol:\n' - ' """Protocol associated with this sender."""\n' - ' return self._sender.protocol\n') return out @@ -478,17 +468,8 @@ class MessageProtocol: out += (f'\n' f'\n' - f'class {ppre}Bound{basename}:\n' - f' """Protocol-specific bound receiver."""\n' - f'\n' - f' def __init__(\n' - f' self,\n' - f' obj: Any,\n' - f' receiver: {ppre}{basename},\n' - f' ) -> None:\n' - f' assert obj is not None\n' - f' self._obj = obj\n' - f' self._receiver = receiver\n') + f'class {ppre}Bound{basename}(BoundMessageReceiver):\n' + f' """Protocol-specific bound receiver."""\n') if is_async: out += ( '\n' @@ -506,12 +487,6 @@ class MessageProtocol: ' return self._receiver.handle_raw_message' '(self._obj, message)\n') - out += ('\n' - ' @property\n' - ' def protocol(self) -> MessageProtocol:\n' - ' """Protocol associated with this receiver."""\n' - ' return self._receiver.protocol\n') - return out @@ -593,6 +568,20 @@ class MessageSender: return response +class BoundMessageSender: + """Base class for bound senders.""" + + def __init__(self, obj: Any, sender: MessageSender) -> None: + assert obj is not None + self._obj = obj + self._sender = sender + + @property + def protocol(self) -> MessageProtocol: + """Protocol associated with this sender.""" + return self._sender.protocol + + class MessageReceiver: """Facilitates receiving & responding to messages from a remote source. @@ -742,7 +731,8 @@ class MessageReceiver: assert type(response) in msgtype.get_response_types() return self.protocol.encode_response(response) - def _handle_raw_message_error(self, exc: Exception) -> str: + def raw_response_for_error(self, exc: Exception) -> str: + """Return a raw response for an error that occurred during handling.""" if self.protocol.log_remote_exceptions: logging.exception('Error handling message.') @@ -771,7 +761,7 @@ class MessageReceiver: return self._encode_response(response, msgtype) except Exception as exc: - return self._handle_raw_message_error(exc) + return self.raw_response_for_error(exc) async def handle_raw_message_async(self, bound_obj: Any, msg: str) -> str: """Should be called when the receiver gets a message. @@ -788,4 +778,33 @@ class MessageReceiver: return self._encode_response(response, msgtype) except Exception as exc: - return self._handle_raw_message_error(exc) + return self.raw_response_for_error(exc) + + +class BoundMessageReceiver: + """Base bound receiver class.""" + + def __init__( + self, + obj: Any, + receiver: MessageReceiver, + ) -> None: + assert obj is not None + self._obj = obj + self._receiver = receiver + + @property + def protocol(self) -> MessageProtocol: + """Protocol associated with this receiver.""" + return self._receiver.protocol + + def raw_response_for_error(self, exc: Exception) -> str: + """Return a raw response for an error that occurred during handling. + + This is automatically called from standard handle_raw_message_x() + calls but can be manually invoked if errors occur outside of there. + This gives clients a better idea of what went wrong vs simply + returning invalid data which they might dismiss as a connection + related error. + """ + return self._receiver.raw_response_for_error(exc)