auto-registering more common response types

This commit is contained in:
Eric Froemling 2021-09-09 10:52:30 -05:00
parent 34df4658a6
commit cdb602b921
No known key found for this signature in database
GPG Key ID: 89C93F0F8D6D5A98
4 changed files with 26 additions and 15 deletions

View File

@ -1877,6 +1877,7 @@
<w>recv</w> <w>recv</w>
<w>redist</w> <w>redist</w>
<w>redistributables</w> <w>redistributables</w>
<w>regtp</w>
<w>relpath</w> <w>relpath</w>
<w>remainingchecks</w> <w>remainingchecks</w>
<w>remoteapp</w> <w>remoteapp</w>

View File

@ -870,6 +870,7 @@
<w>redundants</w> <w>redundants</w>
<w>refcounted</w> <w>refcounted</w>
<w>refl</w> <w>refl</w>
<w>regtp</w>
<w>rehel</w> <w>rehel</w>
<w>reloadmedia</w> <w>reloadmedia</w>
<w>rendererdata</w> <w>rendererdata</w>

View File

@ -14,7 +14,7 @@ from efrotools.statictest import static_type_equals
from efro.error import CleanError, RemoteError from efro.error import CleanError, RemoteError
from efro.dataclassio import ioprepped from efro.dataclassio import ioprepped
from efro.message import (Message, Response, MessageProtocol, MessageSender, from efro.message import (Message, Response, MessageProtocol, MessageSender,
MessageReceiver, EmptyResponse) MessageReceiver)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import List, Type, Any, Callable, Union, Optional from typing import List, Type, Any, Callable, Union, Optional
@ -176,9 +176,8 @@ TEST_PROTOCOL = MessageProtocol(
response_types={ response_types={
0: _TResponse1, 0: _TResponse1,
1: _TResponse2, 1: _TResponse2,
2: EmptyResponse,
}, },
trusted_client=True, trusted_sender=True,
log_remote_exceptions=False, log_remote_exceptions=False,
) )

View File

@ -88,9 +88,13 @@ class MessageProtocol:
type_key: Optional[str] = None, type_key: Optional[str] = None,
preserve_clean_errors: bool = True, preserve_clean_errors: bool = True,
log_remote_exceptions: bool = True, log_remote_exceptions: bool = True,
trusted_client: bool = False) -> None: trusted_sender: bool = False) -> None:
"""Create a protocol with a given configuration. """Create a protocol with a given configuration.
Note that common response types are automatically registered
with (unchanging negative ids) so they don't need to be passed
explicitly (but can be if a different id is desired).
If 'type_key' is provided, the message type ID is stored as the If 'type_key' is provided, the message type ID is stored as the
provided key in the message dict; otherwise it will be stored as provided key in the message dict; otherwise it will be stored as
part of a top level dict with the message payload appearing as a part of a top level dict with the message payload appearing as a
@ -100,10 +104,10 @@ class MessageProtocol:
on the remote end will result in the same error raised locally. on the remote end will result in the same error raised locally.
All other Exception types come across as efro.error.RemoteError. All other Exception types come across as efro.error.RemoteError.
If 'trusted_client' is True, stringified remote stack traces will If 'trusted_sender' is True, stringified remote stack traces will
be included in the RemoteError. This should only be enabled in cases be included in the responses if errors occur.
where the client is trusted.
""" """
# pylint: disable=too-many-locals
self.message_types_by_id: Dict[int, Type[Message]] = {} self.message_types_by_id: Dict[int, Type[Message]] = {}
self.message_ids_by_type: Dict[Type[Message], int] = {} self.message_ids_by_type: Dict[Type[Message], int] = {}
self.response_types_by_id: Dict[int, Type[Response]] = {} self.response_types_by_id: Dict[int, Type[Response]] = {}
@ -129,12 +133,18 @@ class MessageProtocol:
self.response_types_by_id[r_id] = r_type self.response_types_by_id[r_id] = r_type
self.response_ids_by_type[r_type] = r_id self.response_ids_by_type[r_type] = r_id
# If they didn't register ErrorResponse, do so with a special # Go ahead and auto-register a few common response types
# -1 id which ensures it will never conflict with user messages. # if the user has not done so explicitly. Use unique IDs which
if ErrorResponse not in self.response_ids_by_type: # will never change or overlap with user ids.
assert self.response_types_by_id.get(-1) is None def _reg(reg_tp: Type[Response], reg_id: int) -> None:
self.response_types_by_id[-1] = ErrorResponse if reg_tp in self.response_ids_by_type:
self.response_ids_by_type[ErrorResponse] = -1 return
assert self.response_types_by_id.get(reg_id) is None
self.response_types_by_id[reg_id] = reg_tp
self.response_ids_by_type[reg_tp] = reg_id
_reg(ErrorResponse, -1)
_reg(EmptyResponse, -2)
# Some extra-thorough validation in debug mode. # Some extra-thorough validation in debug mode.
if __debug__: if __debug__:
@ -167,7 +177,7 @@ class MessageProtocol:
self._type_key = type_key self._type_key = type_key
self.preserve_clean_errors = preserve_clean_errors self.preserve_clean_errors = preserve_clean_errors
self.log_remote_exceptions = log_remote_exceptions self.log_remote_exceptions = log_remote_exceptions
self.trusted_client = trusted_client self.trusted_sender = trusted_sender
def encode_message(self, message: Message) -> bytes: def encode_message(self, message: Message) -> bytes:
"""Encode a message to bytes for transport.""" """Encode a message to bytes for transport."""
@ -671,7 +681,7 @@ class MessageReceiver:
err_response = ErrorResponse( err_response = ErrorResponse(
error_message=(traceback.format_exc() error_message=(traceback.format_exc()
if self._protocol.trusted_client else if self._protocol.trusted_sender else
'An unknown error has occurred.'), 'An unknown error has occurred.'),
error_type=ErrorType.OTHER) error_type=ErrorType.OTHER)
return self._protocol.encode_response(err_response) return self._protocol.encode_response(err_response)