mirror of
https://github.com/RYDE-WORK/ballistica.git
synced 2026-02-05 23:13:46 +08:00
more messages work
This commit is contained in:
parent
95bbb89d14
commit
9de3738953
1
.idea/dictionaries/ericf.xml
generated
1
.idea/dictionaries/ericf.xml
generated
@ -1992,6 +1992,7 @@
|
|||||||
<w>selindex</w>
|
<w>selindex</w>
|
||||||
<w>selwidget</w>
|
<w>selwidget</w>
|
||||||
<w>selwidgets</w>
|
<w>selwidgets</w>
|
||||||
|
<w>sendable</w>
|
||||||
<w>senze</w>
|
<w>senze</w>
|
||||||
<w>seqtype</w>
|
<w>seqtype</w>
|
||||||
<w>seqtypestr</w>
|
<w>seqtypestr</w>
|
||||||
|
|||||||
@ -918,6 +918,7 @@
|
|||||||
<w>selindex</w>
|
<w>selindex</w>
|
||||||
<w>selwidget</w>
|
<w>selwidget</w>
|
||||||
<w>selwidgets</w>
|
<w>selwidgets</w>
|
||||||
|
<w>sendable</w>
|
||||||
<w>seqlen</w>
|
<w>seqlen</w>
|
||||||
<w>seqtype</w>
|
<w>seqtype</w>
|
||||||
<w>seqtypestr</w>
|
<w>seqtypestr</w>
|
||||||
|
|||||||
@ -10,11 +10,11 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from efrotools.statictest import static_type_equals
|
||||||
from efro.error import CleanError, RemoteError
|
from efro.error import CleanError, RemoteError
|
||||||
from efro.dataclassio import ioprepped
|
from efro.dataclassio import ioprepped
|
||||||
from efro.message import (Message, Response, MessageProtocol, MessageSender,
|
from efro.message import (Message, Response, MessageProtocol, MessageSender,
|
||||||
MessageReceiver)
|
MessageReceiver)
|
||||||
from efrotools.statictest import static_type_equals
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import List, Type, Any, Callable, Union
|
from typing import List, Type, Any, Callable, Union
|
||||||
@ -151,10 +151,12 @@ class _BoundTestMessageReceiver:
|
|||||||
|
|
||||||
TEST_PROTOCOL = MessageProtocol(
|
TEST_PROTOCOL = MessageProtocol(
|
||||||
message_types={
|
message_types={
|
||||||
1: _TMessage1,
|
0: _TMessage1,
|
||||||
2: _TMessage2,
|
1: _TMessage2,
|
||||||
3: _TResponse1,
|
},
|
||||||
4: _TResponse2,
|
response_types={
|
||||||
|
0: _TResponse1,
|
||||||
|
1: _TResponse2,
|
||||||
},
|
},
|
||||||
trusted_client=True,
|
trusted_client=True,
|
||||||
log_remote_exceptions=False,
|
log_remote_exceptions=False,
|
||||||
@ -167,10 +169,16 @@ def test_protocol_creation() -> None:
|
|||||||
# This should fail because _TMessage1 can return _TResponse1 which
|
# This should fail because _TMessage1 can return _TResponse1 which
|
||||||
# is not given an id here.
|
# is not given an id here.
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
_protocol = MessageProtocol(message_types={1: _TMessage1})
|
_protocol = MessageProtocol(
|
||||||
|
message_types={0: _TMessage1},
|
||||||
|
response_types={0: _TResponse2},
|
||||||
|
)
|
||||||
|
|
||||||
# Now it should work.
|
# Now it should work.
|
||||||
_protocol = MessageProtocol(message_types={1: _TMessage1, 2: _TResponse1})
|
_protocol = MessageProtocol(
|
||||||
|
message_types={0: _TMessage1},
|
||||||
|
response_types={0: _TResponse1},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_sender_module_creation() -> None:
|
def test_sender_module_creation() -> None:
|
||||||
|
|||||||
@ -50,14 +50,23 @@ class Response:
|
|||||||
"""Base class for responses to messages."""
|
"""Base class for responses to messages."""
|
||||||
|
|
||||||
|
|
||||||
|
# Some standard response types:
|
||||||
|
|
||||||
|
|
||||||
@ioprepped
|
@ioprepped
|
||||||
@dataclass
|
@dataclass
|
||||||
class RemoteErrorMessage(Response):
|
class RemoteErrorResponse(Response):
|
||||||
"""Message saying some error has occurred on the other end."""
|
"""Message saying some error has occurred on the other end."""
|
||||||
error_message: Annotated[str, IOAttrs('m')]
|
error_message: Annotated[str, IOAttrs('m')]
|
||||||
error_type: Annotated[RemoteErrorType, IOAttrs('t')]
|
error_type: Annotated[RemoteErrorType, IOAttrs('t')]
|
||||||
|
|
||||||
|
|
||||||
|
@ioprepped
|
||||||
|
@dataclass
|
||||||
|
class EmptyResponse(Response):
|
||||||
|
"""The response equivalent of None."""
|
||||||
|
|
||||||
|
|
||||||
class MessageProtocol:
|
class MessageProtocol:
|
||||||
"""Wrangles a set of message types, formats, and response types.
|
"""Wrangles a set of message types, formats, and response types.
|
||||||
Both endpoints must be using a compatible Protocol for communication
|
Both endpoints must be using a compatible Protocol for communication
|
||||||
@ -67,8 +76,8 @@ class MessageProtocol:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
message_types: Dict[int, Union[Type[Message],
|
message_types: Dict[int, Type[Message]],
|
||||||
Type[Response]]],
|
response_types: Dict[int, Type[Response]],
|
||||||
type_key: Optional[str] = None,
|
type_key: Optional[str] = None,
|
||||||
preserve_clean_errors: bool = True,
|
preserve_clean_errors: bool = True,
|
||||||
log_remote_exceptions: bool = True,
|
log_remote_exceptions: bool = True,
|
||||||
@ -88,10 +97,10 @@ class MessageProtocol:
|
|||||||
be included in the RemoteError. This should only be enabled in cases
|
be included in the RemoteError. This should only be enabled in cases
|
||||||
where the client is trusted.
|
where the client is trusted.
|
||||||
"""
|
"""
|
||||||
self.message_types_by_id: Dict[int, Union[Type[Message],
|
self.message_types_by_id: Dict[int, Type[Message]] = {}
|
||||||
Type[Response]]] = {}
|
self.message_ids_by_type: Dict[Type[Message], int] = {}
|
||||||
self.message_ids_by_type: Dict[Union[Type[Message], Type[Response]],
|
self.response_types_by_id: Dict[int, Type[Response]] = {}
|
||||||
int] = {}
|
self.response_ids_by_type: Dict[Type[Response], int] = {}
|
||||||
for m_id, m_type in message_types.items():
|
for m_id, m_type in message_types.items():
|
||||||
|
|
||||||
# Make sure only valid message types were passed and each
|
# Make sure only valid message types were passed and each
|
||||||
@ -99,30 +108,38 @@ class MessageProtocol:
|
|||||||
assert isinstance(m_id, int)
|
assert isinstance(m_id, int)
|
||||||
assert m_id >= 0
|
assert m_id >= 0
|
||||||
assert (is_ioprepped_dataclass(m_type)
|
assert (is_ioprepped_dataclass(m_type)
|
||||||
and issubclass(m_type, (Message, Response)))
|
and issubclass(m_type, Message))
|
||||||
assert self.message_types_by_id.get(m_id) is None
|
assert self.message_types_by_id.get(m_id) is None
|
||||||
|
|
||||||
self.message_types_by_id[m_id] = m_type
|
self.message_types_by_id[m_id] = m_type
|
||||||
self.message_ids_by_type[m_type] = m_id
|
self.message_ids_by_type[m_type] = m_id
|
||||||
|
|
||||||
|
for r_id, r_type in response_types.items():
|
||||||
|
assert isinstance(r_id, int)
|
||||||
|
assert r_id >= 0
|
||||||
|
assert (is_ioprepped_dataclass(r_type)
|
||||||
|
and issubclass(r_type, Response))
|
||||||
|
assert self.response_types_by_id.get(r_id) is None
|
||||||
|
self.response_types_by_id[r_id] = r_type
|
||||||
|
self.response_ids_by_type[r_type] = r_id
|
||||||
|
|
||||||
# Some extra-thorough validation in debug mode.
|
# Some extra-thorough validation in debug mode.
|
||||||
if __debug__:
|
if __debug__:
|
||||||
# Make sure all return types are valid and have been assigned
|
# Make sure all Message types' return types are valid
|
||||||
# an ID as well.
|
# and have been assigned an ID as well.
|
||||||
all_response_types: Set[Type[Response]] = set()
|
all_response_types: Set[Type[Response]] = set()
|
||||||
for m_id, m_type in message_types.items():
|
for m_id, m_type in message_types.items():
|
||||||
if issubclass(m_type, Message):
|
m_rtypes = m_type.get_response_types()
|
||||||
m_rtypes = m_type.get_response_types()
|
assert isinstance(m_rtypes, list)
|
||||||
assert isinstance(m_rtypes, list)
|
assert m_rtypes, (
|
||||||
assert m_rtypes # make sure not empty
|
f'Message type {m_type} specifies no return types.')
|
||||||
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
|
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
|
||||||
all_response_types.update(m_rtypes)
|
all_response_types.update(m_rtypes)
|
||||||
for cls in all_response_types:
|
for cls in all_response_types:
|
||||||
assert is_ioprepped_dataclass(cls) and issubclass(
|
assert is_ioprepped_dataclass(cls) and issubclass(
|
||||||
cls, (Message, Response))
|
cls, Response)
|
||||||
if cls not in self.message_ids_by_type:
|
if cls not in self.response_ids_by_type:
|
||||||
raise ValueError(f'Possible response type {cls}'
|
raise ValueError(f'Possible response type {cls}'
|
||||||
f' was not included in message_types.')
|
f' was not included in response_types.')
|
||||||
|
|
||||||
# Make sure all registered types have unique base names.
|
# Make sure all registered types have unique base names.
|
||||||
# We can take advantage of this to generate cleaner looking
|
# We can take advantage of this to generate cleaner looking
|
||||||
@ -138,23 +155,35 @@ class MessageProtocol:
|
|||||||
self.log_remote_exceptions = log_remote_exceptions
|
self.log_remote_exceptions = log_remote_exceptions
|
||||||
self.trusted_client = trusted_client
|
self.trusted_client = trusted_client
|
||||||
|
|
||||||
def message_encode(self,
|
def encode_message(self,
|
||||||
message: Union[Message, Response],
|
message: Message,
|
||||||
is_error: bool = False) -> bytes:
|
is_error: bool = False) -> bytes:
|
||||||
"""Encode a message to bytes for transport."""
|
"""Encode a message to bytes for transport."""
|
||||||
|
return self._encode(message, is_error, self.message_ids_by_type,
|
||||||
|
'message')
|
||||||
|
|
||||||
|
def encode_response(self,
|
||||||
|
response: Response,
|
||||||
|
is_error: bool = False) -> bytes:
|
||||||
|
"""Encode a response to bytes for transport."""
|
||||||
|
return self._encode(response, is_error, self.response_ids_by_type,
|
||||||
|
'response')
|
||||||
|
|
||||||
|
def _encode(self, message: Any, is_error: bool,
|
||||||
|
ids_by_type: Dict[Type, int], opname: str) -> bytes:
|
||||||
|
"""Encode a message to bytes for transport."""
|
||||||
|
|
||||||
m_id: Optional[int]
|
m_id: Optional[int]
|
||||||
if is_error:
|
if is_error:
|
||||||
m_id = -1
|
m_id = -1
|
||||||
else:
|
else:
|
||||||
m_id = self.message_ids_by_type.get(type(message))
|
m_id = ids_by_type.get(type(message))
|
||||||
if m_id is None:
|
if m_id is None:
|
||||||
raise TypeError(
|
raise TypeError(f'{opname} type is not registered in protocol:'
|
||||||
f'Message/Response type is not registered in Protocol:'
|
f' {type(message)}')
|
||||||
f' {type(message)}')
|
|
||||||
msgdict = dataclass_to_dict(message)
|
msgdict = dataclass_to_dict(message)
|
||||||
|
|
||||||
# Encode type as part of the message dict if desired
|
# Encode type as part of the message/response dict if desired
|
||||||
# (for legacy compatibility).
|
# (for legacy compatibility).
|
||||||
if self._type_key is not None:
|
if self._type_key is not None:
|
||||||
if self._type_key in msgdict:
|
if self._type_key in msgdict:
|
||||||
@ -166,12 +195,21 @@ class MessageProtocol:
|
|||||||
out = {'m': msgdict, 't': m_id}
|
out = {'m': msgdict, 't': m_id}
|
||||||
return json.dumps(out, separators=(',', ':')).encode()
|
return json.dumps(out, separators=(',', ':')).encode()
|
||||||
|
|
||||||
def message_decode(self, data: bytes) -> Union[Message, Response]:
|
def decode_message(self, data: bytes) -> Message:
|
||||||
"""Decode a message from bytes.
|
"""Decode a message from bytes."""
|
||||||
|
out = self._decode(data, self.message_types_by_id, 'message')
|
||||||
|
assert isinstance(out, Message)
|
||||||
|
return out
|
||||||
|
|
||||||
If the message represents a remote error, an Exception will
|
def decode_response(self, data: bytes) -> Response:
|
||||||
be raised.
|
"""Decode a response from bytes."""
|
||||||
"""
|
out = self._decode(data, self.response_types_by_id, 'response')
|
||||||
|
assert isinstance(out, Response)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _decode(self, data: bytes, types_by_id: Dict[int, Type],
|
||||||
|
opname: str) -> Any:
|
||||||
|
"""Decode a message from bytes."""
|
||||||
msgfull = json.loads(data.decode())
|
msgfull = json.loads(data.decode())
|
||||||
assert isinstance(msgfull, dict)
|
assert isinstance(msgfull, dict)
|
||||||
msgdict: Optional[dict]
|
msgdict: Optional[dict]
|
||||||
@ -187,21 +225,18 @@ class MessageProtocol:
|
|||||||
|
|
||||||
# Special case: a remote error occurred. Raise a local Exception.
|
# Special case: a remote error occurred. Raise a local Exception.
|
||||||
if m_id == -1:
|
if m_id == -1:
|
||||||
err = dataclass_from_dict(RemoteErrorMessage, msgdict)
|
assert opname == 'response'
|
||||||
|
err = dataclass_from_dict(RemoteErrorResponse, msgdict)
|
||||||
if (self.preserve_clean_errors
|
if (self.preserve_clean_errors
|
||||||
and err.error_type is RemoteErrorType.CLEAN):
|
and err.error_type is RemoteErrorType.CLEAN):
|
||||||
raise CleanError(err.error_message)
|
raise CleanError(err.error_message)
|
||||||
raise RemoteError(err.error_message)
|
raise RemoteError(err.error_message)
|
||||||
|
|
||||||
# Decode this particular type and make sure its valid.
|
# Decode this particular type.
|
||||||
msgtype = self.message_types_by_id.get(m_id)
|
msgtype = types_by_id.get(m_id)
|
||||||
if msgtype is None:
|
if msgtype is None:
|
||||||
raise TypeError(
|
raise TypeError(f'Got unregistered {opname} type id of {m_id}.')
|
||||||
f'Got unregistered message/response type id of {m_id}.')
|
return dataclass_from_dict(msgtype, msgdict)
|
||||||
|
|
||||||
out = dataclass_from_dict(msgtype, msgdict)
|
|
||||||
assert isinstance(out, (Message, Response))
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _get_module_header(self, part: str) -> str:
|
def _get_module_header(self, part: str) -> str:
|
||||||
"""Return common parts of generated modules."""
|
"""Return common parts of generated modules."""
|
||||||
@ -287,8 +322,8 @@ class MessageProtocol:
|
|||||||
# Ew; @overload requires at least 2 different signatures so
|
# Ew; @overload requires at least 2 different signatures so
|
||||||
# we need to simply write a single function if we have < 2.
|
# we need to simply write a single function if we have < 2.
|
||||||
if len(msgtypes) == 1:
|
if len(msgtypes) == 1:
|
||||||
raise RuntimeError('FIXME: currently require at least 2'
|
raise RuntimeError('FIXME: currently we require at least 2'
|
||||||
' message types.')
|
' registered message types; found 1.')
|
||||||
if len(msgtypes) > 1:
|
if len(msgtypes) > 1:
|
||||||
for msgtype in msgtypes:
|
for msgtype in msgtypes:
|
||||||
msgtypevar = msgtype.__name__
|
msgtypevar = msgtype.__name__
|
||||||
@ -345,7 +380,7 @@ class MessageProtocol:
|
|||||||
# we need to simply write a single function if we have < 2.
|
# we need to simply write a single function if we have < 2.
|
||||||
if len(msgtypes) == 1:
|
if len(msgtypes) == 1:
|
||||||
raise RuntimeError('FIXME: currently require at least 2'
|
raise RuntimeError('FIXME: currently require at least 2'
|
||||||
' message types.')
|
' registered message types; found 1.')
|
||||||
if len(msgtypes) > 1:
|
if len(msgtypes) > 1:
|
||||||
for msgtype in msgtypes:
|
for msgtype in msgtypes:
|
||||||
msgtypevar = msgtype.__name__
|
msgtypevar = msgtype.__name__
|
||||||
@ -433,9 +468,9 @@ class MessageSender:
|
|||||||
if self._send_raw_message_call is None:
|
if self._send_raw_message_call is None:
|
||||||
raise RuntimeError('send() is unimplemented for this type.')
|
raise RuntimeError('send() is unimplemented for this type.')
|
||||||
|
|
||||||
msg_encoded = self._protocol.message_encode(message)
|
msg_encoded = self._protocol.encode_message(message)
|
||||||
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
|
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
|
||||||
response = self._protocol.message_decode(response_encoded)
|
response = self._protocol.decode_response(response_encoded)
|
||||||
assert isinstance(response, Response)
|
assert isinstance(response, Response)
|
||||||
assert type(response) in type(message).get_response_types()
|
assert type(response) in type(message).get_response_types()
|
||||||
return response
|
return response
|
||||||
@ -572,7 +607,7 @@ class MessageReceiver:
|
|||||||
"""Decode, handle, and return encoded response for a message."""
|
"""Decode, handle, and return encoded response for a message."""
|
||||||
try:
|
try:
|
||||||
# Decode the incoming message.
|
# Decode the incoming message.
|
||||||
msg_decoded = self._protocol.message_decode(msg)
|
msg_decoded = self._protocol.decode_message(msg)
|
||||||
msgtype = type(msg_decoded)
|
msgtype = type(msg_decoded)
|
||||||
assert issubclass(msgtype, Message)
|
assert issubclass(msgtype, Message)
|
||||||
|
|
||||||
@ -585,26 +620,27 @@ class MessageReceiver:
|
|||||||
# Re-encode the response.
|
# Re-encode the response.
|
||||||
assert isinstance(response, Response)
|
assert isinstance(response, Response)
|
||||||
assert type(response) in msgtype.get_response_types()
|
assert type(response) in msgtype.get_response_types()
|
||||||
return self._protocol.message_encode(response)
|
return self._protocol.encode_response(response)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
|
||||||
if self._protocol.log_remote_exceptions:
|
if self._protocol.log_remote_exceptions:
|
||||||
logging.exception('Error handling message.')
|
logging.exception('Error handling message.')
|
||||||
|
|
||||||
# If anything goes wrong, return a RemoteErrorMessage instead.
|
# If anything goes wrong, return a RemoteErrorResponse instead.
|
||||||
if (isinstance(exc, CleanError)
|
if (isinstance(exc, CleanError)
|
||||||
and self._protocol.preserve_clean_errors):
|
and self._protocol.preserve_clean_errors):
|
||||||
response = RemoteErrorMessage(error_message=str(exc),
|
err_response = RemoteErrorResponse(
|
||||||
error_type=RemoteErrorType.CLEAN)
|
error_message=str(exc), error_type=RemoteErrorType.CLEAN)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
response = RemoteErrorMessage(
|
err_response = RemoteErrorResponse(
|
||||||
error_message=(traceback.format_exc()
|
error_message=(traceback.format_exc()
|
||||||
if self._protocol.trusted_client else
|
if self._protocol.trusted_client else
|
||||||
'An unknown error has occurred.'),
|
'An unknown error has occurred.'),
|
||||||
error_type=RemoteErrorType.OTHER)
|
error_type=RemoteErrorType.OTHER)
|
||||||
return self._protocol.message_encode(response, is_error=True)
|
return self._protocol.encode_response(err_response, is_error=True)
|
||||||
|
|
||||||
async def handle_raw_message_async(self, msg: bytes) -> bytes:
|
async def handle_raw_message_async(self, msg: bytes) -> bytes:
|
||||||
"""Should be called when the receiver gets a message.
|
"""Should be called when the receiver gets a message.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user