moved some bound-sender/receiver functionality to base classes instead of generated

This commit is contained in:
Eric Froemling 2021-09-21 16:50:29 -05:00
parent f8bd041588
commit 08ea737144
No known key found for this signature in database
GPG Key ID: 89C93F0F8D6D5A98
2 changed files with 60 additions and 78 deletions

View File

@ -15,7 +15,8 @@ from efrotools.statictest import static_type_equals
from efro.error import CleanError, RemoteError
from efro.dataclassio import ioprepped
from efro.message import (Message, Response, MessageProtocol, MessageSender,
MessageReceiver)
BoundMessageSender, MessageReceiver,
BoundMessageReceiver)
if TYPE_CHECKING:
from typing import List, Type, Any, Callable, Union, Optional, Awaitable
@ -83,14 +84,9 @@ class _TestMessageSender(MessageSender):
return _BoundTestMessageSender(obj, self)
class _BoundTestMessageSender:
class _BoundTestMessageSender(BoundMessageSender):
"""Protocol-specific bound sender."""
def __init__(self, obj: Any, sender: _TestMessageSender) -> None:
assert obj is not None
self._obj = obj
self._sender = sender
@overload
def send(self, message: _TMsg1) -> _TResp1:
...
@ -123,11 +119,6 @@ class _BoundTestMessageSender:
"""Send a message asynchronously."""
return await self._sender.send_async(self._obj, message)
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this sender."""
return self._sender.protocol
# SEND_CODE_TEST_END
# RCVS_CODE_TEST_BEGIN
@ -172,27 +163,13 @@ class _TestSyncMessageReceiver(MessageReceiver):
return call
class _BoundTestSyncMessageReceiver:
class _BoundTestSyncMessageReceiver(BoundMessageReceiver):
"""Protocol-specific bound receiver."""
def __init__(
self,
obj: Any,
receiver: _TestSyncMessageReceiver,
) -> None:
assert obj is not None
self._obj = obj
self._receiver = receiver
def handle_raw_message(self, message: str) -> str:
"""Synchronously handle a raw incoming message."""
return self._receiver.handle_raw_message(self._obj, message)
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this receiver."""
return self._receiver.protocol
# RCVS_CODE_TEST_END
# RCVA_CODE_TEST_BEGIN
@ -237,28 +214,14 @@ class _TestAsyncMessageReceiver(MessageReceiver):
return call
class _BoundTestAsyncMessageReceiver:
class _BoundTestAsyncMessageReceiver(BoundMessageReceiver):
"""Protocol-specific bound receiver."""
def __init__(
self,
obj: Any,
receiver: _TestAsyncMessageReceiver,
) -> None:
assert obj is not None
self._obj = obj
self._receiver = receiver
async def handle_raw_message(self, message: str) -> str:
"""Asynchronously handle a raw incoming message."""
return await self._receiver.handle_raw_message_async(
self._obj, message)
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this receiver."""
return self._receiver.protocol
# RCVA_CODE_TEST_END

View File

@ -23,7 +23,6 @@ from efro.dataclassio import (ioprepped, is_ioprepped_dataclass, IOAttrs,
if TYPE_CHECKING:
from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set,
Sequence, Union, Awaitable)
from efro.error import CommunicationError
TM = TypeVar('TM', bound='MessageSender')
@ -278,10 +277,12 @@ class MessageProtocol:
importlines2 += f' {line}\n'
if part == 'sender':
importlines1 = 'from efro.message import MessageSender'
importlines1 = (
'from efro.message import MessageSender, BoundMessageSender')
tpimportex = ''
else:
importlines1 = 'from efro.message import MessageReceiver'
importlines1 = ('from efro.message import MessageReceiver,'
' BoundMessageReceiver')
tpimportex = ', Awaitable'
out = ('# Released under the MIT License. See LICENSE for details.\n'
@ -335,14 +336,8 @@ class MessageProtocol:
f'(obj, self)\n'
f'\n'
f'\n'
f'class {ppre}Bound{basename}:\n'
f' """Protocol-specific bound sender."""\n'
f'\n'
f' def __init__(self, obj: Any,'
f' sender: {ppre}{basename}) -> None:\n'
f' assert obj is not None\n'
f' self._obj = obj\n'
f' self._sender = sender\n')
f'class {ppre}Bound{basename}(BoundMessageSender):\n'
f' """Protocol-specific bound sender."""\n')
# Define handler() overloads for all registered message types.
msgtypes = [
@ -386,11 +381,6 @@ class MessageProtocol:
f' """Send a message {how}."""\n'
f' return {awt}self._sender.'
f'send{sfx}(self._obj, message)\n')
out += ('\n'
' @property\n'
' def protocol(self) -> MessageProtocol:\n'
' """Protocol associated with this sender."""\n'
' return self._sender.protocol\n')
return out
@ -478,17 +468,8 @@ class MessageProtocol:
out += (f'\n'
f'\n'
f'class {ppre}Bound{basename}:\n'
f' """Protocol-specific bound receiver."""\n'
f'\n'
f' def __init__(\n'
f' self,\n'
f' obj: Any,\n'
f' receiver: {ppre}{basename},\n'
f' ) -> None:\n'
f' assert obj is not None\n'
f' self._obj = obj\n'
f' self._receiver = receiver\n')
f'class {ppre}Bound{basename}(BoundMessageReceiver):\n'
f' """Protocol-specific bound receiver."""\n')
if is_async:
out += (
'\n'
@ -506,12 +487,6 @@ class MessageProtocol:
' return self._receiver.handle_raw_message'
'(self._obj, message)\n')
out += ('\n'
' @property\n'
' def protocol(self) -> MessageProtocol:\n'
' """Protocol associated with this receiver."""\n'
' return self._receiver.protocol\n')
return out
@ -593,6 +568,20 @@ class MessageSender:
return response
class BoundMessageSender:
"""Base class for bound senders."""
def __init__(self, obj: Any, sender: MessageSender) -> None:
assert obj is not None
self._obj = obj
self._sender = sender
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this sender."""
return self._sender.protocol
class MessageReceiver:
"""Facilitates receiving & responding to messages from a remote source.
@ -742,7 +731,8 @@ class MessageReceiver:
assert type(response) in msgtype.get_response_types()
return self.protocol.encode_response(response)
def _handle_raw_message_error(self, exc: Exception) -> str:
def raw_response_for_error(self, exc: Exception) -> str:
"""Return a raw response for an error that occurred during handling."""
if self.protocol.log_remote_exceptions:
logging.exception('Error handling message.')
@ -771,7 +761,7 @@ class MessageReceiver:
return self._encode_response(response, msgtype)
except Exception as exc:
return self._handle_raw_message_error(exc)
return self.raw_response_for_error(exc)
async def handle_raw_message_async(self, bound_obj: Any, msg: str) -> str:
"""Should be called when the receiver gets a message.
@ -788,4 +778,33 @@ class MessageReceiver:
return self._encode_response(response, msgtype)
except Exception as exc:
return self._handle_raw_message_error(exc)
return self.raw_response_for_error(exc)
class BoundMessageReceiver:
"""Base bound receiver class."""
def __init__(
self,
obj: Any,
receiver: MessageReceiver,
) -> None:
assert obj is not None
self._obj = obj
self._receiver = receiver
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this receiver."""
return self._receiver.protocol
def raw_response_for_error(self, exc: Exception) -> str:
"""Return a raw response for an error that occurred during handling.
This is automatically called from standard handle_raw_message_x()
calls but can be manually invoked if errors occur outside of there.
This gives clients a better idea of what went wrong vs simply
returning invalid data which they might dismiss as a connection
related error.
"""
return self._receiver.raw_response_for_error(exc)