work on async messaging

This commit is contained in:
Eric Froemling 2021-09-21 12:12:12 -05:00
parent b46b1cbca8
commit c8a6c3733d
No known key found for this signature in database
GPG Key ID: 89C93F0F8D6D5A98
5 changed files with 421 additions and 159 deletions

View File

@ -159,6 +159,7 @@
<w>availmins</w>
<w>availplug</w>
<w>aval</w>
<w>awaitable</w>
<w>axismotion</w>
<w>bacfg</w>
<w>backgrounded</w>
@ -303,6 +304,7 @@
<w>capturetheflag</w>
<w>carentity</w>
<w>cashregistersound</w>
<w>cbgn</w>
<w>cbits</w>
<w>cbot</w>
<w>cbtn</w>
@ -313,6 +315,7 @@
<w>cdrk</w>
<w>cdull</w>
<w>cdval</w>
<w>cend</w>
<w>centeuro</w>
<w>centiseconds</w>
<w>cfconfig</w>
@ -1869,6 +1872,8 @@
<w>rawpaths</w>
<w>rcade</w>
<w>rcfile</w>
<w>rcva</w>
<w>rcvs</w>
<w>rdict</w>
<w>rdir</w>
<w>readline</w>
@ -1997,6 +2002,7 @@
<w>selwidget</w>
<w>selwidgets</w>
<w>sendable</w>
<w>sendmethod</w>
<w>senze</w>
<w>seqtype</w>
<w>seqtypestr</w>

View File

@ -69,6 +69,7 @@
<w>availmins</w>
<w>avel</w>
<w>avels</w>
<w>awaitable</w>
<w>axismotion</w>
<w>backgrounded</w>
<w>backgrounding</w>
@ -155,10 +156,12 @@
<w>cancelbtn</w>
<w>capitan</w>
<w>cargs</w>
<w>cbgn</w>
<w>cbtnoffs</w>
<w>ccdd</w>
<w>ccontext</w>
<w>ccylinder</w>
<w>cend</w>
<w>centiseconds</w>
<w>cfgdir</w>
<w>cfgpath</w>
@ -860,6 +863,8 @@
<w>rasterizer</w>
<w>rawkey</w>
<w>rcade</w>
<w>rcva</w>
<w>rcvs</w>
<w>reaaaly</w>
<w>readset</w>
<w>realloc</w>
@ -922,6 +927,7 @@
<w>selwidget</w>
<w>selwidgets</w>
<w>sendable</w>
<w>sendmethod</w>
<w>seqlen</w>
<w>seqtype</w>
<w>seqtypestr</w>

View File

@ -1,5 +1,5 @@
<!-- THIS FILE IS AUTO GENERATED; DO NOT EDIT BY HAND -->
<h4><em>last updated on 2021-09-19 for Ballistica version 1.6.5 build 20393</em></h4>
<h4><em>last updated on 2021-09-21 for Ballistica version 1.6.5 build 20393</em></h4>
<p>This page documents the Python classes and functions in the 'ba' module,
which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please <a href="mailto:support@froemling.net">let me know</a>. Happy modding!</p>
<hr>

View File

