diff --git a/.idea/dictionaries/ericf.xml b/.idea/dictionaries/ericf.xml index 0d05bfb2..ea1b9b80 100644 --- a/.idea/dictionaries/ericf.xml +++ b/.idea/dictionaries/ericf.xml @@ -159,6 +159,7 @@ availmins availplug aval + awaitable axismotion bacfg backgrounded @@ -303,6 +304,7 @@ capturetheflag carentity cashregistersound + cbgn cbits cbot cbtn @@ -313,6 +315,7 @@ cdrk cdull cdval + cend centeuro centiseconds cfconfig @@ -1869,6 +1872,8 @@ rawpaths rcade rcfile + rcva + rcvs rdict rdir readline @@ -1997,6 +2002,7 @@ selwidget selwidgets sendable + sendmethod senze seqtype seqtypestr diff --git a/ballisticacore-cmake/.idea/dictionaries/ericf.xml b/ballisticacore-cmake/.idea/dictionaries/ericf.xml index ef2bd0b6..69f73419 100644 --- a/ballisticacore-cmake/.idea/dictionaries/ericf.xml +++ b/ballisticacore-cmake/.idea/dictionaries/ericf.xml @@ -69,6 +69,7 @@ availmins avel avels + awaitable axismotion backgrounded backgrounding @@ -155,10 +156,12 @@ cancelbtn capitan cargs + cbgn cbtnoffs ccdd ccontext ccylinder + cend centiseconds cfgdir cfgpath @@ -860,6 +863,8 @@ rasterizer rawkey rcade + rcva + rcvs reaaaly readset realloc @@ -922,6 +927,7 @@ selwidget selwidgets sendable + sendmethod seqlen seqtype seqtypestr diff --git a/docs/ba_module.md b/docs/ba_module.md index 27a61067..fdfc7837 100644 --- a/docs/ba_module.md +++ b/docs/ba_module.md @@ -1,5 +1,5 @@ -

last updated on 2021-09-19 for Ballistica version 1.6.5 build 20393

+

last updated on 2021-09-21 for Ballistica version 1.6.5 build 20393

This page documents the Python classes and functions in the 'ba' module, which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please let me know. Happy modding!


