more messages work

This commit is contained in:
Eric Froemling 2021-09-08 18:14:48 -05:00
parent 95bbb89d14
commit 9de3738953
No known key found for this signature in database
GPG Key ID: 89C93F0F8D6D5A98
4 changed files with 106 additions and 60 deletions

View File

@ -1992,6 +1992,7 @@
<w>selindex</w>
<w>selwidget</w>
<w>selwidgets</w>
<w>sendable</w>
<w>senze</w>
<w>seqtype</w>
<w>seqtypestr</w>

View File

@ -918,6 +918,7 @@
<w>selindex</w>
<w>selwidget</w>
<w>selwidgets</w>
<w>sendable</w>
<w>seqlen</w>
<w>seqtype</w>
<w>seqtypestr</w>

View File

@ -10,11 +10,11 @@ from dataclasses import dataclass
import pytest
from efrotools.statictest import static_type_equals
from efro.error import CleanError, RemoteError
from efro.dataclassio import ioprepped
from efro.message import (Message, Response, MessageProtocol, MessageSender,
MessageReceiver)
from efrotools.statictest import static_type_equals
if TYPE_CHECKING:
from typing import List, Type, Any, Callable, Union
@ -151,10 +151,12 @@ class _BoundTestMessageReceiver:
TEST_PROTOCOL = MessageProtocol(
message_types={
1: _TMessage1,
2: _TMessage2,
3: _TResponse1,
4: _TResponse2,
0: _TMessage1,
1: _TMessage2,
},
response_types={
0: _TResponse1,
1: _TResponse2,
},
trusted_client=True,
log_remote_exceptions=False,
@ -167,10 +169,16 @@ def test_protocol_creation() -> None:
# This should fail because _TMessage1 can return _TResponse1 which
# is not given an id here.
with pytest.raises(ValueError):
_protocol = MessageProtocol(message_types={1: _TMessage1})
_protocol = MessageProtocol(
message_types={0: _TMessage1},
response_types={0: _TResponse2},
)
# Now it should work.
_protocol = MessageProtocol(message_types={1: _TMessage1, 2: _TResponse1})
_protocol = MessageProtocol(
message_types={0: _TMessage1},
response_types={0: _TResponse1},
)
def test_sender_module_creation() -> None:

View File

@ -50,14 +50,23 @@ class Response:
"""Base class for responses to messages."""
# Some standard response types:
@ioprepped
@dataclass
class RemoteErrorMessage(Response):
class RemoteErrorResponse(Response):
"""Message saying some error has occurred on the other end."""
error_message: Annotated[str, IOAttrs('m')]
error_type: Annotated[RemoteErrorType, IOAttrs('t')]
@ioprepped
@dataclass
class EmptyResponse(Response):
"""The response equivalent of None."""
class MessageProtocol:
"""Wrangles a set of message types, formats, and response types.
Both endpoints must be using a compatible Protocol for communication
@ -67,8 +76,8 @@ class MessageProtocol:
"""
def __init__(self,
message_types: Dict[int, Union[Type[Message],
Type[Response]]],
message_types: Dict[int, Type[Message]],
response_types: Dict[int, Type[Response]],
type_key: Optional[str] = None,
preserve_clean_errors: bool = True,
log_remote_exceptions: bool = True,
@ -88,10 +97,10 @@ class MessageProtocol:
be included in the RemoteError. This should only be enabled in cases
where the client is trusted.
"""
self.message_types_by_id: Dict[int, Union[Type[Message],
Type[Response]]] = {}
self.message_ids_by_type: Dict[Union[Type[Message], Type[Response]],
int] = {}
self.message_types_by_id: Dict[int, Type[Message]] = {}
self.message_ids_by_type: Dict[Type[Message], int] = {}
self.response_types_by_id: Dict[int, Type[Response]] = {}
self.response_ids_by_type: Dict[Type[Response], int] = {}
for m_id, m_type in message_types.items():
# Make sure only valid message types were passed and each
@ -99,30 +108,38 @@ class MessageProtocol:
assert isinstance(m_id, int)
assert m_id >= 0
assert (is_ioprepped_dataclass(m_type)
and issubclass(m_type, (Message, Response)))
and issubclass(m_type, Message))
assert self.message_types_by_id.get(m_id) is None
self.message_types_by_id[m_id] = m_type
self.message_ids_by_type[m_type] = m_id
for r_id, r_type in response_types.items():
assert isinstance(r_id, int)
assert r_id >= 0
assert (is_ioprepped_dataclass(r_type)
and issubclass(r_type, Response))
assert self.response_types_by_id.get(r_id) is None
self.response_types_by_id[r_id] = r_type
self.response_ids_by_type[r_type] = r_id
# Some extra-thorough validation in debug mode.
if __debug__:
# Make sure all return types are valid and have been assigned
# an ID as well.
# Make sure all Message types' return types are valid
# and have been assigned an ID as well.
all_response_types: Set[Type[Response]] = set()
for m_id, m_type in message_types.items():
if issubclass(m_type, Message):
m_rtypes = m_type.get_response_types()
assert isinstance(m_rtypes, list)
assert m_rtypes # make sure not empty
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
all_response_types.update(m_rtypes)
m_rtypes = m_type.get_response_types()
assert isinstance(m_rtypes, list)
assert m_rtypes, (
f'Message type {m_type} specifies no return types.')
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
all_response_types.update(m_rtypes)
for cls in all_response_types:
assert is_ioprepped_dataclass(cls) and issubclass(
cls, (Message, Response))
if cls not in self.message_ids_by_type:
cls, Response)
if cls not in self.response_ids_by_type:
raise ValueError(f'Possible response type {cls}'
f' was not included in message_types.')
f' was not included in response_types.')
# Make sure all registered types have unique base names.
# We can take advantage of this to generate cleaner looking
@ -138,23 +155,35 @@ class MessageProtocol:
self.log_remote_exceptions = log_remote_exceptions
self.trusted_client = trusted_client
def message_encode(self,
message: Union[Message, Response],
def encode_message(self,
message: Message,
is_error: bool = False) -> bytes:
"""Encode a message to bytes for transport."""
return self._encode(message, is_error, self.message_ids_by_type,
'message')
def encode_response(self,
response: Response,
is_error: bool = False) -> bytes:
"""Encode a response to bytes for transport."""
return self._encode(response, is_error, self.response_ids_by_type,
'response')
def _encode(self, message: Any, is_error: bool,
ids_by_type: Dict[Type, int], opname: str) -> bytes:
"""Encode a message to bytes for transport."""
m_id: Optional[int]
if is_error:
m_id = -1
else:
m_id = self.message_ids_by_type.get(type(message))
m_id = ids_by_type.get(type(message))
if m_id is None:
raise TypeError(
f'Message/Response type is not registered in Protocol:'
f' {type(message)}')
raise TypeError(f'{opname} type is not registered in protocol:'
f' {type(message)}')
msgdict = dataclass_to_dict(message)
# Encode type as part of the message dict if desired
# Encode type as part of the message/response dict if desired
# (for legacy compatibility).
if self._type_key is not None:
if self._type_key in msgdict:
@ -166,12 +195,21 @@ class MessageProtocol:
out = {'m': msgdict, 't': m_id}
return json.dumps(out, separators=(',', ':')).encode()
def message_decode(self, data: bytes) -> Union[Message, Response]:
"""Decode a message from bytes.
def decode_message(self, data: bytes) -> Message:
"""Decode a message from bytes."""
out = self._decode(data, self.message_types_by_id, 'message')
assert isinstance(out, Message)
return out
If the message represents a remote error, an Exception will
be raised.
"""
def decode_response(self, data: bytes) -> Response:
"""Decode a response from bytes."""
out = self._decode(data, self.response_types_by_id, 'response')
assert isinstance(out, Response)
return out
def _decode(self, data: bytes, types_by_id: Dict[int, Type],
opname: str) -> Any:
"""Decode a message from bytes."""
msgfull = json.loads(data.decode())
assert isinstance(msgfull, dict)
msgdict: Optional[dict]
@ -187,21 +225,18 @@ class MessageProtocol:
# Special case: a remote error occurred. Raise a local Exception.
if m_id == -1:
err = dataclass_from_dict(RemoteErrorMessage, msgdict)
assert opname == 'response'
err = dataclass_from_dict(RemoteErrorResponse, msgdict)
if (self.preserve_clean_errors
and err.error_type is RemoteErrorType.CLEAN):
raise CleanError(err.error_message)
raise RemoteError(err.error_message)
# Decode this particular type and make sure its valid.
msgtype = self.message_types_by_id.get(m_id)
# Decode this particular type.
msgtype = types_by_id.get(m_id)
if msgtype is None:
raise TypeError(
f'Got unregistered message/response type id of {m_id}.')
out = dataclass_from_dict(msgtype, msgdict)
assert isinstance(out, (Message, Response))
return out
raise TypeError(f'Got unregistered {opname} type id of {m_id}.')
return dataclass_from_dict(msgtype, msgdict)
def _get_module_header(self, part: str) -> str:
"""Return common parts of generated modules."""
@ -287,8 +322,8 @@ class MessageProtocol:
# Ew; @overload requires at least 2 different signatures so
# we need to simply write a single function if we have < 2.
if len(msgtypes) == 1:
raise RuntimeError('FIXME: currently require at least 2'
' message types.')
raise RuntimeError('FIXME: currently we require at least 2'
' registered message types; found 1.')
if len(msgtypes) > 1:
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
@ -345,7 +380,7 @@ class MessageProtocol:
# we need to simply write a single function if we have < 2.
if len(msgtypes) == 1:
raise RuntimeError('FIXME: currently require at least 2'
' message types.')
' registered message types; found 1.')
if len(msgtypes) > 1:
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
@ -433,9 +468,9 @@ class MessageSender:
if self._send_raw_message_call is None:
raise RuntimeError('send() is unimplemented for this type.')
msg_encoded = self._protocol.message_encode(message)
msg_encoded = self._protocol.encode_message(message)
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
response = self._protocol.message_decode(response_encoded)
response = self._protocol.decode_response(response_encoded)
assert isinstance(response, Response)
assert type(response) in type(message).get_response_types()
return response
@ -572,7 +607,7 @@ class MessageReceiver:
"""Decode, handle, and return encoded response for a message."""
try:
# Decode the incoming message.
msg_decoded = self._protocol.message_decode(msg)
msg_decoded = self._protocol.decode_message(msg)
msgtype = type(msg_decoded)
assert issubclass(msgtype, Message)
@ -585,26 +620,27 @@ class MessageReceiver:
# Re-encode the response.
assert isinstance(response, Response)
assert type(response) in msgtype.get_response_types()
return self._protocol.message_encode(response)
return self._protocol.encode_response(response)
except Exception as exc:
if self._protocol.log_remote_exceptions:
logging.exception('Error handling message.')
# If anything goes wrong, return a RemoteErrorMessage instead.
# If anything goes wrong, return a RemoteErrorResponse instead.
if (isinstance(exc, CleanError)
and self._protocol.preserve_clean_errors):
response = RemoteErrorMessage(error_message=str(exc),
error_type=RemoteErrorType.CLEAN)
err_response = RemoteErrorResponse(
error_message=str(exc), error_type=RemoteErrorType.CLEAN)
else:
response = RemoteErrorMessage(
err_response = RemoteErrorResponse(
error_message=(traceback.format_exc()
if self._protocol.trusted_client else
'An unknown error has occurred.'),
error_type=RemoteErrorType.OTHER)
return self._protocol.message_encode(response, is_error=True)
return self._protocol.encode_response(err_response, is_error=True)
async def handle_raw_message_async(self, msg: bytes) -> bytes:
"""Should be called when the receiver gets a message.