diff --git a/.idea/dictionaries/ericf.xml b/.idea/dictionaries/ericf.xml
index ab1e2242..55028a7a 100644
--- a/.idea/dictionaries/ericf.xml
+++ b/.idea/dictionaries/ericf.xml
@@ -775,6 +775,7 @@
fileselector
filesize
filestates
+ filt
filterlines
filterpath
filterpaths
diff --git a/ballisticacore-cmake/.idea/dictionaries/ericf.xml b/ballisticacore-cmake/.idea/dictionaries/ericf.xml
index a21794a7..00f8f13b 100644
--- a/ballisticacore-cmake/.idea/dictionaries/ericf.xml
+++ b/ballisticacore-cmake/.idea/dictionaries/ericf.xml
@@ -363,6 +363,7 @@
fieldpath
fifteenbits
filelock
+ filt
filterstr
filterval
finishedptr
diff --git a/tests/test_efro/test_message.py b/tests/test_efro/test_message.py
index 6a9451d4..07ac208b 100644
--- a/tests/test_efro/test_message.py
+++ b/tests/test_efro/test_message.py
@@ -14,10 +14,10 @@ from efrotools.statictest import static_type_equals
from efro.error import CleanError, RemoteError
from efro.dataclassio import ioprepped
from efro.message import (Message, Response, MessageProtocol, MessageSender,
- MessageReceiver)
+ MessageReceiver, EmptyResponse)
if TYPE_CHECKING:
- from typing import List, Type, Any, Callable, Union
+ from typing import List, Type, Any, Callable, Union, Optional
@ioprepped
@@ -42,6 +42,13 @@ class _TMessage2(Message):
return [_TResponse1, _TResponse2]
+@ioprepped
+@dataclass
+class _TMessage3(Message):
+ """Just testing."""
+ sval: str
+
+
@ioprepped
@dataclass
class _TResponse1(Response):
@@ -91,7 +98,11 @@ class _BoundTestMessageSender:
def send(self, message: _TMessage2) -> Union[_TResponse1, _TResponse2]:
...
- def send(self, message: Message) -> Response:
+ @overload
+ def send(self, message: _TMessage3) -> None:
+ ...
+
+ def send(self, message: Message) -> Optional[Response]:
"""Send a message."""
return self._sender.send(self._obj, message)
@@ -124,6 +135,13 @@ class _TestMessageReceiver(MessageReceiver):
) -> Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]]:
...
+ @overload
+ def handler(
+ self,
+ call: Callable[[Any, _TMessage3], None],
+ ) -> Callable[[Any, _TMessage3], None]:
+ ...
+
def handler(self, call: Callable) -> Callable:
"""Decorator to register message handlers."""
self.register_handler(call)
@@ -153,10 +171,12 @@ TEST_PROTOCOL = MessageProtocol(
message_types={
0: _TMessage1,
1: _TMessage2,
+ 2: _TMessage3,
},
response_types={
0: _TResponse1,
1: _TResponse2,
+ 2: EmptyResponse,
},
trusted_client=True,
log_remote_exceptions=False,
@@ -292,17 +312,28 @@ def test_message_sending() -> None:
del msg # Unused
return _TResponse2(fval=1.2)
+ @receiver.handler
+ def handle_test_message_3(self, msg: _TMessage3) -> None:
+ """Test."""
+ del msg # Unused
+
receiver.validate()
obj_r = TestClassR()
obj_s = TestClassS(target=obj_r)
+ response = obj_s.msg.send(_TMessage1(ival=0))
+ assert isinstance(response, _TResponse1)
+
+ response2 = obj_s.msg.send(_TMessage2(sval='rah'))
+ assert isinstance(response2, (_TResponse1, _TResponse2))
+
+ response3 = obj_s.msg.send(_TMessage3(sval='rah'))
+ assert response3 is None
+
if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1':
- response = obj_s.msg.send(_TMessage1(ival=0))
- response2 = obj_s.msg.send(_TMessage2(sval='rah'))
assert static_type_equals(response, _TResponse1)
- assert isinstance(response, _TResponse1)
- assert isinstance(response2, (_TResponse1, _TResponse2))
+ assert static_type_equals(response3, None)
# Remote CleanErrors should come across locally as the same.
try:
diff --git a/tools/efro/message.py b/tools/efro/message.py
index fa063e99..99f619e1 100644
--- a/tools/efro/message.py
+++ b/tools/efro/message.py
@@ -28,22 +28,19 @@ if TYPE_CHECKING:
TM = TypeVar('TM', bound='MessageSender')
-class RemoteErrorType(Enum):
- """Type of error that occurred in remote message handling."""
- OTHER = 0
- CLEAN = 1
-
-
class Message:
"""Base class for messages."""
@classmethod
def get_response_types(cls) -> List[Type[Response]]:
"""Return all message types this Message can result in when sent.
- Note: RemoteErrorMessage is handled transparently and does not
+
+ The default implementation specifies EmptyResponse, so messages with
+ no particular response needs can leave this untouched.
+ Note that ErrorMessage is handled as a special case and does not
need to be specified here.
"""
- return []
+ return [EmptyResponse]
class Response:
@@ -53,16 +50,22 @@ class Response:
# Some standard response types:
+class ErrorType(Enum):
+ """Type of error that occurred in remote message handling."""
+ OTHER = 0
+ CLEAN = 1
+
+
@ioprepped
@dataclass
-class RemoteErrorResponse(Response):
+class ErrorResponse(Response):
"""Message saying some error has occurred on the other end.
This type is unique in that it is not returned to the user; it
instead results in a local exception being raised.
"""
error_message: Annotated[str, IOAttrs('m')]
- error_type: Annotated[RemoteErrorType, IOAttrs('t')]
+ error_type: Annotated[ErrorType, IOAttrs('t')]
@ioprepped
@@ -126,12 +129,12 @@ class MessageProtocol:
self.response_types_by_id[r_id] = r_type
self.response_ids_by_type[r_type] = r_id
- # If they didn't register RemoteErrorResponse, do so with a special
+ # If they didn't register ErrorResponse, do so with a special
# -1 id which ensures it will never conflict with user messages.
- if RemoteErrorResponse not in self.response_ids_by_type:
+ if ErrorResponse not in self.response_ids_by_type:
assert self.response_types_by_id.get(-1) is None
- self.response_types_by_id[-1] = RemoteErrorResponse
- self.response_ids_by_type[RemoteErrorResponse] = -1
+ self.response_types_by_id[-1] = ErrorResponse
+ self.response_ids_by_type[ErrorResponse] = -1
# Some extra-thorough validation in debug mode.
if __debug__:
@@ -146,8 +149,8 @@ class MessageProtocol:
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
all_response_types.update(m_rtypes)
for cls in all_response_types:
- assert is_ioprepped_dataclass(cls) and issubclass(
- cls, Response)
+ assert is_ioprepped_dataclass(cls)
+ assert issubclass(cls, Response)
if cls not in self.response_ids_by_type:
raise ValueError(f'Possible response type {cls}'
f' was not included in response_types.')
@@ -202,10 +205,10 @@ class MessageProtocol:
assert isinstance(out, Message)
return out
- def decode_response(self, data: bytes) -> Response:
+ def decode_response(self, data: bytes) -> Optional[Response]:
"""Decode a response from bytes."""
out = self._decode(data, self.response_types_by_id, 'response')
- assert isinstance(out, Response)
+ assert isinstance(out, (Response, type(None)))
return out
def _decode(self, data: bytes, types_by_id: Dict[int, Type],
@@ -230,12 +233,16 @@ class MessageProtocol:
raise TypeError(f'Got unregistered {opname} type id of {m_id}.')
out = dataclass_from_dict(msgtype, msgdict)
+ # Special case: if we get EmptyResponse, we simply return None.
+ if isinstance(out, EmptyResponse):
+ return None
+
# Special case: a remote error occurred. Raise a local Exception
# instead of returning the message.
- if isinstance(out, RemoteErrorResponse):
+ if isinstance(out, ErrorResponse):
assert opname == 'response'
if (self.preserve_clean_errors
- and out.error_type is RemoteErrorType.CLEAN):
+ and out.error_type is ErrorType.CLEAN):
raise CleanError(out.error_message)
raise RemoteError(out.error_message)
@@ -327,22 +334,29 @@ class MessageProtocol:
if len(msgtypes) == 1:
raise RuntimeError('FIXME: currently we require at least 2'
' registered message types; found 1.')
+
+ def _filt_tp_name(rtype: Type[Response]) -> str:
+ # We accept None to equal EmptyResponse so reflect that
+ # in the type annotation.
+ return 'None' if rtype is EmptyResponse else rtype.__name__
+
if len(msgtypes) > 1:
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
- tps = ', '.join(t.__name__ for t in rtypes)
- responsetypevar = f'Union[{tps}]'
+ tps = ', '.join(_filt_tp_name(t) for t in rtypes)
+ rtypevar = f'Union[{tps}]'
else:
- responsetypevar = rtypes[0].__name__
+ rtypevar = _filt_tp_name(rtypes[0])
out += (f'\n'
f' @overload\n'
f' def send(self, message: {msgtypevar})'
- f' -> {responsetypevar}:\n'
+ f' -> {rtypevar}:\n'
f' ...\n')
out += ('\n'
- ' def send(self, message: Message) -> Response:\n'
+ ' def send(self, message: Message)'
+ ' -> Optional[Response]:\n'
' """Send a message."""\n'
' return self._sender.send(self._obj, message)\n')
@@ -384,15 +398,21 @@ class MessageProtocol:
if len(msgtypes) == 1:
raise RuntimeError('FIXME: currently require at least 2'
' registered message types; found 1.')
+
+ def _filt_tp_name(rtype: Type[Response]) -> str:
+ # We accept None to equal EmptyResponse so reflect that
+ # in the type annotation.
+ return 'None' if rtype is EmptyResponse else rtype.__name__
+
if len(msgtypes) > 1:
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
- tps = ', '.join(t.__name__ for t in rtypes)
+ tps = ', '.join(_filt_tp_name(t) for t in rtypes)
rtypevar = f'Union[{tps}]'
else:
- rtypevar = rtypes[0].__name__
+ rtypevar = _filt_tp_name(rtypes[0])
out += (
f'\n'
f' @overload\n'
@@ -463,7 +483,7 @@ class MessageSender:
self._send_raw_message_call = call
return call
- def send(self, bound_obj: Any, message: Message) -> Response:
+ def send(self, bound_obj: Any, message: Message) -> Optional[Response]:
"""Send a message and receive a response.
Will encode the message for transport and call dispatch_raw_message()
@@ -474,8 +494,9 @@ class MessageSender:
msg_encoded = self._protocol.encode_message(message)
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
response = self._protocol.decode_response(response_encoded)
- assert isinstance(response, Response)
- assert type(response) in type(message).get_response_types()
+ assert isinstance(response, (Response, type(None)))
+ assert (response is None
+ or type(response) in type(message).get_response_types())
return response
def send_bg(self, bound_obj: Any, message: Message) -> Message:
@@ -555,7 +576,7 @@ class MessageReceiver:
assert issubclass(msgtype, Message)
ret = anns.get('return')
- responsetypes: Tuple[Type, ...]
+ responsetypes: Tuple[Union[Type[Any], Type[None]], ...]
# Return types can be a single type or a union of types.
if isinstance(ret, _GenericAlias):
@@ -570,6 +591,10 @@ class MessageReceiver:
f' "return" annotation; got a {type(ret)}.')
responsetypes = (ret, )
+ # Return type of None translates to EmptyResponse.
+ responsetypes = tuple(EmptyResponse if r is type(None) else r
+ for r in responsetypes)
+
# Make sure our protocol has this message type registered and our
# return types exactly match. (Technically we could return a subset
# of the supported types; can allow this in the future if it makes
@@ -607,7 +632,7 @@ class MessageReceiver:
raise TypeError(msg)
def handle_raw_message(self, bound_obj: Any, msg: bytes) -> bytes:
- """Decode, handle, and return encoded response for a message."""
+ """Decode, handle, and return an encoded response for a message."""
try:
# Decode the incoming message.
msg_decoded = self._protocol.decode_message(msg)
@@ -620,8 +645,14 @@ class MessageReceiver:
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
response = handler(bound_obj, msg_decoded)
+ # A return value of None equals EmptyResponse.
+ if response is None:
+ response = EmptyResponse()
+
# Re-encode the response.
assert isinstance(response, Response)
+ # (user should never explicitly return these)
+ assert not isinstance(response, ErrorResponse)
assert type(response) in msgtype.get_response_types()
return self._protocol.encode_response(response)
@@ -630,19 +661,19 @@ class MessageReceiver:
if self._protocol.log_remote_exceptions:
logging.exception('Error handling message.')
- # If anything goes wrong, return a RemoteErrorResponse instead.
+ # If anything goes wrong, return a ErrorResponse instead.
if (isinstance(exc, CleanError)
and self._protocol.preserve_clean_errors):
- err_response = RemoteErrorResponse(
- error_message=str(exc), error_type=RemoteErrorType.CLEAN)
+ err_response = ErrorResponse(error_message=str(exc),
+ error_type=ErrorType.CLEAN)
else:
- err_response = RemoteErrorResponse(
+ err_response = ErrorResponse(
error_message=(traceback.format_exc()
if self._protocol.trusted_client else
'An unknown error has occurred.'),
- error_type=RemoteErrorType.OTHER)
+ error_type=ErrorType.OTHER)
return self._protocol.encode_response(err_response)
async def handle_raw_message_async(self, msg: bytes) -> bytes:
diff --git a/tools/efrotools/code.py b/tools/efrotools/code.py
index 6decf388..8dc4a28e 100644
--- a/tools/efrotools/code.py
+++ b/tools/efrotools/code.py
@@ -189,6 +189,13 @@ def format_yapf(projroot: Path, full: bool) -> None:
flush=True)
+def format_yapf_text(projroot: Path, code: str) -> str:
+ """Run yapf formatting on the provided code."""
+ del projroot # Unused.
+ print('WOULD DO YAPF')
+ return code
+
+
def _should_include_script(fnamefull: str) -> bool:
fname = os.path.basename(fnamefull)
diff --git a/tools/efrotools/statictest.py b/tools/efrotools/statictest.py
index 0da0f0bb..50bbb17e 100644
--- a/tools/efrotools/statictest.py
+++ b/tools/efrotools/statictest.py
@@ -155,7 +155,7 @@ class StaticTestFile:
return '\n'.join(lines_out) + '\n'
-def static_type_equals(value: Any, statictype: Union[Type, str]) -> bool:
+def static_type_equals(value: Any, statictype: Union[Type, None, str]) -> bool:
"""Check a type statically using mypy.
If a string is passed as statictype, it is checked against the mypy