diff --git a/tests/test_efro/test_message.py b/tests/test_efro/test_message.py index 4b283487..cec32d31 100644 --- a/tests/test_efro/test_message.py +++ b/tests/test_efro/test_message.py @@ -5,6 +5,7 @@ from __future__ import annotations import os +import asyncio from typing import TYPE_CHECKING, overload from dataclasses import dataclass @@ -17,55 +18,55 @@ from efro.message import (Message, Response, MessageProtocol, MessageSender, MessageReceiver) if TYPE_CHECKING: - from typing import List, Type, Any, Callable, Union, Optional + from typing import List, Type, Any, Callable, Union, Optional, Awaitable @ioprepped @dataclass -class _TMessage1(Message): +class _TMsg1(Message): """Just testing.""" ival: int @classmethod def get_response_types(cls) -> List[Type[Response]]: - return [_TResponse1] + return [_TResp1] @ioprepped @dataclass -class _TMessage2(Message): +class _TMsg2(Message): """Just testing.""" sval: str @classmethod def get_response_types(cls) -> List[Type[Response]]: - return [_TResponse1, _TResponse2] + return [_TResp1, _TResp2] @ioprepped @dataclass -class _TMessage3(Message): +class _TMsg3(Message): """Just testing.""" sval: str @ioprepped @dataclass -class _TResponse1(Response): +class _TResp1(Response): """Just testing.""" bval: bool @ioprepped @dataclass -class _TResponse2(Response): +class _TResp2(Response): """Just testing.""" fval: float @ioprepped @dataclass -class _TResponse3(Message): +class _TResp3(Message): """Just testing.""" fval: float @@ -91,55 +92,73 @@ class _BoundTestMessageSender: self._sender = sender @overload - def send(self, message: _TMessage1) -> _TResponse1: + def send(self, message: _TMsg1) -> _TResp1: ... @overload - def send(self, message: _TMessage2) -> Union[_TResponse1, _TResponse2]: + def send(self, message: _TMsg2) -> Union[_TResp1, _TResp2]: ... @overload - def send(self, message: _TMessage3) -> None: + def send(self, message: _TMsg3) -> None: ... def send(self, message: Message) -> Optional[Response]: - """Send a message.""" + """Send a message synchronously.""" return self._sender.send(self._obj, message) + @overload + async def send_async(self, message: _TMsg1) -> _TResp1: + ... + + @overload + async def send_async(self, message: _TMsg2) -> Union[_TResp1, _TResp2]: + ... + + @overload + async def send_async(self, message: _TMsg3) -> None: + ... + + async def send_async(self, message: Message) -> Optional[Response]: + """Send a message asynchronously.""" + return await self._sender.send_async(self._obj, message) + # SEND_CODE_TEST_END -# RCV_CODE_TEST_BEGIN +# RCVS_CODE_TEST_BEGIN -class _TestMessageReceiver(MessageReceiver): - """Protocol-specific receiver.""" +class _TestSyncMessageReceiver(MessageReceiver): + """Protocol-specific synchronous receiver.""" + + is_async = False def __get__( self, obj: Any, type_in: Any = None, - ) -> _BoundTestMessageReceiver: - return _BoundTestMessageReceiver(obj, self) + ) -> _BoundTestSyncMessageReceiver: + return _BoundTestSyncMessageReceiver(obj, self) @overload def handler( self, - call: Callable[[Any, _TMessage1], _TResponse1], - ) -> Callable[[Any, _TMessage1], _TResponse1]: + call: Callable[[Any, _TMsg1], _TResp1], + ) -> Callable[[Any, _TMsg1], _TResp1]: ... @overload def handler( self, - call: Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]], - ) -> Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]]: + call: Callable[[Any, _TMsg2], Union[_TResp1, _TResp2]], + ) -> Callable[[Any, _TMsg2], Union[_TResp1, _TResp2]]: ... @overload def handler( self, - call: Callable[[Any, _TMessage3], None], - ) -> Callable[[Any, _TMessage3], None]: + call: Callable[[Any, _TMsg3], None], + ) -> Callable[[Any, _TMsg3], None]: ... def handler(self, call: Callable) -> Callable: @@ -148,34 +167,104 @@ class _TestMessageReceiver(MessageReceiver): return call -class _BoundTestMessageReceiver: +class _BoundTestSyncMessageReceiver: """Protocol-specific bound receiver.""" def __init__( self, obj: Any, - receiver: _TestMessageReceiver, + receiver: _TestSyncMessageReceiver, ) -> None: assert obj is not None self._obj = obj self._receiver = receiver def handle_raw_message(self, message: bytes) -> bytes: - """Handle a raw incoming message.""" + """Handle a raw incoming synchronous message.""" return self._receiver.handle_raw_message(self._obj, message) + async def handle_raw_message_async(self, message: bytes) -> bytes: + """Handle a raw incoming asynchronous message.""" + return await self._receiver.handle_raw_message_async( + self._obj, message) -# RCV_CODE_TEST_END + +# RCVS_CODE_TEST_END +# RCVA_CODE_TEST_BEGIN + + +class _TestAsyncMessageReceiver(MessageReceiver): + """Protocol-specific asynchronous receiver.""" + + is_async = True + + def __get__( + self, + obj: Any, + type_in: Any = None, + ) -> _BoundTestAsyncMessageReceiver: + return _BoundTestAsyncMessageReceiver(obj, self) + + @overload + def handler( + self, + call: Callable[[Any, _TMsg1], Awaitable[_TResp1]], + ) -> Callable[[Any, _TMsg1], Awaitable[_TResp1]]: + ... + + @overload + def handler( + self, + call: Callable[[Any, _TMsg2], Awaitable[Union[_TResp1, _TResp2]]], + ) -> Callable[[Any, _TMsg2], Awaitable[Union[_TResp1, _TResp2]]]: + ... + + @overload + def handler( + self, + call: Callable[[Any, _TMsg3], Awaitable[None]], + ) -> Callable[[Any, _TMsg3], Awaitable[None]]: + ... + + def handler(self, call: Callable) -> Callable: + """Decorator to register message handlers.""" + self.register_handler(call) + return call + + +class _BoundTestAsyncMessageReceiver: + """Protocol-specific bound receiver.""" + + def __init__( + self, + obj: Any, + receiver: _TestAsyncMessageReceiver, + ) -> None: + assert obj is not None + self._obj = obj + self._receiver = receiver + + def handle_raw_message(self, message: bytes) -> bytes: + """Handle a raw incoming synchronous message.""" + return self._receiver.handle_raw_message(self._obj, message) + + async def handle_raw_message_async(self, message: bytes) -> bytes: + """Handle a raw incoming asynchronous message.""" + return await self._receiver.handle_raw_message_async( + self._obj, message) + + +# RCVA_CODE_TEST_END TEST_PROTOCOL = MessageProtocol( message_types={ - 0: _TMessage1, - 1: _TMessage2, - 2: _TMessage3, + 0: _TMsg1, + 1: _TMsg2, + 2: _TMsg3, }, response_types={ - 0: _TResponse1, - 1: _TResponse2, + 0: _TResp1, + 1: _TResp2, }, trusted_sender=True, log_remote_exceptions=False, @@ -185,20 +274,21 @@ TEST_PROTOCOL = MessageProtocol( def test_protocol_creation() -> None: """Test protocol creation.""" - # This should fail because _TMessage1 can return _TResponse1 which + # This should fail because _TMsg1 can return _TResp1 which # is not given an id here. with pytest.raises(ValueError): - _protocol = MessageProtocol(message_types={0: _TMessage1}, - response_types={0: _TResponse2}) + _protocol = MessageProtocol(message_types={0: _TMsg1}, + response_types={0: _TResp2}) # Now it should work. - _protocol = MessageProtocol(message_types={0: _TMessage1}, - response_types={0: _TResponse1}) + _protocol = MessageProtocol(message_types={0: _TMsg1}, + response_types={0: _TResp1}) def test_sender_module_embedded() -> None: """Test generation of protocol-specific sender modules for typing/etc.""" - smod = TEST_PROTOCOL.create_sender_module('Test', private=True) + smod = TEST_PROTOCOL.create_sender_module('TestMessageSender', + private=True) # Clip everything up to our first class declaration. lines = smod.splitlines() @@ -218,13 +308,15 @@ def test_sender_module_embedded() -> None: ' See test stdout for new code.') -def test_receiver_module_embedded() -> None: +def test_receiver_module_sync_embedded() -> None: """Test generation of protocol-specific sender modules for typing/etc.""" - smod = TEST_PROTOCOL.create_receiver_module('Test', private=True) + smod = TEST_PROTOCOL.create_receiver_module('TestSyncMessageReceiver', + is_async=False, + private=True) # Clip everything up to our first class declaration. lines = smod.splitlines() - classline = lines.index('class _TestMessageReceiver(MessageReceiver):') + classline = lines.index('class _TestSyncMessageReceiver(MessageReceiver):') clipped = '\n'.join(lines[classline:]) # This snippet should match what we've got embedded above; @@ -232,12 +324,39 @@ def test_receiver_module_embedded() -> None: with open(__file__, encoding='utf-8') as infile: ourcode = infile.read() - emb = f'# RCV_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCV_CODE_TEST_END\n' + emb = f'# RCVS_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCVS_CODE_TEST_END\n' if emb not in ourcode: - print(f'EXPECTED EMBEDDED CODE:\n{emb}') - raise RuntimeError('Generated sender module does not match embedded;' - ' test code needs to be updated.' - ' See test stdout for new code.') + print(f'EXPECTED SYNC RECEIVER EMBEDDED CODE:\n{emb}') + raise RuntimeError( + 'Generated sync receiver module does not match embedded;' + ' test code needs to be updated.' + ' See test stdout for new code.') + + +def test_receiver_module_async_embedded() -> None: + """Test generation of protocol-specific sender modules for typing/etc.""" + smod = TEST_PROTOCOL.create_receiver_module('TestAsyncMessageReceiver', + is_async=True, + private=True) + + # Clip everything up to our first class declaration. + lines = smod.splitlines() + classline = lines.index( + 'class _TestAsyncMessageReceiver(MessageReceiver):') + clipped = '\n'.join(lines[classline:]) + + # This snippet should match what we've got embedded above; + # If not then we need to update our embedded version. + with open(__file__, encoding='utf-8') as infile: + ourcode = infile.read() + + emb = f'# RCVA_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCVA_CODE_TEST_END\n' + if emb not in ourcode: + print(f'EXPECTED ASYNC RECEIVER EMBEDDED CODE:\n{emb}') + raise RuntimeError( + 'Generated async receiver module does not match embedded;' + ' test code needs to be updated.' + ' See test stdout for new code.') def test_receiver_creation() -> None: @@ -251,13 +370,13 @@ def test_receiver_creation() -> None: class _TestClassR: """Test class incorporating receive functionality.""" - receiver = _TestMessageReceiver(TEST_PROTOCOL) + receiver = _TestSyncMessageReceiver(TEST_PROTOCOL) @receiver.handler - def handle_test_message_2(self, msg: _TMessage2) -> _TResponse2: + def handle_test_message_2(self, msg: _TMsg2) -> _TResp2: """Test.""" del msg # Unused - return _TResponse2(fval=1.2) + return _TResp2(fval=1.2) # Validation should fail because not all message types in the # protocol are handled. @@ -266,13 +385,13 @@ def test_receiver_creation() -> None: class _TestClassR2: """Test class incorporating receive functionality.""" - receiver = _TestMessageReceiver(TEST_PROTOCOL) + receiver = _TestSyncMessageReceiver(TEST_PROTOCOL) # Checks that we've added handlers for all message types, etc. receiver.validate() -def test_synchronous_messaging() -> None: +def test_full_pipeline() -> None: """Test the full pipeline.""" # Define a class that can send messages and one that can receive them. @@ -281,66 +400,119 @@ def test_synchronous_messaging() -> None: msg = _TestMessageSender(TEST_PROTOCOL) - def __init__(self, target: TestClassR) -> None: + def __init__(self, target: Union[TestClassRSync, + TestClassRAsync]) -> None: self._target = target - @msg.send_raw_handler + @msg.send_method def _send_raw_message(self, data: bytes) -> bytes: - """Test.""" + """Handle synchronous sending of raw message data.""" + # Just talk directly to the receiver for this example. + # (currently only support synchronous receivers) + assert isinstance(self._target, TestClassRSync) return self._target.receiver.handle_raw_message(data) - class TestClassR: - """Test class incorporating receive functionality.""" + @msg.send_async_method + async def _send_raw_message_async(self, data: bytes) -> bytes: + """Handle asynchronous sending of raw message data.""" + # Just talk directly to the receiver for this example. + # (we can do sync or async receivers) + if isinstance(self._target, TestClassRSync): + return self._target.receiver.handle_raw_message(data) + return await self._target.receiver.handle_raw_message_async(data) - receiver = _TestMessageReceiver(TEST_PROTOCOL) + class TestClassRSync: + """Test class incorporating synchronous receive functionality.""" + + receiver = _TestSyncMessageReceiver(TEST_PROTOCOL) @receiver.handler - def handle_test_message_1(self, msg: _TMessage1) -> _TResponse1: + def handle_test_message_1(self, msg: _TMsg1) -> _TResp1: """Test.""" if msg.ival == 1: raise CleanError('Testing Clean Error') if msg.ival == 2: raise RuntimeError('Testing Runtime Error') - return _TResponse1(bval=True) + return _TResp1(bval=True) @receiver.handler - def handle_test_message_2( - self, msg: _TMessage2) -> Union[_TResponse1, _TResponse2]: + def handle_test_message_2(self, + msg: _TMsg2) -> Union[_TResp1, _TResp2]: """Test.""" del msg # Unused - return _TResponse2(fval=1.2) + return _TResp2(fval=1.2) @receiver.handler - def handle_test_message_3(self, msg: _TMessage3) -> None: + def handle_test_message_3(self, msg: _TMsg3) -> None: """Test.""" del msg # Unused receiver.validate() - obj_r = TestClassR() - obj = TestClassS(target=obj_r) + class TestClassRAsync: + """Test class incorporating asynchronous receive functionality.""" - response = obj.msg.send(_TMessage1(ival=0)) - assert isinstance(response, _TResponse1) + receiver = _TestAsyncMessageReceiver(TEST_PROTOCOL) - response2 = obj.msg.send(_TMessage2(sval='rah')) - assert isinstance(response2, (_TResponse1, _TResponse2)) + @receiver.handler + async def handle_test_message_1(self, msg: _TMsg1) -> _TResp1: + """Test.""" + if msg.ival == 1: + raise CleanError('Testing Clean Error') + if msg.ival == 2: + raise RuntimeError('Testing Runtime Error') + return _TResp1(bval=True) - response3 = obj.msg.send(_TMessage3(sval='rah')) - assert response3 is None + @receiver.handler + async def handle_test_message_2( + self, msg: _TMsg2) -> Union[_TResp1, _TResp2]: + """Test.""" + del msg # Unused + return _TResp2(fval=1.2) - # Make sure static typing lines up too. + @receiver.handler + async def handle_test_message_3(self, msg: _TMsg3) -> None: + """Test.""" + del msg # Unused + + receiver.validate() + + obj_r_sync = TestClassRSync() + obj_r_async = TestClassRAsync() + obj = TestClassS(target=obj_r_sync) + obj2 = TestClassS(target=obj_r_async) + + # Test sends (of sync and async varieties). + response1 = obj.msg.send(_TMsg1(ival=0)) + response2 = obj.msg.send(_TMsg2(sval='rah')) + response3 = obj.msg.send(_TMsg3(sval='rah')) + response4 = asyncio.run(obj.msg.send_async(_TMsg1(ival=0))) + + # Make sure static typing lines up with what we expect. if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1': - assert static_type_equals(response, _TResponse1) + assert static_type_equals(response1, _TResp1) assert static_type_equals(response3, None) + assert isinstance(response1, _TResp1) + assert isinstance(response2, (_TResp1, _TResp2)) + assert response3 is None + assert isinstance(response4, _TResp1) + # Remote CleanErrors should come across locally as the same. try: - _response4 = obj.msg.send(_TMessage1(ival=1)) + _response5 = obj.msg.send(_TMsg1(ival=1)) except Exception as exc: assert isinstance(exc, CleanError) assert str(exc) == 'Testing Clean Error' # Other remote errors should result in RemoteError. with pytest.raises(RemoteError): - _response4 = obj.msg.send(_TMessage1(ival=2)) + _response5 = obj.msg.send(_TMsg1(ival=2)) + + # Now test sends to async handlers. + response6 = asyncio.run(obj2.msg.send_async(_TMsg1(ival=0))) + assert isinstance(response6, _TResp1) + + # Make sure static typing lines up with what we expect. + if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1': + assert static_type_equals(response6, _TResp1) diff --git a/tools/efro/message.py b/tools/efro/message.py index a7cba22e..f640ed54 100644 --- a/tools/efro/message.py +++ b/tools/efro/message.py @@ -22,7 +22,7 @@ from efro.dataclassio import (ioprepped, is_ioprepped_dataclass, IOAttrs, if TYPE_CHECKING: from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set, - Sequence, Union) + Sequence, Union, Awaitable) from efro.error import CommunicationError TM = TypeVar('TM', bound='MessageSender') @@ -300,7 +300,7 @@ class MessageProtocol: return out def create_sender_module(self, - classname: str, + basename: str, private: bool = False) -> str: """"Create a Python module defining a MessageSender subclass. @@ -308,29 +308,35 @@ class MessageProtocol: for the varieties of send calls for message/response types defined in the protocol. + Class names are based on basename; a basename 'FooSender' will + result in classes FooSender and BoundFooSender. + + If 'private' is True, class-names will be prefixed with an '_'. + Note that line lengths are not clipped, so output may need to be run through a formatter to prevent lint warnings about excessive line lengths. """ + # pylint: disable=too-many-locals ppre = '_' if private else '' out = self._get_module_header('sender') - out += (f'class {ppre}{classname}MessageSender(MessageSender):\n' + out += (f'class {ppre}{basename}(MessageSender):\n' f' """Protocol-specific sender."""\n' f'\n' f' def __get__(self,\n' f' obj: Any,\n' f' type_in: Any = None)' - f' -> {ppre}Bound{classname}MessageSender:\n' - f' return {ppre}Bound{classname}MessageSender' + f' -> {ppre}Bound{basename}:\n' + f' return {ppre}Bound{basename}' f'(obj, self)\n' f'\n' f'\n' - f'class {ppre}Bound{classname}MessageSender:\n' + f'class {ppre}Bound{basename}:\n' f' """Protocol-specific bound sender."""\n' f'\n' f' def __init__(self, obj: Any,' - f' sender: {ppre}{classname}MessageSender) -> None:\n' + f' sender: {ppre}{basename}) -> None:\n' f' assert obj is not None\n' f' self._obj = obj\n' f' self._sender = sender\n') @@ -352,29 +358,37 @@ class MessageProtocol: return 'None' if rtype is EmptyResponse else rtype.__name__ if len(msgtypes) > 1: - for msgtype in msgtypes: - msgtypevar = msgtype.__name__ - rtypes = msgtype.get_response_types() - if len(rtypes) > 1: - tps = ', '.join(_filt_tp_name(t) for t in rtypes) - rtypevar = f'Union[{tps}]' - else: - rtypevar = _filt_tp_name(rtypes[0]) + for async_pass in False, True: + pfx = 'async ' if async_pass else '' + sfx = '_async' if async_pass else '' + awt = 'await ' if async_pass else '' + how = 'asynchronously' if async_pass else 'synchronously' + for msgtype in msgtypes: + msgtypevar = msgtype.__name__ + rtypes = msgtype.get_response_types() + if len(rtypes) > 1: + tps = ', '.join(_filt_tp_name(t) for t in rtypes) + rtypevar = f'Union[{tps}]' + else: + rtypevar = _filt_tp_name(rtypes[0]) + out += (f'\n' + f' @overload\n' + f' {pfx}def send{sfx}(self,' + f' message: {msgtypevar})' + f' -> {rtypevar}:\n' + f' ...\n') out += (f'\n' - f' @overload\n' - f' def send(self, message: {msgtypevar})' - f' -> {rtypevar}:\n' - f' ...\n') - out += ('\n' - ' def send(self, message: Message)' - ' -> Optional[Response]:\n' - ' """Send a message."""\n' - ' return self._sender.send(self._obj, message)\n') + f' {pfx}def send{sfx}(self, message: Message)' + f' -> Optional[Response]:\n' + f' """Send a message {how}."""\n' + f' return {awt}self._sender.' + f'send{sfx}(self._obj, message)\n') return out def create_receiver_module(self, - classname: str, + basename: str, + is_async: bool, private: bool = False) -> str: """"Create a Python module defining a MessageReceiver subclass. @@ -382,21 +396,33 @@ class MessageProtocol: for the register method for message/response types defined in the protocol. + Class names are based on basename; a basename 'FooReceiver' will + result in FooReceiver and BoundFooReceiver. + + If 'is_async' is True, handle_raw_message() will be an async method + and the @handler decorator will expect async methods. + + If 'private' is True, class-names will be prefixed with an '_'. + Note that line lengths are not clipped, so output may need to be run through a formatter to prevent lint warnings about excessive line lengths. """ + # pylint: disable=too-many-locals + desc = 'asynchronous' if is_async else 'synchronous' ppre = '_' if private else '' out = self._get_module_header('receiver') - out += (f'class {ppre}{classname}MessageReceiver(MessageReceiver):\n' - f' """Protocol-specific receiver."""\n' + out += (f'class {ppre}{basename}(MessageReceiver):\n' + f' """Protocol-specific {desc} receiver."""\n' + f'\n' + f' is_async = {is_async}\n' f'\n' f' def __get__(\n' f' self,\n' f' obj: Any,\n' f' type_in: Any = None,\n' - f' ) -> {ppre}Bound{classname}MessageReceiver:\n' - f' return {ppre}Bound{classname}MessageReceiver(' + f' ) -> {ppre}Bound{basename}:\n' + f' return {ppre}Bound{basename}(' f'obj, self)\n') # Define handler() overloads for all registered message types. @@ -416,6 +442,8 @@ class MessageProtocol: return 'None' if rtype is EmptyResponse else rtype.__name__ if len(msgtypes) > 1: + cbgn = 'Awaitable[' if is_async else '' + cend = ']' if is_async else '' for msgtype in msgtypes: msgtypevar = msgtype.__name__ rtypes = msgtype.get_response_types() @@ -424,6 +452,7 @@ class MessageProtocol: rtypevar = f'Union[{tps}]' else: rtypevar = _filt_tp_name(rtypes[0]) + rtypevar = f'{cbgn}{rtypevar}{cend}' out += ( f'\n' f' @overload\n' @@ -441,22 +470,29 @@ class MessageProtocol: out += (f'\n' f'\n' - f'class {ppre}Bound{classname}MessageReceiver:\n' + f'class {ppre}Bound{basename}:\n' f' """Protocol-specific bound receiver."""\n' f'\n' f' def __init__(\n' f' self,\n' f' obj: Any,\n' - f' receiver: {ppre}{classname}MessageReceiver,\n' + f' receiver: {ppre}{basename},\n' f' ) -> None:\n' f' assert obj is not None\n' f' self._obj = obj\n' f' self._receiver = receiver\n' f'\n' f' def handle_raw_message(self, message: bytes) -> bytes:\n' - f' """Handle a raw incoming message."""\n' + f' """Handle a raw incoming synchronous message."""\n' f' return self._receiver.handle_raw_message' - f'(self._obj, message)\n') + f'(self._obj, message)\n' + f'\n' + f' async def handle_raw_message_async(self, message: bytes)' + f' -> bytes:\n' + f' """Handle a raw incoming asynchronous message."""\n' + f' return await' + f' self._receiver.handle_raw_message_async(\n' + f' self._obj, message)\n') return out @@ -471,7 +507,7 @@ class MessageSender: class MyClass: msg = MyMessageSender(some_protocol) - @msg.send_raw_handler + @msg.sendmethod def send_raw_message(self, message: bytes) -> bytes: # Actually send the message here. @@ -485,8 +521,10 @@ class MessageSender: self._protocol = protocol self._send_raw_message_call: Optional[Callable[[Any, bytes], bytes]] = None + self._send_async_raw_message_call: Optional[Callable[ + [Any, bytes], Awaitable[bytes]]] = None - def send_raw_handler( + def send_method( self, call: Callable[[Any, bytes], bytes]) -> Callable[[Any, bytes], bytes]: """Function decorator for setting raw send method.""" @@ -494,6 +532,14 @@ class MessageSender: self._send_raw_message_call = call return call + def send_async_method( + self, call: Callable[[Any, bytes], Awaitable[bytes]] + ) -> Callable[[Any, bytes], Awaitable[bytes]]: + """Function decorator for setting raw send-async method.""" + assert self._send_async_raw_message_call is None + self._send_async_raw_message_call = call + return call + def send(self, bound_obj: Any, message: Message) -> Optional[Response]: """Send a message and receive a response. @@ -510,21 +556,24 @@ class MessageSender: or type(response) in type(message).get_response_types()) return response - def send_bg(self, bound_obj: Any, message: Message) -> Message: - """Send a message asynchronously and receive a future. - - The message will be encoded for transport and passed to - dispatch_raw_message from a background thread. - """ - raise RuntimeError('Unimplemented!') - - def send_async(self, bound_obj: Any, message: Message) -> Message: + async def send_async(self, bound_obj: Any, + message: Message) -> Optional[Response]: """Send a message asynchronously using asyncio. The message will be encoded for transport and passed to dispatch_raw_message_async. """ - raise RuntimeError('Unimplemented!') + if self._send_async_raw_message_call is None: + raise RuntimeError('send_async() is unimplemented for this type.') + + msg_encoded = self._protocol.encode_message(message) + response_encoded = await self._send_async_raw_message_call( + bound_obj, msg_encoded) + response = self._protocol.decode_response(response_encoded) + assert isinstance(response, (Response, type(None))) + assert (response is None + or type(response) in type(message).get_response_types()) + return response class MessageReceiver: @@ -552,6 +601,8 @@ class MessageReceiver: an Exception being raised on the sending end. """ + is_async = False + def __init__(self, protocol: MessageProtocol) -> None: self._protocol = protocol self._handlers: Dict[Type[Message], Callable] = {} @@ -564,6 +615,7 @@ class MessageReceiver: The message type handled by the call is determined by its type annotation. """ + # pylint: disable=too-many-locals # TODO: can use types.GenericAlias in 3.9. from typing import _GenericAlias # type: ignore from typing import Union, get_type_hints, get_args @@ -576,10 +628,19 @@ class MessageReceiver: raise ValueError(f'Expected callable signature of {expectedsig};' f' got {sig.args}') + # Make sure we are only given async methods if we are an async handler + # and sync ones otherwise. + is_async = inspect.iscoroutinefunction(call) + if self.is_async != is_async: + msg = ('Expected a sync method; found an async one.' if is_async + else 'Expected an async method; found a sync one.') + raise ValueError(msg) + # Check annotation types to determine what message types we handle. # Return-type annotation can be a Union, but we probably don't # have it available at runtime. Explicitly pull it in. anns = get_type_hints(call, localns={'Union': Union}) + msgtype = anns.get('msg') if not isinstance(msgtype, type): raise TypeError( @@ -642,54 +703,71 @@ class MessageReceiver: logging.warning(msg) raise TypeError(msg) + def _decode_incoming_message(self, + msg: bytes) -> Tuple[Message, Type[Message]]: + # Decode the incoming message. + msg_decoded = self._protocol.decode_message(msg) + msgtype = type(msg_decoded) + assert issubclass(msgtype, Message) + return msg_decoded, msgtype + + def _encode_response(self, response: Optional[Response], + msgtype: Type[Message]) -> bytes: + + # A return value of None equals EmptyResponse. + if response is None: + response = EmptyResponse() + + # Re-encode the response. + assert isinstance(response, Response) + # (user should never explicitly return these) + assert not isinstance(response, ErrorResponse) + assert type(response) in msgtype.get_response_types() + return self._protocol.encode_response(response) + + def _handle_raw_message_error(self, exc: Exception) -> bytes: + if self._protocol.log_remote_exceptions: + logging.exception('Error handling message.') + + # If anything goes wrong, return a ErrorResponse instead. + if (isinstance(exc, CleanError) + and self._protocol.preserve_clean_errors): + err_response = ErrorResponse(error_message=str(exc), + error_type=ErrorType.CLEAN) + else: + err_response = ErrorResponse( + error_message=(traceback.format_exc() + if self._protocol.trusted_sender else + 'An unknown error has occurred.'), + error_type=ErrorType.OTHER) + return self._protocol.encode_response(err_response) + def handle_raw_message(self, bound_obj: Any, msg: bytes) -> bytes: """Decode, handle, and return an encoded response for a message.""" try: - # Decode the incoming message. - msg_decoded = self._protocol.decode_message(msg) - msgtype = type(msg_decoded) - assert issubclass(msgtype, Message) - - # Call the proper handler. + msg_decoded, msgtype = self._decode_incoming_message(msg) handler = self._handlers.get(msgtype) if handler is None: raise RuntimeError(f'Got unhandled message type: {msgtype}.') response = handler(bound_obj, msg_decoded) - - # A return value of None equals EmptyResponse. - if response is None: - response = EmptyResponse() - - # Re-encode the response. - assert isinstance(response, Response) - # (user should never explicitly return these) - assert not isinstance(response, ErrorResponse) - assert type(response) in msgtype.get_response_types() - return self._protocol.encode_response(response) + return self._encode_response(response, msgtype) except Exception as exc: + return self._handle_raw_message_error(exc) - if self._protocol.log_remote_exceptions: - logging.exception('Error handling message.') - - # If anything goes wrong, return a ErrorResponse instead. - if (isinstance(exc, CleanError) - and self._protocol.preserve_clean_errors): - err_response = ErrorResponse(error_message=str(exc), - error_type=ErrorType.CLEAN) - - else: - - err_response = ErrorResponse( - error_message=(traceback.format_exc() - if self._protocol.trusted_sender else - 'An unknown error has occurred.'), - error_type=ErrorType.OTHER) - return self._protocol.encode_response(err_response) - - async def handle_raw_message_async(self, msg: bytes) -> bytes: + async def handle_raw_message_async(self, bound_obj: Any, + msg: bytes) -> bytes: """Should be called when the receiver gets a message. The return value is the raw response to the message. """ - raise RuntimeError('Unimplemented!') + try: + msg_decoded, msgtype = self._decode_incoming_message(msg) + handler = self._handlers.get(msgtype) + if handler is None: + raise RuntimeError(f'Got unhandled message type: {msgtype}.') + response = await handler(bound_obj, msg_decoded) + return self._encode_response(response, msgtype) + + except Exception as exc: + return self._handle_raw_message_error(exc)