@ -5,6 +5,7 @@
from __future__ import annotations
import os
import asyncio
from typing import TYPE_CHECKING, overload
from dataclasses import dataclass
@ -17,55 +18,55 @@ from efro.message import (Message, Response, MessageProtocol, MessageSender,
MessageReceiver)
if TYPE_CHECKING:
from typing import List, Type, Any, Callable, Union, Optional
from typing import List, Type, Any, Callable, Union, Optional, Awaitable
@ioprepped
@dataclass
class _TMessage1(Message):
class _TMsg1(Message):
"""Just testing."""
ival: int
@classmethod
def get_response_types(cls) -> List[Type[Response]]:
return [_TResponse1]
return [_TResp1]
@ioprepped
@dataclass
class _TMessage2(Message):
class _TMsg2(Message):
"""Just testing."""
sval: str
@classmethod
def get_response_types(cls) -> List[Type[Response]]:
return [_TResponse1, _TResponse2]
return [_TResp1, _TResp2]
@ioprepped
@dataclass
class _TMessage3(Message):
class _TMsg3(Message):
"""Just testing."""
sval: str
@ioprepped
@dataclass
class _TResponse1(Response):
class _TResp1(Response):
"""Just testing."""
bval: bool
@ioprepped
@dataclass
class _TResponse2(Response):
class _TResp2(Response):
"""Just testing."""
fval: float
@ioprepped
@dataclass
class _TResponse3(Message):
class _TResp3(Message):
"""Just testing."""
fval: float
@ -91,55 +92,73 @@ class _BoundTestMessageSender:
self._sender = sender
@overload
def send(self, message: _TMessage1) -> _TResponse1:
def send(self, message: _TMsg1) -> _TResp1:
...
@overload
def send(self, message: _TMessage2) -> Union[_TResponse1, _TResponse2]:
def send(self, message: _TMsg2) -> Union[_TResp1, _TResp2]:
...
@overload
def send(self, message: _TMessage3) -> None:
def send(self, message: _TMsg3) -> None:
...
def send(self, message: Message) -> Optional[Response]:
"""Send a message."""
"""Send a message synchronously."""
return self._sender.send(self._obj, message)
@overload
async def send_async(self, message: _TMsg1) -> _TResp1:
...
@overload
async def send_async(self, message: _TMsg2) -> Union[_TResp1, _TResp2]:
...
@overload
async def send_async(self, message: _TMsg3) -> None:
...
async def send_async(self, message: Message) -> Optional[Response]:
"""Send a message asynchronously."""
return await self._sender.send_async(self._obj, message)
# SEND_CODE_TEST_END
# RCV_CODE_TEST_BEGIN
# RCVS_CODE_TEST_BEGIN
class _TestMessageReceiver(MessageReceiver):
"""Protocol-specific receiver."""
class _TestSyncMessageReceiver(MessageReceiver):
"""Protocol-specific synchronous receiver."""
is_async = False
def __get__(
self,
obj: Any,
type_in: Any = None,
) -> _BoundTestMessageReceiver:
return _BoundTestMessageReceiver(obj, self)
) -> _BoundTestSyncMessageReceiver:
return _BoundTestSyncMessageReceiver(obj, self)
@overload
def handler(
self,
call: Callable[[Any, _TMessage1], _TResponse1],
) -> Callable[[Any, _TMessage1], _TResponse1]:
call: Callable[[Any, _TMsg1], _TResp1],
) -> Callable[[Any, _TMsg1], _TResp1]:
...
@overload
def handler(
self,
call: Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]],
) -> Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]]:
call: Callable[[Any, _TMsg2], Union[_TResp1, _TResp2]],
) -> Callable[[Any, _TMsg2], Union[_TResp1, _TResp2]]:
...
@overload
def handler(
self,
call: Callable[[Any, _TMessage3], None],
) -> Callable[[Any, _TMessage3], None]:
call: Callable[[Any, _TMsg3], None],
) -> Callable[[Any, _TMsg3], None]:
...
def handler(self, call: Callable) -> Callable:
@ -148,34 +167,104 @@ class _TestMessageReceiver(MessageReceiver):
return call
class _BoundTestMessageReceiver:
class _BoundTestSyncMessageReceiver:
"""Protocol-specific bound receiver."""
def __init__(
self,
obj: Any,
receiver: _TestMessageReceiver,
receiver: _TestSyncMessageReceiver,
) -> None:
assert obj is not None
self._obj = obj
self._receiver = receiver
def handle_raw_message(self, message: bytes) -> bytes:
"""Handle a raw incoming message."""
"""Handle a raw incoming synchronous message."""
return self._receiver.handle_raw_message(self._obj, message)
async def handle_raw_message_async(self, message: bytes) -> bytes:
"""Handle a raw incoming asynchronous message."""
return await self._receiver.handle_raw_message_async(
self._obj, message)
# RCV_CODE_TEST_END
# RCVS_CODE_TEST_END
# RCVA_CODE_TEST_BEGIN
class _TestAsyncMessageReceiver(MessageReceiver):
"""Protocol-specific asynchronous receiver."""
is_async = True
def __get__(
self,
obj: Any,
type_in: Any = None,
) -> _BoundTestAsyncMessageReceiver:
return _BoundTestAsyncMessageReceiver(obj, self)
@overload
def handler(
self,
call: Callable[[Any, _TMsg1], Awaitable[_TResp1]],
) -> Callable[[Any, _TMsg1], Awaitable[_TResp1]]:
...
@overload
def handler(
self,
call: Callable[[Any, _TMsg2], Awaitable[Union[_TResp1, _TResp2]]],
) -> Callable[[Any, _TMsg2], Awaitable[Union[_TResp1, _TResp2]]]:
...
@overload
def handler(
self,
call: Callable[[Any, _TMsg3], Awaitable[None]],
) -> Callable[[Any, _TMsg3], Awaitable[None]]:
...
def handler(self, call: Callable) -> Callable:
"""Decorator to register message handlers."""
self.register_handler(call)
return call
class _BoundTestAsyncMessageReceiver:
"""Protocol-specific bound receiver."""
def __init__(
self,
obj: Any,
receiver: _TestAsyncMessageReceiver,
) -> None:
assert obj is not None
self._obj = obj
self._receiver = receiver
def handle_raw_message(self, message: bytes) -> bytes:
"""Handle a raw incoming synchronous message."""
return self._receiver.handle_raw_message(self._obj, message)
async def handle_raw_message_async(self, message: bytes) -> bytes:
"""Handle a raw incoming asynchronous message."""
return await self._receiver.handle_raw_message_async(
self._obj, message)
# RCVA_CODE_TEST_END
TEST_PROTOCOL = MessageProtocol(
message_types={
0: _TMessage1,
1: _TMessage2,
2: _TMessage3,
0: _TMsg1,
1: _TMsg2,
2: _TMsg3,
},
response_types={
0: _TResponse1,
1: _TResponse2,
0: _TResp1,
1: _TResp2,
},
trusted_sender=True,
log_remote_exceptions=False,
@ -185,20 +274,21 @@ TEST_PROTOCOL = MessageProtocol(
def test_protocol_creation() -> None:
"""Test protocol creation."""
# This should fail because _TMessage1 can return _TResponse1 which
# This should fail because _TMsg1 can return _TResp1 which
# is not given an id here.
with pytest.raises(ValueError):
_protocol = MessageProtocol(message_types={0: _TMessage1},
response_types={0: _TResponse2})
_protocol = MessageProtocol(message_types={0: _TMsg1},
response_types={0: _TResp2})
# Now it should work.
_protocol = MessageProtocol(message_types={0: _TMessage1},
response_types={0: _TResponse1})
_protocol = MessageProtocol(message_types={0: _TMsg1},
response_types={0: _TResp1})
def test_sender_module_embedded() -> None:
"""Test generation of protocol-specific sender modules for typing/etc."""
smod = TEST_PROTOCOL.create_sender_module('Test', private=True)
smod = TEST_PROTOCOL.create_sender_module('TestMessageSender',
private=True)
# Clip everything up to our first class declaration.
lines = smod.splitlines()
@ -218,13 +308,15 @@ def test_sender_module_embedded() -> None:
' See test stdout for new code.')
def test_receiver_module_embedded() -> None:
def test_receiver_module_sync_embedded() -> None:
"""Test generation of protocol-specific sender modules for typing/etc."""
smod = TEST_PROTOCOL.create_receiver_module('Test', private=True)
smod = TEST_PROTOCOL.create_receiver_module('TestSyncMessageReceiver',
is_async=False,
private=True)
# Clip everything up to our first class declaration.
lines = smod.splitlines()
classline = lines.index('class _TestMessageReceiver(MessageReceiver):')
classline = lines.index('class _TestSyncMessageReceiver(MessageReceiver):')
clipped = '\n'.join(lines[classline:])
# This snippet should match what we've got embedded above;
@ -232,12 +324,39 @@ def test_receiver_module_embedded() -> None:
with open(__file__, encoding='utf-8') as infile:
ourcode = infile.read()
emb = f'# RCV_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCV_CODE_TEST_END\n'
emb = f'# RCVS_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCVS_CODE_TEST_END\n'
if emb not in ourcode:
print(f'EXPECTED EMBEDDED CODE:\n{emb}')
raise RuntimeError('Generated sender module does not match embedded;'
' test code needs to be updated.'
' See test stdout for new code.')
print(f'EXPECTED SYNC RECEIVER EMBEDDED CODE:\n{emb}')
raise RuntimeError(
'Generated sync receiver module does not match embedded;'
' test code needs to be updated.'
' See test stdout for new code.')
def test_receiver_module_async_embedded() -> None:
"""Test generation of protocol-specific sender modules for typing/etc."""
smod = TEST_PROTOCOL.create_receiver_module('TestAsyncMessageReceiver',
is_async=True,
private=True)
# Clip everything up to our first class declaration.
lines = smod.splitlines()
classline = lines.index(
'class _TestAsyncMessageReceiver(MessageReceiver):')
clipped = '\n'.join(lines[classline:])
# This snippet should match what we've got embedded above;
# If not then we need to update our embedded version.
with open(__file__, encoding='utf-8') as infile:
ourcode = infile.read()
emb = f'# RCVA_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCVA_CODE_TEST_END\n'
if emb not in ourcode:
print(f'EXPECTED ASYNC RECEIVER EMBEDDED CODE:\n{emb}')
raise RuntimeError(
'Generated async receiver module does not match embedded;'
' test code needs to be updated.'
' See test stdout for new code.')
def test_receiver_creation() -> None:
@ -251,13 +370,13 @@ def test_receiver_creation() -> None:
class _TestClassR:
"""Test class incorporating receive functionality."""
receiver = _TestMessageReceiver(TEST_PROTOCOL)
receiver = _TestSyncMessageReceiver(TEST_PROTOCOL)
@receiver.handler
def handle_test_message_2(self, msg: _TMessage2) -> _TResponse2:
def handle_test_message_2(self, msg: _TMsg2) -> _TResp2:
"""Test."""
del msg # Unused
return _TResponse2(fval=1.2)
return _TResp2(fval=1.2)
# Validation should fail because not all message types in the
# protocol are handled.
@ -266,13 +385,13 @@ def test_receiver_creation() -> None:
class _TestClassR2:
"""Test class incorporating receive functionality."""
receiver = _TestMessageReceiver(TEST_PROTOCOL)
receiver = _TestSyncMessageReceiver(TEST_PROTOCOL)
# Checks that we've added handlers for all message types, etc.
receiver.validate()
def test_synchronous_messaging() -> None:
def test_full_pipeline() -> None:
"""Test the full pipeline."""
# Define a class that can send messages and one that can receive them.
@ -281,66 +400,119 @@ def test_synchronous_messaging() -> None:
msg = _TestMessageSender(TEST_PROTOCOL)
def __init__(self, target: TestClassR) -> None:
def __init__(self, target: Union[TestClassRSync,
TestClassRAsync]) -> None:
self._target = target
@msg.send_raw_handler
@msg.send_method
def _send_raw_message(self, data: bytes) -> bytes:
"""Test."""
"""Handle synchronous sending of raw message data."""
# Just talk directly to the receiver for this example.
# (currently only support synchronous receivers)
assert isinstance(self._target, TestClassRSync)
return self._target.receiver.handle_raw_message(data)
class TestClassR:
"""Test class incorporating receive functionality."""
@msg.send_async_method
async def _send_raw_message_async(self, data: bytes) -> bytes:
"""Handle asynchronous sending of raw message data."""
# Just talk directly to the receiver for this example.
# (we can do sync or async receivers)
if isinstance(self._target, TestClassRSync):
return self._target.receiver.handle_raw_message(data)
return await self._target.receiver.handle_raw_message_async(data)
receiver = _TestMessageReceiver(TEST_PROTOCOL)
class TestClassRSync:
"""Test class incorporating synchronous receive functionality."""
receiver = _TestSyncMessageReceiver(TEST_PROTOCOL)
@receiver.handler
def handle_test_message_1(self, msg: _TMessage1) -> _TResponse1:
def handle_test_message_1(self, msg: _TMsg1) -> _TResp1:
"""Test."""
if msg.ival == 1:
raise CleanError('Testing Clean Error')
if msg.ival == 2:
raise RuntimeError('Testing Runtime Error')
return _TResponse1(bval=True)
return _TResp1(bval=True)
@receiver.handler
def handle_test_message_2(
self, msg: _TMessage2) -> Union[_TResponse1, _TResponse2]:
def handle_test_message_2(self,
msg: _TMsg2) -> Union[_TResp1, _TResp2]:
"""Test."""
del msg # Unused
return _TResponse2(fval=1.2)
return _TResp2(fval=1.2)
@receiver.handler
def handle_test_message_3(self, msg: _TMessage3) -> None:
def handle_test_message_3(self, msg: _TMsg3) -> None:
"""Test."""
del msg # Unused
receiver.validate()
obj_r = TestClassR()
obj = TestClassS(target=obj_r)
class TestClassRAsync:
"""Test class incorporating asynchronous receive functionality."""
response = obj.msg.send(_TMessage1(ival=0))
assert isinstance(response, _TResponse1)
receiver = _TestAsyncMessageReceiver(TEST_PROTOCOL)
response2 = obj.msg.send(_TMessage2(sval='rah'))
assert isinstance(response2, (_TResponse1, _TResponse2))
@receiver.handler
async def handle_test_message_1(self, msg: _TMsg1) -> _TResp1:
"""Test."""
if msg.ival == 1:
raise CleanError('Testing Clean Error')
if msg.ival == 2:
raise RuntimeError('Testing Runtime Error')
return _TResp1(bval=True)
response3 = obj.msg.send(_TMessage3(sval='rah'))
assert response3 is None
@receiver.handler
async def handle_test_message_2(
self, msg: _TMsg2) -> Union[_TResp1, _TResp2]:
"""Test."""
del msg # Unused
return _TResp2(fval=1.2)
# Make sure static typing lines up too.
@receiver.handler
async def handle_test_message_3(self, msg: _TMsg3) -> None:
"""Test."""
del msg # Unused
receiver.validate()
obj_r_sync = TestClassRSync()
obj_r_async = TestClassRAsync()
obj = TestClassS(target=obj_r_sync)
obj2 = TestClassS(target=obj_r_async)
# Test sends (of sync and async varieties).
response1 = obj.msg.send(_TMsg1(ival=0))
response2 = obj.msg.send(_TMsg2(sval='rah'))
response3 = obj.msg.send(_TMsg3(sval='rah'))
response4 = asyncio.run(obj.msg.send_async(_TMsg1(ival=0)))
# Make sure static typing lines up with what we expect.
if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1':
assert static_type_equals(response, _TResponse1)
assert static_type_equals(response1, _TResp1)
assert static_type_equals(response3, None)
assert isinstance(response1, _TResp1)
assert isinstance(response2, (_TResp1, _TResp2))
assert response3 is None
assert isinstance(response4, _TResp1)
# Remote CleanErrors should come across locally as the same.
try:
_response4 = obj.msg.send(_TMessage1(ival=1))
_response5 = obj.msg.send(_TMsg1(ival=1))
except Exception as exc:
assert isinstance(exc, CleanError)
assert str(exc) == 'Testing Clean Error'
# Other remote errors should result in RemoteError.
with pytest.raises(RemoteError):
_response4 = obj.msg.send(_TMessage1(ival=2))
_response5 = obj.msg.send(_TMsg1(ival=2))
# Now test sends to async handlers.
response6 = asyncio.run(obj2.msg.send_async(_TMsg1(ival=0)))
assert isinstance(response6, _TResp1)
# Make sure static typing lines up with what we expect.
if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1':
assert static_type_equals(response6, _TResp1)

View File

@ -22,7 +22,7 @@ from efro.dataclassio import (ioprepped, is_ioprepped_dataclass, IOAttrs,
if TYPE_CHECKING:
from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set,
Sequence, Union)
Sequence, Union, Awaitable)
from efro.error import CommunicationError
TM = TypeVar('TM', bound='MessageSender')
@ -300,7 +300,7 @@ class MessageProtocol:
return out
def create_sender_module(self,
classname: str,
basename: str,
private: bool = False) -> str:
""""Create a Python module defining a MessageSender subclass.
@ -308,29 +308,35 @@ class MessageProtocol:
for the varieties of send calls for message/response types defined
in the protocol.
Class names are based on basename; a basename 'FooSender' will
result in classes FooSender and BoundFooSender.
If 'private' is True, class-names will be prefixed with an '_'.
Note that line lengths are not clipped, so output may need to be
run through a formatter to prevent lint warnings about excessive
line lengths.
"""
# pylint: disable=too-many-locals
ppre = '_' if private else ''
out = self._get_module_header('sender')
out += (f'class {ppre}{classname}MessageSender(MessageSender):\n'
out += (f'class {ppre}{basename}(MessageSender):\n'
f' """Protocol-specific sender."""\n'
f'\n'
f' def __get__(self,\n'
f' obj: Any,\n'
f' type_in: Any = None)'
f' -> {ppre}Bound{classname}MessageSender:\n'
f' return {ppre}Bound{classname}MessageSender'
f' -> {ppre}Bound{basename}:\n'
f' return {ppre}Bound{basename}'
f'(obj, self)\n'
f'\n'
f'\n'
f'class {ppre}Bound{classname}MessageSender:\n'
f'class {ppre}Bound{basename}:\n'
f' """Protocol-specific bound sender."""\n'
f'\n'
f' def __init__(self, obj: Any,'
f' sender: {ppre}{classname}MessageSender) -> None:\n'
f' sender: {ppre}{basename}) -> None:\n'
f' assert obj is not None\n'
f' self._obj = obj\n'
f' self._sender = sender\n')
@ -352,29 +358,37 @@ class MessageProtocol:
return 'None' if rtype is EmptyResponse else rtype.__name__
if len(msgtypes) > 1:
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
rtypevar = f'Union[{tps}]'
else:
rtypevar = _filt_tp_name(rtypes[0])
for async_pass in False, True:
pfx = 'async ' if async_pass else ''
sfx = '_async' if async_pass else ''
awt = 'await ' if async_pass else ''
how = 'asynchronously' if async_pass else 'synchronously'
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
rtypevar = f'Union[{tps}]'
else:
rtypevar = _filt_tp_name(rtypes[0])
out += (f'\n'
f' @overload\n'
f' {pfx}def send{sfx}(self,'
f' message: {msgtypevar})'
f' -> {rtypevar}:\n'
f' ...\n')
out += (f'\n'
f' @overload\n'
f' def send(self, message: {msgtypevar})'
f' -> {rtypevar}:\n'
f' ...\n')
out += ('\n'
' def send(self, message: Message)'
' -> Optional[Response]:\n'
' """Send a message."""\n'
' return self._sender.send(self._obj, message)\n')
f' {pfx}def send{sfx}(self, message: Message)'
f' -> Optional[Response]:\n'
f' """Send a message {how}."""\n'
f' return {awt}self._sender.'
f'send{sfx}(self._obj, message)\n')
return out
def create_receiver_module(self,
classname: str,
basename: str,
is_async: bool,
private: bool = False) -> str:
""""Create a Python module defining a MessageReceiver subclass.
@ -382,21 +396,33 @@ class MessageProtocol:
for the register method for message/response types defined in
the protocol.
Class names are based on basename; a basename 'FooReceiver' will
result in FooReceiver and BoundFooReceiver.
If 'is_async' is True, handle_raw_message() will be an async method
and the @handler decorator will expect async methods.
If 'private' is True, class-names will be prefixed with an '_'.
Note that line lengths are not clipped, so output may need to be
run through a formatter to prevent lint warnings about excessive
line lengths.
"""
# pylint: disable=too-many-locals
desc = 'asynchronous' if is_async else 'synchronous'
ppre = '_' if private else ''
out = self._get_module_header('receiver')
out += (f'class {ppre}{classname}MessageReceiver(MessageReceiver):\n'
f' """Protocol-specific receiver."""\n'
out += (f'class {ppre}{basename}(MessageReceiver):\n'
f' """Protocol-specific {desc} receiver."""\n'
f'\n'
f' is_async = {is_async}\n'
f'\n'
f' def __get__(\n'
f' self,\n'
f' obj: Any,\n'
f' type_in: Any = None,\n'
f' ) -> {ppre}Bound{classname}MessageReceiver:\n'
f' return {ppre}Bound{classname}MessageReceiver('
f' ) -> {ppre}Bound{basename}:\n'
f' return {ppre}Bound{basename}('
f'obj, self)\n')
# Define handler() overloads for all registered message types.
@ -416,6 +442,8 @@ class MessageProtocol:
return 'None' if rtype is EmptyResponse else rtype.__name__
if len(msgtypes) > 1:
cbgn = 'Awaitable[' if is_async else ''
cend = ']' if is_async else ''
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
@ -424,6 +452,7 @@ class MessageProtocol:
rtypevar = f'Union[{tps}]'
else:
rtypevar = _filt_tp_name(rtypes[0])
rtypevar = f'{cbgn}{rtypevar}{cend}'
out += (
f'\n'
f' @overload\n'
@ -441,22 +470,29 @@ class MessageProtocol:
out += (f'\n'
f'\n'
f'class {ppre}Bound{classname}MessageReceiver:\n'
f'class {ppre}Bound{basename}:\n'
f' """Protocol-specific bound receiver."""\n'
f'\n'
f' def __init__(\n'
f' self,\n'
f' obj: Any,\n'
f' receiver: {ppre}{classname}MessageReceiver,\n'
f' receiver: {ppre}{basename},\n'
f' ) -> None:\n'
f' assert obj is not None\n'
f' self._obj = obj\n'
f' self._receiver = receiver\n'
f'\n'
f' def handle_raw_message(self, message: bytes) -> bytes:\n'
f' """Handle a raw incoming message."""\n'
f' """Handle a raw incoming synchronous message."""\n'
f' return self._receiver.handle_raw_message'
f'(self._obj, message)\n')
f'(self._obj, message)\n'
f'\n'
f' async def handle_raw_message_async(self, message: bytes)'
f' -> bytes:\n'
f' """Handle a raw incoming asynchronous message."""\n'
f' return await'
f' self._receiver.handle_raw_message_async(\n'
f' self._obj, message)\n')
return out
@ -471,7 +507,7 @@ class MessageSender:
class MyClass:
msg = MyMessageSender(some_protocol)
@msg.send_raw_handler
@msg.sendmethod
def send_raw_message(self, message: bytes) -> bytes:
# Actually send the message here.
@ -485,8 +521,10 @@ class MessageSender:
self._protocol = protocol
self._send_raw_message_call: Optional[Callable[[Any, bytes],
bytes]] = None
self._send_async_raw_message_call: Optional[Callable[
[Any, bytes], Awaitable[bytes]]] = None
def send_raw_handler(
def send_method(
self, call: Callable[[Any, bytes],
bytes]) -> Callable[[Any, bytes], bytes]:
"""Function decorator for setting raw send method."""
@ -494,6 +532,14 @@ class MessageSender:
self._send_raw_message_call = call
return call
def send_async_method(
self, call: Callable[[Any, bytes], Awaitable[bytes]]
) -> Callable[[Any, bytes], Awaitable[bytes]]:
"""Function decorator for setting raw send-async method."""
assert self._send_async_raw_message_call is None
self._send_async_raw_message_call = call
return call
def send(self, bound_obj: Any, message: Message) -> Optional[Response]:
"""Send a message and receive a response.
@ -510,21 +556,24 @@ class MessageSender:
or type(response) in type(message).get_response_types())
return response
def send_bg(self, bound_obj: Any, message: Message) -> Message:
"""Send a message asynchronously and receive a future.
The message will be encoded for transport and passed to
dispatch_raw_message from a background thread.
"""
raise RuntimeError('Unimplemented!')
def send_async(self, bound_obj: Any, message: Message) -> Message:
async def send_async(self, bound_obj: Any,
message: Message) -> Optional[Response]:
"""Send a message asynchronously using asyncio.
The message will be encoded for transport and passed to
dispatch_raw_message_async.
"""
raise RuntimeError('Unimplemented!')
if self._send_async_raw_message_call is None:
raise RuntimeError('send_async() is unimplemented for this type.')
msg_encoded = self._protocol.encode_message(message)
response_encoded = await self._send_async_raw_message_call(
bound_obj, msg_encoded)
response = self._protocol.decode_response(response_encoded)
assert isinstance(response, (Response, type(None)))
assert (response is None
or type(response) in type(message).get_response_types())
return response
class MessageReceiver:
@ -552,6 +601,8 @@ class MessageReceiver:
an Exception being raised on the sending end.
"""
is_async = False
def __init__(self, protocol: MessageProtocol) -> None:
self._protocol = protocol
self._handlers: Dict[Type[Message], Callable] = {}
@ -564,6 +615,7 @@ class MessageReceiver:
The message type handled by the call is determined by its
type annotation.
"""
# pylint: disable=too-many-locals
# TODO: can use types.GenericAlias in 3.9.
from typing import _GenericAlias # type: ignore
from typing import Union, get_type_hints, get_args
@ -576,10 +628,19 @@ class MessageReceiver:
raise ValueError(f'Expected callable signature of {expectedsig};'
f' got {sig.args}')
# Make sure we are only given async methods if we are an async handler
# and sync ones otherwise.
is_async = inspect.iscoroutinefunction(call)
if self.is_async != is_async:
msg = ('Expected a sync method; found an async one.' if is_async
else 'Expected an async method; found a sync one.')
raise ValueError(msg)
# Check annotation types to determine what message types we handle.
# Return-type annotation can be a Union, but we probably don't
# have it available at runtime. Explicitly pull it in.
anns = get_type_hints(call, localns={'Union': Union})
msgtype = anns.get('msg')
if not isinstance(msgtype, type):
raise TypeError(
@ -642,54 +703,71 @@ class MessageReceiver:
logging.warning(msg)
raise TypeError(msg)
def _decode_incoming_message(self,
msg: bytes) -> Tuple[Message, Type[Message]]:
# Decode the incoming message.
msg_decoded = self._protocol.decode_message(msg)
msgtype = type(msg_decoded)
assert issubclass(msgtype, Message)
return msg_decoded, msgtype
def _encode_response(self, response: Optional[Response],
msgtype: Type[Message]) -> bytes:
# A return value of None equals EmptyResponse.
if response is None:
response = EmptyResponse()
# Re-encode the response.
assert isinstance(response, Response)
# (user should never explicitly return these)
assert not isinstance(response, ErrorResponse)
assert type(response) in msgtype.get_response_types()
return self._protocol.encode_response(response)
def _handle_raw_message_error(self, exc: Exception) -> bytes:
if self._protocol.log_remote_exceptions:
logging.exception('Error handling message.')
# If anything goes wrong, return a ErrorResponse instead.
if (isinstance(exc, CleanError)
and self._protocol.preserve_clean_errors):
err_response = ErrorResponse(error_message=str(exc),
error_type=ErrorType.CLEAN)
else:
err_response = ErrorResponse(
error_message=(traceback.format_exc()
if self._protocol.trusted_sender else
'An unknown error has occurred.'),
error_type=ErrorType.OTHER)
return self._protocol.encode_response(err_response)
def handle_raw_message(self, bound_obj: Any, msg: bytes) -> bytes:
"""Decode, handle, and return an encoded response for a message."""
try:
# Decode the incoming message.
msg_decoded = self._protocol.decode_message(msg)
msgtype = type(msg_decoded)
assert issubclass(msgtype, Message)
# Call the proper handler.
msg_decoded, msgtype = self._decode_incoming_message(msg)
handler = self._handlers.get(msgtype)
if handler is None:
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
response = handler(bound_obj, msg_decoded)
# A return value of None equals EmptyResponse.
if response is None:
response = EmptyResponse()
# Re-encode the response.
assert isinstance(response, Response)
# (user should never explicitly return these)
assert not isinstance(response, ErrorResponse)
assert type(response) in msgtype.get_response_types()
return self._protocol.encode_response(response)
return self._encode_response(response, msgtype)
except Exception as exc:
return self._handle_raw_message_error(exc)
if self._protocol.log_remote_exceptions:
logging.exception('Error handling message.')
# If anything goes wrong, return a ErrorResponse instead.
if (isinstance(exc, CleanError)
and self._protocol.preserve_clean_errors):
err_response = ErrorResponse(error_message=str(exc),
error_type=ErrorType.CLEAN)
else:
err_response = ErrorResponse(
error_message=(traceback.format_exc()
if self._protocol.trusted_sender else
'An unknown error has occurred.'),
error_type=ErrorType.OTHER)
return self._protocol.encode_response(err_response)
async def handle_raw_message_async(self, msg: bytes) -> bytes:
async def handle_raw_message_async(self, bound_obj: Any,
msg: bytes) -> bytes:
"""Should be called when the receiver gets a message.
The return value is the raw response to the message.
"""
raise RuntimeError('Unimplemented!')
try:
msg_decoded, msgtype = self._decode_incoming_message(msg)
handler = self._handlers.get(msgtype)
if handler is None:
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
response = await handler(bound_obj, msg_decoded)
return self._encode_response(response, msgtype)
except Exception as exc:
return self._handle_raw_message_error(exc)