more messages work

This commit is contained in:
Eric Froemling 2021-09-08 18:14:48 -05:00
parent 95bbb89d14
commit 9de3738953
No known key found for this signature in database
GPG Key ID: 89C93F0F8D6D5A98
4 changed files with 106 additions and 60 deletions

View File

@ -1992,6 +1992,7 @@
<w>selindex</w> <w>selindex</w>
<w>selwidget</w> <w>selwidget</w>
<w>selwidgets</w> <w>selwidgets</w>
<w>sendable</w>
<w>senze</w> <w>senze</w>
<w>seqtype</w> <w>seqtype</w>
<w>seqtypestr</w> <w>seqtypestr</w>

View File

@ -918,6 +918,7 @@
<w>selindex</w> <w>selindex</w>
<w>selwidget</w> <w>selwidget</w>
<w>selwidgets</w> <w>selwidgets</w>
<w>sendable</w>
<w>seqlen</w> <w>seqlen</w>
<w>seqtype</w> <w>seqtype</w>
<w>seqtypestr</w> <w>seqtypestr</w>

View File

@ -10,11 +10,11 @@ from dataclasses import dataclass
import pytest import pytest
from efrotools.statictest import static_type_equals
from efro.error import CleanError, RemoteError from efro.error import CleanError, RemoteError
from efro.dataclassio import ioprepped from efro.dataclassio import ioprepped
from efro.message import (Message, Response, MessageProtocol, MessageSender, from efro.message import (Message, Response, MessageProtocol, MessageSender,
MessageReceiver) MessageReceiver)
from efrotools.statictest import static_type_equals
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import List, Type, Any, Callable, Union from typing import List, Type, Any, Callable, Union
@ -151,10 +151,12 @@ class _BoundTestMessageReceiver:
TEST_PROTOCOL = MessageProtocol( TEST_PROTOCOL = MessageProtocol(
message_types={ message_types={
1: _TMessage1, 0: _TMessage1,
2: _TMessage2, 1: _TMessage2,
3: _TResponse1, },
4: _TResponse2, response_types={
0: _TResponse1,
1: _TResponse2,
}, },
trusted_client=True, trusted_client=True,
log_remote_exceptions=False, log_remote_exceptions=False,
@ -167,10 +169,16 @@ def test_protocol_creation() -> None:
# This should fail because _TMessage1 can return _TResponse1 which # This should fail because _TMessage1 can return _TResponse1 which
# is not given an id here. # is not given an id here.
with pytest.raises(ValueError): with pytest.raises(ValueError):
_protocol = MessageProtocol(message_types={1: _TMessage1}) _protocol = MessageProtocol(
message_types={0: _TMessage1},
response_types={0: _TResponse2},
)
# Now it should work. # Now it should work.
_protocol = MessageProtocol(message_types={1: _TMessage1, 2: _TResponse1}) _protocol = MessageProtocol(
message_types={0: _TMessage1},
response_types={0: _TResponse1},
)
def test_sender_module_creation() -> None: def test_sender_module_creation() -> None:

View File

@ -50,14 +50,23 @@ class Response:
"""Base class for responses to messages.""" """Base class for responses to messages."""
# Some standard response types:
@ioprepped @ioprepped
@dataclass @dataclass
class RemoteErrorMessage(Response): class RemoteErrorResponse(Response):
"""Message saying some error has occurred on the other end.""" """Message saying some error has occurred on the other end."""
error_message: Annotated[str, IOAttrs('m')] error_message: Annotated[str, IOAttrs('m')]
error_type: Annotated[RemoteErrorType, IOAttrs('t')] error_type: Annotated[RemoteErrorType, IOAttrs('t')]
@ioprepped
@dataclass
class EmptyResponse(Response):
"""The response equivalent of None."""
class MessageProtocol: class MessageProtocol:
"""Wrangles a set of message types, formats, and response types. """Wrangles a set of message types, formats, and response types.
Both endpoints must be using a compatible Protocol for communication Both endpoints must be using a compatible Protocol for communication
@ -67,8 +76,8 @@ class MessageProtocol:
""" """
def __init__(self, def __init__(self,
message_types: Dict[int, Union[Type[Message], message_types: Dict[int, Type[Message]],
Type[Response]]], response_types: Dict[int, Type[Response]],
type_key: Optional[str] = None, type_key: Optional[str] = None,
preserve_clean_errors: bool = True, preserve_clean_errors: bool = True,
log_remote_exceptions: bool = True, log_remote_exceptions: bool = True,
@ -88,10 +97,10 @@ class MessageProtocol:
be included in the RemoteError. This should only be enabled in cases be included in the RemoteError. This should only be enabled in cases
where the client is trusted. where the client is trusted.
""" """
self.message_types_by_id: Dict[int, Union[Type[Message], self.message_types_by_id: Dict[int, Type[Message]] = {}
Type[Response]]] = {} self.message_ids_by_type: Dict[Type[Message], int] = {}
self.message_ids_by_type: Dict[Union[Type[Message], Type[Response]], self.response_types_by_id: Dict[int, Type[Response]] = {}
int] = {} self.response_ids_by_type: Dict[Type[Response], int] = {}
for m_id, m_type in message_types.items(): for m_id, m_type in message_types.items():
# Make sure only valid message types were passed and each # Make sure only valid message types were passed and each
@ -99,30 +108,38 @@ class MessageProtocol:
assert isinstance(m_id, int) assert isinstance(m_id, int)
assert m_id >= 0 assert m_id >= 0
assert (is_ioprepped_dataclass(m_type) assert (is_ioprepped_dataclass(m_type)
and issubclass(m_type, (Message, Response))) and issubclass(m_type, Message))
assert self.message_types_by_id.get(m_id) is None assert self.message_types_by_id.get(m_id) is None
self.message_types_by_id[m_id] = m_type self.message_types_by_id[m_id] = m_type
self.message_ids_by_type[m_type] = m_id self.message_ids_by_type[m_type] = m_id
for r_id, r_type in response_types.items():
assert isinstance(r_id, int)
assert r_id >= 0
assert (is_ioprepped_dataclass(r_type)
and issubclass(r_type, Response))
assert self.response_types_by_id.get(r_id) is None
self.response_types_by_id[r_id] = r_type
self.response_ids_by_type[r_type] = r_id
# Some extra-thorough validation in debug mode. # Some extra-thorough validation in debug mode.
if __debug__: if __debug__:
# Make sure all return types are valid and have been assigned # Make sure all Message types' return types are valid
# an ID as well. # and have been assigned an ID as well.
all_response_types: Set[Type[Response]] = set() all_response_types: Set[Type[Response]] = set()
for m_id, m_type in message_types.items(): for m_id, m_type in message_types.items():
if issubclass(m_type, Message): m_rtypes = m_type.get_response_types()
m_rtypes = m_type.get_response_types() assert isinstance(m_rtypes, list)
assert isinstance(m_rtypes, list) assert m_rtypes, (
assert m_rtypes # make sure not empty f'Message type {m_type} specifies no return types.')
assert len(set(m_rtypes)) == len(m_rtypes) # check dups assert len(set(m_rtypes)) == len(m_rtypes) # check dups
all_response_types.update(m_rtypes) all_response_types.update(m_rtypes)
for cls in all_response_types: for cls in all_response_types:
assert is_ioprepped_dataclass(cls) and issubclass( assert is_ioprepped_dataclass(cls) and issubclass(
cls, (Message, Response)) cls, Response)
if cls not in self.message_ids_by_type: if cls not in self.response_ids_by_type:
raise ValueError(f'Possible response type {cls}' raise ValueError(f'Possible response type {cls}'
f' was not included in message_types.') f' was not included in response_types.')
# Make sure all registered types have unique base names. # Make sure all registered types have unique base names.
# We can take advantage of this to generate cleaner looking # We can take advantage of this to generate cleaner looking
@ -138,23 +155,35 @@ class MessageProtocol:
self.log_remote_exceptions = log_remote_exceptions self.log_remote_exceptions = log_remote_exceptions
self.trusted_client = trusted_client self.trusted_client = trusted_client
def message_encode(self, def encode_message(self,
message: Union[Message, Response], message: Message,
is_error: bool = False) -> bytes: is_error: bool = False) -> bytes:
"""Encode a message to bytes for transport.""" """Encode a message to bytes for transport."""
return self._encode(message, is_error, self.message_ids_by_type,
'message')
def encode_response(self,
response: Response,
is_error: bool = False) -> bytes:
"""Encode a response to bytes for transport."""
return self._encode(response, is_error, self.response_ids_by_type,
'response')
def _encode(self, message: Any, is_error: bool,
ids_by_type: Dict[Type, int], opname: str) -> bytes:
"""Encode a message to bytes for transport."""
m_id: Optional[int] m_id: Optional[int]
if is_error: if is_error:
m_id = -1 m_id = -1
else: else:
m_id = self.message_ids_by_type.get(type(message)) m_id = ids_by_type.get(type(message))
if m_id is None: if m_id is None:
raise TypeError( raise TypeError(f'{opname} type is not registered in protocol:'
f'Message/Response type is not registered in Protocol:' f' {type(message)}')
f' {type(message)}')
msgdict = dataclass_to_dict(message) msgdict = dataclass_to_dict(message)
# Encode type as part of the message dict if desired # Encode type as part of the message/response dict if desired
# (for legacy compatibility). # (for legacy compatibility).
if self._type_key is not None: if self._type_key is not None:
if self._type_key in msgdict: if self._type_key in msgdict:
@ -166,12 +195,21 @@ class MessageProtocol:
out = {'m': msgdict, 't': m_id} out = {'m': msgdict, 't': m_id}
return json.dumps(out, separators=(',', ':')).encode() return json.dumps(out, separators=(',', ':')).encode()
def message_decode(self, data: bytes) -> Union[Message, Response]: def decode_message(self, data: bytes) -> Message:
"""Decode a message from bytes. """Decode a message from bytes."""
out = self._decode(data, self.message_types_by_id, 'message')
assert isinstance(out, Message)
return out
If the message represents a remote error, an Exception will def decode_response(self, data: bytes) -> Response:
be raised. """Decode a response from bytes."""
""" out = self._decode(data, self.response_types_by_id, 'response')
assert isinstance(out, Response)
return out
def _decode(self, data: bytes, types_by_id: Dict[int, Type],
opname: str) -> Any:
"""Decode a message from bytes."""
msgfull = json.loads(data.decode()) msgfull = json.loads(data.decode())
assert isinstance(msgfull, dict) assert isinstance(msgfull, dict)
msgdict: Optional[dict] msgdict: Optional[dict]
@ -187,21 +225,18 @@ class MessageProtocol:
# Special case: a remote error occurred. Raise a local Exception. # Special case: a remote error occurred. Raise a local Exception.
if m_id == -1: if m_id == -1:
err = dataclass_from_dict(RemoteErrorMessage, msgdict) assert opname == 'response'
err = dataclass_from_dict(RemoteErrorResponse, msgdict)
if (self.preserve_clean_errors if (self.preserve_clean_errors
and err.error_type is RemoteErrorType.CLEAN): and err.error_type is RemoteErrorType.CLEAN):
raise CleanError(err.error_message) raise CleanError(err.error_message)
raise RemoteError(err.error_message) raise RemoteError(err.error_message)
# Decode this particular type and make sure its valid. # Decode this particular type.
msgtype = self.message_types_by_id.get(m_id) msgtype = types_by_id.get(m_id)
if msgtype is None: if msgtype is None:
raise TypeError( raise TypeError(f'Got unregistered {opname} type id of {m_id}.')
f'Got unregistered message/response type id of {m_id}.') return dataclass_from_dict(msgtype, msgdict)
out = dataclass_from_dict(msgtype, msgdict)
assert isinstance(out, (Message, Response))
return out
def _get_module_header(self, part: str) -> str: def _get_module_header(self, part: str) -> str:
"""Return common parts of generated modules.""" """Return common parts of generated modules."""
@ -287,8 +322,8 @@ class MessageProtocol:
# Ew; @overload requires at least 2 different signatures so # Ew; @overload requires at least 2 different signatures so
# we need to simply write a single function if we have < 2. # we need to simply write a single function if we have < 2.
if len(msgtypes) == 1: if len(msgtypes) == 1:
raise RuntimeError('FIXME: currently require at least 2' raise RuntimeError('FIXME: currently we require at least 2'
' message types.') ' registered message types; found 1.')
if len(msgtypes) > 1: if len(msgtypes) > 1:
for msgtype in msgtypes: for msgtype in msgtypes:
msgtypevar = msgtype.__name__ msgtypevar = msgtype.__name__
@ -345,7 +380,7 @@ class MessageProtocol:
# we need to simply write a single function if we have < 2. # we need to simply write a single function if we have < 2.
if len(msgtypes) == 1: if len(msgtypes) == 1:
raise RuntimeError('FIXME: currently require at least 2' raise RuntimeError('FIXME: currently require at least 2'
' message types.') ' registered message types; found 1.')
if len(msgtypes) > 1: if len(msgtypes) > 1:
for msgtype in msgtypes: for msgtype in msgtypes:
msgtypevar = msgtype.__name__ msgtypevar = msgtype.__name__
@ -433,9 +468,9 @@ class MessageSender:
if self._send_raw_message_call is None: if self._send_raw_message_call is None:
raise RuntimeError('send() is unimplemented for this type.') raise RuntimeError('send() is unimplemented for this type.')
msg_encoded = self._protocol.message_encode(message) msg_encoded = self._protocol.encode_message(message)
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded) response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
response = self._protocol.message_decode(response_encoded) response = self._protocol.decode_response(response_encoded)
assert isinstance(response, Response) assert isinstance(response, Response)
assert type(response) in type(message).get_response_types() assert type(response) in type(message).get_response_types()
return response return response
@ -572,7 +607,7 @@ class MessageReceiver:
"""Decode, handle, and return encoded response for a message.""" """Decode, handle, and return encoded response for a message."""
try: try:
# Decode the incoming message. # Decode the incoming message.
msg_decoded = self._protocol.message_decode(msg) msg_decoded = self._protocol.decode_message(msg)
msgtype = type(msg_decoded) msgtype = type(msg_decoded)
assert issubclass(msgtype, Message) assert issubclass(msgtype, Message)
@ -585,26 +620,27 @@ class MessageReceiver:
# Re-encode the response. # Re-encode the response.
assert isinstance(response, Response) assert isinstance(response, Response)
assert type(response) in msgtype.get_response_types() assert type(response) in msgtype.get_response_types()
return self._protocol.message_encode(response) return self._protocol.encode_response(response)
except Exception as exc: except Exception as exc:
if self._protocol.log_remote_exceptions: if self._protocol.log_remote_exceptions:
logging.exception('Error handling message.') logging.exception('Error handling message.')
# If anything goes wrong, return a RemoteErrorMessage instead. # If anything goes wrong, return a RemoteErrorResponse instead.
if (isinstance(exc, CleanError) if (isinstance(exc, CleanError)
and self._protocol.preserve_clean_errors): and self._protocol.preserve_clean_errors):
response = RemoteErrorMessage(error_message=str(exc), err_response = RemoteErrorResponse(
error_type=RemoteErrorType.CLEAN) error_message=str(exc), error_type=RemoteErrorType.CLEAN)
else: else:
response = RemoteErrorMessage( err_response = RemoteErrorResponse(
error_message=(traceback.format_exc() error_message=(traceback.format_exc()
if self._protocol.trusted_client else if self._protocol.trusted_client else
'An unknown error has occurred.'), 'An unknown error has occurred.'),
error_type=RemoteErrorType.OTHER) error_type=RemoteErrorType.OTHER)
return self._protocol.message_encode(response, is_error=True) return self._protocol.encode_response(err_response, is_error=True)
async def handle_raw_message_async(self, msg: bytes) -> bytes: async def handle_raw_message_async(self, msg: bytes) -> bytes:
"""Should be called when the receiver gets a message. """Should be called when the receiver gets a message.