diff --git a/.idea/dictionaries/ericf.xml b/.idea/dictionaries/ericf.xml
index 55028a7a..8b5ad5e0 100644
--- a/.idea/dictionaries/ericf.xml
+++ b/.idea/dictionaries/ericf.xml
@@ -1877,6 +1877,7 @@
recv
redist
redistributables
+ regtp
relpath
remainingchecks
remoteapp
diff --git a/ballisticacore-cmake/.idea/dictionaries/ericf.xml b/ballisticacore-cmake/.idea/dictionaries/ericf.xml
index 00f8f13b..e8d994ab 100644
--- a/ballisticacore-cmake/.idea/dictionaries/ericf.xml
+++ b/ballisticacore-cmake/.idea/dictionaries/ericf.xml
@@ -870,6 +870,7 @@
redundants
refcounted
refl
+ regtp
rehel
reloadmedia
rendererdata
diff --git a/tests/test_efro/test_message.py b/tests/test_efro/test_message.py
index 07ac208b..81161621 100644
--- a/tests/test_efro/test_message.py
+++ b/tests/test_efro/test_message.py
@@ -14,7 +14,7 @@ from efrotools.statictest import static_type_equals
from efro.error import CleanError, RemoteError
from efro.dataclassio import ioprepped
from efro.message import (Message, Response, MessageProtocol, MessageSender,
- MessageReceiver, EmptyResponse)
+ MessageReceiver)
if TYPE_CHECKING:
from typing import List, Type, Any, Callable, Union, Optional
@@ -176,9 +176,8 @@ TEST_PROTOCOL = MessageProtocol(
response_types={
0: _TResponse1,
1: _TResponse2,
- 2: EmptyResponse,
},
- trusted_client=True,
+ trusted_sender=True,
log_remote_exceptions=False,
)
diff --git a/tools/efro/message.py b/tools/efro/message.py
index 99f619e1..3b59f78c 100644
--- a/tools/efro/message.py
+++ b/tools/efro/message.py
@@ -88,9 +88,13 @@ class MessageProtocol:
type_key: Optional[str] = None,
preserve_clean_errors: bool = True,
log_remote_exceptions: bool = True,
- trusted_client: bool = False) -> None:
+ trusted_sender: bool = False) -> None:
"""Create a protocol with a given configuration.
+ Note that common response types are automatically registered
+ with (unchanging negative ids) so they don't need to be passed
+ explicitly (but can be if a different id is desired).
+
If 'type_key' is provided, the message type ID is stored as the
provided key in the message dict; otherwise it will be stored as
part of a top level dict with the message payload appearing as a
@@ -100,10 +104,10 @@ class MessageProtocol:
on the remote end will result in the same error raised locally.
All other Exception types come across as efro.error.RemoteError.
- If 'trusted_client' is True, stringified remote stack traces will
- be included in the RemoteError. This should only be enabled in cases
- where the client is trusted.
+ If 'trusted_sender' is True, stringified remote stack traces will
+ be included in the responses if errors occur.
"""
+ # pylint: disable=too-many-locals
self.message_types_by_id: Dict[int, Type[Message]] = {}
self.message_ids_by_type: Dict[Type[Message], int] = {}
self.response_types_by_id: Dict[int, Type[Response]] = {}
@@ -129,12 +133,18 @@ class MessageProtocol:
self.response_types_by_id[r_id] = r_type
self.response_ids_by_type[r_type] = r_id
- # If they didn't register ErrorResponse, do so with a special
- # -1 id which ensures it will never conflict with user messages.
- if ErrorResponse not in self.response_ids_by_type:
- assert self.response_types_by_id.get(-1) is None
- self.response_types_by_id[-1] = ErrorResponse
- self.response_ids_by_type[ErrorResponse] = -1
+ # Go ahead and auto-register a few common response types
+ # if the user has not done so explicitly. Use unique IDs which
+ # will never change or overlap with user ids.
+ def _reg(reg_tp: Type[Response], reg_id: int) -> None:
+ if reg_tp in self.response_ids_by_type:
+ return
+ assert self.response_types_by_id.get(reg_id) is None
+ self.response_types_by_id[reg_id] = reg_tp
+ self.response_ids_by_type[reg_tp] = reg_id
+
+ _reg(ErrorResponse, -1)
+ _reg(EmptyResponse, -2)
# Some extra-thorough validation in debug mode.
if __debug__:
@@ -167,7 +177,7 @@ class MessageProtocol:
self._type_key = type_key
self.preserve_clean_errors = preserve_clean_errors
self.log_remote_exceptions = log_remote_exceptions
- self.trusted_client = trusted_client
+ self.trusted_sender = trusted_sender
def encode_message(self, message: Message) -> bytes:
"""Encode a message to bytes for transport."""
@@ -671,7 +681,7 @@ class MessageReceiver:
err_response = ErrorResponse(
error_message=(traceback.format_exc()
- if self._protocol.trusted_client else
+ if self._protocol.trusted_sender else
'An unknown error has occurred.'),
error_type=ErrorType.OTHER)
return self._protocol.encode_response(err_response)