mirror of
https://github.com/RYDE-WORK/ballistica.git
synced 2026-01-26 00:47:10 +08:00
more messages work
This commit is contained in:
parent
95bbb89d14
commit
9de3738953
1
.idea/dictionaries/ericf.xml
generated
1
.idea/dictionaries/ericf.xml
generated
@ -1992,6 +1992,7 @@
|
||||
<w>selindex</w>
|
||||
<w>selwidget</w>
|
||||
<w>selwidgets</w>
|
||||
<w>sendable</w>
|
||||
<w>senze</w>
|
||||
<w>seqtype</w>
|
||||
<w>seqtypestr</w>
|
||||
|
||||
@ -918,6 +918,7 @@
|
||||
<w>selindex</w>
|
||||
<w>selwidget</w>
|
||||
<w>selwidgets</w>
|
||||
<w>sendable</w>
|
||||
<w>seqlen</w>
|
||||
<w>seqtype</w>
|
||||
<w>seqtypestr</w>
|
||||
|
||||
@ -10,11 +10,11 @@ from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from efrotools.statictest import static_type_equals
|
||||
from efro.error import CleanError, RemoteError
|
||||
from efro.dataclassio import ioprepped
|
||||
from efro.message import (Message, Response, MessageProtocol, MessageSender,
|
||||
MessageReceiver)
|
||||
from efrotools.statictest import static_type_equals
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import List, Type, Any, Callable, Union
|
||||
@ -151,10 +151,12 @@ class _BoundTestMessageReceiver:
|
||||
|
||||
TEST_PROTOCOL = MessageProtocol(
|
||||
message_types={
|
||||
1: _TMessage1,
|
||||
2: _TMessage2,
|
||||
3: _TResponse1,
|
||||
4: _TResponse2,
|
||||
0: _TMessage1,
|
||||
1: _TMessage2,
|
||||
},
|
||||
response_types={
|
||||
0: _TResponse1,
|
||||
1: _TResponse2,
|
||||
},
|
||||
trusted_client=True,
|
||||
log_remote_exceptions=False,
|
||||
@ -167,10 +169,16 @@ def test_protocol_creation() -> None:
|
||||
# This should fail because _TMessage1 can return _TResponse1 which
|
||||
# is not given an id here.
|
||||
with pytest.raises(ValueError):
|
||||
_protocol = MessageProtocol(message_types={1: _TMessage1})
|
||||
_protocol = MessageProtocol(
|
||||
message_types={0: _TMessage1},
|
||||
response_types={0: _TResponse2},
|
||||
)
|
||||
|
||||
# Now it should work.
|
||||
_protocol = MessageProtocol(message_types={1: _TMessage1, 2: _TResponse1})
|
||||
_protocol = MessageProtocol(
|
||||
message_types={0: _TMessage1},
|
||||
response_types={0: _TResponse1},
|
||||
)
|
||||
|
||||
|
||||
def test_sender_module_creation() -> None:
|
||||
|
||||
@ -50,14 +50,23 @@ class Response:
|
||||
"""Base class for responses to messages."""
|
||||
|
||||
|
||||
# Some standard response types:
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class RemoteErrorMessage(Response):
|
||||
class RemoteErrorResponse(Response):
|
||||
"""Message saying some error has occurred on the other end."""
|
||||
error_message: Annotated[str, IOAttrs('m')]
|
||||
error_type: Annotated[RemoteErrorType, IOAttrs('t')]
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class EmptyResponse(Response):
|
||||
"""The response equivalent of None."""
|
||||
|
||||
|
||||
class MessageProtocol:
|
||||
"""Wrangles a set of message types, formats, and response types.
|
||||
Both endpoints must be using a compatible Protocol for communication
|
||||
@ -67,8 +76,8 @@ class MessageProtocol:
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
message_types: Dict[int, Union[Type[Message],
|
||||
Type[Response]]],
|
||||
message_types: Dict[int, Type[Message]],
|
||||
response_types: Dict[int, Type[Response]],
|
||||
type_key: Optional[str] = None,
|
||||
preserve_clean_errors: bool = True,
|
||||
log_remote_exceptions: bool = True,
|
||||
@ -88,10 +97,10 @@ class MessageProtocol:
|
||||
be included in the RemoteError. This should only be enabled in cases
|
||||
where the client is trusted.
|
||||
"""
|
||||
self.message_types_by_id: Dict[int, Union[Type[Message],
|
||||
Type[Response]]] = {}
|
||||
self.message_ids_by_type: Dict[Union[Type[Message], Type[Response]],
|
||||
int] = {}
|
||||
self.message_types_by_id: Dict[int, Type[Message]] = {}
|
||||
self.message_ids_by_type: Dict[Type[Message], int] = {}
|
||||
self.response_types_by_id: Dict[int, Type[Response]] = {}
|
||||
self.response_ids_by_type: Dict[Type[Response], int] = {}
|
||||
for m_id, m_type in message_types.items():
|
||||
|
||||
# Make sure only valid message types were passed and each
|
||||
@ -99,30 +108,38 @@ class MessageProtocol:
|
||||
assert isinstance(m_id, int)
|
||||
assert m_id >= 0
|
||||
assert (is_ioprepped_dataclass(m_type)
|
||||
and issubclass(m_type, (Message, Response)))
|
||||
and issubclass(m_type, Message))
|
||||
assert self.message_types_by_id.get(m_id) is None
|
||||
|
||||
self.message_types_by_id[m_id] = m_type
|
||||
self.message_ids_by_type[m_type] = m_id
|
||||
|
||||
for r_id, r_type in response_types.items():
|
||||
assert isinstance(r_id, int)
|
||||
assert r_id >= 0
|
||||
assert (is_ioprepped_dataclass(r_type)
|
||||
and issubclass(r_type, Response))
|
||||
assert self.response_types_by_id.get(r_id) is None
|
||||
self.response_types_by_id[r_id] = r_type
|
||||
self.response_ids_by_type[r_type] = r_id
|
||||
|
||||
# Some extra-thorough validation in debug mode.
|
||||
if __debug__:
|
||||
# Make sure all return types are valid and have been assigned
|
||||
# an ID as well.
|
||||
# Make sure all Message types' return types are valid
|
||||
# and have been assigned an ID as well.
|
||||
all_response_types: Set[Type[Response]] = set()
|
||||
for m_id, m_type in message_types.items():
|
||||
if issubclass(m_type, Message):
|
||||
m_rtypes = m_type.get_response_types()
|
||||
assert isinstance(m_rtypes, list)
|
||||
assert m_rtypes # make sure not empty
|
||||
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
|
||||
all_response_types.update(m_rtypes)
|
||||
m_rtypes = m_type.get_response_types()
|
||||
assert isinstance(m_rtypes, list)
|
||||
assert m_rtypes, (
|
||||
f'Message type {m_type} specifies no return types.')
|
||||
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
|
||||
all_response_types.update(m_rtypes)
|
||||
for cls in all_response_types:
|
||||
assert is_ioprepped_dataclass(cls) and issubclass(
|
||||
cls, (Message, Response))
|
||||
if cls not in self.message_ids_by_type:
|
||||
cls, Response)
|
||||
if cls not in self.response_ids_by_type:
|
||||
raise ValueError(f'Possible response type {cls}'
|
||||
f' was not included in message_types.')
|
||||
f' was not included in response_types.')
|
||||
|
||||
# Make sure all registered types have unique base names.
|
||||
# We can take advantage of this to generate cleaner looking
|
||||
@ -138,23 +155,35 @@ class MessageProtocol:
|
||||
self.log_remote_exceptions = log_remote_exceptions
|
||||
self.trusted_client = trusted_client
|
||||
|
||||
def message_encode(self,
|
||||
message: Union[Message, Response],
|
||||
def encode_message(self,
|
||||
message: Message,
|
||||
is_error: bool = False) -> bytes:
|
||||
"""Encode a message to bytes for transport."""
|
||||
return self._encode(message, is_error, self.message_ids_by_type,
|
||||
'message')
|
||||
|
||||
def encode_response(self,
|
||||
response: Response,
|
||||
is_error: bool = False) -> bytes:
|
||||
"""Encode a response to bytes for transport."""
|
||||
return self._encode(response, is_error, self.response_ids_by_type,
|
||||
'response')
|
||||
|
||||
def _encode(self, message: Any, is_error: bool,
|
||||
ids_by_type: Dict[Type, int], opname: str) -> bytes:
|
||||
"""Encode a message to bytes for transport."""
|
||||
|
||||
m_id: Optional[int]
|
||||
if is_error:
|
||||
m_id = -1
|
||||
else:
|
||||
m_id = self.message_ids_by_type.get(type(message))
|
||||
m_id = ids_by_type.get(type(message))
|
||||
if m_id is None:
|
||||
raise TypeError(
|
||||
f'Message/Response type is not registered in Protocol:'
|
||||
f' {type(message)}')
|
||||
raise TypeError(f'{opname} type is not registered in protocol:'
|
||||
f' {type(message)}')
|
||||
msgdict = dataclass_to_dict(message)
|
||||
|
||||
# Encode type as part of the message dict if desired
|
||||
# Encode type as part of the message/response dict if desired
|
||||
# (for legacy compatibility).
|
||||
if self._type_key is not None:
|
||||
if self._type_key in msgdict:
|
||||
@ -166,12 +195,21 @@ class MessageProtocol:
|
||||
out = {'m': msgdict, 't': m_id}
|
||||
return json.dumps(out, separators=(',', ':')).encode()
|
||||
|
||||
def message_decode(self, data: bytes) -> Union[Message, Response]:
|
||||
"""Decode a message from bytes.
|
||||
def decode_message(self, data: bytes) -> Message:
|
||||
"""Decode a message from bytes."""
|
||||
out = self._decode(data, self.message_types_by_id, 'message')
|
||||
assert isinstance(out, Message)
|
||||
return out
|
||||
|
||||
If the message represents a remote error, an Exception will
|
||||
be raised.
|
||||
"""
|
||||
def decode_response(self, data: bytes) -> Response:
|
||||
"""Decode a response from bytes."""
|
||||
out = self._decode(data, self.response_types_by_id, 'response')
|
||||
assert isinstance(out, Response)
|
||||
return out
|
||||
|
||||
def _decode(self, data: bytes, types_by_id: Dict[int, Type],
|
||||
opname: str) -> Any:
|
||||
"""Decode a message from bytes."""
|
||||
msgfull = json.loads(data.decode())
|
||||
assert isinstance(msgfull, dict)
|
||||
msgdict: Optional[dict]
|
||||
@ -187,21 +225,18 @@ class MessageProtocol:
|
||||
|
||||
# Special case: a remote error occurred. Raise a local Exception.
|
||||
if m_id == -1:
|
||||
err = dataclass_from_dict(RemoteErrorMessage, msgdict)
|
||||
assert opname == 'response'
|
||||
err = dataclass_from_dict(RemoteErrorResponse, msgdict)
|
||||
if (self.preserve_clean_errors
|
||||
and err.error_type is RemoteErrorType.CLEAN):
|
||||
raise CleanError(err.error_message)
|
||||
raise RemoteError(err.error_message)
|
||||
|
||||
# Decode this particular type and make sure its valid.
|
||||
msgtype = self.message_types_by_id.get(m_id)
|
||||
# Decode this particular type.
|
||||
msgtype = types_by_id.get(m_id)
|
||||
if msgtype is None:
|
||||
raise TypeError(
|
||||
f'Got unregistered message/response type id of {m_id}.')
|
||||
|
||||
out = dataclass_from_dict(msgtype, msgdict)
|
||||
assert isinstance(out, (Message, Response))
|
||||
return out
|
||||
raise TypeError(f'Got unregistered {opname} type id of {m_id}.')
|
||||
return dataclass_from_dict(msgtype, msgdict)
|
||||
|
||||
def _get_module_header(self, part: str) -> str:
|
||||
"""Return common parts of generated modules."""
|
||||
@ -287,8 +322,8 @@ class MessageProtocol:
|
||||
# Ew; @overload requires at least 2 different signatures so
|
||||
# we need to simply write a single function if we have < 2.
|
||||
if len(msgtypes) == 1:
|
||||
raise RuntimeError('FIXME: currently require at least 2'
|
||||
' message types.')
|
||||
raise RuntimeError('FIXME: currently we require at least 2'
|
||||
' registered message types; found 1.')
|
||||
if len(msgtypes) > 1:
|
||||
for msgtype in msgtypes:
|
||||
msgtypevar = msgtype.__name__
|
||||
@ -345,7 +380,7 @@ class MessageProtocol:
|
||||
# we need to simply write a single function if we have < 2.
|
||||
if len(msgtypes) == 1:
|
||||
raise RuntimeError('FIXME: currently require at least 2'
|
||||
' message types.')
|
||||
' registered message types; found 1.')
|
||||
if len(msgtypes) > 1:
|
||||
for msgtype in msgtypes:
|
||||
msgtypevar = msgtype.__name__
|
||||
@ -433,9 +468,9 @@ class MessageSender:
|
||||
if self._send_raw_message_call is None:
|
||||
raise RuntimeError('send() is unimplemented for this type.')
|
||||
|
||||
msg_encoded = self._protocol.message_encode(message)
|
||||
msg_encoded = self._protocol.encode_message(message)
|
||||
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
|
||||
response = self._protocol.message_decode(response_encoded)
|
||||
response = self._protocol.decode_response(response_encoded)
|
||||
assert isinstance(response, Response)
|
||||
assert type(response) in type(message).get_response_types()
|
||||
return response
|
||||
@ -572,7 +607,7 @@ class MessageReceiver:
|
||||
"""Decode, handle, and return encoded response for a message."""
|
||||
try:
|
||||
# Decode the incoming message.
|
||||
msg_decoded = self._protocol.message_decode(msg)
|
||||
msg_decoded = self._protocol.decode_message(msg)
|
||||
msgtype = type(msg_decoded)
|
||||
assert issubclass(msgtype, Message)
|
||||
|
||||
@ -585,26 +620,27 @@ class MessageReceiver:
|
||||
# Re-encode the response.
|
||||
assert isinstance(response, Response)
|
||||
assert type(response) in msgtype.get_response_types()
|
||||
return self._protocol.message_encode(response)
|
||||
return self._protocol.encode_response(response)
|
||||
|
||||
except Exception as exc:
|
||||
|
||||
if self._protocol.log_remote_exceptions:
|
||||
logging.exception('Error handling message.')
|
||||
|
||||
# If anything goes wrong, return a RemoteErrorMessage instead.
|
||||
# If anything goes wrong, return a RemoteErrorResponse instead.
|
||||
if (isinstance(exc, CleanError)
|
||||
and self._protocol.preserve_clean_errors):
|
||||
response = RemoteErrorMessage(error_message=str(exc),
|
||||
error_type=RemoteErrorType.CLEAN)
|
||||
err_response = RemoteErrorResponse(
|
||||
error_message=str(exc), error_type=RemoteErrorType.CLEAN)
|
||||
|
||||
else:
|
||||
|
||||
response = RemoteErrorMessage(
|
||||
err_response = RemoteErrorResponse(
|
||||
error_message=(traceback.format_exc()
|
||||
if self._protocol.trusted_client else
|
||||
'An unknown error has occurred.'),
|
||||
error_type=RemoteErrorType.OTHER)
|
||||
return self._protocol.message_encode(response, is_error=True)
|
||||
return self._protocol.encode_response(err_response, is_error=True)
|
||||
|
||||
async def handle_raw_message_async(self, msg: bytes) -> bytes:
|
||||
"""Should be called when the receiver gets a message.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user