mirror of
https://github.com/RYDE-WORK/ballistica.git
synced 2026-01-26 17:03:14 +08:00
moved some bound-sender/receiver functionality to base classes instead of generated
This commit is contained in:
parent
f8bd041588
commit
08ea737144
@ -15,7 +15,8 @@ from efrotools.statictest import static_type_equals
|
||||
from efro.error import CleanError, RemoteError
|
||||
from efro.dataclassio import ioprepped
|
||||
from efro.message import (Message, Response, MessageProtocol, MessageSender,
|
||||
MessageReceiver)
|
||||
BoundMessageSender, MessageReceiver,
|
||||
BoundMessageReceiver)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import List, Type, Any, Callable, Union, Optional, Awaitable
|
||||
@ -83,14 +84,9 @@ class _TestMessageSender(MessageSender):
|
||||
return _BoundTestMessageSender(obj, self)
|
||||
|
||||
|
||||
class _BoundTestMessageSender:
|
||||
class _BoundTestMessageSender(BoundMessageSender):
|
||||
"""Protocol-specific bound sender."""
|
||||
|
||||
def __init__(self, obj: Any, sender: _TestMessageSender) -> None:
|
||||
assert obj is not None
|
||||
self._obj = obj
|
||||
self._sender = sender
|
||||
|
||||
@overload
|
||||
def send(self, message: _TMsg1) -> _TResp1:
|
||||
...
|
||||
@ -123,11 +119,6 @@ class _BoundTestMessageSender:
|
||||
"""Send a message asynchronously."""
|
||||
return await self._sender.send_async(self._obj, message)
|
||||
|
||||
@property
|
||||
def protocol(self) -> MessageProtocol:
|
||||
"""Protocol associated with this sender."""
|
||||
return self._sender.protocol
|
||||
|
||||
|
||||
# SEND_CODE_TEST_END
|
||||
# RCVS_CODE_TEST_BEGIN
|
||||
@ -172,27 +163,13 @@ class _TestSyncMessageReceiver(MessageReceiver):
|
||||
return call
|
||||
|
||||
|
||||
class _BoundTestSyncMessageReceiver:
|
||||
class _BoundTestSyncMessageReceiver(BoundMessageReceiver):
|
||||
"""Protocol-specific bound receiver."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obj: Any,
|
||||
receiver: _TestSyncMessageReceiver,
|
||||
) -> None:
|
||||
assert obj is not None
|
||||
self._obj = obj
|
||||
self._receiver = receiver
|
||||
|
||||
def handle_raw_message(self, message: str) -> str:
|
||||
"""Synchronously handle a raw incoming message."""
|
||||
return self._receiver.handle_raw_message(self._obj, message)
|
||||
|
||||
@property
|
||||
def protocol(self) -> MessageProtocol:
|
||||
"""Protocol associated with this receiver."""
|
||||
return self._receiver.protocol
|
||||
|
||||
|
||||
# RCVS_CODE_TEST_END
|
||||
# RCVA_CODE_TEST_BEGIN
|
||||
@ -237,28 +214,14 @@ class _TestAsyncMessageReceiver(MessageReceiver):
|
||||
return call
|
||||
|
||||
|
||||
class _BoundTestAsyncMessageReceiver:
|
||||
class _BoundTestAsyncMessageReceiver(BoundMessageReceiver):
|
||||
"""Protocol-specific bound receiver."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obj: Any,
|
||||
receiver: _TestAsyncMessageReceiver,
|
||||
) -> None:
|
||||
assert obj is not None
|
||||
self._obj = obj
|
||||
self._receiver = receiver
|
||||
|
||||
async def handle_raw_message(self, message: str) -> str:
|
||||
"""Asynchronously handle a raw incoming message."""
|
||||
return await self._receiver.handle_raw_message_async(
|
||||
self._obj, message)
|
||||
|
||||
@property
|
||||
def protocol(self) -> MessageProtocol:
|
||||
"""Protocol associated with this receiver."""
|
||||
return self._receiver.protocol
|
||||
|
||||
|
||||
# RCVA_CODE_TEST_END
|
||||
|
||||
|
||||
@ -23,7 +23,6 @@ from efro.dataclassio import (ioprepped, is_ioprepped_dataclass, IOAttrs,
|
||||
if TYPE_CHECKING:
|
||||
from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set,
|
||||
Sequence, Union, Awaitable)
|
||||
from efro.error import CommunicationError
|
||||
|
||||
TM = TypeVar('TM', bound='MessageSender')
|
||||
|
||||
@ -278,10 +277,12 @@ class MessageProtocol:
|
||||
importlines2 += f' {line}\n'
|
||||
|
||||
if part == 'sender':
|
||||
importlines1 = 'from efro.message import MessageSender'
|
||||
importlines1 = (
|
||||
'from efro.message import MessageSender, BoundMessageSender')
|
||||
tpimportex = ''
|
||||
else:
|
||||
importlines1 = 'from efro.message import MessageReceiver'
|
||||
importlines1 = ('from efro.message import MessageReceiver,'
|
||||
' BoundMessageReceiver')
|
||||
tpimportex = ', Awaitable'
|
||||
|
||||
out = ('# Released under the MIT License. See LICENSE for details.\n'
|
||||
@ -335,14 +336,8 @@ class MessageProtocol:
|
||||
f'(obj, self)\n'
|
||||
f'\n'
|
||||
f'\n'
|
||||
f'class {ppre}Bound{basename}:\n'
|
||||
f' """Protocol-specific bound sender."""\n'
|
||||
f'\n'
|
||||
f' def __init__(self, obj: Any,'
|
||||
f' sender: {ppre}{basename}) -> None:\n'
|
||||
f' assert obj is not None\n'
|
||||
f' self._obj = obj\n'
|
||||
f' self._sender = sender\n')
|
||||
f'class {ppre}Bound{basename}(BoundMessageSender):\n'
|
||||
f' """Protocol-specific bound sender."""\n')
|
||||
|
||||
# Define handler() overloads for all registered message types.
|
||||
msgtypes = [
|
||||
@ -386,11 +381,6 @@ class MessageProtocol:
|
||||
f' """Send a message {how}."""\n'
|
||||
f' return {awt}self._sender.'
|
||||
f'send{sfx}(self._obj, message)\n')
|
||||
out += ('\n'
|
||||
' @property\n'
|
||||
' def protocol(self) -> MessageProtocol:\n'
|
||||
' """Protocol associated with this sender."""\n'
|
||||
' return self._sender.protocol\n')
|
||||
|
||||
return out
|
||||
|
||||
@ -478,17 +468,8 @@ class MessageProtocol:
|
||||
|
||||
out += (f'\n'
|
||||
f'\n'
|
||||
f'class {ppre}Bound{basename}:\n'
|
||||
f' """Protocol-specific bound receiver."""\n'
|
||||
f'\n'
|
||||
f' def __init__(\n'
|
||||
f' self,\n'
|
||||
f' obj: Any,\n'
|
||||
f' receiver: {ppre}{basename},\n'
|
||||
f' ) -> None:\n'
|
||||
f' assert obj is not None\n'
|
||||
f' self._obj = obj\n'
|
||||
f' self._receiver = receiver\n')
|
||||
f'class {ppre}Bound{basename}(BoundMessageReceiver):\n'
|
||||
f' """Protocol-specific bound receiver."""\n')
|
||||
if is_async:
|
||||
out += (
|
||||
'\n'
|
||||
@ -506,12 +487,6 @@ class MessageProtocol:
|
||||
' return self._receiver.handle_raw_message'
|
||||
'(self._obj, message)\n')
|
||||
|
||||
out += ('\n'
|
||||
' @property\n'
|
||||
' def protocol(self) -> MessageProtocol:\n'
|
||||
' """Protocol associated with this receiver."""\n'
|
||||
' return self._receiver.protocol\n')
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@ -593,6 +568,20 @@ class MessageSender:
|
||||
return response
|
||||
|
||||
|
||||
class BoundMessageSender:
|
||||
"""Base class for bound senders."""
|
||||
|
||||
def __init__(self, obj: Any, sender: MessageSender) -> None:
|
||||
assert obj is not None
|
||||
self._obj = obj
|
||||
self._sender = sender
|
||||
|
||||
@property
|
||||
def protocol(self) -> MessageProtocol:
|
||||
"""Protocol associated with this sender."""
|
||||
return self._sender.protocol
|
||||
|
||||
|
||||
class MessageReceiver:
|
||||
"""Facilitates receiving & responding to messages from a remote source.
|
||||
|
||||
@ -742,7 +731,8 @@ class MessageReceiver:
|
||||
assert type(response) in msgtype.get_response_types()
|
||||
return self.protocol.encode_response(response)
|
||||
|
||||
def _handle_raw_message_error(self, exc: Exception) -> str:
|
||||
def raw_response_for_error(self, exc: Exception) -> str:
|
||||
"""Return a raw response for an error that occurred during handling."""
|
||||
if self.protocol.log_remote_exceptions:
|
||||
logging.exception('Error handling message.')
|
||||
|
||||
@ -771,7 +761,7 @@ class MessageReceiver:
|
||||
return self._encode_response(response, msgtype)
|
||||
|
||||
except Exception as exc:
|
||||
return self._handle_raw_message_error(exc)
|
||||
return self.raw_response_for_error(exc)
|
||||
|
||||
async def handle_raw_message_async(self, bound_obj: Any, msg: str) -> str:
|
||||
"""Should be called when the receiver gets a message.
|
||||
@ -788,4 +778,33 @@ class MessageReceiver:
|
||||
return self._encode_response(response, msgtype)
|
||||
|
||||
except Exception as exc:
|
||||
return self._handle_raw_message_error(exc)
|
||||
return self.raw_response_for_error(exc)
|
||||
|
||||
|
||||
class BoundMessageReceiver:
|
||||
"""Base bound receiver class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obj: Any,
|
||||
receiver: MessageReceiver,
|
||||
) -> None:
|
||||
assert obj is not None
|
||||
self._obj = obj
|
||||
self._receiver = receiver
|
||||
|
||||
@property
|
||||
def protocol(self) -> MessageProtocol:
|
||||
"""Protocol associated with this receiver."""
|
||||
return self._receiver.protocol
|
||||
|
||||
def raw_response_for_error(self, exc: Exception) -> str:
|
||||
"""Return a raw response for an error that occurred during handling.
|
||||
|
||||
This is automatically called from standard handle_raw_message_x()
|
||||
calls but can be manually invoked if errors occur outside of there.
|
||||
This gives clients a better idea of what went wrong vs simply
|
||||
returning invalid data which they might dismiss as a connection
|
||||
related error.
|
||||
"""
|
||||
return self._receiver.raw_response_for_error(exc)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user