mirror of
https://github.com/RYDE-WORK/ballistica.git
synced 2026-01-23 07:23:19 +08:00
now uses None in place of EmptyResponse for simplicity
This commit is contained in:
parent
6673f8dec6
commit
34df4658a6
1
.idea/dictionaries/ericf.xml
generated
1
.idea/dictionaries/ericf.xml
generated
@ -775,6 +775,7 @@
|
||||
<w>fileselector</w>
|
||||
<w>filesize</w>
|
||||
<w>filestates</w>
|
||||
<w>filt</w>
|
||||
<w>filterlines</w>
|
||||
<w>filterpath</w>
|
||||
<w>filterpaths</w>
|
||||
|
||||
@ -363,6 +363,7 @@
|
||||
<w>fieldpath</w>
|
||||
<w>fifteenbits</w>
|
||||
<w>filelock</w>
|
||||
<w>filt</w>
|
||||
<w>filterstr</w>
|
||||
<w>filterval</w>
|
||||
<w>finishedptr</w>
|
||||
|
||||
@ -14,10 +14,10 @@ from efrotools.statictest import static_type_equals
|
||||
from efro.error import CleanError, RemoteError
|
||||
from efro.dataclassio import ioprepped
|
||||
from efro.message import (Message, Response, MessageProtocol, MessageSender,
|
||||
MessageReceiver)
|
||||
MessageReceiver, EmptyResponse)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import List, Type, Any, Callable, Union
|
||||
from typing import List, Type, Any, Callable, Union, Optional
|
||||
|
||||
|
||||
@ioprepped
|
||||
@ -42,6 +42,13 @@ class _TMessage2(Message):
|
||||
return [_TResponse1, _TResponse2]
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class _TMessage3(Message):
|
||||
"""Just testing."""
|
||||
sval: str
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class _TResponse1(Response):
|
||||
@ -91,7 +98,11 @@ class _BoundTestMessageSender:
|
||||
def send(self, message: _TMessage2) -> Union[_TResponse1, _TResponse2]:
|
||||
...
|
||||
|
||||
def send(self, message: Message) -> Response:
|
||||
@overload
|
||||
def send(self, message: _TMessage3) -> None:
|
||||
...
|
||||
|
||||
def send(self, message: Message) -> Optional[Response]:
|
||||
"""Send a message."""
|
||||
return self._sender.send(self._obj, message)
|
||||
|
||||
@ -124,6 +135,13 @@ class _TestMessageReceiver(MessageReceiver):
|
||||
) -> Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def handler(
|
||||
self,
|
||||
call: Callable[[Any, _TMessage3], None],
|
||||
) -> Callable[[Any, _TMessage3], None]:
|
||||
...
|
||||
|
||||
def handler(self, call: Callable) -> Callable:
|
||||
"""Decorator to register message handlers."""
|
||||
self.register_handler(call)
|
||||
@ -153,10 +171,12 @@ TEST_PROTOCOL = MessageProtocol(
|
||||
message_types={
|
||||
0: _TMessage1,
|
||||
1: _TMessage2,
|
||||
2: _TMessage3,
|
||||
},
|
||||
response_types={
|
||||
0: _TResponse1,
|
||||
1: _TResponse2,
|
||||
2: EmptyResponse,
|
||||
},
|
||||
trusted_client=True,
|
||||
log_remote_exceptions=False,
|
||||
@ -292,17 +312,28 @@ def test_message_sending() -> None:
|
||||
del msg # Unused
|
||||
return _TResponse2(fval=1.2)
|
||||
|
||||
@receiver.handler
|
||||
def handle_test_message_3(self, msg: _TMessage3) -> None:
|
||||
"""Test."""
|
||||
del msg # Unused
|
||||
|
||||
receiver.validate()
|
||||
|
||||
obj_r = TestClassR()
|
||||
obj_s = TestClassS(target=obj_r)
|
||||
|
||||
response = obj_s.msg.send(_TMessage1(ival=0))
|
||||
assert isinstance(response, _TResponse1)
|
||||
|
||||
response2 = obj_s.msg.send(_TMessage2(sval='rah'))
|
||||
assert isinstance(response2, (_TResponse1, _TResponse2))
|
||||
|
||||
response3 = obj_s.msg.send(_TMessage3(sval='rah'))
|
||||
assert response3 is None
|
||||
|
||||
if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1':
|
||||
response = obj_s.msg.send(_TMessage1(ival=0))
|
||||
response2 = obj_s.msg.send(_TMessage2(sval='rah'))
|
||||
assert static_type_equals(response, _TResponse1)
|
||||
assert isinstance(response, _TResponse1)
|
||||
assert isinstance(response2, (_TResponse1, _TResponse2))
|
||||
assert static_type_equals(response3, None)
|
||||
|
||||
# Remote CleanErrors should come across locally as the same.
|
||||
try:
|
||||
|
||||
@ -28,22 +28,19 @@ if TYPE_CHECKING:
|
||||
TM = TypeVar('TM', bound='MessageSender')
|
||||
|
||||
|
||||
class RemoteErrorType(Enum):
|
||||
"""Type of error that occurred in remote message handling."""
|
||||
OTHER = 0
|
||||
CLEAN = 1
|
||||
|
||||
|
||||
class Message:
|
||||
"""Base class for messages."""
|
||||
|
||||
@classmethod
|
||||
def get_response_types(cls) -> List[Type[Response]]:
|
||||
"""Return all message types this Message can result in when sent.
|
||||
Note: RemoteErrorMessage is handled transparently and does not
|
||||
|
||||
The default implementation specifies EmptyResponse, so messages with
|
||||
no particular response needs can leave this untouched.
|
||||
Note that ErrorMessage is handled as a special case and does not
|
||||
need to be specified here.
|
||||
"""
|
||||
return []
|
||||
return [EmptyResponse]
|
||||
|
||||
|
||||
class Response:
|
||||
@ -53,16 +50,22 @@ class Response:
|
||||
# Some standard response types:
|
||||
|
||||
|
||||
class ErrorType(Enum):
|
||||
"""Type of error that occurred in remote message handling."""
|
||||
OTHER = 0
|
||||
CLEAN = 1
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class RemoteErrorResponse(Response):
|
||||
class ErrorResponse(Response):
|
||||
"""Message saying some error has occurred on the other end.
|
||||
|
||||
This type is unique in that it is not returned to the user; it
|
||||
instead results in a local exception being raised.
|
||||
"""
|
||||
error_message: Annotated[str, IOAttrs('m')]
|
||||
error_type: Annotated[RemoteErrorType, IOAttrs('t')]
|
||||
error_type: Annotated[ErrorType, IOAttrs('t')]
|
||||
|
||||
|
||||
@ioprepped
|
||||
@ -126,12 +129,12 @@ class MessageProtocol:
|
||||
self.response_types_by_id[r_id] = r_type
|
||||
self.response_ids_by_type[r_type] = r_id
|
||||
|
||||
# If they didn't register RemoteErrorResponse, do so with a special
|
||||
# If they didn't register ErrorResponse, do so with a special
|
||||
# -1 id which ensures it will never conflict with user messages.
|
||||
if RemoteErrorResponse not in self.response_ids_by_type:
|
||||
if ErrorResponse not in self.response_ids_by_type:
|
||||
assert self.response_types_by_id.get(-1) is None
|
||||
self.response_types_by_id[-1] = RemoteErrorResponse
|
||||
self.response_ids_by_type[RemoteErrorResponse] = -1
|
||||
self.response_types_by_id[-1] = ErrorResponse
|
||||
self.response_ids_by_type[ErrorResponse] = -1
|
||||
|
||||
# Some extra-thorough validation in debug mode.
|
||||
if __debug__:
|
||||
@ -146,8 +149,8 @@ class MessageProtocol:
|
||||
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
|
||||
all_response_types.update(m_rtypes)
|
||||
for cls in all_response_types:
|
||||
assert is_ioprepped_dataclass(cls) and issubclass(
|
||||
cls, Response)
|
||||
assert is_ioprepped_dataclass(cls)
|
||||
assert issubclass(cls, Response)
|
||||
if cls not in self.response_ids_by_type:
|
||||
raise ValueError(f'Possible response type {cls}'
|
||||
f' was not included in response_types.')
|
||||
@ -202,10 +205,10 @@ class MessageProtocol:
|
||||
assert isinstance(out, Message)
|
||||
return out
|
||||
|
||||
def decode_response(self, data: bytes) -> Response:
|
||||
def decode_response(self, data: bytes) -> Optional[Response]:
|
||||
"""Decode a response from bytes."""
|
||||
out = self._decode(data, self.response_types_by_id, 'response')
|
||||
assert isinstance(out, Response)
|
||||
assert isinstance(out, (Response, type(None)))
|
||||
return out
|
||||
|
||||
def _decode(self, data: bytes, types_by_id: Dict[int, Type],
|
||||
@ -230,12 +233,16 @@ class MessageProtocol:
|
||||
raise TypeError(f'Got unregistered {opname} type id of {m_id}.')
|
||||
out = dataclass_from_dict(msgtype, msgdict)
|
||||
|
||||
# Special case: if we get EmptyResponse, we simply return None.
|
||||
if isinstance(out, EmptyResponse):
|
||||
return None
|
||||
|
||||
# Special case: a remote error occurred. Raise a local Exception
|
||||
# instead of returning the message.
|
||||
if isinstance(out, RemoteErrorResponse):
|
||||
if isinstance(out, ErrorResponse):
|
||||
assert opname == 'response'
|
||||
if (self.preserve_clean_errors
|
||||
and out.error_type is RemoteErrorType.CLEAN):
|
||||
and out.error_type is ErrorType.CLEAN):
|
||||
raise CleanError(out.error_message)
|
||||
raise RemoteError(out.error_message)
|
||||
|
||||
@ -327,22 +334,29 @@ class MessageProtocol:
|
||||
if len(msgtypes) == 1:
|
||||
raise RuntimeError('FIXME: currently we require at least 2'
|
||||
' registered message types; found 1.')
|
||||
|
||||
def _filt_tp_name(rtype: Type[Response]) -> str:
|
||||
# We accept None to equal EmptyResponse so reflect that
|
||||
# in the type annotation.
|
||||
return 'None' if rtype is EmptyResponse else rtype.__name__
|
||||
|
||||
if len(msgtypes) > 1:
|
||||
for msgtype in msgtypes:
|
||||
msgtypevar = msgtype.__name__
|
||||
rtypes = msgtype.get_response_types()
|
||||
if len(rtypes) > 1:
|
||||
tps = ', '.join(t.__name__ for t in rtypes)
|
||||
responsetypevar = f'Union[{tps}]'
|
||||
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
|
||||
rtypevar = f'Union[{tps}]'
|
||||
else:
|
||||
responsetypevar = rtypes[0].__name__
|
||||
rtypevar = _filt_tp_name(rtypes[0])
|
||||
out += (f'\n'
|
||||
f' @overload\n'
|
||||
f' def send(self, message: {msgtypevar})'
|
||||
f' -> {responsetypevar}:\n'
|
||||
f' -> {rtypevar}:\n'
|
||||
f' ...\n')
|
||||
out += ('\n'
|
||||
' def send(self, message: Message) -> Response:\n'
|
||||
' def send(self, message: Message)'
|
||||
' -> Optional[Response]:\n'
|
||||
' """Send a message."""\n'
|
||||
' return self._sender.send(self._obj, message)\n')
|
||||
|
||||
@ -384,15 +398,21 @@ class MessageProtocol:
|
||||
if len(msgtypes) == 1:
|
||||
raise RuntimeError('FIXME: currently require at least 2'
|
||||
' registered message types; found 1.')
|
||||
|
||||
def _filt_tp_name(rtype: Type[Response]) -> str:
|
||||
# We accept None to equal EmptyResponse so reflect that
|
||||
# in the type annotation.
|
||||
return 'None' if rtype is EmptyResponse else rtype.__name__
|
||||
|
||||
if len(msgtypes) > 1:
|
||||
for msgtype in msgtypes:
|
||||
msgtypevar = msgtype.__name__
|
||||
rtypes = msgtype.get_response_types()
|
||||
if len(rtypes) > 1:
|
||||
tps = ', '.join(t.__name__ for t in rtypes)
|
||||
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
|
||||
rtypevar = f'Union[{tps}]'
|
||||
else:
|
||||
rtypevar = rtypes[0].__name__
|
||||
rtypevar = _filt_tp_name(rtypes[0])
|
||||
out += (
|
||||
f'\n'
|
||||
f' @overload\n'
|
||||
@ -463,7 +483,7 @@ class MessageSender:
|
||||
self._send_raw_message_call = call
|
||||
return call
|
||||
|
||||
def send(self, bound_obj: Any, message: Message) -> Response:
|
||||
def send(self, bound_obj: Any, message: Message) -> Optional[Response]:
|
||||
"""Send a message and receive a response.
|
||||
|
||||
Will encode the message for transport and call dispatch_raw_message()
|
||||
@ -474,8 +494,9 @@ class MessageSender:
|
||||
msg_encoded = self._protocol.encode_message(message)
|
||||
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
|
||||
response = self._protocol.decode_response(response_encoded)
|
||||
assert isinstance(response, Response)
|
||||
assert type(response) in type(message).get_response_types()
|
||||
assert isinstance(response, (Response, type(None)))
|
||||
assert (response is None
|
||||
or type(response) in type(message).get_response_types())
|
||||
return response
|
||||
|
||||
def send_bg(self, bound_obj: Any, message: Message) -> Message:
|
||||
@ -555,7 +576,7 @@ class MessageReceiver:
|
||||
assert issubclass(msgtype, Message)
|
||||
|
||||
ret = anns.get('return')
|
||||
responsetypes: Tuple[Type, ...]
|
||||
responsetypes: Tuple[Union[Type[Any], Type[None]], ...]
|
||||
|
||||
# Return types can be a single type or a union of types.
|
||||
if isinstance(ret, _GenericAlias):
|
||||
@ -570,6 +591,10 @@ class MessageReceiver:
|
||||
f' "return" annotation; got a {type(ret)}.')
|
||||
responsetypes = (ret, )
|
||||
|
||||
# Return type of None translates to EmptyResponse.
|
||||
responsetypes = tuple(EmptyResponse if r is type(None) else r
|
||||
for r in responsetypes)
|
||||
|
||||
# Make sure our protocol has this message type registered and our
|
||||
# return types exactly match. (Technically we could return a subset
|
||||
# of the supported types; can allow this in the future if it makes
|
||||
@ -607,7 +632,7 @@ class MessageReceiver:
|
||||
raise TypeError(msg)
|
||||
|
||||
def handle_raw_message(self, bound_obj: Any, msg: bytes) -> bytes:
|
||||
"""Decode, handle, and return encoded response for a message."""
|
||||
"""Decode, handle, and return an encoded response for a message."""
|
||||
try:
|
||||
# Decode the incoming message.
|
||||
msg_decoded = self._protocol.decode_message(msg)
|
||||
@ -620,8 +645,14 @@ class MessageReceiver:
|
||||
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
|
||||
response = handler(bound_obj, msg_decoded)
|
||||
|
||||
# A return value of None equals EmptyResponse.
|
||||
if response is None:
|
||||
response = EmptyResponse()
|
||||
|
||||
# Re-encode the response.
|
||||
assert isinstance(response, Response)
|
||||
# (user should never explicitly return these)
|
||||
assert not isinstance(response, ErrorResponse)
|
||||
assert type(response) in msgtype.get_response_types()
|
||||
return self._protocol.encode_response(response)
|
||||
|
||||
@ -630,19 +661,19 @@ class MessageReceiver:
|
||||
if self._protocol.log_remote_exceptions:
|
||||
logging.exception('Error handling message.')
|
||||
|
||||
# If anything goes wrong, return a RemoteErrorResponse instead.
|
||||
# If anything goes wrong, return a ErrorResponse instead.
|
||||
if (isinstance(exc, CleanError)
|
||||
and self._protocol.preserve_clean_errors):
|
||||
err_response = RemoteErrorResponse(
|
||||
error_message=str(exc), error_type=RemoteErrorType.CLEAN)
|
||||
err_response = ErrorResponse(error_message=str(exc),
|
||||
error_type=ErrorType.CLEAN)
|
||||
|
||||
else:
|
||||
|
||||
err_response = RemoteErrorResponse(
|
||||
err_response = ErrorResponse(
|
||||
error_message=(traceback.format_exc()
|
||||
if self._protocol.trusted_client else
|
||||
'An unknown error has occurred.'),
|
||||
error_type=RemoteErrorType.OTHER)
|
||||
error_type=ErrorType.OTHER)
|
||||
return self._protocol.encode_response(err_response)
|
||||
|
||||
async def handle_raw_message_async(self, msg: bytes) -> bytes:
|
||||
|
||||
@ -189,6 +189,13 @@ def format_yapf(projroot: Path, full: bool) -> None:
|
||||
flush=True)
|
||||
|
||||
|
||||
def format_yapf_text(projroot: Path, code: str) -> str:
|
||||
"""Run yapf formatting on the provided code."""
|
||||
del projroot # Unused.
|
||||
print('WOULD DO YAPF')
|
||||
return code
|
||||
|
||||
|
||||
def _should_include_script(fnamefull: str) -> bool:
|
||||
fname = os.path.basename(fnamefull)
|
||||
|
||||
|
||||
@ -155,7 +155,7 @@ class StaticTestFile:
|
||||
return '\n'.join(lines_out) + '\n'
|
||||
|
||||
|
||||
def static_type_equals(value: Any, statictype: Union[Type, str]) -> bool:
|
||||
def static_type_equals(value: Any, statictype: Union[Type, None, str]) -> bool:
|
||||
"""Check a type statically using mypy.
|
||||
|
||||
If a string is passed as statictype, it is checked against the mypy
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user