now uses None in place of EmptyResponse for simplicity

This commit is contained in:
Eric Froemling 2021-09-09 10:07:28 -05:00
parent 6673f8dec6
commit 34df4658a6
No known key found for this signature in database
GPG Key ID: 89C93F0F8D6D5A98
6 changed files with 116 additions and 45 deletions

View File

@ -775,6 +775,7 @@
<w>fileselector</w> <w>fileselector</w>
<w>filesize</w> <w>filesize</w>
<w>filestates</w> <w>filestates</w>
<w>filt</w>
<w>filterlines</w> <w>filterlines</w>
<w>filterpath</w> <w>filterpath</w>
<w>filterpaths</w> <w>filterpaths</w>

View File

@ -363,6 +363,7 @@
<w>fieldpath</w> <w>fieldpath</w>
<w>fifteenbits</w> <w>fifteenbits</w>
<w>filelock</w> <w>filelock</w>
<w>filt</w>
<w>filterstr</w> <w>filterstr</w>
<w>filterval</w> <w>filterval</w>
<w>finishedptr</w> <w>finishedptr</w>

View File

@ -14,10 +14,10 @@ from efrotools.statictest import static_type_equals
from efro.error import CleanError, RemoteError from efro.error import CleanError, RemoteError
from efro.dataclassio import ioprepped from efro.dataclassio import ioprepped
from efro.message import (Message, Response, MessageProtocol, MessageSender, from efro.message import (Message, Response, MessageProtocol, MessageSender,
MessageReceiver) MessageReceiver, EmptyResponse)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import List, Type, Any, Callable, Union from typing import List, Type, Any, Callable, Union, Optional
@ioprepped @ioprepped
@ -42,6 +42,13 @@ class _TMessage2(Message):
return [_TResponse1, _TResponse2] return [_TResponse1, _TResponse2]
@ioprepped
@dataclass
class _TMessage3(Message):
"""Just testing."""
sval: str
@ioprepped @ioprepped
@dataclass @dataclass
class _TResponse1(Response): class _TResponse1(Response):
@ -91,7 +98,11 @@ class _BoundTestMessageSender:
def send(self, message: _TMessage2) -> Union[_TResponse1, _TResponse2]: def send(self, message: _TMessage2) -> Union[_TResponse1, _TResponse2]:
... ...
def send(self, message: Message) -> Response: @overload
def send(self, message: _TMessage3) -> None:
...
def send(self, message: Message) -> Optional[Response]:
"""Send a message.""" """Send a message."""
return self._sender.send(self._obj, message) return self._sender.send(self._obj, message)
@ -124,6 +135,13 @@ class _TestMessageReceiver(MessageReceiver):
) -> Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]]: ) -> Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]]:
... ...
@overload
def handler(
self,
call: Callable[[Any, _TMessage3], None],
) -> Callable[[Any, _TMessage3], None]:
...
def handler(self, call: Callable) -> Callable: def handler(self, call: Callable) -> Callable:
"""Decorator to register message handlers.""" """Decorator to register message handlers."""
self.register_handler(call) self.register_handler(call)
@ -153,10 +171,12 @@ TEST_PROTOCOL = MessageProtocol(
message_types={ message_types={
0: _TMessage1, 0: _TMessage1,
1: _TMessage2, 1: _TMessage2,
2: _TMessage3,
}, },
response_types={ response_types={
0: _TResponse1, 0: _TResponse1,
1: _TResponse2, 1: _TResponse2,
2: EmptyResponse,
}, },
trusted_client=True, trusted_client=True,
log_remote_exceptions=False, log_remote_exceptions=False,
@ -292,17 +312,28 @@ def test_message_sending() -> None:
del msg # Unused del msg # Unused
return _TResponse2(fval=1.2) return _TResponse2(fval=1.2)
@receiver.handler
def handle_test_message_3(self, msg: _TMessage3) -> None:
"""Test."""
del msg # Unused
receiver.validate() receiver.validate()
obj_r = TestClassR() obj_r = TestClassR()
obj_s = TestClassS(target=obj_r) obj_s = TestClassS(target=obj_r)
response = obj_s.msg.send(_TMessage1(ival=0))
assert isinstance(response, _TResponse1)
response2 = obj_s.msg.send(_TMessage2(sval='rah'))
assert isinstance(response2, (_TResponse1, _TResponse2))
response3 = obj_s.msg.send(_TMessage3(sval='rah'))
assert response3 is None
if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1': if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1':
response = obj_s.msg.send(_TMessage1(ival=0))
response2 = obj_s.msg.send(_TMessage2(sval='rah'))
assert static_type_equals(response, _TResponse1) assert static_type_equals(response, _TResponse1)
assert isinstance(response, _TResponse1) assert static_type_equals(response3, None)
assert isinstance(response2, (_TResponse1, _TResponse2))
# Remote CleanErrors should come across locally as the same. # Remote CleanErrors should come across locally as the same.
try: try:

View File

@ -28,22 +28,19 @@ if TYPE_CHECKING:
TM = TypeVar('TM', bound='MessageSender') TM = TypeVar('TM', bound='MessageSender')
class RemoteErrorType(Enum):
"""Type of error that occurred in remote message handling."""
OTHER = 0
CLEAN = 1
class Message: class Message:
"""Base class for messages.""" """Base class for messages."""
@classmethod @classmethod
def get_response_types(cls) -> List[Type[Response]]: def get_response_types(cls) -> List[Type[Response]]:
"""Return all message types this Message can result in when sent. """Return all message types this Message can result in when sent.
Note: RemoteErrorMessage is handled transparently and does not
The default implementation specifies EmptyResponse, so messages with
no particular response needs can leave this untouched.
Note that ErrorMessage is handled as a special case and does not
need to be specified here. need to be specified here.
""" """
return [] return [EmptyResponse]
class Response: class Response:
@ -53,16 +50,22 @@ class Response:
# Some standard response types: # Some standard response types:
class ErrorType(Enum):
"""Type of error that occurred in remote message handling."""
OTHER = 0
CLEAN = 1
@ioprepped @ioprepped
@dataclass @dataclass
class RemoteErrorResponse(Response): class ErrorResponse(Response):
"""Message saying some error has occurred on the other end. """Message saying some error has occurred on the other end.
This type is unique in that it is not returned to the user; it This type is unique in that it is not returned to the user; it
instead results in a local exception being raised. instead results in a local exception being raised.
""" """
error_message: Annotated[str, IOAttrs('m')] error_message: Annotated[str, IOAttrs('m')]
error_type: Annotated[RemoteErrorType, IOAttrs('t')] error_type: Annotated[ErrorType, IOAttrs('t')]
@ioprepped @ioprepped
@ -126,12 +129,12 @@ class MessageProtocol:
self.response_types_by_id[r_id] = r_type self.response_types_by_id[r_id] = r_type
self.response_ids_by_type[r_type] = r_id self.response_ids_by_type[r_type] = r_id
# If they didn't register RemoteErrorResponse, do so with a special # If they didn't register ErrorResponse, do so with a special
# -1 id which ensures it will never conflict with user messages. # -1 id which ensures it will never conflict with user messages.
if RemoteErrorResponse not in self.response_ids_by_type: if ErrorResponse not in self.response_ids_by_type:
assert self.response_types_by_id.get(-1) is None assert self.response_types_by_id.get(-1) is None
self.response_types_by_id[-1] = RemoteErrorResponse self.response_types_by_id[-1] = ErrorResponse
self.response_ids_by_type[RemoteErrorResponse] = -1 self.response_ids_by_type[ErrorResponse] = -1
# Some extra-thorough validation in debug mode. # Some extra-thorough validation in debug mode.
if __debug__: if __debug__:
@ -146,8 +149,8 @@ class MessageProtocol:
assert len(set(m_rtypes)) == len(m_rtypes) # check dups assert len(set(m_rtypes)) == len(m_rtypes) # check dups
all_response_types.update(m_rtypes) all_response_types.update(m_rtypes)
for cls in all_response_types: for cls in all_response_types:
assert is_ioprepped_dataclass(cls) and issubclass( assert is_ioprepped_dataclass(cls)
cls, Response) assert issubclass(cls, Response)
if cls not in self.response_ids_by_type: if cls not in self.response_ids_by_type:
raise ValueError(f'Possible response type {cls}' raise ValueError(f'Possible response type {cls}'
f' was not included in response_types.') f' was not included in response_types.')
@ -202,10 +205,10 @@ class MessageProtocol:
assert isinstance(out, Message) assert isinstance(out, Message)
return out return out
def decode_response(self, data: bytes) -> Response: def decode_response(self, data: bytes) -> Optional[Response]:
"""Decode a response from bytes.""" """Decode a response from bytes."""
out = self._decode(data, self.response_types_by_id, 'response') out = self._decode(data, self.response_types_by_id, 'response')
assert isinstance(out, Response) assert isinstance(out, (Response, type(None)))
return out return out
def _decode(self, data: bytes, types_by_id: Dict[int, Type], def _decode(self, data: bytes, types_by_id: Dict[int, Type],
@ -230,12 +233,16 @@ class MessageProtocol:
raise TypeError(f'Got unregistered {opname} type id of {m_id}.') raise TypeError(f'Got unregistered {opname} type id of {m_id}.')
out = dataclass_from_dict(msgtype, msgdict) out = dataclass_from_dict(msgtype, msgdict)
# Special case: if we get EmptyResponse, we simply return None.
if isinstance(out, EmptyResponse):
return None
# Special case: a remote error occurred. Raise a local Exception # Special case: a remote error occurred. Raise a local Exception
# instead of returning the message. # instead of returning the message.
if isinstance(out, RemoteErrorResponse): if isinstance(out, ErrorResponse):
assert opname == 'response' assert opname == 'response'
if (self.preserve_clean_errors if (self.preserve_clean_errors
and out.error_type is RemoteErrorType.CLEAN): and out.error_type is ErrorType.CLEAN):
raise CleanError(out.error_message) raise CleanError(out.error_message)
raise RemoteError(out.error_message) raise RemoteError(out.error_message)
@ -327,22 +334,29 @@ class MessageProtocol:
if len(msgtypes) == 1: if len(msgtypes) == 1:
raise RuntimeError('FIXME: currently we require at least 2' raise RuntimeError('FIXME: currently we require at least 2'
' registered message types; found 1.') ' registered message types; found 1.')
def _filt_tp_name(rtype: Type[Response]) -> str:
# We accept None to equal EmptyResponse so reflect that
# in the type annotation.
return 'None' if rtype is EmptyResponse else rtype.__name__
if len(msgtypes) > 1: if len(msgtypes) > 1:
for msgtype in msgtypes: for msgtype in msgtypes:
msgtypevar = msgtype.__name__ msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types() rtypes = msgtype.get_response_types()
if len(rtypes) > 1: if len(rtypes) > 1:
tps = ', '.join(t.__name__ for t in rtypes) tps = ', '.join(_filt_tp_name(t) for t in rtypes)
responsetypevar = f'Union[{tps}]' rtypevar = f'Union[{tps}]'
else: else:
responsetypevar = rtypes[0].__name__ rtypevar = _filt_tp_name(rtypes[0])
out += (f'\n' out += (f'\n'
f' @overload\n' f' @overload\n'
f' def send(self, message: {msgtypevar})' f' def send(self, message: {msgtypevar})'
f' -> {responsetypevar}:\n' f' -> {rtypevar}:\n'
f' ...\n') f' ...\n')
out += ('\n' out += ('\n'
' def send(self, message: Message) -> Response:\n' ' def send(self, message: Message)'
' -> Optional[Response]:\n'
' """Send a message."""\n' ' """Send a message."""\n'
' return self._sender.send(self._obj, message)\n') ' return self._sender.send(self._obj, message)\n')
@ -384,15 +398,21 @@ class MessageProtocol:
if len(msgtypes) == 1: if len(msgtypes) == 1:
raise RuntimeError('FIXME: currently require at least 2' raise RuntimeError('FIXME: currently require at least 2'
' registered message types; found 1.') ' registered message types; found 1.')
def _filt_tp_name(rtype: Type[Response]) -> str:
# We accept None to equal EmptyResponse so reflect that
# in the type annotation.
return 'None' if rtype is EmptyResponse else rtype.__name__
if len(msgtypes) > 1: if len(msgtypes) > 1:
for msgtype in msgtypes: for msgtype in msgtypes:
msgtypevar = msgtype.__name__ msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types() rtypes = msgtype.get_response_types()
if len(rtypes) > 1: if len(rtypes) > 1:
tps = ', '.join(t.__name__ for t in rtypes) tps = ', '.join(_filt_tp_name(t) for t in rtypes)
rtypevar = f'Union[{tps}]' rtypevar = f'Union[{tps}]'
else: else:
rtypevar = rtypes[0].__name__ rtypevar = _filt_tp_name(rtypes[0])
out += ( out += (
f'\n' f'\n'
f' @overload\n' f' @overload\n'
@ -463,7 +483,7 @@ class MessageSender:
self._send_raw_message_call = call self._send_raw_message_call = call
return call return call
def send(self, bound_obj: Any, message: Message) -> Response: def send(self, bound_obj: Any, message: Message) -> Optional[Response]:
"""Send a message and receive a response. """Send a message and receive a response.
Will encode the message for transport and call dispatch_raw_message() Will encode the message for transport and call dispatch_raw_message()
@ -474,8 +494,9 @@ class MessageSender:
msg_encoded = self._protocol.encode_message(message) msg_encoded = self._protocol.encode_message(message)
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded) response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
response = self._protocol.decode_response(response_encoded) response = self._protocol.decode_response(response_encoded)
assert isinstance(response, Response) assert isinstance(response, (Response, type(None)))
assert type(response) in type(message).get_response_types() assert (response is None
or type(response) in type(message).get_response_types())
return response return response
def send_bg(self, bound_obj: Any, message: Message) -> Message: def send_bg(self, bound_obj: Any, message: Message) -> Message:
@ -555,7 +576,7 @@ class MessageReceiver:
assert issubclass(msgtype, Message) assert issubclass(msgtype, Message)
ret = anns.get('return') ret = anns.get('return')
responsetypes: Tuple[Type, ...] responsetypes: Tuple[Union[Type[Any], Type[None]], ...]
# Return types can be a single type or a union of types. # Return types can be a single type or a union of types.
if isinstance(ret, _GenericAlias): if isinstance(ret, _GenericAlias):
@ -570,6 +591,10 @@ class MessageReceiver:
f' "return" annotation; got a {type(ret)}.') f' "return" annotation; got a {type(ret)}.')
responsetypes = (ret, ) responsetypes = (ret, )
# Return type of None translates to EmptyResponse.
responsetypes = tuple(EmptyResponse if r is type(None) else r
for r in responsetypes)
# Make sure our protocol has this message type registered and our # Make sure our protocol has this message type registered and our
# return types exactly match. (Technically we could return a subset # return types exactly match. (Technically we could return a subset
# of the supported types; can allow this in the future if it makes # of the supported types; can allow this in the future if it makes
@ -607,7 +632,7 @@ class MessageReceiver:
raise TypeError(msg) raise TypeError(msg)
def handle_raw_message(self, bound_obj: Any, msg: bytes) -> bytes: def handle_raw_message(self, bound_obj: Any, msg: bytes) -> bytes:
"""Decode, handle, and return encoded response for a message.""" """Decode, handle, and return an encoded response for a message."""
try: try:
# Decode the incoming message. # Decode the incoming message.
msg_decoded = self._protocol.decode_message(msg) msg_decoded = self._protocol.decode_message(msg)
@ -620,8 +645,14 @@ class MessageReceiver:
raise RuntimeError(f'Got unhandled message type: {msgtype}.') raise RuntimeError(f'Got unhandled message type: {msgtype}.')
response = handler(bound_obj, msg_decoded) response = handler(bound_obj, msg_decoded)
# A return value of None equals EmptyResponse.
if response is None:
response = EmptyResponse()
# Re-encode the response. # Re-encode the response.
assert isinstance(response, Response) assert isinstance(response, Response)
# (user should never explicitly return these)
assert not isinstance(response, ErrorResponse)
assert type(response) in msgtype.get_response_types() assert type(response) in msgtype.get_response_types()
return self._protocol.encode_response(response) return self._protocol.encode_response(response)
@ -630,19 +661,19 @@ class MessageReceiver:
if self._protocol.log_remote_exceptions: if self._protocol.log_remote_exceptions:
logging.exception('Error handling message.') logging.exception('Error handling message.')
# If anything goes wrong, return a RemoteErrorResponse instead. # If anything goes wrong, return a ErrorResponse instead.
if (isinstance(exc, CleanError) if (isinstance(exc, CleanError)
and self._protocol.preserve_clean_errors): and self._protocol.preserve_clean_errors):
err_response = RemoteErrorResponse( err_response = ErrorResponse(error_message=str(exc),
error_message=str(exc), error_type=RemoteErrorType.CLEAN) error_type=ErrorType.CLEAN)
else: else:
err_response = RemoteErrorResponse( err_response = ErrorResponse(
error_message=(traceback.format_exc() error_message=(traceback.format_exc()
if self._protocol.trusted_client else if self._protocol.trusted_client else
'An unknown error has occurred.'), 'An unknown error has occurred.'),
error_type=RemoteErrorType.OTHER) error_type=ErrorType.OTHER)
return self._protocol.encode_response(err_response) return self._protocol.encode_response(err_response)
async def handle_raw_message_async(self, msg: bytes) -> bytes: async def handle_raw_message_async(self, msg: bytes) -> bytes:

View File

@ -189,6 +189,13 @@ def format_yapf(projroot: Path, full: bool) -> None:
flush=True) flush=True)
def format_yapf_text(projroot: Path, code: str) -> str:
"""Run yapf formatting on the provided code."""
del projroot # Unused.
print('WOULD DO YAPF')
return code
def _should_include_script(fnamefull: str) -> bool: def _should_include_script(fnamefull: str) -> bool:
fname = os.path.basename(fnamefull) fname = os.path.basename(fnamefull)

View File

@ -155,7 +155,7 @@ class StaticTestFile:
return '\n'.join(lines_out) + '\n' return '\n'.join(lines_out) + '\n'
def static_type_equals(value: Any, statictype: Union[Type, str]) -> bool: def static_type_equals(value: Any, statictype: Union[Type, None, str]) -> bool:
"""Check a type statically using mypy. """Check a type statically using mypy.
If a string is passed as statictype, it is checked against the mypy If a string is passed as statictype, it is checked against the mypy