diff --git a/tests/test_efro/test_message.py b/tests/test_efro/test_message.py index 6646cb56..2ca55164 100644 --- a/tests/test_efro/test_message.py +++ b/tests/test_efro/test_message.py @@ -72,10 +72,11 @@ class _TResp3(Message): fval: float -# SEND_CODE_TEST_BEGIN +# Generated sender supporting both sync and async sending: +# SEND_SYNC_CODE_TEST_BEGIN -class _TestMessageSender(MessageSender): +class _TestMessageSenderSync(MessageSender): """Protocol-specific sender.""" def __init__(self) -> None: @@ -84,11 +85,89 @@ class _TestMessageSender(MessageSender): def __get__(self, obj: Any, - type_in: Any = None) -> _BoundTestMessageSender: - return _BoundTestMessageSender(obj, self) + type_in: Any = None) -> _BoundTestMessageSenderSync: + return _BoundTestMessageSenderSync(obj, self) -class _BoundTestMessageSender(BoundMessageSender): +class _BoundTestMessageSenderSync(BoundMessageSender): + """Protocol-specific bound sender.""" + + @overload + def send(self, message: _TMsg1) -> _TResp1: + ... + + @overload + def send(self, message: _TMsg2) -> Union[_TResp1, _TResp2]: + ... + + @overload + def send(self, message: _TMsg3) -> None: + ... + + def send(self, message: Message) -> Optional[Response]: + """Send a message synchronously.""" + return self._sender.send(self._obj, message) + + +# SEND_SYNC_CODE_TEST_END + +# Generated sender supporting only async sending: +# SEND_ASYNC_CODE_TEST_BEGIN + + +class _TestMessageSenderAsync(MessageSender): + """Protocol-specific sender.""" + + def __init__(self) -> None: + protocol = TEST_PROTOCOL + super().__init__(protocol) + + def __get__(self, + obj: Any, + type_in: Any = None) -> _BoundTestMessageSenderAsync: + return _BoundTestMessageSenderAsync(obj, self) + + +class _BoundTestMessageSenderAsync(BoundMessageSender): + """Protocol-specific bound sender.""" + + @overload + async def send_async(self, message: _TMsg1) -> _TResp1: + ... + + @overload + async def send_async(self, message: _TMsg2) -> Union[_TResp1, _TResp2]: + ... + + @overload + async def send_async(self, message: _TMsg3) -> None: + ... + + async def send_async(self, message: Message) -> Optional[Response]: + """Send a message asynchronously.""" + return await self._sender.send_async(self._obj, message) + + +# SEND_ASYNC_CODE_TEST_END + +# Generated sender supporting both sync and async sending: +# SEND_BOTH_CODE_TEST_BEGIN + + +class _TestMessageSenderBoth(MessageSender): + """Protocol-specific sender.""" + + def __init__(self) -> None: + protocol = TEST_PROTOCOL + super().__init__(protocol) + + def __get__(self, + obj: Any, + type_in: Any = None) -> _BoundTestMessageSenderBoth: + return _BoundTestMessageSenderBoth(obj, self) + + +class _BoundTestMessageSenderBoth(BoundMessageSender): """Protocol-specific bound sender.""" @overload @@ -124,8 +203,10 @@ class _BoundTestMessageSender(BoundMessageSender): return await self._sender.send_async(self._obj, message) -# SEND_CODE_TEST_END -# RCVS_CODE_TEST_BEGIN +# SEND_BOTH_CODE_TEST_END + +# Generated receiver supporting sync handling: +# RCV_SYNC_CODE_TEST_BEGIN class _TestSyncMessageReceiver(MessageReceiver): @@ -179,8 +260,10 @@ class _BoundTestSyncMessageReceiver(BoundMessageReceiver): return self._receiver.handle_raw_message(self._obj, message) -# RCVS_CODE_TEST_END -# RCVA_CODE_TEST_BEGIN +# RCV_SYNC_CODE_TEST_END + +# Generated receiver supporting async handling: +# RCV_ASYNC_CODE_TEST_BEGIN class _TestAsyncMessageReceiver(MessageReceiver): @@ -235,7 +318,7 @@ class _BoundTestAsyncMessageReceiver(BoundMessageReceiver): self._obj, message) -# RCVA_CODE_TEST_END +# RCV_ASYNC_CODE_TEST_END TEST_PROTOCOL = MessageProtocol( message_types={ @@ -266,21 +349,23 @@ def test_protocol_creation() -> None: response_types={0: _TResp1}) -def test_sender_module_embedded() -> None: +def test_sender_module_sync_embedded() -> None: """Test generation of protocol-specific sender modules for typing/etc.""" # NOTE: Ideally we should be testing efro.message.create_sender_module() # here, but it requires us to pass code which imports this test module # to get at the protocol, and that currently fails in our static mypy # tests. smod = TEST_PROTOCOL.do_create_sender_module( - 'TestMessageSender', - 'protocol = TEST_PROTOCOL', + 'TestMessageSenderSync', + protocol_create_code='protocol = TEST_PROTOCOL', + enable_sync_sends=True, + enable_async_sends=False, private=True, ) # Clip everything up to our first class declaration. lines = smod.splitlines() - classline = lines.index('class _TestMessageSender(MessageSender):') + classline = lines.index('class _TestMessageSenderSync(MessageSender):') clipped = '\n'.join(lines[classline:]) # This snippet should match what we've got embedded above; @@ -288,7 +373,74 @@ def test_sender_module_embedded() -> None: with open(__file__, encoding='utf-8') as infile: ourcode = infile.read() - emb = f'# SEND_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# SEND_CODE_TEST_END\n' + emb = (f'# SEND_SYNC_CODE_TEST_BEGIN' + f'\n\n\n{clipped}\n\n\n# SEND_SYNC_CODE_TEST_END\n') + if emb not in ourcode: + print(f'EXPECTED EMBEDDED CODE:\n{emb}') + raise RuntimeError('Generated sender module does not match embedded;' + ' test code needs to be updated.' + ' See test stdout for new code.') + + +def test_sender_module_async_embedded() -> None: + """Test generation of protocol-specific sender modules for typing/etc.""" + # NOTE: Ideally we should be testing efro.message.create_sender_module() + # here, but it requires us to pass code which imports this test module + # to get at the protocol, and that currently fails in our static mypy + # tests. + smod = TEST_PROTOCOL.do_create_sender_module( + 'TestMessageSenderAsync', + protocol_create_code='protocol = TEST_PROTOCOL', + enable_sync_sends=False, + enable_async_sends=True, + private=True, + ) + + # Clip everything up to our first class declaration. + lines = smod.splitlines() + classline = lines.index('class _TestMessageSenderAsync(MessageSender):') + clipped = '\n'.join(lines[classline:]) + + # This snippet should match what we've got embedded above; + # If not then we need to update our embedded version. + with open(__file__, encoding='utf-8') as infile: + ourcode = infile.read() + + emb = (f'# SEND_ASYNC_CODE_TEST_BEGIN' + f'\n\n\n{clipped}\n\n\n# SEND_ASYNC_CODE_TEST_END\n') + if emb not in ourcode: + print(f'EXPECTED EMBEDDED CODE:\n{emb}') + raise RuntimeError('Generated sender module does not match embedded;' + ' test code needs to be updated.' + ' See test stdout for new code.') + + +def test_sender_module_both_embedded() -> None: + """Test generation of protocol-specific sender modules for typing/etc.""" + # NOTE: Ideally we should be testing efro.message.create_sender_module() + # here, but it requires us to pass code which imports this test module + # to get at the protocol, and that currently fails in our static mypy + # tests. + smod = TEST_PROTOCOL.do_create_sender_module( + 'TestMessageSenderBoth', + protocol_create_code='protocol = TEST_PROTOCOL', + enable_sync_sends=True, + enable_async_sends=True, + private=True, + ) + + # Clip everything up to our first class declaration. + lines = smod.splitlines() + classline = lines.index('class _TestMessageSenderBoth(MessageSender):') + clipped = '\n'.join(lines[classline:]) + + # This snippet should match what we've got embedded above; + # If not then we need to update our embedded version. + with open(__file__, encoding='utf-8') as infile: + ourcode = infile.read() + + emb = (f'# SEND_BOTH_CODE_TEST_BEGIN' + f'\n\n\n{clipped}\n\n\n# SEND_BOTH_CODE_TEST_END\n') if emb not in ourcode: print(f'EXPECTED EMBEDDED CODE:\n{emb}') raise RuntimeError('Generated sender module does not match embedded;' @@ -319,7 +471,8 @@ def test_receiver_module_sync_embedded() -> None: with open(__file__, encoding='utf-8') as infile: ourcode = infile.read() - emb = f'# RCVS_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCVS_CODE_TEST_END\n' + emb = (f'# RCV_SYNC_CODE_TEST_BEGIN' + f'\n\n\n{clipped}\n\n\n# RCV_SYNC_CODE_TEST_END\n') if emb not in ourcode: print(f'EXPECTED SYNC RECEIVER EMBEDDED CODE:\n{emb}') raise RuntimeError( @@ -352,7 +505,8 @@ def test_receiver_module_async_embedded() -> None: with open(__file__, encoding='utf-8') as infile: ourcode = infile.read() - emb = f'# RCVA_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCVA_CODE_TEST_END\n' + emb = (f'# RCV_ASYNC_CODE_TEST_BEGIN' + f'\n\n\n{clipped}\n\n\n# RCV_ASYNC_CODE_TEST_END\n') if emb not in ourcode: print(f'EXPECTED ASYNC RECEIVER EMBEDDED CODE:\n{emb}') raise RuntimeError( @@ -400,7 +554,7 @@ def test_full_pipeline() -> None: class TestClassS: """Test class incorporating send functionality.""" - msg = _TestMessageSender() + msg = _TestMessageSenderBoth() def __init__(self, target: Union[TestClassRSync, TestClassRAsync]) -> None: diff --git a/tools/efro/message.py b/tools/efro/message.py index 0fcabcba..64924beb 100644 --- a/tools/efro/message.py +++ b/tools/efro/message.py @@ -317,6 +317,8 @@ class MessageProtocol: def do_create_sender_module(self, basename: str, protocol_create_code: str, + enable_sync_sends: bool, + enable_async_sends: bool, private: bool = False) -> str: """Used by create_sender_module(); do not call directly.""" # pylint: disable=too-many-locals @@ -361,6 +363,10 @@ class MessageProtocol: if len(msgtypes) > 1: for async_pass in False, True: + if async_pass and not enable_async_sends: + continue + if not async_pass and not enable_sync_sends: + continue pfx = 'async ' if async_pass else '' sfx = '_async' if async_pass else '' awt = 'await ' if async_pass else '' @@ -823,6 +829,8 @@ class BoundMessageReceiver: def create_sender_module(basename: str, protocol_create_code: str, + enable_sync_sends: bool, + enable_async_sends: bool, private: bool = False) -> str: """Create a Python module defining a MessageSender subclass. @@ -853,6 +861,8 @@ def create_sender_module(basename: str, return protocol.do_create_sender_module( basename=basename, protocol_create_code=protocol_create_code, + enable_sync_sends=enable_sync_sends, + enable_async_sends=enable_async_sends, private=private)