now uses None in place of EmptyResponse for simplicity

This commit is contained in:
Eric Froemling 2021-09-09 10:07:28 -05:00
parent 6673f8dec6
commit 34df4658a6
No known key found for this signature in database
GPG Key ID: 89C93F0F8D6D5A98
6 changed files with 116 additions and 45 deletions

View File

@ -775,6 +775,7 @@
<w>fileselector</w>
<w>filesize</w>
<w>filestates</w>
<w>filt</w>
<w>filterlines</w>
<w>filterpath</w>
<w>filterpaths</w>

View File

@ -363,6 +363,7 @@
<w>fieldpath</w>
<w>fifteenbits</w>
<w>filelock</w>
<w>filt</w>
<w>filterstr</w>
<w>filterval</w>
<w>finishedptr</w>

View File

@ -14,10 +14,10 @@ from efrotools.statictest import static_type_equals
from efro.error import CleanError, RemoteError
from efro.dataclassio import ioprepped
from efro.message import (Message, Response, MessageProtocol, MessageSender,
MessageReceiver)
MessageReceiver, EmptyResponse)
if TYPE_CHECKING:
from typing import List, Type, Any, Callable, Union
from typing import List, Type, Any, Callable, Union, Optional
@ioprepped
@ -42,6 +42,13 @@ class _TMessage2(Message):
return [_TResponse1, _TResponse2]
@ioprepped
@dataclass
class _TMessage3(Message):
"""Just testing."""
sval: str
@ioprepped
@dataclass
class _TResponse1(Response):
@ -91,7 +98,11 @@ class _BoundTestMessageSender:
def send(self, message: _TMessage2) -> Union[_TResponse1, _TResponse2]:
...
def send(self, message: Message) -> Response:
@overload
def send(self, message: _TMessage3) -> None:
...
def send(self, message: Message) -> Optional[Response]:
"""Send a message."""
return self._sender.send(self._obj, message)
@ -124,6 +135,13 @@ class _TestMessageReceiver(MessageReceiver):
) -> Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]]:
...
@overload
def handler(
self,
call: Callable[[Any, _TMessage3], None],
) -> Callable[[Any, _TMessage3], None]:
...
def handler(self, call: Callable) -> Callable:
"""Decorator to register message handlers."""
self.register_handler(call)
@ -153,10 +171,12 @@ TEST_PROTOCOL = MessageProtocol(
message_types={
0: _TMessage1,
1: _TMessage2,
2: _TMessage3,
},
response_types={
0: _TResponse1,
1: _TResponse2,
2: EmptyResponse,
},
trusted_client=True,
log_remote_exceptions=False,
@ -292,17 +312,28 @@ def test_message_sending() -> None:
del msg # Unused
return _TResponse2(fval=1.2)
@receiver.handler
def handle_test_message_3(self, msg: _TMessage3) -> None:
"""Test."""
del msg # Unused
receiver.validate()
obj_r = TestClassR()
obj_s = TestClassS(target=obj_r)
response = obj_s.msg.send(_TMessage1(ival=0))
assert isinstance(response, _TResponse1)
response2 = obj_s.msg.send(_TMessage2(sval='rah'))
assert isinstance(response2, (_TResponse1, _TResponse2))
response3 = obj_s.msg.send(_TMessage3(sval='rah'))
assert response3 is None
if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1':
response = obj_s.msg.send(_TMessage1(ival=0))
response2 = obj_s.msg.send(_TMessage2(sval='rah'))
assert static_type_equals(response, _TResponse1)
assert isinstance(response, _TResponse1)
assert isinstance(response2, (_TResponse1, _TResponse2))
assert static_type_equals(response3, None)
# Remote CleanErrors should come across locally as the same.
try:

View File

@ -28,22 +28,19 @@ if TYPE_CHECKING:
TM = TypeVar('TM', bound='MessageSender')
class RemoteErrorType(Enum):
"""Type of error that occurred in remote message handling."""
OTHER = 0
CLEAN = 1
class Message:
"""Base class for messages."""
@classmethod
def get_response_types(cls) -> List[Type[Response]]:
"""Return all message types this Message can result in when sent.
Note: RemoteErrorMessage is handled transparently and does not
The default implementation specifies EmptyResponse, so messages with
no particular response needs can leave this untouched.
Note that ErrorMessage is handled as a special case and does not
need to be specified here.
"""
return []
return [EmptyResponse]
class Response:
@ -53,16 +50,22 @@ class Response:
# Some standard response types:
class ErrorType(Enum):
"""Type of error that occurred in remote message handling."""
OTHER = 0
CLEAN = 1
@ioprepped
@dataclass
class RemoteErrorResponse(Response):
class ErrorResponse(Response):
"""Message saying some error has occurred on the other end.
This type is unique in that it is not returned to the user; it
instead results in a local exception being raised.
"""
error_message: Annotated[str, IOAttrs('m')]
error_type: Annotated[RemoteErrorType, IOAttrs('t')]
error_type: Annotated[ErrorType, IOAttrs('t')]
@ioprepped
@ -126,12 +129,12 @@ class MessageProtocol:
self.response_types_by_id[r_id] = r_type
self.response_ids_by_type[r_type] = r_id
# If they didn't register RemoteErrorResponse, do so with a special
# If they didn't register ErrorResponse, do so with a special
# -1 id which ensures it will never conflict with user messages.
if RemoteErrorResponse not in self.response_ids_by_type:
if ErrorResponse not in self.response_ids_by_type:
assert self.response_types_by_id.get(-1) is None
self.response_types_by_id[-1] = RemoteErrorResponse
self.response_ids_by_type[RemoteErrorResponse] = -1
self.response_types_by_id[-1] = ErrorResponse
self.response_ids_by_type[ErrorResponse] = -1
# Some extra-thorough validation in debug mode.
if __debug__:
@ -146,8 +149,8 @@ class MessageProtocol:
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
all_response_types.update(m_rtypes)
for cls in all_response_types:
assert is_ioprepped_dataclass(cls) and issubclass(
cls, Response)
assert is_ioprepped_dataclass(cls)
assert issubclass(cls, Response)
if cls not in self.response_ids_by_type:
raise ValueError(f'Possible response type {cls}'
f' was not included in response_types.')
@ -202,10 +205,10 @@ class MessageProtocol:
assert isinstance(out, Message)
return out
def decode_response(self, data: bytes) -> Response:
def decode_response(self, data: bytes) -> Optional[Response]:
"""Decode a response from bytes."""
out = self._decode(data, self.response_types_by_id, 'response')
assert isinstance(out, Response)
assert isinstance(out, (Response, type(None)))
return out
def _decode(self, data: bytes, types_by_id: Dict[int, Type],
@ -230,12 +233,16 @@ class MessageProtocol:
raise TypeError(f'Got unregistered {opname} type id of {m_id}.')
out = dataclass_from_dict(msgtype, msgdict)
# Special case: if we get EmptyResponse, we simply return None.
if isinstance(out, EmptyResponse):
return None
# Special case: a remote error occurred. Raise a local Exception
# instead of returning the message.
if isinstance(out, RemoteErrorResponse):
if isinstance(out, ErrorResponse):
assert opname == 'response'
if (self.preserve_clean_errors
and out.error_type is RemoteErrorType.CLEAN):
and out.error_type is ErrorType.CLEAN):
raise CleanError(out.error_message)
raise RemoteError(out.error_message)
@ -327,22 +334,29 @@ class MessageProtocol:
if len(msgtypes) == 1:
raise RuntimeError('FIXME: currently we require at least 2'
' registered message types; found 1.')
def _filt_tp_name(rtype: Type[Response]) -> str:
# We accept None to equal EmptyResponse so reflect that
# in the type annotation.
return 'None' if rtype is EmptyResponse else rtype.__name__
if len(msgtypes) > 1:
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
tps = ', '.join(t.__name__ for t in rtypes)
responsetypevar = f'Union[{tps}]'
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
rtypevar = f'Union[{tps}]'
else:
responsetypevar = rtypes[0].__name__
rtypevar = _filt_tp_name(rtypes[0])
out += (f'\n'
f' @overload\n'
f' def send(self, message: {msgtypevar})'
f' -> {responsetypevar}:\n'
f' -> {rtypevar}:\n'
f' ...\n')
out += ('\n'
' def send(self, message: Message) -> Response:\n'
' def send(self, message: Message)'
' -> Optional[Response]:\n'
' """Send a message."""\n'
' return self._sender.send(self._obj, message)\n')
@ -384,15 +398,21 @@ class MessageProtocol:
if len(msgtypes) == 1:
raise RuntimeError('FIXME: currently require at least 2'
' registered message types; found 1.')
def _filt_tp_name(rtype: Type[Response]) -> str:
# We accept None to equal EmptyResponse so reflect that
# in the type annotation.
return 'None' if rtype is EmptyResponse else rtype.__name__
if len(msgtypes) > 1:
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
tps = ', '.join(t.__name__ for t in rtypes)
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
rtypevar = f'Union[{tps}]'
else:
rtypevar = rtypes[0].__name__
rtypevar = _filt_tp_name(rtypes[0])
out += (
f'\n'
f' @overload\n'
@ -463,7 +483,7 @@ class MessageSender:
self._send_raw_message_call = call
return call
def send(self, bound_obj: Any, message: Message) -> Response:
def send(self, bound_obj: Any, message: Message) -> Optional[Response]:
"""Send a message and receive a response.
Will encode the message for transport and call dispatch_raw_message()
@ -474,8 +494,9 @@ class MessageSender:
msg_encoded = self._protocol.encode_message(message)
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
response = self._protocol.decode_response(response_encoded)
assert isinstance(response, Response)
assert type(response) in type(message).get_response_types()
assert isinstance(response, (Response, type(None)))
assert (response is None
or type(response) in type(message).get_response_types())
return response
def send_bg(self, bound_obj: Any, message: Message) -> Message:
@ -555,7 +576,7 @@ class MessageReceiver:
assert issubclass(msgtype, Message)
ret = anns.get('return')
responsetypes: Tuple[Type, ...]
responsetypes: Tuple[Union[Type[Any], Type[None]], ...]
# Return types can be a single type or a union of types.
if isinstance(ret, _GenericAlias):
@ -570,6 +591,10 @@ class MessageReceiver:
f' "return" annotation; got a {type(ret)}.')
responsetypes = (ret, )
# Return type of None translates to EmptyResponse.
responsetypes = tuple(EmptyResponse if r is type(None) else r
for r in responsetypes)
# Make sure our protocol has this message type registered and our
# return types exactly match. (Technically we could return a subset
# of the supported types; can allow this in the future if it makes
@ -607,7 +632,7 @@ class MessageReceiver:
raise TypeError(msg)
def handle_raw_message(self, bound_obj: Any, msg: bytes) -> bytes:
"""Decode, handle, and return encoded response for a message."""
"""Decode, handle, and return an encoded response for a message."""
try:
# Decode the incoming message.
msg_decoded = self._protocol.decode_message(msg)
@ -620,8 +645,14 @@ class MessageReceiver:
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
response = handler(bound_obj, msg_decoded)
# A return value of None equals EmptyResponse.
if response is None:
response = EmptyResponse()
# Re-encode the response.
assert isinstance(response, Response)
# (user should never explicitly return these)
assert not isinstance(response, ErrorResponse)
assert type(response) in msgtype.get_response_types()
return self._protocol.encode_response(response)
@ -630,19 +661,19 @@ class MessageReceiver:
if self._protocol.log_remote_exceptions:
logging.exception('Error handling message.')
# If anything goes wrong, return a RemoteErrorResponse instead.
# If anything goes wrong, return a ErrorResponse instead.
if (isinstance(exc, CleanError)
and self._protocol.preserve_clean_errors):
err_response = RemoteErrorResponse(
error_message=str(exc), error_type=RemoteErrorType.CLEAN)
err_response = ErrorResponse(error_message=str(exc),
error_type=ErrorType.CLEAN)
else:
err_response = RemoteErrorResponse(
err_response = ErrorResponse(
error_message=(traceback.format_exc()
if self._protocol.trusted_client else
'An unknown error has occurred.'),
error_type=RemoteErrorType.OTHER)
error_type=ErrorType.OTHER)
return self._protocol.encode_response(err_response)
async def handle_raw_message_async(self, msg: bytes) -> bytes:

View File

@ -189,6 +189,13 @@ def format_yapf(projroot: Path, full: bool) -> None:
flush=True)
def format_yapf_text(projroot: Path, code: str) -> str:
"""Run yapf formatting on the provided code."""
del projroot # Unused.
print('WOULD DO YAPF')
return code
def _should_include_script(fnamefull: str) -> bool:
fname = os.path.basename(fnamefull)

View File

@ -155,7 +155,7 @@ class StaticTestFile:
return '\n'.join(lines_out) + '\n'
def static_type_equals(value: Any, statictype: Union[Type, str]) -> bool:
def static_type_equals(value: Any, statictype: Union[Type, None, str]) -> bool:
"""Check a type statically using mypy.
If a string is passed as statictype, it is checked against the mypy