moved some bound-sender/receiver functionality to base classes instead of generated

This commit is contained in:
Eric Froemling 2021-09-21 16:50:29 -05:00
parent f8bd041588
commit 08ea737144
No known key found for this signature in database
GPG Key ID: 89C93F0F8D6D5A98
2 changed files with 60 additions and 78 deletions

View File

@ -15,7 +15,8 @@ from efrotools.statictest import static_type_equals
from efro.error import CleanError, RemoteError from efro.error import CleanError, RemoteError
from efro.dataclassio import ioprepped from efro.dataclassio import ioprepped
from efro.message import (Message, Response, MessageProtocol, MessageSender, from efro.message import (Message, Response, MessageProtocol, MessageSender,
MessageReceiver) BoundMessageSender, MessageReceiver,
BoundMessageReceiver)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import List, Type, Any, Callable, Union, Optional, Awaitable from typing import List, Type, Any, Callable, Union, Optional, Awaitable
@ -83,14 +84,9 @@ class _TestMessageSender(MessageSender):
return _BoundTestMessageSender(obj, self) return _BoundTestMessageSender(obj, self)
class _BoundTestMessageSender: class _BoundTestMessageSender(BoundMessageSender):
"""Protocol-specific bound sender.""" """Protocol-specific bound sender."""
def __init__(self, obj: Any, sender: _TestMessageSender) -> None:
assert obj is not None
self._obj = obj
self._sender = sender
@overload @overload
def send(self, message: _TMsg1) -> _TResp1: def send(self, message: _TMsg1) -> _TResp1:
... ...
@ -123,11 +119,6 @@ class _BoundTestMessageSender:
"""Send a message asynchronously.""" """Send a message asynchronously."""
return await self._sender.send_async(self._obj, message) return await self._sender.send_async(self._obj, message)
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this sender."""
return self._sender.protocol
# SEND_CODE_TEST_END # SEND_CODE_TEST_END
# RCVS_CODE_TEST_BEGIN # RCVS_CODE_TEST_BEGIN
@ -172,27 +163,13 @@ class _TestSyncMessageReceiver(MessageReceiver):
return call return call
class _BoundTestSyncMessageReceiver: class _BoundTestSyncMessageReceiver(BoundMessageReceiver):
"""Protocol-specific bound receiver.""" """Protocol-specific bound receiver."""
def __init__(
self,
obj: Any,
receiver: _TestSyncMessageReceiver,
) -> None:
assert obj is not None
self._obj = obj
self._receiver = receiver
def handle_raw_message(self, message: str) -> str: def handle_raw_message(self, message: str) -> str:
"""Synchronously handle a raw incoming message.""" """Synchronously handle a raw incoming message."""
return self._receiver.handle_raw_message(self._obj, message) return self._receiver.handle_raw_message(self._obj, message)
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this receiver."""
return self._receiver.protocol
# RCVS_CODE_TEST_END # RCVS_CODE_TEST_END
# RCVA_CODE_TEST_BEGIN # RCVA_CODE_TEST_BEGIN
@ -237,28 +214,14 @@ class _TestAsyncMessageReceiver(MessageReceiver):
return call return call
class _BoundTestAsyncMessageReceiver: class _BoundTestAsyncMessageReceiver(BoundMessageReceiver):
"""Protocol-specific bound receiver.""" """Protocol-specific bound receiver."""
def __init__(
self,
obj: Any,
receiver: _TestAsyncMessageReceiver,
) -> None:
assert obj is not None
self._obj = obj
self._receiver = receiver
async def handle_raw_message(self, message: str) -> str: async def handle_raw_message(self, message: str) -> str:
"""Asynchronously handle a raw incoming message.""" """Asynchronously handle a raw incoming message."""
return await self._receiver.handle_raw_message_async( return await self._receiver.handle_raw_message_async(
self._obj, message) self._obj, message)
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this receiver."""
return self._receiver.protocol
# RCVA_CODE_TEST_END # RCVA_CODE_TEST_END

View File

@ -23,7 +23,6 @@ from efro.dataclassio import (ioprepped, is_ioprepped_dataclass, IOAttrs,
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set, from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set,
Sequence, Union, Awaitable) Sequence, Union, Awaitable)
from efro.error import CommunicationError
TM = TypeVar('TM', bound='MessageSender') TM = TypeVar('TM', bound='MessageSender')
@ -278,10 +277,12 @@ class MessageProtocol:
importlines2 += f' {line}\n' importlines2 += f' {line}\n'
if part == 'sender': if part == 'sender':
importlines1 = 'from efro.message import MessageSender' importlines1 = (
'from efro.message import MessageSender, BoundMessageSender')
tpimportex = '' tpimportex = ''
else: else:
importlines1 = 'from efro.message import MessageReceiver' importlines1 = ('from efro.message import MessageReceiver,'
' BoundMessageReceiver')
tpimportex = ', Awaitable' tpimportex = ', Awaitable'
out = ('# Released under the MIT License. See LICENSE for details.\n' out = ('# Released under the MIT License. See LICENSE for details.\n'
@ -335,14 +336,8 @@ class MessageProtocol:
f'(obj, self)\n' f'(obj, self)\n'
f'\n' f'\n'
f'\n' f'\n'
f'class {ppre}Bound{basename}:\n' f'class {ppre}Bound{basename}(BoundMessageSender):\n'
f' """Protocol-specific bound sender."""\n' f' """Protocol-specific bound sender."""\n')
f'\n'
f' def __init__(self, obj: Any,'
f' sender: {ppre}{basename}) -> None:\n'
f' assert obj is not None\n'
f' self._obj = obj\n'
f' self._sender = sender\n')
# Define handler() overloads for all registered message types. # Define handler() overloads for all registered message types.
msgtypes = [ msgtypes = [
@ -386,11 +381,6 @@ class MessageProtocol:
f' """Send a message {how}."""\n' f' """Send a message {how}."""\n'
f' return {awt}self._sender.' f' return {awt}self._sender.'
f'send{sfx}(self._obj, message)\n') f'send{sfx}(self._obj, message)\n')
out += ('\n'
' @property\n'
' def protocol(self) -> MessageProtocol:\n'
' """Protocol associated with this sender."""\n'
' return self._sender.protocol\n')
return out return out
@ -478,17 +468,8 @@ class MessageProtocol:
out += (f'\n' out += (f'\n'
f'\n' f'\n'
f'class {ppre}Bound{basename}:\n' f'class {ppre}Bound{basename}(BoundMessageReceiver):\n'
f' """Protocol-specific bound receiver."""\n' f' """Protocol-specific bound receiver."""\n')
f'\n'
f' def __init__(\n'
f' self,\n'
f' obj: Any,\n'
f' receiver: {ppre}{basename},\n'
f' ) -> None:\n'
f' assert obj is not None\n'
f' self._obj = obj\n'
f' self._receiver = receiver\n')
if is_async: if is_async:
out += ( out += (
'\n' '\n'
@ -506,12 +487,6 @@ class MessageProtocol:
' return self._receiver.handle_raw_message' ' return self._receiver.handle_raw_message'
'(self._obj, message)\n') '(self._obj, message)\n')
out += ('\n'
' @property\n'
' def protocol(self) -> MessageProtocol:\n'
' """Protocol associated with this receiver."""\n'
' return self._receiver.protocol\n')
return out return out
@ -593,6 +568,20 @@ class MessageSender:
return response return response
class BoundMessageSender:
"""Base class for bound senders."""
def __init__(self, obj: Any, sender: MessageSender) -> None:
assert obj is not None
self._obj = obj
self._sender = sender
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this sender."""
return self._sender.protocol
class MessageReceiver: class MessageReceiver:
"""Facilitates receiving & responding to messages from a remote source. """Facilitates receiving & responding to messages from a remote source.
@ -742,7 +731,8 @@ class MessageReceiver:
assert type(response) in msgtype.get_response_types() assert type(response) in msgtype.get_response_types()
return self.protocol.encode_response(response) return self.protocol.encode_response(response)
def _handle_raw_message_error(self, exc: Exception) -> str: def raw_response_for_error(self, exc: Exception) -> str:
"""Return a raw response for an error that occurred during handling."""
if self.protocol.log_remote_exceptions: if self.protocol.log_remote_exceptions:
logging.exception('Error handling message.') logging.exception('Error handling message.')
@ -771,7 +761,7 @@ class MessageReceiver:
return self._encode_response(response, msgtype) return self._encode_response(response, msgtype)
except Exception as exc: except Exception as exc:
return self._handle_raw_message_error(exc) return self.raw_response_for_error(exc)
async def handle_raw_message_async(self, bound_obj: Any, msg: str) -> str: async def handle_raw_message_async(self, bound_obj: Any, msg: str) -> str:
"""Should be called when the receiver gets a message. """Should be called when the receiver gets a message.
@ -788,4 +778,33 @@ class MessageReceiver:
return self._encode_response(response, msgtype) return self._encode_response(response, msgtype)
except Exception as exc: except Exception as exc:
return self._handle_raw_message_error(exc) return self.raw_response_for_error(exc)
class BoundMessageReceiver:
"""Base bound receiver class."""
def __init__(
self,
obj: Any,
receiver: MessageReceiver,
) -> None:
assert obj is not None
self._obj = obj
self._receiver = receiver
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this receiver."""
return self._receiver.protocol
def raw_response_for_error(self, exc: Exception) -> str:
"""Return a raw response for an error that occurred during handling.
This is automatically called from standard handle_raw_message_x()
calls but can be manually invoked if errors occur outside of there.
This gives clients a better idea of what went wrong vs simply
returning invalid data which they might dismiss as a connection
related error.
"""
return self._receiver.raw_response_for_error(exc)