diff --git a/.idea/dictionaries/ericf.xml b/.idea/dictionaries/ericf.xml
index 2f3299c2..ab1e2242 100644
--- a/.idea/dictionaries/ericf.xml
+++ b/.idea/dictionaries/ericf.xml
@@ -1992,6 +1992,7 @@
selindex
selwidget
selwidgets
+ sendable
senze
seqtype
seqtypestr
diff --git a/ballisticacore-cmake/.idea/dictionaries/ericf.xml b/ballisticacore-cmake/.idea/dictionaries/ericf.xml
index 471e39a2..a21794a7 100644
--- a/ballisticacore-cmake/.idea/dictionaries/ericf.xml
+++ b/ballisticacore-cmake/.idea/dictionaries/ericf.xml
@@ -918,6 +918,7 @@
selindex
selwidget
selwidgets
+ sendable
seqlen
seqtype
seqtypestr
diff --git a/tests/test_efro/test_message.py b/tests/test_efro/test_message.py
index d4801bd4..6a9451d4 100644
--- a/tests/test_efro/test_message.py
+++ b/tests/test_efro/test_message.py
@@ -10,11 +10,11 @@ from dataclasses import dataclass
import pytest
+from efrotools.statictest import static_type_equals
from efro.error import CleanError, RemoteError
from efro.dataclassio import ioprepped
from efro.message import (Message, Response, MessageProtocol, MessageSender,
MessageReceiver)
-from efrotools.statictest import static_type_equals
if TYPE_CHECKING:
from typing import List, Type, Any, Callable, Union
@@ -151,10 +151,12 @@ class _BoundTestMessageReceiver:
TEST_PROTOCOL = MessageProtocol(
message_types={
- 1: _TMessage1,
- 2: _TMessage2,
- 3: _TResponse1,
- 4: _TResponse2,
+ 0: _TMessage1,
+ 1: _TMessage2,
+ },
+ response_types={
+ 0: _TResponse1,
+ 1: _TResponse2,
},
trusted_client=True,
log_remote_exceptions=False,
@@ -167,10 +169,16 @@ def test_protocol_creation() -> None:
# This should fail because _TMessage1 can return _TResponse1 which
# is not given an id here.
with pytest.raises(ValueError):
- _protocol = MessageProtocol(message_types={1: _TMessage1})
+ _protocol = MessageProtocol(
+ message_types={0: _TMessage1},
+ response_types={0: _TResponse2},
+ )
# Now it should work.
- _protocol = MessageProtocol(message_types={1: _TMessage1, 2: _TResponse1})
+ _protocol = MessageProtocol(
+ message_types={0: _TMessage1},
+ response_types={0: _TResponse1},
+ )
def test_sender_module_creation() -> None:
diff --git a/tools/efro/message.py b/tools/efro/message.py
index b9c8f090..3f9439a6 100644
--- a/tools/efro/message.py
+++ b/tools/efro/message.py
@@ -50,14 +50,23 @@ class Response:
"""Base class for responses to messages."""
+# Some standard response types:
+
+
@ioprepped
@dataclass
-class RemoteErrorMessage(Response):
+class RemoteErrorResponse(Response):
"""Message saying some error has occurred on the other end."""
error_message: Annotated[str, IOAttrs('m')]
error_type: Annotated[RemoteErrorType, IOAttrs('t')]
+@ioprepped
+@dataclass
+class EmptyResponse(Response):
+ """The response equivalent of None."""
+
+
class MessageProtocol:
"""Wrangles a set of message types, formats, and response types.
Both endpoints must be using a compatible Protocol for communication
@@ -67,8 +76,8 @@ class MessageProtocol:
"""
def __init__(self,
- message_types: Dict[int, Union[Type[Message],
- Type[Response]]],
+ message_types: Dict[int, Type[Message]],
+ response_types: Dict[int, Type[Response]],
type_key: Optional[str] = None,
preserve_clean_errors: bool = True,
log_remote_exceptions: bool = True,
@@ -88,10 +97,10 @@ class MessageProtocol:
be included in the RemoteError. This should only be enabled in cases
where the client is trusted.
"""
- self.message_types_by_id: Dict[int, Union[Type[Message],
- Type[Response]]] = {}
- self.message_ids_by_type: Dict[Union[Type[Message], Type[Response]],
- int] = {}
+ self.message_types_by_id: Dict[int, Type[Message]] = {}
+ self.message_ids_by_type: Dict[Type[Message], int] = {}
+ self.response_types_by_id: Dict[int, Type[Response]] = {}
+ self.response_ids_by_type: Dict[Type[Response], int] = {}
for m_id, m_type in message_types.items():
# Make sure only valid message types were passed and each
@@ -99,30 +108,38 @@ class MessageProtocol:
assert isinstance(m_id, int)
assert m_id >= 0
assert (is_ioprepped_dataclass(m_type)
- and issubclass(m_type, (Message, Response)))
+ and issubclass(m_type, Message))
assert self.message_types_by_id.get(m_id) is None
-
self.message_types_by_id[m_id] = m_type
self.message_ids_by_type[m_type] = m_id
+ for r_id, r_type in response_types.items():
+ assert isinstance(r_id, int)
+ assert r_id >= 0
+ assert (is_ioprepped_dataclass(r_type)
+ and issubclass(r_type, Response))
+ assert self.response_types_by_id.get(r_id) is None
+ self.response_types_by_id[r_id] = r_type
+ self.response_ids_by_type[r_type] = r_id
+
# Some extra-thorough validation in debug mode.
if __debug__:
- # Make sure all return types are valid and have been assigned
- # an ID as well.
+ # Make sure all Message types' return types are valid
+ # and have been assigned an ID as well.
all_response_types: Set[Type[Response]] = set()
for m_id, m_type in message_types.items():
- if issubclass(m_type, Message):
- m_rtypes = m_type.get_response_types()
- assert isinstance(m_rtypes, list)
- assert m_rtypes # make sure not empty
- assert len(set(m_rtypes)) == len(m_rtypes) # check dups
- all_response_types.update(m_rtypes)
+ m_rtypes = m_type.get_response_types()
+ assert isinstance(m_rtypes, list)
+ assert m_rtypes, (
+ f'Message type {m_type} specifies no return types.')
+ assert len(set(m_rtypes)) == len(m_rtypes) # check dups
+ all_response_types.update(m_rtypes)
for cls in all_response_types:
assert is_ioprepped_dataclass(cls) and issubclass(
- cls, (Message, Response))
- if cls not in self.message_ids_by_type:
+ cls, Response)
+ if cls not in self.response_ids_by_type:
raise ValueError(f'Possible response type {cls}'
- f' was not included in message_types.')
+ f' was not included in response_types.')
# Make sure all registered types have unique base names.
# We can take advantage of this to generate cleaner looking
@@ -138,23 +155,35 @@ class MessageProtocol:
self.log_remote_exceptions = log_remote_exceptions
self.trusted_client = trusted_client
- def message_encode(self,
- message: Union[Message, Response],
+ def encode_message(self,
+ message: Message,
is_error: bool = False) -> bytes:
"""Encode a message to bytes for transport."""
+ return self._encode(message, is_error, self.message_ids_by_type,
+ 'message')
+
+ def encode_response(self,
+ response: Response,
+ is_error: bool = False) -> bytes:
+ """Encode a response to bytes for transport."""
+ return self._encode(response, is_error, self.response_ids_by_type,
+ 'response')
+
+ def _encode(self, message: Any, is_error: bool,
+ ids_by_type: Dict[Type, int], opname: str) -> bytes:
+ """Encode a message to bytes for transport."""
m_id: Optional[int]
if is_error:
m_id = -1
else:
- m_id = self.message_ids_by_type.get(type(message))
+ m_id = ids_by_type.get(type(message))
if m_id is None:
- raise TypeError(
- f'Message/Response type is not registered in Protocol:'
- f' {type(message)}')
+ raise TypeError(f'{opname} type is not registered in protocol:'
+ f' {type(message)}')
msgdict = dataclass_to_dict(message)
- # Encode type as part of the message dict if desired
+ # Encode type as part of the message/response dict if desired
# (for legacy compatibility).
if self._type_key is not None:
if self._type_key in msgdict:
@@ -166,12 +195,21 @@ class MessageProtocol:
out = {'m': msgdict, 't': m_id}
return json.dumps(out, separators=(',', ':')).encode()
- def message_decode(self, data: bytes) -> Union[Message, Response]:
- """Decode a message from bytes.
+ def decode_message(self, data: bytes) -> Message:
+ """Decode a message from bytes."""
+ out = self._decode(data, self.message_types_by_id, 'message')
+ assert isinstance(out, Message)
+ return out
- If the message represents a remote error, an Exception will
- be raised.
- """
+ def decode_response(self, data: bytes) -> Response:
+ """Decode a response from bytes."""
+ out = self._decode(data, self.response_types_by_id, 'response')
+ assert isinstance(out, Response)
+ return out
+
+ def _decode(self, data: bytes, types_by_id: Dict[int, Type],
+ opname: str) -> Any:
+ """Decode a message from bytes."""
msgfull = json.loads(data.decode())
assert isinstance(msgfull, dict)
msgdict: Optional[dict]
@@ -187,21 +225,18 @@ class MessageProtocol:
# Special case: a remote error occurred. Raise a local Exception.
if m_id == -1:
- err = dataclass_from_dict(RemoteErrorMessage, msgdict)
+ assert opname == 'response'
+ err = dataclass_from_dict(RemoteErrorResponse, msgdict)
if (self.preserve_clean_errors
and err.error_type is RemoteErrorType.CLEAN):
raise CleanError(err.error_message)
raise RemoteError(err.error_message)
- # Decode this particular type and make sure its valid.
- msgtype = self.message_types_by_id.get(m_id)
+ # Decode this particular type.
+ msgtype = types_by_id.get(m_id)
if msgtype is None:
- raise TypeError(
- f'Got unregistered message/response type id of {m_id}.')
-
- out = dataclass_from_dict(msgtype, msgdict)
- assert isinstance(out, (Message, Response))
- return out
+ raise TypeError(f'Got unregistered {opname} type id of {m_id}.')
+ return dataclass_from_dict(msgtype, msgdict)
def _get_module_header(self, part: str) -> str:
"""Return common parts of generated modules."""
@@ -287,8 +322,8 @@ class MessageProtocol:
# Ew; @overload requires at least 2 different signatures so
# we need to simply write a single function if we have < 2.
if len(msgtypes) == 1:
- raise RuntimeError('FIXME: currently require at least 2'
- ' message types.')
+ raise RuntimeError('FIXME: currently we require at least 2'
+ ' registered message types; found 1.')
if len(msgtypes) > 1:
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
@@ -345,7 +380,7 @@ class MessageProtocol:
# we need to simply write a single function if we have < 2.
if len(msgtypes) == 1:
raise RuntimeError('FIXME: currently require at least 2'
- ' message types.')
+ ' registered message types; found 1.')
if len(msgtypes) > 1:
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
@@ -433,9 +468,9 @@ class MessageSender:
if self._send_raw_message_call is None:
raise RuntimeError('send() is unimplemented for this type.')
- msg_encoded = self._protocol.message_encode(message)
+ msg_encoded = self._protocol.encode_message(message)
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
- response = self._protocol.message_decode(response_encoded)
+ response = self._protocol.decode_response(response_encoded)
assert isinstance(response, Response)
assert type(response) in type(message).get_response_types()
return response
@@ -572,7 +607,7 @@ class MessageReceiver:
"""Decode, handle, and return encoded response for a message."""
try:
# Decode the incoming message.
- msg_decoded = self._protocol.message_decode(msg)
+ msg_decoded = self._protocol.decode_message(msg)
msgtype = type(msg_decoded)
assert issubclass(msgtype, Message)
@@ -585,26 +620,27 @@ class MessageReceiver:
# Re-encode the response.
assert isinstance(response, Response)
assert type(response) in msgtype.get_response_types()
- return self._protocol.message_encode(response)
+ return self._protocol.encode_response(response)
except Exception as exc:
if self._protocol.log_remote_exceptions:
logging.exception('Error handling message.')
- # If anything goes wrong, return a RemoteErrorMessage instead.
+ # If anything goes wrong, return a RemoteErrorResponse instead.
if (isinstance(exc, CleanError)
and self._protocol.preserve_clean_errors):
- response = RemoteErrorMessage(error_message=str(exc),
- error_type=RemoteErrorType.CLEAN)
+ err_response = RemoteErrorResponse(
+ error_message=str(exc), error_type=RemoteErrorType.CLEAN)
+
else:
- response = RemoteErrorMessage(
+ err_response = RemoteErrorResponse(
error_message=(traceback.format_exc()
if self._protocol.trusted_client else
'An unknown error has occurred.'),
error_type=RemoteErrorType.OTHER)
- return self._protocol.message_encode(response, is_error=True)
+ return self._protocol.encode_response(err_response, is_error=True)
async def handle_raw_message_async(self, msg: bytes) -> bytes:
"""Should be called when the receiver gets a message.