diff --git a/.idea/dictionaries/ericf.xml b/.idea/dictionaries/ericf.xml index ab1e2242..55028a7a 100644 --- a/.idea/dictionaries/ericf.xml +++ b/.idea/dictionaries/ericf.xml @@ -775,6 +775,7 @@ fileselector filesize filestates + filt filterlines filterpath filterpaths diff --git a/ballisticacore-cmake/.idea/dictionaries/ericf.xml b/ballisticacore-cmake/.idea/dictionaries/ericf.xml index a21794a7..00f8f13b 100644 --- a/ballisticacore-cmake/.idea/dictionaries/ericf.xml +++ b/ballisticacore-cmake/.idea/dictionaries/ericf.xml @@ -363,6 +363,7 @@ fieldpath fifteenbits filelock + filt filterstr filterval finishedptr diff --git a/tests/test_efro/test_message.py b/tests/test_efro/test_message.py index 6a9451d4..07ac208b 100644 --- a/tests/test_efro/test_message.py +++ b/tests/test_efro/test_message.py @@ -14,10 +14,10 @@ from efrotools.statictest import static_type_equals from efro.error import CleanError, RemoteError from efro.dataclassio import ioprepped from efro.message import (Message, Response, MessageProtocol, MessageSender, - MessageReceiver) + MessageReceiver, EmptyResponse) if TYPE_CHECKING: - from typing import List, Type, Any, Callable, Union + from typing import List, Type, Any, Callable, Union, Optional @ioprepped @@ -42,6 +42,13 @@ class _TMessage2(Message): return [_TResponse1, _TResponse2] +@ioprepped +@dataclass +class _TMessage3(Message): + """Just testing.""" + sval: str + + @ioprepped @dataclass class _TResponse1(Response): @@ -91,7 +98,11 @@ class _BoundTestMessageSender: def send(self, message: _TMessage2) -> Union[_TResponse1, _TResponse2]: ... - def send(self, message: Message) -> Response: + @overload + def send(self, message: _TMessage3) -> None: + ... + + def send(self, message: Message) -> Optional[Response]: """Send a message.""" return self._sender.send(self._obj, message) @@ -124,6 +135,13 @@ class _TestMessageReceiver(MessageReceiver): ) -> Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]]: ... + @overload + def handler( + self, + call: Callable[[Any, _TMessage3], None], + ) -> Callable[[Any, _TMessage3], None]: + ... + def handler(self, call: Callable) -> Callable: """Decorator to register message handlers.""" self.register_handler(call) @@ -153,10 +171,12 @@ TEST_PROTOCOL = MessageProtocol( message_types={ 0: _TMessage1, 1: _TMessage2, + 2: _TMessage3, }, response_types={ 0: _TResponse1, 1: _TResponse2, + 2: EmptyResponse, }, trusted_client=True, log_remote_exceptions=False, @@ -292,17 +312,28 @@ def test_message_sending() -> None: del msg # Unused return _TResponse2(fval=1.2) + @receiver.handler + def handle_test_message_3(self, msg: _TMessage3) -> None: + """Test.""" + del msg # Unused + receiver.validate() obj_r = TestClassR() obj_s = TestClassS(target=obj_r) + response = obj_s.msg.send(_TMessage1(ival=0)) + assert isinstance(response, _TResponse1) + + response2 = obj_s.msg.send(_TMessage2(sval='rah')) + assert isinstance(response2, (_TResponse1, _TResponse2)) + + response3 = obj_s.msg.send(_TMessage3(sval='rah')) + assert response3 is None + if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1': - response = obj_s.msg.send(_TMessage1(ival=0)) - response2 = obj_s.msg.send(_TMessage2(sval='rah')) assert static_type_equals(response, _TResponse1) - assert isinstance(response, _TResponse1) - assert isinstance(response2, (_TResponse1, _TResponse2)) + assert static_type_equals(response3, None) # Remote CleanErrors should come across locally as the same. try: diff --git a/tools/efro/message.py b/tools/efro/message.py index fa063e99..99f619e1 100644 --- a/tools/efro/message.py +++ b/tools/efro/message.py @@ -28,22 +28,19 @@ if TYPE_CHECKING: TM = TypeVar('TM', bound='MessageSender') -class RemoteErrorType(Enum): - """Type of error that occurred in remote message handling.""" - OTHER = 0 - CLEAN = 1 - - class Message: """Base class for messages.""" @classmethod def get_response_types(cls) -> List[Type[Response]]: """Return all message types this Message can result in when sent. - Note: RemoteErrorMessage is handled transparently and does not + + The default implementation specifies EmptyResponse, so messages with + no particular response needs can leave this untouched. + Note that ErrorMessage is handled as a special case and does not need to be specified here. """ - return [] + return [EmptyResponse] class Response: @@ -53,16 +50,22 @@ class Response: # Some standard response types: +class ErrorType(Enum): + """Type of error that occurred in remote message handling.""" + OTHER = 0 + CLEAN = 1 + + @ioprepped @dataclass -class RemoteErrorResponse(Response): +class ErrorResponse(Response): """Message saying some error has occurred on the other end. This type is unique in that it is not returned to the user; it instead results in a local exception being raised. """ error_message: Annotated[str, IOAttrs('m')] - error_type: Annotated[RemoteErrorType, IOAttrs('t')] + error_type: Annotated[ErrorType, IOAttrs('t')] @ioprepped @@ -126,12 +129,12 @@ class MessageProtocol: self.response_types_by_id[r_id] = r_type self.response_ids_by_type[r_type] = r_id - # If they didn't register RemoteErrorResponse, do so with a special + # If they didn't register ErrorResponse, do so with a special # -1 id which ensures it will never conflict with user messages. - if RemoteErrorResponse not in self.response_ids_by_type: + if ErrorResponse not in self.response_ids_by_type: assert self.response_types_by_id.get(-1) is None - self.response_types_by_id[-1] = RemoteErrorResponse - self.response_ids_by_type[RemoteErrorResponse] = -1 + self.response_types_by_id[-1] = ErrorResponse + self.response_ids_by_type[ErrorResponse] = -1 # Some extra-thorough validation in debug mode. if __debug__: @@ -146,8 +149,8 @@ class MessageProtocol: assert len(set(m_rtypes)) == len(m_rtypes) # check dups all_response_types.update(m_rtypes) for cls in all_response_types: - assert is_ioprepped_dataclass(cls) and issubclass( - cls, Response) + assert is_ioprepped_dataclass(cls) + assert issubclass(cls, Response) if cls not in self.response_ids_by_type: raise ValueError(f'Possible response type {cls}' f' was not included in response_types.') @@ -202,10 +205,10 @@ class MessageProtocol: assert isinstance(out, Message) return out - def decode_response(self, data: bytes) -> Response: + def decode_response(self, data: bytes) -> Optional[Response]: """Decode a response from bytes.""" out = self._decode(data, self.response_types_by_id, 'response') - assert isinstance(out, Response) + assert isinstance(out, (Response, type(None))) return out def _decode(self, data: bytes, types_by_id: Dict[int, Type], @@ -230,12 +233,16 @@ class MessageProtocol: raise TypeError(f'Got unregistered {opname} type id of {m_id}.') out = dataclass_from_dict(msgtype, msgdict) + # Special case: if we get EmptyResponse, we simply return None. + if isinstance(out, EmptyResponse): + return None + # Special case: a remote error occurred. Raise a local Exception # instead of returning the message. - if isinstance(out, RemoteErrorResponse): + if isinstance(out, ErrorResponse): assert opname == 'response' if (self.preserve_clean_errors - and out.error_type is RemoteErrorType.CLEAN): + and out.error_type is ErrorType.CLEAN): raise CleanError(out.error_message) raise RemoteError(out.error_message) @@ -327,22 +334,29 @@ class MessageProtocol: if len(msgtypes) == 1: raise RuntimeError('FIXME: currently we require at least 2' ' registered message types; found 1.') + + def _filt_tp_name(rtype: Type[Response]) -> str: + # We accept None to equal EmptyResponse so reflect that + # in the type annotation. + return 'None' if rtype is EmptyResponse else rtype.__name__ + if len(msgtypes) > 1: for msgtype in msgtypes: msgtypevar = msgtype.__name__ rtypes = msgtype.get_response_types() if len(rtypes) > 1: - tps = ', '.join(t.__name__ for t in rtypes) - responsetypevar = f'Union[{tps}]' + tps = ', '.join(_filt_tp_name(t) for t in rtypes) + rtypevar = f'Union[{tps}]' else: - responsetypevar = rtypes[0].__name__ + rtypevar = _filt_tp_name(rtypes[0]) out += (f'\n' f' @overload\n' f' def send(self, message: {msgtypevar})' - f' -> {responsetypevar}:\n' + f' -> {rtypevar}:\n' f' ...\n') out += ('\n' - ' def send(self, message: Message) -> Response:\n' + ' def send(self, message: Message)' + ' -> Optional[Response]:\n' ' """Send a message."""\n' ' return self._sender.send(self._obj, message)\n') @@ -384,15 +398,21 @@ class MessageProtocol: if len(msgtypes) == 1: raise RuntimeError('FIXME: currently require at least 2' ' registered message types; found 1.') + + def _filt_tp_name(rtype: Type[Response]) -> str: + # We accept None to equal EmptyResponse so reflect that + # in the type annotation. + return 'None' if rtype is EmptyResponse else rtype.__name__ + if len(msgtypes) > 1: for msgtype in msgtypes: msgtypevar = msgtype.__name__ rtypes = msgtype.get_response_types() if len(rtypes) > 1: - tps = ', '.join(t.__name__ for t in rtypes) + tps = ', '.join(_filt_tp_name(t) for t in rtypes) rtypevar = f'Union[{tps}]' else: - rtypevar = rtypes[0].__name__ + rtypevar = _filt_tp_name(rtypes[0]) out += ( f'\n' f' @overload\n' @@ -463,7 +483,7 @@ class MessageSender: self._send_raw_message_call = call return call - def send(self, bound_obj: Any, message: Message) -> Response: + def send(self, bound_obj: Any, message: Message) -> Optional[Response]: """Send a message and receive a response. Will encode the message for transport and call dispatch_raw_message() @@ -474,8 +494,9 @@ class MessageSender: msg_encoded = self._protocol.encode_message(message) response_encoded = self._send_raw_message_call(bound_obj, msg_encoded) response = self._protocol.decode_response(response_encoded) - assert isinstance(response, Response) - assert type(response) in type(message).get_response_types() + assert isinstance(response, (Response, type(None))) + assert (response is None + or type(response) in type(message).get_response_types()) return response def send_bg(self, bound_obj: Any, message: Message) -> Message: @@ -555,7 +576,7 @@ class MessageReceiver: assert issubclass(msgtype, Message) ret = anns.get('return') - responsetypes: Tuple[Type, ...] + responsetypes: Tuple[Union[Type[Any], Type[None]], ...] # Return types can be a single type or a union of types. if isinstance(ret, _GenericAlias): @@ -570,6 +591,10 @@ class MessageReceiver: f' "return" annotation; got a {type(ret)}.') responsetypes = (ret, ) + # Return type of None translates to EmptyResponse. + responsetypes = tuple(EmptyResponse if r is type(None) else r + for r in responsetypes) + # Make sure our protocol has this message type registered and our # return types exactly match. (Technically we could return a subset # of the supported types; can allow this in the future if it makes @@ -607,7 +632,7 @@ class MessageReceiver: raise TypeError(msg) def handle_raw_message(self, bound_obj: Any, msg: bytes) -> bytes: - """Decode, handle, and return encoded response for a message.""" + """Decode, handle, and return an encoded response for a message.""" try: # Decode the incoming message. msg_decoded = self._protocol.decode_message(msg) @@ -620,8 +645,14 @@ class MessageReceiver: raise RuntimeError(f'Got unhandled message type: {msgtype}.') response = handler(bound_obj, msg_decoded) + # A return value of None equals EmptyResponse. + if response is None: + response = EmptyResponse() + # Re-encode the response. assert isinstance(response, Response) + # (user should never explicitly return these) + assert not isinstance(response, ErrorResponse) assert type(response) in msgtype.get_response_types() return self._protocol.encode_response(response) @@ -630,19 +661,19 @@ class MessageReceiver: if self._protocol.log_remote_exceptions: logging.exception('Error handling message.') - # If anything goes wrong, return a RemoteErrorResponse instead. + # If anything goes wrong, return a ErrorResponse instead. if (isinstance(exc, CleanError) and self._protocol.preserve_clean_errors): - err_response = RemoteErrorResponse( - error_message=str(exc), error_type=RemoteErrorType.CLEAN) + err_response = ErrorResponse(error_message=str(exc), + error_type=ErrorType.CLEAN) else: - err_response = RemoteErrorResponse( + err_response = ErrorResponse( error_message=(traceback.format_exc() if self._protocol.trusted_client else 'An unknown error has occurred.'), - error_type=RemoteErrorType.OTHER) + error_type=ErrorType.OTHER) return self._protocol.encode_response(err_response) async def handle_raw_message_async(self, msg: bytes) -> bytes: diff --git a/tools/efrotools/code.py b/tools/efrotools/code.py index 6decf388..8dc4a28e 100644 --- a/tools/efrotools/code.py +++ b/tools/efrotools/code.py @@ -189,6 +189,13 @@ def format_yapf(projroot: Path, full: bool) -> None: flush=True) +def format_yapf_text(projroot: Path, code: str) -> str: + """Run yapf formatting on the provided code.""" + del projroot # Unused. + print('WOULD DO YAPF') + return code + + def _should_include_script(fnamefull: str) -> bool: fname = os.path.basename(fnamefull) diff --git a/tools/efrotools/statictest.py b/tools/efrotools/statictest.py index 0da0f0bb..50bbb17e 100644 --- a/tools/efrotools/statictest.py +++ b/tools/efrotools/statictest.py @@ -155,7 +155,7 @@ class StaticTestFile: return '\n'.join(lines_out) + '\n' -def static_type_equals(value: Any, statictype: Union[Type, str]) -> bool: +def static_type_equals(value: Any, statictype: Union[Type, None, str]) -> bool: """Check a type statically using mypy. If a string is passed as statictype, it is checked against the mypy