mirror of
https://github.com/RYDE-WORK/ballistica.git
synced 2026-01-23 07:23:19 +08:00
generated senders can now support sync and/or async sends
This commit is contained in:
parent
91a40f1062
commit
46b02414fd
@ -72,10 +72,11 @@ class _TResp3(Message):
|
||||
fval: float
|
||||
|
||||
|
||||
# SEND_CODE_TEST_BEGIN
|
||||
# Generated sender supporting both sync and async sending:
|
||||
# SEND_SYNC_CODE_TEST_BEGIN
|
||||
|
||||
|
||||
class _TestMessageSender(MessageSender):
|
||||
class _TestMessageSenderSync(MessageSender):
|
||||
"""Protocol-specific sender."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@ -84,11 +85,89 @@ class _TestMessageSender(MessageSender):
|
||||
|
||||
def __get__(self,
|
||||
obj: Any,
|
||||
type_in: Any = None) -> _BoundTestMessageSender:
|
||||
return _BoundTestMessageSender(obj, self)
|
||||
type_in: Any = None) -> _BoundTestMessageSenderSync:
|
||||
return _BoundTestMessageSenderSync(obj, self)
|
||||
|
||||
|
||||
class _BoundTestMessageSender(BoundMessageSender):
|
||||
class _BoundTestMessageSenderSync(BoundMessageSender):
|
||||
"""Protocol-specific bound sender."""
|
||||
|
||||
@overload
|
||||
def send(self, message: _TMsg1) -> _TResp1:
|
||||
...
|
||||
|
||||
@overload
|
||||
def send(self, message: _TMsg2) -> Union[_TResp1, _TResp2]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def send(self, message: _TMsg3) -> None:
|
||||
...
|
||||
|
||||
def send(self, message: Message) -> Optional[Response]:
|
||||
"""Send a message synchronously."""
|
||||
return self._sender.send(self._obj, message)
|
||||
|
||||
|
||||
# SEND_SYNC_CODE_TEST_END
|
||||
|
||||
# Generated sender supporting only async sending:
|
||||
# SEND_ASYNC_CODE_TEST_BEGIN
|
||||
|
||||
|
||||
class _TestMessageSenderAsync(MessageSender):
|
||||
"""Protocol-specific sender."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
protocol = TEST_PROTOCOL
|
||||
super().__init__(protocol)
|
||||
|
||||
def __get__(self,
|
||||
obj: Any,
|
||||
type_in: Any = None) -> _BoundTestMessageSenderAsync:
|
||||
return _BoundTestMessageSenderAsync(obj, self)
|
||||
|
||||
|
||||
class _BoundTestMessageSenderAsync(BoundMessageSender):
|
||||
"""Protocol-specific bound sender."""
|
||||
|
||||
@overload
|
||||
async def send_async(self, message: _TMsg1) -> _TResp1:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def send_async(self, message: _TMsg2) -> Union[_TResp1, _TResp2]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def send_async(self, message: _TMsg3) -> None:
|
||||
...
|
||||
|
||||
async def send_async(self, message: Message) -> Optional[Response]:
|
||||
"""Send a message asynchronously."""
|
||||
return await self._sender.send_async(self._obj, message)
|
||||
|
||||
|
||||
# SEND_ASYNC_CODE_TEST_END
|
||||
|
||||
# Generated sender supporting both sync and async sending:
|
||||
# SEND_BOTH_CODE_TEST_BEGIN
|
||||
|
||||
|
||||
class _TestMessageSenderBoth(MessageSender):
|
||||
"""Protocol-specific sender."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
protocol = TEST_PROTOCOL
|
||||
super().__init__(protocol)
|
||||
|
||||
def __get__(self,
|
||||
obj: Any,
|
||||
type_in: Any = None) -> _BoundTestMessageSenderBoth:
|
||||
return _BoundTestMessageSenderBoth(obj, self)
|
||||
|
||||
|
||||
class _BoundTestMessageSenderBoth(BoundMessageSender):
|
||||
"""Protocol-specific bound sender."""
|
||||
|
||||
@overload
|
||||
@ -124,8 +203,10 @@ class _BoundTestMessageSender(BoundMessageSender):
|
||||
return await self._sender.send_async(self._obj, message)
|
||||
|
||||
|
||||
# SEND_CODE_TEST_END
|
||||
# RCVS_CODE_TEST_BEGIN
|
||||
# SEND_BOTH_CODE_TEST_END
|
||||
|
||||
# Generated receiver supporting sync handling:
|
||||
# RCV_SYNC_CODE_TEST_BEGIN
|
||||
|
||||
|
||||
class _TestSyncMessageReceiver(MessageReceiver):
|
||||
@ -179,8 +260,10 @@ class _BoundTestSyncMessageReceiver(BoundMessageReceiver):
|
||||
return self._receiver.handle_raw_message(self._obj, message)
|
||||
|
||||
|
||||
# RCVS_CODE_TEST_END
|
||||
# RCVA_CODE_TEST_BEGIN
|
||||
# RCV_SYNC_CODE_TEST_END
|
||||
|
||||
# Generated receiver supporting async handling:
|
||||
# RCV_ASYNC_CODE_TEST_BEGIN
|
||||
|
||||
|
||||
class _TestAsyncMessageReceiver(MessageReceiver):
|
||||
@ -235,7 +318,7 @@ class _BoundTestAsyncMessageReceiver(BoundMessageReceiver):
|
||||
self._obj, message)
|
||||
|
||||
|
||||
# RCVA_CODE_TEST_END
|
||||
# RCV_ASYNC_CODE_TEST_END
|
||||
|
||||
TEST_PROTOCOL = MessageProtocol(
|
||||
message_types={
|
||||
@ -266,21 +349,23 @@ def test_protocol_creation() -> None:
|
||||
response_types={0: _TResp1})
|
||||
|
||||
|
||||
def test_sender_module_embedded() -> None:
|
||||
def test_sender_module_sync_embedded() -> None:
|
||||
"""Test generation of protocol-specific sender modules for typing/etc."""
|
||||
# NOTE: Ideally we should be testing efro.message.create_sender_module()
|
||||
# here, but it requires us to pass code which imports this test module
|
||||
# to get at the protocol, and that currently fails in our static mypy
|
||||
# tests.
|
||||
smod = TEST_PROTOCOL.do_create_sender_module(
|
||||
'TestMessageSender',
|
||||
'protocol = TEST_PROTOCOL',
|
||||
'TestMessageSenderSync',
|
||||
protocol_create_code='protocol = TEST_PROTOCOL',
|
||||
enable_sync_sends=True,
|
||||
enable_async_sends=False,
|
||||
private=True,
|
||||
)
|
||||
|
||||
# Clip everything up to our first class declaration.
|
||||
lines = smod.splitlines()
|
||||
classline = lines.index('class _TestMessageSender(MessageSender):')
|
||||
classline = lines.index('class _TestMessageSenderSync(MessageSender):')
|
||||
clipped = '\n'.join(lines[classline:])
|
||||
|
||||
# This snippet should match what we've got embedded above;
|
||||
@ -288,7 +373,74 @@ def test_sender_module_embedded() -> None:
|
||||
with open(__file__, encoding='utf-8') as infile:
|
||||
ourcode = infile.read()
|
||||
|
||||
emb = f'# SEND_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# SEND_CODE_TEST_END\n'
|
||||
emb = (f'# SEND_SYNC_CODE_TEST_BEGIN'
|
||||
f'\n\n\n{clipped}\n\n\n# SEND_SYNC_CODE_TEST_END\n')
|
||||
if emb not in ourcode:
|
||||
print(f'EXPECTED EMBEDDED CODE:\n{emb}')
|
||||
raise RuntimeError('Generated sender module does not match embedded;'
|
||||
' test code needs to be updated.'
|
||||
' See test stdout for new code.')
|
||||
|
||||
|
||||
def test_sender_module_async_embedded() -> None:
|
||||
"""Test generation of protocol-specific sender modules for typing/etc."""
|
||||
# NOTE: Ideally we should be testing efro.message.create_sender_module()
|
||||
# here, but it requires us to pass code which imports this test module
|
||||
# to get at the protocol, and that currently fails in our static mypy
|
||||
# tests.
|
||||
smod = TEST_PROTOCOL.do_create_sender_module(
|
||||
'TestMessageSenderAsync',
|
||||
protocol_create_code='protocol = TEST_PROTOCOL',
|
||||
enable_sync_sends=False,
|
||||
enable_async_sends=True,
|
||||
private=True,
|
||||
)
|
||||
|
||||
# Clip everything up to our first class declaration.
|
||||
lines = smod.splitlines()
|
||||
classline = lines.index('class _TestMessageSenderAsync(MessageSender):')
|
||||
clipped = '\n'.join(lines[classline:])
|
||||
|
||||
# This snippet should match what we've got embedded above;
|
||||
# If not then we need to update our embedded version.
|
||||
with open(__file__, encoding='utf-8') as infile:
|
||||
ourcode = infile.read()
|
||||
|
||||
emb = (f'# SEND_ASYNC_CODE_TEST_BEGIN'
|
||||
f'\n\n\n{clipped}\n\n\n# SEND_ASYNC_CODE_TEST_END\n')
|
||||
if emb not in ourcode:
|
||||
print(f'EXPECTED EMBEDDED CODE:\n{emb}')
|
||||
raise RuntimeError('Generated sender module does not match embedded;'
|
||||
' test code needs to be updated.'
|
||||
' See test stdout for new code.')
|
||||
|
||||
|
||||
def test_sender_module_both_embedded() -> None:
|
||||
"""Test generation of protocol-specific sender modules for typing/etc."""
|
||||
# NOTE: Ideally we should be testing efro.message.create_sender_module()
|
||||
# here, but it requires us to pass code which imports this test module
|
||||
# to get at the protocol, and that currently fails in our static mypy
|
||||
# tests.
|
||||
smod = TEST_PROTOCOL.do_create_sender_module(
|
||||
'TestMessageSenderBoth',
|
||||
protocol_create_code='protocol = TEST_PROTOCOL',
|
||||
enable_sync_sends=True,
|
||||
enable_async_sends=True,
|
||||
private=True,
|
||||
)
|
||||
|
||||
# Clip everything up to our first class declaration.
|
||||
lines = smod.splitlines()
|
||||
classline = lines.index('class _TestMessageSenderBoth(MessageSender):')
|
||||
clipped = '\n'.join(lines[classline:])
|
||||
|
||||
# This snippet should match what we've got embedded above;
|
||||
# If not then we need to update our embedded version.
|
||||
with open(__file__, encoding='utf-8') as infile:
|
||||
ourcode = infile.read()
|
||||
|
||||
emb = (f'# SEND_BOTH_CODE_TEST_BEGIN'
|
||||
f'\n\n\n{clipped}\n\n\n# SEND_BOTH_CODE_TEST_END\n')
|
||||
if emb not in ourcode:
|
||||
print(f'EXPECTED EMBEDDED CODE:\n{emb}')
|
||||
raise RuntimeError('Generated sender module does not match embedded;'
|
||||
@ -319,7 +471,8 @@ def test_receiver_module_sync_embedded() -> None:
|
||||
with open(__file__, encoding='utf-8') as infile:
|
||||
ourcode = infile.read()
|
||||
|
||||
emb = f'# RCVS_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCVS_CODE_TEST_END\n'
|
||||
emb = (f'# RCV_SYNC_CODE_TEST_BEGIN'
|
||||
f'\n\n\n{clipped}\n\n\n# RCV_SYNC_CODE_TEST_END\n')
|
||||
if emb not in ourcode:
|
||||
print(f'EXPECTED SYNC RECEIVER EMBEDDED CODE:\n{emb}')
|
||||
raise RuntimeError(
|
||||
@ -352,7 +505,8 @@ def test_receiver_module_async_embedded() -> None:
|
||||
with open(__file__, encoding='utf-8') as infile:
|
||||
ourcode = infile.read()
|
||||
|
||||
emb = f'# RCVA_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCVA_CODE_TEST_END\n'
|
||||
emb = (f'# RCV_ASYNC_CODE_TEST_BEGIN'
|
||||
f'\n\n\n{clipped}\n\n\n# RCV_ASYNC_CODE_TEST_END\n')
|
||||
if emb not in ourcode:
|
||||
print(f'EXPECTED ASYNC RECEIVER EMBEDDED CODE:\n{emb}')
|
||||
raise RuntimeError(
|
||||
@ -400,7 +554,7 @@ def test_full_pipeline() -> None:
|
||||
class TestClassS:
|
||||
"""Test class incorporating send functionality."""
|
||||
|
||||
msg = _TestMessageSender()
|
||||
msg = _TestMessageSenderBoth()
|
||||
|
||||
def __init__(self, target: Union[TestClassRSync,
|
||||
TestClassRAsync]) -> None:
|
||||
|
||||
@ -317,6 +317,8 @@ class MessageProtocol:
|
||||
def do_create_sender_module(self,
|
||||
basename: str,
|
||||
protocol_create_code: str,
|
||||
enable_sync_sends: bool,
|
||||
enable_async_sends: bool,
|
||||
private: bool = False) -> str:
|
||||
"""Used by create_sender_module(); do not call directly."""
|
||||
# pylint: disable=too-many-locals
|
||||
@ -361,6 +363,10 @@ class MessageProtocol:
|
||||
|
||||
if len(msgtypes) > 1:
|
||||
for async_pass in False, True:
|
||||
if async_pass and not enable_async_sends:
|
||||
continue
|
||||
if not async_pass and not enable_sync_sends:
|
||||
continue
|
||||
pfx = 'async ' if async_pass else ''
|
||||
sfx = '_async' if async_pass else ''
|
||||
awt = 'await ' if async_pass else ''
|
||||
@ -823,6 +829,8 @@ class BoundMessageReceiver:
|
||||
|
||||
def create_sender_module(basename: str,
|
||||
protocol_create_code: str,
|
||||
enable_sync_sends: bool,
|
||||
enable_async_sends: bool,
|
||||
private: bool = False) -> str:
|
||||
"""Create a Python module defining a MessageSender subclass.
|
||||
|
||||
@ -853,6 +861,8 @@ def create_sender_module(basename: str,
|
||||
return protocol.do_create_sender_module(
|
||||
basename=basename,
|
||||
protocol_create_code=protocol_create_code,
|
||||
enable_sync_sends=enable_sync_sends,
|
||||
enable_async_sends=enable_async_sends,
|
||||
private=private)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user